Eliezer Oliveira commited on
Commit
5e2cfbe
·
1 Parent(s): baac0ee

Add initial implementation of TacoGPT chatbot

Browse files
Files changed (3) hide show
  1. .gitignore +1 -0
  2. app.py +160 -0
  3. requirements.txt +5 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .idea
app.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+
3
+ import torch
4
+ import gradio as gr
5
+ from transformers import (
6
+ AutoTokenizer,
7
+ AutoModelForCausalLM,
8
+ )
9
+
10
+ # Causal LMs only
11
+ DEFAULT_MODELS = [
12
+ "microsoft/DialoGPT-medium",
13
+ "gpt2",
14
+ "EleutherAI/gpt-neo-125M",
15
+ "EleutherAI/pythia-350m",
16
+ "facebook/opt-350m",
17
+ "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
18
+ "Qwen/Qwen1.5-1.8B-Chat",
19
+ ]
20
+
21
+ # cache model load to optimize model selection
22
+ _MODEL_CACHE = {}
23
+
24
+
25
+ def load_model(model_name: str):
26
+ key = model_name
27
+
28
+ if key in _MODEL_CACHE:
29
+ return _MODEL_CACHE[key]
30
+
31
+ tok = AutoTokenizer.from_pretrained(model_name, use_fast=True)
32
+ model = AutoModelForCausalLM.from_pretrained(model_name, use_safetensors=True)
33
+ model.eval()
34
+
35
+ _MODEL_CACHE[key] = (tok, model)
36
+ return tok, model
37
+
38
+
39
+ def build_inputs(tokenizer, history: List[Tuple[str, str]], user_message: str):
40
+ # If a chat template exists, use it (best for Qwen/TinyLlama).
41
+ use_chat_template = bool(getattr(tokenizer, "chat_template", None))
42
+ if use_chat_template:
43
+ conv = []
44
+ for u, b in (history or [])[-6:]:
45
+ conv.append({"role": "user", "content": u})
46
+ conv.append({"role": "assistant", "content": b})
47
+ conv.append({"role": "user", "content": user_message})
48
+
49
+ prompt_ids = tokenizer.apply_chat_template(
50
+ conv,
51
+ tokenize=True,
52
+ add_generation_prompt=True,
53
+ return_tensors=None,
54
+ )
55
+ input_ids = torch.tensor([prompt_ids], dtype=torch.long)
56
+ attention_mask = torch.ones_like(input_ids)
57
+ return input_ids, attention_mask
58
+
59
+ eos = tokenizer.eos_token or ""
60
+
61
+ ids = []
62
+ for u, b in (history or [])[-6:]:
63
+ ids.extend(tokenizer.encode(u + eos))
64
+ ids.extend(tokenizer.encode(b + eos))
65
+
66
+ # Current user message; add EOS to mark turn boundary for all non-templated LMs
67
+ ids.extend(tokenizer.encode(user_message + eos))
68
+
69
+ input_ids = torch.tensor([ids], dtype=torch.long)
70
+ attention_mask = torch.ones_like(input_ids)
71
+ return input_ids, attention_mask
72
+
73
+
74
+ def chat_fn(user_message: str, history: List[Tuple[str, str]], model_name: str):
75
+ if not user_message or not user_message.strip():
76
+ return "", history, history
77
+
78
+ tokenizer, model = load_model(model_name)
79
+ input_ids, attention_mask = build_inputs(
80
+ tokenizer, history or [], user_message.strip()
81
+ )
82
+
83
+ with torch.inference_mode():
84
+ output_ids = model.generate(
85
+ input_ids=input_ids,
86
+ attention_mask=attention_mask,
87
+ max_new_tokens=48,
88
+ do_sample=True,
89
+ temperature=0.7,
90
+ top_p=0.9,
91
+ repetition_penalty=1.1,
92
+ no_repeat_ngram_size=3,
93
+ pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
94
+ )
95
+
96
+ # Decode only the newly generated part
97
+ new_tokens = output_ids[0, input_ids.shape[1] :]
98
+ reply = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
99
+
100
+ if not reply:
101
+ reply = tokenizer.decode(output_ids[0], skip_special_tokens=True)[-200:].strip()
102
+
103
+ new_hist = (history or []) + [(user_message, reply)]
104
+ return "", new_hist, new_hist
105
+
106
+
107
+ with gr.Blocks(title="🌮 TacoGPT") as demo:
108
+ gr.Markdown(
109
+ """
110
+ # 🌮 TacoGPT - Your Spicy AI Assistant
111
+ """
112
+ )
113
+
114
+ with gr.Row():
115
+ model_name = gr.Dropdown(
116
+ choices=DEFAULT_MODELS,
117
+ value=DEFAULT_MODELS[0],
118
+ label="Model",
119
+ scale=2,
120
+ )
121
+
122
+ chatbot = gr.Chatbot(height=460, label="Chat")
123
+ msg = gr.Textbox(
124
+ placeholder="Type your message and press Enter...",
125
+ label="Message",
126
+ submit_btn="Send",
127
+ )
128
+
129
+ state = gr.State([])
130
+
131
+ gr.Examples(
132
+ examples=[
133
+ "Who created Python?",
134
+ "Write a taco recipe.",
135
+ "Give me a fun fact about tacos.",
136
+ "What's the history of tacos?",
137
+ ],
138
+ inputs=msg,
139
+ )
140
+
141
+ msg.submit(
142
+ chat_fn,
143
+ inputs=[msg, state, model_name],
144
+ outputs=[msg, chatbot, state],
145
+ queue=True,
146
+ api_name=False,
147
+ )
148
+
149
+ gr.Markdown(
150
+ """
151
+ **Tips**
152
+ - Tiny causal models can be quirky; try TinyLlama or Qwen 1.5 1.8B for better chat quality.
153
+ - Keep the questions short and objective. Like, "Who was Python creator?"
154
+ - The first run may be slower while downloading the LM weights.
155
+ """.strip()
156
+ )
157
+
158
+
159
+ if __name__ == "__main__":
160
+ demo.launch(server_name="127.0.0.1", server_port=7860, show_error=True)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ --index-url https://download.pytorch.org/whl/cpu
2
+ torch
3
+ transformers
4
+ gradio
5
+ safetensors