import torch import torch.nn.functional as F import numpy as np from PIL import Image import albumentations as A from albumentations.pytorch import ToTensorV2 from model import AstronomyClassifier, MODEL_CONFIG class AstronomyInference: """Astronomy Image Classification Inference with Ensemble Support""" def __init__(self, use_ensemble=True, device="cpu"): self.device = torch.device(device) self.class_names = MODEL_CONFIG["class_names"] self.num_classes = MODEL_CONFIG["num_classes"] self.use_ensemble = use_ensemble # Load models self.models = {} self.load_models() # Setup transforms self.transform = A.Compose([ A.Resize(MODEL_CONFIG["input_size"][0], MODEL_CONFIG["input_size"][1]), A.Normalize( mean=MODEL_CONFIG["mean"], std=MODEL_CONFIG["std"] ), ToTensorV2() ]) def load_models(self): """Load both ResNet50 and DenseNet121 models""" try: # Load ResNet50 resnet_model = AstronomyClassifier( model_name="resnet50", num_classes=self.num_classes, pretrained=False ) resnet_state_dict = torch.load("best_resnet50.pth", map_location=self.device) resnet_model.load_state_dict(resnet_state_dict) resnet_model.to(self.device) resnet_model.eval() self.models["resnet50"] = resnet_model print("✅ ResNet50 model loaded successfully") except Exception as e: print(f"❌ Failed to load ResNet50: {e}") try: # Load DenseNet121 densenet_model = AstronomyClassifier( model_name="densenet121", num_classes=self.num_classes, pretrained=False ) densenet_state_dict = torch.load("best_densenet121.pth", map_location=self.device) densenet_model.load_state_dict(densenet_state_dict) densenet_model.to(self.device) densenet_model.eval() self.models["densenet121"] = densenet_model print("✅ DenseNet121 model loaded successfully") except Exception as e: print(f"❌ Failed to load DenseNet121: {e}") def preprocess_image(self, image): """Preprocess image for inference""" if isinstance(image, str): image = Image.open(image).convert('RGB') elif isinstance(image, np.ndarray): image = Image.fromarray(image).convert('RGB') # Apply transforms image_np = np.array(image) transformed = self.transform(image=image_np) image_tensor = transformed['image'].unsqueeze(0) return image_tensor.to(self.device) def predict_single_model(self, model, image_tensor): """Predict using a single model""" with torch.no_grad(): outputs = model(image_tensor) probabilities = F.softmax(outputs, dim=1) confidence, predicted = torch.max(probabilities, 1) predicted_class = self.class_names[predicted.item()] confidence_score = confidence.item() all_probs = probabilities[0].cpu().numpy() return predicted_class, confidence_score, all_probs def predict_ensemble(self, image_tensor): """Predict using ensemble of models""" all_probabilities = [] individual_results = {} for model_name, model in self.models.items(): predicted_class, confidence, probs = self.predict_single_model(model, image_tensor) all_probabilities.append(probs) individual_results[model_name] = { "predicted_class": predicted_class, "confidence": confidence } # Average probabilities (soft voting) avg_probabilities = np.mean(all_probabilities, axis=0) predicted_class = self.class_names[np.argmax(avg_probabilities)] confidence_score = float(np.max(avg_probabilities)) # Create probability dictionary prob_dict = { self.class_names[i]: float(avg_probabilities[i]) for i in range(len(self.class_names)) } return { "predicted_class": predicted_class, "confidence": confidence_score, "probabilities": prob_dict, "individual_results": individual_results } def predict(self, image, return_probabilities=True): """Predict image class""" # Preprocess image_tensor = self.preprocess_image(image) if self.use_ensemble and len(self.models) > 1: # Use ensemble prediction result = self.predict_ensemble(image_tensor) if return_probabilities: return result else: return { "predicted_class": result["predicted_class"], "confidence": result["confidence"] } else: # Use single model (first available) model_name = list(self.models.keys())[0] model = self.models[model_name] predicted_class, confidence, all_probs = self.predict_single_model(model, image_tensor) if return_probabilities: prob_dict = { self.class_names[i]: float(all_probs[i]) for i in range(len(self.class_names)) } return { "predicted_class": predicted_class, "confidence": confidence, "probabilities": prob_dict, "model_used": model_name } else: return { "predicted_class": predicted_class, "confidence": confidence, "model_used": model_name } # Global inference instance inference_model = None def get_inference_model(): """Get or create inference model""" global inference_model if inference_model is None: inference_model = AstronomyInference(use_ensemble=True) return inference_model