foxxy-hm commited on
Commit
16bbc11
·
1 Parent(s): 06db71f

Rename models/predict_model.py to models/handler.py

Browse files
Files changed (2) hide show
  1. models/handler.py +92 -0
  2. models/predict_model.py +0 -76
models/handler.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models.pairwise_model import *
2
+ from features.text_utils import *
3
+ import regex as re
4
+ from models.bm25_utils import BM25Gensim
5
+ from models.qa_model import *
6
+ from tqdm.auto import tqdm
7
+ tqdm.pandas()
8
+ from datasets import load_dataset
9
+ # from typing import Dict, List, Any
10
+ # from transformers import pipeline, AutoTokenizer
11
+
12
+
13
+ class EndpointHandler():
14
+ def __init__(self, path=""):
15
+ df_wiki_windows = load_dataset("foxxy-hm/e2eqa-wiki", data_files="processed/wikipedia_20220620_cleaned_v2.csv")["train"].to_pandas()
16
+ df_wiki = load_dataset("foxxy-hm/e2eqa-wiki", data_files="wikipedia_20220620_short.csv")["train"].to_pandas()
17
+ df_wiki.title = df_wiki.title.apply(str)
18
+
19
+ entity_dict = load_dataset("foxxy-hm/e2eqa-wiki", data_files="processed/entities.json")["train"].to_dict()
20
+ new_dict = dict()
21
+ for key, val in entity_dict.items():
22
+ val = val[0].replace("wiki/", "").replace("_", " ")
23
+ entity_dict[key] = val
24
+ key = preprocess(key)
25
+ new_dict[key.lower()] = val
26
+ entity_dict.update(new_dict)
27
+ title2idx = dict([(x.strip(), y) for x, y in zip(df_wiki.title, df_wiki.index.values)])
28
+ # load the optimized model
29
+ qa_model = QAEnsembleModel("nguyenvulebinh/vi-mrc-large", ["qa_model_robust.bin"], entity_dict)
30
+ pairwise_model_stage1 = PairwiseModel("nguyenvulebinh/vi-mrc-base")#.half()
31
+ pairwise_model_stage1.load_state_dict(torch.load("pairwise_v2.bin", map_location=torch.device('cpu')))
32
+ pairwise_model_stage1.eval()
33
+
34
+ pairwise_model_stage2 = PairwiseModel("nguyenvulebinh/vi-mrc-base")#.half()
35
+ pairwise_model_stage2.load_state_dict(torch.load("pairwise_stage2_seed0.bin", map_location=torch.device('cpu')))
36
+
37
+ bm25_model_stage1 = BM25Gensim("bm25_stage1/", entity_dict, title2idx)
38
+ bm25_model_stage2_full = BM25Gensim("bm25_stage2/full_text/", entity_dict, title2idx)
39
+ bm25_model_stage2_title = BM25Gensim("bm25_stage2/title/", entity_dict, title2idx)
40
+ # # create inference pipeline
41
+ # self.pipeline = pipeline("text-classification", model=model, tokenizer=tokenizer)
42
+
43
+ def get_answer_e2e(self, question):
44
+ #Bm25 retrieval for top200 candidates
45
+ query = preprocess(question).lower()
46
+ top_n, bm25_scores = bm25_model_stage1.get_topk_stage1(query, topk=200)
47
+ titles = [preprocess(df_wiki_windows.title.values[i]) for i in top_n]
48
+ texts = [preprocess(df_wiki_windows.text.values[i]) for i in top_n]
49
+
50
+ #Reranking with pairwise model for top10
51
+ question = preprocess(question)
52
+ ranking_preds = pairwise_model_stage1.stage1_ranking(question, texts)
53
+ ranking_scores = ranking_preds * bm25_scores
54
+
55
+ #Question answering
56
+ best_idxs = np.argsort(ranking_scores)[-10:]
57
+ ranking_scores = np.array(ranking_scores)[best_idxs]
58
+ texts = np.array(texts)[best_idxs]
59
+ best_answer = qa_model(question, texts, ranking_scores)
60
+ if best_answer is None:
61
+ return "Chịu"
62
+ bm25_answer = preprocess(str(best_answer).lower(), max_length=128, remove_puncts=True)
63
+
64
+ #Entity mapping
65
+ if not check_number(bm25_answer):
66
+ bm25_question = preprocess(str(question).lower(), max_length=128, remove_puncts=True)
67
+ bm25_question_answer = bm25_question + " " + bm25_answer
68
+ candidates, scores = bm25_model_stage2_title.get_topk_stage2(bm25_answer, raw_answer=best_answer)
69
+ titles = [df_wiki.title.values[i] for i in candidates]
70
+ texts = [df_wiki.text.values[i] for i in candidates]
71
+ ranking_preds = pairwise_model_stage2.stage2_ranking(question, best_answer, titles, texts)
72
+ if ranking_preds.max() >= 0.1:
73
+ final_answer = titles[ranking_preds.argmax()]
74
+ else:
75
+ candidates, scores = bm25_model_stage2_full.get_topk_stage2(bm25_question_answer)
76
+ titles = [df_wiki.title.values[i] for i in candidates] + titles
77
+ texts = [df_wiki.text.values[i] for i in candidates] + texts
78
+ ranking_preds = np.concatenate(
79
+ [pairwise_model_stage2.stage2_ranking(question, best_answer, titles, texts), ranking_preds])
80
+ final_answer = "wiki/"+titles[ranking_preds.argmax()].replace(" ","_")
81
+ else:
82
+ final_answer = bm25_answer.lower()
83
+ return final_answer
84
+
85
+
86
+ def __call__(self, question):
87
+ """
88
+ """
89
+ # Call the get_answer_e2e method with the question
90
+ answer = self.get_answer_e2e(question)
91
+ # Return the answer as a dictionary
92
+ return {"answer": answer}
models/predict_model.py DELETED
@@ -1,76 +0,0 @@
1
- from src.models.pairwise_model import *
2
- from src.features.text_utils import *
3
- import regex as re
4
- from src.models.bm25_utils import BM25Gensim
5
- from src.models.qa_model import *
6
- from tqdm.auto import tqdm
7
- tqdm.pandas()
8
- from datasets import load_dataset
9
-
10
- df_wiki_windows = load_dataset("foxxy-hm/e2eqa-wiki", data_files="processed/wikipedia_20220620_cleaned_v2.csv")["train"].to_pandas()
11
- df_wiki = load_dataset("foxxy-hm/e2eqa-wiki", data_files="wikipedia_20220620_short.csv")["train"].to_pandas()
12
- df_wiki.title = df_wiki.title.apply(str)
13
-
14
- entity_dict = load_dataset("foxxy-hm/e2eqa-wiki", data_files="processed/entities.json")["train"].to_dict()
15
- new_dict = dict()
16
- for key, val in entity_dict.items():
17
- val = val[0].replace("wiki/", "").replace("_", " ")
18
- entity_dict[key] = val
19
- key = preprocess(key)
20
- new_dict[key.lower()] = val
21
- entity_dict.update(new_dict)
22
- title2idx = dict([(x.strip(), y) for x, y in zip(df_wiki.title, df_wiki.index.values)])
23
-
24
- qa_model = QAEnsembleModel("nguyenvulebinh/vi-mrc-large", ["models/qa_model_robust.bin"], entity_dict)
25
- pairwise_model_stage1 = PairwiseModel("nguyenvulebinh/vi-mrc-base")#.half()
26
- pairwise_model_stage1.load_state_dict(torch.load("models/pairwise_v2.bin", map_location=torch.device('cpu')))
27
- pairwise_model_stage1.eval()
28
-
29
- pairwise_model_stage2 = PairwiseModel("nguyenvulebinh/vi-mrc-base")#.half()
30
- pairwise_model_stage2.load_state_dict(torch.load("models/pairwise_stage2_seed0.bin", map_location=torch.device('cpu')))
31
-
32
- bm25_model_stage1 = BM25Gensim("models/bm25_stage1/", entity_dict, title2idx)
33
- bm25_model_stage2_full = BM25Gensim("models/bm25_stage2/full_text/", entity_dict, title2idx)
34
- bm25_model_stage2_title = BM25Gensim("models/bm25_stage2/title/", entity_dict, title2idx)
35
-
36
- def get_answer_e2e(question):
37
- #Bm25 retrieval for top200 candidates
38
- query = preprocess(question).lower()
39
- top_n, bm25_scores = bm25_model_stage1.get_topk_stage1(query, topk=200)
40
- titles = [preprocess(df_wiki_windows.title.values[i]) for i in top_n]
41
- texts = [preprocess(df_wiki_windows.text.values[i]) for i in top_n]
42
-
43
- #Reranking with pairwise model for top10
44
- question = preprocess(question)
45
- ranking_preds = pairwise_model_stage1.stage1_ranking(question, texts)
46
- ranking_scores = ranking_preds * bm25_scores
47
-
48
- #Question answering
49
- best_idxs = np.argsort(ranking_scores)[-10:]
50
- ranking_scores = np.array(ranking_scores)[best_idxs]
51
- texts = np.array(texts)[best_idxs]
52
- best_answer = qa_model(question, texts, ranking_scores)
53
- if best_answer is None:
54
- return "Chịu"
55
- bm25_answer = preprocess(str(best_answer).lower(), max_length=128, remove_puncts=True)
56
-
57
- #Entity mapping
58
- if not check_number(bm25_answer):
59
- bm25_question = preprocess(str(question).lower(), max_length=128, remove_puncts=True)
60
- bm25_question_answer = bm25_question + " " + bm25_answer
61
- candidates, scores = bm25_model_stage2_title.get_topk_stage2(bm25_answer, raw_answer=best_answer)
62
- titles = [df_wiki.title.values[i] for i in candidates]
63
- texts = [df_wiki.text.values[i] for i in candidates]
64
- ranking_preds = pairwise_model_stage2.stage2_ranking(question, best_answer, titles, texts)
65
- if ranking_preds.max() >= 0.1:
66
- final_answer = titles[ranking_preds.argmax()]
67
- else:
68
- candidates, scores = bm25_model_stage2_full.get_topk_stage2(bm25_question_answer)
69
- titles = [df_wiki.title.values[i] for i in candidates] + titles
70
- texts = [df_wiki.text.values[i] for i in candidates] + texts
71
- ranking_preds = np.concatenate(
72
- [pairwise_model_stage2.stage2_ranking(question, best_answer, titles, texts), ranking_preds])
73
- final_answer = "wiki/"+titles[ranking_preds.argmax()].replace(" ","_")
74
- else:
75
- final_answer = bm25_answer.lower()
76
- return final_answer