import torch import torch.nn as nn import torchvision.models as models class AstronomyClassifier(nn.Module): """Astronomy Image Classification Model""" def __init__(self, model_name='resnet50', num_classes=6, pretrained=False): super(AstronomyClassifier, self).__init__() self.model_name = model_name self.num_classes = num_classes # Load backbone if model_name == 'resnet50': self.backbone = models.resnet50(pretrained=pretrained) num_features = self.backbone.fc.in_features self.backbone.fc = nn.Identity() elif model_name == 'densenet121': self.backbone = models.densenet121(pretrained=pretrained) num_features = self.backbone.classifier.in_features self.backbone.classifier = nn.Identity() else: raise ValueError(f"Unsupported model: {model_name}") # Custom classifier self.classifier = nn.Sequential( nn.Dropout(0.5), nn.Linear(num_features, 512), nn.ReLU(), nn.BatchNorm1d(512), nn.Dropout(0.5), nn.Linear(512, 256), nn.ReLU(), nn.BatchNorm1d(256), nn.Dropout(0.5), nn.Linear(256, num_classes) ) def forward(self, x): features = self.backbone(x) output = self.classifier(features) return output # Model configuration MODEL_CONFIG = { "model_name": "resnet50", "num_classes": 6, "class_names": ["constellation", "cosmos", "galaxies", "nebula", "planets", "stars"], "input_size": (224, 224), "mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225] }