ryoshimu commited on
Commit
2445678
·
1 Parent(s): 06de6e5
Files changed (1) hide show
  1. app.py +88 -36
app.py CHANGED
@@ -1,3 +1,12 @@
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
@@ -7,12 +16,18 @@ from typing import List, Tuple
7
  # Hugging Face token from environment variable
8
  HF_TOKEN = os.getenv("HF_TOKEN")
9
 
 
 
 
 
10
  # Check if running on ZeroGPU
11
  try:
12
  import spaces
13
  IS_ZEROGPU = True
 
14
  except ImportError:
15
  IS_ZEROGPU = False
 
16
 
17
  class ChatBot:
18
  def __init__(self):
@@ -25,32 +40,45 @@ class ChatBot:
25
  if self.current_model == model_name and self.model is not None:
26
  return
27
 
28
- # メモリクリア
29
- if self.model is not None:
30
- del self.model
31
- torch.cuda.empty_cache()
32
-
33
- # トークナイザーロード
34
- self.tokenizer = AutoTokenizer.from_pretrained(
35
- model_name,
36
- use_auth_token=HF_TOKEN,
37
- trust_remote_code=True
38
- )
39
-
40
- # パッドトークンの設定
41
- if self.tokenizer.pad_token is None:
42
- self.tokenizer.pad_token = self.tokenizer.eos_token
43
-
44
- # モデルロード(ZeroGPU対応)
45
- self.model = AutoModelForCausalLM.from_pretrained(
46
- model_name,
47
- use_auth_token=HF_TOKEN,
48
- torch_dtype=torch.float16,
49
- low_cpu_mem_usage=True,
50
- trust_remote_code=True
51
- )
52
 
53
- self.current_model = model_name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  def _generate_response_gpu(self, message: str, history: List[Tuple[str, str]], model_name: str,
56
  temperature: float = 0.7, max_tokens: int = 512) -> str:
@@ -75,6 +103,8 @@ class ChatBot:
75
  temperature=temperature,
76
  do_sample=True,
77
  top_p=0.95,
 
 
78
  pad_token_id=self.tokenizer.pad_token_id,
79
  eos_token_id=self.tokenizer.eos_token_id
80
  )
@@ -85,6 +115,7 @@ class ChatBot:
85
  # CPUに戻す(メモリ節約)
86
  self.model.to('cpu')
87
  torch.cuda.empty_cache()
 
88
 
89
  return response.strip()
90
 
@@ -112,6 +143,8 @@ class ChatBot:
112
  temperature=temperature,
113
  do_sample=True,
114
  top_p=0.95,
 
 
115
  pad_token_id=self.tokenizer.pad_token_id,
116
  eos_token_id=self.tokenizer.eos_token_id
117
  )
@@ -123,8 +156,8 @@ class ChatBot:
123
  """会話履歴からプロンプトを構築"""
124
  prompt = ""
125
 
126
- # 履歴を追加
127
- for user_msg, assistant_msg in history[-5:]: # 最新5件の履歴のみ使用
128
  prompt += f"User: {user_msg}\nAssistant: {assistant_msg}\n\n"
129
 
130
  # 現在のメッセージを追加
@@ -137,7 +170,7 @@ chatbot = ChatBot()
137
 
138
  # ZeroGPU環境の場合、GPUデコレータを適用
139
  if IS_ZEROGPU:
140
- chatbot._generate_response_gpu = spaces.GPU(duration=60)(chatbot._generate_response_gpu)
141
 
142
  def respond(message: str, history: List[Tuple[str, str]], model_name: str,
143
  temperature: float, max_tokens: int) -> Tuple[List[Tuple[str, str]], str]:
