Spaces:
Sleeping
Sleeping
import os | |
import h5py | |
import numpy as np | |
import tensorflow as tf | |
from sklearn.model_selection import train_test_split | |
def load_h5_file_multitask(file_path): | |
with h5py.File(file_path, 'r') as f: | |
image = f['image'][()] # (240, 240, 4) | |
mask = f['mask'][()] # (240, 240) OR (240, 240, 3) | |
# Normalize image | |
image = (image - np.mean(image, axis=(0, 1), keepdims=True)) / \ | |
(np.std(image, axis=(0, 1), keepdims=True) + 1e-6) | |
# Handle case if mask is already one-hot encoded (3-channel) | |
if mask.ndim == 3 and mask.shape[-1] == 3: | |
# One-hot format (NCR, ED, ET) → derive binary targets | |
ncr = mask[..., 0] | |
ed = mask[..., 1] | |
et = mask[..., 2] | |
wt = ((ncr + ed + et) > 0).astype(np.float32)[..., np.newaxis] # Whole tumor | |
tc = ((ncr + et) > 0).astype(np.float32)[..., np.newaxis] # Tumor core | |
et = (et > 0).astype(np.float32)[..., np.newaxis] # Enhancing tumor | |
else: | |
# Single-channel label map | |
mask = mask.astype(np.uint8) | |
wt = (mask > 0).astype(np.float32)[..., np.newaxis] | |
tc = np.isin(mask, [1, 4]).astype(np.float32)[..., np.newaxis] | |
et = (mask == 4).astype(np.float32)[..., np.newaxis] | |
return image.astype(np.float32), wt, tc, et | |
def _parse_multitask_function(path): | |
image, wt, tc, et = tf.numpy_function( | |
load_h5_file_multitask, [path], | |
[tf.float32, tf.float32, tf.float32, tf.float32] | |
) | |
image.set_shape((240, 240, 4)) | |
wt.set_shape((240, 240, 1)) | |
tc.set_shape((240, 240, 1)) | |
et.set_shape((240, 240, 1)) | |
masks = { | |
'wt_head': wt, | |
'tc_head': tc, | |
'et_head': et | |
} | |
return image, masks | |
def get_dataset(file_paths, batch_size=8, shuffle=False, num_workers=4): | |
dataset = tf.data.Dataset.from_tensor_slices(file_paths) | |
if shuffle: | |
dataset = dataset.shuffle(buffer_size=1000) | |
dataset = dataset.map(_parse_multitask_function, num_parallel_calls=num_workers) | |
dataset = dataset.batch(batch_size) | |
dataset = dataset.prefetch(tf.data.AUTOTUNE) | |
return dataset | |
def get_train_val_datasets(data_dir, batch_size=8, test_size=0.2, random_state=42): | |
all_files = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith('.h5')] | |
train_files, val_files = train_test_split(all_files, test_size=test_size, random_state=random_state) | |
train_dataset = get_dataset(train_files, batch_size=batch_size, shuffle=True, num_workers=4) | |
val_dataset = get_dataset(val_files, batch_size=batch_size, shuffle=False, num_workers=4) | |
print(f"Total files: {len(all_files)}") | |
print(f"Train files: {len(train_files)}") | |
print(f"Val files: {len(val_files)}") | |
return train_dataset, val_dataset | |