iqramukhtiar's picture
Update app.py
800cbbf verified
raw
history blame
16.7 kB
import gradio as gr
import pandas as pd
import os
import torch
from torchvision import models, transforms
from PIL import Image
import numpy as np
import json
# --- Model Setup ---
def load_model():
# Load a pre-trained ResNet model
model = models.resnet50(pretrained=False)
# Modify the final layer for our number of classes
num_classes = 38 # PlantVillage has 38 classes
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
# Load the trained weights
try:
model.load_state_dict(torch.load('plant_disease_model.pth', map_location=torch.device('cpu')))
print("Model loaded successfully")
except:
print("Model weights not found, using untrained model for demonstration")
model.eval()
return model
# --- Data Loading ---
def load_class_names():
try:
with open('class_names.json', 'r') as f:
class_names = json.load(f)
print("Class names loaded successfully")
return class_names
except:
print("Class names file not found, using default classes")
# Default class names from PlantVillage dataset
return [
"Apple___Apple_scab", "Apple___Black_rot", "Apple___Cedar_apple_rust", "Apple___healthy",
"Cherry___healthy", "Cherry___Powdery_mildew", "Corn___Cercospora_leaf_spot",
"Corn___Common_rust", "Corn___healthy", "Corn___Northern_Leaf_Blight",
"Grape___Black_rot", "Grape___Esca_(Black_Measles)", "Grape___healthy",
"Grape___Leaf_blight_(Isariopsis_Leaf_Spot)", "Orange___Haunglongbing_(Citrus_greening)",
"Peach___Bacterial_spot", "Peach___healthy", "Pepper,_bell___Bacterial_spot",
"Pepper,_bell___healthy", "Potato___Early_blight", "Potato___healthy",
"Potato___Late_blight", "Squash___Powdery_mildew", "Strawberry___healthy",
"Strawberry___Leaf_scorch", "Tomato___Bacterial_spot", "Tomato___Early_blight",
"Tomato___healthy", "Tomato___Late_blight", "Tomato___Leaf_Mold",
"Tomato___Septoria_leaf_spot", "Tomato___Spider_mites Two-spotted_spider_mite",
"Tomato___Target_Spot", "Tomato___Tomato_mosaic_virus",
"Tomato___Tomato_Yellow_Leaf_Curl_Virus"
]
def load_treatments_data():
try:
context = pd.read_csv('crop_diseases_treatments.csv')
print("CSV file loaded successfully.")
return context
except FileNotFoundError:
print("Error: crop_diseases_treatments.csv not found.")
# Create a minimal dataframe for demonstration
return pd.DataFrame({
'Crop': ['Tomato', 'Apple', 'Corn', 'Potato', 'Grape', 'Cherry', 'Peach', 'Strawberry'],
'Disease': ['Early Blight', 'Apple Scab', 'Common Rust', 'Late Blight', 'Black Rot', 'Powdery Mildew', 'Bacterial Spot', 'Leaf Scorch'],
'Symptoms': [
'Brown spots with concentric rings on leaves',
'Olive-green to brown spots on leaves and fruit',
'Rust-colored pustules on leaves',
'Dark water-soaked spots on leaves that turn brown',
'Reddish-brown spots on leaves and fruit',
'White powdery coating on leaves and stems',
'Small brown spots on leaves and fruit',
'Scorched appearance on leaf margins'
],
'Treatment': [
'Remove affected leaves, improve air circulation, rotate crops',
'Prune affected branches, remove fallen leaves, apply fungicide',
'Remove infected plants, apply fungicide early in season',
'Remove infected plants, avoid overhead watering, apply fungicide',
'Prune infected areas, remove mummified fruit, apply fungicide',
'Improve air circulation, apply fungicide, remove infected parts',
'Copper-based sprays, crop rotation, remove infected plants',
'Ensure proper watering, add mulch, improve soil drainage'
],
'Medicine/Chemical Control': [
'Chlorothalonil, Mancozeb, Copper fungicides',
'Captan, Myclobutanil, Sulfur sprays',
'Propiconazole, Azoxystrobin, Mancozeb',
'Chlorothalonil, Mancozeb, Copper-based fungicides',
'Captan, Myclobutanil, Mancozeb',
'Sulfur, Potassium bicarbonate, Neem oil',
'Copper hydroxide, Streptomycin sulfate',
'Calcium nitrate sprays, Fungicides with Captan'
]
})
# --- Image Processing ---
def prepare_image_transform():
return transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# --- Disease Classification ---
def classify_disease(image, model, class_names, transform):
if image is None:
return None, None
# Process the image
img_tensor = transform(image).unsqueeze(0)
# Make prediction
with torch.no_grad():
outputs = model(img_tensor)
_, predicted = torch.max(outputs, 1)
class_idx = predicted.item()
if class_idx < len(class_names):
class_name = class_names[class_idx]
# Extract crop and disease from class name
parts = class_name.split("___")
crop = parts[0].replace("_", " ")
disease = parts[1].replace("_", " ") if len(parts) > 1 else "healthy"
return crop, disease
return "Unknown", "Unknown"
# --- Treatment Lookup ---
def find_treatment(crop, disease, df):
if disease.lower() == "healthy":
return {
"crop": crop,
"disease": "healthy",
"symptoms": "No symptoms - plant is healthy",
"treatment": "No treatment needed",
"medicine": "N/A"
}
# Search in the DataFrame
matches = df[(df['Crop'].str.lower() == crop.lower()) &
(df['Disease'].str.lower().str.contains(disease.lower()))]
if not matches.empty:
# Get the first match
match = matches.iloc[0]
return {
"crop": match['Crop'],
"disease": match['Disease'],
"symptoms": match['Symptoms'],
"treatment": match['Treatment'],
"medicine": match['Medicine/Chemical Control']
}
else:
# Try to find a partial match
matches = df[df['Disease'].str.lower().str.contains(disease.lower())]
if not matches.empty:
match = matches.iloc[0]
return {
"crop": match['Crop'],
"disease": match['Disease'],
"symptoms": match['Symptoms'],
"treatment": match['Treatment'],
"medicine": match['Medicine/Chemical Control']
}
return {
"crop": crop,
"disease": disease,
"symptoms": "Unknown",
"treatment": "No specific treatment information found. General advice: Remove affected parts, ensure proper spacing for air circulation, and consider organic or chemical fungicides appropriate for your region.",
"medicine": "Consult a local agricultural extension office or plant pathologist for specific recommendations for your region."
}
# --- Q&A Function ---
def answer_question(question, df):
question = question.lower()
# Look for crop and disease mentions in the question
crop_match = None
disease_match = None
for crop in df['Crop'].unique():
if crop.lower() in question:
crop_match = crop
break
for disease in df['Disease'].unique():
if disease.lower() in question:
disease_match = disease
break
# If we found both crop and disease
if crop_match and disease_match:
matches = df[(df['Crop'] == crop_match) & (df['Disease'] == disease_match)]
if not matches.empty:
match = matches.iloc[0]
return f"""
## {match['Crop']} - {match['Disease']}
**Symptoms:**
{match['Symptoms']}
**Treatment:**
{match['Treatment']}
**Recommended Products:**
{match['Medicine/Chemical Control']}
"""
# If we found only crop
elif crop_match:
matches = df[df['Crop'] == crop_match]
if not matches.empty:
response = f"## Common Diseases Affecting {crop_match}\n\n"
for _, row in matches.iterrows():
response += f"### {row['Disease']}\n"
response += f"**Symptoms:** {row['Symptoms']}\n\n"
response += f"**Treatment:** {row['Treatment']}\n\n"
response += f"**Products:** {row['Medicine/Chemical Control']}\n\n"
return response
# If we found only disease
elif disease_match:
matches = df[df['Disease'] == disease_match]
if not matches.empty:
response = f"## {disease_match} in Different Crops\n\n"
for _, row in matches.iterrows():
response += f"### {row['Crop']}\n"
response += f"**Symptoms:** {row['Symptoms']}\n\n"
response += f"**Treatment:** {row['Treatment']}\n\n"
response += f"**Products:** {row['Medicine/Chemical Control']}\n\n"
return response
# General search
keywords = question.split()
relevant_rows = []
for _, row in df.iterrows():
score = 0
for keyword in keywords:
if len(keyword) < 3: # Skip short words
continue
if keyword in str(row['Crop']).lower():
score += 3
if keyword in str(row['Disease']).lower():
score += 5
if keyword in str(row['Symptoms']).lower():
score += 2
if keyword in str(row['Treatment']).lower():
score += 1
if keyword in str(row['Medicine/Chemical Control']).lower():
score += 1
if score > 0:
relevant_rows.append((score, row))
if relevant_rows:
# Sort by relevance score
relevant_rows.sort(key=lambda x: x[0], reverse=True)
top_matches = relevant_rows[:3] # Get top 3 matches
response = "## Here's what I found based on your question:\n\n"
for _, row in top_matches:
response += f"### {row['Crop']} - {row['Disease']}\n"
response += f"**Symptoms:** {row['Symptoms']}\n\n"
response += f"**Treatment:** {row['Treatment']}\n\n"
response += f"**Products:** {row['Medicine/Chemical Control']}\n\n"
return response
return "I couldn't find specific information related to your question in my database. Please try asking about specific crops or diseases, or upload an image for analysis."
# --- Image Processing Function ---
def process_image(image, model, class_names, transform, df):
if image is None:
return "Please upload an image to analyze.", None, None, None, None
# Identify the crop and disease
crop, disease = classify_disease(image, model, class_names, transform)
if crop is None or disease is None:
return "Unable to analyze the image. Please try a clearer image of a plant leaf.", None, None, None, None
# Get treatment information
treatment_info = find_treatment(crop, disease, df)
# Create analysis text
if disease.lower() == "healthy":
analysis = f"## Good news! \nThis {crop} plant appears to be healthy."
else:
analysis = f"## Analysis Results\n\nThis {crop} plant appears to be affected by **{disease}**.\n\n"
analysis += f"**Typical symptoms:**\n{treatment_info['symptoms']}\n\n"
analysis += f"**Recommended treatment:**\n{treatment_info['treatment']}\n\n"
analysis += f"**Recommended products:**\n{treatment_info['medicine']}"
return analysis, treatment_info["crop"], treatment_info["disease"], treatment_info["treatment"], treatment_info["medicine"]
# --- Main Function ---
def main():
# Load model and data
model = load_model()
class_names = load_class_names()
transform = prepare_image_transform()
df = load_treatments_data()
# Create Gradio interface
with gr.Blocks(title="Plant Disease Assistant", css="footer {visibility: hidden}") as app:
gr.Markdown("""
# 🌱 Plant Disease Treatment Assistant
Upload a plant image or ask a question to get disease identification and treatment information.
This tool uses a machine learning model trained on the PlantVillage dataset to identify plant diseases.
""")
with gr.Tabs():
with gr.TabItem("Image Analysis"):
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="pil", label="Upload Plant Image")
image_submit = gr.Button("Analyze Image", variant="primary")
gr.Markdown("""
### Tips for best results:
- Upload a clear, well-lit image
- Focus on affected leaves or plant parts
- Include multiple symptoms if possible
""")
with gr.Column(scale=2):
analysis_output = gr.Markdown(label="Analysis")
with gr.Group():
gr.Markdown("### Plant Information")
with gr.Row():
crop_output = gr.Textbox(label="Crop")
disease_output = gr.Textbox(label="Disease")
treatment_output = gr.Textbox(label="Treatment", lines=3)
medicine_output = gr.Textbox(label="Recommended Products", lines=3)
with gr.TabItem("Q&A"):
with gr.Row():
with gr.Column():
question_input = gr.Textbox(
lines=2,
placeholder="Ask a question like 'How do I treat early blight in tomatoes?'",
label="Your Question"
)
question_submit = gr.Button("Get Answer", variant="primary")
gr.Markdown("""
### Example questions:
- How do I treat early blight in tomatoes?
- What are the symptoms of powdery mildew?
- What chemical controls work on apple scab?
- What causes leaf curl in peach trees?
""")
with gr.Column():
qa_output = gr.Markdown(label="Answer")
# Set up event handlers
image_submit.click(
fn=lambda img: process_image(img, model, class_names, transform, df),
inputs=[image_input],
outputs=[analysis_output, crop_output, disease_output, treatment_output, medicine_output]
)
question_submit.click(
fn=lambda q: (answer_question(q, df), "", "", "", ""),
inputs=[question_input],
outputs=[qa_output, crop_output, disease_output, treatment_output, medicine_output]
)
# Example questions
gr.Examples(
examples=[
["How do I treat early blight in tomatoes?"],
["What are the symptoms of powdery mildew?"],
["What chemical controls work on apple scab?"],
["How can I prevent late blight in potatoes?"],
["What causes black spots on rose leaves?"]
],
inputs=question_input
)
# Example images
gr.Examples(
examples=[
"example_images/tomato_early_blight.jpg",
"example_images/apple_scab.jpg",
"example_images/corn_rust.jpg"
],
inputs=image_input
)
return app
# Launch the app
if __name__ == "__main__":
app = main()
app.launch()