Spaces:
Sleeping
Sleeping
| from typing import List, Tuple | |
| import torch | |
| import gradio as gr | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForCausalLM, | |
| ) | |
| # Causal LMs only | |
| DEFAULT_MODELS = [ | |
| "microsoft/DialoGPT-medium", | |
| "gpt2", | |
| "EleutherAI/gpt-neo-125M", | |
| "EleutherAI/pythia-350m", | |
| "facebook/opt-350m", | |
| "TinyLlama/TinyLlama-1.1B-Chat-v1.0", | |
| "Qwen/Qwen1.5-1.8B-Chat", | |
| ] | |
| # cache model load to optimize model selection | |
| _MODEL_CACHE = {} | |
| def load_model(model_name: str): | |
| key = model_name | |
| if key in _MODEL_CACHE: | |
| return _MODEL_CACHE[key] | |
| tok = AutoTokenizer.from_pretrained(model_name, use_fast=True) | |
| model = AutoModelForCausalLM.from_pretrained(model_name, use_safetensors=True) | |
| model.eval() | |
| _MODEL_CACHE[key] = (tok, model) | |
| return tok, model | |
| def build_inputs(tokenizer, history: List[Tuple[str, str]], user_message: str): | |
| # If a chat template exists, use it (best for Qwen/TinyLlama). | |
| use_chat_template = bool(getattr(tokenizer, "chat_template", None)) | |
| if use_chat_template: | |
| conv = [] | |
| for u, b in (history or [])[-6:]: | |
| conv.append({"role": "user", "content": u}) | |
| conv.append({"role": "assistant", "content": b}) | |
| conv.append({"role": "user", "content": user_message}) | |
| prompt_ids = tokenizer.apply_chat_template( | |
| conv, | |
| tokenize=True, | |
| add_generation_prompt=True, | |
| return_tensors=None, | |
| ) | |
| input_ids = torch.tensor([prompt_ids], dtype=torch.long) | |
| attention_mask = torch.ones_like(input_ids) | |
| return input_ids, attention_mask | |
| eos = tokenizer.eos_token or "" | |
| ids = [] | |
| for u, b in (history or [])[-6:]: | |
| ids.extend(tokenizer.encode(u + eos)) | |
| ids.extend(tokenizer.encode(b + eos)) | |
| # Current user message; add EOS to mark turn boundary for all non-templated LMs | |
| ids.extend(tokenizer.encode(user_message + eos)) | |
| input_ids = torch.tensor([ids], dtype=torch.long) | |
| attention_mask = torch.ones_like(input_ids) | |
| return input_ids, attention_mask | |
| def chat_fn(user_message: str, history: List[Tuple[str, str]], model_name: str): | |
| if not user_message or not user_message.strip(): | |
| return "", history, history | |
| tokenizer, model = load_model(model_name) | |
| input_ids, attention_mask = build_inputs( | |
| tokenizer, history or [], user_message.strip() | |
| ) | |
| with torch.inference_mode(): | |
| output_ids = model.generate( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| max_new_tokens=48, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.9, | |
| repetition_penalty=1.1, | |
| no_repeat_ngram_size=3, | |
| pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id, | |
| ) | |
| # Decode only the newly generated part | |
| new_tokens = output_ids[0, input_ids.shape[1] :] | |
| reply = tokenizer.decode(new_tokens, skip_special_tokens=True).strip() | |
| if not reply: | |
| reply = tokenizer.decode(output_ids[0], skip_special_tokens=True)[-200:].strip() | |
| new_hist = (history or []) + [(user_message, reply)] | |
| return "", new_hist, new_hist | |
| with gr.Blocks(title="🌮 TacoGPT") as demo: | |
| gr.Markdown( | |
| """ | |
| # 🌮 TacoGPT - Your Spicy AI Assistant | |
| """ | |
| ) | |
| with gr.Row(): | |
| model_name = gr.Dropdown( | |
| choices=DEFAULT_MODELS, | |
| value=DEFAULT_MODELS[0], | |
| label="Model", | |
| scale=2, | |
| ) | |
| chatbot = gr.Chatbot(height=460, label="Chat", type="tuples") | |
| msg = gr.Textbox( | |
| placeholder="Type your message and press Enter...", | |
| label="Message", | |
| submit_btn="Send", | |
| ) | |
| state = gr.State([]) | |
| gr.Examples( | |
| examples=[ | |
| "Who created Python?", | |
| "Write a taco recipe.", | |
| "Give me a fun fact about tacos.", | |
| "What's the history of tacos?", | |
| ], | |
| inputs=msg, | |
| ) | |
| msg.submit( | |
| chat_fn, | |
| inputs=[msg, state, model_name], | |
| outputs=[msg, chatbot, state], | |
| queue=True, | |
| api_name=False, | |
| ) | |
| gr.Markdown( | |
| """ | |
| **Tips** | |
| - Tiny causal models can be quirky; try TinyLlama or Qwen 1.5 1.8B for better chat quality. | |
| - Keep the questions short and objective. Like, "Who was Python creator?" | |
| - The first run may be slower while downloading the LM weights. | |
| """.strip() | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue() | |
| demo.launch(server_name="0.0.0.0", show_error=True) | |