File size: 10,464 Bytes
9bb53cb
95617f0
9bb53cb
61aeb8a
 
 
 
 
 
b78ac6b
ddadd74
61aeb8a
 
 
9bb53cb
61aeb8a
b78ac6b
 
 
 
61aeb8a
 
 
 
de713cb
 
 
 
 
 
 
 
 
 
 
 
61aeb8a
 
 
de713cb
 
 
 
 
 
 
61aeb8a
de713cb
61aeb8a
 
 
de713cb
 
 
 
 
53cb889
de713cb
 
61aeb8a
 
 
 
 
 
de713cb
61aeb8a
 
 
 
 
 
 
 
 
 
de713cb
61aeb8a
de713cb
61aeb8a
 
 
de713cb
 
 
 
 
 
 
 
 
 
 
 
61aeb8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de713cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61aeb8a
de713cb
 
 
 
 
 
 
 
61aeb8a
de713cb
 
 
61aeb8a
 
6ad7bf5
 
 
 
 
 
 
 
 
5ec35b6
6ad7bf5
 
 
 
 
 
 
 
ddadd74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53cb889
ddadd74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
# ===================== IMPORTS =====================
import streamlit as st
import os
import json
import difflib
import pickle
import torch
import numpy as np
from PIL import Image
from tf_keras.models import load_model  # Use tf.keras for legacy compatibility
from tf_keras.preprocessing.image import img_to_array, load_img  # Use tf.keras
from transformers import CLIPProcessor, CLIPModel
from sentence_transformers import SentenceTransformer, CrossEncoder
from langdetect import detect

# ===================== PATHS =====================
save_dir = "Main_py"  # Matches your save_dir convention, assumed as Main_py in Space
model_path = os.path.join(save_dir, "best_cnn_model_finetuned.keras")
label_path = os.path.join(save_dir, "label_encoder.pkl")
json_path = os.path.join(save_dir, "banana_disease_knowledge_base_updated_shuffled.json")

# ===================== LOAD MODELS & DATA =====================
@st.cache_resource
def load_cnn_clip_kb():
    try:
        model = load_model(model_path)
        with open(label_path, "rb") as f:
            le = pickle.load(f)
        with open(json_path, "r", encoding="utf-8") as f:
            kb_data = json.load(f)
        clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
        clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
        return model, le, kb_data, clip_model, clip_processor
    except Exception as e:
        st.error(f"Error loading models or data: {str(e)}")
        return None, None, None, None, None

@st.cache_resource
def load_nlp_models():
    try:
        embedder = SentenceTransformer("sentence-transformers/paraphrase-MiniLM-L6-v2")
        cross_encoder = CrossEncoder("cross-encoder/mmarco-mMiniLMv2-L12-H384-v1")
        return embedder, cross_encoder
    except Exception as e:
        st.error(f"Error loading NLP models: {str(e)}")
        return None, None

# Load models
model, le, kb_data, clip_model, clip_processor = load_cnn_clip_kb()
embedder, cross_encoder = load_nlp_models()

# Check if models loaded successfully
if model is None or le is None or kb_data is None or clip_model is None or clip_processor is None:
    st.error("Failed to load CNN/CLIP models or knowledge base. Please check file paths and model files.")
    st.stop()
if embedder is None or cross_encoder is None:
    st.error("Failed to load NLP models. Check dependencies.")
    st.stop()

# ===================== CLIP FILTER =====================
def verify_image_with_clip(image_path):
    prompts = ["a photo of a banana leaf", "a photo of something that is not a banana leaf"]
    try:
        image = Image.open(image_path).convert("RGB")
    except Exception as e:
        return 'REJECTED', f'Invalid image file: {e}', 0.0

    inputs = clip_processor(text=prompts, images=image, return_tensors="pt", padding=True)
    with torch.no_grad():
        outputs = clip_model(**inputs)
        probs = outputs.logits_per_image.softmax(dim=1).cpu().numpy()[0]

    banana_score, not_banana_score = probs[0], probs[1]
    rejection_factor = 3.0

    if banana_score >= not_banana_score * rejection_factor:
        return 'ACCEPTED', 'banana leaf', banana_score
    else:
        return 'REJECTED', 'Not a banana leaf', not_banana_score

# ===================== CNN PREDICTION =====================
def predict_disease(image_path, target_size=(224, 224)):
    try:
        image = load_img(image_path, target_size=target_size)
        img_array = img_to_array(image) / 255.0
        img_array = np.expand_dims(img_array, axis=0)
        preds = model.predict(img_array)[0]
        idx = np.argmax(preds)
        label = le.inverse_transform([idx])[0]
        confidence = preds[idx]
        return label, confidence, image
    except Exception as e:
        st.error(f"Error predicting disease: {str(e)}")
        return None, None, None

# ===================== FUZZY MARATHI OUTPUT =====================
def match_disease_name_fuzzy(predicted_name):
    disease_names = [entry["Disease"].strip().lower() for entry in kb_data]
    matches = difflib.get_close_matches(predicted_name.strip().lower(), disease_names, n=1, cutoff=0.5)
    if matches:
        for entry in kb_data:
            if entry["Disease"].strip().lower() == matches[0]:
                return entry
    return None

