Gowtham122 commited on
Commit
05dc08c
·
verified ·
1 Parent(s): 32213f1

feature to save model locally

Browse files
Files changed (1) hide show
  1. app/models.py +64 -6
app/models.py CHANGED
@@ -1,12 +1,66 @@
 
 
 
 
 
1
  from transformers import AutoTokenizer, AlbertForQuestionAnswering
2
  import torch
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  class QAModel:
5
- def __init__(self, model_name: str = "Gowtham122/albertqa"):
6
  """
7
  Initialize the QA model and tokenizer.
8
  """
9
- self.model_name = model_name
10
  self.tokenizer = None
11
  self.model = None
12
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -16,8 +70,12 @@ class QAModel:
16
  """
17
  Load the tokenizer and model.
18
  """
19
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
20
- self.model = AlbertForQuestionAnswering.from_pretrained(self.model_name).to(self.device)
 
 
 
 
21
  logger.info(f"Loaded QA model: {self.model_name}")
22
 
23
  def inference_qa(self, context: str, question: str):
@@ -51,12 +109,12 @@ class QAModel:
51
  # Global instance of the QA model
52
  qa_model = QAModel()
53
 
54
- def load_qa_pipeline(model_name: str = "Gowtham122/albertqa"):
55
  """
56
  Load the QA model and tokenizer.
57
  """
58
  global qa_model
59
- qa_model = QAModel(model_name)
60
  return qa_model
61
 
62
  def inference_qa(qa_pipeline, context: str, question: str):
 
1
+ import os
2
+ import logging
3
+ from pathlib import Path
4
+ from typing import Optional
5
+ from pydantic import BaseModel
6
  from transformers import AutoTokenizer, AlbertForQuestionAnswering
7
  import torch
8
 
9
+ # Set up logging
10
+ logger = logging.getLogger(__name__)
11
+
12
+ # Define the model directory
13
+ MODEL_DIR = Path(__file__).parent.parent / "_models"
14
+ MODEL_DIR.mkdir(parents=True, exist_ok=True) # Create the directory if it doesn't exist
15
+
16
+ # Hugging Face authentication token (from environment variable)
17
+ AUTH_TOKEN = os.getenv("auth_token")
18
+ if not AUTH_TOKEN:
19
+ raise ValueError("Hugging Face auth_token environment variable is not set.")
20
+
21
+ class DataLocation(BaseModel):
22
+ """
23
+ Represents the location of a model (local path and optional cloud URI).
24
+ """
25
+ local_path: str
26
+ cloud_uri: Optional[str] = None
27
+
28
+ def exists_or_download(self):
29
+ """
30
+ Check if the model exists locally. If not, download it from Hugging Face.
31
+ """
32
+ if not os.path.exists(self.local_path):
33
+ if self.cloud_uri is not None:
34
+ logger.warning(f"Downloading model from cloud URI: {self.cloud_uri}")
35
+ # Implement cloud download logic here if needed
36
+ else:
37
+ logger.info(f"Downloading model from Hugging Face to: {self.local_path}")
38
+ # Download from Hugging Face
39
+ tokenizer = AutoTokenizer.from_pretrained(
40
+ self.cloud_uri or self.local_path, use_auth_token=AUTH_TOKEN
41
+ )
42
+ model = AlbertForQuestionAnswering.from_pretrained(
43
+ self.cloud_uri or self.local_path, use_auth_token=AUTH_TOKEN
44
+ )
45
+ # Save the model and tokenizer locally
46
+ tokenizer.save_pretrained(self.local_path)
47
+ model.save_pretrained(self.local_path)
48
+ logger.info(f"Model saved to: {self.local_path}")
49
+ return self.local_path
50
+
51
+ # Define the model location
52
+ MODEL_NAME = "twmkn9/albert-base-v2-squad2"
53
+ MODEL_LOCATION = DataLocation(
54
+ local_path=str(MODEL_DIR / MODEL_NAME.replace("/", "-")),
55
+ cloud_uri=MODEL_NAME, # Hugging Face model ID
56
+ )
57
+
58
  class QAModel:
59
+ def __init__(self):
60
  """
61
  Initialize the QA model and tokenizer.
62
  """
63
+ self.model_name = MODEL_NAME
64
  self.tokenizer = None
65
  self.model = None
66
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
70
  """
71
  Load the tokenizer and model.
72
  """
73
+ # Ensure the model is downloaded
74
+ model_path = MODEL_LOCATION.exists_or_download()
75
+
76
+ # Load the tokenizer and model
77
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
78
+ self.model = AlbertForQuestionAnswering.from_pretrained(model_path).to(self.device)
79
  logger.info(f"Loaded QA model: {self.model_name}")
80
 
81
  def inference_qa(self, context: str, question: str):
 
109
  # Global instance of the QA model
110
  qa_model = QAModel()
111
 
112
+ def load_qa_pipeline():
113
  """
114
  Load the QA model and tokenizer.
115
  """
116
  global qa_model
117
+ qa_model = QAModel()
118
  return qa_model
119
 
120
  def inference_qa(qa_pipeline, context: str, question: str):