File size: 10,464 Bytes
9bb53cb 95617f0 9bb53cb 61aeb8a b78ac6b ddadd74 61aeb8a 9bb53cb 61aeb8a b78ac6b 61aeb8a de713cb 61aeb8a de713cb 61aeb8a de713cb 61aeb8a de713cb 53cb889 de713cb 61aeb8a de713cb 61aeb8a de713cb 61aeb8a de713cb 61aeb8a de713cb 61aeb8a de713cb 61aeb8a de713cb 61aeb8a de713cb 61aeb8a 6ad7bf5 5ec35b6 6ad7bf5 ddadd74 53cb889 ddadd74 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 |
# ===================== IMPORTS =====================
import streamlit as st
import os
import json
import difflib
import pickle
import torch
import numpy as np
from PIL import Image
from tf_keras.models import load_model # Use tf.keras for legacy compatibility
from tf_keras.preprocessing.image import img_to_array, load_img # Use tf.keras
from transformers import CLIPProcessor, CLIPModel
from sentence_transformers import SentenceTransformer, CrossEncoder
from langdetect import detect
# ===================== PATHS =====================
save_dir = "Main_py" # Matches your save_dir convention, assumed as Main_py in Space
model_path = os.path.join(save_dir, "best_cnn_model_finetuned.keras")
label_path = os.path.join(save_dir, "label_encoder.pkl")
json_path = os.path.join(save_dir, "banana_disease_knowledge_base_updated_shuffled.json")
# ===================== LOAD MODELS & DATA =====================
@st.cache_resource
def load_cnn_clip_kb():
try:
model = load_model(model_path)
with open(label_path, "rb") as f:
le = pickle.load(f)
with open(json_path, "r", encoding="utf-8") as f:
kb_data = json.load(f)
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
return model, le, kb_data, clip_model, clip_processor
except Exception as e:
st.error(f"Error loading models or data: {str(e)}")
return None, None, None, None, None
@st.cache_resource
def load_nlp_models():
try:
embedder = SentenceTransformer("sentence-transformers/paraphrase-MiniLM-L6-v2")
cross_encoder = CrossEncoder("cross-encoder/mmarco-mMiniLMv2-L12-H384-v1")
return embedder, cross_encoder
except Exception as e:
st.error(f"Error loading NLP models: {str(e)}")
return None, None
# Load models
model, le, kb_data, clip_model, clip_processor = load_cnn_clip_kb()
embedder, cross_encoder = load_nlp_models()
# Check if models loaded successfully
if model is None or le is None or kb_data is None or clip_model is None or clip_processor is None:
st.error("Failed to load CNN/CLIP models or knowledge base. Please check file paths and model files.")
st.stop()
if embedder is None or cross_encoder is None:
st.error("Failed to load NLP models. Check dependencies.")
st.stop()
# ===================== CLIP FILTER =====================
def verify_image_with_clip(image_path):
prompts = ["a photo of a banana leaf", "a photo of something that is not a banana leaf"]
try:
image = Image.open(image_path).convert("RGB")
except Exception as e:
return 'REJECTED', f'Invalid image file: {e}', 0.0
inputs = clip_processor(text=prompts, images=image, return_tensors="pt", padding=True)
with torch.no_grad():
outputs = clip_model(**inputs)
probs = outputs.logits_per_image.softmax(dim=1).cpu().numpy()[0]
banana_score, not_banana_score = probs[0], probs[1]
rejection_factor = 3.0
if banana_score >= not_banana_score * rejection_factor:
return 'ACCEPTED', 'banana leaf', banana_score
else:
return 'REJECTED', 'Not a banana leaf', not_banana_score
# ===================== CNN PREDICTION =====================
def predict_disease(image_path, target_size=(224, 224)):
try:
image = load_img(image_path, target_size=target_size)
img_array = img_to_array(image) / 255.0
img_array = np.expand_dims(img_array, axis=0)
preds = model.predict(img_array)[0]
idx = np.argmax(preds)
label = le.inverse_transform([idx])[0]
confidence = preds[idx]
return label, confidence, image
except Exception as e:
st.error(f"Error predicting disease: {str(e)}")
return None, None, None
# ===================== FUZZY MARATHI OUTPUT =====================
def match_disease_name_fuzzy(predicted_name):
disease_names = [entry["Disease"].strip().lower() for entry in kb_data]
matches = difflib.get_close_matches(predicted_name.strip().lower(), disease_names, n=1, cutoff=0.5)
if matches:
for entry in kb_data:
if entry["Disease"].strip().lower() == matches[0]:
return entry
return None
def get_marathi_recommendation_fuzzy(predicted_disease, confidence=None):
entry = match_disease_name_fuzzy(predicted_disease)
if entry:
return {
"पिक": entry.get("Crop", "केळी"),
"रोग": entry.get("Local_Name", {}).get("mr", predicted_disease),
"लक्षणे": entry.get("Symptoms_MR", ""),
"कारण": entry.get("Cause_MR", ""),
"किटकनाशके": entry.get("Pesticide_MR", ""),
"किटकनाशक शिफारस": entry.get("Pesticide_Recommendation", {}).get("mr", ""),
"नियंत्रण पद्धती": entry.get("Management_MR", ""),
"रोगजन्य घटक": entry.get("Pathogen", ""),
"विश्वासार्हता": f"{confidence:.2%}" if confidence else "N/A"
}
return None
# ===================== NLP PREDICTION =====================
def detect_language(query: str) -> str:
try:
lang = detect(query)
return lang if lang in ["mr", "hi"] else "en"
except:
return "en"
def predict_disease_from_text(query: str):
try:
lang = detect_language(query)
query_emb = embedder.encode([query], normalize_embeddings=True)
symptom_key = f"Symptoms_{lang.upper()}" if lang != "en" else "Symptoms"
pairs = [[query, entry.get(symptom_key, "")] for entry in kb_data]
scores = cross_encoder.predict(pairs)
best_idx = np.argmax(scores)
if scores[best_idx] < 0.2:
return {
"message": {
"mr": "हा रोग आमच्या डेटाबेसमध्ये नाही.",
"hi": "यह रोग हमारे डेटाबेस में नहीं है।",
"en": "This disease is not in our database."
}[lang]
}
entry = kb_data[best_idx]
return {
"पिक": entry.get("Crop", "केळी"),
"रोग": entry["Local_Name"].get(lang, entry["Disease"]),
"लक्षणे": entry.get(symptom_key, ""),
"कारण": entry.get(f"Cause_{lang.upper()}", entry.get("Cause", "")),
"किटकनाशक शिफारस": entry.get("Pesticide_Recommendation", {}).get(lang, ""),
"किटकनाशके": entry.get("Pesticide", ""),
"रोगजन्य घटक": entry.get("Pathogen", ""),
"नियंत्रण पद्धती": entry.get(f"Management_{lang.upper()}", entry.get("Management_Practices", "")),
}
except Exception as e:
st.error(f"Error in text prediction: {str(e)}")
return {"message": "Error processing text input."}
# ===================== STREAMLIT UI =====================
st.set_page_config(page_title="🍌 Banana Disease Detection (CNN + NLP)", layout="centered")
st.title("केळीच्या पानांवरील रोगांचे निदान")
st.markdown("प्रतिमा किंवा लक्षणे वापरून केळीवरील रोगांचे निदान करा (मराठी, हिंदी, इंग्रजी भाषांमध्ये).")
option = st.radio("इनपुट पद्धत निवडा:", ["Image Only", "Text Only", "Both"])
# ===================== IMAGE FLOW =====================
if option in ["Image Only", "Both"]:
st.subheader("प्रतिमा अपलोड करा")
uploaded_img = st.file_uploader("JPG / PNG / JPEG", type=["png", "jpg", "jpeg"])
if uploaded_img:
temp_path = "temp_uploaded.jpg"
try:
with open(temp_path, "wb") as f:
f.write(uploaded_img.getbuffer())
st.info("CLIP मॉडेलद्वारे पडताळणी करत आहे...")
status, reason, clip_conf = verify_image_with_clip(temp_path)
if status == "REJECTED":
st.error(f"CLIP नकार: {reason} [विश्वासार्हता: {clip_conf:.2f}]")
else:
st.success(f"CLIP मंजूरी: शक्यतो केळीचे पान [विश्वासार्हता: {clip_conf:.2f}]")
pred_disease, cnn_conf, img = predict_disease(temp_path)
if pred_disease:
st.markdown(f"**ओळखलेला रोग:** {pred_disease} (विश्वासार्हता: {cnn_conf:.2%})")
marathi_info = get_marathi_recommendation_fuzzy(predicted_disease=pred_disease, confidence=cnn_conf)
if marathi_info:
st.subheader("मराठी शिफारस:")
for k, v in marathi_info.items():
st.markdown(f"**{k}**: {v}")
else:
st.warning("ज्ञानतळात रोगासाठी माहिती नाही.")
st.image(img, caption=f"अपलोड केलेली प्रतिमा - {pred_disease} ({cnn_conf:.2%})", width=300)
else:
st.error("Failed to predict disease from image.")
except Exception as e:
st.error(f"Error processing image: {str(e)}")
finally:
if os.path.exists(temp_path):
os.remove(temp_path)
# ===================== TEXT FLOW =====================
if option in ["Text Only", "Both"]:
st.subheader("लक्षणे लिहा")
symptoms = st.text_area("लक्षणे (मराठी / हिंदी / इंग्रजी):")
if symptoms and st.button("रोग ओळखा"):
result = predict_disease_from_text(symptoms)
if "message" in result:
st.warning(result["message"])
else:
st.subheader("शिफारस:")
for k, v in result.items():
st.markdown(f"**{k}**: {v}") |