llm_chat_app / app.py
ryoshimu
commit
760e16a
raw
history blame
6.6 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
from typing import List, Tuple
# Hugging Face token from environment variable
HF_TOKEN = os.getenv("HF_TOKEN")
class ChatBot:
def __init__(self):
self.model = None
self.tokenizer = None
self.current_model = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"
def load_model(self, model_name: str):
"""モデルとトークナイザーをロード"""
if self.current_model == model_name:
return
# メモリクリア
if self.model is not None:
del self.model
del self.tokenizer
torch.cuda.empty_cache()
# モデルロード
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
use_auth_token=HF_TOKEN,
trust_remote_code=True
)
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
use_auth_token=HF_TOKEN,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
device_map="auto" if self.device == "cuda" else None,
trust_remote_code=True
)
if self.device == "cuda":
self.model = self.model.to(self.device)
self.current_model = model_name
def generate_response(self, message: str, history: List[Tuple[str, str]], model_name: str,
temperature: float = 0.7, max_tokens: int = 512) -> str:
"""メッセージに対する応答を生成"""
# モデルロード
self.load_model(model_name)
# プロンプト構築
prompt = self._build_prompt(message, history)
# トークナイズ
inputs = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
# 生成
with torch.no_grad():
outputs = self.model.generate(
inputs,
max_new_tokens=max_tokens,
temperature=temperature,
do_sample=True,
top_p=0.95,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
# デコード
response = self.tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
return response.strip()
def _build_prompt(self, message: str, history: List[Tuple[str, str]]) -> str:
"""会話履歴からプロンプトを構築"""
prompt = ""
# 履歴を追加
for user_msg, assistant_msg in history[-5:]: # 最新5件の履歴のみ使用
prompt += f"User: {user_msg}\nAssistant: {assistant_msg}\n\n"
# 現在のメッセージを追加
prompt += f"User: {message}\nAssistant: "
return prompt
# ChatBotインスタンス
chatbot = ChatBot()
def respond(message: str, history: List[Tuple[str, str]], model_name: str,
temperature: float, max_tokens: int) -> Tuple[List[Tuple[str, str]], str]:
"""Gradioのコールバック関数"""
if not message:
return history, ""
try:
# 応答生成
response = chatbot.generate_response(message, history, model_name, temperature, max_tokens)
# 履歴に追加
history.append((message, response))
return history, ""
except Exception as e:
error_msg = f"エラーが発生しました: {str(e)}"
history.append((message, error_msg))
return history, ""
def clear_chat() -> Tuple[List, str]:
"""チャット履歴をクリア"""
return [], ""
# Gradio UI
with gr.Blocks(title="ChatGPT Clone", theme=gr.themes.Soft()) as app:
gr.Markdown("# 🤖 ChatGPT Clone")
gr.Markdown("日本語対応のLLMを使用したチャットボットです。")
with gr.Row():
with gr.Column(scale=3):
chatbot_ui = gr.Chatbot(
label="Chat",
height=500,
show_label=False,
container=True
)
with gr.Row():
msg_input = gr.Textbox(
label="メッセージを入力",
placeholder="ここにメッセージを入力してください...",
lines=2,
scale=4,
show_label=False
)
send_btn = gr.Button("送信", variant="primary", scale=1)
with gr.Row():
clear_btn = gr.Button("🗑️ 新しい会話", variant="secondary")
with gr.Column(scale=1):
model_select = gr.Dropdown(
choices=[
"elyza/Llama-3-ELYZA-JP-8B",
"Fugaku-LLM/Fugaku-LLM-13B"
],
value="elyza/Llama-3-ELYZA-JP-8B",
label="モデル選択",
interactive=True
)
temperature = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.7,
step=0.1,
label="Temperature",
info="生成の創造性を調整"
)
max_tokens = gr.Slider(
minimum=64,
maximum=1024,
value=512,
step=64,
label="最大トークン数",
info="生成する最大トークン数"
)
gr.Markdown("""
### 使い方
1. モデルを選択
2. メッセージを入力
3. 送信ボタンをクリック
### 注意事項
- 初回のモデル読み込みには時間がかかります
- GPU使用時は高速に動作します
""")
# イベントハンドラ
msg_input.submit(
fn=respond,
inputs=[msg_input, chatbot_ui, model_select, temperature, max_tokens],
outputs=[chatbot_ui, msg_input]
)
send_btn.click(
fn=respond,
inputs=[msg_input, chatbot_ui, model_select, temperature, max_tokens],
outputs=[chatbot_ui, msg_input]
)
clear_btn.click(
fn=clear_chat,
outputs=[chatbot_ui, msg_input]
)
if __name__ == "__main__":
app.launch()