def get_marathi_recommendation_fuzzy(predicted_disease, confidence=None):
    entry = match_disease_name_fuzzy(predicted_disease)
    if entry:
        return {
            "पिक": entry.get("Crop", "केळी"),
            "रोग": entry.get("Local_Name", {}).get("mr", predicted_disease),
            "लक्षणे": entry.get("Symptoms_MR", ""),
            "कारण": entry.get("Cause_MR", ""),
            "किटकनाशके": entry.get("Pesticide_MR", ""),
            "किटकनाशक शिफारस": entry.get("Pesticide_Recommendation", {}).get("mr", ""),
            "नियंत्रण पद्धती": entry.get("Management_MR", ""),
            "रोगजन्य घटक": entry.get("Pathogen", ""),
            "विश्वासार्हता": f"{confidence:.2%}" if confidence else "N/A"
        }
    return None

# ===================== NLP PREDICTION =====================
def detect_language(query: str) -> str:
    try:
        lang = detect(query)
        return lang if lang in ["mr", "hi"] else "en"
    except:
        return "en"

def predict_disease_from_text(query: str):
    try:
        lang = detect_language(query)
        query_emb = embedder.encode([query], normalize_embeddings=True)

        symptom_key = f"Symptoms_{lang.upper()}" if lang != "en" else "Symptoms"
        pairs = [[query, entry.get(symptom_key, "")] for entry in kb_data]
        scores = cross_encoder.predict(pairs)
        best_idx = np.argmax(scores)

        if scores[best_idx] < 0.2:
            return {
                "message": {
                    "mr": "हा रोग आमच्या डेटाबेसमध्ये नाही.",
                    "hi": "यह रोग हमारे डेटाबेस में नहीं है।",
                    "en": "This disease is not in our database."
                }[lang]
            }

        entry = kb_data[best_idx]
        return {
            "पिक": entry.get("Crop", "केळी"),
            "रोग": entry["Local_Name"].get(lang, entry["Disease"]),
            "लक्षणे": entry.get(symptom_key, ""),
            "कारण": entry.get(f"Cause_{lang.upper()}", entry.get("Cause", "")),
            "किटकनाशक शिफारस": entry.get("Pesticide_Recommendation", {}).get(lang, ""),
            "किटकनाशके": entry.get("Pesticide", ""),
            "रोगजन्य घटक": entry.get("Pathogen", ""),
            "नियंत्रण पद्धती": entry.get(f"Management_{lang.upper()}", entry.get("Management_Practices", "")),
        }
    except Exception as e:
        st.error(f"Error in text prediction: {str(e)}")
        return {"message": "Error processing text input."}

# ===================== STREAMLIT UI =====================
st.set_page_config(page_title="🍌 Banana Disease Detection (CNN + NLP)", layout="centered")
st.title("केळीच्या पानांवरील रोगांचे निदान")
st.markdown("प्रतिमा किंवा लक्षणे वापरून केळीवरील रोगांचे निदान करा (मराठी, हिंदी, इंग्रजी भाषांमध्ये).")

option = st.radio("इनपुट पद्धत निवडा:", ["Image Only", "Text Only", "Both"])

# ===================== IMAGE FLOW =====================
if option in ["Image Only", "Both"]:
    st.subheader("प्रतिमा अपलोड करा")
    uploaded_img = st.file_uploader("JPG / PNG / JPEG", type=["png", "jpg", "jpeg"])
    if uploaded_img:
        temp_path = "temp_uploaded.jpg"
        try:
            with open(temp_path, "wb") as f:
                f.write(uploaded_img.getbuffer())

            st.info("CLIP मॉडेलद्वारे पडताळणी करत आहे...")

            status, reason, clip_conf = verify_image_with_clip(temp_path)
            if status == "REJECTED":
                st.error(f"CLIP नकार: {reason} [विश्वासार्हता: {clip_conf:.2f}]")
            else:
                st.success(f"CLIP मंजूरी: शक्यतो केळीचे पान [विश्वासार्हता: {clip_conf:.2f}]")
                pred_disease, cnn_conf, img = predict_disease(temp_path)
                if pred_disease:
                    st.markdown(f"**ओळखलेला रोग:** {pred_disease} (विश्वासार्हता: {cnn_conf:.2%})")

                    marathi_info = get_marathi_recommendation_fuzzy(predicted_disease=pred_disease, confidence=cnn_conf)
                    if marathi_info:
                        st.subheader("मराठी शिफारस:")
                        for k, v in marathi_info.items():
                            st.markdown(f"**{k}**: {v}")
                    else:
                        st.warning("ज्ञानतळात रोगासाठी माहिती नाही.")
                    st.image(img, caption=f"अपलोड केलेली प्रतिमा - {pred_disease} ({cnn_conf:.2%})", width=300)
                else:
                    st.error("Failed to predict disease from image.")
        except Exception as e:
            st.error(f"Error processing image: {str(e)}")
        finally:
            if os.path.exists(temp_path):
                os.remove(temp_path)

# ===================== TEXT FLOW =====================
if option in ["Text Only", "Both"]:
    st.subheader("लक्षणे लिहा")
    symptoms = st.text_area("लक्षणे (मराठी / हिंदी / इंग्रजी):")
    if symptoms and st.button("रोग ओळखा"):
        result = predict_disease_from_text(symptoms)
        if "message" in result:
            st.warning(result["message"])
        else:
            st.subheader("शिफारस:")
            for k, v in result.items():
                st.markdown(f"**{k}**: {v}")