BigSalmon commited on
Commit
8d6238f
·
1 Parent(s): 07977f4

Create new file

Browse files
Files changed (1) hide show
  1. app.py +67 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
3
+ import torch
4
+
5
+ first = """It is a wonderful day to"""
6
+
7
+
8
+ name_of_model = st.text_input("Name of the model you want to run", "gpt2")
9
+
10
+ @st.cache(allow_output_mutation=True)
11
+ def get_model(name_of_model):
12
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
13
+ model = AutoModelForCausalLM.from_pretrained(name_of_model)
14
+ return model, tokenizer
15
+
16
+ model, tokenizer = get_model()
17
+ temp = st.sidebar.slider("Temperature", 0.7, 1.5)
18
+ number_of_outputs = st.sidebar.slider("Number of Outputs", 5, 50)
19
+ lengths = st.sidebar.slider("Length", 3, 500)
20
+ bad_words = st.text_input("Words You Do Not Want Generated", " core lemon height time ")
21
+ logs_outputs = st.sidebar.slider("Logit Outputs", 50, 300)
22
+
23
+ def run_generate(text, bad_words):
24
+ yo = []
25
+ input_ids = tokenizer.encode(text, return_tensors='pt')
26
+ res = len(tokenizer.encode(text))
27
+ bad_words = bad_words.split()
28
+ bad_word_ids = []
29
+ for bad_word in bad_words:
30
+ bad_word = " " + bad_word
31
+ ids = tokenizer(bad_word).input_ids
32
+ bad_word_ids.append(ids)
33
+ sample_outputs = model.generate(
34
+ input_ids,
35
+ do_sample=True,
36
+ max_length= res + lengths,
37
+ min_length = res + lengths,
38
+ top_k=50,
39
+ temperature=temp,
40
+ num_return_sequences=number_of_outputs,
41
+ bad_words_ids=bad_word_ids
42
+ )
43
+ for i in range(number_of_outputs):
44
+ e = tokenizer.decode(sample_outputs[i])
45
+ e = e.replace(text, "")
46
+ yo.append(e)
47
+ return yo
48
+ with st.form(key='my_form'):
49
+ text = st.text_area(label='Enter sentence', value=first)
50
+ submit_button = st.form_submit_button(label='Submit')
51
+ submit_button2 = st.form_submit_button(label='Submit Log Probs')
52
+ if submit_button:
53
+ translated_text = run_generate(text, bad_words)
54
+ st.write(translated_text if translated_text else "No translation found")
55
+ if submit_button2:
56
+ with torch.no_grad():
57
+ text2 = str(text)
58
+ print(text2)
59
+ text3 = tokenizer.encode(text2)
60
+ myinput, past_key_values = torch.tensor([text3]), None
61
+ myinput = myinput
62
+ logits, past_key_values = model(myinput, past_key_values = past_key_values, return_dict=False)
63
+ logits = logits[0,-1]
64
+ probabilities = torch.nn.functional.softmax(logits)
65
+ best_logits, best_indices = logits.topk(logs_outputs)
66
+ best_words = [tokenizer.decode([idx.item()]) for idx in best_indices]
67
+ st.write(best_words)