brain-tumor-segmentation / dataprep_multitask.py
Muzenda-K
Initial commit
5172761
import os
import h5py
import numpy as np
import tensorflow as tf
from sklearn.model_selection import train_test_split
def load_h5_file_multitask(file_path):
with h5py.File(file_path, 'r') as f:
image = f['image'][()] # (240, 240, 4)
mask = f['mask'][()] # (240, 240) OR (240, 240, 3)
# Normalize image
image = (image - np.mean(image, axis=(0, 1), keepdims=True)) / \
(np.std(image, axis=(0, 1), keepdims=True) + 1e-6)
# Handle case if mask is already one-hot encoded (3-channel)
if mask.ndim == 3 and mask.shape[-1] == 3:
# One-hot format (NCR, ED, ET) → derive binary targets
ncr = mask[..., 0]
ed = mask[..., 1]
et = mask[..., 2]
wt = ((ncr + ed + et) > 0).astype(np.float32)[..., np.newaxis] # Whole tumor
tc = ((ncr + et) > 0).astype(np.float32)[..., np.newaxis] # Tumor core
et = (et > 0).astype(np.float32)[..., np.newaxis] # Enhancing tumor
else:
# Single-channel label map
mask = mask.astype(np.uint8)
wt = (mask > 0).astype(np.float32)[..., np.newaxis]
tc = np.isin(mask, [1, 4]).astype(np.float32)[..., np.newaxis]
et = (mask == 4).astype(np.float32)[..., np.newaxis]
return image.astype(np.float32), wt, tc, et
def _parse_multitask_function(path):
image, wt, tc, et = tf.numpy_function(
load_h5_file_multitask, [path],
[tf.float32, tf.float32, tf.float32, tf.float32]
)
image.set_shape((240, 240, 4))
wt.set_shape((240, 240, 1))
tc.set_shape((240, 240, 1))
et.set_shape((240, 240, 1))
masks = {
'wt_head': wt,
'tc_head': tc,
'et_head': et
}
return image, masks
def get_dataset(file_paths, batch_size=8, shuffle=False, num_workers=4):
dataset = tf.data.Dataset.from_tensor_slices(file_paths)
if shuffle:
dataset = dataset.shuffle(buffer_size=1000)
dataset = dataset.map(_parse_multitask_function, num_parallel_calls=num_workers)
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(tf.data.AUTOTUNE)
return dataset
def get_train_val_datasets(data_dir, batch_size=8, test_size=0.2, random_state=42):
all_files = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith('.h5')]
train_files, val_files = train_test_split(all_files, test_size=test_size, random_state=random_state)
train_dataset = get_dataset(train_files, batch_size=batch_size, shuffle=True, num_workers=4)
val_dataset = get_dataset(val_files, batch_size=batch_size, shuffle=False, num_workers=4)
print(f"Total files: {len(all_files)}")
print(f"Train files: {len(train_files)}")
print(f"Val files: {len(val_files)}")
return train_dataset, val_dataset