Spaces:
Runtime error
Runtime error
from datasets import load_dataset | |
from torch.utils.data import Dataset, DataLoader | |
from torchvision import transforms, models | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
# Load dataset | |
dataset = load_dataset("plant_village") | |
label_names = dataset["train"].features["label"].names | |
transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()]) | |
class PlantVillageDataset(Dataset): | |
def __init__(self, hf_data, transform): | |
self.dataset = hf_data | |
self.transform = transform | |
def __len__(self): | |
return len(self.dataset) | |
def __getitem__(self, idx): | |
image = self.transform(self.dataset[idx]["image"]) | |
label = self.dataset[idx]["label"] | |
return image, label | |
train_data = PlantVillageDataset(dataset["train"], transform) | |
train_loader = DataLoader(train_data, batch_size=32, shuffle=True) | |
# Define model | |
model = models.resnet18(pretrained=True) | |
model.fc = nn.Linear(model.fc.in_features, len(label_names)) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
# Train | |
optimizer = optim.Adam(model.parameters(), lr=0.001) | |
criterion = nn.CrossEntropyLoss() | |
model.train() | |
for epoch in range(2): # You can increase this | |
for images, labels in train_loader: | |
images, labels = images.to(device), labels.to(device) | |
optimizer.zero_grad() | |
outputs = model(images) | |
loss = criterion(outputs, labels) | |
loss.backward() | |
optimizer.step() | |
print(f"Epoch {epoch+1} completed.") | |
# Save the model | |
torch.save(model.state_dict(), "model.pth") | |
print("Model saved as model.pth") | |