File size: 1,952 Bytes
dc396d1
00fce7b
dc396d1
 
210faa8
 
 
fc30cb8
dc396d1
 
 
2df4a0e
 
dc396d1
 
 
 
 
 
 
 
 
 
 
 
 
 
210faa8
dc396d1
 
 
 
210faa8
35f3dbe
dc396d1
 
fc30cb8
35f3dbe
7c1a1dd
dc396d1
35f3dbe
dc396d1
 
 
 
 
 
 
255ba51
210faa8
dc396d1
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
import streamlit as st
import transformers

bert_mlm_positive, bert_mlm_negative, bert_clf, tokenizer = torch.load('model.pt', map_location='cpu')

def get_replacements(sentence: str, num_tokens, k_best, epsilon=1e-3):
  tokens_info = tokenizer(sentence, return_tensors="pt")
  probs_positive = bert_mlm_positive(**tokens_info).logits.softmax(dim=-1)[0]
  probs_neagtive = bert_mlm_negative(**tokens_info).logits.softmax(dim=-1)[0]
  
  sentence_tokens_pos_prob = probs_positive[torch.arange(len(probs_positive)), tokens_info['input_ids'][0]]
  sentence_tokens_neg_prob = probs_neagtive[torch.arange(len(probs_positive)), tokens_info['input_ids'][0]]
 
  probs_ratio = torch.argsort((sentence_tokens_pos_prob + epsilon) / (sentence_tokens_neg_prob + epsilon))
 
  result = []
  for ind in probs_ratio[:num_tokens]:
    top_positive = torch.flip(torch.argsort(probs_positive[ind])[-k_best:], dims=(0,))
    for new_ind in top_positive:
        new_sentece = tokens_info['input_ids'][0]
        new_sentece[ind] = new_ind
        result.append(tokenizer.decode(new_sentece[1:-1]))
      
  return result

      
def beamsearch(sentence: str, max_word_changes: int, k_best: int, beam_size:int):
    beam = [sentence]
    for i in range(max_word_changes):
      new_beam = []
      for line in beam:
        new_beam += get_replacements(line, num_tokens=1, k_best = k_best)
      new_beam = list(set(new_beam))
      clf_probs = []
      for line in new_beam:
        tokens_info = tokenizer(line, return_tensors="pt")
        clf_probs.append(bert_clf(**tokens_info).logits.softmax(dim=-1)[0][1])
      top_indexes = torch.argsort(torch.tensor(clf_probs))[-beam_size:]
      beam = [new_beam[i] for i in top_indexes]
    return list(set(beam))
    
    
st.markdown("Hello")
 
user_input = st.text_input("Please enter some sentence")
 
if user_input:
 
   st.markdown(beamsearch(user_input, max_word_changes=5, k_best=3, beam_size=5))