ma7583 commited on
Commit
0a4648d
·
verified ·
1 Parent(s): 3a307f4

Delete sft_gradio.py

Browse files
Files changed (1) hide show
  1. sft_gradio.py +0 -150
sft_gradio.py DELETED
@@ -1,150 +0,0 @@
1
- import streamlit as st
2
- import random
3
- import torch, os
4
- import pandas as pd
5
- import torch.nn.functional as F
6
- from transformers import AutoTokenizer, AutoModelForCausalLM
7
- from datasets import Dataset
8
- from sft_dataset import extract_rollouts, load_config
9
-
10
- # Load model and tokenizer with caching and progress
11
- @st.cache_resource(show_spinner=True)
12
- def load_model_and_tokenizer(path):
13
- with st.spinner("Loading model and tokenizer..."):
14
- tokenizer = AutoTokenizer.from_pretrained(path, use_fast=True)
15
- model = AutoModelForCausalLM.from_pretrained(path, dtype="auto", device_map="auto")
16
- return tokenizer, model
17
-
18
- # Load and prepare dataset with caching and progress
19
- @st.cache_data(show_spinner=True)
20
- def load_dataset(path):
21
- with st.spinner("Loading and processing dataset..."):
22
- ds_train, ds_test = extract_rollouts(path)
23
- return ds_test
24
-
25
- # Log probability scorer
26
- def get_logprob(model, tokenizer, text):
27
- inputs = tokenizer(text, return_tensors='pt', max_length=2048, truncation=True, padding=True).to(model.device)
28
- with torch.no_grad():
29
- outputs = model(**inputs, labels=inputs['input_ids'])
30
- logits = outputs.logits
31
- labels = inputs['input_ids']
32
- log_probs = -F.cross_entropy(
33
- logits.view(-1, logits.size(-1)), labels.view(-1), reduction='none'
34
- ).view(labels.shape).sum(dim=1)
35
- return log_probs
36
-
37
- # Completion ranker
38
- def recommend_top_k_steps(model, tokenizer, prompt, top_k=3):
39
- inputs = tokenizer(prompt, max_length=2048, truncation=True, return_tensors='pt').to(model.device)
40
-
41
- stop_ids = {tokenizer.eos_token_id}
42
- # for token in ["END","\n"]:
43
- for token in ["END"]:
44
- END_ID = tokenizer.convert_tokens_to_ids(token)
45
- stop_ids.add(END_ID)
46
-
47
- model.eval()
48
- with torch.no_grad():
49
- gen = model.generate(
50
- **inputs,
51
- do_sample=True,
52
- num_return_sequences=top_k,
53
- top_k=50,
54
- top_p=0.9,
55
- temperature=0.7,
56
- pad_token_id=tokenizer.eos_token_id,
57
- eos_token_id=list(stop_ids),
58
- output_scores=True,
59
- return_dict_in_generate=True,
60
- )
61
-
62
- sequences = gen.sequences
63
- scores = gen.scores
64
- prompt_len = inputs["input_ids"].shape[1]
65
-
66
- suggestions_with_logprob = []
67
- for i in range(sequences.size(0)):
68
- gen_text = tokenizer.decode(sequences[i, prompt_len:], skip_special_tokens=True).strip()
69
- gen_ids = sequences[i, prompt_len:]
70
- total_logprob, token_count = 0.0, 0
71
-
72
- for t in range(min(len(scores), gen_ids.numel())):
73
- token_id = int(gen_ids[t].item())
74
- if token_id in stop_ids:
75
- break
76
- step_logits = scores[t][i]
77
- step_logprobs = F.log_softmax(step_logits, dim=-1)
78
- total_logprob += float(step_logprobs[token_id].item())
79
- token_count += 1
80
-
81
- length_norm_logprob = total_logprob / max(token_count, 1)
82
- suggestions_with_logprob.append((gen_text, length_norm_logprob))
83
-
84
- suggestions_ranked = sorted(suggestions_with_logprob, key=lambda x: x[1], reverse=True)
85
- return suggestions_ranked
86
-
87
- # WEBSOCKET = 111334
88
- # @reqest.method("GET", "/recoommen")
89
- # # def recommend(sequent: str, previous_comamnds: list):
90
- # result = model( seqent, prev_commands)
91
- # ## write to websocket port
92
- # return websocket.send(json.dumps(result))
93
-
94
-
95
- if __name__ == "__main__":
96
-
97
- # Streamlit UI
98
- st.title("PVS Step Recommender")
99
-
100
- config = load_config("pvs_v5.yaml")
101
- SAVE_PATH = config.save_path
102
- ROLLOUTS = config.rollouts_dir
103
- DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
104
- NUM_SUGGESTIONS = 3
105
-
106
- # Initialize model, tokenizer, and dataset
107
- test_ds = load_dataset(ROLLOUTS)
108
- tokenizer, model = load_model_and_tokenizer(SAVE_PATH)
109
-
110
- if st.button("Sample Random Proof Step"):
111
- idx = random.randrange(len(test_ds))
112
- example = test_ds[idx]
113
- st.session_state["current_idx"] = idx
114
- st.session_state["example"] = example
115
- st.session_state["auto_recommend"] = True
116
-
117
- if "example" in st.session_state:
118
- example = st.session_state["example"]
119
- prompt = example["text"] + "\nNext Command:\n"
120
- st.text(f"Theorem: {example['id']}")
121
- st.text(f"True: {example['label']}")
122
- user_input = st.text_area("Input", prompt, height="content")
123
-
124
- # if st.button("Recommend Next Steps"):
125
- # ranked_completions = recommend_top_k_steps(user_input)
126
- # st.subheader("Top Suggestions")
127
- # for i, (completion, score) in enumerate(ranked_completions):
128
- # st.markdown(
129
- # f"**Suggestion {i+1} (log prob: {score:.2f})**<br><br>{completion.replace(chr(10), '<br>')}",
130
- # unsafe_allow_html=True
131
- # )
132
-
133
- recommend_pressed = st.button("Recommend Next Steps")
134
- if recommend_pressed or st.session_state.get("auto_recommend", False):
135
- ranked_completions = recommend_top_k_steps(model, tokenizer, user_input, top_k=NUM_SUGGESTIONS)
136
- st.subheader("Top Suggestions")
137
- # Build table data
138
- table_data = [
139
- {
140
- "Log_prob": round(score, 2),
141
- "Command": completion.split("\n")[0],
142
- "Output": completion.replace('\n', '\\n'),# Escaped for display
143
- }
144
- for i, (completion, score) in enumerate(ranked_completions)
145
- ]
146
- st.table(pd.DataFrame(table_data))
147
-
148
-
149
-
150
-