import torch import torch.nn.functional as F from torchvision import transforms from PIL import Image import gradio as gr import numpy as np import cv2 # Load the model class definition from models.efficientnet_b0 import EfficientNetB0Classifier # Constants MODEL_PATH = "efficientnet_best9912.pth" CLASS_NAMES = ["Fresh", "Not Fresh"] INPUT_SIZE = 380 MODEL_ACCURACY = "99.12%" # Your model's validation accuracy # Define preprocessing pipeline preprocess = transforms.Compose([ transforms.Resize((INPUT_SIZE, INPUT_SIZE)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Load model def load_model(): model = EfficientNetB0Classifier(train_base=False) model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu'))) model.eval() return model model = load_model() def process_prediction(confidence_score): """Convert model output to detailed prediction information""" fresh_prob = float(confidence_score) not_fresh_prob = float(1.0 - confidence_score) prediction = "Fresh" if fresh_prob > 0.5 else "Not Fresh" confidence = fresh_prob if fresh_prob > 0.5 else not_fresh_prob return { "Fresh": fresh_prob, "Not Fresh": not_fresh_prob }, prediction, confidence def analyze_image(image): """Analyze the image and return detailed results""" if image is None: return None, None, None, None # Convert to RGB if needed if len(image.shape) == 2: # Grayscale image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) elif image.shape[2] == 4: # RGBA image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) # Prepare image for model pil_image = Image.fromarray(image).convert('RGB') input_tensor = preprocess(pil_image).unsqueeze(0) # Get prediction with torch.no_grad(): output = model(input_tensor) confidence_score = output.item() # Process results probabilities, prediction, confidence = process_prediction(confidence_score) # Create result message confidence_percentage = f"{confidence * 100:.2f}%" message = f"Prediction: {prediction} (Confidence: {confidence_percentage})" # Prepare visualization display_image = cv2.resize(image, (INPUT_SIZE, INPUT_SIZE)) return probabilities, message, display_image, confidence_percentage # Custom CSS for better styling custom_css = """ .gradio-container { font-family: 'IBM Plex Sans', sans-serif; } .gr-button { color: white; border-radius: 8px; background: linear-gradient(45deg, #4CAF50, #45a049); border: none; font-size: 1.2em; padding: 10px 20px; } .gr-button:hover { background: linear-gradient(45deg, #45a049, #4CAF50); transform: translateY(-2px); box-shadow: 0 5px 15px rgba(0,0,0,0.1); } .footer { margin-top: 20px; text-align: center; font-size: 0.8em; } .confidence { font-size: 1.2em; font-weight: bold; margin-top: 10px; } .container { max-width: 1200px; margin: 0 auto; padding: 20px; } .result-box { background: #f8f9fa; border-radius: 10px; padding: 20px; margin-top: 20px; box-shadow: 0 2px 10px rgba(0,0,0,0.1); } """ # Create Gradio interface with gr.Blocks(css=custom_css) as demo: gr.Markdown( """ # 🐟 Fish Freshness Classifier Upload a fish image and get instant freshness analysis using our advanced AI model. ### Model Performance - Architecture: EfficientNet-B0 - Validation Accuracy: """ + MODEL_ACCURACY + """ """ ) with gr.Row(): with gr.Column(scale=1): input_image = gr.Image( label="Upload Fish Image", type="numpy", height=400, sources=["upload", "webcam", "clipboard"] ) upload_button = gr.Button("📸 Analyze Freshness", variant="primary", size="lg") with gr.Column(scale=1): with gr.Group(elem_classes="result-box"): output_label = gr.Label( num_top_classes=2, label="Freshness Analysis", show_label=True ) result_message = gr.Textbox( label="Detailed Result", show_copy_button=True ) confidence_indicator = gr.Textbox( label="Confidence Level", show_copy_button=True ) gr.Markdown( """ ### 📝 Best Practices - Use clear, well-lit images - Ensure the fish is clearly visible - Include key features (eyes, gills, skin) - Avoid blurry or dark photos """ ) # Set up the prediction flow upload_button.click( fn=analyze_image, inputs=input_image, outputs=[output_label, result_message, input_image, confidence_indicator] ) if __name__ == "__main__": demo.launch(share=True)