AnilNiraula commited on
Commit
ecb3e47
·
verified ·
1 Parent(s): f395a8f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +269 -53
app.py CHANGED
@@ -14,6 +14,15 @@ import torch
14
  import yfinance as yf
15
  from datetime import datetime, timedelta
16
  from math import sqrt
 
 
 
 
 
 
 
 
 
17
 
18
  # Set up logging
19
  logging.basicConfig(level=logging.INFO)
@@ -48,21 +57,20 @@ try:
48
  from PIL import Image
49
  import io
50
  except ModuleNotFoundError:
51
- subprocess.check_call([sys.executable, "-m", "pip", "install", "matplotlib", "pillow"])
52
  import matplotlib.pyplot as plt
53
  from PIL import Image
54
  import io
 
55
 
56
  MAX_MAX_NEW_TOKENS = 512
57
  DEFAULT_MAX_NEW_TOKENS = 512
58
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "1024"))
59
 
60
  DESCRIPTION = """# FinChat: Investing Q&A (Optimized for Speed)
61
- This application delivers an interactive chat interface powered by a highly efficient, small AI model adapted for addressing investing and finance inquiries through specialized prompt engineering. It ensures rapid, reasoned responses to user queries. Duplicate this Space for customization or queue-free deployment.
62
- <p>Running on CPU or GPU if available. Using Phi-2 model for faster inference. Inference is heavily optimized for responses in under 10 seconds for simple queries, with output limited to 250 tokens maximum. For longer responses, increase 'Max New Tokens' in Advanced Settings. Brief delays may occur in free-tier environments due to shared resources, but typical generation speeds are improved with the smaller model.</p>"""
63
 
64
- LICENSE = """<p/>
65
- ---
66
  This application employs the Phi-2 model, governed by Microsoft's Terms of Use. Refer to the [model card](https://huggingface.co/TheBloke/phi-2-GGUF) for details."""
67
 
68
  # Load the model (skip fine-tuning for faster startup)
@@ -75,12 +83,12 @@ try:
75
  llm = Llama(
76
  model_path=model_path,
77
  n_ctx=1024,
78
- n_batch=1024, # Increased for faster processing
79
  n_threads=multiprocessing.cpu_count(),
80
  n_gpu_layers=n_gpu_layers,
81
- chat_format="chatml" # Phi-2 uses ChatML format in llama.cpp
82
  )
83
- logger.info(f"Model loaded successfully with n_gpu_layers= {n_gpu_layers}.")
84
  # Warm up the model for faster initial inference
85
  llm("Warm-up prompt", max_tokens=1, echo=False)
86
  logger.info("Model warm-up completed.")
@@ -101,6 +109,8 @@ Assistant:
101
  - Represents average annual return with compounding
102
  - Past performance is not indicative of future results."""
103
 
 
 
104
  # Function to calculate CAGR using yfinance
105
  def calculate_cagr(ticker, start_date, end_date):
106
  try:
@@ -116,7 +126,7 @@ def calculate_cagr(ticker, start_date, end_date):
116
  logger.error(f"Error calculating CAGR for {ticker}: {str(e)}")
117
  return None
118
 
119
- # New function to calculate risk metrics using yfinance
120
  def calculate_risk_metrics(ticker, years=5):
121
  try:
122
  end_date = datetime.now().strftime('%Y-%m-%d')
@@ -134,6 +144,68 @@ def calculate_risk_metrics(ticker, years=5):
134
  logger.error(f"Error calculating risk metrics for {ticker}: {str(e)}")
135
  return None, None
136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  # Assuming the generate function handles the chat logic (extended to include risk comparison)
138
  def generate(
139
  message: str,
@@ -143,10 +215,13 @@ def generate(
143
  temperature: float,
144
  top_p: float,
145
  top_k: int,
146
- ) -> Iterator[str]:
 
 
147
  if not system_prompt:
148
  system_prompt = DEFAULT_SYSTEM_PROMPT
149
 
 
150
  # Detect CAGR query
151
  cagr_match = re.search(r'average return for (\w+) between (\d{4}) and (\d{4})', message.lower())
152
  if cagr_match:
@@ -157,11 +232,23 @@ def generate(
157
  end_date = f"{end_year}-12-31"
158
  cagr = calculate_cagr(ticker, start_date, end_date)
159
  if cagr is not None:
160
- yield f"- {ticker} CAGR ({start_year}-{end_year}): ~{cagr:.2f}%\n- Represents average annual return with compounding\n- Past performance is not indicative of future results.\n- Consult a financial advisor for personalized advice."
161
- return
 
162
  else:
163
- yield "Unable to calculate CAGR for the specified period."
164
- return
 
 
 
 
 
 
 
 
 
 
 
165
 
166
  # Detect risk comparison query
167
  risk_match = re.search(r'which stock is riskier (\w+) or (\w+)', message.lower())
@@ -171,29 +258,120 @@ def generate(
171
  vol1, sharpe1 = calculate_risk_metrics(ticker1)
172
  vol2, sharpe2 = calculate_risk_metrics(ticker2)
173
  if vol1 is None or vol2 is None:
174
- yield "Unable to fetch risk metrics for one or both tickers."
175
- return
176
- if vol1 > vol2:
177
- riskier = ticker1
178
- less_risky = ticker2
179
- higher_vol = vol1
180
- lower_vol = vol2
181
- riskier_sharpe = sharpe1
182
- less_sharpe = sharpe2
183
  else:
184
- riskier = ticker2
185
- less_risky = ticker1
186
- higher_vol = vol2
187
- lower_vol = vol1
188
- riskier_sharpe = sharpe2
189
- less_sharpe = sharpe1
190
- yield f"- {riskier} is riskier compared to {less_risky}.\n- It has a higher annualized standard deviation ({higher_vol:.2f}% vs {lower_vol:.2f}%) and a lower Sharpe ratio ({riskier_sharpe:.2f} vs {less_sharpe:.2f}), indicating greater volatility and potentially lower risk-adjusted returns.\n- Calculations based on the past 5 years of data.\n- Past performance is not indicative of future results. Consult a financial advisor for personalized advice."
191
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
193
  # For other queries, fall back to LLM generation
194
  conversation = [{"role": "system", "content": system_prompt}]
195
  for user, assistant in history:
196
- conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
 
 
 
197
  conversation.append({"role": "user", "content": message})
198
 
199
  # Generate response using LLM (streamed)
@@ -205,35 +383,73 @@ def generate(
205
  top_k=top_k,
206
  stream=True
207
  )
208
-
209
  partial_text = ""
210
  for chunk in response:
211
  if "content" in chunk["choices"][0]["delta"]:
212
  partial_text += chunk["choices"][0]["delta"]["content"]
213
  yield partial_text
 
 
 
 
 
 
 
 
 
 
 
214
 
215
- # Gradio interface setup (assuming this is part of the original code)
 
 
 
 
 
 
216
  with gr.Blocks(theme=themes.Default()) as demo:
217
  gr.Markdown(DESCRIPTION)
218
  gr.Markdown(LICENSE)
219
-
220
- chatbot = gr.Chatbot()
221
- msg = gr.Textbox(label="Enter your question")
222
- with gr.Row():
223
- submit = gr.Button("Submit")
224
- clear = gr.Button("Clear")
225
-
226
- advanced = gr.Accordion("Advanced Settings", open=False)
227
- with advanced:
228
- system_prompt = gr.Textbox(label="System Prompt", value=DEFAULT_SYSTEM_PROMPT, lines=6)
229
- max_new_tokens = gr.Slider(minimum=1, maximum=MAX_MAX_NEW_TOKENS, value=DEFAULT_MAX_NEW_TOKENS, step=1, label="Max New Tokens")
230
- temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, label="Temperature")
231
- top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.1, label="Top P")
232
- top_k = gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top K")
233
-
234
- submit.click(generate, [msg, chatbot, system_prompt, max_new_tokens, temperature, top_p, top_k], chatbot, queue=False).then(
235
- lambda: "", None, msg
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  )
237
- clear.click(lambda: None, None, chatbot)
 
238
 
239
- demo.launch()
 
14
  import yfinance as yf
15
  from datetime import datetime, timedelta
16
  from math import sqrt
17
+ import time
18
+ import base64
19
+ import io
20
+ import numpy as np
21
+ try:
22
+ import scipy.optimize as opt
23
+ except ModuleNotFoundError:
24
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "scipy"])
25
+ import scipy.optimize as opt
26
 
27
  # Set up logging
28
  logging.basicConfig(level=logging.INFO)
 
57
  from PIL import Image
58
  import io
59
  except ModuleNotFoundError:
60
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "matplotlib", "pillow", "numpy"])
61
  import matplotlib.pyplot as plt
62
  from PIL import Image
63
  import io
64
+ import numpy as np
65
 
66
  MAX_MAX_NEW_TOKENS = 512
67
  DEFAULT_MAX_NEW_TOKENS = 512
68
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "1024"))
69
 
70
  DESCRIPTION = """# FinChat: Investing Q&A (Optimized for Speed)
71
+ This application delivers an interactive chat interface powered by a highly efficient, small AI model adapted for addressing investing and finance inquiries through specialized prompt engineering. It ensures rapid, reasoned responses to user queries. Duplicate this Space for customization or queue-free deployment.<p>Running on CPU or GPU if available. Using Phi-2 model for faster inference. Inference is heavily optimized for responses in under 10 seconds for simple queries, with output limited to 250 tokens maximum. For longer responses, increase 'Max New Tokens' in Advanced Settings. Brief delays may occur in free-tier environments due to shared resources, but typical generation speeds are improved with the smaller model.</p>"""
 
72
 
73
+ LICENSE = """<p/>---
 
74
  This application employs the Phi-2 model, governed by Microsoft's Terms of Use. Refer to the [model card](https://huggingface.co/TheBloke/phi-2-GGUF) for details."""
75
 
76
  # Load the model (skip fine-tuning for faster startup)
 
83
  llm = Llama(
84
  model_path=model_path,
85
  n_ctx=1024,
86
+ n_batch=1024, # Increased for faster processing
87
  n_threads=multiprocessing.cpu_count(),
88
  n_gpu_layers=n_gpu_layers,
89
+ chat_format="chatml" # Phi-2 uses ChatML format in llama.cpp
90
  )
91
+ logger.info(f"Model loaded successfully with n_gpu_layers={n_gpu_layers}.")
92
  # Warm up the model for faster initial inference
93
  llm("Warm-up prompt", max_tokens=1, echo=False)
94
  logger.info("Model warm-up completed.")
 
109
  - Represents average annual return with compounding
110
  - Past performance is not indicative of future results."""
111
 
112
+ logs = []
113
+
114
  # Function to calculate CAGR using yfinance
115
  def calculate_cagr(ticker, start_date, end_date):
116
  try:
 
126
  logger.error(f"Error calculating CAGR for {ticker}: {str(e)}")
127
  return None
128
 
129
+ # Function to calculate risk metrics using yfinance
130
  def calculate_risk_metrics(ticker, years=5):
131
  try:
132
  end_date = datetime.now().strftime('%Y-%m-%d')
 
144
  logger.error(f"Error calculating risk metrics for {ticker}: {str(e)}")
145
  return None, None
146
 
147
+ # Function for inline plot
148
+ def generate_plot(ticker, period='5y'):
149
+ try:
150
+ data = yf.download(ticker, period=period)
151
+ if data.empty:
152
+ return "Unable to fetch data for plotting."
153
+ plt.figure(figsize=(10, 5))
154
+ plt.plot(data['Adj Close'], label='Adjusted Close')
155
+ plt.title(f'{ticker} Price History ({period})')
156
+ plt.xlabel('Date')
157
+ plt.ylabel('Price (USD)')
158
+ plt.legend()
159
+ plt.grid(True)
160
+ buf = io.BytesIO()
161
+ plt.savefig(buf, format='png', bbox_inches='tight')
162
+ buf.seek(0)
163
+ b64 = base64.b64encode(buf.read()).decode('utf-8')
164
+ plt.close()
165
+ return f"![{ticker} Price Chart](data:image/png;base64,{b64})"
166
+ except Exception as e:
167
+ logger.error(f"Error generating plot for {ticker}: {str(e)}")
168
+ return "Error generating plot."
169
+
170
+ # Function for portfolio optimization using scipy
171
+ def portfolio_optimization(tickers, target_return=None):
172
+ try:
173
+ data = yf.download(tickers, period='5y')['Adj Close']
174
+ returns = data.pct_change().dropna()
175
+ mean_returns = returns.mean() * 252
176
+ cov_matrix = returns.cov() * 252
177
+ num_assets = len(tickers)
178
+
179
+ def portfolio_volatility(weights):
180
+ return np.sqrt(np.dot(weights.T, np.dot(cov_matrix, weights)))
181
+
182
+ constraints = ({'type': 'eq', 'fun': lambda x: np.sum(x) - 1})
183
+ bounds = tuple((0, 1) for _ in range(num_assets))
184
+ initial_guess = np.array(num_assets * [1. / num_assets])
185
+
186
+ if target_return:
187
+ # Maximize Sharpe or min vol for target return
188
+ def objective(weights):
189
+ ret = np.sum(mean_returns * weights)
190
+ vol = portfolio_volatility(weights)
191
+ return - (ret - 0.02) / vol if vol != 0 else np.inf # Neg Sharpe
192
+ cons = [{'type': 'eq', 'fun': lambda x: np.sum(x) - 1},
193
+ {'type': 'eq', 'fun': lambda x: np.sum(mean_returns * x) - target_return}]
194
+ result = opt.minimize(objective, initial_guess, method='SLSQP', bounds=bounds, constraints=cons)
195
+ else:
196
+ # Minimize volatility
197
+ result = opt.minimize(portfolio_volatility, initial_guess, method='SLSQP',
198
+ bounds=bounds, constraints=constraints)
199
+
200
+ if result.success:
201
+ weights = dict(zip(tickers, result.x))
202
+ return weights
203
+ else:
204
+ return {ticker: 1/len(tickers) for ticker in tickers} # Fallback equal weights
205
+ except Exception as e:
206
+ logger.error(f"Error in portfolio optimization: {str(e)}")
207
+ return {ticker: 1/len(tickers) for ticker in tickers}
208
+
209
  # Assuming the generate function handles the chat logic (extended to include risk comparison)
210
  def generate(
211
  message: str,
 
215
  temperature: float,
216
  top_p: float,
217
  top_k: int,
218
+ logs_state: list
219
+ ) -> tuple[Iterator[str], list]:
220
+ start_time = time.time()
221
  if not system_prompt:
222
  system_prompt = DEFAULT_SYSTEM_PROMPT
223
 
224
+ full_response = ""
225
  # Detect CAGR query
226
  cagr_match = re.search(r'average return for (\w+) between (\d{4}) and (\d{4})', message.lower())
227
  if cagr_match:
 
232
  end_date = f"{end_year}-12-31"
233
  cagr = calculate_cagr(ticker, start_date, end_date)
234
  if cagr is not None:
235
+ response = f"- {ticker} CAGR ({start_year}-{end_year}): ~{cagr:.2f}%\n- Represents average annual return with compounding\n- Past performance is not indicative of future results.\n- Consult a financial advisor for personalized advice."
236
+ yield response
237
+ full_response = response
238
  else:
239
+ response = "Unable to calculate CAGR for the specified period."
240
+ yield response
241
+ full_response = response
242
+ end_time = time.time()
243
+ logs_state.append({
244
+ 'timestamp': datetime.now().isoformat(),
245
+ 'query': message,
246
+ 'response': full_response,
247
+ 'response_length': len(full_response.split()),
248
+ 'generation_time': end_time - start_time,
249
+ 'token_efficiency': len(full_response.split()) / max_new_tokens
250
+ })
251
+ return iter([]), logs_state # No more yield
252
 
253
  # Detect risk comparison query
254
  risk_match = re.search(r'which stock is riskier (\w+) or (\w+)', message.lower())
 
258
  vol1, sharpe1 = calculate_risk_metrics(ticker1)
259
  vol2, sharpe2 = calculate_risk_metrics(ticker2)
260
  if vol1 is None or vol2 is None:
261
+ response = "Unable to fetch risk metrics for one or both tickers."
262
+ yield response
263
+ full_response = response
 
 
 
 
 
 
264
  else:
265
+ if vol1 > vol2:
266
+ riskier = ticker1
267
+ less_risky = ticker2
268
+ higher_vol = vol1
269
+ lower_vol = vol2
270
+ riskier_sharpe = sharpe1
271
+ less_sharpe = sharpe2
272
+ else:
273
+ riskier = ticker2
274
+ less_risky = ticker1
275
+ higher_vol = vol2
276
+ lower_vol = vol1
277
+ riskier_sharpe = sharpe2
278
+ less_sharpe = sharpe1
279
+ response = f"- {riskier} is riskier compared to {less_risky}.\n- It has a higher annualized standard deviation ({higher_vol:.2f}% vs {lower_vol:.2f}%) and a lower Sharpe ratio ({riskier_sharpe:.2f} vs {less_sharpe:.2f}), indicating greater volatility and potentially lower risk-adjusted returns.\n- Calculations based on the past 5 years of data.\n- Past performance is not indicative of future results. Consult a financial advisor for personalized advice."
280
+ yield response
281
+ full_response = response
282
+ end_time = time.time()
283
+ logs_state.append({
284
+ 'timestamp': datetime.now().isoformat(),
285
+ 'query': message,
286
+ 'response': full_response,
287
+ 'response_length': len(full_response.split()),
288
+ 'generation_time': end_time - start_time,
289
+ 'token_efficiency': len(full_response.split()) / max_new_tokens
290
+ })
291
+ return iter([]), logs_state
292
+
293
+ # Detect plot/chart query
294
+ plot_match = re.search(r'(plot|chart)\s+(\w+)(?:\s+(historical|price|volatility))?', message.lower())
295
+ if plot_match:
296
+ ticker = plot_match.group(2).upper()
297
+ plot_type = plot_match.group(3) if plot_match.group(3) else 'price'
298
+ if plot_type == 'volatility':
299
+ # Simple volatility plot (returns histogram)
300
+ try:
301
+ data = yf.download(ticker, period='1y')
302
+ returns = data['Adj Close'].pct_change().dropna()
303
+ plt.figure(figsize=(10, 5))
304
+ plt.hist(returns, bins=50, alpha=0.7)
305
+ plt.title(f'{ticker} Daily Returns Distribution (1Y)')
306
+ plt.xlabel('Return')
307
+ plt.ylabel('Frequency')
308
+ except:
309
+ plot_type = 'price' # Fallback
310
+ if plot_type != 'volatility':
311
+ plot_md = generate_plot(ticker)
312
+ response = f"Price chart for {ticker}:\n{plot_md}\n- This visualizes the historical adjusted close prices.\n- Past performance is not indicative of future results. Consult a financial advisor."
313
+ yield response
314
+ full_response = response
315
+ else:
316
+ # For volatility, similar
317
+ buf = io.BytesIO()
318
+ plt.savefig(buf, format='png', bbox_inches='tight')
319
+ buf.seek(0)
320
+ b64 = base64.b64encode(buf.read()).decode('utf-8')
321
+ plt.close()
322
+ plot_md = f"![{ticker} Volatility](data:image/png;base64,{b64})"
323
+ response = f"Volatility chart for {ticker}:\n{plot_md}\n- Histogram of daily returns over the past year."
324
+ yield response
325
+ full_response = response
326
+ end_time = time.time()
327
+ logs_state.append({
328
+ 'timestamp': datetime.now().isoformat(),
329
+ 'query': message,
330
+ 'response': full_response,
331
+ 'response_length': len(full_response.split()),
332
+ 'generation_time': end_time - start_time,
333
+ 'token_efficiency': len(full_response.split()) / max_new_tokens
334
+ })
335
+ return iter([]), logs_state
336
+
337
+ # Detect portfolio optimization query
338
+ port_match = re.search(r'optimize\s+portfolio\s+for\s+([\w,\s]+)(?:\s+with\s+(risk|return)\s+tolerance\s+([\d.]+))?', message.lower())
339
+ if port_match:
340
+ tickers_str = port_match.group(1).strip()
341
+ tickers = [t.strip().upper() for t in re.split(r'[,;]', tickers_str) if t.strip()]
342
+ target = None
343
+ if port_match.group(3):
344
+ target = float(port_match.group(3))
345
+ if port_match.group(2) == 'risk':
346
+ # For risk tolerance, min vol with vol <= target (but simplify to min vol)
347
+ pass # Use default min vol
348
+ else:
349
+ target_return = target
350
+ weights = portfolio_optimization(tickers, target_return=target if 'return' in (port_match.group(2) or '') else None)
351
+ df = pd.DataFrame(list(weights.items()), columns=['Ticker', 'Weight'])
352
+ df['Weight'] = df['Weight'].round(4)
353
+ table_md = df.to_markdown(index=False)
354
+ response = f"- Suggested portfolio weights for {', '.join(tickers)}:\n{table_md}\n- Based on minimum variance optimization (or target return if specified).\n- Assumes 5-year historical data for means and covariances.\n- Past performance is not indicative of future results. Consult a financial advisor for personalized advice."
355
+ yield response
356
+ full_response = response
357
+ end_time = time.time()
358
+ logs_state.append({
359
+ 'timestamp': datetime.now().isoformat(),
360
+ 'query': message,
361
+ 'response': full_response,
362
+ 'response_length': len(full_response.split()),
363
+ 'generation_time': end_time - start_time,
364
+ 'token_efficiency': len(full_response.split()) / max_new_tokens
365
+ })
366
+ return iter([]), logs_state
367
 
368
  # For other queries, fall back to LLM generation
369
  conversation = [{"role": "system", "content": system_prompt}]
370
  for user, assistant in history:
371
+ conversation.extend([
372
+ {"role": "user", "content": user},
373
+ {"role": "assistant", "content": assistant}
374
+ ])
375
  conversation.append({"role": "user", "content": message})
376
 
377
  # Generate response using LLM (streamed)
 
383
  top_k=top_k,
384
  stream=True
385
  )
 
386
  partial_text = ""
387
  for chunk in response:
388
  if "content" in chunk["choices"][0]["delta"]:
389
  partial_text += chunk["choices"][0]["delta"]["content"]
390
  yield partial_text
391
+ full_response = partial_text
392
+ end_time = time.time()
393
+ logs_state.append({
394
+ 'timestamp': datetime.now().isoformat(),
395
+ 'query': message,
396
+ 'response': full_response,
397
+ 'response_length': len(full_response.split()),
398
+ 'generation_time': end_time - start_time,
399
+ 'token_efficiency': len(full_response.split()) / max_new_tokens
400
+ })
401
+ return iter([]), logs_state
402
 
403
+ def update_logs(logs_state):
404
+ if logs_state:
405
+ df = pd.DataFrame(logs_state)
406
+ return df
407
+ return pd.DataFrame()
408
+
409
+ # Gradio interface setup
410
  with gr.Blocks(theme=themes.Default()) as demo:
411
  gr.Markdown(DESCRIPTION)
412
  gr.Markdown(LICENSE)
413
+
414
+ with gr.Tabs():
415
+ with gr.TabItem("Chat"):
416
+ chatbot = gr.Chatbot()
417
+ msg = gr.Textbox(label="Enter your question")
418
+ with gr.Row():
419
+ submit = gr.Button("Submit")
420
+ clear = gr.Button("Clear")
421
+ advanced = gr.Accordion("Advanced Settings", open=False)
422
+ with advanced:
423
+ system_prompt = gr.Textbox(label="System Prompt", value=DEFAULT_SYSTEM_PROMPT, lines=6)
424
+ max_new_tokens = gr.Slider(minimum=1, maximum=MAX_MAX_NEW_TOKENS, value=DEFAULT_MAX_NEW_TOKENS, step=1, label="Max New Tokens")
425
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, label="Temperature")
426
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.1, label="Top P")
427
+ top_k = gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top K")
428
+
429
+ with gr.TabItem("Metrics"):
430
+ metrics_df = gr.Dataframe(headers=['timestamp', 'query', 'response', 'response_length', 'generation_time', 'token_efficiency'])
431
+
432
+ logs_state = gr.State(logs)
433
+
434
+ def submit_fn(msg, history, system_prompt, max_new_tokens, temperature, top_p, top_k, logs_state):
435
+ gen, new_logs = generate(msg, history, system_prompt, max_new_tokens, temperature, top_p, top_k, logs_state)
436
+ history.append((msg, ""))
437
+ for partial in gen:
438
+ history[-1] = (history[-1][0], partial)
439
+ yield history, "", new_logs
440
+ return history, "", new_logs
441
+
442
+ submit.click(
443
+ submit_fn,
444
+ inputs=[msg, chatbot, system_prompt, max_new_tokens, temperature, top_p, top_k, logs_state],
445
+ outputs=[chatbot, msg, logs_state],
446
+ queue=False
447
+ ).then(
448
+ update_logs,
449
+ inputs=[logs_state],
450
+ outputs=[metrics_df]
451
  )
452
+
453
+ clear.click(lambda: ([], []), None, (chatbot, logs_state))
454
 
455
+ demo.launch()