AnilNiraula commited on
Commit
66c6745
·
verified ·
1 Parent(s): caedbee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -3
app.py CHANGED
@@ -60,7 +60,7 @@ from datasets import load_dataset
60
 
61
  MAX_MAX_NEW_TOKENS = 512
62
  DEFAULT_MAX_NEW_TOKENS = 128
63
- MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "256"))
64
 
65
  DESCRIPTION = """\
66
  # FinChat: Investing Q&A (CPU-Only, Ultra-Fast Optimization)
@@ -165,7 +165,7 @@ try:
165
  )
166
  llm = Llama(
167
  model_path=model_path,
168
- n_ctx=256,
169
  n_batch=512,
170
  n_threads=multiprocessing.cpu_count(),
171
  n_gpu_layers=0,
@@ -249,10 +249,20 @@ def generate(
249
  conversation.append({"role": "assistant", "content": msg["content"]})
250
  conversation.append({"role": "user", "content": message})
251
 
252
- # Approximate token length check
253
  prompt_text = "\n".join(d["content"] for d in conversation)
254
  input_tokens = llm.tokenize(prompt_text.encode("utf-8"), add_bos=False)
255
 
 
 
 
 
 
 
 
 
 
 
256
  # Generate response
257
  try:
258
  response = ""
@@ -273,6 +283,12 @@ def generate(
273
  if chunk["choices"][0]["finish_reason"] is not None:
274
  break
275
  logger.info("Response generation completed.")
 
 
 
 
 
 
276
  except Exception as e:
277
  logger.error(f"Error during response generation: {str(e)}")
278
  yield f"Error generating response: {str(e)}"
 
60
 
61
  MAX_MAX_NEW_TOKENS = 512
62
  DEFAULT_MAX_NEW_TOKENS = 128
63
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "1024"))
64
 
65
  DESCRIPTION = """\
66
  # FinChat: Investing Q&A (CPU-Only, Ultra-Fast Optimization)
 
165
  )
166
  llm = Llama(
167
  model_path=model_path,
168
+ n_ctx=1024,
169
  n_batch=512,
170
  n_threads=multiprocessing.cpu_count(),
171
  n_gpu_layers=0,
 
249
  conversation.append({"role": "assistant", "content": msg["content"]})
250
  conversation.append({"role": "user", "content": message})
251
 
252
+ # Approximate token length check and truncate if necessary
253
  prompt_text = "\n".join(d["content"] for d in conversation)
254
  input_tokens = llm.tokenize(prompt_text.encode("utf-8"), add_bos=False)
255
 
256
+ while len(input_tokens) > MAX_INPUT_TOKEN_LENGTH:
257
+ logger.warning(f"Input tokens ({len(input_tokens)}) exceed limit ({MAX_INPUT_TOKEN_LENGTH}). Truncating history.")
258
+ if len(conversation) > 2: # Preserve system prompt and current user message
259
+ conversation.pop(1) # Remove oldest user/assistant pair
260
+ prompt_text = "\n".join(d["content"] for d in conversation)
261
+ input_tokens = llm.tokenize(prompt_text.encode("utf-8"), add_bos=False)
262
+ else:
263
+ yield "Error: Input is too long even after truncation. Please shorten your query."
264
+ return
265
+
266
  # Generate response
267
  try:
268
  response = ""
 
283
  if chunk["choices"][0]["finish_reason"] is not None:
284
  break
285
  logger.info("Response generation completed.")
286
+ except ValueError as ve:
287
+ if "exceed context window" in str(ve):
288
+ yield "Error: Prompt too long for context window. Please try a shorter query or clear history."
289
+ else:
290
+ logger.error(f"Error during response generation: {str(ve)}")
291
+ yield f"Error generating response: {str(ve)}"
292
  except Exception as e:
293
  logger.error(f"Error during response generation: {str(e)}")
294
  yield f"Error generating response: {str(e)}"