tacogpt / app.py
Eliezer Oliveira
Update Chatbot `type` attribute to `tuples
cac7141
from typing import List, Tuple
import torch
import gradio as gr
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
)
# Causal LMs only
DEFAULT_MODELS = [
"microsoft/DialoGPT-medium",
"gpt2",
"EleutherAI/gpt-neo-125M",
"EleutherAI/pythia-350m",
"facebook/opt-350m",
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"Qwen/Qwen1.5-1.8B-Chat",
]
# cache model load to optimize model selection
_MODEL_CACHE = {}
def load_model(model_name: str):
key = model_name
if key in _MODEL_CACHE:
return _MODEL_CACHE[key]
tok = AutoTokenizer.from_pretrained(model_name, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(model_name, use_safetensors=True)
model.eval()
_MODEL_CACHE[key] = (tok, model)
return tok, model
def build_inputs(tokenizer, history: List[Tuple[str, str]], user_message: str):
# If a chat template exists, use it (best for Qwen/TinyLlama).
use_chat_template = bool(getattr(tokenizer, "chat_template", None))
if use_chat_template:
conv = []
for u, b in (history or [])[-6:]:
conv.append({"role": "user", "content": u})
conv.append({"role": "assistant", "content": b})
conv.append({"role": "user", "content": user_message})
prompt_ids = tokenizer.apply_chat_template(
conv,
tokenize=True,
add_generation_prompt=True,
return_tensors=None,
)
input_ids = torch.tensor([prompt_ids], dtype=torch.long)
attention_mask = torch.ones_like(input_ids)
return input_ids, attention_mask
eos = tokenizer.eos_token or ""
ids = []
for u, b in (history or [])[-6:]:
ids.extend(tokenizer.encode(u + eos))
ids.extend(tokenizer.encode(b + eos))
# Current user message; add EOS to mark turn boundary for all non-templated LMs
ids.extend(tokenizer.encode(user_message + eos))
input_ids = torch.tensor([ids], dtype=torch.long)
attention_mask = torch.ones_like(input_ids)
return input_ids, attention_mask
def chat_fn(user_message: str, history: List[Tuple[str, str]], model_name: str):
if not user_message or not user_message.strip():
return "", history, history
tokenizer, model = load_model(model_name)
input_ids, attention_mask = build_inputs(
tokenizer, history or [], user_message.strip()
)
with torch.inference_mode():
output_ids = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=48,
do_sample=True,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.1,
no_repeat_ngram_size=3,
pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
)
# Decode only the newly generated part
new_tokens = output_ids[0, input_ids.shape[1] :]
reply = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
if not reply:
reply = tokenizer.decode(output_ids[0], skip_special_tokens=True)[-200:].strip()
new_hist = (history or []) + [(user_message, reply)]
return "", new_hist, new_hist
with gr.Blocks(title="🌮 TacoGPT") as demo:
gr.Markdown(
"""
# 🌮 TacoGPT - Your Spicy AI Assistant
"""
)
with gr.Row():
model_name = gr.Dropdown(
choices=DEFAULT_MODELS,
value=DEFAULT_MODELS[0],
label="Model",
scale=2,
)
chatbot = gr.Chatbot(height=460, label="Chat", type="tuples")
msg = gr.Textbox(
placeholder="Type your message and press Enter...",
label="Message",
submit_btn="Send",
)
state = gr.State([])
gr.Examples(
examples=[
"Who created Python?",
"Write a taco recipe.",
"Give me a fun fact about tacos.",
"What's the history of tacos?",
],
inputs=msg,
)
msg.submit(
chat_fn,
inputs=[msg, state, model_name],
outputs=[msg, chatbot, state],
queue=True,
api_name=False,
)
gr.Markdown(
"""
**Tips**
- Tiny causal models can be quirky; try TinyLlama or Qwen 1.5 1.8B for better chat quality.
- Keep the questions short and objective. Like, "Who was Python creator?"
- The first run may be slower while downloading the LM weights.
""".strip()
)
if __name__ == "__main__":
demo.queue()
demo.launch(server_name="0.0.0.0", show_error=True)