#!/usr/bin/env python # coding: utf-8 # In[1]: import os import numpy as np import tensorflow as tf from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping from dataprep_multitask import get_train_val_datasets from utils.losses import dice_loss, dice_coefficient, focal_tversky_loss # custom metrics from models.unet_multitask import build_unet_multioutput from glob import glob from visualize import visualize_batch import h5py os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # 3=errors only os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' os.environ['TF_GPU_ALLOCATOR'] = 'cuda_malloc_async' # Set seeds for reproducibility tf.random.set_seed(42) np.random.seed(42) # Paths DATA_DIR = "BraTS20/BraTS2020_training_data/content/data" MODEL_SAVE_PATH = "models/unet_brats.keras" # Hyperparameters BATCH_SIZE = 8 EPOCHS = 30 # Load datasets train_dataset, val_dataset = get_train_val_datasets(DATA_DIR, batch_size=BATCH_SIZE) # ------------------- Model Setup ------------------- model = build_unet_multioutput(input_shape=(240, 240, 4)) model.compile( optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4), loss={ 'wt_head': focal_tversky_loss, 'tc_head': focal_tversky_loss, 'et_head': focal_tversky_loss, }, loss_weights={ 'wt_head': 1.0, 'tc_head': 1.0, 'et_head': 2.0, }, metrics={ 'wt_head': dice_coefficient, 'tc_head': dice_coefficient, 'et_head': dice_coefficient, } ) # ------------------- Callbacks ------------------- callbacks = [ ModelCheckpoint("models/unet_multihead_brats.keras", save_best_only=True, monitor='val_et_head_dice_coefficient', mode='max'), ReduceLROnPlateau(monitor='val_et_head_dice_coefficient', factor=0.5, patience=4, min_lr=1e-6, verbose=1), EarlyStopping(monitor='val_et_head_dice_coefficient', mode='max', patience=8, restore_best_weights=True) ] # ------------------- Training ------------------- history = model.fit( train_dataset, validation_data=val_dataset, steps_per_epoch=len(train_dataset) // BATCH_SIZE, validation_steps=len(val_dataset) // BATCH_SIZE, epochs=EPOCHS, callbacks=callbacks ) # In[4]: import matplotlib.pyplot as plt print("\nGenerating predictions and visualizing...\n") # Get one batch from validation set for x_batch, y_batch in val_dataset.take(1): preds = model.predict(x_batch) pred_masks = { 'wt_head': (preds[0] > 0.5).astype(np.float32), 'tc_head': (preds[1] > 0.5).astype(np.float32), 'et_head': (preds[2] > 0.5).astype(np.float32), } def overlay_prediction(img, pred_masks): flair = img[:, :, 3] flair_norm = (flair - flair.min()) / (flair.max() - flair.min() + 1e-8) overlay = np.stack([flair_norm]*3, axis=-1) overlay[pred_masks['wt_head'][..., 0] == 1] = [1, 0, 0] # Red overlay[pred_masks['tc_head'][..., 0] == 1] = [0, 1, 0] # Green overlay[pred_masks['et_head'][..., 0] == 1] = [0, 0, 1] # Blue return overlay for i in range(min(3, x_batch.shape[0])): image = x_batch[i].numpy() flair = image[:, :, 3] pred_overlay = overlay_prediction(image, { 'wt_head': pred_masks['wt_head'][i], 'tc_head': pred_masks['tc_head'][i], 'et_head': pred_masks['et_head'][i], }) plt.figure(figsize=(12, 4)) plt.subplot(1, 2, 1) plt.imshow(flair, cmap='gray') plt.title("FLAIR") plt.axis('off') plt.subplot(1, 2, 2) plt.imshow(pred_overlay) plt.title("Predicted Mask Overlay") plt.axis('off') plt.show() break # In[5]: import matplotlib.pyplot as plt def overlay_mask(flair, wt, tc, et): """ Build RGB overlay on FLAIR using WT (red), TC (green), ET (blue) masks. """ flair_norm = (flair - flair.min()) / (flair.max() - flair.min() + 1e-8) overlay = np.stack([flair_norm]*3, axis=-1) overlay[wt[..., 0] == 1] = [1, 0, 0] # Red overlay[tc[..., 0] == 1] = [0, 1, 0] # Green overlay[et[..., 0] == 1] = [0, 0, 1] # Blue return overlay # Predict one batch for x_batch, y_batch in val_dataset.take(1): preds = model.predict(x_batch) for i in range(3): # Show first 3 samples img = x_batch[i].numpy() flair = img[:, :, 3] # Predicted masks (thresholded) wt_pred = (preds[0][i] > 0.5).astype(np.float32) tc_pred = (preds[1][i] > 0.5).astype(np.float32) et_pred = (preds[2][i] > 0.5).astype(np.float32) # Ground truth masks wt_true = y_batch['wt_head'][i].numpy() tc_true = y_batch['tc_head'][i].numpy() et_true = y_batch['et_head'][i].numpy() # Overlays gt_overlay = overlay_mask(flair, wt_true, tc_true, et_true) pred_overlay = overlay_mask(flair, wt_pred, tc_pred, et_pred) # Plot side-by-side plt.figure(figsize=(10, 5)) plt.subplot(1, 2, 1) plt.imshow(gt_overlay) plt.title("Ground Truth Overlay") plt.axis('off') plt.subplot(1, 2, 2) plt.imshow(pred_overlay) plt.title("Predicted Overlay") plt.axis('off') plt.tight_layout() plt.show() break # In[6]: import matplotlib.pyplot as plt import numpy as np def overlay_mask(flair, wt, tc, et): """ Build RGB overlay on FLAIR using WT (red), TC (green), ET (blue). """ flair_norm = (flair - flair.min()) / (flair.max() - flair.min() + 1e-8) overlay = np.stack([flair_norm]*3, axis=-1) overlay[wt[..., 0] == 1] = [1, 0, 0] # Red: WT overlay[tc[..., 0] == 1] = [0, 1, 0] # Green: TC overlay[et[..., 0] == 1] = [0, 0, 1] # Blue: ET return overlay def overlay_errors(flair, gt, pred): """ Compare prediction vs ground truth: - TP (correct) regions keep their color - FP (predicted but not GT): Magenta - FN (GT but not predicted): Yellow """ flair_norm = (flair - flair.min()) / (flair.max() - flair.min() + 1e-8) overlay = np.stack([flair_norm]*3, axis=-1) for mask_name, color, idx in zip(['wt_head', 'tc_head', 'et_head'], [[1, 0, 0], [0, 1, 0], [0, 0, 1]], range(3)): gt_mask = gt[mask_name][..., 0] pred_mask = pred[mask_name][..., 0] tp = np.logical_and(gt_mask == 1, pred_mask == 1) fn = np.logical_and(gt_mask == 1, pred_mask == 0) # Missed fp = np.logical_and(gt_mask == 0, pred_mask == 1) # Extra overlay[tp] = color # Correct overlay[fn] = [1, 1, 0] # Yellow for FN overlay[fp] = [1, 0, 1] # Magenta for FP return overlay # Run on one batch for x_batch, y_batch in val_dataset.take(1): preds = model.predict(x_batch) for i in range(3): img = x_batch[i].numpy() flair = img[:, :, 3] # Threshold predictions pred_masks = { 'wt_head': (preds[0][i] > 0.5).astype(np.float32), 'tc_head': (preds[1][i] > 0.5).astype(np.float32), 'et_head': (preds[2][i] > 0.5).astype(np.float32) } gt_masks = { 'wt_head': y_batch['wt_head'][i].numpy(), 'tc_head': y_batch['tc_head'][i].numpy(), 'et_head': y_batch['et_head'][i].numpy() } # Overlays gt_overlay = overlay_mask(flair, gt_masks['wt_head'], gt_masks['tc_head'], gt_masks['et_head']) pred_overlay = overlay_mask(flair, pred_masks['wt_head'], pred_masks['tc_head'], pred_masks['et_head']) error_overlay = overlay_errors(flair, gt_masks, pred_masks) # Plot all 3 plt.figure(figsize=(18, 5)) plt.subplot(1, 3, 1) plt.imshow(gt_overlay) plt.title("Ground Truth Overlay") plt.axis('off') plt.subplot(1, 3, 2) plt.imshow(pred_overlay) plt.title("Predicted Overlay") plt.axis('off') plt.subplot(1, 3, 3) plt.imshow(error_overlay) plt.title("Error Overlay\nFN=Yellow | FP=Magenta") plt.axis('off') plt.tight_layout() plt.show() break # In[ ]: