Spaces:
Runtime error
Runtime error
File size: 1,952 Bytes
dc396d1 00fce7b dc396d1 210faa8 fc30cb8 dc396d1 2df4a0e dc396d1 210faa8 dc396d1 210faa8 35f3dbe dc396d1 fc30cb8 35f3dbe 7c1a1dd dc396d1 35f3dbe dc396d1 255ba51 210faa8 dc396d1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import torch
import streamlit as st
import transformers
bert_mlm_positive, bert_mlm_negative, bert_clf, tokenizer = torch.load('model.pt', map_location='cpu')
def get_replacements(sentence: str, num_tokens, k_best, epsilon=1e-3):
tokens_info = tokenizer(sentence, return_tensors="pt")
probs_positive = bert_mlm_positive(**tokens_info).logits.softmax(dim=-1)[0]
probs_neagtive = bert_mlm_negative(**tokens_info).logits.softmax(dim=-1)[0]
sentence_tokens_pos_prob = probs_positive[torch.arange(len(probs_positive)), tokens_info['input_ids'][0]]
sentence_tokens_neg_prob = probs_neagtive[torch.arange(len(probs_positive)), tokens_info['input_ids'][0]]
probs_ratio = torch.argsort((sentence_tokens_pos_prob + epsilon) / (sentence_tokens_neg_prob + epsilon))
result = []
for ind in probs_ratio[:num_tokens]:
top_positive = torch.flip(torch.argsort(probs_positive[ind])[-k_best:], dims=(0,))
for new_ind in top_positive:
new_sentece = tokens_info['input_ids'][0]
new_sentece[ind] = new_ind
result.append(tokenizer.decode(new_sentece[1:-1]))
return result
def beamsearch(sentence: str, max_word_changes: int, k_best: int, beam_size:int):
beam = [sentence]
for i in range(max_word_changes):
new_beam = []
for line in beam:
new_beam += get_replacements(line, num_tokens=1, k_best = k_best)
new_beam = list(set(new_beam))
clf_probs = []
for line in new_beam:
tokens_info = tokenizer(line, return_tensors="pt")
clf_probs.append(bert_clf(**tokens_info).logits.softmax(dim=-1)[0][1])
top_indexes = torch.argsort(torch.tensor(clf_probs))[-beam_size:]
beam = [new_beam[i] for i in top_indexes]
return list(set(beam))
st.markdown("Hello")
user_input = st.text_input("Please enter some sentence")
if user_input:
st.markdown(beamsearch(user_input, max_word_changes=5, k_best=3, beam_size=5))
|