Spaces:
Sleeping
Sleeping
| import torch | |
| from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, precision_score, recall_score, matthews_corrcoef | |
| from models import ModernBertForSentiment # Assuming models.py is in the same directory | |
| from tqdm import tqdm # Add this import for the progress bar | |
| def evaluate(model, dataloader, device): | |
| model.eval() | |
| all_preds = [] | |
| all_labels = [] | |
| all_probs_for_auc = [] | |
| total_loss = 0 | |
| num_batches = len(dataloader) | |
| processed_batches = 0 | |
| with torch.no_grad(): | |
| for batch in dataloader: # dataloader here should not be pre-wrapped with tqdm by the caller if we yield progress | |
| processed_batches += 1 | |
| # Move batch to device, ensure all model inputs are covered | |
| input_ids = batch['input_ids'].to(device) | |
| attention_mask = batch['attention_mask'].to(device) | |
| labels = batch['labels'].to(device) | |
| lengths = batch.get('lengths') # Get lengths from batch | |
| if lengths is None: | |
| # Fallback or error if lengths are expected but not found | |
| # For now, let's raise an error if using weighted loss that needs it | |
| # Or, if your model can run without it for some pooling strategies, handle accordingly | |
| # However, the error clearly states it's needed when labels are specified. | |
| pass # Or handle error: raise ValueError("'lengths' not found in batch, but required by model") | |
| else: | |
| lengths = lengths.to(device) # Move to device if found | |
| # Pass all necessary parts of the batch to the model | |
| model_inputs = { | |
| 'input_ids': input_ids, | |
| 'attention_mask': attention_mask, | |
| 'labels': labels | |
| } | |
| if lengths is not None: | |
| model_inputs['lengths'] = lengths | |
| outputs = model(**model_inputs) | |
| loss = outputs.loss | |
| logits = outputs.logits | |
| total_loss += loss.item() | |
| if logits.shape[1] > 1: | |
| preds = torch.argmax(logits, dim=1) | |
| else: | |
| preds = (torch.sigmoid(logits) > 0.5).long() | |
| all_preds.extend(preds.cpu().numpy()) | |
| all_labels.extend(labels.cpu().numpy()) | |
| # Populate probabilities for AUC calculation | |
| if logits.shape[1] > 1: | |
| # Multi-class or multi-label, assuming positive class is at index 1 for binary-like AUC | |
| probs_for_auc = torch.softmax(logits, dim=1)[:, 1] | |
| else: | |
| # Binary classification with a single logit output | |
| probs_for_auc = torch.sigmoid(logits).squeeze() | |
| all_probs_for_auc.extend(probs_for_auc.cpu().numpy()) | |
| # Yield progress update | |
| progress_update_frequency = max(1, num_batches // 20) # Ensure at least 1 to avoid modulo zero | |
| if processed_batches % progress_update_frequency == 0 or processed_batches == num_batches: # Update roughly 20 times + final | |
| yield f"Processed {processed_batches}/{num_batches} batches ({processed_batches/num_batches*100:.2f}%)" | |
| avg_loss = total_loss / num_batches | |
| accuracy = accuracy_score(all_labels, all_preds) | |
| f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0) | |
| precision = precision_score(all_labels, all_preds, average='weighted', zero_division=0) | |
| recall = recall_score(all_labels, all_preds, average='weighted', zero_division=0) | |
| mcc = matthews_corrcoef(all_labels, all_preds) | |
| try: | |
| roc_auc = roc_auc_score(all_labels, all_probs_for_auc) | |
| except ValueError as e: | |
| print(f"Could not calculate AUC-ROC: {e}. Labels: {list(set(all_labels))[:10]}. Probs example: {all_probs_for_auc[:5]}. Setting to 0.0") | |
| roc_auc = 0.0 | |
| results = { | |
| 'accuracy': accuracy, | |
| 'f1': f1, | |
| 'roc_auc': roc_auc, | |
| 'precision': precision, | |
| 'recall': recall, | |
| 'mcc': mcc, | |
| 'average_loss': avg_loss | |
| } | |
| yield f"Processed {processed_batches}/{num_batches} batches (100.00%)" # Ensure final progress update | |
| yield "Evaluation complete. Compiling results..." | |
| yield results | |
| if __name__ == "__main__": | |
| import argparse | |
| from torch.utils.data import DataLoader | |
| from datasets import load_dataset | |
| from inference import SentimentInference # Assuming inference.py is in the same directory | |
| import yaml | |
| from transformers import AutoTokenizer, AutoConfig | |
| from models import ModernBertForSentiment # Assuming models.py is in the same directory or PYTHONPATH | |
| class SentimentInference: | |
| def __init__(self, config_path): | |
| with open(config_path, 'r') as f: | |
| config_data = yaml.safe_load(f) | |
| self.config_path = config_path | |
| self.config_data = config_data | |
| # Adjust to access the correct key from the nested config structure | |
| self.model_hf_repo_id = config_data['model']['name_or_path'] | |
| self.tokenizer_name_or_path = config_data['model'].get('tokenizer_name_or_path', self.model_hf_repo_id) | |
| self.local_model_weights_path = config_data['model'].get('local_model_weights_path', None) # Assuming it might be under 'model' | |
| self.load_from_local_pt = config_data['model'].get('load_from_local_pt', False) | |
| self.trust_remote_code_for_config = config_data['model'].get('trust_remote_code_for_config', True) # Default to True for custom code | |
| self.max_length = config_data['model']['max_length'] | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") | |
| try: | |
| if self.load_from_local_pt and self.local_model_weights_path: | |
| print(f"Loading model from local path: {self.local_model_weights_path}") | |
| # When loading local, config might also be local or from base model if not saved with custom checkpoint | |
| # For simplicity, assume config is part of the saved pretrained local model or not strictly needed if all architecture is in code | |
| self.config = AutoConfig.from_pretrained(self.local_model_weights_path, trust_remote_code=self.trust_remote_code_for_config) | |
| self.model = ModernBertForSentiment.from_pretrained(self.local_model_weights_path, config=self.config, trust_remote_code=True) | |
| else: | |
| print(f"Loading base ModernBertConfig from: {self.model_hf_repo_id}") | |
| self.config = AutoConfig.from_pretrained(self.model_hf_repo_id, trust_remote_code=self.trust_remote_code_for_config) | |
| print(f"Instantiating and loading model weights for {self.model_hf_repo_id} using ModernBertForSentiment...") | |
| self.model = ModernBertForSentiment.from_pretrained(self.model_hf_repo_id, config=self.config, trust_remote_code=True) | |
| print(f"Model {self.model_hf_repo_id} loaded successfully from Hugging Face Hub using ModernBertForSentiment.") | |
| self.model.to(self.device) | |
| except Exception as e: | |
| print(f"Failed to load model: {e}") | |
| # Optionally print more detailed traceback | |
| import traceback | |
| traceback.print_exc() | |
| exit() | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name_or_path, trust_remote_code=self.trust_remote_code_for_config) | |
| def print_debug_info(self): | |
| print(f"Model HF Repo ID: {self.model_hf_repo_id}") | |
| print(f"Tokenizer Name or Path: {self.tokenizer_name_or_path}") | |
| print(f"Local Model Weights Path: {self.local_model_weights_path}") | |
| print(f"Load from Local PT: {self.load_from_local_pt}") | |
| parser = argparse.ArgumentParser(description="Evaluate a sentiment analysis model on the IMDB test set.") | |
| parser.add_argument( | |
| "--config_path", | |
| type=str, | |
| default="local_test_config.yaml", | |
| help="Path to the configuration file for SentimentInference (e.g., local_test_config.yaml or config.yaml)" | |
| ) | |
| parser.add_argument( | |
| "--batch_size", | |
| type=int, | |
| default=16, | |
| help="Batch size for evaluation." | |
| ) | |
| args = parser.parse_args() | |
| print(f"Using configuration: {args.config_path}") | |
| print("Loading sentiment model and tokenizer...") | |
| inferer = SentimentInference(config_path=args.config_path) | |
| model = inferer.model | |
| tokenizer = inferer.tokenizer | |
| max_length = inferer.max_length | |
| device = inferer.device | |
| print("Loading IMDB test dataset...") | |
| try: | |
| imdb_dataset_test = load_dataset("imdb", split="test") | |
| except Exception as e: | |
| print(f"Failed to load IMDB dataset: {e}") | |
| exit() | |
| print("Tokenizing dataset...") | |
| def tokenize_function(examples): | |
| tokenized_output = tokenizer(examples["text"], truncation=True, padding="max_length", max_length=max_length) | |
| tokenized_output["lengths"] = [sum(mask) for mask in tokenized_output["attention_mask"]] | |
| return tokenized_output | |
| tokenized_imdb_test = imdb_dataset_test.map(tokenize_function, batched=True) | |
| tokenized_imdb_test = tokenized_imdb_test.remove_columns(["text"]) | |
| tokenized_imdb_test = tokenized_imdb_test.rename_column("label", "labels") | |
| tokenized_imdb_test.set_format("torch", columns=["input_ids", "attention_mask", "labels", "lengths"]) | |
| test_dataloader = DataLoader(tokenized_imdb_test, batch_size=args.batch_size) | |
| print("Starting evaluation...") | |
| progress_bar = tqdm(evaluate(model, test_dataloader, device), desc="Evaluating") | |
| for update in progress_bar: | |
| if isinstance(update, dict): | |
| results = update | |
| break | |
| else: | |
| progress_bar.set_postfix_str(update) | |
| print("\n--- Evaluation Results ---") | |
| for key, value in results.items(): | |
| if isinstance(value, float): | |
| print(f"{key.capitalize()}: {value:.4f}") | |
| else: | |
| print(f"{key.capitalize()}: {value}") |