import os import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, random_split from torchvision import datasets, models, transforms import json from tqdm import tqdm import time def train_model_with_huggingface_data(): """ Trains a model using the PlantVillage dataset downloaded from Hugging Face. """ print("Starting model training with Hugging Face dataset...") # Set device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Data transformations data_transforms = { 'train': transforms.Compose([ transforms.Resize((224, 224)), transforms.RandomHorizontalFlip(), transforms.RandomRotation(15), transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]), 'val': transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]), } # Load the dataset print("Loading dataset...") try: dataset_path = 'PlantVillage' if not os.path.exists(dataset_path): print(f"Error: Dataset directory {dataset_path} not found.") print("Please run download_huggingface_dataset.py first.") return dataset = datasets.ImageFolder(dataset_path, transform=data_transforms['train']) # Split into train and validation sets train_size = int(0.8 * len(dataset)) val_size = len(dataset) - train_size train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) # Apply different transforms to the splits train_dataset.dataset.transform = data_transforms['train'] val_dataset.dataset.transform = data_transforms['val'] # Create data loaders train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4) val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4) # Save class names class_names = dataset.classes with open('class_names.json', 'w') as f: json.dump(class_names, f) print(f"Dataset loaded with {len(class_names)} classes") print(f"Training set: {len(train_dataset)} images") print(f"Validation set: {len(val_dataset)} images") # Load a pre-trained model print("Loading pre-trained model...") model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) # Modify the final layer for our number of classes num_ftrs = model.fc.in_features model.fc = nn.Linear(num_ftrs, len(class_names)) model = model.to(device) # Define loss function and optimizer criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001) scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1) # Train the model num_epochs = 10 best_acc = 0.0 print(f"Starting training for {num_epochs} epochs...") for epoch in range(num_epochs): print(f'Epoch {epoch+1}/{num_epochs}') print('-' * 10) # Training phase model.train() running_loss = 0.0 running_corrects = 0 # Iterate over data for inputs, labels in tqdm(train_loader, desc=f"Training"): inputs = inputs.to(device) labels = labels.to(device) # Zero the parameter gradients optimizer.zero_grad() # Forward pass outputs = model(inputs) _, preds = torch.max(outputs, 1) loss = criterion(outputs, labels) # Backward + optimize loss.backward() optimizer.step() # Statistics running_loss += loss.item() * inputs.size(0) running_corrects += torch.sum(preds == labels.data) scheduler.step() epoch_loss = running_loss / len(train_dataset) epoch_acc = running_corrects.double() / len(train_dataset) print(f'Training Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}') # Validation phase model.eval() running_loss = 0.0 running_corrects = 0 # Iterate over data for inputs, labels in tqdm(val_loader, desc=f"Validation"): inputs = inputs.to(device) labels = labels.to(device) # Forward pass with torch.no_grad(): outputs = model(inputs) _, preds = torch.max(outputs, 1) loss = criterion(outputs, labels) # Statistics running_loss += loss.item() * inputs.size(0) running_corrects += torch.sum(preds == labels.data) epoch_loss = running_loss / len(val_dataset) epoch_acc = running_corrects.double() / len(val_dataset) print(f'Validation Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}') # Save the best model if epoch_acc > best_acc: best_acc = epoch_acc torch.save(model.state_dict(), 'plant_disease_model.pth') print(f"Saved new best model with accuracy: {best_acc:.4f}") print() print(f'Best val Acc: {best_acc:.4f}') print('Model saved as plant_disease_model.pth') except Exception as e: print(f"Error during training: {e}") if __name__ == "__main__": start_time = time.time() train_model_with_huggingface_data() end_time = time.time() print(f"Training completed in {(end_time - start_time)/60:.2f} minutes")