PlantDiseaseTreatmentAssistant / create_dataset.py
iqramukhtiar's picture
Create create_dataset.py
4fa4171 verified
raw
history blame
2.45 kB
import os
import shutil
import random
from tqdm import tqdm
def create_train_val_split(source_dir, output_dir, val_split=0.2):
"""
Creates a train/val split from the PlantVillage dataset.
Args:
source_dir: Directory containing the PlantVillage dataset
output_dir: Directory to save the split dataset
val_split: Proportion of data to use for validation
"""
# Create output directories
train_dir = os.path.join(output_dir, 'train')
val_dir = os.path.join(output_dir, 'val')
os.makedirs(train_dir, exist_ok=True)
os.makedirs(val_dir, exist_ok=True)
# Get all class directories
class_dirs = [d for d in os.listdir(source_dir) if os.path.isdir(os.path.join(source_dir, d))]
for class_dir in tqdm(class_dirs, desc="Processing classes"):
# Create class directories in train and val
os.makedirs(os.path.join(train_dir, class_dir), exist_ok=True)
os.makedirs(os.path.join(val_dir, class_dir), exist_ok=True)
# Get all images in the class
images = [f for f in os.listdir(os.path.join(source_dir, class_dir))
if f.endswith('.jpg') or f.endswith('.JPG') or f.endswith('.png')]
# Shuffle images
random.shuffle(images)
# Split into train and val
split_idx = int(len(images) * (1 - val_split))
train_images = images[:split_idx]
val_images = images[split_idx:]
# Copy images to train directory
for img in tqdm(train_images, desc=f"Copying {class_dir} train images", leave=False):
src = os.path.join(source_dir, class_dir, img)
dst = os.path.join(train_dir, class_dir, img)
shutil.copy(src, dst)
# Copy images to val directory
for img in tqdm(val_images, desc=f"Copying {class_dir} val images", leave=False):
src = os.path.join(source_dir, class_dir, img)
dst = os.path.join(val_dir, class_dir, img)
shutil.copy(src, dst)
print(f"Dataset split complete. Train: {len(os.listdir(train_dir))} classes, Val: {len(os.listdir(val_dir))} classes")
def main():
# Set paths
source_dir = 'PlantVillage-Dataset/raw/color' # Update this to your dataset path
output_dir = 'PlantVillage'
# Create the split
create_train_val_split(source_dir, output_dir)
if __name__ == "__main__":
main()