Saqib772's picture
image classification
7992750 verified
import streamlit as st
import torch
from PIL import Image
import numpy as np
from inference import get_inference_model
import json
# Page config
st.set_page_config(
page_title="🌌 Astronomy Image Classification",
page_icon="🌌",
layout="wide"
)
# Title
st.title("🌌 Astronomy Image Classification")
st.markdown("Classify astronomy images into 6 categories using ensemble of ResNet50 and DenseNet121 models")
# Sidebar
st.sidebar.title("📊 Model Info")
st.sidebar.markdown("""
**Models**: ResNet50 + DenseNet121 Ensemble
**ResNet50 Accuracy**: 64.86%
**DenseNet121 Accuracy**: 63.96%
**Ensemble**: Higher accuracy than individual models
**Classes**: 6 astronomy categories
**Input Size**: 224x224 pixels
""")
# Load model
@st.cache_resource
def load_model():
try:
return get_inference_model()
except Exception as e:
st.error(f"Error loading model: {e}")
return None
# Main interface
model = load_model()
if model is not None:
# Upload image
uploaded_file = st.file_uploader(
"Upload an astronomy image",
type=['jpg', 'jpeg', 'png'],
help="Upload an image of constellation, cosmos, galaxies, nebula, planets, or stars"
)
if uploaded_file is not None:
# Display image
col1, col2 = st.columns([1, 1])
with col1:
image = Image.open(uploaded_file)
st.image(image, caption="Uploaded Image", use_column_width=True)
with col2:
# Make prediction
with st.spinner("Analyzing image with ensemble models..."):
result = model.predict(image)
# Display results
st.subheader("🎯 Ensemble Prediction Results")
# Main prediction
predicted_class = result["predicted_class"]
confidence = result["confidence"]
# Color code based on confidence
if confidence > 0.8:
color = "��"
status = "High Confidence"
elif confidence > 0.6:
color = "🟡"
status = "Medium Confidence"
else:
color = "🔴"
status = "Low Confidence"
st.markdown(f"""
**{color} Predicted Class**: {predicted_class}
**Confidence**: {confidence:.3f}
**Status**: {status}
""")
# Progress bar
st.progress(confidence)
# Individual model results
if "individual_results" in result:
st.subheader("🔍 Individual Model Results")
individual_results = result["individual_results"]
for model_name, model_result in individual_results.items():
model_confidence = model_result["confidence"]
model_prediction = model_result["predicted_class"]
# Color code individual results
if model_confidence > 0.8:
model_color = "🟢"
elif model_confidence > 0.6:
model_color = "🟡"
else:
model_color = "🔴"
st.write(f"**{model_name}**: {model_color} {model_prediction} ({model_confidence:.3f})")
# All probabilities
st.subheader("�� All Class Probabilities")
probabilities = result["probabilities"]
# Create a more visual representation
for class_name, prob in sorted(probabilities.items(), key=lambda x: x[1], reverse=True):
# Create a bar chart for each probability
col_prob, col_bar = st.columns([2, 3])
with col_prob:
st.write(f"**{class_name}**")
with col_bar:
st.progress(prob)
st.write(f"{prob:.3f}")
# Sample images section
st.markdown("---")
st.subheader("📸 Sample Images")
# Create sample images with better descriptions
sample_cols = st.columns(3)
with sample_cols[0]:
st.markdown("**🌟 Constellation**")
st.info("Star patterns forming recognizable shapes like Orion, Big Dipper, etc.")
with sample_cols[1]:
st.markdown("**🌌 Galaxies**")
st.info("Spiral, elliptical, or irregular galaxies like Andromeda, Milky Way")
with sample_cols[2]:
st.markdown("**�� Nebula**")
st.info("Gas clouds and stellar nurseries like Orion Nebula, Eagle Nebula")
# Second row
sample_cols2 = st.columns(3)
with sample_cols2[0]:
st.markdown("**🪐 Planets**")
st.info("Solar system planets like Jupiter, Saturn, Mars, Earth")
with sample_cols2[1]:
st.markdown("**⭐ Stars**")
st.info("Individual stars, stellar objects, and stellar phenomena")
with sample_cols2[2]:
st.markdown("**🌠 Cosmos**")
st.info("General space scenes, cosmic phenomena, and deep space")
# Model comparison
st.markdown("---")
st.subheader("�� Model Performance Comparison")
perf_col1, perf_col2 = st.columns(2)
with perf_col1:
st.metric("ResNet50 Accuracy", "64.86%", "Base Model")
with perf_col2:
st.metric("DenseNet121 Accuracy", "63.96%", "Base Model")
st.info("🎯 **Ensemble Method**: Combines both models for higher accuracy than individual models")
else:
st.error("❌ Model could not be loaded. Please check the model files.")
st.markdown("""
**Required files:**
- `best_resnet50.pth` (ResNet50 model weights)
- `best_densenet121.pth` (DenseNet121 model weights)
""")
# Footer
st.markdown("---")
st.markdown("""
<div style='text-align: center'>
<p>�� Astronomy Image Classification System | Built with PyTorch & Streamlit</p>
<p>Ensemble of ResNet50 + DenseNet121 | Target Accuracy: >95% | Current: 64.86%</p>
<p>�� Deployed on Hugging Face Spaces</p>
</div>
""", unsafe_allow_html=True)