Gowtham122 commited on
Commit
46990e9
·
verified ·
1 Parent(s): 8613386

create models.py with AlbertQA model

Browse files
Files changed (1) hide show
  1. app/models.py +66 -0
app/models.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AlbertForQuestionAnswering
2
+ import torch
3
+
4
+ class QAModel:
5
+ def __init__(self, model_name: str = "twmkn9/albert-base-v2-squad2"):
6
+ """
7
+ Initialize the QA model and tokenizer.
8
+ """
9
+ self.model_name = model_name
10
+ self.tokenizer = None
11
+ self.model = None
12
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ self.load_model()
14
+
15
+ def load_model(self):
16
+ """
17
+ Load the tokenizer and model.
18
+ """
19
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
20
+ self.model = AlbertForQuestionAnswering.from_pretrained(self.model_name).to(self.device)
21
+ logger.info(f"Loaded QA model: {self.model_name}")
22
+
23
+ def inference_qa(self, context: str, question: str):
24
+ """
25
+ Perform question-answering inference.
26
+ Args:
27
+ context (str): The text passage or document.
28
+ question (str): The question to be answered.
29
+ Returns:
30
+ str: The predicted answer.
31
+ """
32
+ if self.tokenizer is None or self.model is None:
33
+ raise ValueError("Model or tokenizer is not loaded.")
34
+
35
+ # Tokenize inputs
36
+ inputs = self.tokenizer(question, context, return_tensors="pt", truncation=True, padding=True)
37
+ inputs = {key: value.to(self.device) for key, value in inputs.items()}
38
+
39
+ # Perform inference
40
+ with torch.no_grad():
41
+ outputs = self.model(**inputs)
42
+
43
+ # Extract answer
44
+ answer_start_index = outputs.start_logits.argmax()
45
+ answer_end_index = outputs.end_logits.argmax()
46
+ predict_answer_tokens = inputs["input_ids"][0, answer_start_index : answer_end_index + 1]
47
+ answer = self.tokenizer.decode(predict_answer_tokens, skip_special_tokens=True)
48
+
49
+ return answer
50
+
51
+ # Global instance of the QA model
52
+ qa_model = QAModel()
53
+
54
+ def load_qa_pipeline(model_name: str = "twmkn9/albert-base-v2-squad2"):
55
+ """
56
+ Load the QA model and tokenizer.
57
+ """
58
+ global qa_model
59
+ qa_model = QAModel(model_name)
60
+ return qa_model
61
+
62
+ def inference_qa(qa_pipeline, context: str, question: str):
63
+ """
64
+ Perform QA inference using the loaded pipeline.
65
+ """
66
+ return qa_pipeline.inference_qa(context, question)