Mekhron's picture
Update app.py
7c1a1dd
import torch
import streamlit as st
import transformers
bert_mlm_positive, bert_mlm_negative, bert_clf, tokenizer = torch.load('model.pt', map_location='cpu')
def get_replacements(sentence: str, num_tokens, k_best, epsilon=1e-3):
tokens_info = tokenizer(sentence, return_tensors="pt")
probs_positive = bert_mlm_positive(**tokens_info).logits.softmax(dim=-1)[0]
probs_neagtive = bert_mlm_negative(**tokens_info).logits.softmax(dim=-1)[0]
sentence_tokens_pos_prob = probs_positive[torch.arange(len(probs_positive)), tokens_info['input_ids'][0]]
sentence_tokens_neg_prob = probs_neagtive[torch.arange(len(probs_positive)), tokens_info['input_ids'][0]]
probs_ratio = torch.argsort((sentence_tokens_pos_prob + epsilon) / (sentence_tokens_neg_prob + epsilon))
result = []
for ind in probs_ratio[:num_tokens]:
top_positive = torch.flip(torch.argsort(probs_positive[ind])[-k_best:], dims=(0,))
for new_ind in top_positive:
new_sentece = tokens_info['input_ids'][0]
new_sentece[ind] = new_ind
result.append(tokenizer.decode(new_sentece[1:-1]))
return result
def beamsearch(sentence: str, max_word_changes: int, k_best: int, beam_size:int):
beam = [sentence]
for i in range(max_word_changes):
new_beam = []
for line in beam:
new_beam += get_replacements(line, num_tokens=1, k_best = k_best)
new_beam = list(set(new_beam))
clf_probs = []
for line in new_beam:
tokens_info = tokenizer(line, return_tensors="pt")
clf_probs.append(bert_clf(**tokens_info).logits.softmax(dim=-1)[0][1])
top_indexes = torch.argsort(torch.tensor(clf_probs))[-beam_size:]
beam = [new_beam[i] for i in top_indexes]
return list(set(beam))
st.markdown("Hello")
user_input = st.text_input("Please enter some sentence")
if user_input:
st.markdown(beamsearch(user_input, max_word_changes=5, k_best=3, beam_size=5))