|
|
import os |
|
|
import requests |
|
|
from tqdm import tqdm |
|
|
from datasets import load_dataset |
|
|
import shutil |
|
|
|
|
|
def download_plantvillage_from_huggingface(): |
|
|
""" |
|
|
Downloads the PlantVillage dataset from Hugging Face and organizes it for training. |
|
|
""" |
|
|
print("Downloading PlantVillage dataset from Hugging Face...") |
|
|
|
|
|
|
|
|
os.makedirs('PlantVillage', exist_ok=True) |
|
|
|
|
|
try: |
|
|
|
|
|
dataset = load_dataset("GVJahnavi/PlantVillage_dataset") |
|
|
print(f"Dataset loaded successfully with {len(dataset['train'])} training samples") |
|
|
|
|
|
|
|
|
labels = dataset['train'].features['label'].names |
|
|
print(f"Found {len(labels)} classes: {labels}") |
|
|
|
|
|
|
|
|
for label_idx, label_name in enumerate(labels): |
|
|
label_dir = os.path.join('PlantVillage', label_name) |
|
|
os.makedirs(label_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
class_samples = dataset['train'].filter(lambda example: example['label'] == label_idx) |
|
|
print(f"Processing class {label_name} with {len(class_samples)} samples") |
|
|
|
|
|
|
|
|
for i, sample in enumerate(tqdm(class_samples, desc=f"Saving {label_name}")): |
|
|
img = sample['image'] |
|
|
img_path = os.path.join(label_dir, f"{label_name}_{i}.jpg") |
|
|
img.save(img_path) |
|
|
|
|
|
|
|
|
with open('class_names.json', 'w') as f: |
|
|
import json |
|
|
json.dump(labels, f) |
|
|
|
|
|
print("Dataset downloaded and organized successfully") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error downloading dataset from Hugging Face: {e}") |
|
|
return False |
|
|
|
|
|
if __name__ == "__main__": |
|
|
download_plantvillage_from_huggingface() |
|
|
|