import os import shutil import random from tqdm import tqdm def create_train_val_split(source_dir, output_dir, val_split=0.2): """ Creates a train/val split from the PlantVillage dataset. Args: source_dir: Directory containing the PlantVillage dataset output_dir: Directory to save the split dataset val_split: Proportion of data to use for validation """ # Create output directories train_dir = os.path.join(output_dir, 'train') val_dir = os.path.join(output_dir, 'val') os.makedirs(train_dir, exist_ok=True) os.makedirs(val_dir, exist_ok=True) # Get all class directories class_dirs = [d for d in os.listdir(source_dir) if os.path.isdir(os.path.join(source_dir, d))] for class_dir in tqdm(class_dirs, desc="Processing classes"): # Create class directories in train and val os.makedirs(os.path.join(train_dir, class_dir), exist_ok=True) os.makedirs(os.path.join(val_dir, class_dir), exist_ok=True) # Get all images in the class images = [f for f in os.listdir(os.path.join(source_dir, class_dir)) if f.endswith('.jpg') or f.endswith('.JPG') or f.endswith('.png')] # Shuffle images random.shuffle(images) # Split into train and val split_idx = int(len(images) * (1 - val_split)) train_images = images[:split_idx] val_images = images[split_idx:] # Copy images to train directory for img in tqdm(train_images, desc=f"Copying {class_dir} train images", leave=False): src = os.path.join(source_dir, class_dir, img) dst = os.path.join(train_dir, class_dir, img) shutil.copy(src, dst) # Copy images to val directory for img in tqdm(val_images, desc=f"Copying {class_dir} val images", leave=False): src = os.path.join(source_dir, class_dir, img) dst = os.path.join(val_dir, class_dir, img) shutil.copy(src, dst) print(f"Dataset split complete. Train: {len(os.listdir(train_dir))} classes, Val: {len(os.listdir(val_dir))} classes") def main(): # Set paths source_dir = 'PlantVillage-Dataset/raw/color' # Update this to your dataset path output_dir = 'PlantVillage' # Create the split create_train_val_split(source_dir, output_dir) if __name__ == "__main__": main()