Spaces:
Sleeping
Sleeping
""" | |
ChatGPT Clone - 日本語対応チャットボット | |
Hugging Face Spaces (ZeroGPU) 対応版 | |
使用モデル: | |
- elyza/Llama-3-ELYZA-JP-8B | |
- Fugaku-LLM/Fugaku-LLM-13B | |
""" | |
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import os | |
from typing import List, Tuple | |
# Hugging Face token from environment variable | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
# トークンのチェック | |
if not HF_TOKEN: | |
print("警告: HF_TOKENが設定されていません。プライベートモデルへのアクセスが制限される場合があります。") | |
# Check if running on ZeroGPU | |
try: | |
import spaces | |
IS_ZEROGPU = True | |
print("ZeroGPU環境を検出しました。") | |
except ImportError: | |
IS_ZEROGPU = False | |
print("通常のGPU/CPU環境で実行しています。") | |
class ChatBot: | |
def __init__(self): | |
self.model = None | |
self.tokenizer = None | |
self.current_model = None | |
def load_model(self, model_name: str): | |
"""モデルとトークナイザーをロード""" | |
if self.current_model == model_name and self.model is not None: | |
return | |
try: | |
# メモリクリア | |
if self.model is not None: | |
del self.model | |
del self.tokenizer | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
torch.cuda.synchronize() | |
# トークナイザーロード | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
model_name, | |
token=HF_TOKEN, | |
trust_remote_code=True, | |
padding_side="left" | |
) | |
# パッドトークンの設定 | |
if self.tokenizer.pad_token is None: | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id | |
# モデルロード(ZeroGPU対応) | |
self.model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
token=HF_TOKEN, | |
torch_dtype=torch.float16, | |
low_cpu_mem_usage=True, | |
trust_remote_code=True, | |
load_in_8bit=False, # ZeroGPU環境では8bit量子化は使わない | |
device_map=None # ZeroGPU環境では自動マッピングしない | |
) | |
self.current_model = model_name | |
print(f"モデル {model_name} のロードが完了しました。") | |
except Exception as e: | |
print(f"モデルのロード中にエラーが発生しました: {str(e)}") | |
raise | |
def _generate_response_gpu(self, message: str, history: List[Tuple[str, str]], model_name: str, | |
temperature: float = 0.7, max_tokens: int = 512) -> str: | |
"""GPU上で応答を生成する実際の処理""" | |
# モデルロード | |
self.load_model(model_name) | |
# GPUに移動 | |
self.model.to('cuda') | |
# プロンプト構築 | |
prompt = self._build_prompt(message, history) | |
# トークナイズ | |
inputs = self.tokenizer.encode(prompt, return_tensors="pt").to('cuda') | |
# 生成 | |
with torch.no_grad(): | |
outputs = self.model.generate( | |
inputs, | |
max_new_tokens=max_tokens, | |
temperature=temperature, | |
do_sample=True, | |
top_p=0.95, | |
top_k=50, | |
repetition_penalty=1.1, | |
pad_token_id=self.tokenizer.pad_token_id, | |
eos_token_id=self.tokenizer.eos_token_id | |
) | |
# デコード | |
response = self.tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True) | |
# CPUに戻す(メモリ節約) | |
self.model.to('cpu') | |
torch.cuda.empty_cache() | |
torch.cuda.synchronize() | |
return response.strip() | |
def generate_response(self, message: str, history: List[Tuple[str, str]], model_name: str, | |
temperature: float = 0.7, max_tokens: int = 512) -> str: | |
"""メッセージに対する応答を生成""" | |
if IS_ZEROGPU: | |
# ZeroGPU環境の場合 | |
return self._generate_response_gpu(message, history, model_name, temperature, max_tokens) | |
else: | |
# 通常環境の場合 | |
self.load_model(model_name) | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
if device == 'cuda': | |
self.model.to(device) | |
prompt = self._build_prompt(message, history) | |
inputs = self.tokenizer.encode(prompt, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
outputs = self.model.generate( | |
inputs, | |
max_new_tokens=max_tokens, | |
temperature=temperature, | |
do_sample=True, | |
top_p=0.95, | |
top_k=50, | |
repetition_penalty=1.1, | |
pad_token_id=self.tokenizer.pad_token_id, | |
eos_token_id=self.tokenizer.eos_token_id | |
) | |
response = self.tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True) | |
return response.strip() | |
def _build_prompt(self, message: str, history: List[Tuple[str, str]]) -> str: | |
"""会話履歴からプロンプトを構築""" | |
prompt = "" | |
# 履歴を追加(最新3件のみ使用 - メモリ効率のため) | |
for user_msg, assistant_msg in history[-3:]: | |
prompt += f"User: {user_msg}\nAssistant: {assistant_msg}\n\n" | |
# 現在のメッセージを追加 | |
prompt += f"User: {message}\nAssistant: " | |
return prompt | |
# ChatBotインスタンス | |
chatbot = ChatBot() | |
# ZeroGPU環境の場合、GPUデコレータを適用 | |
if IS_ZEROGPU: | |
chatbot._generate_response_gpu = spaces.GPU(duration=120)(chatbot._generate_response_gpu) | |
def respond(message: str, history: List[Tuple[str, str]], model_name: str, | |
temperature: float, max_tokens: int) -> Tuple[List[Tuple[str, str]], str]: | |
"""Gradioのコールバック関数""" | |
if not message: | |
return history, "" | |
try: | |
# 応答生成 | |
response = chatbot.generate_response(message, history, model_name, temperature, max_tokens) | |
# 履歴に追加 | |
history.append((message, response)) | |
return history, "" | |
except RuntimeError as e: | |
if "out of memory" in str(e).lower(): | |
error_msg = "メモリ不足エラー: より小さいモデルを使用するか、最大トークン数を減らしてください。" | |
else: | |
error_msg = f"実行時エラー: {str(e)}" | |
history.append((message, error_msg)) | |
return history, "" | |
except Exception as e: | |
error_msg = f"エラーが発生しました: {str(e)}" | |
history.append((message, error_msg)) | |
return history, "" | |
def clear_chat() -> Tuple[List, str]: | |
"""チャット履歴をクリア""" | |
return [], "" | |
# Gradio UI | |
with gr.Blocks(title="ChatGPT Clone", theme=gr.themes.Soft()) as app: | |
gr.Markdown("# 🤖 ChatGPT Clone") | |
gr.Markdown(""" | |
日本語対応のLLMを使用したチャットボットです。 | |
**使用可能モデル:** | |
- [elyza/Llama-3-ELYZA-JP-8B](https://huggingface.co/elyza/Llama-3-ELYZA-JP-8B) | |
- [Fugaku-LLM/Fugaku-LLM-13B](https://huggingface.co/Fugaku-LLM/Fugaku-LLM-13B) | |
""") | |
with gr.Row(): | |
with gr.Column(scale=3): | |
chatbot_ui = gr.Chatbot( | |
label="Chat", | |
height=500, | |
show_label=False, | |
container=True | |
) | |
with gr.Row(): | |
msg_input = gr.Textbox( | |
label="メッセージを入力", | |
placeholder="ここにメッセージを入力してください...", | |
lines=2, | |
scale=4, | |
show_label=False | |
) | |
send_btn = gr.Button("送信", variant="primary", scale=1) | |
with gr.Row(): | |
clear_btn = gr.Button("🗑️ 新しい会話", variant="secondary") | |
with gr.Column(scale=1): | |
model_select = gr.Dropdown( | |
choices=[ | |
"elyza/Llama-3-ELYZA-JP-8B", | |
"Fugaku-LLM/Fugaku-LLM-13B", | |
], | |
value="elyza/Llama-3-ELYZA-JP-8B", | |
label="モデル選択", | |
interactive=True | |
) | |
temperature = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.7, | |
step=0.1, | |
label="Temperature", | |
info="生成の創造性を調整" | |
) | |
max_tokens = gr.Slider( | |
minimum=64, | |
maximum=512, | |
value=256, | |
step=64, | |
label="最大トークン数", | |
info="生成する最大トークン数" | |
) | |
gr.Markdown(""" | |
### 使い方 | |
1. モデルを選択 | |
2. メッセージを入力 | |
3. 送信ボタンをクリック | |
### 注意事項 | |
- 初回のモデル読み込みには時間がかかります | |
- ZeroGPU使用により高速推論が可能 | |
- 1回の生成は120秒以内に完了します | |
- 大きなモデル使用時は、短めの応答になる場合があります | |
""") | |
# イベントハンドラ | |
msg_input.submit( | |
fn=respond, | |
inputs=[msg_input, chatbot_ui, model_select, temperature, max_tokens], | |
outputs=[chatbot_ui, msg_input] | |
) | |
send_btn.click( | |
fn=respond, | |
inputs=[msg_input, chatbot_ui, model_select, temperature, max_tokens], | |
outputs=[chatbot_ui, msg_input] | |
) | |
clear_btn.click( | |
fn=clear_chat, | |
outputs=[chatbot_ui, msg_input] | |
) | |
if __name__ == "__main__": | |
# Hugging Face Spaces環境かどうかを確認 | |
is_hf_spaces = os.getenv("SPACE_ID") is not None | |
app.launch( | |
share=False, | |
show_error=True, | |
server_name="0.0.0.0" if is_hf_spaces else "127.0.0.1", | |
server_port=7860 | |
) |