PsychicFireSong commited on
Commit
a2731d5
·
0 Parent(s):

Initial commit: Add Gradio app and model

Browse files
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ *.pth filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Plant Species Classification
3
+ emoji: 🌿
4
+ colorFrom: green
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 4.28.3
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ # 🌿 Plant Species Classification
14
+
15
+ This is a Gradio app for the AML Group Project by PsychicFireSong.
16
+
17
+ It uses a `ConvNextV2` model fine-tuned on the Herbarium Field dataset to classify plant species from images.
18
+
19
+ **Models Available:**
20
+ - **Herbarium Species Classifier:** The primary model for classification.
21
+ - **Future Model 1 (Placeholder):** Not yet implemented.
22
+ - **Future Model 2 (Placeholder):** Not yet implemented.
app.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import pandas as pd
4
+ import os
5
+ from torchvision import transforms
6
+ from PIL import Image
7
+ from transformers import ConvNextV2ForImageClassification
8
+
9
+ # --- Configuration ---
10
+ # Paths are relative to the app's root directory in the Hugging Face Space
11
+ DATA_DIR = '.'
12
+ LIST_DIR = os.path.join(DATA_DIR, 'list')
13
+ MODEL_PATH_HERBARIUM = os.path.join(DATA_DIR, 'herbarium_convnext_v2_base.pth')
14
+ SPECIES_LIST_TXT = os.path.join(LIST_DIR, 'species_list.txt')
15
+
16
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17
+
18
+ # --- Load Species Information ---
19
+ try:
20
+ species_df = pd.read_csv(SPECIES_LIST_TXT, sep=';', header=None, names=['class_id', 'species_name'])
21
+ class_names = list(species_df['species_name'])
22
+ num_labels = len(class_names)
23
+ except FileNotFoundError:
24
+ # Fallback if the species list is not found
25
+ class_names = [f"Class {i}" for i in range(100)] # Assuming 100 classes as a fallback
26
+ num_labels = 100
27
+ print(f"Warning: '{SPECIES_LIST_TXT}' not found. Using generic class names.")
28
+
29
+
30
+ # --- Image Transformations ---
31
+ data_transforms = transforms.Compose([
32
+ transforms.Resize(256),
33
+ transforms.CenterCrop(224),
34
+ transforms.ToTensor(),
35
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
36
+ ])
37
+
38
+ # --- Model Loading ---
39
+ def load_herbarium_model():
40
+ """Loads the Herbarium ConvNextV2 model."""
41
+ model = ConvNextV2ForImageClassification.from_pretrained(
42
+ "facebook/convnextv2-base-22k-224",
43
+ num_labels=num_labels,
44
+ ignore_mismatched_sizes=True
45
+ )
46
+ try:
47
+ # Load the state dictionary
48
+ model.load_state_dict(torch.load(MODEL_PATH_HERBARIUM, map_location=DEVICE))
49
+ except FileNotFoundError:
50
+ print(f"Warning: Model weights not found at '{MODEL_PATH_HERBARIUM}'. The model is using pre-trained weights, not fine-tuned ones.")
51
+ except Exception as e:
52
+ print(f"Error loading model weights: {e}. The model is using pre-trained weights.")
53
+
54
+ model = model.to(DEVICE)
55
+ model.eval()
56
+ return model
57
+
58
+ # Load the primary model
59
+ herbarium_model = load_herbarium_model()
60
+
61
+ # --- Prediction Functions ---
62
+ def predict_herbarium(image):
63
+ """Runs inference on the herbarium model."""
64
+ if image is None:
65
+ return "Please upload an image."
66
+
67
+ # Preprocess the image
68
+ image = data_transforms(image).unsqueeze(0)
69
+ image = image.to(DEVICE)
70
+
71
+ # Get model predictions
72
+ with torch.no_grad():
73
+ outputs = herbarium_model(image).logits
74
+ probabilities = torch.nn.functional.softmax(outputs, dim=1)[0]
75
+
76
+ # Get top 5 predictions
77
+ top5_prob, top5_indices = torch.topk(probabilities, 5)
78
+
79
+ # Format results
80
+ results = {class_names[i]: f"{p:.3f}" for i, p in zip(top5_indices, top5_prob)}
81
+ return results
82
+
83
+ def predict_placeholder_1(image):
84
+ """Placeholder function for the second model."""
85
+ if image is None:
86
+ return "Please upload an image."
87
+ return "Model 2 is not available yet. Please check back later."
88
+
89
+ def predict_placeholder_2(image):
90
+ """Placeholder function for the third model."""
91
+ if image is None:
92
+ return "Please upload an image."
93
+ return "Model 3 is not available yet. Please check back later."
94
+
95
+ # --- Main Prediction Logic ---
96
+ def predict(model_choice, image):
97
+ """Routes the prediction to the chosen model."""
98
+ if model_choice == "Herbarium Species Classifier":
99
+ return predict_herbarium(image)
100
+ elif model_choice == "Future Model 1 (Placeholder)":
101
+ return predict_placeholder_1(image)
102
+ elif model_choice == "Future Model 2 (Placeholder)":
103
+ return predict_placeholder_2(image)
104
+ else:
105
+ return "Invalid model selected."
106
+
107
+ # --- Gradio Interface ---
108
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
109
+ gr.Markdown(
110
+ """
111
+ # 🌿 Plant Species Classification
112
+ ## AML Group Project - PsychicFireSong
113
+ Upload an image of a plant to classify it. Select a model from the dropdown below.
114
+ """
115
+ )
116
+
117
+ with gr.Row():
118
+ with gr.Column(scale=1):
119
+ model_selector = gr.Dropdown(
120
+ label="Select Model",
121
+ choices=[
122
+ "Herbarium Species Classifier",
123
+ "Future Model 1 (Placeholder)",
124
+ "Future Model 2 (Placeholder)"
125
+ ],
126
+ value="Herbarium Species Classifier"
127
+ )
128
+ image_input = gr.Image(type="pil", label="Upload Plant Image")
129
+ submit_button = gr.Button("Classify", variant="primary")
130
+
131
+ with gr.Column(scale=1):
132
+ output_label = gr.Label(label="Top 5 Predictions", num_top_classes=5)
133
+
134
+ submit_button.click(
135
+ fn=predict,
136
+ inputs=[model_selector, image_input],
137
+ outputs=output_label
138
+ )
139
+
140
+ gr.Examples(
141
+ examples=[
142
+ # Add paths to example images if you have any in your project
143
+ # e.g., os.path.join("examples", "example1.jpg")
144
+ ],
145
+ inputs=image_input,
146
+ outputs=output_label,
147
+ fn=lambda img: predict("Herbarium Species Classifier", img),
148
+ cache_examples=False
149
+ )
150
+
151
+ if __name__ == "__main__":
152
+ demo.launch()
herbarium_convnext_v2_base.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:837cca126e235c0ae822770470e38a3621b81b0ba7e915aaef2b15a7f66914e6
3
+ size 351335085
list/species_list.txt ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 105951; Maripa glabra Choisy
2
+ 106023; Merremia umbellata (L.) Hallier f.
3
+ 106387; Costus arabicus L.
4
+ 106461; Costus scaber Ruiz Pav.
5
+ 106466; Costus spiralis (Jacq.) Roscoe
6
+ 110432; Evodianthus funifer (Poit.) Lindm.
7
+ 116853; Pteridium arachnoideum (Kaulf.) Maxon
8
+ 119986; Olfersia cervina (L.) Kunze
9
+ 120497; Diospyros capreifolia Mart. ex Hiern
10
+ 121836; Sloanea grandiflora Sm.
11
+ 121841; Sloanea guianensis (Aubl.) Benth.
12
+ 12254; Anacardium occidentale L.
13
+ 12518; Mangifera indica L.
14
+ 125412; Sphyrospermum cordifolium Benth.
15
+ 126895; Syngonanthus caulescens (Poir.) Ruhland
16
+ 127007; Tonina fluviatilis Aubl.
17
+ 127097; Erythroxylum fimbriatum Peyr.
18
+ 127151; Erythroxylum macrophyllum Cav.
19
+ 127242; Erythroxylum squamatum Sw.
20
+ 12910; Spondias mombin L.
21
+ 12922; Tapirira guianensis Aubl.
22
+ 129645; Croton schiedeanus Schltdl.
23
+ 130657; Euphorbia cotinifolia L.
24
+ 131079; Euphorbia heterophylla L.
25
+ 131736; Euphorbia prostrata Aiton
26
+ 132107; Euphorbia thymifolia L.
27
+ 132113; Euphorbia tithymaloides L.
28
+ 132431; Hura crepitans L.
29
+ 132476; Jatropha curcas L.
30
+ 132501; Jatropha gossypiifolia L.
31
+ 13276; Annona ambotay Aubl.
32
+ 13325; Annona foetida Mart.
33
+ 13330; Annona glabra L.
34
+ 133595; Ricinus communis L.
35
+ 133617; Sapium glandulosum (L.) Morong
36
+ 13370; Annona muricata L.
37
+ 136761; Potalia amara Aubl.
38
+ 138662; Chrysothemis pulchella (Donn ex Sims) Decne.
39
+ 140367; Lembocarpus amoenus Leeuwenb.
40
+ 141068; Sinningia incarnata (Aubl.) D.L.Denham
41
+ 141332; Dicranopteris flexuosa (Schrad.) Underw.
42
+ 141336; Dicranopteris pectinata (Willd.) Underw.
43
+ 142550; Heliconia chartacea Lane ex Barreiros
44
+ 142736; Hernandia guianensis Aubl.
45
+ 143496; Hymenophyllum hirsutum (L.) Sw.
46
+ 14353; Guatteria ouregou (Aubl.) Dunal
47
+ 143706; Trichomanes diversifrons (Bory) Mett. ex Sadeb.
48
+ 143758; Trichomanes punctatum Poir.
49
+ 14401; Guatteria scandens Ducke
50
+ 144394; Didymochlaena truncatula (Sw.) J. Sm.
51
+ 145020; Cipura paludosa Aubl.
52
+ 148220; Aegiphila macrantha Ducke
53
+ 148977; Clerodendrum paniculatum L.
54
+ 149264; Congea tomentosa Roxb.
55
+ 149682; Gmelina philippensis Cham.
56
+ 149919; Holmskioldia sanguinea Retz.
57
+ 150135; Hyptis lanceolata Poir.
58
+ 15014; Rollinia mucosa (Jacq.) Baill.
59
+ 151469; Ocimum campechianum Mill.
60
+ 151593; Orthosiphon aristatus (Blume) Miq.
61
+ 15318; Xylopia aromatica (Lam.) Mart.
62
+ 15330; Xylopia cayennensis Maas
63
+ 15355; Xylopia frutescens Aubl.
64
+ 156516; Aniba guianensis Aubl.
65
+ 156526; Aniba megaphylla Mez
66
+ 158341; Nectandra cissiflora Nees
67
+ 158592; Ocotea cernua (Nees) Mez
68
+ 158653; Ocotea floribunda (Sw.) Mez
69
+ 158736; Ocotea longifolia Kunth
70
+ 158793; Ocotea oblonga (Meisn.) Mez
71
+ 158833; Ocotea puberula (Rich.) Nees
72
+ 159434; Couratari guianensis Aubl.
73
+ 159516; Eschweilera parviflora (Aubl.) Miers
74
+ 159518; Eschweilera pedicellata (Rich.) S.A.Mori
75
+ 160570; Acacia mangium Willd.
76
+ 166822; Caesalpinia pulcherrima (L.) Sw.
77
+ 166869; Cajanus cajan (L.) Millsp.
78
+ 169293; Crotalaria retusa L.
79
+ 171727; Erythrina fusca Lour.
80
+ 173914; Inga alba (Sw.) Willd.
81
+ 173972; Inga capitata Desv.
82
+ 174017; Inga edulis Mart.
83
+ 177730; Mimosa pigra L.
84
+ 177775; Mimosa pudica L.
85
+ 189669; Punica granatum L.
86
+ 191642; Adansonia digitata L.
87
+ 19165; Allamanda cathartica L.
88
+ 192311; Ceiba pentandra (L.) Gaertn.
89
+ 194035; Hibiscus rosa-sinensis L.
90
+ 19489; Asclepias curassavica L.
91
+ 209328; Psidium guineense Sw.
92
+ 211059; Nephrolepis biserrata (Sw.) Schott
93
+ 244705; Averrhoa carambola L.
94
+ 248392; Turnera ulmifolia L.
95
+ 254180; Piper peltatum L.
96
+ 275029; Eichhornia crassipes (Mart.) Solms
97
+ 280085; Ceratopteris thalictroides (L.) Brongn.
98
+ 280698; Pityrogramma calomelanos (L.) Link
99
+ 285398; Cassipourea guianensis Aubl.
100
+ 29686; Oreopanax capitatus (Jacq.) Decne. Planch.
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ transformers
4
+ pandas
5
+ gradio