File size: 1,598 Bytes
fff452e
95062a5
 
 
 
fff452e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95062a5
 
 
fff452e
 
 
 
 
 
 
95062a5
fff452e
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
from src.model import CRF_Tagger
from src.preprocessing import process_demo_sentence
import os

def predict(model, loader, count_loss=True):
    
    model.eval() # Evaluation Mode, Ignore Dropout, BatchNorm, ...
    all_preds, all_true = [], []
    loss = 0.0

    with torch.no_grad(): # Stop track gradient
        for x, y, _ in loader:
            mask = (y != -1)

            # Get loss
            if count_loss:
                loss += model(x, y, mask).item()
            
            # Get prediction
            preds = model.decode(x, mask)

            # Loop for each sentence in mini-batch
            for pred_seq, true_seq, m in zip(preds, y, mask):
                true_labels = true_seq[m].tolist() # tensor[mask tensor boolean]
                all_preds.extend(pred_seq)
                all_true.extend(true_labels)
    
    return all_preds, all_true, loss/len(loader)

def predict_demo(text):

    BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
    model_path = os.path.join(BASE_DIR, "models", "best_epoch_16.pt")


    id_tag = {0: 'O', 1: 'B-PER', 2: 'I-PER', 3: 'B-ORG', 4: 'I-ORG', 5: 'B-LOC', 6: 'I-LOC'}

    x, tokens = process_demo_sentence(text) # 1 x seq_length x 768
    NUM_TAGS = 7

    model = CRF_Tagger(input_dim=x.size(2), num_tags=NUM_TAGS)
    model.load_state_dict(torch.load(model_path))
    model.eval()
    with torch.no_grad():
        preds = model.decode(x)
    
    labels = [id_tag[lab] for lab in preds[0]] # preds[0] vì sẽ trả về nhiều batch nhưng chúng ta chỉ có 1

    return tokens, labels