| from models.pairwise_model import * |
| from features.text_utils import * |
| import regex as re |
| from models.bm25_utils import BM25Gensim |
| from models.qa_model import * |
| from tqdm.auto import tqdm |
| tqdm.pandas() |
| from datasets import load_dataset |
| |
| from typing import Dict, List, Any |
|
|
| class InferencePipeline(): |
| def __init__(self, path=""): |
| df_wiki_windows = load_dataset("foxxy-hm/e2eqa-wiki", data_files="processed/wikipedia_20220620_cleaned_v2.csv")["train"].to_pandas() |
| df_wiki = load_dataset("foxxy-hm/e2eqa-wiki", data_files="wikipedia_20220620_short.csv")["train"].to_pandas() |
| df_wiki.title = df_wiki.title.apply(str) |
|
|
| entity_dict = load_dataset("foxxy-hm/e2eqa-wiki", data_files="processed/entities.json")["train"].to_dict() |
| new_dict = dict() |
| for key, val in entity_dict.items(): |
| val = val[0].replace("wiki/", "").replace("_", " ") |
| entity_dict[key] = val |
| key = preprocess(key) |
| new_dict[key.lower()] = val |
| entity_dict.update(new_dict) |
| title2idx = dict([(x.strip(), y) for x, y in zip(df_wiki.title, df_wiki.index.values)]) |
|
|
| qa_model = QAEnsembleModel("nguyenvulebinh/vi-mrc-large", ["qa_model_robust.bin"], entity_dict) |
| pairwise_model_stage1 = PairwiseModel("nguyenvulebinh/vi-mrc-base") |
| pairwise_model_stage1.load_state_dict(torch.load("pairwise_v2.bin", map_location=torch.device('cpu'))) |
| pairwise_model_stage1.eval() |
|
|
| pairwise_model_stage2 = PairwiseModel("nguyenvulebinh/vi-mrc-base") |
| pairwise_model_stage2.load_state_dict(torch.load("pairwise_stage2_seed0.bin", map_location=torch.device('cpu'))) |
|
|
| bm25_model_stage1 = BM25Gensim("bm25_stage1/", entity_dict, title2idx) |
| bm25_model_stage2_full = BM25Gensim("bm25_stage2/full_text/", entity_dict, title2idx) |
| bm25_model_stage2_title = BM25Gensim("bm25_stage2/title/", entity_dict, title2idx) |
|
|
| self.qa_model = qa_model |
| self.pairwise_model_stage1 = pairwise_model_stage1 |
| self.pairwise_model_stage2 = pairwise_model_stage2 |
| self.bm25_model_stage1 = bm25_model_stage1 |
| self.bm25_model_stage2_full = bm25_model_stage2_full |
| self.bm25_model_stage2_title = bm25_model_stage2_title |
| |
| def get_answer_e2e(self, question): |
| query = preprocess(question).lower() |
| top_n, bm25_scores = self.bm25_model_stage1.get_topk_stage1(query, topk=200) |
| titles = [preprocess(df_wiki_windows.title.values[i]) for i in top_n] |
| texts = [preprocess(df_wiki_windows.text.values[i]) for i in top_n] |
|
|
| question = preprocess(question) |
| ranking_preds = self.pairwise_model_stage1.stage1_ranking(question, texts) |
| ranking_scores = ranking_preds * bm25_scores |
|
|
| best_idxs = np.argsort(ranking_scores)[-10:] |
| ranking_scores = np.array(ranking_scores)[best_idxs] |
| texts = np.array(texts)[best_idxs] |
| best_answer = self.qa_model(question, texts, ranking_scores) |
| if best_answer is None: |
| return "Chịu" |
| bm25_answer = preprocess(str(best_answer).lower(), max_length=128, remove_puncts=True) |
|
|
| if not check_number(bm25_answer): |
| bm25_question = preprocess(str(question).lower(), max_length=128, remove_puncts=True) |
| bm25_question_answer = bm25_question + " " + bm25_answer |
| candidates, scores = self.bm25_model_stage2_title.get_topk_stage2(bm25_answer, raw_answer=best_answer) |
| titles = [df_wiki.title.values[i] for i in candidates] |
| texts = [df_wiki.text.values[i] for i in candidates] |
| ranking_preds = self.pairwise_model_stage2.stage2_ranking(question, best_answer, titles, texts) |
| if ranking_preds.max() >= 0.1: |
| final_answer = titles[ranking_preds.argmax()] |
| else: |
| candidates, scores = self.bm25_model_stage2_full.get_topk_stage2(bm25_question_answer) |
| titles = [df_wiki.title.values[i] for i in candidates] + titles |
| texts = [df_wiki.text.values[i] for i in candidates] + texts |
| ranking_preds = np.concatenate( |
| [self.pairwise_model_stage2.stage2_ranking(question, best_answer, titles, texts), ranking_preds]) |
| final_answer = "wiki/"+titles[ranking_preds.argmax()].replace(" ","_") |
| else: |
| final_answer = bm25_answer.lower() |
| return final_answer |
|
|
| |
| class EndpointHandler(): |
| def __init__(self, path=""): |
| self.inference_pipeline = InferencePipeline(".") |
| |
| |
| def __call__(self, data: Dict[str, Any]) -> Dict[str, str]: |
| inputs = data.pop("inputs", data) |
| |
| answer = self.inference_pipeline.get_answer_e2e(inputs) |
| return [{"generated_text": answer}] |
|
|