|
|
import os |
|
|
import shutil |
|
|
import random |
|
|
from tqdm import tqdm |
|
|
|
|
|
def create_train_val_split(source_dir, output_dir, val_split=0.2): |
|
|
""" |
|
|
Creates a train/val split from the PlantVillage dataset. |
|
|
|
|
|
Args: |
|
|
source_dir: Directory containing the PlantVillage dataset |
|
|
output_dir: Directory to save the split dataset |
|
|
val_split: Proportion of data to use for validation |
|
|
""" |
|
|
|
|
|
train_dir = os.path.join(output_dir, 'train') |
|
|
val_dir = os.path.join(output_dir, 'val') |
|
|
|
|
|
os.makedirs(train_dir, exist_ok=True) |
|
|
os.makedirs(val_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
class_dirs = [d for d in os.listdir(source_dir) if os.path.isdir(os.path.join(source_dir, d))] |
|
|
|
|
|
for class_dir in tqdm(class_dirs, desc="Processing classes"): |
|
|
|
|
|
os.makedirs(os.path.join(train_dir, class_dir), exist_ok=True) |
|
|
os.makedirs(os.path.join(val_dir, class_dir), exist_ok=True) |
|
|
|
|
|
|
|
|
images = [f for f in os.listdir(os.path.join(source_dir, class_dir)) |
|
|
if f.endswith('.jpg') or f.endswith('.JPG') or f.endswith('.png')] |
|
|
|
|
|
|
|
|
random.shuffle(images) |
|
|
|
|
|
|
|
|
split_idx = int(len(images) * (1 - val_split)) |
|
|
train_images = images[:split_idx] |
|
|
val_images = images[split_idx:] |
|
|
|
|
|
|
|
|
for img in tqdm(train_images, desc=f"Copying {class_dir} train images", leave=False): |
|
|
src = os.path.join(source_dir, class_dir, img) |
|
|
dst = os.path.join(train_dir, class_dir, img) |
|
|
shutil.copy(src, dst) |
|
|
|
|
|
|
|
|
for img in tqdm(val_images, desc=f"Copying {class_dir} val images", leave=False): |
|
|
src = os.path.join(source_dir, class_dir, img) |
|
|
dst = os.path.join(val_dir, class_dir, img) |
|
|
shutil.copy(src, dst) |
|
|
|
|
|
print(f"Dataset split complete. Train: {len(os.listdir(train_dir))} classes, Val: {len(os.listdir(val_dir))} classes") |
|
|
|
|
|
def main(): |
|
|
|
|
|
source_dir = 'PlantVillage-Dataset/raw/color' |
|
|
output_dir = 'PlantVillage' |
|
|
|
|
|
|
|
|
create_train_val_split(source_dir, output_dir) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|