import os import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torchvision import datasets, models, transforms import json from tqdm import tqdm def train_model(): # Set device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Data transformations data_transforms = { 'train': transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.RandomRotation(15), transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]), 'val': transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]), } # Load the dataset data_dir = "https://datasets-server.huggingface.co/rows?dataset=ButterChicken98%2Fplantvillage-image-text-pairs&config=default&split=train&offset=0&length=100" try: image_datasets = { 'train': datasets.ImageFolder(os.path.join(data_dir, 'train'), data_transforms['train']), 'val': datasets.ImageFolder(os.path.join(data_dir, 'val'), data_transforms['val']) } dataloaders = { 'train': DataLoader(image_datasets['train'], batch_size=32, shuffle=True, num_workers=4), 'val': DataLoader(image_datasets['val'], batch_size=32, shuffle=False, num_workers=4) } dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']} class_names = image_datasets['train'].classes # Save class names to a JSON file with open('class_names.json', 'w') as f: json.dump(class_names, f) print(f"Dataset loaded successfully with {len(class_names)} classes") print(f"Training set size: {dataset_sizes['train']}") print(f"Validation set size: {dataset_sizes['val']}") # Load a pre-trained model from torchvision.models import ResNet50_Weights model = models.resnet50(weights=ResNet50_Weights.DEFAULT) # Modify the final layer for our number of classes num_ftrs = model.fc.in_features model.fc = nn.Linear(num_ftrs, len(class_names)) model = model.to(device) # Define loss function and optimizer criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001) scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1) # Train the model num_epochs = 15 best_acc = 0.0 for epoch in range(num_epochs): print(f'Epoch {epoch+1}/{num_epochs}') print('-' * 10) # Each epoch has a training and validation phase for phase in ['train', 'val']: if phase == 'train': model.train() else: model.eval() running_loss = 0.0 running_corrects = 0 # Iterate over data for inputs, labels in tqdm(dataloaders[phase]): inputs = inputs.to(device) labels = labels.to(device) # Zero the parameter gradients optimizer.zero_grad() # Forward pass with torch.set_grad_enabled(phase == 'train'): outputs = model(inputs) _, preds = torch.max(outputs, 1) loss = criterion(outputs, labels) # Backward + optimize only if in training phase if phase == 'train': loss.backward() optimizer.step() # Statistics running_loss += loss.item() * inputs.size(0) running_corrects += torch.sum(preds == labels.data) if phase == 'train': scheduler.step() epoch_loss = running_loss / dataset_sizes[phase] epoch_acc = running_corrects.double() / dataset_sizes[phase] print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}') # Save the best model if phase == 'val' and epoch_acc > best_acc: best_acc = epoch_acc torch.save(model.state_dict(), 'plant_disease_model.pth') print() print(f'Best val Acc: {best_acc:.4f}') print('Model saved as plant_disease_model.pth') except Exception as e: print(f"Error during training: {e}") print("Please make sure the dataset is properly organized in the following structure:") print("PlantVillage/") print("├── train/") print("│ ├── Apple___Apple_scab/") print("│ ├── Apple___Black_rot/") print("│ └── ... (other classes)") print("└── val/") print(" ├── Apple___Apple_scab/") print(" ├── Apple___Black_rot/") print(" └── ... (other classes)") if __name__ == "__main__": train_model()