Spaces:
Running
Running
Lev McKinney
commited on
Commit
Β·
c35da92
1
Parent(s):
7a724e0
upgraded app to use tuned_lens=0.1.0
Browse files- README.md +1 -0
- app.py +21 -15
- requirements.txt +1 -1
README.md
CHANGED
|
@@ -3,6 +3,7 @@ title: Tuned Lens
|
|
| 3 |
emoji: π
|
| 4 |
colorFrom: pink
|
| 5 |
colorTo: blue
|
|
|
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
license: mit
|
|
|
|
| 3 |
emoji: π
|
| 4 |
colorFrom: pink
|
| 5 |
colorTo: blue
|
| 6 |
+
port: 7860
|
| 7 |
sdk: docker
|
| 8 |
pinned: false
|
| 9 |
license: mit
|
app.py
CHANGED
|
@@ -1,17 +1,20 @@
|
|
| 1 |
import torch
|
| 2 |
from tuned_lens.nn.lenses import TunedLens, LogitLens
|
| 3 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 4 |
-
from tuned_lens.plotting import
|
| 5 |
import gradio as gr
|
| 6 |
from plotly import graph_objects as go
|
| 7 |
|
| 8 |
device = torch.device("cpu")
|
| 9 |
print(f"Using device {device} for inference")
|
| 10 |
-
model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-410m-deduped
|
| 11 |
model = model.to(device)
|
| 12 |
-
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-410m-deduped
|
| 13 |
-
tuned_lens = TunedLens.
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
lens_options_dict = {
|
| 17 |
"Tuned Lens": tuned_lens,
|
|
@@ -20,13 +23,15 @@ lens_options_dict = {
|
|
| 20 |
|
| 21 |
statistic_options_dict = {
|
| 22 |
"Entropy": "entropy",
|
| 23 |
-
"Cross Entropy": "
|
| 24 |
"Forward KL": "forward_kl",
|
| 25 |
}
|
| 26 |
|
| 27 |
|
| 28 |
def make_plot(lens, text, statistic, token_cutoff):
|
| 29 |
input_ids = tokenizer.encode(text, return_tensors="pt")
|
|
|
|
|
|
|
| 30 |
|
| 31 |
if len(input_ids[0]) == 0:
|
| 32 |
return go.Figure(layout=dict(title="Please enter some text."))
|
|
@@ -34,18 +39,19 @@ def make_plot(lens, text, statistic, token_cutoff):
|
|
| 34 |
if token_cutoff < 1:
|
| 35 |
return go.Figure(layout=dict(title="Please provide valid token cut off."))
|
| 36 |
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
layer_stride=2,
|
| 42 |
input_ids=input_ids,
|
| 43 |
-
|
| 44 |
-
|
|
|
|
| 45 |
)
|
| 46 |
|
| 47 |
-
return
|
| 48 |
-
|
|
|
|
| 49 |
|
| 50 |
preamble = """
|
| 51 |
# The Tuned Lens π
|
|
|
|
| 1 |
import torch
|
| 2 |
from tuned_lens.nn.lenses import TunedLens, LogitLens
|
| 3 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 4 |
+
from tuned_lens.plotting import PredictionTrajectory
|
| 5 |
import gradio as gr
|
| 6 |
from plotly import graph_objects as go
|
| 7 |
|
| 8 |
device = torch.device("cpu")
|
| 9 |
print(f"Using device {device} for inference")
|
| 10 |
+
model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-410m-deduped")
|
| 11 |
model = model.to(device)
|
| 12 |
+
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-410m-deduped")
|
| 13 |
+
tuned_lens = TunedLens.from_model_and_pretrained(
|
| 14 |
+
model=model,
|
| 15 |
+
map_location=device,
|
| 16 |
+
)
|
| 17 |
+
logit_lens = LogitLens.from_model(model)
|
| 18 |
|
| 19 |
lens_options_dict = {
|
| 20 |
"Tuned Lens": tuned_lens,
|
|
|
|
| 23 |
|
| 24 |
statistic_options_dict = {
|
| 25 |
"Entropy": "entropy",
|
| 26 |
+
"Cross Entropy": "cross_entropy",
|
| 27 |
"Forward KL": "forward_kl",
|
| 28 |
}
|
| 29 |
|
| 30 |
|
| 31 |
def make_plot(lens, text, statistic, token_cutoff):
|
| 32 |
input_ids = tokenizer.encode(text, return_tensors="pt")
|
| 33 |
+
input_ids = [tokenizer.bos_token_id] + input_ids
|
| 34 |
+
targets = input_ids[1:] + [tokenizer.eos_token_id]
|
| 35 |
|
| 36 |
if len(input_ids[0]) == 0:
|
| 37 |
return go.Figure(layout=dict(title="Please enter some text."))
|
|
|
|
| 39 |
if token_cutoff < 1:
|
| 40 |
return go.Figure(layout=dict(title="Please provide valid token cut off."))
|
| 41 |
|
| 42 |
+
start_pos=max(len(input_ids[0]) - token_cutoff, 0),
|
| 43 |
+
pred_traj = PredictionTrajectory.from_lens_and_model(
|
| 44 |
+
lens=lens,
|
| 45 |
+
model=model,
|
|
|
|
| 46 |
input_ids=input_ids,
|
| 47 |
+
tokenizer=tokenizer,
|
| 48 |
+
targets=targets,
|
| 49 |
+
start_pos=start_pos,
|
| 50 |
)
|
| 51 |
|
| 52 |
+
return getattr(pred_traj, statistic)().figure(
|
| 53 |
+
title=f"{lens.__class__.__name__} ({model.name_or_path}) {statistic}",
|
| 54 |
+
)
|
| 55 |
|
| 56 |
preamble = """
|
| 57 |
# The Tuned Lens π
|
requirements.txt
CHANGED
|
@@ -1,2 +1,2 @@
|
|
| 1 |
-
tuned_lens==0.0
|
| 2 |
gradio
|
|
|
|
| 1 |
+
tuned_lens==0.1.0
|
| 2 |
gradio
|