ryoshimu commited on
Commit
125a238
·
1 Parent(s): 6246717
Files changed (1) hide show
  1. app.py +52 -9
app.py CHANGED
@@ -3,11 +3,17 @@ import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import os
5
  from typing import List, Tuple
6
- import spaces
7
 
8
  # Hugging Face token from environment variable
9
  HF_TOKEN = os.getenv("HF_TOKEN")
10
 
 
 
 
 
 
 
 
11
  class ChatBot:
12
  def __init__(self):
13
  self.model = None
@@ -46,10 +52,9 @@ class ChatBot:
46
 
47
  self.current_model = model_name
48
 
49
- @spaces.GPU(duration=60)
50
- def generate_response(self, message: str, history: List[Tuple[str, str]], model_name: str,
51
- temperature: float = 0.7, max_tokens: int = 512) -> str:
52
- """メッセージに対する応答を生成"""
53
  # モデルロード
54
  self.load_model(model_name)
55
 
@@ -82,6 +87,37 @@ class ChatBot:
82
  torch.cuda.empty_cache()
83
 
84
  return response.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
  def _build_prompt(self, message: str, history: List[Tuple[str, str]]) -> str:
87
  """会話履歴からプロンプトを構築"""
@@ -99,6 +135,10 @@ class ChatBot:
99
  # ChatBotインスタンス
100
  chatbot = ChatBot()
101
 
 
 
 
 
102
  def respond(message: str, history: List[Tuple[str, str]], model_name: str,
103
  temperature: float, max_tokens: int) -> Tuple[List[Tuple[str, str]], str]:
104
  """Gradioのコールバック関数"""
@@ -152,10 +192,10 @@ with gr.Blocks(title="ChatGPT Clone", theme=gr.themes.Soft()) as app:
152
  with gr.Column(scale=1):
153
  model_select = gr.Dropdown(
154
  choices=[
155
- "elyza/Llama-3-ELYZA-JP-8B",
156
- "cyberagent/open-calm-7b"
157
  ],
158
- value="elyza/Llama-3-ELYZA-JP-8B",
159
  label="モデル選択",
160
  interactive=True
161
  )
@@ -209,4 +249,7 @@ with gr.Blocks(title="ChatGPT Clone", theme=gr.themes.Soft()) as app:
209
  )
210
 
211
  if __name__ == "__main__":
212
- app.launch()
 
 
 
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import os
5
  from typing import List, Tuple
 
6
 
7
  # Hugging Face token from environment variable
8
  HF_TOKEN = os.getenv("HF_TOKEN")
9
 
10
+ # Check if running on ZeroGPU
11
+ try:
12
+ import spaces
13
+ IS_ZEROGPU = True
14
+ except ImportError:
15
+ IS_ZEROGPU = False
16
+
17
  class ChatBot:
18
  def __init__(self):
19
  self.model = None
 
52
 
53
  self.current_model = model_name
54
 
55
+ def _generate_response_gpu(self, message: str, history: List[Tuple[str, str]], model_name: str,
56
+ temperature: float = 0.7, max_tokens: int = 512) -> str:
57
+ """GPU上で応答を生成する実際の処理"""
 
58
  # モデルロード
59
  self.load_model(model_name)
60
 
 
87
  torch.cuda.empty_cache()
88
 
89
  return response.strip()
90
+
91
+ def generate_response(self, message: str, history: List[Tuple[str, str]], model_name: str,
92
+ temperature: float = 0.7, max_tokens: int = 512) -> str:
93
+ """メッセージに対する応答を生成"""
94
+ if IS_ZEROGPU:
95
+ # ZeroGPU環境の場合
96
+ return self._generate_response_gpu(message, history, model_name, temperature, max_tokens)
97
+ else:
98
+ # 通常環境の場合
99
+ self.load_model(model_name)
100
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
101
+
102
+ if device == 'cuda':
103
+ self.model.to(device)
104
+
105
+ prompt = self._build_prompt(message, history)
106
+ inputs = self.tokenizer.encode(prompt, return_tensors="pt").to(device)
107
+
108
+ with torch.no_grad():
109
+ outputs = self.model.generate(
110
+ inputs,
111
+ max_new_tokens=max_tokens,
112
+ temperature=temperature,
113
+ do_sample=True,
114
+ top_p=0.95,
115
+ pad_token_id=self.tokenizer.pad_token_id,
116
+ eos_token_id=self.tokenizer.eos_token_id
117
+ )
118
+
119
+ response = self.tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
120
+ return response.strip()
121
 
122
  def _build_prompt(self, message: str, history: List[Tuple[str, str]]) -> str:
123
  """会話履歴からプロンプトを構築"""
 
135
  # ChatBotインスタンス
136
  chatbot = ChatBot()
137
 
138
+ # ZeroGPU環境の場合、GPUデコレータを適用
139
+ if IS_ZEROGPU:
140
+ chatbot._generate_response_gpu = spaces.GPU(duration=60)(chatbot._generate_response_gpu)
141
+
142
  def respond(message: str, history: List[Tuple[str, str]], model_name: str,
143
  temperature: float, max_tokens: int) -> Tuple[List[Tuple[str, str]], str]:
144
  """Gradioのコールバック関数"""
 
192
  with gr.Column(scale=1):
193
  model_select = gr.Dropdown(
194
  choices=[
195
+ "rinna/japanese-gpt2-medium",
196
+ "cyberagent/open-calm-small"
197
  ],
198
+ value="rinna/japanese-gpt2-medium",
199
  label="モデル選択",
200
  interactive=True
201
  )
 
249
  )
250
 
251
  if __name__ == "__main__":
252
+ app.launch(
253
+ share=False,
254
+ show_error=True
255
+ )