@@ -152,6 +185,13 @@ def respond(message: str, history: List[Tuple[str, str]], model_name: str,
152
  # 履歴に追加
153
  history.append((message, response))
154
 
 
 
 
 
 
 
 
155
  return history, ""
156
  except Exception as e:
157
  error_msg = f"エラーが発生しました: {str(e)}"
@@ -165,7 +205,13 @@ def clear_chat() -> Tuple[List, str]:
165
  # Gradio UI
166
  with gr.Blocks(title="ChatGPT Clone", theme=gr.themes.Soft()) as app:
167
  gr.Markdown("# 🤖 ChatGPT Clone")
168
- gr.Markdown("日本語対応のLLMを使用したチャットボットです。")
 
 
 
 
 
 
169
 
170
  with gr.Row():
171
  with gr.Column(scale=3):
@@ -192,10 +238,10 @@ with gr.Blocks(title="ChatGPT Clone", theme=gr.themes.Soft()) as app:
192
  with gr.Column(scale=1):
193
  model_select = gr.Dropdown(
194
  choices=[
195
- "rinna/japanese-gpt2-medium",
196
- "cyberagent/open-calm-small"
197
  ],
198
- value="rinna/japanese-gpt2-medium",
199
  label="モデル選択",
200
  interactive=True
201
  )
@@ -211,8 +257,8 @@ with gr.Blocks(title="ChatGPT Clone", theme=gr.themes.Soft()) as app:
211
 
212
  max_tokens = gr.Slider(
213
  minimum=64,
214
- maximum=1024,
215
- value=512,
216
  step=64,
217
  label="最大トークン数",
218
  info="生成する最大トークン数"
@@ -227,7 +273,8 @@ with gr.Blocks(title="ChatGPT Clone", theme=gr.themes.Soft()) as app:
227
  ### 注意事項
228
  - 初回のモデル読み込みには時間がかかります
229
  - ZeroGPU使用により高速推論が可能
230
- - 1回の生成は60秒以内に完了します
 
231
  """)
232
 
233
  # イベントハンドラ
@@ -249,7 +296,12 @@ with gr.Blocks(title="ChatGPT Clone", theme=gr.themes.Soft()) as app:
249
  )
250
 
251
  if __name__ == "__main__":
 
 
 
252
  app.launch(
253
  share=False,
254
- show_error=True
 
 
255
  )
 
1
+ """
2
+ ChatGPT Clone - 日本語対応チャットボット
3
+ Hugging Face Spaces (ZeroGPU) 対応版
4
+
5
+ 使用モデル:
6
+ - elyza/Llama-3-ELYZA-JP-8B
7
+ - Fugaku-LLM/Fugaku-LLM-13B
8
+ """
9
+
10
  import gradio as gr
11
  import torch
12
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
16
  # Hugging Face token from environment variable
17
  HF_TOKEN = os.getenv("HF_TOKEN")
18
 
19
+ # トークンのチェック
20
+ if not HF_TOKEN:
21
+ print("警告: HF_TOKENが設定されていません。プライベートモデルへのアクセスが制限される場合があります。")
22
+
23
  # Check if running on ZeroGPU
24
  try:
25
  import spaces
26
  IS_ZEROGPU = True
27
+ print("ZeroGPU環境を検出しました。")
28
  except ImportError:
29
  IS_ZEROGPU = False
30
+ print("通常のGPU/CPU環境で実行しています。")
31
 
32
  class ChatBot:
33
  def __init__(self):
 
40
  if self.current_model == model_name and self.model is not None:
41
  return
42
 
43
+ try:
44
+ # メモリクリア
45
+ if self.model is not None:
46
+ del self.model
47
+ del self.tokenizer
48
+ if torch.cuda.is_available():
49
+ torch.cuda.empty_cache()
50
+ torch.cuda.synchronize()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
+ # トークナイザーロード
53
+ self.tokenizer = AutoTokenizer.from_pretrained(
54
+ model_name,
55
+ token=HF_TOKEN,
56
+ trust_remote_code=True,
57
+ padding_side="left"
58
+ )
59
+
60
+ # パッドトークンの設定
61
+ if self.tokenizer.pad_token is None:
62
+ self.tokenizer.pad_token = self.tokenizer.eos_token
63
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
64
+
65
+ # モデルロード(ZeroGPU対応)
66
+ self.model = AutoModelForCausalLM.from_pretrained(
67
+ model_name,
68
+ token=HF_TOKEN,
69
+ torch_dtype=torch.float16,
70
+ low_cpu_mem_usage=True,
71
+ trust_remote_code=True,
72
+ load_in_8bit=False, # ZeroGPU環境では8bit量子化は使わない
73
+ device_map=None # ZeroGPU環境では自動マッピングしない
74
+ )
75
+
76
+ self.current_model = model_name
77
+ print(f"モデル {model_name} のロードが完了しました。")
78
+
79
+ except Exception as e:
80
+ print(f"モデルのロード中にエラーが発生しました: {str(e)}")
81
+ raise
82
 
83
  def _generate_response_gpu(self, message: str, history: List[Tuple[str, str]], model_name: str,
84
  temperature: float = 0.7, max_tokens: int = 512) -> str:
 
103
  temperature=temperature,
104
  do_sample=True,
105
  top_p=0.95,
106
+ top_k=50,
107
+ repetition_penalty=1.1,
108
  pad_token_id=self.tokenizer.pad_token_id,
109
  eos_token_id=self.tokenizer.eos_token_id
110
  )
 
115
  # CPUに戻す(メモリ節約)
116
  self.model.to('cpu')
117
  torch.cuda.empty_cache()
118
+ torch.cuda.synchronize()
119
 
120
  return response.strip()
121
 
 
143
  temperature=temperature,
144
  do_sample=True,
145
  top_p=0.95,
146
+ top_k=50,
147
+ repetition_penalty=1.1,
148
  pad_token_id=self.tokenizer.pad_token_id,
149
  eos_token_id=self.tokenizer.eos_token_id
150
  )
 
156
  """会話履歴からプロンプトを構築"""
157
  prompt = ""
158
 
159
+ # 履歴を追加(最新3件のみ使用 - メモリ効率のため)
160
+ for user_msg, assistant_msg in history[-3:]:
161
  prompt += f"User: {user_msg}\nAssistant: {assistant_msg}\n\n"
162
 
163
  # 現在のメッセージを追加
 
170
 
171
  # ZeroGPU環境の場合、GPUデコレータを適用
172
  if IS_ZEROGPU:
173
+ chatbot._generate_response_gpu = spaces.GPU(duration=120)(chatbot._generate_response_gpu)
174
 
175
  def respond(message: str, history: List[Tuple[str, str]], model_name: str,
176
  temperature: float, max_tokens: int) -> Tuple[List[Tuple[str, str]], str]:
 
185
  # 履歴に追加
186
  history.append((message, response))
187
 
188
+ return history, ""
189
+ except RuntimeError as e:
190
+ if "out of memory" in str(e).lower():
191
+ error_msg = "メモリ不足エラー: より小さいモデルを使用するか、最大トークン数を減らしてください。"
192
+ else:
193
+ error_msg = f"実行時エラー: {str(e)}"
194
+ history.append((message, error_msg))
195
  return history, ""
196
  except Exception as e:
197
  error_msg = f"エラーが発生しました: {str(e)}"
 
205
  # Gradio UI
206
  with gr.Blocks(title="ChatGPT Clone", theme=gr.themes.Soft()) as app:
207
  gr.Markdown("# 🤖 ChatGPT Clone")
208
+ gr.Markdown("""
209
+ 日本語対応のLLMを使用したチャットボットです。
210
+
211
+ **使用可能モデル:**
212
+ - [elyza/Llama-3-ELYZA-JP-8B](https://huggingface.co/elyza/Llama-3-ELYZA-JP-8B)
213
+ - [Fugaku-LLM/Fugaku-LLM-13B](https://huggingface.co/Fugaku-LLM/Fugaku-LLM-13B)
214
+ """)
215
 
216
  with gr.Row():
217
  with gr.Column(scale=3):
 
238
  with gr.Column(scale=1):
239
  model_select = gr.Dropdown(
240
  choices=[
241
+ "elyza/Llama-3-ELYZA-JP-8B",
242
+ "Fugaku-LLM/Fugaku-LLM-13B"
243
  ],
244
+ value="elyza/Llama-3-ELYZA-JP-8B",
245
  label="モデル選択",
246
  interactive=True
247
  )
 
257
 
258
  max_tokens = gr.Slider(
259
  minimum=64,
260
+ maximum=512,
261
+ value=256,
262
  step=64,
263
  label="最大トークン数",
264
  info="生成する最大トークン数"
 
273
  ### 注意事項
274
  - 初回のモデル読み込みには時間がかかります
275
  - ZeroGPU使用により高速推論が可能
276
+ - 1回の生成は120秒以内に完了します
277
+ - 大きなモデル使用時は、短めの応答になる場合があります
278
  """)
279
 
280
  # イベントハンドラ
 
296
  )
297
 
298
  if __name__ == "__main__":
299
+ # Hugging Face Spaces環境かどうかを確認
300
+ is_hf_spaces = os.getenv("SPACE_ID") is not None
301
+
302
  app.launch(
303
  share=False,
304
+ show_error=True,
305
+ server_name="0.0.0.0" if is_hf_spaces else "127.0.0.1",
306
+ server_port=7860
307
  )