File size: 5,655 Bytes
91cb873
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
06ee41a
 
91cb873
 
 
 
 
 
 
 
 
 
 
 
 
17e606f
91cb873
 
 
 
06ee41a
91cb873
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms
import json
from tqdm import tqdm

def train_model():
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Data transformations
    data_transforms = {
        'train': transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(15),
            transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }
    
    # Load the dataset
    data_dir = 'PlantVillage'  # Update this to your dataset path
    
    try:
        image_datasets = {
            'train': datasets.ImageFolder(os.path.join(data_dir, 'train'), data_transforms['train']),
            'val': datasets.ImageFolder(os.path.join(data_dir, 'val'), data_transforms['val'])
        }
        
        dataloaders = {
            'train': DataLoader(image_datasets['train'], batch_size=32, shuffle=True, num_workers=4),
            'val': DataLoader(image_datasets['val'], batch_size=32, shuffle=False, num_workers=4)
        }
        
        dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
        class_names = image_datasets['train'].classes
        
        # Save class names to a JSON file
        with open('class_names.json', 'w') as f:
            json.dump(class_names, f)
        
        print(f"Dataset loaded successfully with {len(class_names)} classes")
        print(f"Training set size: {dataset_sizes['train']}")
        print(f"Validation set size: {dataset_sizes['val']}")
        
        # Load a pre-trained model
        from torchvision.models import ResNet50_Weights
        model = models.resnet50(weights=ResNet50_Weights.DEFAULT)
        
        # Modify the final layer for our number of classes
        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs, len(class_names))
        
        model = model.to(device)
        
        # Define loss function and optimizer
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=0.001)
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
        
        # Train the model
        num_epochs = 15
        best_acc = 0.0
        
        for epoch in range(num_epochs):
            print(f'Epoch {epoch+1}/{num_epochs}')
            print('-' * 10)
            
            # Each epoch has a training and validation phase
            for phase in ['train', 'val']:
                if phase == 'train':
                    model.train()
                else:
                    model.eval()
                
                running_loss = 0.0
                running_corrects = 0
                
                # Iterate over data
                for inputs, labels in tqdm(dataloaders[phase]):
                    inputs = inputs.to(device)
                    labels = labels.to(device)
                    
                    # Zero the parameter gradients
                    optimizer.zero_grad()
                    
                    # Forward pass
                    with torch.set_grad_enabled(phase == 'train'):
                        outputs = model(inputs)
                        _, preds = torch.max(outputs, 1)
                        loss = criterion(outputs, labels)
                        
                        # Backward + optimize only if in training phase
                        if phase == 'train':
                            loss.backward()
                            optimizer.step()
                    
                    # Statistics
                    running_loss += loss.item() * inputs.size(0)
                    running_corrects += torch.sum(preds == labels.data)
                
                if phase == 'train':
                    scheduler.step()
                
                epoch_loss = running_loss / dataset_sizes[phase]
                epoch_acc = running_corrects.double() / dataset_sizes[phase]
                
                print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
                
                # Save the best model
                if phase == 'val' and epoch_acc > best_acc:
                    best_acc = epoch_acc
                    torch.save(model.state_dict(), 'plant_disease_model.pth')
            
            print()
        
        print(f'Best val Acc: {best_acc:.4f}')
        print('Model saved as plant_disease_model.pth')
        
    except Exception as e:
        print(f"Error during training: {e}")
        print("Please make sure the dataset is properly organized in the following structure:")
        print("PlantVillage/")
        print("β”œβ”€β”€ train/")
        print("β”‚   β”œβ”€β”€ Apple___Apple_scab/")
        print("β”‚   β”œβ”€β”€ Apple___Black_rot/")
        print("β”‚   └── ... (other classes)")
        print("└── val/")
        print("    β”œβ”€β”€ Apple___Apple_scab/")
        print("    β”œβ”€β”€ Apple___Black_rot/")
        print("    └── ... (other classes)")

if __name__ == "__main__":
    train_model()