Spaces:
Sleeping
Sleeping
new test inference
Browse files- inference.py +60 -46
inference.py
CHANGED
|
@@ -29,78 +29,92 @@ class SentimentInference:
|
|
| 29 |
if not tokenizer_hf_repo_id and not model_hf_repo_id:
|
| 30 |
raise ValueError("Either model.tokenizer_name_or_path or model.name_or_path (as fallback for tokenizer) must be specified in config.yaml")
|
| 31 |
effective_tokenizer_repo_id = tokenizer_hf_repo_id or model_hf_repo_id
|
| 32 |
-
print(f"Loading tokenizer from: {effective_tokenizer_repo_id}")
|
| 33 |
self.tokenizer = AutoTokenizer.from_pretrained(effective_tokenizer_repo_id)
|
| 34 |
|
| 35 |
# --- Model Loading --- #
|
| 36 |
# Determine if we are loading from a local .pt file or from Hugging Face Hub
|
| 37 |
load_from_local_pt = False
|
| 38 |
if local_model_weights_path and os.path.isfile(local_model_weights_path):
|
| 39 |
-
print(f"Found local model weights path: {local_model_weights_path}")
|
| 40 |
print(f"--- Debug: Found local model weights path: {local_model_weights_path} ---") # Add this
|
| 41 |
load_from_local_pt = True
|
| 42 |
elif not model_hf_repo_id:
|
| 43 |
raise ValueError("No local model_path found and model.name_or_path (for Hub) is not specified in config.yaml")
|
| 44 |
|
|
|
|
| 45 |
print(f"--- Debug: load_from_local_pt is: {load_from_local_pt} ---") # Add this
|
| 46 |
|
| 47 |
if load_from_local_pt:
|
| 48 |
-
print("Attempting to load model from
|
| 49 |
print("--- Debug: Entering LOCAL .pt loading path ---") # Add this
|
| 50 |
# Base BERT config must still be loaded, usually from a Hub ID (e.g., original base model)
|
| 51 |
# This base_model_for_config_id is crucial for building the correct ModernBertForSentiment structure.
|
| 52 |
-
base_model_for_config_id = model_yaml_cfg.get('base_model_for_config',
|
| 53 |
-
print(f"--- Debug: base_model_for_config_id (for local .pt): {base_model_for_config_id} ---") # Add this
|
| 54 |
if not base_model_for_config_id:
|
| 55 |
-
raise ValueError("
|
| 56 |
|
| 57 |
-
print(f"
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
|
| 67 |
-
print("
|
| 68 |
-
self.model = ModernBertForSentiment(
|
| 69 |
|
| 70 |
-
print(f"Loading
|
| 71 |
checkpoint = torch.load(local_model_weights_path, map_location=torch.device('cpu'))
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
self.model.load_state_dict(model_state_to_load)
|
| 77 |
-
print(f"Model loaded successfully from local checkpoint: {local_model_weights_path}.")
|
| 78 |
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
print(f"
|
| 82 |
-
print(f"--- Debug: model_hf_repo_id (for Hub loading): {model_hf_repo_id} ---") # Add this
|
| 83 |
-
if not model_hf_repo_id:
|
| 84 |
-
raise ValueError("model.name_or_path must be specified in config.yaml for Hub loading.")
|
| 85 |
|
| 86 |
-
|
| 87 |
-
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
-
#
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
-
print(f"Instantiating and loading model weights for {model_hf_repo_id}...")
|
| 96 |
-
self.model = AutoModelForSequenceClassification.from_pretrained(
|
| 97 |
-
model_hf_repo_id,
|
| 98 |
-
config=loaded_config,
|
| 99 |
-
trust_remote_code=True,
|
| 100 |
-
force_download=True # <--- TEMPORARY - remove when everything is working
|
| 101 |
-
)
|
| 102 |
-
print(f"Model {model_hf_repo_id} loaded successfully from Hugging Face Hub.")
|
| 103 |
-
|
| 104 |
self.model.eval()
|
| 105 |
|
| 106 |
def predict(self, text: str) -> Dict[str, Any]:
|
|
|
|
| 29 |
if not tokenizer_hf_repo_id and not model_hf_repo_id:
|
| 30 |
raise ValueError("Either model.tokenizer_name_or_path or model.name_or_path (as fallback for tokenizer) must be specified in config.yaml")
|
| 31 |
effective_tokenizer_repo_id = tokenizer_hf_repo_id or model_hf_repo_id
|
| 32 |
+
print(f"[INFERENCE_LOG] Loading tokenizer from: {effective_tokenizer_repo_id}") # Logging
|
| 33 |
self.tokenizer = AutoTokenizer.from_pretrained(effective_tokenizer_repo_id)
|
| 34 |
|
| 35 |
# --- Model Loading --- #
|
| 36 |
# Determine if we are loading from a local .pt file or from Hugging Face Hub
|
| 37 |
load_from_local_pt = False
|
| 38 |
if local_model_weights_path and os.path.isfile(local_model_weights_path):
|
| 39 |
+
print(f"[INFERENCE_LOG] Found local model weights path: {local_model_weights_path}") # Logging
|
| 40 |
print(f"--- Debug: Found local model weights path: {local_model_weights_path} ---") # Add this
|
| 41 |
load_from_local_pt = True
|
| 42 |
elif not model_hf_repo_id:
|
| 43 |
raise ValueError("No local model_path found and model.name_or_path (for Hub) is not specified in config.yaml")
|
| 44 |
|
| 45 |
+
print(f"[INFERENCE_LOG] load_from_local_pt: {load_from_local_pt}") # Logging
|
| 46 |
print(f"--- Debug: load_from_local_pt is: {load_from_local_pt} ---") # Add this
|
| 47 |
|
| 48 |
if load_from_local_pt:
|
| 49 |
+
print("[INFERENCE_LOG] Attempting to load model from LOCAL .pt checkpoint...") # Logging
|
| 50 |
print("--- Debug: Entering LOCAL .pt loading path ---") # Add this
|
| 51 |
# Base BERT config must still be loaded, usually from a Hub ID (e.g., original base model)
|
| 52 |
# This base_model_for_config_id is crucial for building the correct ModernBertForSentiment structure.
|
| 53 |
+
base_model_for_config_id = model_yaml_cfg.get('base_model_for_config', model_yaml_cfg.get('name_or_path'))
|
|
|
|
| 54 |
if not base_model_for_config_id:
|
| 55 |
+
raise ValueError("model.base_model_for_config or model.name_or_path must be specified in config.yaml when loading local .pt for ModernBertForSentiment structure.")
|
| 56 |
|
| 57 |
+
print(f"[INFERENCE_LOG] LOCAL_PT_LOAD: base_model_for_config_id: {base_model_for_config_id}") # Logging
|
| 58 |
+
|
| 59 |
+
model_config = ModernBertConfig.from_pretrained(
|
| 60 |
+
base_model_for_config_id,
|
| 61 |
+
num_labels=model_yaml_cfg.get('num_labels', 1), # from config.yaml via model_yaml_cfg
|
| 62 |
+
pooling_strategy=model_yaml_cfg.get('pooling_strategy', 'mean'), # from config.yaml via model_yaml_cfg
|
| 63 |
+
num_weighted_layers=model_yaml_cfg.get('num_weighted_layers', 4) # from config.yaml via model_yaml_cfg
|
| 64 |
+
)
|
| 65 |
+
print(f"[INFERENCE_LOG] LOCAL_PT_LOAD: Loaded ModernBertConfig: {model_config.to_diff_dict()}") # Logging
|
| 66 |
|
| 67 |
+
print(f"[INFERENCE_LOG] LOCAL_PT_LOAD: Initializing ModernBertForSentiment with this config.") # Logging
|
| 68 |
+
self.model = ModernBertForSentiment(config=model_config)
|
| 69 |
|
| 70 |
+
print(f"[INFERENCE_LOG] LOCAL_PT_LOAD: Loading weights from checkpoint: {local_model_weights_path}") # Logging
|
| 71 |
checkpoint = torch.load(local_model_weights_path, map_location=torch.device('cpu'))
|
| 72 |
+
|
| 73 |
+
state_dict_to_load = checkpoint.get('model_state_dict', checkpoint.get('state_dict', checkpoint))
|
| 74 |
+
if not isinstance(state_dict_to_load, dict):
|
| 75 |
+
raise TypeError(f"Loaded checkpoint from {local_model_weights_path} is not a dict or does not contain 'model_state_dict' or 'state_dict'.")
|
|
|
|
|
|
|
| 76 |
|
| 77 |
+
# Log first few keys for debugging
|
| 78 |
+
first_few_keys = list(state_dict_to_load.keys())[:5]
|
| 79 |
+
print(f"[INFERENCE_LOG] LOCAL_PT_LOAD: First few keys from checkpoint state_dict: {first_few_keys}") # Logging
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
+
self.model.load_state_dict(state_dict_to_load)
|
| 82 |
+
print(f"[INFERENCE_LOG] LOCAL_PT_LOAD: Weights loaded successfully into ModernBertForSentiment from {local_model_weights_path}.") # Logging
|
| 83 |
+
else:
|
| 84 |
+
# Load from Hugging Face Hub
|
| 85 |
+
print(f"[INFERENCE_LOG] Attempting to load model from HUGGING_FACE_HUB: {model_hf_repo_id}") # Logging
|
| 86 |
|
| 87 |
+
# Here, we use the config that's packaged with the model on the Hub by default.
|
| 88 |
+
# We just add/override num_labels, pooling_strategy, num_weighted_layers if they are in our local config.yaml
|
| 89 |
+
# as these might be specific to our fine-tuning and not in the Hub's default config.json.
|
| 90 |
+
hub_config_overrides = {
|
| 91 |
+
"num_labels": model_yaml_cfg.get('num_labels', 1),
|
| 92 |
+
"pooling_strategy": model_yaml_cfg.get('pooling_strategy', 'mean'),
|
| 93 |
+
"num_weighted_layers": model_yaml_cfg.get('num_weighted_layers', 6) # Default to 6 now
|
| 94 |
+
}
|
| 95 |
+
print(f"[INFERENCE_LOG] HUB_LOAD: Overrides for Hub config: {hub_config_overrides}") # Logging
|
| 96 |
+
|
| 97 |
+
try:
|
| 98 |
+
# Using ModernBertForSentiment.from_pretrained directly.
|
| 99 |
+
# This assumes the config.json on the Hub for 'model_hf_repo_id' is compatible
|
| 100 |
+
# or that from_pretrained can correctly initialize ModernBertForSentiment with it.
|
| 101 |
+
self.model = ModernBertForSentiment.from_pretrained(
|
| 102 |
+
model_hf_repo_id,
|
| 103 |
+
**hub_config_overrides
|
| 104 |
+
)
|
| 105 |
+
print(f"[INFERENCE_LOG] HUB_LOAD: Model ModernBertForSentiment loaded successfully from {model_hf_repo_id}.") # Logging
|
| 106 |
+
except Exception as e:
|
| 107 |
+
print(f"[INFERENCE_LOG] HUB_LOAD: Error loading ModernBertForSentiment from {model_hf_repo_id}: {e}") # Logging
|
| 108 |
+
print(f"[INFERENCE_LOG] HUB_LOAD: Falling back to AutoModelForSequenceClassification for {model_hf_repo_id}.") # Logging
|
| 109 |
+
# Fallback: Try with AutoModelForSequenceClassification if ModernBertForSentiment fails
|
| 110 |
+
# This might happen if the Hub model isn't strictly saved as a ModernBertForSentiment type
|
| 111 |
+
# or if its config.json doesn't have _custom_class set, etc.
|
| 112 |
+
self.model = AutoModelForSequenceClassification.from_pretrained(
|
| 113 |
+
model_hf_repo_id,
|
| 114 |
+
**hub_config_overrides
|
| 115 |
+
)
|
| 116 |
+
print(f"[INFERENCE_LOG] HUB_LOAD: AutoModelForSequenceClassification loaded for {model_hf_repo_id}.") # Logging
|
| 117 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
self.model.eval()
|
| 119 |
|
| 120 |
def predict(self, text: str) -> Dict[str, Any]:
|