Spaces:
Running
Running
File size: 1,598 Bytes
fff452e 95062a5 fff452e 95062a5 fff452e 95062a5 fff452e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import torch
from src.model import CRF_Tagger
from src.preprocessing import process_demo_sentence
import os
def predict(model, loader, count_loss=True):
model.eval() # Evaluation Mode, Ignore Dropout, BatchNorm, ...
all_preds, all_true = [], []
loss = 0.0
with torch.no_grad(): # Stop track gradient
for x, y, _ in loader:
mask = (y != -1)
# Get loss
if count_loss:
loss += model(x, y, mask).item()
# Get prediction
preds = model.decode(x, mask)
# Loop for each sentence in mini-batch
for pred_seq, true_seq, m in zip(preds, y, mask):
true_labels = true_seq[m].tolist() # tensor[mask tensor boolean]
all_preds.extend(pred_seq)
all_true.extend(true_labels)
return all_preds, all_true, loss/len(loader)
def predict_demo(text):
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
model_path = os.path.join(BASE_DIR, "models", "best_epoch_16.pt")
id_tag = {0: 'O', 1: 'B-PER', 2: 'I-PER', 3: 'B-ORG', 4: 'I-ORG', 5: 'B-LOC', 6: 'I-LOC'}
x, tokens = process_demo_sentence(text) # 1 x seq_length x 768
NUM_TAGS = 7
model = CRF_Tagger(input_dim=x.size(2), num_tags=NUM_TAGS)
model.load_state_dict(torch.load(model_path))
model.eval()
with torch.no_grad():
preds = model.decode(x)
labels = [id_tag[lab] for lab in preds[0]] # preds[0] vì sẽ trả về nhiều batch nhưng chúng ta chỉ có 1
return tokens, labels
|