Spaces:
Sleeping
Sleeping
apply fix to config dict passing for inference
Browse files- inference.py +33 -21
inference.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
import torch
|
| 2 |
-
from transformers import AutoTokenizer, AutoModelForSequenceClassification, ModernBertConfig
|
| 3 |
from typing import Dict, Any
|
| 4 |
import yaml
|
| 5 |
import os
|
|
@@ -84,36 +84,48 @@ class SentimentInference:
|
|
| 84 |
# Load from Hugging Face Hub
|
| 85 |
print(f"[INFERENCE_LOG] Attempting to load model from HUGGING_FACE_HUB: {model_hf_repo_id}") # Logging
|
| 86 |
|
| 87 |
-
|
| 88 |
-
# We just add/override num_labels, pooling_strategy, num_weighted_layers if they are in our local config.yaml
|
| 89 |
-
# as these might be specific to our fine-tuning and not in the Hub's default config.json.
|
| 90 |
-
hub_config_overrides = {
|
| 91 |
"num_labels": model_yaml_cfg.get('num_labels', 1),
|
| 92 |
"pooling_strategy": model_yaml_cfg.get('pooling_strategy', 'mean'),
|
| 93 |
-
"num_weighted_layers": model_yaml_cfg.get('num_weighted_layers', 6)
|
| 94 |
}
|
| 95 |
-
print(f"[INFERENCE_LOG] HUB_LOAD:
|
| 96 |
|
| 97 |
try:
|
| 98 |
-
#
|
| 99 |
-
|
| 100 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
self.model = ModernBertForSentiment.from_pretrained(
|
| 102 |
model_hf_repo_id,
|
| 103 |
-
|
| 104 |
)
|
| 105 |
-
print(f"[INFERENCE_LOG] HUB_LOAD: Model ModernBertForSentiment loaded successfully from {model_hf_repo_id}.") # Logging
|
| 106 |
except Exception as e:
|
| 107 |
-
print(f"[INFERENCE_LOG] HUB_LOAD: Error loading ModernBertForSentiment from {model_hf_repo_id}: {e}") # Logging
|
| 108 |
print(f"[INFERENCE_LOG] HUB_LOAD: Falling back to AutoModelForSequenceClassification for {model_hf_repo_id}.") # Logging
|
| 109 |
-
|
| 110 |
-
#
|
| 111 |
-
#
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
self.model.eval()
|
| 119 |
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig, ModernBertConfig
|
| 3 |
from typing import Dict, Any
|
| 4 |
import yaml
|
| 5 |
import os
|
|
|
|
| 84 |
# Load from Hugging Face Hub
|
| 85 |
print(f"[INFERENCE_LOG] Attempting to load model from HUGGING_FACE_HUB: {model_hf_repo_id}") # Logging
|
| 86 |
|
| 87 |
+
hub_config_params = {
|
|
|
|
|
|
|
|
|
|
| 88 |
"num_labels": model_yaml_cfg.get('num_labels', 1),
|
| 89 |
"pooling_strategy": model_yaml_cfg.get('pooling_strategy', 'mean'),
|
| 90 |
+
"num_weighted_layers": model_yaml_cfg.get('num_weighted_layers', 6)
|
| 91 |
}
|
| 92 |
+
print(f"[INFERENCE_LOG] HUB_LOAD: Parameters to update Hub config: {hub_config_params}") # Logging
|
| 93 |
|
| 94 |
try:
|
| 95 |
+
# Step 1: Load config from Hub, allowing for our custom ModernBertConfig
|
| 96 |
+
config = ModernBertConfig.from_pretrained(model_hf_repo_id)
|
| 97 |
+
# Step 2: Update the loaded config with our specific parameters
|
| 98 |
+
for key, value in hub_config_params.items():
|
| 99 |
+
setattr(config, key, value)
|
| 100 |
+
print(f"[INFERENCE_LOG] HUB_LOAD: Updated config: {config.to_diff_dict()}")
|
| 101 |
+
|
| 102 |
+
# Step 3: Load model with the updated config
|
| 103 |
self.model = ModernBertForSentiment.from_pretrained(
|
| 104 |
model_hf_repo_id,
|
| 105 |
+
config=config
|
| 106 |
)
|
| 107 |
+
print(f"[INFERENCE_LOG] HUB_LOAD: Model ModernBertForSentiment loaded successfully from {model_hf_repo_id} with updated config.") # Logging
|
| 108 |
except Exception as e:
|
| 109 |
+
print(f"[INFERENCE_LOG] HUB_LOAD: Error loading ModernBertForSentiment from {model_hf_repo_id} with explicit config: {e}") # Logging
|
| 110 |
print(f"[INFERENCE_LOG] HUB_LOAD: Falling back to AutoModelForSequenceClassification for {model_hf_repo_id}.") # Logging
|
| 111 |
+
|
| 112 |
+
# Fallback: Try with AutoModelForSequenceClassification
|
| 113 |
+
# Load its config (could be BertConfig or ModernBertConfig if auto-detected)
|
| 114 |
+
# AutoConfig should ideally resolve to ModernBertConfig if architectures field is set in Hub's config.json
|
| 115 |
+
try:
|
| 116 |
+
config_fallback = AutoConfig.from_pretrained(model_hf_repo_id)
|
| 117 |
+
for key, value in hub_config_params.items():
|
| 118 |
+
setattr(config_fallback, key, value)
|
| 119 |
+
print(f"[INFERENCE_LOG] HUB_LOAD_FALLBACK: Updated fallback config: {config_fallback.to_diff_dict()}")
|
| 120 |
+
|
| 121 |
+
self.model = AutoModelForSequenceClassification.from_pretrained(
|
| 122 |
+
model_hf_repo_id,
|
| 123 |
+
config=config_fallback
|
| 124 |
+
)
|
| 125 |
+
print(f"[INFERENCE_LOG] HUB_LOAD_FALLBACK: AutoModelForSequenceClassification loaded for {model_hf_repo_id} with updated config.") # Logging
|
| 126 |
+
except Exception as e_fallback:
|
| 127 |
+
print(f"[INFERENCE_LOG] HUB_LOAD_FALLBACK: Critical error during fallback load: {e_fallback}")
|
| 128 |
+
raise e_fallback # Re-raise if fallback also fails catastrophically
|
| 129 |
|
| 130 |
self.model.eval()
|
| 131 |
|