Alysha Creelman commited on
Commit
918a9fb
Β·
unverified Β·
1 Parent(s): 6abafc6

Putting back old button code

Browse files
Files changed (1) hide show
  1. app.py +21 -5
app.py CHANGED
@@ -4,12 +4,15 @@ import torch
4
  from transformers import pipeline
5
  import os
6
 
 
7
  token = os.getenv('HF_TOKEN')
8
  client = InferenceClient(model="HuggingFaceH4/zephyr-7b-beta", token=token)
9
  pipe = pipeline("text-generation", "microsoft/Phi-3-mini-4k-instruct", torch_dtype=torch.bfloat16, device_map="auto")
10
 
 
11
  stop_inference = False
12
 
 
13
  def respond(
14
  message,
15
  history: list[tuple[str, str]],
@@ -20,12 +23,14 @@ def respond(
20
  use_local_model=False,
21
  ):
22
  global stop_inference
23
- stop_inference = False
24
 
 
25
  if history is None:
26
  history = []
27
 
28
  if use_local_model:
 
29
  messages = [{"role": "system", "content": system_message}]
30
  for val in history:
31
  if val[0]:
@@ -48,9 +53,10 @@ def respond(
48
  return
49
  token = output['generated_text'][-1]['content']
50
  response += token
51
- yield history + [(message, response)]
52
 
53
  else:
 
54
  messages = [{"role": "system", "content": system_message}]
55
  for val in history:
56
  if val[0]:
@@ -76,12 +82,14 @@ def respond(
76
  break
77
  token = message_chunk.choices[0].delta.content
78
  response += token
79
- yield history + [(message, response)]
 
80
 
81
  def cancel_inference():
82
  global stop_inference
83
  stop_inference = True
84
 
 
85
  custom_css = """
86
  #main-container {
87
  background: #cdebc5;
@@ -127,6 +135,7 @@ custom_css = """
127
  }
128
  """
129
 
 
130
  def update_system_message(level):
131
  if level == "Elementary School":
132
  return "Your name is Wormington. You are a friendly Chatbot that can help answer questions from elementary school students. Please respond with the vocabulary that a seven-year-old can understand."
@@ -137,6 +146,7 @@ def update_system_message(level):
137
  elif level == "College":
138
  return "Your name is Wormington. You are a friendly Chatbot that can help answer questions from college students. Please respond using very advanced, college-level vocabulary."
139
 
 
140
  with gr.Blocks(css=custom_css) as demo:
141
  gr.Markdown("<h2 style='text-align: center;'>🍎✏️ School AI Chatbot ✏️🍎</h2>")
142
  gr.Image("wormington_headshot.jpg", elem_id="school_ai_image", show_label=False, interactive=False)
@@ -148,8 +158,10 @@ with gr.Blocks(css=custom_css) as demo:
148
  high_button = gr.Button("High School", elem_id="high", variant="primary")
149
  college_button = gr.Button("College", elem_id="college", variant="primary")
150
 
 
151
  system_message_display = gr.Textbox(label="System Message", value="", interactive=False)
152
 
 
153
  elementary_button.click(fn=lambda: update_system_message("Elementary School"), inputs=None, outputs=system_message_display)
154
  middle_button.click(fn=lambda: update_system_message("Middle School"), inputs=None, outputs=system_message_display)
155
  high_button.click(fn=lambda: update_system_message("High School"), inputs=None, outputs=system_message_display)
@@ -157,6 +169,7 @@ with gr.Blocks(css=custom_css) as demo:
157
 
158
  with gr.Row():
159
  use_local_model = gr.Checkbox(label="Use Local Model", value=False)
 
160
 
161
  with gr.Row():
162
  max_tokens = gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens")
@@ -169,9 +182,12 @@ with gr.Blocks(css=custom_css) as demo:
169
 
170
  cancel_button = gr.Button("Cancel Inference", variant="danger")
171
 
172
- user_input.submit(fn=respond, inputs=[user_input, chat_history, system_message_display, max_tokens, temperature, top_p, use_local_model], outputs=chat_history)
 
173
 
174
  cancel_button.click(cancel_inference)
175
 
 
 
176
  if __name__ == "__main__":
177
- demo.launch(share=False)
 
4
  from transformers import pipeline
5
  import os
6
 
7
+ # Inference client setup with token from environment
8
  token = os.getenv('HF_TOKEN')
9
  client = InferenceClient(model="HuggingFaceH4/zephyr-7b-beta", token=token)
10
  pipe = pipeline("text-generation", "microsoft/Phi-3-mini-4k-instruct", torch_dtype=torch.bfloat16, device_map="auto")
11
 
12
+ # Global flag to handle cancellation
13
  stop_inference = False
14
 
15
+
16
  def respond(
17
  message,
18
  history: list[tuple[str, str]],
 
23
  use_local_model=False,
24
  ):
25
  global stop_inference
26
+ stop_inference = False # Reset cancellation flag
27
 
28
+ # Initialize history if it's None
29
  if history is None:
30
  history = []
31
 
32
  if use_local_model:
33
+ # local inference
34
  messages = [{"role": "system", "content": system_message}]
35
  for val in history:
36
  if val[0]:
 
53
  return
54
  token = output['generated_text'][-1]['content']
55
  response += token
56
+ yield history + [(message, response)] # Yield history + new response
57
 
58
  else:
59
+ # API-based inference
60
  messages = [{"role": "system", "content": system_message}]
61
  for val in history:
62
  if val[0]:
 
82
  break
83
  token = message_chunk.choices[0].delta.content
84
  response += token
85
+ yield history + [(message, response)] # Yield history + new response
86
+
87
 
88
  def cancel_inference():
89
  global stop_inference
90
  stop_inference = True
91
 
92
+ # Custom CSS for a fancy look
93
  custom_css = """
94
  #main-container {
95
  background: #cdebc5;
 
135
  }
136
  """
137
 
138
+ # Define system messages for each level
139
  def update_system_message(level):
140
  if level == "Elementary School":
141
  return "Your name is Wormington. You are a friendly Chatbot that can help answer questions from elementary school students. Please respond with the vocabulary that a seven-year-old can understand."
 
146
  elif level == "College":
147
  return "Your name is Wormington. You are a friendly Chatbot that can help answer questions from college students. Please respond using very advanced, college-level vocabulary."
148
 
149
+ # Define interface
150
  with gr.Blocks(css=custom_css) as demo:
151
  gr.Markdown("<h2 style='text-align: center;'>🍎✏️ School AI Chatbot ✏️🍎</h2>")
152
  gr.Image("wormington_headshot.jpg", elem_id="school_ai_image", show_label=False, interactive=False)
 
158
  high_button = gr.Button("High School", elem_id="high", variant="primary")
159
  college_button = gr.Button("College", elem_id="college", variant="primary")
160
 
161
+ # Display area for the selected system message
162
  system_message_display = gr.Textbox(label="System Message", value="", interactive=False)
163
 
164
+ # Update the system message when a button is clicked
165
  elementary_button.click(fn=lambda: update_system_message("Elementary School"), inputs=None, outputs=system_message_display)
166
  middle_button.click(fn=lambda: update_system_message("Middle School"), inputs=None, outputs=system_message_display)
167
  high_button.click(fn=lambda: update_system_message("High School"), inputs=None, outputs=system_message_display)
 
169
 
170
  with gr.Row():
171
  use_local_model = gr.Checkbox(label="Use Local Model", value=False)
172
+
173
 
174
  with gr.Row():
175
  max_tokens = gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens")
 
182
 
183
  cancel_button = gr.Button("Cancel Inference", variant="danger")
184
 
185
+ # Adjusted to ensure history is maintained and passed correctly
186
+ user_input.submit(respond, [user_input, chat_history, system_message_display, max_tokens, temperature, top_p, use_local_model], chat_history)
187
 
188
  cancel_button.click(cancel_inference)
189
 
190
+
191
+
192
  if __name__ == "__main__":
193
+ demo.launch(share=False) # Remove share=True because it's not supported on HF Spaces