Spaces:
Sleeping
A newer version of the Streamlit SDK is available:
1.49.1
title: Brain tumor segmentation
emoji: 🧠
colorFrom: green
colorTo: red
sdk: streamlit
sdk_version: 1.45.1
app_file: app.py
pinned: false
Brain Tumor Segmentation using Multi-Output U-Net
Table of Contents
- Project Overview
- Model Architecture
- Installation
- Dataset Preparation
- Training
- Evaluation
- Visualization
- Performance Metrics
- Customization
- Troubleshooting
- License
Project Overview
This project implements a multi-task U-Net model for brain tumor segmentation from MRI scans. The model simultaneously predicts three tumor sub-regions:
- Whole Tumor (WT)
- Tumor Core (TC)
- Enhancing Tumor (ET)
The implementation uses TensorFlow/Keras with custom loss functions and visualization tools for model evaluation.
Model Architecture
The model is based on a U-Net architecture with the following key components:
Encoder Path
- 4 encoding blocks with [64, 128, 256, 512] filters
- Each block consists of:
- Two 3×3 convolutional layers with BatchNorm and ReLU
- Max pooling (2×2) for downsampling
Bottleneck
- 1024 filters with 50% dropout for regularization
Decoder Path
- 4 decoding blocks with [512, 256, 128, 64] filters
- Each block consists of:
- Transposed convolution (2×2) for upsampling
- Concatenation with skip connections
- Two 3×3 convolutional layers with BatchNorm and ReLU
Multi-Task Heads
- Three parallel output heads (1×1 conv + sigmoid)
- WT head (whole tumor)
- TC head (tumor core)
- ET head (enhancing tumor)
Installation
- Clone this repository:
git clone https://github.com/yourusername/brain-tumor-segmentation.git
cd brain-tumor-segmentation
- Create and activate a virtual environment:
python -m venv venv
source venv/bin/activate # Linux/Mac
venv\Scripts\activate # Windows
- Install dependencies:
pip install -r requirements.txt
Dataset Preparation
The model expects input images with shape (240, 240, 4) where the channels are:
- T1-weighted
- T1-weighted with contrast
- T2-weighted
- FLAIR
Preprocessing
- Normalize each modality separately (zero mean, unit variance)
- Resample all images to 1mm isotropic resolution
- Register all modalities to a common space
- Crop/pad to (240, 240) size
Training
To train the model:
from model import build_unet_multioutput
from losses import focal_tversky_loss
# Build model
model = build_unet_multioutput(input_shape=(240, 240, 4))
# Compile with multi-task losses
model.compile(optimizer='adam',
loss={
'wt_head': focal_tversky_loss,
'tc_head': focal_tversky_loss,
'et_head': focal_tversky_loss
},
metrics={'wt_head': dice_coefficient,
'tc_head': dice_coefficient,
'et_head': dice_coefficient})
# Train
history = model.fit(train_dataset,
validation_data=val_dataset,
epochs=num_epochs,
callbacks=[...])
Visualization
The package includes visualization utilities to compare predictions with ground truth:
Overlay Types
Ground Truth Overlay:
- WT: Red
- TC: Green
- ET: Blue
Prediction Overlay:
- Same color scheme as ground truth
Error Overlay:
- True Positives: Original colors
- False Negatives (missed tumors): Yellow
- False Positives (extra predictions): Magenta
Example visualization code:
from visualization import overlay_mask, overlay_errors
# For one sample
flair = x_batch[i, ..., 3] # FLAIR channel
gt_overlay = overlay_mask(flair, gt_wt, gt_tc, gt_et)
pred_overlay = overlay_mask(flair, pred_wt, pred_tc, pred_et)
error_overlay = overlay_errors(flair, gt_masks, pred_masks)
Performance Metrics
Dice score:
Region | Dice Score |
---|---|
WT | 0.8548 |
TC | 0.7484 |
ET | 0.7242 |
Customization
Model Parameters
Modify build_unet_multioutput()
to change:
- Input shape
- Number of filters
- Dropout rate
- Depth of network
Loss Functions
Available loss functions:
focal_tversky_loss
: Focuses on hard examplesdice_loss
: Standard Dice implementationbinary_crossentropy
: Traditional BCE
Visualization
Sample overlay screenshots: Compare prediction vs ground truth: - TP (correct) regions keep their color - FP (predicted but not GT): Magenta - FN (GT but not predicted): Yellow
Demo app
Troubleshooting
Out of Memory Errors:
- Reduce batch size
- Use mixed precision training
- Crop images to smaller size
Poor Convergence:
- Check data normalization
- Adjust learning rate
- Try different loss weights
NaN Losses:
- Add small epsilon (1e-7) to denominators
- Clip predictions (e.g., to [1e-7, 1-1e-7])
License
This project is licensed under the MIT License. See LICENSE file for details.
Citation
If you use this code in your research, please cite:
@misc{brain-tumor-segmentation-unet,
author = {Muzenda K},
title = {Multi-Output U-Net for Brain Tumor Segmentation},
year = {2025},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/Muzenda-K/brain-tumor-segmentation}}
}