|
|
import os |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
from torch.utils.data import DataLoader, random_split |
|
|
from torchvision import datasets, models, transforms |
|
|
import json |
|
|
from tqdm import tqdm |
|
|
import time |
|
|
|
|
|
def train_quick_model(): |
|
|
""" |
|
|
Trains a quick model on a small subset of data for demonstration purposes. |
|
|
This will show training epochs in the console. |
|
|
""" |
|
|
print("Starting quick model training...") |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
|
|
|
data_transforms = { |
|
|
'train': transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.RandomHorizontalFlip(), |
|
|
transforms.RandomRotation(15), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
|
]), |
|
|
'val': transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
|
]), |
|
|
} |
|
|
|
|
|
|
|
|
print("Creating a small synthetic dataset for demonstration...") |
|
|
|
|
|
|
|
|
dataset_path = 'PlantVillage' |
|
|
if not os.path.exists(dataset_path): |
|
|
os.makedirs(dataset_path, exist_ok=True) |
|
|
print(f"Created directory {dataset_path}") |
|
|
|
|
|
|
|
|
classes = ["Tomato_Early_blight", "Tomato_healthy", "Apple_scab", "Apple_healthy"] |
|
|
for cls in classes: |
|
|
os.makedirs(os.path.join(dataset_path, cls), exist_ok=True) |
|
|
print(f"Created class directory: {cls}") |
|
|
|
|
|
|
|
|
for i in range(10): |
|
|
img = torch.randint(0, 256, (3, 224, 224), dtype=torch.uint8) |
|
|
img_tensor = transforms.ToPILImage()(img) |
|
|
img_tensor.save(os.path.join(dataset_path, cls, f"image_{i}.jpg")) |
|
|
print(f"Created 10 dummy images for {cls}") |
|
|
|
|
|
|
|
|
print("Loading dataset...") |
|
|
try: |
|
|
dataset = datasets.ImageFolder(dataset_path, transform=data_transforms['train']) |
|
|
|
|
|
|
|
|
train_size = int(0.8 * len(dataset)) |
|
|
val_size = len(dataset) - train_size |
|
|
train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) |
|
|
|
|
|
|
|
|
train_dataset.dataset.transform = data_transforms['train'] |
|
|
val_dataset.dataset.transform = data_transforms['val'] |
|
|
|
|
|
|
|
|
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True) |
|
|
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False) |
|
|
|
|
|
|
|
|
class_names = dataset.classes |
|
|
with open('class_names.json', 'w') as f: |
|
|
json.dump(class_names, f) |
|
|
|
|
|
print(f"Dataset loaded with {len(class_names)} classes") |
|
|
print(f"Training set: {len(train_dataset)} images") |
|
|
print(f"Validation set: {len(val_dataset)} images") |
|
|
|
|
|
|
|
|
print("Loading pre-trained model...") |
|
|
model = models.resnet18(weights=None) |
|
|
|
|
|
|
|
|
num_ftrs = model.fc.in_features |
|
|
model.fc = nn.Linear(num_ftrs, len(class_names)) |
|
|
|
|
|
model = model.to(device) |
|
|
|
|
|
|
|
|
criterion = nn.CrossEntropyLoss() |
|
|
optimizer = optim.Adam(model.parameters(), lr=0.001) |
|
|
|
|
|
|
|
|
num_epochs = 5 |
|
|
|
|
|
print(f"Starting training for {num_epochs} epochs...") |
|
|
for epoch in range(num_epochs): |
|
|
print(f'Epoch {epoch+1}/{num_epochs}') |
|
|
print('-' * 10) |
|
|
|
|
|
|
|
|
model.train() |
|
|
running_loss = 0.0 |
|
|
running_corrects = 0 |
|
|
|
|
|
|
|
|
for inputs, labels in tqdm(train_loader, desc=f"Training"): |
|
|
inputs = inputs.to(device) |
|
|
labels = labels.to(device) |
|
|
|
|
|
|
|
|
optimizer.zero_grad() |
|
|
|
|
|
|
|
|
outputs = model(inputs) |
|
|
_, preds = torch.max(outputs, 1) |
|
|
loss = criterion(outputs, labels) |
|
|
|
|
|
|
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
|
|
|
|
|
|
running_loss += loss.item() * inputs.size(0) |
|
|
running_corrects += torch.sum(preds == labels.data) |
|
|
|
|
|
epoch_loss = running_loss / len(train_dataset) |
|
|
epoch_acc = running_corrects.double() / len(train_dataset) |
|
|
|
|
|
print(f'Training Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}') |
|
|
|
|
|
|
|
|
model.eval() |
|
|
running_loss = 0.0 |
|
|
running_corrects = 0 |
|
|
|
|
|
|
|
|
for inputs, labels in tqdm(val_loader, desc=f"Validation"): |
|
|
inputs = inputs.to(device) |
|
|
labels = labels.to(device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(inputs) |
|
|
_, preds = torch.max(outputs, 1) |
|
|
loss = criterion(outputs, labels) |
|
|
|
|
|
|
|
|
running_loss += loss.item() * inputs.size(0) |
|
|
running_corrects += torch.sum(preds == labels.data) |
|
|
|
|
|
epoch_loss = running_loss / len(val_dataset) |
|
|
epoch_acc = running_corrects.double() / len(val_dataset) |
|
|
|
|
|
print(f'Validation Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}') |
|
|
print() |
|
|
|
|
|
|
|
|
torch.save(model.state_dict(), 'plant_disease_model.pth') |
|
|
print("Model saved as plant_disease_model.pth") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error during training: {e}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
start_time = time.time() |
|
|
train_quick_model() |
|
|
end_time = time.time() |
|
|
print(f"Training completed in {end_time - start_time:.2f} seconds") |
|
|
|