iqramukhtiar commited on
Commit
91cb873
Β·
verified Β·
1 Parent(s): c82182e

Create train_model.py

Browse files
Files changed (1) hide show
  1. train_model.py +145 -0
train_model.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ from torch.utils.data import DataLoader
6
+ from torchvision import datasets, models, transforms
7
+ import json
8
+ from tqdm import tqdm
9
+
10
+ def train_model():
11
+ # Set device
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ print(f"Using device: {device}")
14
+
15
+ # Data transformations
16
+ data_transforms = {
17
+ 'train': transforms.Compose([
18
+ transforms.RandomResizedCrop(224),
19
+ transforms.RandomHorizontalFlip(),
20
+ transforms.RandomRotation(15),
21
+ transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
22
+ transforms.ToTensor(),
23
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
24
+ ]),
25
+ 'val': transforms.Compose([
26
+ transforms.Resize(256),
27
+ transforms.CenterCrop(224),
28
+ transforms.ToTensor(),
29
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
30
+ ]),
31
+ }
32
+
33
+ # Load the dataset
34
+ data_dir = 'PlantVillage' # Update this to your dataset path
35
+
36
+ try:
37
+ image_datasets = {
38
+ 'train': datasets.ImageFolder(os.path.join(data_dir, 'train'), data_transforms['train']),
39
+ 'val': datasets.ImageFolder(os.path.join(data_dir, 'val'), data_transforms['val'])
40
+ }
41
+
42
+ dataloaders = {
43
+ 'train': DataLoader(image_datasets['train'], batch_size=32, shuffle=True, num_workers=4),
44
+ 'val': DataLoader(image_datasets['val'], batch_size=32, shuffle=False, num_workers=4)
45
+ }
46
+
47
+ dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
48
+ class_names = image_datasets['train'].classes
49
+
50
+ # Save class names to a JSON file
51
+ with open('class_names.json', 'w') as f:
52
+ json.dump(class_names, f)
53
+
54
+ print(f"Dataset loaded successfully with {len(class_names)} classes")
55
+ print(f"Training set size: {dataset_sizes['train']}")
56
+ print(f"Validation set size: {dataset_sizes['val']}")
57
+
58
+ # Load a pre-trained model
59
+ model = models.resnet50(pretrained=True)
60
+
61
+ # Modify the final layer for our number of classes
62
+ num_ftrs = model.fc.in_features
63
+ model.fc = nn.Linear(num_ftrs, len(class_names))
64
+
65
+ model = model.to(device)
66
+
67
+ # Define loss function and optimizer
68
+ criterion = nn.CrossEntropyLoss()
69
+ optimizer = optim.Adam(model.parameters(), lr=0.001)
70
+ scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
71
+
72
+ # Train the model
73
+ num_epochs = 10
74
+ best_acc = 0.0
75
+
76
+ for epoch in range(num_epochs):
77
+ print(f'Epoch {epoch+1}/{num_epochs}')
78
+ print('-' * 10)
79
+
80
+ # Each epoch has a training and validation phase
81
+ for phase in ['train', 'val']:
82
+ if phase == 'train':
83
+ model.train()
84
+ else:
85
+ model.eval()
86
+
87
+ running_loss = 0.0
88
+ running_corrects = 0
89
+
90
+ # Iterate over data
91
+ for inputs, labels in tqdm(dataloaders[phase]):
92
+ inputs = inputs.to(device)
93
+ labels = labels.to(device)
94
+
95
+ # Zero the parameter gradients
96
+ optimizer.zero_grad()
97
+
98
+ # Forward pass
99
+ with torch.set_grad_enabled(phase == 'train'):
100
+ outputs = model(inputs)
101
+ _, preds = torch.max(outputs, 1)
102
+ loss = criterion(outputs, labels)
103
+
104
+ # Backward + optimize only if in training phase
105
+ if phase == 'train':
106
+ loss.backward()
107
+ optimizer.step()
108
+
109
+ # Statistics
110
+ running_loss += loss.item() * inputs.size(0)
111
+ running_corrects += torch.sum(preds == labels.data)
112
+
113
+ if phase == 'train':
114
+ scheduler.step()
115
+
116
+ epoch_loss = running_loss / dataset_sizes[phase]
117
+ epoch_acc = running_corrects.double() / dataset_sizes[phase]
118
+
119
+ print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
120
+
121
+ # Save the best model
122
+ if phase == 'val' and epoch_acc > best_acc:
123
+ best_acc = epoch_acc
124
+ torch.save(model.state_dict(), 'plant_disease_model.pth')
125
+
126
+ print()
127
+
128
+ print(f'Best val Acc: {best_acc:.4f}')
129
+ print('Model saved as plant_disease_model.pth')
130
+
131
+ except Exception as e:
132
+ print(f"Error during training: {e}")
133
+ print("Please make sure the dataset is properly organized in the following structure:")
134
+ print("PlantVillage/")
135
+ print("β”œβ”€β”€ train/")
136
+ print("β”‚ β”œβ”€β”€ Apple___Apple_scab/")
137
+ print("β”‚ β”œβ”€β”€ Apple___Black_rot/")
138
+ print("β”‚ └── ... (other classes)")
139
+ print("└── val/")
140
+ print(" β”œβ”€β”€ Apple___Apple_scab/")
141
+ print(" β”œβ”€β”€ Apple___Black_rot/")
142
+ print(" └── ... (other classes)")
143
+
144
+ if __name__ == "__main__":
145
+ train_model()