Mekhron commited on
Commit
fc30cb8
·
1 Parent(s): 210faa8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -5,7 +5,7 @@ import transformers
5
  bert_mlm_positive, bert_mlm_negative, bert_clf, tokenizer = torch.load('model.pt', map_location='cpu')
6
 
7
  def get_replacements(sentence: str, num_tokens, k_best, epsilon=1e-3):
8
- tokens_info = tokenizer(sentence, return_tensors="pt").to(device)
9
  probs_positive = bert_mlm_positive(**tokens_info).logits.softmax(dim=-1)[0]
10
  probs_neagtive = bert_mlm_negative(**tokens_info).logits.softmax(dim=-1)[0]
11
 
@@ -34,7 +34,7 @@ def beamsearch(sentence: str, max_word_changes: int, k_best: int, beam_size:int)
34
  new_beam = list(set(new_beam))
35
  clf_probs = []
36
  for line in new_beam:
37
- tokens_info = tokenizer(line, return_tensors="pt").to(device)
38
  clf_probs.append(bert_clf(**tokens_info).logits.softmax(dim=-1)[0][1])
39
  top_indexes = np.argsort(clf_probs)[-beam_size:]
40
  beam = [new_beam[i] for i in top_indexes]
 
5
  bert_mlm_positive, bert_mlm_negative, bert_clf, tokenizer = torch.load('model.pt', map_location='cpu')
6
 
7
  def get_replacements(sentence: str, num_tokens, k_best, epsilon=1e-3):
8
+ tokens_info = tokenizer(sentence, return_tensors="pt")
9
  probs_positive = bert_mlm_positive(**tokens_info).logits.softmax(dim=-1)[0]
10
  probs_neagtive = bert_mlm_negative(**tokens_info).logits.softmax(dim=-1)[0]
11
 
 
34
  new_beam = list(set(new_beam))
35
  clf_probs = []
36
  for line in new_beam:
37
+ tokens_info = tokenizer(line, return_tensors="pt")
38
  clf_probs.append(bert_clf(**tokens_info).logits.softmax(dim=-1)[0][1])
39
  top_indexes = np.argsort(clf_probs)[-beam_size:]
40
  beam = [new_beam[i] for i in top_indexes]