Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import os | |
from typing import List, Tuple | |
# Hugging Face token from environment variable | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
class ChatBot: | |
def __init__(self): | |
self.model = None | |
self.tokenizer = None | |
self.current_model = None | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
def load_model(self, model_name: str): | |
"""モデルとトークナイザーをロード""" | |
if self.current_model == model_name: | |
return | |
# メモリクリア | |
if self.model is not None: | |
del self.model | |
del self.tokenizer | |
torch.cuda.empty_cache() | |
# モデルロード | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
model_name, | |
use_auth_token=HF_TOKEN, | |
trust_remote_code=True | |
) | |
self.model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
use_auth_token=HF_TOKEN, | |
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, | |
device_map="auto" if self.device == "cuda" else None, | |
trust_remote_code=True | |
) | |
if self.device == "cuda": | |
self.model = self.model.to(self.device) | |
self.current_model = model_name | |
def generate_response(self, message: str, history: List[Tuple[str, str]], model_name: str, | |
temperature: float = 0.7, max_tokens: int = 512) -> str: | |
"""メッセージに対する応答を生成""" | |
# モデルロード | |
self.load_model(model_name) | |
# プロンプト構築 | |
prompt = self._build_prompt(message, history) | |
# トークナイズ | |
inputs = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device) | |
# 生成 | |
with torch.no_grad(): | |
outputs = self.model.generate( | |
inputs, | |
max_new_tokens=max_tokens, | |
temperature=temperature, | |
do_sample=True, | |
top_p=0.95, | |
pad_token_id=self.tokenizer.pad_token_id, | |
eos_token_id=self.tokenizer.eos_token_id | |
) | |
# デコード | |
response = self.tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True) | |
return response.strip() | |
def _build_prompt(self, message: str, history: List[Tuple[str, str]]) -> str: | |
"""会話履歴からプロンプトを構築""" | |
prompt = "" | |
# 履歴を追加 | |
for user_msg, assistant_msg in history[-5:]: # 最新5件の履歴のみ使用 | |
prompt += f"User: {user_msg}\nAssistant: {assistant_msg}\n\n" | |
# 現在のメッセージを追加 | |
prompt += f"User: {message}\nAssistant: " | |
return prompt | |
# ChatBotインスタンス | |
chatbot = ChatBot() | |
def respond(message: str, history: List[Tuple[str, str]], model_name: str, | |
temperature: float, max_tokens: int) -> Tuple[List[Tuple[str, str]], str]: | |
"""Gradioのコールバック関数""" | |
if not message: | |
return history, "" | |
try: | |
# 応答生成 | |
response = chatbot.generate_response(message, history, model_name, temperature, max_tokens) | |
# 履歴に追加 | |
history.append((message, response)) | |
return history, "" | |
except Exception as e: | |
error_msg = f"エラーが発生しました: {str(e)}" | |
history.append((message, error_msg)) | |
return history, "" | |
def clear_chat() -> Tuple[List, str]: | |
"""チャット履歴をクリア""" | |
return [], "" | |
# Gradio UI | |
with gr.Blocks(title="ChatGPT Clone", theme=gr.themes.Soft()) as app: | |
gr.Markdown("# 🤖 ChatGPT Clone") | |
gr.Markdown("日本語対応のLLMを使用したチャットボットです。") | |
with gr.Row(): | |
with gr.Column(scale=3): | |
chatbot_ui = gr.Chatbot( | |
label="Chat", | |
height=500, | |
show_label=False, | |
container=True | |
) | |
with gr.Row(): | |
msg_input = gr.Textbox( | |
label="メッセージを入力", | |
placeholder="ここにメッセージを入力してください...", | |
lines=2, | |
scale=4, | |
show_label=False | |
) | |
send_btn = gr.Button("送信", variant="primary", scale=1) | |
with gr.Row(): | |
clear_btn = gr.Button("🗑️ 新しい会話", variant="secondary") | |
with gr.Column(scale=1): | |
model_select = gr.Dropdown( | |
choices=[ | |
"elyza/Llama-3-ELYZA-JP-8B", | |
"Fugaku-LLM/Fugaku-LLM-13B" | |
], | |
value="elyza/Llama-3-ELYZA-JP-8B", | |
label="モデル選択", | |
interactive=True | |
) | |
temperature = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.7, | |
step=0.1, | |
label="Temperature", | |
info="生成の創造性を調整" | |
) | |
max_tokens = gr.Slider( | |
minimum=64, | |
maximum=1024, | |
value=512, | |
step=64, | |
label="最大トークン数", | |
info="生成する最大トークン数" | |
) | |
gr.Markdown(""" | |
### 使い方 | |
1. モデルを選択 | |
2. メッセージを入力 | |
3. 送信ボタンをクリック | |
### 注意事項 | |
- 初回のモデル読み込みには時間がかかります | |
- GPU使用時は高速に動作します | |
""") | |
# イベントハンドラ | |
msg_input.submit( | |
fn=respond, | |
inputs=[msg_input, chatbot_ui, model_select, temperature, max_tokens], | |
outputs=[chatbot_ui, msg_input] | |
) | |
send_btn.click( | |
fn=respond, | |
inputs=[msg_input, chatbot_ui, model_select, temperature, max_tokens], | |
outputs=[chatbot_ui, msg_input] | |
) | |
clear_btn.click( | |
fn=clear_chat, | |
outputs=[chatbot_ui, msg_input] | |
) | |
if __name__ == "__main__": | |
app.launch() |