File size: 4,415 Bytes
05dc08c
 
 
 
 
ed18722
0b52243
46990e9
 
05dc08c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a0fafb
05dc08c
 
fdfb527
05dc08c
0b52243
ed18722
fdfb527
05dc08c
 
 
 
 
1a0fafb
 
05dc08c
 
 
2712129
05dc08c
46990e9
2712129
46990e9
 
 
2712129
 
46990e9
 
 
1a0fafb
46990e9
1a0fafb
46990e9
 
 
05dc08c
2712129
05dc08c
1a0fafb
05dc08c
ed18722
46990e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2712129
46990e9
521c52d
46990e9
 
 
2712129
 
 
 
c8e3a10
46990e9
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
import logging
from pathlib import Path
from typing import Optional
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForQuestionAnswering

import torch

# Set up logging
logger = logging.getLogger(__name__)

# Define the model directory
MODEL_DIR = Path(__file__).parent.parent / "_models"
MODEL_DIR.mkdir(parents=True, exist_ok=True)  # Create the directory if it doesn't exist

# Hugging Face authentication token (from environment variable)
AUTH_TOKEN = os.getenv("auth_token")
if not AUTH_TOKEN:
    raise ValueError("Hugging Face auth_token environment variable is not set.")

class DataLocation(BaseModel):
    """
    Represents the location of a model (local path and optional cloud URI).
    """
    local_path: str
    cloud_uri: Optional[str] = None

    def exists_or_download(self):
        """
        Check if the model exists locally. If not, download it from Hugging Face.
        """
        if not os.path.exists(self.local_path):
            if self.cloud_uri is not None:
                logger.warning(f"Downloading model from Hugging Face: {self.cloud_uri}")
                # Download from Hugging Face
                tokenizer = AutoTokenizer.from_pretrained(
                    self.cloud_uri, token=AUTH_TOKEN
                )

                model = AutoModelForQuestionAnswering.from_pretrained(
                    self.cloud_uri, token=AUTH_TOKEN
                )
                # Save the model and tokenizer locally
                tokenizer.save_pretrained(self.local_path)
                model.save_pretrained(self.local_path)
                logger.info(f"Model saved to: {self.local_path}")
            else:
                raise ValueError(f"Model not found locally and no cloud URI provided: {self.local_path}")
        return self.local_path

# Define the model location


class QAModel:
    def __init__(self,model_name,model_locaton):
        """
        Initialize the QA model and tokenizer.
        """
        self.model_name = model_name
        self.model_location = model_locaton
        self.tokenizer = None
        self.model = None
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self._load_model()  # Call the method to load the model and tokenizer

    def _load_model(self):
        """
        Load the tokenizer and model.
        """
        # Ensure the model is downloaded
        model_path = self.model_location.exists_or_download()

        # Load the tokenizer and model from the local path
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.model = AutoModelForQuestionAnswering.from_pretrained(model_path,return_dict=True).to(self.device)
        logger.info(f"Loaded QA model: {self.model_name}")

    def inference_qa(self, context: str, question: str):
        """
        Perform question-answering inference.
        Args:
            context (str): The text passage or document.
            question (str): The question to be answered.
        Returns:
            str: The predicted answer.
        """
        if self.tokenizer is None or self.model is None:
            raise ValueError("Model or tokenizer is not loaded.")

        # Tokenize inputs
        inputs = self.tokenizer(question, context, return_tensors="pt", truncation=True, padding=True)
        inputs = {key: value.to(self.device) for key, value in inputs.items()}

        # Perform inference
        with torch.no_grad():
            outputs = self.model(**inputs)

        # Extract answer
        answer_start_index = outputs.start_logits.argmax()
        answer_end_index = outputs.end_logits.argmax()
        predict_answer_tokens = inputs["input_ids"][0, answer_start_index : answer_end_index + 1]
        answer = self.tokenizer.decode(predict_answer_tokens, skip_special_tokens=True)

        return answer



def load_qa_pipeline(model_name: str = "Gowtham122/albertqa"):
    """
    Load the QA model and tokenizer.
    """
    model_location = DataLocation(
        local_path=str(MODEL_DIR / model_name.replace("/", "-")),
        cloud_uri=model_name,  # Hugging Face model ID
    )
    qa_model = QAModel(model_name,model_location)
    return qa_model

def inference_qa(qa_pipeline, context: str, question: str):
    """
    Perform QA inference using the loaded pipeline.
    """
    return qa_pipeline.inference_qa(context, question)