import os import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, random_split from torchvision import datasets, models, transforms import json from tqdm import tqdm import time def train_quick_model(): """ Trains a quick model on a small subset of data for demonstration purposes. This will show training epochs in the console. """ print("Starting quick model training...") # Set device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Data transformations data_transforms = { 'train': transforms.Compose([ transforms.Resize((224, 224)), transforms.RandomHorizontalFlip(), transforms.RandomRotation(15), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]), 'val': transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]), } # Create a small synthetic dataset for demonstration print("Creating a small synthetic dataset for demonstration...") # Check if PlantVillage dataset exists dataset_path = 'PlantVillage' if not os.path.exists(dataset_path): os.makedirs(dataset_path, exist_ok=True) print(f"Created directory {dataset_path}") # Create some example class directories classes = ["Tomato_Early_blight", "Tomato_healthy", "Apple_scab", "Apple_healthy"] for cls in classes: os.makedirs(os.path.join(dataset_path, cls), exist_ok=True) print(f"Created class directory: {cls}") # Create a few dummy images (1x1 pixel) for each class for i in range(10): img = torch.randint(0, 256, (3, 224, 224), dtype=torch.uint8) img_tensor = transforms.ToPILImage()(img) img_tensor.save(os.path.join(dataset_path, cls, f"image_{i}.jpg")) print(f"Created 10 dummy images for {cls}") # Load the dataset print("Loading dataset...") try: dataset = datasets.ImageFolder(dataset_path, transform=data_transforms['train']) # Split into train and validation sets train_size = int(0.8 * len(dataset)) val_size = len(dataset) - train_size train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) # Apply different transforms to the splits train_dataset.dataset.transform = data_transforms['train'] val_dataset.dataset.transform = data_transforms['val'] # Create data loaders train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False) # Save class names class_names = dataset.classes with open('class_names.json', 'w') as f: json.dump(class_names, f) print(f"Dataset loaded with {len(class_names)} classes") print(f"Training set: {len(train_dataset)} images") print(f"Validation set: {len(val_dataset)} images") # Load a pre-trained model print("Loading pre-trained model...") model = models.resnet18(weights=None) # Use a smaller model for quicker training # Modify the final layer for our number of classes num_ftrs = model.fc.in_features model.fc = nn.Linear(num_ftrs, len(class_names)) model = model.to(device) # Define loss function and optimizer criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001) # Train the model for a few epochs num_epochs = 5 # Just a few epochs for demonstration print(f"Starting training for {num_epochs} epochs...") for epoch in range(num_epochs): print(f'Epoch {epoch+1}/{num_epochs}') print('-' * 10) # Training phase model.train() running_loss = 0.0 running_corrects = 0 # Iterate over data for inputs, labels in tqdm(train_loader, desc=f"Training"): inputs = inputs.to(device) labels = labels.to(device) # Zero the parameter gradients optimizer.zero_grad() # Forward pass outputs = model(inputs) _, preds = torch.max(outputs, 1) loss = criterion(outputs, labels) # Backward + optimize loss.backward() optimizer.step() # Statistics running_loss += loss.item() * inputs.size(0) running_corrects += torch.sum(preds == labels.data) epoch_loss = running_loss / len(train_dataset) epoch_acc = running_corrects.double() / len(train_dataset) print(f'Training Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}') # Validation phase model.eval() running_loss = 0.0 running_corrects = 0 # Iterate over data for inputs, labels in tqdm(val_loader, desc=f"Validation"): inputs = inputs.to(device) labels = labels.to(device) # Forward pass with torch.no_grad(): outputs = model(inputs) _, preds = torch.max(outputs, 1) loss = criterion(outputs, labels) # Statistics running_loss += loss.item() * inputs.size(0) running_corrects += torch.sum(preds == labels.data) epoch_loss = running_loss / len(val_dataset) epoch_acc = running_corrects.double() / len(val_dataset) print(f'Validation Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}') print() # Save the model torch.save(model.state_dict(), 'plant_disease_model.pth') print("Model saved as plant_disease_model.pth") except Exception as e: print(f"Error during training: {e}") if __name__ == "__main__": start_time = time.time() train_quick_model() end_time = time.time() print(f"Training completed in {end_time - start_time:.2f} seconds")