Spaces:
Running
Running
import torch | |
from src.model import CRF_Tagger | |
from src.preprocessing import process_demo_sentence | |
import os | |
def predict(model, loader, count_loss=True): | |
model.eval() # Evaluation Mode, Ignore Dropout, BatchNorm, ... | |
all_preds, all_true = [], [] | |
loss = 0.0 | |
with torch.no_grad(): # Stop track gradient | |
for x, y, _ in loader: | |
mask = (y != -1) | |
# Get loss | |
if count_loss: | |
loss += model(x, y, mask).item() | |
# Get prediction | |
preds = model.decode(x, mask) | |
# Loop for each sentence in mini-batch | |
for pred_seq, true_seq, m in zip(preds, y, mask): | |
true_labels = true_seq[m].tolist() # tensor[mask tensor boolean] | |
all_preds.extend(pred_seq) | |
all_true.extend(true_labels) | |
return all_preds, all_true, loss/len(loader) | |
def predict_demo(text): | |
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
model_path = os.path.join(BASE_DIR, "models", "best_epoch_16.pt") | |
id_tag = {0: 'O', 1: 'B-PER', 2: 'I-PER', 3: 'B-ORG', 4: 'I-ORG', 5: 'B-LOC', 6: 'I-LOC'} | |
x, tokens = process_demo_sentence(text) # 1 x seq_length x 768 | |
NUM_TAGS = 7 | |
model = CRF_Tagger(input_dim=x.size(2), num_tags=NUM_TAGS) | |
model.load_state_dict(torch.load(model_path)) | |
model.eval() | |
with torch.no_grad(): | |
preds = model.decode(x) | |
labels = [id_tag[lab] for lab in preds[0]] # preds[0] vì sẽ trả về nhiều batch nhưng chúng ta chỉ có 1 | |
return tokens, labels | |