import streamlit as st import torch from PIL import Image import numpy as np from inference import get_inference_model import json # Page config st.set_page_config( page_title="🌌 Astronomy Image Classification", page_icon="🌌", layout="wide" ) # Title st.title("🌌 Astronomy Image Classification") st.markdown("Classify astronomy images into 6 categories using ensemble of ResNet50 and DenseNet121 models") # Sidebar st.sidebar.title("📊 Model Info") st.sidebar.markdown(""" **Models**: ResNet50 + DenseNet121 Ensemble **ResNet50 Accuracy**: 64.86% **DenseNet121 Accuracy**: 63.96% **Ensemble**: Higher accuracy than individual models **Classes**: 6 astronomy categories **Input Size**: 224x224 pixels """) # Load model @st.cache_resource def load_model(): try: return get_inference_model() except Exception as e: st.error(f"Error loading model: {e}") return None # Main interface model = load_model() if model is not None: # Upload image uploaded_file = st.file_uploader( "Upload an astronomy image", type=['jpg', 'jpeg', 'png'], help="Upload an image of constellation, cosmos, galaxies, nebula, planets, or stars" ) if uploaded_file is not None: # Display image col1, col2 = st.columns([1, 1]) with col1: image = Image.open(uploaded_file) st.image(image, caption="Uploaded Image", use_column_width=True) with col2: # Make prediction with st.spinner("Analyzing image with ensemble models..."): result = model.predict(image) # Display results st.subheader("🎯 Ensemble Prediction Results") # Main prediction predicted_class = result["predicted_class"] confidence = result["confidence"] # Color code based on confidence if confidence > 0.8: color = "��" status = "High Confidence" elif confidence > 0.6: color = "🟡" status = "Medium Confidence" else: color = "🔴" status = "Low Confidence" st.markdown(f""" **{color} Predicted Class**: {predicted_class} **Confidence**: {confidence:.3f} **Status**: {status} """) # Progress bar st.progress(confidence) # Individual model results if "individual_results" in result: st.subheader("🔍 Individual Model Results") individual_results = result["individual_results"] for model_name, model_result in individual_results.items(): model_confidence = model_result["confidence"] model_prediction = model_result["predicted_class"] # Color code individual results if model_confidence > 0.8: model_color = "🟢" elif model_confidence > 0.6: model_color = "🟡" else: model_color = "🔴" st.write(f"**{model_name}**: {model_color} {model_prediction} ({model_confidence:.3f})") # All probabilities st.subheader("�� All Class Probabilities") probabilities = result["probabilities"] # Create a more visual representation for class_name, prob in sorted(probabilities.items(), key=lambda x: x[1], reverse=True): # Create a bar chart for each probability col_prob, col_bar = st.columns([2, 3]) with col_prob: st.write(f"**{class_name}**") with col_bar: st.progress(prob) st.write(f"{prob:.3f}") # Sample images section st.markdown("---") st.subheader("📸 Sample Images") # Create sample images with better descriptions sample_cols = st.columns(3) with sample_cols[0]: st.markdown("**🌟 Constellation**") st.info("Star patterns forming recognizable shapes like Orion, Big Dipper, etc.") with sample_cols[1]: st.markdown("**🌌 Galaxies**") st.info("Spiral, elliptical, or irregular galaxies like Andromeda, Milky Way") with sample_cols[2]: st.markdown("**�� Nebula**") st.info("Gas clouds and stellar nurseries like Orion Nebula, Eagle Nebula") # Second row sample_cols2 = st.columns(3) with sample_cols2[0]: st.markdown("**🪐 Planets**") st.info("Solar system planets like Jupiter, Saturn, Mars, Earth") with sample_cols2[1]: st.markdown("**⭐ Stars**") st.info("Individual stars, stellar objects, and stellar phenomena") with sample_cols2[2]: st.markdown("**🌠 Cosmos**") st.info("General space scenes, cosmic phenomena, and deep space") # Model comparison st.markdown("---") st.subheader("�� Model Performance Comparison") perf_col1, perf_col2 = st.columns(2) with perf_col1: st.metric("ResNet50 Accuracy", "64.86%", "Base Model") with perf_col2: st.metric("DenseNet121 Accuracy", "63.96%", "Base Model") st.info("🎯 **Ensemble Method**: Combines both models for higher accuracy than individual models") else: st.error("❌ Model could not be loaded. Please check the model files.") st.markdown(""" **Required files:** - `best_resnet50.pth` (ResNet50 model weights) - `best_densenet121.pth` (DenseNet121 model weights) """) # Footer st.markdown("---") st.markdown("""

�� Astronomy Image Classification System | Built with PyTorch & Streamlit

Ensemble of ResNet50 + DenseNet121 | Target Accuracy: >95% | Current: 64.86%

�� Deployed on Hugging Face Spaces

""", unsafe_allow_html=True)