Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import json | |
| import numpy as np | |
| import torch | |
| import transformers | |
| import tokenizers | |
| from model import BertAD | |
| DICTIONARY = json.load(open('model/dict.json')) | |
| TOKENIZER = tokenizers.BertWordPieceTokenizer(f"model/vocab.txt", lowercase=True) | |
| MAX_LEN = 256 | |
| MODEL = BertAD() | |
| vec = MODEL.state_dict()['bert.embeddings.position_ids'] | |
| chkp = torch.load(os.path.join('model', 'model_0.bin'), map_location='cpu') | |
| chkp['bert.embeddings.position_ids'] = vec | |
| MODEL.load_state_dict(chkp) | |
| del chkp, vec | |
| def sample_text(text, acronym, max_len): | |
| text = text.split() | |
| idx = text.index(acronym) | |
| left_idx = max(0, idx - max_len//2) | |
| right_idx = min(len(text), idx + max_len//2) | |
| sampled_text = text[left_idx:right_idx] | |
| return ' '.join(sampled_text) | |
| def process_data(text, acronym, expansion, tokenizer, max_len): | |
| text = str(text) | |
| expansion = str(expansion) | |
| acronym = str(acronym) | |
| n_tokens = len(text.split()) | |
| if n_tokens>120: | |
| text = sample_text(text, acronym, 120) | |
| answers = acronym + ' ' + ' '.join(DICTIONARY[acronym]) | |
| start = answers.find(expansion) | |
| end = start + len(expansion) | |
| char_mask = [0]*len(answers) | |
| for i in range(start, end): | |
| char_mask[i] = 1 | |
| tok_answer = tokenizer.encode(answers) | |
| answer_ids = tok_answer.ids | |
| answer_offsets = tok_answer.offsets | |
| answer_ids = answer_ids[1:-1] | |
| answer_offsets = answer_offsets[1:-1] | |
| target_idx = [] | |
| for i, (off1, off2) in enumerate(answer_offsets): | |
| if sum(char_mask[off1:off2])>0: | |
| target_idx.append(i) | |
| start = target_idx[0] | |
| end = target_idx[-1] | |
| text_ids = tokenizer.encode(text).ids[1:-1] | |
| token_ids = [101] + answer_ids + [102] + text_ids + [102] | |
| offsets = [(0,0)] + answer_offsets + [(0,0)]*(len(text_ids) + 2) | |
| mask = [1] * len(token_ids) | |
| token_type = [0]*(len(answer_ids) + 1) + [1]*(2+len(text_ids)) | |
| text = answers + text | |
| start = start + 1 | |
| end = end + 1 | |
| padding = max_len - len(token_ids) | |
| if padding>=0: | |
| token_ids = token_ids + ([0] * padding) | |
| token_type = token_type + [1] * padding | |
| mask = mask + ([0] * padding) | |
| offsets = offsets + ([(0, 0)] * padding) | |
| else: | |
| token_ids = token_ids[0:max_len] | |
| token_type = token_type[0:max_len] | |
| mask = mask[0:max_len] | |
| offsets = offsets[0:max_len] | |
| assert len(token_ids)==max_len | |
| assert len(mask)==max_len | |
| assert len(offsets)==max_len | |
| assert len(token_type)==max_len | |
| return { | |
| 'ids': token_ids, | |
| 'mask': mask, | |
| 'token_type': token_type, | |
| 'offset': offsets, | |
| 'start': start, | |
| 'end': end, | |
| 'text': text, | |
| 'expansion': expansion, | |
| 'acronym': acronym, | |
| } | |
| def jaccard(str1, str2): | |
| a = set(str1.lower().split()) | |
| b = set(str2.lower().split()) | |
| c = a.intersection(b) | |
| return float(len(c)) / (len(a) + len(b) - len(c)) | |
| def evaluate_jaccard(text, selected_text, acronym, offsets, idx_start, idx_end): | |
| filtered_output = "" | |
| for ix in range(idx_start, idx_end + 1): | |
| filtered_output += text[offsets[ix][0]: offsets[ix][1]] | |
| if (ix+1) < len(offsets) and offsets[ix][1] < offsets[ix+1][0]: | |
| filtered_output += " " | |
| candidates = DICTIONARY[acronym] | |
| candidate_jaccards = [jaccard(w.strip(), filtered_output.strip()) for w in candidates] | |
| idx = np.argmax(candidate_jaccards) | |
| return candidate_jaccards[idx], candidates[idx] | |
| def disambiguate(text, acronym): | |
| inputs = process_data(text, acronym, acronym, TOKENIZER, MAX_LEN) | |
| ids = torch.tensor(input['ids']).view(1, -1) | |
| mask = torch.tensor(inputs['mask']).view(1, -1) | |
| token_type = torch.tensor(inputs['token_type']).view(1, -1) | |
| offsets = inputs['offset'] | |
| expansion = inputs['expnsion'] | |
| acronym = inputs['acronym'] | |
| start_logits, end_logits = MODEL(ids, mask, token_type) | |
| start_prob = torch.softmax(start_logits, axis=-1).detach().numpy() | |
| end_prob = torch.softmax(end_logits, axis=-1).detach().numpy() | |
| start_idx = np.argmax(start_prob[0,:]) | |
| end_idx = np.argmax(end_prob[0,:]) | |
| js, exp = evaluate_jaccard(text, expansion[0], acronym[0], offsets[0], start_idx, end_idx) | |
| return exp | |
| text = gr.inputs.Textbox(lines=5, label="Context", placeholder="Type a sentence or paragraph here."), | |
| acronym = gr.inputs.Textbox(lines=2, label="Question", placeholder="Type acronym") | |
| expansion = gr.outputs.Textbox(label="Answer") | |
| iface = gr.Interface(fn=disambiguate, inputs=[text, acronym], outputs=expansion) | |
| iface.launch() |