ryoshimu commited on
Commit
760e16a
·
1 Parent(s): 9232bab
Files changed (2) hide show
  1. app.py +201 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ import os
5
+ from typing import List, Tuple
6
+
7
+ # Hugging Face token from environment variable
8
+ HF_TOKEN = os.getenv("HF_TOKEN")
9
+
10
+ class ChatBot:
11
+ def __init__(self):
12
+ self.model = None
13
+ self.tokenizer = None
14
+ self.current_model = None
15
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
16
+
17
+ def load_model(self, model_name: str):
18
+ """モデルとトークナイザーをロード"""
19
+ if self.current_model == model_name:
20
+ return
21
+
22
+ # メモリクリア
23
+ if self.model is not None:
24
+ del self.model
25
+ del self.tokenizer
26
+ torch.cuda.empty_cache()
27
+
28
+ # モデルロード
29
+ self.tokenizer = AutoTokenizer.from_pretrained(
30
+ model_name,
31
+ use_auth_token=HF_TOKEN,
32
+ trust_remote_code=True
33
+ )
34
+
35
+ self.model = AutoModelForCausalLM.from_pretrained(
36
+ model_name,
37
+ use_auth_token=HF_TOKEN,
38
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
39
+ device_map="auto" if self.device == "cuda" else None,
40
+ trust_remote_code=True
41
+ )
42
+
43
+ if self.device == "cuda":
44
+ self.model = self.model.to(self.device)
45
+
46
+ self.current_model = model_name
47
+
48
+ def generate_response(self, message: str, history: List[Tuple[str, str]], model_name: str,
49
+ temperature: float = 0.7, max_tokens: int = 512) -> str:
50
+ """メッセージに対する応答を生成"""
51
+ # モデルロード
52
+ self.load_model(model_name)
53
+
54
+ # プロンプト構築
55
+ prompt = self._build_prompt(message, history)
56
+
57
+ # トークナイズ
58
+ inputs = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
59
+
60
+ # 生成
61
+ with torch.no_grad():
62
+ outputs = self.model.generate(
63
+ inputs,
64
+ max_new_tokens=max_tokens,
65
+ temperature=temperature,
66
+ do_sample=True,
67
+ top_p=0.95,
68
+ pad_token_id=self.tokenizer.pad_token_id,
69
+ eos_token_id=self.tokenizer.eos_token_id
70
+ )
71
+
72
+ # デコード
73
+ response = self.tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
74
+ return response.strip()
75
+
76
+ def _build_prompt(self, message: str, history: List[Tuple[str, str]]) -> str:
77
+ """会話履歴からプロンプトを構築"""
78
+ prompt = ""
79
+
80
+ # 履歴を追加
81
+ for user_msg, assistant_msg in history[-5:]: # 最新5件の履歴のみ使用
82
+ prompt += f"User: {user_msg}\nAssistant: {assistant_msg}\n\n"
83
+
84
+ # 現在のメッセージを追加
85
+ prompt += f"User: {message}\nAssistant: "
86
+
87
+ return prompt
88
+
89
+ # ChatBotインスタンス
90
+ chatbot = ChatBot()
91
+
92
+ def respond(message: str, history: List[Tuple[str, str]], model_name: str,
93
+ temperature: float, max_tokens: int) -> Tuple[List[Tuple[str, str]], str]:
94
+ """Gradioのコールバック関数"""
95
+ if not message:
96
+ return history, ""
97
+
98
+ try:
99
+ # 応答生成
100
+ response = chatbot.generate_response(message, history, model_name, temperature, max_tokens)
101
+
102
+ # 履歴に追加
103
+ history.append((message, response))
104
+
105
+ return history, ""
106
+ except Exception as e:
107
+ error_msg = f"エラーが発生しました: {str(e)}"
108
+ history.append((message, error_msg))
109
+ return history, ""
110
+
111
+ def clear_chat() -> Tuple[List, str]:
112
+ """チャット履歴をクリア"""
113
+ return [], ""
114
+
115
+ # Gradio UI
116
+ with gr.Blocks(title="ChatGPT Clone", theme=gr.themes.Soft()) as app:
117
+ gr.Markdown("# 🤖 ChatGPT Clone")
118
+ gr.Markdown("日本語対応のLLMを使用したチャットボットです。")
119
+
120
+ with gr.Row():
121
+ with gr.Column(scale=3):
122
+ chatbot_ui = gr.Chatbot(
123
+ label="Chat",
124
+ height=500,
125
+ show_label=False,
126
+ container=True
127
+ )
128
+
129
+ with gr.Row():
130
+ msg_input = gr.Textbox(
131
+ label="メッセージを入力",
132
+ placeholder="ここにメッセージを入力してください...",
133
+ lines=2,
134
+ scale=4,
135
+ show_label=False
136
+ )
137
+ send_btn = gr.Button("送信", variant="primary", scale=1)
138
+
139
+ with gr.Row():
140
+ clear_btn = gr.Button("🗑️ 新しい会話", variant="secondary")
141
+
142
+ with gr.Column(scale=1):
143
+ model_select = gr.Dropdown(
144
+ choices=[
145
+ "elyza/Llama-3-ELYZA-JP-8B",
146
+ "Fugaku-LLM/Fugaku-LLM-13B"
147
+ ],
148
+ value="elyza/Llama-3-ELYZA-JP-8B",
149
+ label="モデル選択",
150
+ interactive=True
151
+ )
152
+
153
+ temperature = gr.Slider(
154
+ minimum=0.1,
155
+ maximum=1.0,
156
+ value=0.7,
157
+ step=0.1,
158
+ label="Temperature",
159
+ info="生成の創造性を調整"
160
+ )
161
+
162
+ max_tokens = gr.Slider(
163
+ minimum=64,
164
+ maximum=1024,
165
+ value=512,
166
+ step=64,
167
+ label="最大トークン数",
168
+ info="生成する最大トークン数"
169
+ )
170
+
171
+ gr.Markdown("""
172
+ ### 使い方
173
+ 1. モデルを選択
174
+ 2. メッセージを入力
175
+ 3. 送信ボタンをクリック
176
+
177
+ ### 注意事項
178
+ - 初回のモデル読み込みには時間がかかります
179
+ - GPU使用時は高速に動作します
180
+ """)
181
+
182
+ # イベントハンドラ
183
+ msg_input.submit(
184
+ fn=respond,
185
+ inputs=[msg_input, chatbot_ui, model_select, temperature, max_tokens],
186
+ outputs=[chatbot_ui, msg_input]
187
+ )
188
+
189
+ send_btn.click(
190
+ fn=respond,
191
+ inputs=[msg_input, chatbot_ui, model_select, temperature, max_tokens],
192
+ outputs=[chatbot_ui, msg_input]
193
+ )
194
+
195
+ clear_btn.click(
196
+ fn=clear_chat,
197
+ outputs=[chatbot_ui, msg_input]
198
+ )
199
+
200
+ if __name__ == "__main__":
201
+ app.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio==4.19.2
2
+ transformers==4.38.2
3
+ torch==2.2.1
4
+ accelerate==0.27.2
5
+ sentencepiece==0.2.0
6
+ protobuf==4.25.3