|
|
|
|
|
import torch |
|
|
import random |
|
|
|
|
|
def classify_disease(image, model, class_names, transform): |
|
|
if image is None: |
|
|
return None, None |
|
|
|
|
|
|
|
|
img_tensor = transform(image).unsqueeze(0) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
try: |
|
|
outputs = model(img_tensor) |
|
|
_, predicted = torch.max(outputs, 1) |
|
|
class_idx = predicted.item() |
|
|
|
|
|
|
|
|
if not hasattr(model, '_is_trained') or not model._is_trained: |
|
|
print("Using random prediction since model is untrained") |
|
|
class_idx = random.randint(0, len(class_names) - 1) |
|
|
except Exception as e: |
|
|
print(f"Error during prediction: {e}") |
|
|
|
|
|
class_idx = random.randint(0, len(class_names) - 1) |
|
|
|
|
|
if class_idx < len(class_names): |
|
|
class_name = class_names[class_idx] |
|
|
|
|
|
|
|
|
parts = class_name.split("___") |
|
|
crop = parts[0].replace("_", " ") |
|
|
disease = parts[1].replace("_", " ") if len(parts) > 1 else "healthy" |
|
|
|
|
|
return crop, disease |
|
|
|
|
|
return "Unknown", "Unknown" |
|
|
|