crop-disease-detectotor / train-model.py
sadiaafzaal's picture
Create train-model.py
c0b6788 verified
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import torch
import torch.nn as nn
import torch.optim as optim
# Load dataset
dataset = load_dataset("plant_village")
label_names = dataset["train"].features["label"].names
transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
class PlantVillageDataset(Dataset):
def __init__(self, hf_data, transform):
self.dataset = hf_data
self.transform = transform
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
image = self.transform(self.dataset[idx]["image"])
label = self.dataset[idx]["label"]
return image, label
train_data = PlantVillageDataset(dataset["train"], transform)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
# Define model
model = models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, len(label_names))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Train
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
model.train()
for epoch in range(2): # You can increase this
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
print(f"Epoch {epoch+1} completed.")
# Save the model
torch.save(model.state_dict(), "model.pth")
print("Model saved as model.pth")