iqramukhtiar's picture
Upload 3 files
4382bbc verified
raw
history blame
6.72 kB
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, models, transforms
import json
from tqdm import tqdm
import time
def train_quick_model():
"""
Trains a quick model on a small subset of data for demonstration purposes.
This will show training epochs in the console.
"""
print("Starting quick model training...")
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Data transformations
data_transforms = {
'train': transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
# Create a small synthetic dataset for demonstration
print("Creating a small synthetic dataset for demonstration...")
# Check if PlantVillage dataset exists
dataset_path = 'PlantVillage'
if not os.path.exists(dataset_path):
os.makedirs(dataset_path, exist_ok=True)
print(f"Created directory {dataset_path}")
# Create some example class directories
classes = ["Tomato_Early_blight", "Tomato_healthy", "Apple_scab", "Apple_healthy"]
for cls in classes:
os.makedirs(os.path.join(dataset_path, cls), exist_ok=True)
print(f"Created class directory: {cls}")
# Create a few dummy images (1x1 pixel) for each class
for i in range(10):
img = torch.randint(0, 256, (3, 224, 224), dtype=torch.uint8)
img_tensor = transforms.ToPILImage()(img)
img_tensor.save(os.path.join(dataset_path, cls, f"image_{i}.jpg"))
print(f"Created 10 dummy images for {cls}")
# Load the dataset
print("Loading dataset...")
try:
dataset = datasets.ImageFolder(dataset_path, transform=data_transforms['train'])
# Split into train and validation sets
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
# Apply different transforms to the splits
train_dataset.dataset.transform = data_transforms['train']
val_dataset.dataset.transform = data_transforms['val']
# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)
# Save class names
class_names = dataset.classes
with open('class_names.json', 'w') as f:
json.dump(class_names, f)
print(f"Dataset loaded with {len(class_names)} classes")
print(f"Training set: {len(train_dataset)} images")
print(f"Validation set: {len(val_dataset)} images")
# Load a pre-trained model
print("Loading pre-trained model...")
model = models.resnet18(weights=None) # Use a smaller model for quicker training
# Modify the final layer for our number of classes
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(class_names))
model = model.to(device)
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Train the model for a few epochs
num_epochs = 5 # Just a few epochs for demonstration
print(f"Starting training for {num_epochs} epochs...")
for epoch in range(num_epochs):
print(f'Epoch {epoch+1}/{num_epochs}')
print('-' * 10)
# Training phase
model.train()
running_loss = 0.0
running_corrects = 0
# Iterate over data
for inputs, labels in tqdm(train_loader, desc=f"Training"):
inputs = inputs.to(device)
labels = labels.to(device)
# Zero the parameter gradients
optimizer.zero_grad()
# Forward pass
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
# Backward + optimize
loss.backward()
optimizer.step()
# Statistics
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
epoch_loss = running_loss / len(train_dataset)
epoch_acc = running_corrects.double() / len(train_dataset)
print(f'Training Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
# Validation phase
model.eval()
running_loss = 0.0
running_corrects = 0
# Iterate over data
for inputs, labels in tqdm(val_loader, desc=f"Validation"):
inputs = inputs.to(device)
labels = labels.to(device)
# Forward pass
with torch.no_grad():
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
# Statistics
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
epoch_loss = running_loss / len(val_dataset)
epoch_acc = running_corrects.double() / len(val_dataset)
print(f'Validation Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
print()
# Save the model
torch.save(model.state_dict(), 'plant_disease_model.pth')
print("Model saved as plant_disease_model.pth")
except Exception as e:
print(f"Error during training: {e}")
if __name__ == "__main__":
start_time = time.time()
train_quick_model()
end_time = time.time()
print(f"Training completed in {end_time - start_time:.2f} seconds")