import gradio as gr import torch import pandas as pd import os from torchvision import transforms from PIL import Image from transformers import ConvNextV2ForImageClassification # --- Configuration --- # Paths are relative to the app's root directory in the Hugging Face Space DATA_DIR = '.' LIST_DIR = os.path.join(DATA_DIR, 'list') MODEL_PATH_HERBARIUM = os.path.join(DATA_DIR, 'herbarium_convnext_v2_base.pth') SPECIES_LIST_TXT = os.path.join(LIST_DIR, 'species_list.txt') DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # --- Load Species Information --- try: species_df = pd.read_csv(SPECIES_LIST_TXT, sep=';', header=None, names=['class_id', 'species_name']) class_names = list(species_df['species_name']) num_labels = len(class_names) except FileNotFoundError: # Fallback if the species list is not found class_names = [f"Class {i}" for i in range(100)] # Assuming 100 classes as a fallback num_labels = 100 print(f"Warning: '{SPECIES_LIST_TXT}' not found. Using generic class names.") # --- Image Transformations --- data_transforms = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # --- Model Loading --- def load_herbarium_model(): """Loads the Herbarium ConvNextV2 model.""" model = ConvNextV2ForImageClassification.from_pretrained( "facebook/convnextv2-base-22k-224", num_labels=num_labels, ignore_mismatched_sizes=True ) try: # Load the state dictionary model.load_state_dict(torch.load(MODEL_PATH_HERBARIUM, map_location=DEVICE)) except FileNotFoundError: print(f"Warning: Model weights not found at '{MODEL_PATH_HERBARIUM}'. The model is using pre-trained weights, not fine-tuned ones.") except Exception as e: print(f"Error loading model weights: {e}. The model is using pre-trained weights.") model = model.to(DEVICE) model.eval() return model # Load the primary model herbarium_model = load_herbarium_model() # --- Prediction Functions --- def predict_herbarium(image): """Runs inference on the herbarium model.""" if image is None: return "Please upload an image." # Preprocess the image image = data_transforms(image).unsqueeze(0) image = image.to(DEVICE) # Get model predictions with torch.no_grad(): outputs = herbarium_model(image).logits probabilities = torch.nn.functional.softmax(outputs, dim=1)[0] # Get top 5 predictions top5_prob, top5_indices = torch.topk(probabilities, 5) # Format results results = {class_names[i]: f"{p:.3f}" for i, p in zip(top5_indices, top5_prob)} return results def predict_placeholder_1(image): """Placeholder function for the second model.""" if image is None: return "Please upload an image." return "Model 2 is not available yet. Please check back later." def predict_placeholder_2(image): """Placeholder function for the third model.""" if image is None: return "Please upload an image." return "Model 3 is not available yet. Please check back later." # --- Main Prediction Logic --- def predict(model_choice, image): """Routes the prediction to the chosen model.""" if model_choice == "Herbarium Species Classifier": return predict_herbarium(image) elif model_choice == "Future Model 1 (Placeholder)": return predict_placeholder_1(image) elif model_choice == "Future Model 2 (Placeholder)": return predict_placeholder_2(image) else: return "Invalid model selected." # --- Gradio Interface --- with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown( """ # 🌿 Plant Species Classification ## AML Group Project - PsychicFireSong Upload an image of a plant to classify it. Select a model from the dropdown below. """ ) with gr.Row(): with gr.Column(scale=1): model_selector = gr.Dropdown( label="Select Model", choices=[ "Herbarium Species Classifier", "Future Model 1 (Placeholder)", "Future Model 2 (Placeholder)" ], value="Herbarium Species Classifier" ) image_input = gr.Image(type="pil", label="Upload Plant Image") submit_button = gr.Button("Classify", variant="primary") with gr.Column(scale=1): output_label = gr.Label(label="Top 5 Predictions", num_top_classes=5) submit_button.click( fn=predict, inputs=[model_selector, image_input], outputs=output_label ) gr.Examples( examples=[ # Add paths to example images if you have any in your project # e.g., os.path.join("examples", "example1.jpg") ], inputs=image_input, outputs=output_label, fn=lambda img: predict("Herbarium Species Classifier", img), cache_examples=False ) if __name__ == "__main__": demo.launch() demo.launch()