Spaces:
Paused
Paused
File size: 4,415 Bytes
05dc08c ed18722 0b52243 46990e9 05dc08c 1a0fafb 05dc08c fdfb527 05dc08c 0b52243 ed18722 fdfb527 05dc08c 1a0fafb 05dc08c 2712129 05dc08c 46990e9 2712129 46990e9 2712129 46990e9 1a0fafb 46990e9 1a0fafb 46990e9 05dc08c 2712129 05dc08c 1a0fafb 05dc08c ed18722 46990e9 2712129 46990e9 521c52d 46990e9 2712129 c8e3a10 46990e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import os
import logging
from pathlib import Path
from typing import Optional
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
import torch
# Set up logging
logger = logging.getLogger(__name__)
# Define the model directory
MODEL_DIR = Path(__file__).parent.parent / "_models"
MODEL_DIR.mkdir(parents=True, exist_ok=True) # Create the directory if it doesn't exist
# Hugging Face authentication token (from environment variable)
AUTH_TOKEN = os.getenv("auth_token")
if not AUTH_TOKEN:
raise ValueError("Hugging Face auth_token environment variable is not set.")
class DataLocation(BaseModel):
"""
Represents the location of a model (local path and optional cloud URI).
"""
local_path: str
cloud_uri: Optional[str] = None
def exists_or_download(self):
"""
Check if the model exists locally. If not, download it from Hugging Face.
"""
if not os.path.exists(self.local_path):
if self.cloud_uri is not None:
logger.warning(f"Downloading model from Hugging Face: {self.cloud_uri}")
# Download from Hugging Face
tokenizer = AutoTokenizer.from_pretrained(
self.cloud_uri, token=AUTH_TOKEN
)
model = AutoModelForQuestionAnswering.from_pretrained(
self.cloud_uri, token=AUTH_TOKEN
)
# Save the model and tokenizer locally
tokenizer.save_pretrained(self.local_path)
model.save_pretrained(self.local_path)
logger.info(f"Model saved to: {self.local_path}")
else:
raise ValueError(f"Model not found locally and no cloud URI provided: {self.local_path}")
return self.local_path
# Define the model location
class QAModel:
def __init__(self,model_name,model_locaton):
"""
Initialize the QA model and tokenizer.
"""
self.model_name = model_name
self.model_location = model_locaton
self.tokenizer = None
self.model = None
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self._load_model() # Call the method to load the model and tokenizer
def _load_model(self):
"""
Load the tokenizer and model.
"""
# Ensure the model is downloaded
model_path = self.model_location.exists_or_download()
# Load the tokenizer and model from the local path
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = AutoModelForQuestionAnswering.from_pretrained(model_path,return_dict=True).to(self.device)
logger.info(f"Loaded QA model: {self.model_name}")
def inference_qa(self, context: str, question: str):
"""
Perform question-answering inference.
Args:
context (str): The text passage or document.
question (str): The question to be answered.
Returns:
str: The predicted answer.
"""
if self.tokenizer is None or self.model is None:
raise ValueError("Model or tokenizer is not loaded.")
# Tokenize inputs
inputs = self.tokenizer(question, context, return_tensors="pt", truncation=True, padding=True)
inputs = {key: value.to(self.device) for key, value in inputs.items()}
# Perform inference
with torch.no_grad():
outputs = self.model(**inputs)
# Extract answer
answer_start_index = outputs.start_logits.argmax()
answer_end_index = outputs.end_logits.argmax()
predict_answer_tokens = inputs["input_ids"][0, answer_start_index : answer_end_index + 1]
answer = self.tokenizer.decode(predict_answer_tokens, skip_special_tokens=True)
return answer
def load_qa_pipeline(model_name: str = "Gowtham122/albertqa"):
"""
Load the QA model and tokenizer.
"""
model_location = DataLocation(
local_path=str(MODEL_DIR / model_name.replace("/", "-")),
cloud_uri=model_name, # Hugging Face model ID
)
qa_model = QAModel(model_name,model_location)
return qa_model
def inference_qa(qa_pipeline, context: str, question: str):
"""
Perform QA inference using the loaded pipeline.
"""
return qa_pipeline.inference_qa(context, question) |