|
|
--- |
|
|
license: mit |
|
|
library_name: diffusers |
|
|
tags: |
|
|
- diffusion |
|
|
- ddpm |
|
|
- retinal-fundus |
|
|
- image-generation |
|
|
--- |
|
|
# Model Card for ddpm-unet-retinal-fundus-image-generator |
|
|
|
|
|
|
|
|
A **U-Net–based Denoising Diffusion Probabilistic Model (DDPM)** trained to generate **retinal fundus images**. This model can be used for synthetic medical image generation to augment datasets for training diagnostic models or other biomedical tasks. |
|
|
|
|
|
--- |
|
|
|
|
|
## Model Architecture |
|
|
|
|
|
- **Base**: [`UNet2DModel`](https://huggingface.co/docs/diffusers/main/en/api/models/unet2d) |
|
|
- **Scheduler**: [`DDPMScheduler`](https://huggingface.co/docs/diffusers/main/en/api/schedulers/ddpm) |
|
|
- **Resolution**: `128x128` |
|
|
- **Channels**: `RGB (3)` |
|
|
- **Attention**: Spatial self-attention in mid-resolution blocks |
|
|
- **Framework**: [🤗 Diffusers](https://github.com/huggingface/diffusers) + PyTorch |
|
|
|
|
|
--- |
|
|
|
|
|
## Dataset |
|
|
|
|
|
- **Source**: [Kaggle - Retinal Fundus Images](https://www.kaggle.com/datasets/kssanjaynithish03/retinal-fundus-images) |
|
|
- **Subset Used**: `train/Moderate Diabetic Retinopathy` |
|
|
- **Preprocessing**: |
|
|
- Resized to `128x128` |
|
|
- Normalized to `[-1, 1]` |
|
|
- Random horizontal flip |
|
|
|
|
|
--- |
|
|
|
|
|
## Training Configuration |
|
|
|
|
|
| Setting | Value | |
|
|
|----------------------------|--------------------| |
|
|
| Epochs | 35 | |
|
|
| Batch size | 16 | |
|
|
| Optimizer | AdamW | |
|
|
| Learning rate | 1e-4 | |
|
|
| Scheduler | Cosine w/ warmup | |
|
|
| Precision | Mixed (fp16) | |
|
|
| Diffusion Timesteps | 1000 | |
|
|
| Image Samples Saved | Every 10 epochs | |
|
|
|
|
|
Training was done using 🤗 Accelerate and TensorBoard logging. |
|
|
|
|
|
--- |
|
|
|
|
|
## How to Use |
|
|
|
|
|
```python |
|
|
from diffusers import DDPMPipeline |
|
|
import torch |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
pipeline = DDPMPipeline.from_pretrained("GS-23/ddpm-unet-retinal-fundus-image-generator") |
|
|
images = pipeline(batch_size=1, generator=torch.manual_seed(0)).images |
|
|
|
|
|
for img in images: |
|
|
plt.imshow(img) |
|
|
plt.axis("off") |
|
|
plt.show() |
|
|
``` |
|
|
|
|
|
--- |
|
|
|
|
|
## 📌 Use Cases |
|
|
|
|
|
- Data augmentation for diabetic retinopathy classifiers |
|
|
- Retinal pathology simulation and training |
|
|
- Medical generative AI research |
|
|
- Domain-specific image synthesis |
|
|
|