Spaces:
Runtime error
Runtime error
import torch | |
import streamlit as st | |
import transformers | |
bert_mlm_positive, bert_mlm_negative, bert_clf, tokenizer = torch.load('model.pt', map_location='cpu') | |
def get_replacements(sentence: str, num_tokens, k_best, epsilon=1e-3): | |
tokens_info = tokenizer(sentence, return_tensors="pt") | |
probs_positive = bert_mlm_positive(**tokens_info).logits.softmax(dim=-1)[0] | |
probs_neagtive = bert_mlm_negative(**tokens_info).logits.softmax(dim=-1)[0] | |
sentence_tokens_pos_prob = probs_positive[torch.arange(len(probs_positive)), tokens_info['input_ids'][0]] | |
sentence_tokens_neg_prob = probs_neagtive[torch.arange(len(probs_positive)), tokens_info['input_ids'][0]] | |
probs_ratio = torch.argsort((sentence_tokens_pos_prob + epsilon) / (sentence_tokens_neg_prob + epsilon)) | |
result = [] | |
for ind in probs_ratio[:num_tokens]: | |
top_positive = torch.flip(torch.argsort(probs_positive[ind])[-k_best:], dims=(0,)) | |
for new_ind in top_positive: | |
new_sentece = tokens_info['input_ids'][0] | |
new_sentece[ind] = new_ind | |
result.append(tokenizer.decode(new_sentece[1:-1])) | |
return result | |
def beamsearch(sentence: str, max_word_changes: int, k_best: int, beam_size:int): | |
beam = [sentence] | |
for i in range(max_word_changes): | |
new_beam = [] | |
for line in beam: | |
new_beam += get_replacements(line, num_tokens=1, k_best = k_best) | |
new_beam = list(set(new_beam)) | |
clf_probs = [] | |
for line in new_beam: | |
tokens_info = tokenizer(line, return_tensors="pt") | |
clf_probs.append(bert_clf(**tokens_info).logits.softmax(dim=-1)[0][1]) | |
top_indexes = torch.argsort(torch.tensor(clf_probs))[-beam_size:] | |
beam = [new_beam[i] for i in top_indexes] | |
return list(set(beam)) | |
st.markdown("Hello") | |
user_input = st.text_input("Please enter some sentence") | |
if user_input: | |
st.markdown(beamsearch(user_input, max_word_changes=5, k_best=3, beam_size=5)) | |