import torch import streamlit as st import transformers bert_mlm_positive, bert_mlm_negative, bert_clf, tokenizer = torch.load('model.pt', map_location='cpu') def get_replacements(sentence: str, num_tokens, k_best, epsilon=1e-3): tokens_info = tokenizer(sentence, return_tensors="pt") probs_positive = bert_mlm_positive(**tokens_info).logits.softmax(dim=-1)[0] probs_neagtive = bert_mlm_negative(**tokens_info).logits.softmax(dim=-1)[0] sentence_tokens_pos_prob = probs_positive[torch.arange(len(probs_positive)), tokens_info['input_ids'][0]] sentence_tokens_neg_prob = probs_neagtive[torch.arange(len(probs_positive)), tokens_info['input_ids'][0]] probs_ratio = torch.argsort((sentence_tokens_pos_prob + epsilon) / (sentence_tokens_neg_prob + epsilon)) result = [] for ind in probs_ratio[:num_tokens]: top_positive = torch.flip(torch.argsort(probs_positive[ind])[-k_best:], dims=(0,)) for new_ind in top_positive: new_sentece = tokens_info['input_ids'][0] new_sentece[ind] = new_ind result.append(tokenizer.decode(new_sentece[1:-1])) return result def beamsearch(sentence: str, max_word_changes: int, k_best: int, beam_size:int): beam = [sentence] for i in range(max_word_changes): new_beam = [] for line in beam: new_beam += get_replacements(line, num_tokens=1, k_best = k_best) new_beam = list(set(new_beam)) clf_probs = [] for line in new_beam: tokens_info = tokenizer(line, return_tensors="pt") clf_probs.append(bert_clf(**tokens_info).logits.softmax(dim=-1)[0][1]) top_indexes = torch.argsort(torch.tensor(clf_probs))[-beam_size:] beam = [new_beam[i] for i in top_indexes] return list(set(beam)) st.markdown("Hello") user_input = st.text_input("Please enter some sentence") if user_input: st.markdown(beamsearch(user_input, max_word_changes=5, k_best=3, beam_size=5))