Gowtham122 commited on
Commit
1a0fafb
·
verified ·
1 Parent(s): 1f183f7

Update app/models.py

Browse files
Files changed (1) hide show
  1. app/models.py +9 -10
app/models.py CHANGED
@@ -31,21 +31,20 @@ class DataLocation(BaseModel):
31
  """
32
  if not os.path.exists(self.local_path):
33
  if self.cloud_uri is not None:
34
- logger.warning(f"Downloading model from cloud URI: {self.cloud_uri}")
35
- # Implement cloud download logic here if needed
36
- else:
37
- logger.info(f"Downloading model from Hugging Face to: {self.local_path}")
38
  # Download from Hugging Face
39
  tokenizer = AutoTokenizer.from_pretrained(
40
- self.cloud_uri or self.local_path, use_auth_token=AUTH_TOKEN
41
  )
42
  model = AlbertForQuestionAnswering.from_pretrained(
43
- self.cloud_uri or self.local_path, use_auth_token=AUTH_TOKEN
44
  )
45
  # Save the model and tokenizer locally
46
  tokenizer.save_pretrained(self.local_path)
47
  model.save_pretrained(self.local_path)
48
  logger.info(f"Model saved to: {self.local_path}")
 
 
49
  return self.local_path
50
 
51
  # Define the model location
@@ -64,16 +63,16 @@ class QAModel:
64
  self.tokenizer = None
65
  self.model = None
66
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
67
- self.load_model()
68
 
69
- def load_model(self):
70
  """
71
  Load the tokenizer and model.
72
  """
73
  # Ensure the model is downloaded
74
  model_path = MODEL_LOCATION.exists_or_download()
75
 
76
- # Load the tokenizer and model
77
  self.tokenizer = AutoTokenizer.from_pretrained(model_path)
78
  self.model = AlbertForQuestionAnswering.from_pretrained(model_path).to(self.device)
79
  logger.info(f"Loaded QA model: {self.model_name}")
@@ -114,7 +113,7 @@ def load_qa_pipeline():
114
  Load the QA model and tokenizer.
115
  """
116
  global qa_model
117
- qa_model = QAModel()
118
  return qa_model
119
 
120
  def inference_qa(qa_pipeline, context: str, question: str):
 
31
  """
32
  if not os.path.exists(self.local_path):
33
  if self.cloud_uri is not None:
34
+ logger.warning(f"Downloading model from Hugging Face: {self.cloud_uri}")
 
 
 
35
  # Download from Hugging Face
36
  tokenizer = AutoTokenizer.from_pretrained(
37
+ self.cloud_uri, use_auth_token=AUTH_TOKEN
38
  )
39
  model = AlbertForQuestionAnswering.from_pretrained(
40
+ self.cloud_uri, use_auth_token=AUTH_TOKEN
41
  )
42
  # Save the model and tokenizer locally
43
  tokenizer.save_pretrained(self.local_path)
44
  model.save_pretrained(self.local_path)
45
  logger.info(f"Model saved to: {self.local_path}")
46
+ else:
47
+ raise ValueError(f"Model not found locally and no cloud URI provided: {self.local_path}")
48
  return self.local_path
49
 
50
  # Define the model location
 
63
  self.tokenizer = None
64
  self.model = None
65
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
66
+ self._load_model() # Call the method to load the model and tokenizer
67
 
68
+ def _load_model(self):
69
  """
70
  Load the tokenizer and model.
71
  """
72
  # Ensure the model is downloaded
73
  model_path = MODEL_LOCATION.exists_or_download()
74
 
75
+ # Load the tokenizer and model from the local path
76
  self.tokenizer = AutoTokenizer.from_pretrained(model_path)
77
  self.model = AlbertForQuestionAnswering.from_pretrained(model_path).to(self.device)
78
  logger.info(f"Loaded QA model: {self.model_name}")
 
113
  Load the QA model and tokenizer.
114
  """
115
  global qa_model
116
+ qa_model = QAModel() # This will automatically call `_load_model` during initialization
117
  return qa_model
118
 
119
  def inference_qa(qa_pipeline, context: str, question: str):