manu02 commited on
Commit
0d6e478
·
verified ·
1 Parent(s): 25ea14a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +527 -502
app.py CHANGED
@@ -1,502 +1,527 @@
1
- # app.py
2
- """
3
- Gradio word-level attention visualizer with:
4
- - Paragraph-style wrapping and semi-transparent backgrounds per word
5
- - Proper detokenization to words (regex)
6
- - Ability to pick from many causal LMs
7
- - Trailing EOS/PAD special tokens removed (no <|endoftext|> shown)
8
- - FIX: safely reset Radio with value=None to avoid Gradio choices error
9
- """
10
-
11
- import re
12
- from typing import List, Tuple
13
-
14
- import gradio as gr
15
- from transformers import AutoModelForCausalLM, AutoTokenizer
16
- import torch
17
- import numpy as np
18
-
19
- # =========================
20
- # Config
21
- # =========================
22
- ALLOWED_MODELS = [
23
- # ---- GPT-2 family
24
- "gpt2", "distilgpt2", "gpt2-medium", "gpt2-large", "gpt2-xl",
25
- # ---- EleutherAI (Neo/J/NeoX/Pythia)
26
- "EleutherAI/gpt-neo-125M", "EleutherAI/gpt-neo-1.3B", "EleutherAI/gpt-neo-2.7B",
27
- "EleutherAI/gpt-j-6B", "EleutherAI/gpt-neox-20b",
28
- "EleutherAI/pythia-70m", "EleutherAI/pythia-160m", "EleutherAI/pythia-410m",
29
- "EleutherAI/pythia-1b", "EleutherAI/pythia-1.4b", "EleutherAI/pythia-2.8b",
30
- "EleutherAI/pythia-6.9b", "EleutherAI/pythia-12b",
31
- # ---- Meta OPT
32
- "facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b", "facebook/opt-2.7b",
33
- "facebook/opt-6.7b", "facebook/opt-13b", "facebook/opt-30b",
34
- # ---- Mistral
35
- "mistralai/Mistral-7B-v0.1", "mistralai/Mistral-7B-v0.3", "mistralai/Mistral-7B-Instruct-v0.2",
36
- # ---- TinyLlama / OpenLLaMA
37
- "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T",
38
- "openlm-research/open_llama_3b", "openlm-research/open_llama_7b",
39
- # ---- Microsoft Phi
40
- "microsoft/phi-1", "microsoft/phi-1_5", "microsoft/phi-2",
41
- # ---- Qwen
42
- "Qwen/Qwen1.5-0.5B", "Qwen/Qwen1.5-1.8B", "Qwen/Qwen1.5-4B", "Qwen/Qwen1.5-7B",
43
- "Qwen/Qwen2-1.5B", "Qwen/Qwen2-7B",
44
- # ---- MPT
45
- "mosaicml/mpt-7b", "mosaicml/mpt-7b-instruct",
46
- # ---- Falcon
47
- "tiiuae/falcon-7b", "tiiuae/falcon-7b-instruct", "tiiuae/falcon-40b",
48
- # ---- Cerebras GPT
49
- "cerebras/Cerebras-GPT-111M", "cerebras/Cerebras-GPT-256M",
50
- "cerebras/Cerebras-GPT-590M", "cerebras/Cerebras-GPT-1.3B", "cerebras/Cerebras-GPT-2.7B",
51
- ]
52
-
53
- device = "cuda" if torch.cuda.is_available() else "cpu"
54
- model = None
55
- tokenizer = None
56
-
57
- # Word regex (words + punctuation)
58
- WORD_RE = re.compile(r"\w+(?:'\w+)?|[^\w\s]")
59
-
60
- # =========================
61
- # Model loading
62
- # =========================
63
- def _safe_set_attn_impl(m):
64
- try:
65
- m.config._attn_implementation = "eager"
66
- except Exception:
67
- pass
68
-
69
- def load_model(model_name: str):
70
- """Load tokenizer+model globally."""
71
- global model, tokenizer
72
- try:
73
- del model
74
- torch.cuda.empty_cache()
75
- except Exception:
76
- pass
77
-
78
- tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
79
- # Ensure pad token id
80
- if tokenizer.pad_token_id is None:
81
- if tokenizer.eos_token_id is not None:
82
- tokenizer.pad_token_id = tokenizer.eos_token_id
83
- else:
84
- tokenizer.add_special_tokens({"pad_token": "<|pad|>"})
85
-
86
- model = AutoModelForCausalLM.from_pretrained(model_name)
87
- _safe_set_attn_impl(model)
88
- if hasattr(model, "resize_token_embeddings") and tokenizer.pad_token_id >= model.get_input_embeddings().num_embeddings:
89
- model.resize_token_embeddings(len(tokenizer))
90
- model.eval()
91
- model.to(device)
92
-
93
- def model_heads_layers():
94
- try:
95
- L = int(getattr(model.config, "num_hidden_layers", 12))
96
- except Exception:
97
- L = 12
98
- try:
99
- H = int(getattr(model.config, "num_attention_heads", 12))
100
- except Exception:
101
- H = 12
102
- return max(1, L), max(1, H)
103
-
104
- # =========================
105
- # Attention utils
106
- # =========================
107
- def get_attention_for_token_layer(
108
- attentions,
109
- token_index,
110
- layer_index,
111
- batch_index=0,
112
- head_index=0,
113
- mean_across_layers=True,
114
- mean_across_heads=True,
115
- ):
116
- """
117
- attentions: tuple length = #generated tokens
118
- attentions[t] -> tuple of len = num_layers, each: (batch, heads, q, k)
119
- """
120
- token_attention = attentions[token_index]
121
-
122
- if mean_across_layers:
123
- layer_attention = torch.stack(token_attention).mean(dim=0) # (batch, heads, q, k)
124
- else:
125
- layer_attention = token_attention[int(layer_index)] # (batch, heads, q, k)
126
-
127
- batch_attention = layer_attention[int(batch_index)] # (heads, q, k)
128
-
129
- if mean_across_heads:
130
- head_attention = batch_attention.mean(dim=0) # (q, k)
131
- else:
132
- head_attention = batch_attention[int(head_index)] # (q, k)
133
-
134
- return head_attention.squeeze(0) # q==1 -> (k,)
135
-
136
- # =========================
137
- # Tokens -> words mapping
138
- # =========================
139
- def _words_and_map_from_tokens(gen_token_ids: List[int]) -> Tuple[List[str], List[int]]:
140
- """
141
- From *generated* token ids, return:
142
- - words: detokenized words (regex-split)
143
- - word2tok: list where word2tok[i] = index (relative to generated) of the
144
- LAST token that composes that word.
145
- """
146
- if not gen_token_ids:
147
- return [], []
148
-
149
- gen_tokens_str = tokenizer.convert_ids_to_tokens(gen_token_ids)
150
- detok_text = tokenizer.convert_tokens_to_string(gen_tokens_str)
151
-
152
- words = WORD_RE.findall(detok_text)
153
-
154
- enc = tokenizer(detok_text, return_offsets_mapping=True, add_special_tokens=False)
155
- tok_offsets = enc["offset_mapping"]
156
- n = min(len(tok_offsets), len(gen_token_ids))
157
-
158
- spans = [m.span() for m in re.finditer(WORD_RE, detok_text)]
159
-
160
- word2tok: List[int] = []
161
- t = 0
162
- for (ws, we) in spans:
163
- last_t = None
164
- while t < n:
165
- ts, te = tok_offsets[t]
166
- if not (te <= ws or ts >= we):
167
- last_t = t
168
- t += 1
169
- else:
170
- if te <= ws:
171
- t += 1
172
- else:
173
- break
174
- if last_t is None:
175
- last_t = max(0, min(n - 1, t - 1))
176
- word2tok.append(int(last_t))
177
-
178
- return words, word2tok
179
-
180
- # =========================
181
- # Helpers
182
- # =========================
183
- def _strip_trailing_special(ids: List[int]) -> List[int]:
184
- """Remove trailing EOS/PAD/other special tokens from the generated ids."""
185
- specials = set(getattr(tokenizer, "all_special_ids", []) or [])
186
- j = len(ids)
187
- while j > 0 and ids[j - 1] in specials:
188
- j -= 1
189
- return ids[:j]
190
-
191
- def clamp01(x: float) -> float:
192
- x = float(x)
193
- return 0.0 if x < 0 else 1.0 if x > 1 else x
194
-
195
- # =========================
196
- # Visualization (WORD-LEVEL)
197
- # =========================
198
- def generate_word_visualization(words: List[str],
199
- abs_word_ends: List[int],
200
- attention_values: np.ndarray,
201
- selected_token_abs_idx: int) -> str:
202
- """
203
- Paragraph-style visualization over words.
204
- For each word, aggregate attention over its composing tokens (sum),
205
- normalize across words, and render opacity as a semi-transparent background.
206
- """
207
- if not words or attention_values is None or len(attention_values) == 0:
208
- return (
209
- "<div style='width:100%;'>"
210
- " <div style='background:#444;border:1px solid #eee;border-radius:8px;padding:10px;'>"
211
- " <div style='color:#ddd;'>No attention values.</div>"
212
- " </div>"
213
- "</div>"
214
- )
215
-
216
- # Start..end spans from ends
217
- starts = []
218
- for i, end in enumerate(abs_word_ends):
219
- if i == 0:
220
- starts.append(0)
221
- else:
222
- starts.append(min(abs_word_ends[i - 1] + 1, end))
223
-
224
- # Sum attention per word
225
- word_scores = []
226
- for i, end in enumerate(abs_word_ends):
227
- start = starts[i]
228
- if start > end:
229
- start = end
230
- s = max(0, min(start, len(attention_values) - 1))
231
- e = max(0, min(end, len(attention_values) - 1))
232
- if e < s:
233
- s, e = e, s
234
- word_scores.append(float(attention_values[s:e + 1].sum()))
235
-
236
- max_attn = max(0.1, float(max(word_scores)) if word_scores else 0.0)
237
-
238
- # Which word holds the selected token?
239
- selected_word_idx = None
240
- for i, end in enumerate(abs_word_ends):
241
- if selected_token_abs_idx <= end:
242
- selected_word_idx = i
243
- break
244
- if selected_word_idx is None and abs_word_ends:
245
- selected_word_idx = len(abs_word_ends) - 1
246
-
247
- spans = []
248
- for i, w in enumerate(words):
249
- alpha = min(1.0, word_scores[i] / max_attn) if max_attn > 0 else 0.0
250
- bg = f"rgba(66,133,244,{alpha:.3f})"
251
- border = "2px solid #fff" if i == selected_word_idx else "1px solid transparent"
252
- spans.append(
253
- f"<span style='display:inline-block;background:{bg};border:{border};"
254
- f"border-radius:6px;padding:2px 6px;margin:2px 4px 4px 0;color:#fff;'>"
255
- f"{w}</span>"
256
- )
257
-
258
- return (
259
- "<div style='width:100%;'>"
260
- " <div style='background:#444;border:1px solid #eee;border-radius:8px;padding:10px;'>"
261
- " <div style='white-space:normal;line-height:1.8;'>"
262
- f" {''.join(spans)}"
263
- " </div>"
264
- " </div>"
265
- "</div>"
266
- )
267
-
268
- # =========================
269
- # Core functions
270
- # =========================
271
- def run_generation(prompt, max_new_tokens, temperature, top_p):
272
- """Generate and prepare word-level selector + initial visualization."""
273
- inputs = tokenizer(prompt or "", return_tensors="pt").to(device)
274
- prompt_len = inputs["input_ids"].shape[1]
275
-
276
- with torch.no_grad():
277
- outputs = model.generate(
278
- **inputs,
279
- max_new_tokens=int(max_new_tokens),
280
- temperature=float(temperature),
281
- top_p=float(top_p),
282
- do_sample=True,
283
- pad_token_id=tokenizer.pad_token_id,
284
- output_attentions=True,
285
- return_dict_in_generate=True,
286
- )
287
-
288
- all_token_ids = outputs.sequences[0].tolist()
289
- generated_token_ids = _strip_trailing_special(all_token_ids[prompt_len:])
290
-
291
- # Words and map (word -> last generated token index)
292
- words, word2tok = _words_and_map_from_tokens(generated_token_ids)
293
-
294
- display_choices = [(w, i) for i, w in enumerate(words)]
295
- if not display_choices:
296
- return {
297
- state_attentions: None,
298
- state_all_token_ids: None,
299
- state_prompt_len: 0,
300
- state_words: None,
301
- state_word2tok: None,
302
- # SAFE RADIO RESET
303
- radio_word_selector: gr.update(choices=[], value=None),
304
- html_visualization: "<div style='text-align:center;padding:20px;'>No new tokens generated.</div>",
305
- }
306
-
307
- first_word_idx = 0
308
- html_init = update_visualization(
309
- first_word_idx,
310
- outputs.attentions,
311
- all_token_ids,
312
- prompt_len,
313
- 0, 0, True, True,
314
- words,
315
- word2tok,
316
- )
317
-
318
- return {
319
- state_attentions: outputs.attentions,
320
- state_all_token_ids: all_token_ids,
321
- state_prompt_len: prompt_len,
322
- state_words: words,
323
- state_word2tok: word2tok,
324
- radio_word_selector: gr.update(choices=display_choices, value=first_word_idx),
325
- html_visualization: html_init,
326
- }
327
-
328
- def update_visualization(
329
- selected_word_index,
330
- attentions,
331
- all_token_ids,
332
- prompt_len,
333
- layer,
334
- head,
335
- mean_layers,
336
- mean_heads,
337
- words,
338
- word2tok,
339
- ):
340
- """Recompute visualization for the chosen word (maps to its last token)."""
341
- if selected_word_index is None or attentions is None or word2tok is None:
342
- return "<div style='text-align:center;padding:20px;'>Generate text first.</div>"
343
-
344
- widx = int(selected_word_index)
345
- if not (0 <= widx < len(word2tok)):
346
- return "<div style='text-align:center;padding:20px;'>Invalid selection.</div>"
347
-
348
- token_index_relative = int(word2tok[widx])
349
- token_index_absolute = int(prompt_len) + token_index_relative
350
-
351
- token_attn = get_attention_for_token_layer(
352
- attentions,
353
- token_index=token_index_relative,
354
- layer_index=int(layer),
355
- head_index=int(head),
356
- mean_across_layers=bool(mean_layers),
357
- mean_across_heads=bool(mean_heads),
358
- )
359
-
360
- attn_vals = token_attn.detach().cpu().numpy()
361
-
362
- # Pad attention to full (prompt + generated) length
363
- total_tokens = len(all_token_ids)
364
- padded = np.zeros(total_tokens, dtype=float)
365
- if attn_vals.ndim == 2:
366
- attn_vals = attn_vals[-1]
367
- padded[: len(attn_vals)] = attn_vals
368
-
369
- # Absolute word ends (prompt offset + relative token index)
370
- abs_word_ends = [int(prompt_len) + int(t) for t in (word2tok or [])]
371
-
372
- return generate_word_visualization(words, abs_word_ends, padded, token_index_absolute)
373
-
374
- def toggle_slider(is_mean):
375
- return gr.update(interactive=not bool(is_mean))
376
-
377
- # =========================
378
- # Gradio UI
379
- # =========================
380
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
381
- gr.Markdown("# 🤖 Word-Level Attention Visualizer — choose a model & explore")
382
- gr.Markdown(
383
- "Pick a model, generate text, then select a **generated word** to see where it attends. "
384
- "Words wrap in a paragraph; opacity is the summed attention over the word’s tokens. "
385
- "EOS tokens are stripped so `<|endoftext|>` doesn’t appear."
386
- )
387
-
388
- # States
389
- state_attentions = gr.State(None)
390
- state_all_token_ids = gr.State(None)
391
- state_prompt_len = gr.State(None)
392
- state_words = gr.State(None)
393
- state_word2tok = gr.State(None)
394
- state_model_name = gr.State(None)
395
-
396
- with gr.Row():
397
- with gr.Column(scale=1):
398
- gr.Markdown("### 0) Model")
399
- dd_model = gr.Dropdown(
400
- ALLOWED_MODELS, value=ALLOWED_MODELS[0], label="Causal LM",
401
- info="Models that work with AutoModelForCausalLM + attentions"
402
- )
403
- btn_load = gr.Button("Load / Switch Model", variant="secondary")
404
-
405
- gr.Markdown("### 1) Generation")
406
- txt_prompt = gr.Textbox("In a distant future, humanity", label="Prompt")
407
- btn_generate = gr.Button("Generate", variant="primary")
408
- slider_max_tokens = gr.Slider(10, 200, value=50, step=10, label="Max New Tokens")
409
- slider_temp = gr.Slider(0.0, 1.5, value=0.7, step=0.1, label="Temperature")
410
- slider_top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.05, label="Top P")
411
-
412
- gr.Markdown("### 2) Attention")
413
- check_mean_layers = gr.Checkbox(True, label="Mean Across Layers")
414
- check_mean_heads = gr.Checkbox(True, label="Mean Across Heads")
415
- slider_layer = gr.Slider(0, 11, value=0, step=1, label="Layer", interactive=False)
416
- slider_head = gr.Slider(0, 11, value=0, step=1, label="Head", interactive=False)
417
-
418
- with gr.Column(scale=3):
419
- radio_word_selector = gr.Radio(
420
- [], label="Select Generated Word to Visualize",
421
- info="Click Generate to populate"
422
- )
423
- html_visualization = gr.HTML(
424
- "<div style='text-align:center;padding:20px;color:#888;border:1px dashed #888;border-radius:8px;'>"
425
- "Attention visualization will appear here.</div>"
426
- )
427
-
428
- # Load/switch model
429
- def on_load_model(selected_name, mean_layers, mean_heads):
430
- load_model(selected_name)
431
- L, H = model_heads_layers()
432
- return (
433
- selected_name, # state_model_name
434
- gr.update(minimum=0, maximum=L - 1, value=0, interactive=not bool(mean_layers)),
435
- gr.update(minimum=0, maximum=H - 1, value=0, interactive=not bool(mean_heads)),
436
- # SAFE RADIO RESET (avoid Value: [] not in choices)
437
- gr.update(choices=[], value=None),
438
- "<div style='text-align:center;padding:20px;'>Model loaded. Generate to visualize.</div>",
439
- )
440
-
441
- btn_load.click(
442
- fn=on_load_model,
443
- inputs=[dd_model, check_mean_layers, check_mean_heads],
444
- outputs=[state_model_name, slider_layer, slider_head, radio_word_selector, html_visualization],
445
- )
446
-
447
- # Load default model at app start
448
- def _init_model(_):
449
- load_model(ALLOWED_MODELS[0])
450
- L, H = model_heads_layers()
451
- return (
452
- ALLOWED_MODELS[0],
453
- gr.update(minimum=0, maximum=L - 1, value=0, interactive=False if check_mean_layers.value else True),
454
- gr.update(minimum=0, maximum=H - 1, value=0, interactive=False if check_mean_heads.value else True),
455
- # Also ensure radio is clean at start
456
- gr.update(choices=[], value=None),
457
- )
458
- demo.load(_init_model, inputs=[gr.State(None)], outputs=[state_model_name, slider_layer, slider_head, radio_word_selector])
459
-
460
- # Generate
461
- btn_generate.click(
462
- fn=run_generation,
463
- inputs=[txt_prompt, slider_max_tokens, slider_temp, slider_top_p],
464
- outputs=[
465
- state_attentions,
466
- state_all_token_ids,
467
- state_prompt_len,
468
- state_words,
469
- state_word2tok,
470
- radio_word_selector,
471
- html_visualization,
472
- ],
473
- )
474
-
475
- # Update viz on any control
476
- for control in [radio_word_selector, slider_layer, slider_head, check_mean_layers, check_mean_heads]:
477
- control.change(
478
- fn=update_visualization,
479
- inputs=[
480
- radio_word_selector,
481
- state_attentions,
482
- state_all_token_ids,
483
- state_prompt_len,
484
- slider_layer,
485
- slider_head,
486
- check_mean_layers,
487
- check_mean_heads,
488
- state_words,
489
- state_word2tok,
490
- ],
491
- outputs=html_visualization,
492
- )
493
-
494
- # Toggle slider interactivity
495
- check_mean_layers.change(toggle_slider, check_mean_layers, slider_layer)
496
- check_mean_heads.change(toggle_slider, check_mean_heads, slider_head)
497
-
498
- if __name__ == "__main__":
499
- print(f"Device: {device}")
500
- # Ensure a default model is ready
501
- load_model(ALLOWED_MODELS[0])
502
- demo.launch(debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ """
3
+ Gradio word-level attention visualizer with:
4
+ - Paragraph-style wrapping and semi-transparent backgrounds per word
5
+ - Proper detokenization to words (regex)
6
+ - Trailing EOS/PAD special tokens removed (no <|endoftext|> shown)
7
+ - Selection ONLY from generated words; prompt is hidden from selector
8
+ - Viewer shows attention over BOTH prompt and generated words (context)
9
+ """
10
+
11
+ import re
12
+ from typing import List, Tuple
13
+
14
+ import gradio as gr
15
+ from transformers import AutoModelForCausalLM, AutoTokenizer
16
+ import torch
17
+ import numpy as np
18
+
19
+ # =========================
20
+ # Config
21
+ # =========================
22
+ ALLOWED_MODELS = [
23
+ # ---- GPT-2 family
24
+ "gpt2", "distilgpt2", "gpt2-medium", "gpt2-large", "gpt2-xl",
25
+ # ---- EleutherAI (Neo/J/NeoX/Pythia)
26
+ "EleutherAI/gpt-neo-125M", "EleutherAI/gpt-neo-1.3B", "EleutherAI/gpt-neo-2.7B",
27
+ "EleutherAI/gpt-j-6B", "EleutherAI/gpt-neox-20b",
28
+ "EleutherAI/pythia-70m", "EleutherAI/pythia-160m", "EleutherAI/pythia-410m",
29
+ "EleutherAI/pythia-1b", "EleutherAI/pythia-1.4b", "EleutherAI/pythia-2.8b",
30
+ "EleutherAI/pythia-6.9b", "EleutherAI/pythia-12b",
31
+ # ---- Meta OPT
32
+ "facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b", "facebook/opt-2.7b",
33
+ "facebook/opt-6.7b", "facebook/opt-13b", "facebook/opt-30b",
34
+ # ---- Mistral
35
+ "mistralai/Mistral-7B-v0.1", "mistralai/Mistral-7B-v0.3", "mistralai/Mistral-7B-Instruct-v0.2",
36
+ # ---- TinyLlama / OpenLLaMA
37
+ "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T",
38
+ "openlm-research/open_llama_3b", "openlm-research/open_llama_7b",
39
+ # ---- Microsoft Phi
40
+ "microsoft/phi-1", "microsoft/phi-1_5", "microsoft/phi-2",
41
+ # ---- Qwen
42
+ "Qwen/Qwen1.5-0.5B", "Qwen/Qwen1.5-1.8B", "Qwen/Qwen1.5-4B", "Qwen/Qwen1.5-7B",
43
+ "Qwen/Qwen2-1.5B", "Qwen/Qwen2-7B",
44
+ # ---- MPT
45
+ "mosaicml/mpt-7b", "mosaicml/mpt-7b-instruct",
46
+ # ---- Falcon
47
+ "tiiuae/falcon-7b", "tiiuae/falcon-7b-instruct", "tiiuae/falcon-40b",
48
+ # ---- Cerebras GPT
49
+ "cerebras/Cerebras-GPT-111M", "cerebras/Cerebras-GPT-256M",
50
+ "cerebras/Cerebras-GPT-590M", "cerebras/Cerebras-GPT-1.3B", "cerebras/Cerebras-GPT-2.7B",
51
+ ]
52
+
53
+ device = "cuda" if torch.cuda.is_available() else "cpu"
54
+ model = None
55
+ tokenizer = None
56
+
57
+ # Word regex (words + punctuation)
58
+ WORD_RE = re.compile(r"\w+(?:'\w+)?|[^\w\s]")
59
+
60
+ # =========================
61
+ # Model loading
62
+ # =========================
63
+ def _safe_set_attn_impl(m):
64
+ try:
65
+ m.config._attn_implementation = "eager"
66
+ except Exception:
67
+ pass
68
+
69
+ def load_model(model_name: str):
70
+ """Load tokenizer+model globally."""
71
+ global model, tokenizer
72
+ try:
73
+ del model
74
+ torch.cuda.empty_cache()
75
+ except Exception:
76
+ pass
77
+
78
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
79
+ # Ensure pad token id
80
+ if tokenizer.pad_token_id is None:
81
+ if tokenizer.eos_token_id is not None:
82
+ tokenizer.pad_token_id = tokenizer.eos_token_id
83
+ else:
84
+ tokenizer.add_special_tokens({"pad_token": "<|pad|>"})
85
+
86
+ model = AutoModelForCausalLM.from_pretrained(model_name)
87
+ _safe_set_attn_impl(model)
88
+ if hasattr(model, "resize_token_embeddings") and tokenizer.pad_token_id >= model.get_input_embeddings().num_embeddings:
89
+ model.resize_token_embeddings(len(tokenizer))
90
+ model.eval()
91
+ model.to(device)
92
+
93
+ def model_heads_layers():
94
+ try:
95
+ L = int(getattr(model.config, "num_hidden_layers", 12))
96
+ except Exception:
97
+ L = 12
98
+ try:
99
+ H = int(getattr(model.config, "num_attention_heads", 12))
100
+ except Exception:
101
+ H = 12
102
+ return max(1, L), max(1, H)
103
+
104
+ # =========================
105
+ # Attention utils
106
+ # =========================
107
+ def get_attention_for_token_layer(
108
+ attentions,
109
+ token_index,
110
+ layer_index,
111
+ batch_index=0,
112
+ head_index=0,
113
+ mean_across_layers=True,
114
+ mean_across_heads=True,
115
+ ):
116
+ """
117
+ attentions: tuple length = #generated tokens
118
+ attentions[t] -> tuple of len = num_layers, each: (batch, heads, q, k)
119
+ """
120
+ token_attention = attentions[token_index]
121
+
122
+ if mean_across_layers:
123
+ layer_attention = torch.stack(token_attention).mean(dim=0) # (batch, heads, q, k)
124
+ else:
125
+ layer_attention = token_attention[int(layer_index)] # (batch, heads, q, k)
126
+
127
+ batch_attention = layer_attention[int(batch_index)] # (heads, q, k)
128
+
129
+ if mean_across_heads:
130
+ head_attention = batch_attention.mean(dim=0) # (q, k)
131
+ else:
132
+ head_attention = batch_attention[int(head_index)] # (q, k)
133
+
134
+ return head_attention.squeeze(0) # q==1 -> (k,)
135
+
136
+ # =========================
137
+ # Tokens -> words mapping
138
+ # =========================
139
+ def _words_and_map_from_tokens_simple(token_ids: List[int]) -> Tuple[List[str], List[int]]:
140
+ """
141
+ Given token_ids (in-order), return:
142
+ - words: regex-split words from detokenized text
143
+ - word2tok: indices (relative to `token_ids`) of the LAST token composing each word
144
+ """
145
+ if not token_ids:
146
+ return [], []
147
+ toks = tokenizer.convert_ids_to_tokens(token_ids)
148
+ detok = tokenizer.convert_tokens_to_string(toks)
149
+ words = WORD_RE.findall(detok)
150
+
151
+ enc = tokenizer(detok, return_offsets_mapping=True, add_special_tokens=False)
152
+ tok_offsets = enc["offset_mapping"]
153
+ n = min(len(tok_offsets), len(token_ids))
154
+ spans = [m.span() for m in re.finditer(WORD_RE, detok)]
155
+
156
+ word2tok: List[int] = []
157
+ t = 0
158
+ for (ws, we) in spans:
159
+ last_t = None
160
+ while t < n:
161
+ ts, te = tok_offsets[t]
162
+ if not (te <= ws or ts >= we):
163
+ last_t = t
164
+ t += 1
165
+ else:
166
+ if te <= ws:
167
+ t += 1
168
+ else:
169
+ break
170
+ if last_t is None:
171
+ last_t = max(0, min(n - 1, t - 1))
172
+ word2tok.append(int(last_t))
173
+ return words, word2tok
174
+
175
+ def _strip_trailing_special(ids: List[int]) -> List[int]:
176
+ """Remove trailing EOS/PAD/other special tokens from the generated ids."""
177
+ specials = set(getattr(tokenizer, "all_special_ids", []) or [])
178
+ j = len(ids)
179
+ while j > 0 and ids[j - 1] in specials:
180
+ j -= 1
181
+ return ids[:j]
182
+
183
+ def _words_and_maps_for_full_and_gen(all_token_ids: List[int], prompt_len: int):
184
+ """
185
+ Returns:
186
+ words_all: list[str] (prompt + generated, in order)
187
+ abs_ends_all: list[int] absolute last-token index per word (over all_token_ids)
188
+ words_gen: list[str] (generated only)
189
+ abs_ends_gen: list[int] absolute last-token index per generated word
190
+ """
191
+ if not all_token_ids:
192
+ return [], [], [], []
193
+
194
+ prompt_ids = all_token_ids[:prompt_len]
195
+ gen_ids = _strip_trailing_special(all_token_ids[prompt_len:])
196
+
197
+ p_words, p_map_rel = _words_and_map_from_tokens_simple(prompt_ids)
198
+ g_words, g_map_rel = _words_and_map_from_tokens_simple(gen_ids)
199
+
200
+ p_abs = [int(i) for i in p_map_rel] # prompt starts at absolute 0
201
+ g_abs = [prompt_len + int(i) for i in g_map_rel]
202
+
203
+ words_all = p_words + g_words
204
+ abs_ends_all = p_abs + g_abs
205
+
206
+ return words_all, abs_ends_all, g_words, g_abs
207
+
208
+ # =========================
209
+ # Visualization (WORD-LEVEL)
210
+ # =========================
211
+ def generate_word_visualization(words_all: List[str],
212
+ abs_word_ends_all: List[int],
213
+ attention_values: np.ndarray,
214
+ selected_token_abs_idx: int) -> str:
215
+ """
216
+ Paragraph-style visualization over words (prompt + generated).
217
+ For each word, aggregate attention over its composing tokens (sum),
218
+ normalize across words, and render opacity as a semi-transparent background.
219
+ """
220
+ if not words_all or attention_values is None or len(attention_values) == 0:
221
+ return (
222
+ "<div style='width:100%;'>"
223
+ " <div style='background:#444;border:1px solid #eee;border-radius:8px;padding:10px;'>"
224
+ " <div style='color:#ddd;'>No attention values.</div>"
225
+ " </div>"
226
+ "</div>"
227
+ )
228
+
229
+ # Build word starts from ends (inclusive token indices)
230
+ starts = []
231
+ for i, end in enumerate(abs_word_ends_all):
232
+ if i == 0:
233
+ starts.append(0)
234
+ else:
235
+ starts.append(min(abs_word_ends_all[i - 1] + 1, end))
236
+
237
+ # Sum attention per word
238
+ word_scores = []
239
+ for i, end in enumerate(abs_word_ends_all):
240
+ start = starts[i]
241
+ if start > end:
242
+ start = end
243
+ s = max(0, min(start, len(attention_values) - 1))
244
+ e = max(0, min(end, len(attention_values) - 1))
245
+ if e < s:
246
+ s, e = e, s
247
+ word_scores.append(float(attention_values[s:e + 1].sum()))
248
+
249
+ max_attn = max(0.1, float(max(word_scores)) if word_scores else 0.0)
250
+
251
+ # Which word holds the selected token?
252
+ selected_word_idx = None
253
+ for i, end in enumerate(abs_word_ends_all):
254
+ if selected_token_abs_idx <= end:
255
+ selected_word_idx = i
256
+ break
257
+ if selected_word_idx is None and abs_word_ends_all:
258
+ selected_word_idx = len(abs_word_ends_all) - 1
259
+
260
+ spans = []
261
+ for i, w in enumerate(words_all):
262
+ alpha = min(1.0, word_scores[i] / max_attn) if max_attn > 0 else 0.0
263
+ bg = f"rgba(66,133,244,{alpha:.3f})"
264
+ border = "2px solid #fff" if i == selected_word_idx else "1px solid transparent"
265
+ spans.append(
266
+ f"<span style='display:inline-block;background:{bg};border:{border};"
267
+ f"border-radius:6px;padding:2px 6px;margin:2px 4px 4px 0;color:#fff;'>"
268
+ f"{w}</span>"
269
+ )
270
+
271
+ return (
272
+ "<div style='width:100%;'>"
273
+ " <div style='background:#444;border:1px solid #eee;border-radius:8px;padding:10px;'>"
274
+ " <div style='white-space:normal;line-height:1.8;'>"
275
+ f" {''.join(spans)}"
276
+ " </div>"
277
+ " </div>"
278
+ "</div>"
279
+ )
280
+
281
+ # =========================
282
+ # Core functions
283
+ # =========================
284
+ def run_generation(prompt, max_new_tokens, temperature, top_p):
285
+ """Generate and prepare word-level selector + initial visualization."""
286
+ inputs = tokenizer(prompt or "", return_tensors="pt").to(device)
287
+ prompt_len = inputs["input_ids"].shape[1]
288
+
289
+ with torch.no_grad():
290
+ outputs = model.generate(
291
+ **inputs,
292
+ max_new_tokens=int(max_new_tokens),
293
+ temperature=float(temperature),
294
+ top_p=float(top_p),
295
+ do_sample=True,
296
+ pad_token_id=tokenizer.pad_token_id,
297
+ output_attentions=True,
298
+ return_dict_in_generate=True,
299
+ )
300
+
301
+ all_token_ids = outputs.sequences[0].tolist()
302
+
303
+ # Build mappings for (prompt+generated) and for generated-only
304
+ words_all, abs_all, words_gen, abs_gen = _words_and_maps_for_full_and_gen(all_token_ids, prompt_len)
305
+
306
+ # Radio choices: ONLY generated words
307
+ display_choices = [(w, i) for i, w in enumerate(words_gen)]
308
+
309
+ if not display_choices:
310
+ return {
311
+ state_attentions: None,
312
+ state_all_token_ids: None,
313
+ state_prompt_len: 0,
314
+ state_words_all: None,
315
+ state_abs_all: None,
316
+ state_gen_abs: None,
317
+ radio_word_selector: gr.update(choices=[], value=None),
318
+ html_visualization: "<div style='text-align:center;padding:20px;'>No generated tokens to visualize.</div>",
319
+ }
320
+
321
+ first_gen_idx = 0
322
+ html_init = update_visualization(
323
+ first_gen_idx,
324
+ outputs.attentions,
325
+ all_token_ids,
326
+ prompt_len,
327
+ 0, 0, True, True,
328
+ words_all,
329
+ abs_all,
330
+ abs_gen, # map selector index -> absolute token end
331
+ )
332
+
333
+ return {
334
+ state_attentions: outputs.attentions,
335
+ state_all_token_ids: all_token_ids,
336
+ state_prompt_len: prompt_len,
337
+ state_words_all: words_all,
338
+ state_abs_all: abs_all,
339
+ state_gen_abs: abs_gen,
340
+ radio_word_selector: gr.update(choices=display_choices, value=first_gen_idx),
341
+ html_visualization: html_init,
342
+ }
343
+
344
+ def update_visualization(
345
+ selected_gen_index,
346
+ attentions,
347
+ all_token_ids,
348
+ prompt_len,
349
+ layer,
350
+ head,
351
+ mean_layers,
352
+ mean_heads,
353
+ words_all,
354
+ abs_all,
355
+ gen_abs_list, # absolute last-token indices for generated words (selector domain)
356
+ ):
357
+ """Recompute visualization for the chosen GENERATED word, over full context."""
358
+ if selected_gen_index is None or attentions is None or gen_abs_list is None:
359
+ return "<div style='text-align:center;padding:20px;'>Generate text first.</div>"
360
+
361
+ gidx = int(selected_gen_index)
362
+ if not (0 <= gidx < len(gen_abs_list)):
363
+ return "<div style='text-align:center;padding:20px;'>Invalid selection.</div>"
364
+
365
+ token_index_abs = int(gen_abs_list[gidx])
366
+
367
+ # Map absolute generated index -> generation step
368
+ # step = abs_idx - prompt_len (clamped)
369
+ if len(attentions) == 0:
370
+ return "<div style='text-align:center;padding:20px;'>No attention steps available.</div>"
371
+
372
+ step_index = token_index_abs - prompt_len
373
+ step_index = max(0, min(step_index, len(attentions) - 1))
374
+
375
+ token_attn = get_attention_for_token_layer(
376
+ attentions,
377
+ token_index=step_index, # index by generation step
378
+ layer_index=int(layer),
379
+ head_index=int(head),
380
+ mean_across_layers=bool(mean_layers),
381
+ mean_across_heads=bool(mean_heads),
382
+ )
383
+
384
+ attn_vals = token_attn.detach().cpu().numpy()
385
+ if attn_vals.ndim == 2:
386
+ attn_vals = attn_vals[-1]
387
+
388
+ total_tokens = len(all_token_ids)
389
+ padded = np.zeros(total_tokens, dtype=float)
390
+ k_len = min(len(attn_vals), total_tokens)
391
+ padded[:k_len] = attn_vals[:k_len]
392
+
393
+ # Absolute word ends for FULL sequence (prompt + generated)
394
+ abs_word_ends = [int(i) for i in (abs_all or [])]
395
+
396
+ return generate_word_visualization(words_all, abs_word_ends, padded, token_index_abs)
397
+
398
+ def toggle_slider(is_mean):
399
+ return gr.update(interactive=not bool(is_mean))
400
+
401
+ # =========================
402
+ # Gradio UI
403
+ # =========================
404
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
405
+ gr.Markdown("# 🤖 Word-Level Attention Visualizer — choose a model & explore")
406
+ gr.Markdown(
407
+ "Generate text, then select a **generated word** to see where it attends. "
408
+ "The viewer below shows attention over both the **prompt** and the **generated** continuation. "
409
+ "EOS tokens are stripped so `<|endoftext|>` doesn’t appear."
410
+ )
411
+
412
+ # States
413
+ state_attentions = gr.State(None)
414
+ state_all_token_ids = gr.State(None)
415
+ state_prompt_len = gr.State(None)
416
+ state_words_all = gr.State(None) # full (prompt + gen) words
417
+ state_abs_all = gr.State(None) # full abs ends
418
+ state_gen_abs = gr.State(None) # generated-only abs ends
419
+ state_model_name = gr.State(None)
420
+
421
+ with gr.Row():
422
+ with gr.Column(scale=1):
423
+ gr.Markdown("### 0) Model")
424
+ dd_model = gr.Dropdown(
425
+ ALLOWED_MODELS, value=ALLOWED_MODELS[0], label="Causal LM",
426
+ info="Models that work with AutoModelForCausalLM + attentions"
427
+ )
428
+ btn_load = gr.Button("Load / Switch Model", variant="secondary")
429
+
430
+ gr.Markdown("### 1) Generation")
431
+ txt_prompt = gr.Textbox("In a distant future, humanity", label="Prompt")
432
+ btn_generate = gr.Button("Generate", variant="primary")
433
+ slider_max_tokens = gr.Slider(10, 200, value=50, step=10, label="Max New Tokens")
434
+ slider_temp = gr.Slider(0.0, 1.5, value=0.7, step=0.1, label="Temperature")
435
+ slider_top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.05, label="Top P")
436
+
437
+ gr.Markdown("### 2) Attention")
438
+ check_mean_layers = gr.Checkbox(True, label="Mean Across Layers")
439
+ check_mean_heads = gr.Checkbox(True, label="Mean Across Heads")
440
+ slider_layer = gr.Slider(0, 11, value=0, step=1, label="Layer", interactive=False)
441
+ slider_head = gr.Slider(0, 11, value=0, step=1, label="Head", interactive=False)
442
+
443
+ with gr.Column(scale=3):
444
+ radio_word_selector = gr.Radio(
445
+ [], label="Select Generated Word",
446
+ info="Selector lists only generated words"
447
+ )
448
+ html_visualization = gr.HTML(
449
+ "<div style='text-align:center;padding:20px;color:#888;border:1px dashed #888;border-radius:8px;'>"
450
+ "Attention visualization will appear here.</div>"
451
+ )
452
+
453
+ # Load/switch model
454
+ def on_load_model(selected_name, mean_layers, mean_heads):
455
+ load_model(selected_name)
456
+ L, H = model_heads_layers()
457
+ return (
458
+ selected_name, # state_model_name
459
+ gr.update(minimum=0, maximum=L - 1, value=0, interactive=not bool(mean_layers)),
460
+ gr.update(minimum=0, maximum=H - 1, value=0, interactive=not bool(mean_heads)),
461
+ # SAFE RADIO RESET
462
+ gr.update(choices=[], value=None),
463
+ "<div style='text-align:center;padding:20px;'>Model loaded. Generate to visualize.</div>",
464
+ )
465
+
466
+ btn_load.click(
467
+ fn=on_load_model,
468
+ inputs=[dd_model, check_mean_layers, check_mean_heads],
469
+ outputs=[state_model_name, slider_layer, slider_head, radio_word_selector, html_visualization],
470
+ )
471
+
472
+ # Load default model at app start
473
+ def _init_model(_):
474
+ load_model(ALLOWED_MODELS[0])
475
+ L, H = model_heads_layers()
476
+ return (
477
+ ALLOWED_MODELS[0],
478
+ gr.update(minimum=0, maximum=L - 1, value=0, interactive=False),
479
+ gr.update(minimum=0, maximum=H - 1, value=0, interactive=False),
480
+ gr.update(choices=[], value=None),
481
+ )
482
+ demo.load(_init_model, inputs=[gr.State(None)], outputs=[state_model_name, slider_layer, slider_head, radio_word_selector])
483
+
484
+ # Generate
485
+ btn_generate.click(
486
+ fn=run_generation,
487
+ inputs=[txt_prompt, slider_max_tokens, slider_temp, slider_top_p],
488
+ outputs=[
489
+ state_attentions,
490
+ state_all_token_ids,
491
+ state_prompt_len,
492
+ state_words_all,
493
+ state_abs_all,
494
+ state_gen_abs,
495
+ radio_word_selector,
496
+ html_visualization,
497
+ ],
498
+ )
499
+
500
+ # Update viz on any control
501
+ for control in [radio_word_selector, slider_layer, slider_head, check_mean_layers, check_mean_heads]:
502
+ control.change(
503
+ fn=update_visualization,
504
+ inputs=[
505
+ radio_word_selector,
506
+ state_attentions,
507
+ state_all_token_ids,
508
+ state_prompt_len,
509
+ slider_layer,
510
+ slider_head,
511
+ check_mean_layers,
512
+ check_mean_heads,
513
+ state_words_all,
514
+ state_abs_all,
515
+ state_gen_abs,
516
+ ],
517
+ outputs=html_visualization,
518
+ )
519
+
520
+ # Toggle slider interactivity
521
+ check_mean_layers.change(toggle_slider, check_mean_layers, slider_layer)
522
+ check_mean_heads.change(toggle_slider, check_mean_heads, slider_head)
523
+
524
+ if __name__ == "__main__":
525
+ print(f"Device: {device}")
526
+ load_model(ALLOWED_MODELS[0])
527
+ demo.launch(debug=True)