Upload 55 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +59 -20
- dnafiber/__init__.py +1 -0
- dnafiber/__pycache__/__init__.cpython-312.pyc +0 -0
- dnafiber/__pycache__/deployment.cpython-312.pyc +0 -0
- dnafiber/__pycache__/inference.cpython-312.pyc +0 -0
- dnafiber/__pycache__/metric.cpython-312.pyc +0 -0
- dnafiber/__pycache__/post_process.cpython-312.pyc +0 -0
- dnafiber/__pycache__/trainee.cpython-312.pyc +0 -0
- dnafiber/analysis/__init__.py +0 -0
- dnafiber/analysis/chart.py +61 -0
- dnafiber/analysis/const.py +3 -0
- dnafiber/analysis/utils.py +21 -0
- dnafiber/callbacks.py +50 -0
- dnafiber/data/__init__.py +0 -0
- dnafiber/data/__pycache__/__init__.cpython-312.pyc +0 -0
- dnafiber/data/__pycache__/utils.cpython-312.pyc +0 -0
- dnafiber/data/dataset.py +271 -0
- dnafiber/data/intergrader/__init__.py +1 -0
- dnafiber/data/intergrader/__pycache__/__init__.cpython-312.pyc +0 -0
- dnafiber/data/intergrader/__pycache__/analysis.cpython-312.pyc +0 -0
- dnafiber/data/intergrader/__pycache__/const.cpython-312.pyc +0 -0
- dnafiber/data/intergrader/__pycache__/io.cpython-312.pyc +0 -0
- dnafiber/data/intergrader/__pycache__/plot.cpython-312.pyc +0 -0
- dnafiber/data/intergrader/analysis.py +120 -0
- dnafiber/data/intergrader/auto.py +3 -0
- dnafiber/data/intergrader/const.py +21 -0
- dnafiber/data/intergrader/io.py +27 -0
- dnafiber/data/intergrader/plot.py +172 -0
- dnafiber/data/utils.py +80 -0
- dnafiber/deployment.py +44 -0
- dnafiber/inference.py +105 -0
- dnafiber/metric.py +150 -0
- dnafiber/model/maskrcnn.py +0 -0
- dnafiber/postprocess/__init__.py +1 -0
- dnafiber/postprocess/__pycache__/__init__.cpython-312.pyc +0 -0
- dnafiber/postprocess/__pycache__/core.cpython-312.pyc +0 -0
- dnafiber/postprocess/__pycache__/fiber.cpython-312.pyc +0 -0
- dnafiber/postprocess/__pycache__/skan.cpython-312.pyc +0 -0
- dnafiber/postprocess/core.py +274 -0
- dnafiber/postprocess/fiber.py +129 -0
- dnafiber/postprocess/skan.py +211 -0
- dnafiber/start.py +22 -0
- dnafiber/trainee.py +148 -0
- dnafiber/ui/Welcome.py +47 -0
- dnafiber/ui/__init__.py +0 -0
- dnafiber/ui/__pycache__/__init__.cpython-312.pyc +0 -0
- dnafiber/ui/__pycache__/inference.cpython-312.pyc +0 -0
- dnafiber/ui/__pycache__/utils.cpython-312.pyc +0 -0
- dnafiber/ui/inference.py +69 -0
- dnafiber/ui/pages/1_Load.py +196 -0
README.md
CHANGED
@@ -1,20 +1,59 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# DN-AI
|
2 |
+
|
3 |
+
This is the official repository for DN-AI, an automated tool for measurement of differentiated DNA replication in fluorescence microscopy images.
|
4 |
+
|
5 |
+
DN-AI offers different solutions for biologists to measure DNA replication in fluorescence microscopy images, without requiring programming skills. See the [Installation](#installation) section for instructions on how to install DN-AI.
|
6 |
+
|
7 |
+
## Features
|
8 |
+
|
9 |
+
- **Automated DNA replication measurement**: DN-AI can automatically measure the amount of DNA replication in fluorescence microscopy images. We use a deep learning model to segment the images and measure the amount of DNA replication.
|
10 |
+
- **User-friendly interface**: DN-AI provides a web-based user-friendly interface that allows users to easily upload images and view the results. Both jpeg and tiff images are supported.
|
11 |
+
- **Batch processing**: DN-AI can process multiple images at once, making it easy to analyze large datasets. It also supports comparing ratios between different batches of images.
|
12 |
+
|
13 |
+
|
14 |
+
## Installation
|
15 |
+
|
16 |
+
DN-AI relies on Python. We recommend installing its latest version (3.10 or higher) and using a virtual environment to avoid conflicts with other packages.
|
17 |
+
|
18 |
+
### Prerequisites
|
19 |
+
Before installing DN-AI, make sure you have the following prerequisites installed:
|
20 |
+
- [Python 3.10 or higher](https://www.python.org/downloads/)
|
21 |
+
- [pip](https://pip.pypa.io/en/stable/installation/) (Python package installer)
|
22 |
+
|
23 |
+
### Python Package
|
24 |
+
To install DN-AI as a Python package, you can use pip:
|
25 |
+
|
26 |
+
```bash
|
27 |
+
pip install git+https://github.com/ClementPla/DeepFiberQ.git
|
28 |
+
```
|
29 |
+
|
30 |
+
|
31 |
+
### Graphical User Interface (GUI)
|
32 |
+
|
33 |
+
To run the DN-AI graphical user interface, you can use the following command:
|
34 |
+
|
35 |
+
```bash
|
36 |
+
DNAI
|
37 |
+
```
|
38 |
+
|
39 |
+
Make sure you are running this command in the terminal where you have installed DN-AI. This will start a local web server and you will see output similar to:
|
40 |
+
|
41 |
+
|
42 |
+
Then open your web browser and go to `http://localhost:8501` to access the DN-AI interface.
|
43 |
+
|
44 |
+
Screenshots of the GUI:
|
45 |
+
|
46 |
+

|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
### Docker
|
51 |
+
A Docker image is available for DN-AI. You can pull the image from Docker Hub:
|
52 |
+
|
53 |
+
```bash
|
54 |
+
docker pull clementpla/dnafiber
|
55 |
+
```
|
56 |
+
|
57 |
+
### Google Colab
|
58 |
+
We also provide a Google Colab notebook for DN-AI. You can access it [here](https://colab.research.google.com/github/ClementPla/DeepFiberQ/blob/main/Colab/DNA_Fiber_Q.ipynb).
|
59 |
+
|
dnafiber/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from dnafiber.deployment import _get_model
|
dnafiber/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (219 Bytes). View file
|
|
dnafiber/__pycache__/deployment.cpython-312.pyc
ADDED
Binary file (2.46 kB). View file
|
|
dnafiber/__pycache__/inference.cpython-312.pyc
ADDED
Binary file (4.95 kB). View file
|
|
dnafiber/__pycache__/metric.cpython-312.pyc
ADDED
Binary file (6.79 kB). View file
|
|
dnafiber/__pycache__/post_process.cpython-312.pyc
ADDED
Binary file (5.36 kB). View file
|
|
dnafiber/__pycache__/trainee.cpython-312.pyc
ADDED
Binary file (7.73 kB). View file
|
|
dnafiber/analysis/__init__.py
ADDED
File without changes
|
dnafiber/analysis/chart.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
from dnafiber.analysis.const import palette
|
3 |
+
import plotly.express as px
|
4 |
+
|
5 |
+
|
6 |
+
def get_color_association(df):
|
7 |
+
"""
|
8 |
+
Get the color association for each image in the dataframe.
|
9 |
+
"""
|
10 |
+
unique_name = df["image_name"].unique()
|
11 |
+
color_association = {i: p for (i, p) in zip(unique_name, palette)}
|
12 |
+
return color_association
|
13 |
+
|
14 |
+
|
15 |
+
def plot_ratio(df, color_association=None, only_bilateral=True):
|
16 |
+
df = df[["ratio", "image_name", "fiber_type"]].copy()
|
17 |
+
|
18 |
+
df["Image"] = df["image_name"]
|
19 |
+
df["Fiber Type"] = df["fiber_type"]
|
20 |
+
df["Ratio"] = df["ratio"]
|
21 |
+
if only_bilateral:
|
22 |
+
df = df[df["Fiber Type"] == "double"]
|
23 |
+
|
24 |
+
df = df.sort_values(
|
25 |
+
by=["Image", "Fiber Type"],
|
26 |
+
ascending=[True, True],
|
27 |
+
)
|
28 |
+
|
29 |
+
# Order the dataframe by the average ratio of each image
|
30 |
+
image_order = (
|
31 |
+
df.groupby("Image")["Ratio"].median().sort_values(ascending=True).index
|
32 |
+
)
|
33 |
+
df["Image"] = pd.Categorical(df["Image"], categories=image_order, ordered=True)
|
34 |
+
df.sort_values("Image", inplace=True)
|
35 |
+
if color_association is None:
|
36 |
+
color_association = get_color_association(df)
|
37 |
+
unique_name = df["image_name"].unique()
|
38 |
+
color_association = {i: p for (i, p) in zip(unique_name, palette)}
|
39 |
+
|
40 |
+
this_palette = [color_association[i] for i in unique_name]
|
41 |
+
fig = px.violin(
|
42 |
+
df,
|
43 |
+
y="Ratio",
|
44 |
+
x="Image",
|
45 |
+
color="Image",
|
46 |
+
color_discrete_sequence=this_palette,
|
47 |
+
box=True, # draw box plot inside the violin
|
48 |
+
points="all", # can be 'outliers', or False
|
49 |
+
)
|
50 |
+
|
51 |
+
# Make the fig taller
|
52 |
+
|
53 |
+
fig.update_layout(
|
54 |
+
height=500,
|
55 |
+
width=1000,
|
56 |
+
title="Ratio of green to red",
|
57 |
+
yaxis_title="Ratio",
|
58 |
+
xaxis_title="Image",
|
59 |
+
legend_title="Image",
|
60 |
+
)
|
61 |
+
return fig
|
dnafiber/analysis/const.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from catppuccin.palette import PALETTE
|
2 |
+
|
3 |
+
palette = [c.hex for c in PALETTE.latte.colors]
|
dnafiber/analysis/utils.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tqdm.auto import tqdm
|
2 |
+
from dnafiber.data.utils import read_colormask
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
def build_consensus_map(intergraders, root_img, list_img):
|
7 |
+
all_masks = []
|
8 |
+
for img_path in tqdm(list_img):
|
9 |
+
path_from_root = img_path.relative_to(root_img)
|
10 |
+
masks = []
|
11 |
+
for intergrader in intergraders:
|
12 |
+
intergrader_path = (intergrader / path_from_root).with_suffix(".png")
|
13 |
+
if not intergrader_path.exists():
|
14 |
+
print(f"Missing {intergrader_path}")
|
15 |
+
continue
|
16 |
+
mask = read_colormask(intergrader_path)
|
17 |
+
masks.append(mask)
|
18 |
+
masks = np.array(masks)
|
19 |
+
|
20 |
+
all_masks.append(masks)
|
21 |
+
return np.array(all_masks)
|
dnafiber/callbacks.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from lightning.pytorch.callbacks import Callback
|
2 |
+
from pytorch_lightning.utilities import rank_zero_only
|
3 |
+
import wandb
|
4 |
+
|
5 |
+
|
6 |
+
class LogPredictionSamplesCallback(Callback):
|
7 |
+
def __init__(self, wandb_logger, n_images=8):
|
8 |
+
self.n_images = n_images
|
9 |
+
self.wandb_logger = wandb_logger
|
10 |
+
super().__init__()
|
11 |
+
|
12 |
+
@rank_zero_only
|
13 |
+
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
14 |
+
if batch_idx < 1 and trainer.is_global_zero:
|
15 |
+
n = self.n_images
|
16 |
+
x = batch["image"][:n].float()
|
17 |
+
h, w = x.shape[-2:]
|
18 |
+
y = batch["mask"][:n]
|
19 |
+
pred = outputs[:n]
|
20 |
+
pred = pred.argmax(dim=1)
|
21 |
+
|
22 |
+
if len(y.shape) == 4:
|
23 |
+
y = y.squeeze(1)
|
24 |
+
if len(pred.shape) == 4:
|
25 |
+
pred = pred.squeeze(1)
|
26 |
+
y = y.clamp(0, 2)
|
27 |
+
columns = ["image"]
|
28 |
+
class_labels = {0: "Background", 1: "Red", 2: "Green"}
|
29 |
+
|
30 |
+
data = [
|
31 |
+
[
|
32 |
+
wandb.Image(
|
33 |
+
x_i,
|
34 |
+
masks={
|
35 |
+
"Prediction": {
|
36 |
+
"mask_data": p_i.cpu().numpy(),
|
37 |
+
"class_labels": class_labels,
|
38 |
+
},
|
39 |
+
"Groundtruth": {
|
40 |
+
"mask_data": y_i.cpu().numpy(),
|
41 |
+
"class_labels": class_labels,
|
42 |
+
},
|
43 |
+
},
|
44 |
+
)
|
45 |
+
]
|
46 |
+
for x_i, y_i, p_i in list(zip(x, y, pred))
|
47 |
+
]
|
48 |
+
self.wandb_logger.log_table(
|
49 |
+
data=data, key=f"Validation Batch {batch_idx}", columns=columns
|
50 |
+
)
|
dnafiber/data/__init__.py
ADDED
File without changes
|
dnafiber/data/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (167 Bytes). View file
|
|
dnafiber/data/__pycache__/utils.cpython-312.pyc
ADDED
Binary file (4.66 kB). View file
|
|
dnafiber/data/dataset.py
ADDED
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import albumentations as A
|
2 |
+
import nntools.dataset as D
|
3 |
+
import numpy as np
|
4 |
+
from albumentations.pytorch import ToTensorV2
|
5 |
+
from lightning import LightningDataModule
|
6 |
+
from sklearn.model_selection import train_test_split
|
7 |
+
from torch.utils.data import DataLoader
|
8 |
+
from skimage.measure import label, regionprops
|
9 |
+
from skimage.morphology import skeletonize, dilation
|
10 |
+
from skimage.segmentation import expand_labels
|
11 |
+
import torch
|
12 |
+
from nntools.dataset.composer import CacheBullet
|
13 |
+
|
14 |
+
|
15 |
+
@D.nntools_wrapper
|
16 |
+
def convert_mask(mask):
|
17 |
+
output = np.zeros(mask.shape[:2], dtype=np.uint8)
|
18 |
+
output[mask[:, :, 0] > 200] = 1
|
19 |
+
output[mask[:, :, 1] > 200] = 2
|
20 |
+
binary_mask = output > 0
|
21 |
+
skeleton = skeletonize(binary_mask) * output
|
22 |
+
output = expand_labels(skeleton, 3)
|
23 |
+
output = np.clip(output, 0, 2)
|
24 |
+
return {"mask": output}
|
25 |
+
|
26 |
+
|
27 |
+
@D.nntools_wrapper
|
28 |
+
def extract_bbox(mask):
|
29 |
+
binary_mask = mask > 0
|
30 |
+
labelled = label(binary_mask)
|
31 |
+
props = regionprops(labelled, intensity_image=mask)
|
32 |
+
skeleton = skeletonize(binary_mask) * mask
|
33 |
+
mask = dilation(skeleton, np.ones((3, 3)))
|
34 |
+
bboxes = []
|
35 |
+
masks = []
|
36 |
+
# We want the XYXY format
|
37 |
+
for prop in props:
|
38 |
+
minr, minc, maxr, maxc = prop.bbox
|
39 |
+
bboxes.append([minc, minr, maxc, maxr])
|
40 |
+
masks.append((labelled == prop.label).astype(np.uint8))
|
41 |
+
if not masks:
|
42 |
+
masks = np.zeros_like(mask)[np.newaxis, :, :]
|
43 |
+
masks = np.array(masks)
|
44 |
+
masks = np.moveaxis(masks, 0, -1)
|
45 |
+
|
46 |
+
return {
|
47 |
+
"bboxes": np.array(bboxes),
|
48 |
+
"mask": masks,
|
49 |
+
"fiber_ids": np.array([p.label for p in props]),
|
50 |
+
}
|
51 |
+
|
52 |
+
|
53 |
+
class FiberDatamodule(LightningDataModule):
|
54 |
+
def __init__(
|
55 |
+
self,
|
56 |
+
root_img,
|
57 |
+
crop_size=(256, 256),
|
58 |
+
shape=1024,
|
59 |
+
batch_size=32,
|
60 |
+
num_workers=8,
|
61 |
+
use_bbox=False,
|
62 |
+
**kwargs,
|
63 |
+
):
|
64 |
+
self.shape = shape
|
65 |
+
self.root_img = str(root_img)
|
66 |
+
self.crop_size = crop_size
|
67 |
+
self.batch_size = batch_size
|
68 |
+
self.num_workers = num_workers
|
69 |
+
self.kwargs = kwargs
|
70 |
+
self.use_bbox = use_bbox
|
71 |
+
|
72 |
+
super().__init__()
|
73 |
+
|
74 |
+
def setup(self, *args, **kwargs):
|
75 |
+
def _get_dataset(version):
|
76 |
+
dataset = D.MultiImageDataset(
|
77 |
+
{
|
78 |
+
"image": f"{self.root_img}/{version}/images/",
|
79 |
+
"mask": f"{self.root_img}/{version}/annotations/",
|
80 |
+
},
|
81 |
+
shape=(self.shape, self.shape),
|
82 |
+
use_cache=self.kwargs.get("use_cache", False),
|
83 |
+
cache_option=self.kwargs.get("cache_option", None),
|
84 |
+
) # type: ignore
|
85 |
+
dataset.img_filepath["image"] = np.asarray( # type: ignore
|
86 |
+
sorted(
|
87 |
+
list(dataset.img_filepath["image"]),
|
88 |
+
key=lambda x: (x.parent.stem, x.stem),
|
89 |
+
)
|
90 |
+
)
|
91 |
+
dataset.img_filepath["mask"] = np.asarray( # type: ignore
|
92 |
+
sorted(
|
93 |
+
list(dataset.img_filepath["mask"]),
|
94 |
+
key=lambda x: (x.parent.stem, x.stem),
|
95 |
+
)
|
96 |
+
)
|
97 |
+
dataset.composer = D.Composition()
|
98 |
+
dataset.composer << convert_mask # type: ignore
|
99 |
+
if self.use_bbox:
|
100 |
+
dataset.composer << extract_bbox
|
101 |
+
|
102 |
+
return dataset
|
103 |
+
|
104 |
+
self.train = _get_dataset("train")
|
105 |
+
self.val = _get_dataset("train")
|
106 |
+
self.test = _get_dataset("test")
|
107 |
+
self.train.composer << CacheBullet()
|
108 |
+
self.val.use_cache = False
|
109 |
+
self.test.use_cache = False
|
110 |
+
|
111 |
+
stratify = []
|
112 |
+
for f in self.train.img_filepath["image"]:
|
113 |
+
if "tile" in f.stem:
|
114 |
+
stratify.append(int(f.parent.stem))
|
115 |
+
else:
|
116 |
+
stratify.append(25)
|
117 |
+
train_idx, val_idx = train_test_split(
|
118 |
+
np.arange(len(self.train)), # type: ignore
|
119 |
+
stratify=stratify,
|
120 |
+
test_size=0.2,
|
121 |
+
random_state=42,
|
122 |
+
)
|
123 |
+
self.train.subset(train_idx)
|
124 |
+
self.val.subset(val_idx)
|
125 |
+
|
126 |
+
self.train.composer.add(*self.get_train_composer())
|
127 |
+
self.val.composer.add(*self.cast_operators())
|
128 |
+
self.test.composer.add(*self.cast_operators())
|
129 |
+
|
130 |
+
def get_train_composer(self):
|
131 |
+
transforms = []
|
132 |
+
if self.crop_size is not None:
|
133 |
+
transforms.append(
|
134 |
+
A.CropNonEmptyMaskIfExists(
|
135 |
+
width=self.crop_size[0], height=self.crop_size[1]
|
136 |
+
),
|
137 |
+
)
|
138 |
+
return [
|
139 |
+
A.Compose(
|
140 |
+
transforms
|
141 |
+
+ [
|
142 |
+
A.HorizontalFlip(),
|
143 |
+
A.VerticalFlip(),
|
144 |
+
A.Affine(),
|
145 |
+
A.ElasticTransform(),
|
146 |
+
A.RandomRotate90(),
|
147 |
+
A.OneOf(
|
148 |
+
[
|
149 |
+
A.RandomBrightnessContrast(
|
150 |
+
brightness_limit=(-0.2, 0.1),
|
151 |
+
contrast_limit=(-0.2, 0.1),
|
152 |
+
p=0.5,
|
153 |
+
),
|
154 |
+
A.HueSaturationValue(
|
155 |
+
hue_shift_limit=(-5, 5),
|
156 |
+
sat_shift_limit=(-20, 20),
|
157 |
+
val_shift_limit=(-20, 20),
|
158 |
+
p=0.5,
|
159 |
+
),
|
160 |
+
]
|
161 |
+
),
|
162 |
+
A.GaussNoise(std_range=(0.0, 0.1), p=0.5),
|
163 |
+
],
|
164 |
+
bbox_params=A.BboxParams(
|
165 |
+
format="pascal_voc", label_fields=["fiber_ids"], min_visibility=0.95
|
166 |
+
)
|
167 |
+
if self.use_bbox
|
168 |
+
else None,
|
169 |
+
),
|
170 |
+
*self.cast_operators(),
|
171 |
+
]
|
172 |
+
|
173 |
+
def cast_operators(self):
|
174 |
+
return [
|
175 |
+
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
|
176 |
+
if not self.use_bbox
|
177 |
+
else A.Normalize(
|
178 |
+
mean=(
|
179 |
+
0.0,
|
180 |
+
0.0,
|
181 |
+
0.0,
|
182 |
+
),
|
183 |
+
std=(1.0, 1.0, 1.0),
|
184 |
+
max_pixel_value=255,
|
185 |
+
),
|
186 |
+
ToTensorV2(),
|
187 |
+
]
|
188 |
+
|
189 |
+
def train_dataloader(self):
|
190 |
+
if self.use_bbox:
|
191 |
+
return DataLoader(
|
192 |
+
self.train,
|
193 |
+
batch_size=self.batch_size,
|
194 |
+
shuffle=True,
|
195 |
+
num_workers=self.num_workers,
|
196 |
+
pin_memory=True,
|
197 |
+
persistent_workers=True,
|
198 |
+
collate_fn=bbox_collate_fn,
|
199 |
+
)
|
200 |
+
|
201 |
+
else:
|
202 |
+
return DataLoader(
|
203 |
+
self.train,
|
204 |
+
batch_size=self.batch_size,
|
205 |
+
shuffle=True,
|
206 |
+
num_workers=self.num_workers,
|
207 |
+
pin_memory=True,
|
208 |
+
persistent_workers=True,
|
209 |
+
)
|
210 |
+
|
211 |
+
def val_dataloader(self):
|
212 |
+
if self.use_bbox:
|
213 |
+
return DataLoader(
|
214 |
+
self.val,
|
215 |
+
batch_size=self.batch_size,
|
216 |
+
shuffle=False,
|
217 |
+
num_workers=self.num_workers,
|
218 |
+
pin_memory=True,
|
219 |
+
persistent_workers=True,
|
220 |
+
collate_fn=bbox_collate_fn,
|
221 |
+
)
|
222 |
+
return DataLoader(
|
223 |
+
self.val,
|
224 |
+
batch_size=self.batch_size,
|
225 |
+
shuffle=False,
|
226 |
+
num_workers=self.num_workers,
|
227 |
+
)
|
228 |
+
|
229 |
+
def test_dataloader(self):
|
230 |
+
if self.use_bbox:
|
231 |
+
return DataLoader(
|
232 |
+
self.test,
|
233 |
+
batch_size=self.batch_size,
|
234 |
+
shuffle=False,
|
235 |
+
num_workers=self.num_workers,
|
236 |
+
pin_memory=True,
|
237 |
+
persistent_workers=True,
|
238 |
+
collate_fn=bbox_collate_fn,
|
239 |
+
)
|
240 |
+
return DataLoader(
|
241 |
+
self.test,
|
242 |
+
batch_size=self.batch_size,
|
243 |
+
shuffle=False,
|
244 |
+
num_workers=self.num_workers,
|
245 |
+
)
|
246 |
+
|
247 |
+
|
248 |
+
def bbox_collate_fn(batch):
|
249 |
+
images = []
|
250 |
+
targets = []
|
251 |
+
|
252 |
+
for b in batch:
|
253 |
+
target = dict()
|
254 |
+
|
255 |
+
target["boxes"] = torch.from_numpy(b["bboxes"])
|
256 |
+
if target["boxes"].shape[0] == 0:
|
257 |
+
target["boxes"] = torch.zeros((0, 4), dtype=torch.float32)
|
258 |
+
images.append(b["image"])
|
259 |
+
target["boxes"] = torch.from_numpy(b["bboxes"])
|
260 |
+
target["masks"] = b["mask"].permute(2, 0, 1)
|
261 |
+
if target["boxes"].shape[0] == 0:
|
262 |
+
target["labels"] = torch.zeros(1, dtype=torch.int64)
|
263 |
+
else:
|
264 |
+
target["labels"] = torch.ones_like(target["boxes"][:, 0], dtype=torch.int64)
|
265 |
+
|
266 |
+
targets.append(target)
|
267 |
+
|
268 |
+
return {
|
269 |
+
"image": torch.stack(images),
|
270 |
+
"targets": targets,
|
271 |
+
}
|
dnafiber/data/intergrader/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .const import *
|
dnafiber/data/intergrader/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (206 Bytes). View file
|
|
dnafiber/data/intergrader/__pycache__/analysis.cpython-312.pyc
ADDED
Binary file (6.06 kB). View file
|
|
dnafiber/data/intergrader/__pycache__/const.cpython-312.pyc
ADDED
Binary file (994 Bytes). View file
|
|
dnafiber/data/intergrader/__pycache__/io.cpython-312.pyc
ADDED
Binary file (1.72 kB). View file
|
|
dnafiber/data/intergrader/__pycache__/plot.cpython-312.pyc
ADDED
Binary file (6.6 kB). View file
|
|
dnafiber/data/intergrader/analysis.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from skimage.morphology import skeletonize
|
2 |
+
import numpy as np
|
3 |
+
from skimage.measure import label
|
4 |
+
from tqdm.contrib.concurrent import thread_map # or thread_map
|
5 |
+
def extract_fiber_properties(mask):
|
6 |
+
|
7 |
+
binary_mask = mask > 0
|
8 |
+
skeleton = skeletonize(binary_mask)
|
9 |
+
r = mask == 1
|
10 |
+
g = mask == 2
|
11 |
+
labeled_skeleton = label(skeleton, connectivity=2)
|
12 |
+
properties = {"R": [], "G": [], "ratio": []}
|
13 |
+
for i in range(1, labeled_skeleton.max() + 1):
|
14 |
+
fiber_mask = labeled_skeleton == i
|
15 |
+
sum_r = np.sum(r & fiber_mask)
|
16 |
+
sum_g = np.sum(g & fiber_mask)
|
17 |
+
if sum_r == 0 or sum_g == 0:
|
18 |
+
continue
|
19 |
+
properties["R"].append(np.sum(r & fiber_mask))
|
20 |
+
properties["G"].append(np.sum(g & fiber_mask))
|
21 |
+
|
22 |
+
properties["R"] = np.array(properties["R"])
|
23 |
+
properties["G"] = np.array(properties["G"])
|
24 |
+
properties["ratio"] = properties["R"] / (properties["G"])
|
25 |
+
properties["label"] = labeled_skeleton
|
26 |
+
return properties
|
27 |
+
|
28 |
+
|
29 |
+
def filter_non_commons_fibers(properties):
|
30 |
+
# Properties is a a list of dicts. For each dict, we have a labelmap and a list of reds, greens and ratios
|
31 |
+
# We want to filter out the fibers that are not common in all images
|
32 |
+
|
33 |
+
binary_labels = [p['label'] > 0 for p in properties]
|
34 |
+
common_labels = np.logical_and.reduce(binary_labels)
|
35 |
+
filtered_properties = {k:[] for k in properties.keys()}
|
36 |
+
for i, p in enumerate(properties):
|
37 |
+
# We want to keep the labels that are common in all images
|
38 |
+
good_labels = common_labels * p['label']
|
39 |
+
indices = np.unique(good_labels[good_labels > 0])
|
40 |
+
|
41 |
+
filtered_properties.append({
|
42 |
+
"R": p["R"][common_labels],
|
43 |
+
"G": p["G"][common_labels],
|
44 |
+
"ratio": p["ratio"][common_labels],
|
45 |
+
"label": p["label"][common_labels]
|
46 |
+
})
|
47 |
+
|
48 |
+
def skeletonize_mask(mask):
|
49 |
+
# Skeletonize the mask and return the skeleton
|
50 |
+
binary_mask = mask > 0
|
51 |
+
skeleton = skeletonize(binary_mask) * mask
|
52 |
+
return skeleton
|
53 |
+
|
54 |
+
|
55 |
+
def skeletonize_data_dict(data_dict):
|
56 |
+
skeletons = dict()
|
57 |
+
for annotator, images in data_dict.items():
|
58 |
+
skeletons[annotator] = dict()
|
59 |
+
for image_type, masks in images.items():
|
60 |
+
skeletons[annotator][image_type] = thread_map(skeletonize_mask, masks, max_workers=8)
|
61 |
+
|
62 |
+
return skeletons
|
63 |
+
|
64 |
+
|
65 |
+
def extract_properties_from_datadict(data_dict, with_common_analysis=True):
|
66 |
+
"""
|
67 |
+
Extract the properties of the fibers from the data dictionary.
|
68 |
+
The data dictionary is a dict of annotators. Each value is a dict of images. Each image is a list of masks.
|
69 |
+
"""
|
70 |
+
properties = dict(annotator=[], image_type=[], red=[], green=[], ratio=[], fiber_type=[])
|
71 |
+
all_annotators = list(data_dict.keys())
|
72 |
+
|
73 |
+
found_by = {a: [] for a in all_annotators}
|
74 |
+
properties.update(found_by)
|
75 |
+
for annotator, images in data_dict.items():
|
76 |
+
for image_type, masks in images.items():
|
77 |
+
for i, mask in enumerate(masks):
|
78 |
+
if with_common_analysis:
|
79 |
+
others_masks = []
|
80 |
+
other_annotators = []
|
81 |
+
for other in all_annotators:
|
82 |
+
if other == annotator:
|
83 |
+
continue
|
84 |
+
other_annotators.append(other)
|
85 |
+
others_masks.append(data_dict[other][image_type][i] > 0)
|
86 |
+
|
87 |
+
labels, num = label(mask>0, connectivity=2, return_num=True)
|
88 |
+
for l in range(1, num + 1):
|
89 |
+
fiber = labels == l
|
90 |
+
if np.sum(fiber) < 10:
|
91 |
+
continue
|
92 |
+
|
93 |
+
properties["annotator"].append(annotator)
|
94 |
+
properties["image_type"].append(image_type)
|
95 |
+
|
96 |
+
# Check for common fibers
|
97 |
+
properties[annotator].append(True)
|
98 |
+
if with_common_analysis:
|
99 |
+
for i, (other_mask, other_annotator) in enumerate(zip(others_masks, other_annotators)):
|
100 |
+
properties[other_annotator].append(np.any(fiber & other_mask))
|
101 |
+
|
102 |
+
red_length = np.sum(fiber & (mask == 1))
|
103 |
+
green_length = np.sum(fiber & (mask == 2))
|
104 |
+
if red_length == 0 or green_length == 0:
|
105 |
+
continue
|
106 |
+
properties["ratio"].append(green_length / (red_length + 1e-7)) # Avoid division by zero
|
107 |
+
properties["red"].append(red_length)
|
108 |
+
properties["green"].append(green_length)
|
109 |
+
|
110 |
+
segments, count = label(mask[fiber], connectivity=1, return_num=True)
|
111 |
+
if count == 1:
|
112 |
+
properties["fiber_type"].append("single")
|
113 |
+
elif count == 2:
|
114 |
+
properties["fiber_type"].append("double")
|
115 |
+
elif count > 2:
|
116 |
+
properties["fiber_type"].append("multiple")
|
117 |
+
else:
|
118 |
+
properties["fiber_type"].append("unknown")
|
119 |
+
|
120 |
+
return properties
|
dnafiber/data/intergrader/auto.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
def inference_model(model, path, use_cuda=False):
|
2 |
+
pass
|
3 |
+
|
dnafiber/data/intergrader/const.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
BLIND_MAPPING = {
|
2 |
+
"siB+M-01": "0",
|
3 |
+
"siB+M-04": "1",
|
4 |
+
"siBRCA2-02": "5",
|
5 |
+
"siBRCA2-03": "15",
|
6 |
+
"siTONSL-03": "11",
|
7 |
+
"siTONSL-04": "14",
|
8 |
+
"HLTF ko+si MMS22L-01": "8",
|
9 |
+
"HLTF ko+si MMS22L-02": "13",
|
10 |
+
"siBRCA2+SMARCAL KO-01": "2",
|
11 |
+
"siBRCA2+SMARCAL KO-03": "9",
|
12 |
+
"siBRCA2+SMARCAL KO-04": "16",
|
13 |
+
"siBRCA2-01": "4",
|
14 |
+
"59_siBRCA2-02": "7",
|
15 |
+
"siNT-01": "10",
|
16 |
+
"siNT-02": "12",
|
17 |
+
"siMMS22L_+dox-01": "3",
|
18 |
+
"siMMS22L_+dox-02": "6",
|
19 |
+
}
|
20 |
+
|
21 |
+
REVERSE_BLIND_MAPPING = {v: k for k, v in BLIND_MAPPING.items()}
|
dnafiber/data/intergrader/io.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
from skimage.segmentation import expand_labels
|
4 |
+
|
5 |
+
def read_to_mask(f):
|
6 |
+
img = cv2.imread(str(f), cv2.IMREAD_UNCHANGED)[:,:,::-1]
|
7 |
+
mask = np.zeros(img.shape[:2], dtype=np.uint8)
|
8 |
+
mask[img[:, :, 0] > 200] = 1
|
9 |
+
mask[img[:, :, 1] > 200] = 2
|
10 |
+
|
11 |
+
return mask
|
12 |
+
|
13 |
+
|
14 |
+
def read_mask_from_path_gens(dict_gens, mapping=None):
|
15 |
+
output = {k: dict() for k in dict_gens.keys()}
|
16 |
+
for k, files in dict_gens.items():
|
17 |
+
for file in files:
|
18 |
+
name = file.parent.stem
|
19 |
+
if mapping is not None:
|
20 |
+
name = mapping.get(name, name)
|
21 |
+
mask = read_to_mask(file)
|
22 |
+
mask = expand_labels(mask, 1)
|
23 |
+
if output[k].get(name) is None:
|
24 |
+
output[k][name] = []
|
25 |
+
output[k][name].append(mask)
|
26 |
+
return output
|
27 |
+
|
dnafiber/data/intergrader/plot.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib.pyplot as plt
|
2 |
+
import numpy as np
|
3 |
+
from matplotlib.colors import ListedColormap
|
4 |
+
from skimage.measure import label, regionprops
|
5 |
+
import base64
|
6 |
+
from typing import Callable
|
7 |
+
|
8 |
+
def imshow_compare(data_dict, ax_size=4, draw_bbox=False, max_images=None):
|
9 |
+
"""
|
10 |
+
Display the images in a grid format for comparison.
|
11 |
+
Each key is an annotator, each value is another dict, where the key is the image type and the value the list of corresponding images.
|
12 |
+
"""
|
13 |
+
# 0 is black, 1 is red, 2 is green
|
14 |
+
cmap = ListedColormap(['black', 'red', 'green'])
|
15 |
+
|
16 |
+
# Convert the data dictionary to a dict of annotators: list of images
|
17 |
+
data = dict()
|
18 |
+
for annotator, images in data_dict.items():
|
19 |
+
if annotator not in data:
|
20 |
+
data[annotator] = []
|
21 |
+
for image_type, masks in images.items():
|
22 |
+
for mask in masks:
|
23 |
+
data[annotator].append(mask)
|
24 |
+
|
25 |
+
annotators = list(data.keys())
|
26 |
+
num_images = len(data[annotators[0]])
|
27 |
+
if max_images is not None and num_images > max_images:
|
28 |
+
num_images = max_images
|
29 |
+
num_annotators = len(annotators)
|
30 |
+
|
31 |
+
fig_size = (ax_size * num_annotators, ax_size * num_images)
|
32 |
+
fig, axes = plt.subplots(num_images, num_annotators, figsize=fig_size, squeeze=False)
|
33 |
+
|
34 |
+
for i, annotator in enumerate(annotators):
|
35 |
+
for j in range(num_images):
|
36 |
+
if max_images is not None and j > max_images:
|
37 |
+
break
|
38 |
+
ax = axes[j, i]
|
39 |
+
mask = data[annotator][j]
|
40 |
+
ax.imshow(mask, cmap=cmap, interpolation='nearest')
|
41 |
+
ax.axis('off')
|
42 |
+
ax.set_xticks([])
|
43 |
+
ax.set_yticks([])
|
44 |
+
if draw_bbox:
|
45 |
+
mask = mask > 0
|
46 |
+
labeled_mask = label(mask, connectivity=2)
|
47 |
+
regions = regionprops(labeled_mask)
|
48 |
+
for region in regions:
|
49 |
+
minr, minc, maxr, maxc = region.bbox
|
50 |
+
rect = plt.Rectangle((minc, minr), maxc - minc, maxr - minr,
|
51 |
+
fill=False, edgecolor='yellow', linewidth=0.5)
|
52 |
+
ax.add_patch(rect)
|
53 |
+
|
54 |
+
|
55 |
+
|
56 |
+
if j == 0:
|
57 |
+
ax.set_title(annotator)
|
58 |
+
|
59 |
+
|
60 |
+
fig.tight_layout()
|
61 |
+
return fig, axes
|
62 |
+
|
63 |
+
|
64 |
+
def add_p_value_annotation(fig, array_columns, stats_test, subplot=None, _format=dict(interline=0.07, text_height=1.07, color='black')):
|
65 |
+
''' Adds notations giving the p-value between two box plot data (t-test two-sided comparison)
|
66 |
+
|
67 |
+
Parameters:
|
68 |
+
----------
|
69 |
+
fig: figure
|
70 |
+
plotly boxplot figure
|
71 |
+
array_columns: np.array
|
72 |
+
array of which columns to compare
|
73 |
+
e.g.: [[0,1], [1,2]] compares column 0 with 1 and 1 with 2
|
74 |
+
subplot: None or int
|
75 |
+
specifies if the figures has subplots and what subplot to add the notation to
|
76 |
+
_format: dict
|
77 |
+
format characteristics for the lines
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
-------
|
81 |
+
fig: figure
|
82 |
+
figure with the added notation
|
83 |
+
'''
|
84 |
+
# Specify in what y_range to plot for each pair of columns
|
85 |
+
y_range = np.zeros([len(array_columns), 2])
|
86 |
+
for i in range(len(array_columns)):
|
87 |
+
y_range[i] = [1.01+i*_format['interline'], 1.02+i*_format['interline']]
|
88 |
+
|
89 |
+
# Get values from figure
|
90 |
+
fig_dict = fig.to_dict()
|
91 |
+
# Get indices if working with subplots
|
92 |
+
if subplot:
|
93 |
+
if subplot == 1:
|
94 |
+
subplot_str = ''
|
95 |
+
else:
|
96 |
+
subplot_str =str(subplot)
|
97 |
+
indices = [] #Change the box index to the indices of the data for that subplot
|
98 |
+
for index, data in enumerate(fig_dict['data']):
|
99 |
+
#print(index, data['xaxis'], 'x' + subplot_str)
|
100 |
+
if data['xaxis'] == 'x' + subplot_str:
|
101 |
+
indices = np.append(indices, index)
|
102 |
+
indices = [int(i) for i in indices]
|
103 |
+
print((indices))
|
104 |
+
else:
|
105 |
+
subplot_str = ''
|
106 |
+
|
107 |
+
# Print the p-values
|
108 |
+
for index, column_pair in enumerate(array_columns):
|
109 |
+
if subplot:
|
110 |
+
data_pair = [indices[column_pair[0]], indices[column_pair[1]]]
|
111 |
+
else:
|
112 |
+
data_pair = column_pair
|
113 |
+
|
114 |
+
# Mare sure it is selecting the data and subplot you want
|
115 |
+
#print('0:', fig_dict['data'][data_pair[0]]['name'], fig_dict['data'][data_pair[0]]['xaxis'])
|
116 |
+
#print('1:', fig_dict['data'][data_pair[1]]['name'], fig_dict['data'][data_pair[1]]['xaxis'])
|
117 |
+
|
118 |
+
if isinstance(stats_test, Callable):
|
119 |
+
# Get the p-value
|
120 |
+
d1 = fig_dict['data'][data_pair[0]]['y']
|
121 |
+
d2 = fig_dict['data'][data_pair[1]]['y']
|
122 |
+
d1 = base64.b64decode(d1['bdata'])
|
123 |
+
d2 = base64.b64decode(d2['bdata'])
|
124 |
+
d1 = np.frombuffer(d1, dtype=np.float64)
|
125 |
+
d2 = np.frombuffer(d2, dtype=np.float64)
|
126 |
+
pvalue = stats_test(
|
127 |
+
d1,
|
128 |
+
d2,
|
129 |
+
)[1]
|
130 |
+
else:
|
131 |
+
pvalue = stats_test[index]
|
132 |
+
if pvalue >= 0.05:
|
133 |
+
symbol = 'ns'
|
134 |
+
elif pvalue >= 0.01:
|
135 |
+
symbol = '*'
|
136 |
+
elif pvalue >= 0.001:
|
137 |
+
symbol = '**'
|
138 |
+
else:
|
139 |
+
symbol = '***'
|
140 |
+
# Vertical line
|
141 |
+
fig.add_shape(type="line",
|
142 |
+
xref="x"+subplot_str, yref="y"+subplot_str+" domain",
|
143 |
+
x0=column_pair[0], y0=y_range[index][0],
|
144 |
+
x1=column_pair[0], y1=y_range[index][1],
|
145 |
+
line=dict(color=_format['color'], width=2,)
|
146 |
+
)
|
147 |
+
# Horizontal line
|
148 |
+
fig.add_shape(type="line",
|
149 |
+
xref="x"+subplot_str, yref="y"+subplot_str+" domain",
|
150 |
+
x0=column_pair[0], y0=y_range[index][1],
|
151 |
+
x1=column_pair[1], y1=y_range[index][1],
|
152 |
+
line=dict(color=_format['color'], width=2,)
|
153 |
+
)
|
154 |
+
# Vertical line
|
155 |
+
fig.add_shape(type="line",
|
156 |
+
xref="x"+subplot_str, yref="y"+subplot_str+" domain",
|
157 |
+
x0=column_pair[1], y0=y_range[index][0],
|
158 |
+
x1=column_pair[1], y1=y_range[index][1],
|
159 |
+
line=dict(color=_format['color'], width=2,)
|
160 |
+
)
|
161 |
+
## add text at the correct x, y coordinates
|
162 |
+
## for bars, there is a direct mapping from the bar number to 0, 1, 2...
|
163 |
+
fig.add_annotation(dict(font=dict(color=_format['color'],size=14),
|
164 |
+
x=(column_pair[0] + column_pair[1])/2,
|
165 |
+
y=y_range[index][1]*_format['text_height'],
|
166 |
+
showarrow=False,
|
167 |
+
text=symbol,
|
168 |
+
textangle=0,
|
169 |
+
xref="x"+subplot_str,
|
170 |
+
yref="y"+subplot_str+" domain"
|
171 |
+
))
|
172 |
+
return fig
|
dnafiber/data/utils.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
|
3 |
+
from xml.dom import minidom
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
from czifile import CziFile
|
7 |
+
from tifffile import imread
|
8 |
+
|
9 |
+
|
10 |
+
def read_svg(svg_path):
|
11 |
+
doc = minidom.parse(str(svg_path))
|
12 |
+
img_strings = {
|
13 |
+
path.getAttribute("id"): path.getAttribute("href")
|
14 |
+
for path in doc.getElementsByTagName("image")
|
15 |
+
}
|
16 |
+
doc.unlink()
|
17 |
+
|
18 |
+
red = img_strings["Red"]
|
19 |
+
green = img_strings["Green"]
|
20 |
+
red = base64.b64decode(red.split(",")[1])
|
21 |
+
green = base64.b64decode(green.split(",")[1])
|
22 |
+
red = cv2.imdecode(np.frombuffer(red, dtype=np.uint8), cv2.IMREAD_UNCHANGED)
|
23 |
+
green = cv2.imdecode(np.frombuffer(green, dtype=np.uint8), cv2.IMREAD_UNCHANGED)
|
24 |
+
|
25 |
+
red = cv2.cvtColor(red, cv2.COLOR_BGRA2GRAY)
|
26 |
+
green = cv2.cvtColor(green, cv2.COLOR_BGRA2GRAY)
|
27 |
+
mask = np.zeros_like(red)
|
28 |
+
mask[red > 0] = 1
|
29 |
+
mask[green > 0] = 2
|
30 |
+
return mask
|
31 |
+
|
32 |
+
|
33 |
+
def extract_bboxes(mask):
|
34 |
+
mask = np.array(mask)
|
35 |
+
mask = mask.astype(np.uint8)
|
36 |
+
|
37 |
+
# Find connected components
|
38 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(
|
39 |
+
mask, connectivity=8
|
40 |
+
)
|
41 |
+
bboxes = []
|
42 |
+
for i in range(1, num_labels):
|
43 |
+
x, y, w, h, area = stats[i]
|
44 |
+
bboxes.append([x, y, x + w, y + h])
|
45 |
+
return bboxes
|
46 |
+
|
47 |
+
|
48 |
+
def preprocess(raw_data, reverse_channels=False):
|
49 |
+
MAX_VALUE = 2**16 - 1
|
50 |
+
if raw_data.ndim == 2:
|
51 |
+
raw_data = raw_data[np.newaxis, :, :]
|
52 |
+
h, w = raw_data.shape[1:3]
|
53 |
+
orders = np.arange(raw_data.shape[0])[::-1] # Reverse channel order
|
54 |
+
result = np.zeros((h, w, 3), dtype=np.uint8)
|
55 |
+
|
56 |
+
for i, chan in enumerate(raw_data):
|
57 |
+
hist, bins = np.histogram(chan.ravel(), MAX_VALUE + 1, (0, MAX_VALUE + 1))
|
58 |
+
cdf = hist.cumsum()
|
59 |
+
cdf_normalized = cdf / cdf[-1]
|
60 |
+
bmax = np.searchsorted(cdf_normalized, 0.99, side="left")
|
61 |
+
clip = np.clip(chan, 0, bmax).astype(np.float32)
|
62 |
+
clip = (clip - clip.min()) / (bmax - clip.min()) * 255
|
63 |
+
result[:, :, orders[i]] = clip
|
64 |
+
if reverse_channels:
|
65 |
+
# Reverse channels 0 and 1
|
66 |
+
result = result[:, :, [1, 0, 2]]
|
67 |
+
return result
|
68 |
+
|
69 |
+
|
70 |
+
def read_czi(filepath):
|
71 |
+
data = CziFile(filepath)
|
72 |
+
|
73 |
+
return data.asarray().squeeze()
|
74 |
+
|
75 |
+
|
76 |
+
def read_tiff(filepath):
|
77 |
+
|
78 |
+
data = imread(filepath).squeeze()
|
79 |
+
|
80 |
+
return data
|
dnafiber/deployment.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dnafiber.trainee import Trainee
|
2 |
+
from dnafiber.postprocess.fiber import FiberProps
|
3 |
+
import pandas as pd
|
4 |
+
|
5 |
+
def _get_model(revision, device="cuda"):
|
6 |
+
if revision is None:
|
7 |
+
model = Trainee.from_pretrained(
|
8 |
+
"ClementP/DeepFiberQ", arch="unet", encoder_name="mit_b0"
|
9 |
+
)
|
10 |
+
else:
|
11 |
+
model = Trainee.from_pretrained(
|
12 |
+
"ClementP/DeepFiberQ",
|
13 |
+
revision=revision,
|
14 |
+
)
|
15 |
+
return model.eval().to(device)
|
16 |
+
|
17 |
+
|
18 |
+
def format_results(results: list[FiberProps], pixel_size: float) -> pd.DataFrame:
|
19 |
+
"""
|
20 |
+
Format the results for display in the UI.
|
21 |
+
"""
|
22 |
+
results = [fiber for fiber in results if fiber.is_valid]
|
23 |
+
all_results = dict(
|
24 |
+
FirstAnalog=[], SecondAnalog=[], length=[], ratio=[], fiber_type=[]
|
25 |
+
)
|
26 |
+
all_results["FirstAnalog"].extend([fiber.red * pixel_size for fiber in results])
|
27 |
+
all_results["SecondAnalog"].extend([fiber.green * pixel_size for fiber in results])
|
28 |
+
all_results["length"].extend(
|
29 |
+
[fiber.red * pixel_size + fiber.green * pixel_size for fiber in results]
|
30 |
+
)
|
31 |
+
all_results["ratio"].extend([fiber.ratio for fiber in results])
|
32 |
+
all_results["fiber_type"].extend([fiber.fiber_type for fiber in results])
|
33 |
+
|
34 |
+
return pd.DataFrame.from_dict(all_results)
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
MODELS_ZOO = {
|
40 |
+
"Ensemble": "ensemble",
|
41 |
+
"SegFormer MiT-B4": "segformer_mit_b4",
|
42 |
+
"SegFormer MiT-B2": "segformer_mit_b2",
|
43 |
+
"U-Net SE-ResNet50": "unet_se_resnet50",
|
44 |
+
}
|
dnafiber/inference.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn.functional as F
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from torchvision.transforms._functional_tensor import normalize
|
5 |
+
import pandas as pd
|
6 |
+
from skimage.segmentation import expand_labels
|
7 |
+
from skimage.measure import label
|
8 |
+
import albumentations as A
|
9 |
+
from monai.inferers import SlidingWindowInferer
|
10 |
+
from dnafiber.deployment import _get_model
|
11 |
+
from dnafiber.postprocess import refine_segmentation
|
12 |
+
|
13 |
+
transform = A.Compose(
|
14 |
+
[
|
15 |
+
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
16 |
+
A.ToTensorV2(),
|
17 |
+
]
|
18 |
+
)
|
19 |
+
|
20 |
+
|
21 |
+
def preprocess_image(image):
|
22 |
+
image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0)
|
23 |
+
image = normalize(image, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
24 |
+
return image
|
25 |
+
|
26 |
+
|
27 |
+
def convert_to_dataset(counts):
|
28 |
+
data = {"index": [], "red": [], "green": [], "ratio": []}
|
29 |
+
for k, v in counts.items():
|
30 |
+
data["index"].append(k)
|
31 |
+
data["green"].append(v["green"])
|
32 |
+
data["red"].append(v["red"])
|
33 |
+
if v["red"] == 0:
|
34 |
+
data["ratio"].append(np.nan)
|
35 |
+
else:
|
36 |
+
data["ratio"].append(v["green"] / (v["red"]))
|
37 |
+
df = pd.DataFrame(data)
|
38 |
+
return df
|
39 |
+
|
40 |
+
|
41 |
+
def convert_mask_to_image(mask, expand=False):
|
42 |
+
if expand:
|
43 |
+
mask = expand_labels(mask, distance=expand)
|
44 |
+
h, w = mask.shape
|
45 |
+
image = np.zeros((h, w, 3), dtype=np.uint8)
|
46 |
+
GREEN = np.array([0, 255, 0])
|
47 |
+
RED = np.array([255, 0, 0])
|
48 |
+
|
49 |
+
image[mask == 1] = RED
|
50 |
+
image[mask == 2] = GREEN
|
51 |
+
|
52 |
+
return image
|
53 |
+
|
54 |
+
|
55 |
+
@torch.inference_mode()
|
56 |
+
def infer(model, image, device, scale=0.13, to_numpy=True, only_probabilities=False):
|
57 |
+
if isinstance(model, str):
|
58 |
+
model = _get_model(device=device, revision=model)
|
59 |
+
model_pixel_size = 0.26
|
60 |
+
|
61 |
+
scale = scale / model_pixel_size
|
62 |
+
tensor = transform(image=image)["image"].unsqueeze(0).to(device)
|
63 |
+
h, w = tensor.shape[2], tensor.shape[3]
|
64 |
+
device = torch.device(device)
|
65 |
+
with torch.autocast(device_type=device.type):
|
66 |
+
tensor = F.interpolate(
|
67 |
+
tensor,
|
68 |
+
size=(int(h * scale), int(w * scale)),
|
69 |
+
mode="bilinear",
|
70 |
+
)
|
71 |
+
if tensor.shape[2] > 1024 or tensor.shape[3] > 1024:
|
72 |
+
inferer = SlidingWindowInferer(
|
73 |
+
roi_size=(1024, 1024),
|
74 |
+
sw_batch_size=4,
|
75 |
+
overlap=0.25,
|
76 |
+
mode="gaussian",
|
77 |
+
device=device,
|
78 |
+
progress=True,
|
79 |
+
)
|
80 |
+
output = inferer(tensor, model)
|
81 |
+
else:
|
82 |
+
output = model(tensor)
|
83 |
+
|
84 |
+
probabilities = F.softmax(output, dim=1)
|
85 |
+
if only_probabilities:
|
86 |
+
probabilities = probabilities.cpu()
|
87 |
+
|
88 |
+
probabilities = F.interpolate(
|
89 |
+
probabilities,
|
90 |
+
size=(h, w),
|
91 |
+
mode="bilinear",
|
92 |
+
)
|
93 |
+
return probabilities
|
94 |
+
|
95 |
+
output = F.interpolate(
|
96 |
+
probabilities.argmax(dim=1, keepdim=True).float(),
|
97 |
+
size=(h, w),
|
98 |
+
mode="nearest",
|
99 |
+
)
|
100 |
+
|
101 |
+
output = output.squeeze().byte()
|
102 |
+
if to_numpy:
|
103 |
+
output = output.cpu().numpy()
|
104 |
+
|
105 |
+
return output
|
dnafiber/metric.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import kornia as K
|
2 |
+
import torch
|
3 |
+
import torchmetrics.functional as F
|
4 |
+
from skimage.measure import label
|
5 |
+
from torchmetrics import Metric
|
6 |
+
|
7 |
+
|
8 |
+
class DNAFIBERMetric(Metric):
|
9 |
+
def __init__(self, **kwargs):
|
10 |
+
super().__init__(**kwargs)
|
11 |
+
|
12 |
+
self.add_state(
|
13 |
+
"detection_tp",
|
14 |
+
default=torch.tensor(0, dtype=torch.int64),
|
15 |
+
dist_reduce_fx="sum",
|
16 |
+
)
|
17 |
+
self.add_state(
|
18 |
+
"fiber_red_dice",
|
19 |
+
default=torch.tensor(0, dtype=torch.float32),
|
20 |
+
dist_reduce_fx="sum",
|
21 |
+
)
|
22 |
+
self.add_state(
|
23 |
+
"fiber_green_dice",
|
24 |
+
default=torch.tensor(0, dtype=torch.float32),
|
25 |
+
dist_reduce_fx="sum",
|
26 |
+
)
|
27 |
+
self.add_state(
|
28 |
+
"fiber_red_recall",
|
29 |
+
default=torch.tensor(0, dtype=torch.float32),
|
30 |
+
dist_reduce_fx="sum",
|
31 |
+
)
|
32 |
+
self.add_state(
|
33 |
+
"fiber_green_recall",
|
34 |
+
default=torch.tensor(0, dtype=torch.float32),
|
35 |
+
dist_reduce_fx="sum",
|
36 |
+
)
|
37 |
+
# Specificity
|
38 |
+
self.add_state(
|
39 |
+
"fiber_red_precision",
|
40 |
+
default=torch.tensor(0, dtype=torch.float32),
|
41 |
+
dist_reduce_fx="sum",
|
42 |
+
)
|
43 |
+
self.add_state(
|
44 |
+
"fiber_green_precision",
|
45 |
+
default=torch.tensor(0, dtype=torch.float32),
|
46 |
+
dist_reduce_fx="sum",
|
47 |
+
)
|
48 |
+
|
49 |
+
self.add_state(
|
50 |
+
"detection_fp",
|
51 |
+
default=torch.tensor(0, dtype=torch.int64),
|
52 |
+
dist_reduce_fx="sum",
|
53 |
+
)
|
54 |
+
self.add_state(
|
55 |
+
"N",
|
56 |
+
default=torch.tensor(0, dtype=torch.int64),
|
57 |
+
dist_reduce_fx="sum",
|
58 |
+
)
|
59 |
+
|
60 |
+
def update(self, preds, target):
|
61 |
+
if preds.ndim == 4:
|
62 |
+
preds = preds.argmax(dim=1)
|
63 |
+
if target.ndim == 4:
|
64 |
+
target = target.squeeze(1)
|
65 |
+
B, H, W = preds.shape
|
66 |
+
preds_labels = []
|
67 |
+
target_labels = []
|
68 |
+
binary_preds = preds > 0
|
69 |
+
binary_target = target > 0
|
70 |
+
N_true_labels = 0
|
71 |
+
for i in range(B):
|
72 |
+
pred = binary_preds[i].detach().cpu().numpy()
|
73 |
+
target_np = binary_target[i].detach().cpu().numpy()
|
74 |
+
pred_labels = label(pred, connectivity=2)
|
75 |
+
target_labels_np = label(target_np, connectivity=2)
|
76 |
+
preds_labels.append(torch.from_numpy(pred_labels).to(preds.device))
|
77 |
+
target_labels.append(torch.from_numpy(target_labels_np).to(preds.device))
|
78 |
+
N_true_labels += target_labels_np.max()
|
79 |
+
|
80 |
+
preds_labels = torch.stack(preds_labels)
|
81 |
+
target_labels = torch.stack(target_labels)
|
82 |
+
|
83 |
+
for i, plab in enumerate(preds_labels):
|
84 |
+
labels = torch.unique(plab)
|
85 |
+
for blob in labels:
|
86 |
+
if blob == 0:
|
87 |
+
continue
|
88 |
+
pred_mask = plab == blob
|
89 |
+
pixels_in_common = torch.any(pred_mask & binary_target[i])
|
90 |
+
if pixels_in_common:
|
91 |
+
self.detection_tp += 1
|
92 |
+
gt_label = target_labels[i][pred_mask].unique()[-1]
|
93 |
+
gt_mask = target_labels[i] == gt_label
|
94 |
+
common_mask = pred_mask | gt_mask
|
95 |
+
pred_fiber = preds[i][common_mask]
|
96 |
+
gt_fiber = target[i][common_mask]
|
97 |
+
dices = F.dice(
|
98 |
+
pred_fiber,
|
99 |
+
gt_fiber,
|
100 |
+
num_classes=3,
|
101 |
+
ignore_index=0,
|
102 |
+
average=None,
|
103 |
+
)
|
104 |
+
dices = torch.nan_to_num(dices, nan=0.0)
|
105 |
+
self.fiber_red_dice += dices[1]
|
106 |
+
self.fiber_green_dice += dices[2]
|
107 |
+
recalls = F.recall(
|
108 |
+
pred_fiber,
|
109 |
+
gt_fiber,
|
110 |
+
num_classes=3,
|
111 |
+
ignore_index=0,
|
112 |
+
task="multiclass",
|
113 |
+
average=None,
|
114 |
+
)
|
115 |
+
recalls = torch.nan_to_num(recalls, nan=0.0)
|
116 |
+
self.fiber_red_recall += recalls[1]
|
117 |
+
self.fiber_green_recall += recalls[2]
|
118 |
+
|
119 |
+
# Specificity
|
120 |
+
specificity = F.precision(
|
121 |
+
pred_fiber,
|
122 |
+
gt_fiber,
|
123 |
+
num_classes=3,
|
124 |
+
ignore_index=0,
|
125 |
+
task="multiclass",
|
126 |
+
average=None,
|
127 |
+
)
|
128 |
+
specificity = torch.nan_to_num(specificity, nan=0.0)
|
129 |
+
self.fiber_red_precision += specificity[1]
|
130 |
+
self.fiber_green_precision += specificity[2]
|
131 |
+
|
132 |
+
else:
|
133 |
+
self.detection_fp += 1
|
134 |
+
|
135 |
+
self.N += N_true_labels
|
136 |
+
|
137 |
+
def compute(self):
|
138 |
+
return {
|
139 |
+
"detection_precision": self.detection_tp
|
140 |
+
/ (self.detection_tp + self.detection_fp + 1e-7),
|
141 |
+
"detection_recall": self.detection_tp / (self.N + 1e-7),
|
142 |
+
"fiber_red_dice": self.fiber_red_dice / (self.detection_tp + 1e-7),
|
143 |
+
"fiber_green_dice": self.fiber_green_dice / (self.detection_tp + 1e-7),
|
144 |
+
"fiber_red_recall": self.fiber_red_recall / (self.detection_tp + 1e-7),
|
145 |
+
"fiber_green_recall": self.fiber_green_recall / (self.detection_tp + 1e-7),
|
146 |
+
"fiber_red_precision": self.fiber_red_precision
|
147 |
+
/ (self.detection_tp + 1e-7),
|
148 |
+
"fiber_green_precision": self.fiber_green_precision
|
149 |
+
/ (self.detection_tp + 1e-7),
|
150 |
+
}
|
dnafiber/model/maskrcnn.py
ADDED
File without changes
|
dnafiber/postprocess/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .core import refine_segmentation
|
dnafiber/postprocess/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (225 Bytes). View file
|
|
dnafiber/postprocess/__pycache__/core.cpython-312.pyc
ADDED
Binary file (11.9 kB). View file
|
|
dnafiber/postprocess/__pycache__/fiber.cpython-312.pyc
ADDED
Binary file (6.39 kB). View file
|
|
dnafiber/postprocess/__pycache__/skan.cpython-312.pyc
ADDED
Binary file (9.43 kB). View file
|
|
dnafiber/postprocess/core.py
ADDED
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
from typing import List, Tuple
|
4 |
+
from dnafiber.postprocess.skan import find_endpoints, compute_points_angle
|
5 |
+
from scipy.spatial.distance import cdist
|
6 |
+
|
7 |
+
from scipy.sparse.csgraph import connected_components
|
8 |
+
from scipy.sparse import csr_array
|
9 |
+
from skimage.morphology import skeletonize
|
10 |
+
from dnafiber.postprocess.skan import find_line_intersection
|
11 |
+
from dnafiber.postprocess.fiber import Fiber, FiberProps, Bbox
|
12 |
+
from itertools import compress
|
13 |
+
import matplotlib.pyplot as plt
|
14 |
+
from matplotlib.colors import ListedColormap
|
15 |
+
|
16 |
+
cmlabel = ListedColormap(["black", "red", "green"])
|
17 |
+
|
18 |
+
MIN_ANGLE = 20
|
19 |
+
MIN_BRANCH_LENGTH = 10
|
20 |
+
MIN_BRANCH_DISTANCE = 30
|
21 |
+
|
22 |
+
|
23 |
+
def handle_multiple_fiber_in_cc(fiber, junctions_fiber, coordinates):
|
24 |
+
for y, x in junctions_fiber:
|
25 |
+
fiber[y - 1 : y + 2, x - 1 : x + 2] = 0
|
26 |
+
|
27 |
+
endpoints = find_endpoints(fiber > 0)
|
28 |
+
endpoints = np.asarray(endpoints)
|
29 |
+
# We only keep the endpoints that are close to the junction
|
30 |
+
# We compute the distance between the endpoints and the junctions
|
31 |
+
distances = np.linalg.norm(
|
32 |
+
np.expand_dims(endpoints, axis=1) - np.expand_dims(junctions_fiber, axis=0),
|
33 |
+
axis=2,
|
34 |
+
)
|
35 |
+
# We only keep the endpoints that are close to the junctions
|
36 |
+
distances = distances < 5
|
37 |
+
endpoints = endpoints[distances.any(axis=1)]
|
38 |
+
|
39 |
+
retval, branches, branches_stats, _ = cv2.connectedComponentsWithStatsWithAlgorithm(
|
40 |
+
fiber, connectivity=8, ccltype=cv2.CCL_BOLELLI, ltype=cv2.CV_16U
|
41 |
+
)
|
42 |
+
branches_bboxes = branches_stats[
|
43 |
+
:,
|
44 |
+
[cv2.CC_STAT_LEFT, cv2.CC_STAT_TOP, cv2.CC_STAT_WIDTH, cv2.CC_STAT_HEIGHT],
|
45 |
+
]
|
46 |
+
|
47 |
+
num_branches = branches_bboxes.shape[0] - 1
|
48 |
+
# We associate the endpoints to the branches
|
49 |
+
endpoints_ids = np.zeros((len(endpoints),), dtype=np.uint16)
|
50 |
+
endpoints_color = np.zeros((len(endpoints),), dtype=np.uint8)
|
51 |
+
for i, endpoint in enumerate(endpoints):
|
52 |
+
# Get the branch id
|
53 |
+
branch_id = branches[endpoint[0], endpoint[1]]
|
54 |
+
# Check if the branch id is not 0
|
55 |
+
if branch_id != 0:
|
56 |
+
endpoints_ids[i] = branch_id
|
57 |
+
endpoints_color[i] = fiber[endpoint[0], endpoint[1]]
|
58 |
+
|
59 |
+
# We remove the small branches
|
60 |
+
kept_branches = set()
|
61 |
+
for i in range(1, num_branches + 1):
|
62 |
+
# Get the branch
|
63 |
+
branch = branches == i
|
64 |
+
# Compute the area of the branch
|
65 |
+
area = np.sum(branch.astype(np.uint8))
|
66 |
+
# If the area is less than 10 pixels, remove the branch
|
67 |
+
if area < MIN_BRANCH_LENGTH:
|
68 |
+
branches[branch] = 0
|
69 |
+
else:
|
70 |
+
kept_branches.add(i)
|
71 |
+
|
72 |
+
# We remove the endpoints that are in the filtered branches
|
73 |
+
remaining_idxs = np.isin(endpoints_ids, np.asarray(list(kept_branches)))
|
74 |
+
if remaining_idxs.sum() == 0:
|
75 |
+
return []
|
76 |
+
endpoints = endpoints[remaining_idxs]
|
77 |
+
|
78 |
+
endpoints_color = endpoints_color[remaining_idxs]
|
79 |
+
endpoints_ids = endpoints_ids[remaining_idxs]
|
80 |
+
|
81 |
+
# We compute the angles of the endpoints
|
82 |
+
angles = compute_points_angle(fiber, endpoints, steps=15)
|
83 |
+
angles = np.rad2deg(angles)
|
84 |
+
# We compute the difference of angles between all the endpoints
|
85 |
+
endpoints_angles_diff = cdist(angles[:, None], angles[:, None], metric="cityblock")
|
86 |
+
|
87 |
+
# Put inf to the diagonal
|
88 |
+
endpoints_angles_diff[range(len(endpoints)), range(len(endpoints))] = np.inf
|
89 |
+
endpoints_distances = cdist(endpoints, endpoints, metric="euclidean")
|
90 |
+
|
91 |
+
endpoints_distances[range(len(endpoints)), range(len(endpoints))] = np.inf
|
92 |
+
|
93 |
+
# We sort by the distance
|
94 |
+
endpoints_distances[endpoints_distances > MIN_BRANCH_DISTANCE] = np.inf
|
95 |
+
endpoints_distances[endpoints_angles_diff > MIN_ANGLE] = np.inf
|
96 |
+
|
97 |
+
matchB = np.argmin(endpoints_distances, axis=1)
|
98 |
+
values = np.take_along_axis(endpoints_distances, matchB[:, None], axis=1)
|
99 |
+
|
100 |
+
added_edges = dict()
|
101 |
+
N = len(endpoints)
|
102 |
+
A = np.eye(N, dtype=np.uint8)
|
103 |
+
for i in range(N):
|
104 |
+
for j in range(N):
|
105 |
+
if i == j:
|
106 |
+
continue
|
107 |
+
if endpoints_ids[i] == endpoints_ids[j]:
|
108 |
+
A[i, j] = 1
|
109 |
+
A[j, i] = 1
|
110 |
+
|
111 |
+
if matchB[i] == j and values[i, 0] < np.inf:
|
112 |
+
added_edges[i] = j
|
113 |
+
A[i, j] = 1
|
114 |
+
A[j, i] = 1
|
115 |
+
|
116 |
+
A = csr_array(A)
|
117 |
+
n, ccs = connected_components(A, directed=False, return_labels=True)
|
118 |
+
unique_clusters = np.unique(ccs)
|
119 |
+
results = []
|
120 |
+
for c in unique_clusters:
|
121 |
+
idx = np.where(ccs == c)[0]
|
122 |
+
branches_ids = np.unique(endpoints_ids[idx])
|
123 |
+
|
124 |
+
unique_branches = np.logical_or.reduce(
|
125 |
+
[branches == i for i in branches_ids], axis=0
|
126 |
+
)
|
127 |
+
|
128 |
+
commons_bboxes = branches_bboxes[branches_ids]
|
129 |
+
# Compute the union of the bboxes
|
130 |
+
min_x = np.min(commons_bboxes[:, 0])
|
131 |
+
min_y = np.min(commons_bboxes[:, 1])
|
132 |
+
max_x = np.max(commons_bboxes[:, 0] + commons_bboxes[:, 2])
|
133 |
+
max_y = np.max(commons_bboxes[:, 1] + commons_bboxes[:, 3])
|
134 |
+
|
135 |
+
new_fiber = fiber[min_y:max_y, min_x:max_x]
|
136 |
+
new_fiber = unique_branches[min_y:max_y, min_x:max_x] * new_fiber
|
137 |
+
for cidx in idx:
|
138 |
+
if cidx not in added_edges:
|
139 |
+
continue
|
140 |
+
pointA = endpoints[cidx]
|
141 |
+
pointB = endpoints[added_edges[cidx]]
|
142 |
+
pointA = (
|
143 |
+
pointA[1] - min_x,
|
144 |
+
pointA[0] - min_y,
|
145 |
+
)
|
146 |
+
pointB = (
|
147 |
+
pointB[1] - min_x,
|
148 |
+
pointB[0] - min_y,
|
149 |
+
)
|
150 |
+
colA = endpoints_color[cidx]
|
151 |
+
colB = endpoints_color[added_edges[cidx]]
|
152 |
+
new_fiber = cv2.line(
|
153 |
+
new_fiber,
|
154 |
+
pointA,
|
155 |
+
pointB,
|
156 |
+
color=2 if colA != colB else int(colA),
|
157 |
+
thickness=1,
|
158 |
+
)
|
159 |
+
# We express the bbox in the original image
|
160 |
+
bbox = (
|
161 |
+
coordinates[0] + min_x,
|
162 |
+
coordinates[1] + min_y,
|
163 |
+
max_x - min_x,
|
164 |
+
max_y - min_y,
|
165 |
+
)
|
166 |
+
bbox = Bbox(x=bbox[0], y=bbox[1], width=bbox[2], height=bbox[3])
|
167 |
+
result = Fiber(bbox=bbox, data=new_fiber)
|
168 |
+
results.append(result)
|
169 |
+
return results
|
170 |
+
|
171 |
+
|
172 |
+
def handle_ccs_with_junctions(
|
173 |
+
ccs: List[np.ndarray],
|
174 |
+
junctions: List[List[Tuple[int, int]]],
|
175 |
+
coordinates: List[Tuple[int, int]],
|
176 |
+
):
|
177 |
+
"""
|
178 |
+
Handle the connected components with junctions.
|
179 |
+
The function takes a list of connected components, a list of list of junctions and a list of coordinates.
|
180 |
+
The junctions
|
181 |
+
The coordinates corresponds to the top left corner of the connected component.
|
182 |
+
"""
|
183 |
+
jncts_fibers = []
|
184 |
+
for fiber, junction, coordinate in zip(ccs, junctions, coordinates):
|
185 |
+
jncts_fibers += handle_multiple_fiber_in_cc(fiber, junction, coordinate)
|
186 |
+
|
187 |
+
return jncts_fibers
|
188 |
+
|
189 |
+
|
190 |
+
def refine_segmentation(segmentation, fix_junctions=True, show=False):
|
191 |
+
skeleton = skeletonize(segmentation > 0, method="lee").astype(np.uint8)
|
192 |
+
skeleton_gt = skeleton * segmentation
|
193 |
+
retval, labels, stats, centroids = cv2.connectedComponentsWithStatsWithAlgorithm(
|
194 |
+
skeleton, connectivity=8, ccltype=cv2.CCL_BOLELLI, ltype=cv2.CV_16U
|
195 |
+
)
|
196 |
+
|
197 |
+
bboxes = stats[
|
198 |
+
:,
|
199 |
+
[
|
200 |
+
cv2.CC_STAT_LEFT,
|
201 |
+
cv2.CC_STAT_TOP,
|
202 |
+
cv2.CC_STAT_WIDTH,
|
203 |
+
cv2.CC_STAT_HEIGHT,
|
204 |
+
],
|
205 |
+
]
|
206 |
+
|
207 |
+
local_fibers = []
|
208 |
+
coordinates = []
|
209 |
+
junctions = []
|
210 |
+
for i in range(1, retval):
|
211 |
+
bbox = bboxes[i]
|
212 |
+
x1, y1, w, h = bbox
|
213 |
+
local_gt = skeleton_gt[y1 : y1 + h, x1 : x1 + w]
|
214 |
+
local_label = (labels[y1 : y1 + h, x1 : x1 + w] == i).astype(np.uint8)
|
215 |
+
local_fiber = local_gt * local_label
|
216 |
+
local_fibers.append(local_fiber)
|
217 |
+
coordinates.append(np.asarray([x1, y1, w, h]))
|
218 |
+
local_junctions = find_line_intersection(local_fiber > 0)
|
219 |
+
local_junctions = np.where(local_junctions)
|
220 |
+
local_junctions = np.array(local_junctions).transpose()
|
221 |
+
junctions.append(local_junctions)
|
222 |
+
if show:
|
223 |
+
for bbox, junction in zip(coordinates, junctions):
|
224 |
+
x, y, w, h = bbox
|
225 |
+
junction_to_global = np.array(junction) + np.array([y, x])
|
226 |
+
|
227 |
+
plt.scatter(
|
228 |
+
junction_to_global[:, 1],
|
229 |
+
junction_to_global[:, 0],
|
230 |
+
color="white",
|
231 |
+
s=30,
|
232 |
+
alpha=0.35,
|
233 |
+
)
|
234 |
+
|
235 |
+
plt.imshow(skeleton_gt, cmap=cmlabel, interpolation="nearest")
|
236 |
+
plt.axis("off")
|
237 |
+
plt.xticks([])
|
238 |
+
plt.yticks([])
|
239 |
+
plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
|
240 |
+
plt.show()
|
241 |
+
|
242 |
+
fibers = []
|
243 |
+
if fix_junctions:
|
244 |
+
has_junctions = [len(j) > 0 for j in junctions]
|
245 |
+
for fiber, coordinate in zip(
|
246 |
+
compress(local_fibers, np.logical_not(has_junctions)),
|
247 |
+
compress(coordinates, np.logical_not(has_junctions)),
|
248 |
+
):
|
249 |
+
bbox = Bbox(
|
250 |
+
x=coordinate[0],
|
251 |
+
y=coordinate[1],
|
252 |
+
width=coordinate[2],
|
253 |
+
height=coordinate[3],
|
254 |
+
)
|
255 |
+
fibers.append(Fiber(bbox=bbox, data=fiber))
|
256 |
+
|
257 |
+
fibers += handle_ccs_with_junctions(
|
258 |
+
compress(local_fibers, has_junctions),
|
259 |
+
compress(junctions, has_junctions),
|
260 |
+
compress(coordinates, has_junctions),
|
261 |
+
)
|
262 |
+
else:
|
263 |
+
for fiber, coordinate in zip(local_fibers, coordinates):
|
264 |
+
bbox = Bbox(
|
265 |
+
x=coordinate[0],
|
266 |
+
y=coordinate[1],
|
267 |
+
width=coordinate[2],
|
268 |
+
height=coordinate[3],
|
269 |
+
)
|
270 |
+
fibers.append(Fiber(bbox=bbox, data=fiber))
|
271 |
+
|
272 |
+
fiberprops = [FiberProps(fiber=f, fiber_id=i) for i, f in enumerate(fibers)]
|
273 |
+
|
274 |
+
return fiberprops
|
dnafiber/postprocess/fiber.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import attrs
|
2 |
+
import numpy as np
|
3 |
+
from typing import Tuple
|
4 |
+
from dnafiber.postprocess.skan import trace_skeleton
|
5 |
+
|
6 |
+
@attrs.define
|
7 |
+
class Bbox:
|
8 |
+
x: int
|
9 |
+
y: int
|
10 |
+
width: int
|
11 |
+
height: int
|
12 |
+
|
13 |
+
@property
|
14 |
+
def bbox(self) -> Tuple[int, int, int, int]:
|
15 |
+
return (self.x, self.y, self.width, self.height)
|
16 |
+
|
17 |
+
@bbox.setter
|
18 |
+
def bbox(self, value: Tuple[int, int, int, int]):
|
19 |
+
self.x, self.y, self.width, self.height = value
|
20 |
+
|
21 |
+
|
22 |
+
@attrs.define
|
23 |
+
class Fiber:
|
24 |
+
bbox: Bbox
|
25 |
+
data: np.ndarray
|
26 |
+
|
27 |
+
|
28 |
+
@attrs.define
|
29 |
+
class FiberProps:
|
30 |
+
fiber: Fiber
|
31 |
+
fiber_id: int
|
32 |
+
red_pixels: int = None
|
33 |
+
green_pixels: int = None
|
34 |
+
category: str = None
|
35 |
+
|
36 |
+
@property
|
37 |
+
def bbox(self):
|
38 |
+
return self.fiber.bbox.bbox
|
39 |
+
|
40 |
+
@bbox.setter
|
41 |
+
def bbox(self, value):
|
42 |
+
self.fiber.bbox = value
|
43 |
+
|
44 |
+
@property
|
45 |
+
def data(self):
|
46 |
+
return self.fiber.data
|
47 |
+
|
48 |
+
@data.setter
|
49 |
+
def data(self, value):
|
50 |
+
self.fiber.data = value
|
51 |
+
|
52 |
+
@property
|
53 |
+
def red(self):
|
54 |
+
if self.red_pixels is None:
|
55 |
+
self.red_pixels, self.green_pixels = self.counts
|
56 |
+
return self.red_pixels
|
57 |
+
|
58 |
+
@property
|
59 |
+
def green(self):
|
60 |
+
if self.green_pixels is None:
|
61 |
+
self.red_pixels, self.green_pixels = self.counts
|
62 |
+
return self.green_pixels
|
63 |
+
|
64 |
+
@property
|
65 |
+
def length(self):
|
66 |
+
return sum(self.counts)
|
67 |
+
|
68 |
+
@property
|
69 |
+
def counts(self):
|
70 |
+
if self.red_pixels is None or self.green_pixels is None:
|
71 |
+
self.red_pixels = np.sum(self.data == 1)
|
72 |
+
self.green_pixels = np.sum(self.data == 2)
|
73 |
+
return self.red_pixels, self.green_pixels
|
74 |
+
|
75 |
+
@property
|
76 |
+
def fiber_type(self):
|
77 |
+
if self.category is not None:
|
78 |
+
return self.category
|
79 |
+
red_pixels, green_pixels = self.counts
|
80 |
+
if red_pixels == 0 or green_pixels == 0:
|
81 |
+
self.category = "single"
|
82 |
+
else:
|
83 |
+
self.category = estimate_fiber_category(self.data)
|
84 |
+
return self.category
|
85 |
+
|
86 |
+
@property
|
87 |
+
def ratio(self):
|
88 |
+
return self.green / self.red
|
89 |
+
|
90 |
+
@property
|
91 |
+
def is_valid(self):
|
92 |
+
return (
|
93 |
+
self.fiber_type == "double"
|
94 |
+
or self.fiber_type == "one-two-one"
|
95 |
+
or self.fiber_type == "two-one-two"
|
96 |
+
)
|
97 |
+
|
98 |
+
def scaled_coordinates(self, scale: float) -> Tuple[int, int]:
|
99 |
+
"""
|
100 |
+
Scale down the coordinates of the fiber's bounding box.
|
101 |
+
"""
|
102 |
+
x, y, width, height = self.bbox
|
103 |
+
return (
|
104 |
+
int(x * scale),
|
105 |
+
int(y * scale),
|
106 |
+
int(width * scale),
|
107 |
+
int(height * scale),
|
108 |
+
)
|
109 |
+
|
110 |
+
|
111 |
+
def estimate_fiber_category(fiber: np.ndarray) -> str:
|
112 |
+
"""
|
113 |
+
Estimate the fiber category based on the number of red and green pixels.
|
114 |
+
"""
|
115 |
+
coordinates = trace_skeleton(fiber > 0)
|
116 |
+
coordinates = np.asarray(coordinates)
|
117 |
+
values = fiber[coordinates[:, 0], coordinates[:, 1]]
|
118 |
+
diff = np.diff(values)
|
119 |
+
jump = np.sum(diff != 0)
|
120 |
+
n_ccs = jump + 1
|
121 |
+
if n_ccs == 2:
|
122 |
+
return "double"
|
123 |
+
elif n_ccs == 3:
|
124 |
+
if values[0] == 1:
|
125 |
+
return "one-two-one"
|
126 |
+
else:
|
127 |
+
return "two-one-two"
|
128 |
+
else:
|
129 |
+
return "multiple"
|
dnafiber/postprocess/skan.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Functions to generate kernels of curve intersection
|
2 |
+
import numpy as np
|
3 |
+
import cv2
|
4 |
+
import itertools
|
5 |
+
from numba import njit, int64
|
6 |
+
from numba.typed import List
|
7 |
+
from numba.types import Tuple
|
8 |
+
|
9 |
+
# Define the element type: a tuple of two int64
|
10 |
+
tuple_type = Tuple((int64, int64))
|
11 |
+
|
12 |
+
|
13 |
+
def find_neighbours(fibers_map, point):
|
14 |
+
"""
|
15 |
+
Find the next point in the fiber starting from the given point.
|
16 |
+
The function returns None if the point is not in the fiber.
|
17 |
+
"""
|
18 |
+
# Get the fiber id
|
19 |
+
neighbors = []
|
20 |
+
h, w = fibers_map.shape
|
21 |
+
for i in range(-1, 2):
|
22 |
+
for j in range(-1, 2):
|
23 |
+
# Skip the center point
|
24 |
+
if i == 0 and j == 0:
|
25 |
+
continue
|
26 |
+
# Get the next point
|
27 |
+
nextpoint = (point[0] + i, point[1] + j)
|
28 |
+
# Check if the next point is in the image
|
29 |
+
if (
|
30 |
+
nextpoint[0] < 0
|
31 |
+
or nextpoint[0] >= h
|
32 |
+
or nextpoint[1] < 0
|
33 |
+
or nextpoint[1] >= w
|
34 |
+
):
|
35 |
+
continue
|
36 |
+
|
37 |
+
# Check if the next point is in the fiber
|
38 |
+
if fibers_map[nextpoint]:
|
39 |
+
neighbors.append(nextpoint)
|
40 |
+
return neighbors
|
41 |
+
|
42 |
+
|
43 |
+
def compute_points_angle(fibers_map, points, steps=25):
|
44 |
+
"""
|
45 |
+
For each endpoint, follow the fiber for a given number of steps and estimate the tangent line by
|
46 |
+
fitting a line to the visited points. The angle of the line is returned.
|
47 |
+
"""
|
48 |
+
points_angle = np.zeros((len(points),), dtype=np.float32)
|
49 |
+
for i, point in enumerate(points):
|
50 |
+
# Find the fiber it belongs to
|
51 |
+
# Lets navigate along the fiber starting from the point during steps pixels.
|
52 |
+
# We compute the angles at each step and return the mean angle.
|
53 |
+
visited = trace_from_point(
|
54 |
+
fibers_map > 0, (point[0], point[1]), max_length=steps
|
55 |
+
)
|
56 |
+
visited = np.array(visited)
|
57 |
+
vx, vy, x, y = cv2.fitLine(visited[:, ::-1], cv2.DIST_L2, 0, 0.01, 0.01)
|
58 |
+
# Compute the angle of the line
|
59 |
+
points_angle[i] = np.arctan(vy / vx)
|
60 |
+
|
61 |
+
return points_angle
|
62 |
+
|
63 |
+
|
64 |
+
def generate_nonadjacent_combination(input_list, take_n):
|
65 |
+
"""
|
66 |
+
It generates combinations of m taken n at a time where there is no adjacent n.
|
67 |
+
INPUT:
|
68 |
+
input_list = (iterable) List of elements you want to extract the combination
|
69 |
+
take_n = (integer) Number of elements that you are going to take at a time in
|
70 |
+
each combination
|
71 |
+
OUTPUT:
|
72 |
+
all_comb = (np.array) with all the combinations
|
73 |
+
"""
|
74 |
+
all_comb = []
|
75 |
+
for comb in itertools.combinations(input_list, take_n):
|
76 |
+
comb = np.array(comb)
|
77 |
+
d = np.diff(comb)
|
78 |
+
if len(d[d == 1]) == 0 and comb[-1] - comb[0] != 7:
|
79 |
+
all_comb.append(comb)
|
80 |
+
return all_comb
|
81 |
+
|
82 |
+
|
83 |
+
def populate_intersection_kernel(combinations):
|
84 |
+
"""
|
85 |
+
Maps the numbers from 0-7 into the 8 pixels surrounding the center pixel in
|
86 |
+
a 9 x 9 matrix clockwisely i.e. up_pixel = 0, right_pixel = 2, etc. And
|
87 |
+
generates a kernel that represents a line intersection, where the center
|
88 |
+
pixel is occupied and 3 or 4 pixels of the border are ocuppied too.
|
89 |
+
INPUT:
|
90 |
+
combinations = (np.array) matrix where every row is a vector of combinations
|
91 |
+
OUTPUT:
|
92 |
+
kernels = (List) list of 9 x 9 kernels/masks. each element is a mask.
|
93 |
+
"""
|
94 |
+
n = len(combinations[0])
|
95 |
+
template = np.array(([-1, -1, -1], [-1, 1, -1], [-1, -1, -1]), dtype="int")
|
96 |
+
match = [(0, 1), (0, 2), (1, 2), (2, 2), (2, 1), (2, 0), (1, 0), (0, 0)]
|
97 |
+
kernels = []
|
98 |
+
for n in combinations:
|
99 |
+
tmp = np.copy(template)
|
100 |
+
for m in n:
|
101 |
+
tmp[match[m][0], match[m][1]] = 1
|
102 |
+
kernels.append(tmp)
|
103 |
+
return kernels
|
104 |
+
|
105 |
+
|
106 |
+
def give_intersection_kernels():
|
107 |
+
"""
|
108 |
+
Generates all the intersection kernels in a 9x9 matrix.
|
109 |
+
INPUT:
|
110 |
+
None
|
111 |
+
OUTPUT:
|
112 |
+
kernels = (List) list of 9 x 9 kernels/masks. each element is a mask.
|
113 |
+
"""
|
114 |
+
input_list = np.arange(8)
|
115 |
+
taken_n = [4, 3]
|
116 |
+
kernels = []
|
117 |
+
for taken in taken_n:
|
118 |
+
comb = generate_nonadjacent_combination(input_list, taken)
|
119 |
+
tmp_ker = populate_intersection_kernel(comb)
|
120 |
+
kernels.extend(tmp_ker)
|
121 |
+
return kernels
|
122 |
+
|
123 |
+
|
124 |
+
def find_line_intersection(input_image, show=0):
|
125 |
+
"""
|
126 |
+
Applies morphologyEx with parameter HitsMiss to look for all the curve
|
127 |
+
intersection kernels generated with give_intersection_kernels() function.
|
128 |
+
INPUT:
|
129 |
+
input_image = (np.array dtype=np.uint8) binarized m x n image matrix
|
130 |
+
OUTPUT:
|
131 |
+
output_image = (np.array dtype=np.uint8) image where the nonzero pixels
|
132 |
+
are the line intersection.
|
133 |
+
"""
|
134 |
+
input_image = input_image.astype(np.uint8)
|
135 |
+
kernel = np.array(give_intersection_kernels())
|
136 |
+
output_image = np.zeros(input_image.shape)
|
137 |
+
for i in np.arange(len(kernel)):
|
138 |
+
out = cv2.morphologyEx(
|
139 |
+
input_image,
|
140 |
+
cv2.MORPH_HITMISS,
|
141 |
+
kernel[i, :, :],
|
142 |
+
borderValue=0,
|
143 |
+
borderType=cv2.BORDER_CONSTANT,
|
144 |
+
)
|
145 |
+
output_image = output_image + out
|
146 |
+
|
147 |
+
return output_image
|
148 |
+
|
149 |
+
|
150 |
+
@njit
|
151 |
+
def get_neighbors_8(y, x, shape):
|
152 |
+
neighbors = List.empty_list(tuple_type)
|
153 |
+
for dy in range(-1, 2):
|
154 |
+
for dx in range(-1, 2):
|
155 |
+
if dy == 0 and dx == 0:
|
156 |
+
continue
|
157 |
+
ny, nx = y + dy, x + dx
|
158 |
+
if 0 <= ny < shape[0] and 0 <= nx < shape[1]:
|
159 |
+
neighbors.append((ny, nx))
|
160 |
+
return neighbors
|
161 |
+
|
162 |
+
|
163 |
+
@njit
|
164 |
+
def find_endpoints(skel):
|
165 |
+
endpoints = List.empty_list(tuple_type)
|
166 |
+
for y in range(skel.shape[0]):
|
167 |
+
for x in range(skel.shape[1]):
|
168 |
+
if skel[y, x] == 1:
|
169 |
+
count = 0
|
170 |
+
neighbors = get_neighbors_8(y, x, skel.shape)
|
171 |
+
for ny, nx in neighbors:
|
172 |
+
if skel[ny, nx] == 1:
|
173 |
+
count += 1
|
174 |
+
if count == 1:
|
175 |
+
endpoints.append((y, x))
|
176 |
+
return endpoints
|
177 |
+
|
178 |
+
|
179 |
+
@njit
|
180 |
+
def trace_skeleton(skel):
|
181 |
+
endpoints = find_endpoints(skel)
|
182 |
+
if len(endpoints) < 1:
|
183 |
+
return List.empty_list(tuple_type) # Return empty list with proper type
|
184 |
+
|
185 |
+
return trace_from_point(skel, endpoints[0], max_length=skel.sum())
|
186 |
+
|
187 |
+
|
188 |
+
@njit
|
189 |
+
def trace_from_point(skel, point, max_length=25):
|
190 |
+
visited = np.zeros_like(skel, dtype=np.uint8)
|
191 |
+
path = List.empty_list(tuple_type)
|
192 |
+
|
193 |
+
# Check if the starting point is on the skeleton
|
194 |
+
y, x = point
|
195 |
+
if y < 0 or y >= skel.shape[0] or x < 0 or x >= skel.shape[1] or skel[y, x] != 1:
|
196 |
+
return path
|
197 |
+
|
198 |
+
stack = List.empty_list(tuple_type)
|
199 |
+
stack.append(point)
|
200 |
+
|
201 |
+
while len(stack) > 0 and len(path) < max_length:
|
202 |
+
y, x = stack.pop()
|
203 |
+
if visited[y, x]:
|
204 |
+
continue
|
205 |
+
visited[y, x] = 1
|
206 |
+
path.append((y, x))
|
207 |
+
neighbors = get_neighbors_8(y, x, skel.shape)
|
208 |
+
for ny, nx in neighbors:
|
209 |
+
if skel[ny, nx] == 1 and not visited[ny, nx]:
|
210 |
+
stack.append((ny, nx))
|
211 |
+
return path
|
dnafiber/start.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import subprocess
|
2 |
+
import os
|
3 |
+
|
4 |
+
|
5 |
+
def main():
|
6 |
+
# Start the Streamlit application
|
7 |
+
print("Starting Streamlit application...")
|
8 |
+
local_dir = os.path.dirname(os.path.abspath(__file__))
|
9 |
+
subprocess.run(
|
10 |
+
[
|
11 |
+
"streamlit",
|
12 |
+
"run",
|
13 |
+
os.path.join(local_dir, "ui", "Welcome.py"),
|
14 |
+
"--server.maxUploadSize",
|
15 |
+
"1024",
|
16 |
+
],
|
17 |
+
)
|
18 |
+
|
19 |
+
|
20 |
+
if __name__ == "__main__":
|
21 |
+
main()
|
22 |
+
print("Streamlit application started successfully.")
|
dnafiber/trainee.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from lightning import LightningModule
|
2 |
+
import segmentation_models_pytorch as smp
|
3 |
+
from monai.losses.dice import GeneralizedDiceLoss
|
4 |
+
from monai.losses.cldice import SoftDiceclDiceLoss
|
5 |
+
from torchmetrics.classification import Dice, JaccardIndex
|
6 |
+
from torch.optim import AdamW
|
7 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
8 |
+
from torchmetrics import MetricCollection
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from huggingface_hub import PyTorchModelHubMixin
|
11 |
+
import torch
|
12 |
+
import torchvision
|
13 |
+
from dnafiber.metric import DNAFIBERMetric
|
14 |
+
|
15 |
+
|
16 |
+
class Trainee(LightningModule, PyTorchModelHubMixin):
|
17 |
+
def __init__(
|
18 |
+
self, learning_rate=0.001, weight_decay=0.0002, num_classes=3, **model_config
|
19 |
+
):
|
20 |
+
super().__init__()
|
21 |
+
self.model_config = model_config
|
22 |
+
if (
|
23 |
+
self.model_config.get("arch", None) is None
|
24 |
+
or self.model_config["arch"] == "maskrcnn"
|
25 |
+
):
|
26 |
+
self.model = None
|
27 |
+
else:
|
28 |
+
self.model = smp.create_model(classes=3, **self.model_config, dropout=0.2)
|
29 |
+
self.loss = GeneralizedDiceLoss(to_onehot_y=False, softmax=False)
|
30 |
+
self.metric = MetricCollection(
|
31 |
+
{
|
32 |
+
"dice": Dice(num_classes=num_classes, ignore_index=0),
|
33 |
+
"jaccard": JaccardIndex(
|
34 |
+
num_classes=num_classes,
|
35 |
+
task="multiclass" if num_classes > 2 else "binary",
|
36 |
+
ignore_index=0,
|
37 |
+
),
|
38 |
+
"detection": DNAFIBERMetric(),
|
39 |
+
}
|
40 |
+
)
|
41 |
+
self.weight_decay = weight_decay
|
42 |
+
self.learning_rate = learning_rate
|
43 |
+
self.save_hyperparameters()
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
yhat = self.model(x)
|
47 |
+
return yhat
|
48 |
+
|
49 |
+
def training_step(self, batch, batch_idx):
|
50 |
+
x, y = batch["image"], batch["mask"]
|
51 |
+
y = y.clamp(0, 2)
|
52 |
+
y_hat = self(x)
|
53 |
+
loss = self.get_loss(y_hat, y)
|
54 |
+
|
55 |
+
self.log("train_loss", loss)
|
56 |
+
|
57 |
+
return loss
|
58 |
+
|
59 |
+
def get_loss(self, y_hat, y):
|
60 |
+
y_hat = F.softmax(y_hat, dim=1)
|
61 |
+
y = F.one_hot(y.long(), num_classes=3)
|
62 |
+
y = y.permute(0, 3, 1, 2).float()
|
63 |
+
loss = self.loss(y_hat, y)
|
64 |
+
return loss
|
65 |
+
|
66 |
+
def validation_step(self, batch, batch_idx):
|
67 |
+
x, y = batch["image"], batch["mask"]
|
68 |
+
y = y.clamp(0, 2)
|
69 |
+
y_hat = self(x)
|
70 |
+
loss = self.get_loss(y_hat, y)
|
71 |
+
self.log("val_loss", loss, on_step=False, on_epoch=True, sync_dist=True)
|
72 |
+
self.metric.update(y_hat, y)
|
73 |
+
return y_hat
|
74 |
+
|
75 |
+
def on_validation_epoch_end(self):
|
76 |
+
scores = self.metric.compute()
|
77 |
+
self.log_dict(scores, sync_dist=True)
|
78 |
+
self.metric.reset()
|
79 |
+
|
80 |
+
def configure_optimizers(self):
|
81 |
+
optimizer = AdamW(
|
82 |
+
self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay
|
83 |
+
)
|
84 |
+
scheduler = CosineAnnealingLR(
|
85 |
+
optimizer,
|
86 |
+
T_max=self.trainer.max_epochs, # type: ignore
|
87 |
+
eta_min=self.learning_rate / 25,
|
88 |
+
)
|
89 |
+
scheduler = {
|
90 |
+
"scheduler": scheduler,
|
91 |
+
"interval": "epoch",
|
92 |
+
}
|
93 |
+
return [optimizer], [scheduler]
|
94 |
+
|
95 |
+
|
96 |
+
class TraineeMaskRCNN(Trainee):
|
97 |
+
def __init__(self, learning_rate=0.001, weight_decay=0.0002, **model_config):
|
98 |
+
super().__init__(learning_rate, weight_decay, **model_config)
|
99 |
+
self.model = torchvision.models.get_model("maskrcnn_resnet50_fpn_v2")
|
100 |
+
|
101 |
+
def forward(self, x):
|
102 |
+
yhat = self.model(x)
|
103 |
+
return yhat
|
104 |
+
|
105 |
+
def training_step(self, batch, batch_idx):
|
106 |
+
image = batch["image"]
|
107 |
+
targets = batch["targets"]
|
108 |
+
loss_dict = self.model(image, targets)
|
109 |
+
losses = sum(loss for loss in loss_dict.values())
|
110 |
+
self.log("train_loss", losses, on_step=True, on_epoch=False, sync_dist=True)
|
111 |
+
return losses
|
112 |
+
|
113 |
+
def validation_step(self, batch, batch_idx):
|
114 |
+
image = batch["image"]
|
115 |
+
targets = batch["targets"]
|
116 |
+
|
117 |
+
predictions = self.model(image)
|
118 |
+
b = len(predictions)
|
119 |
+
predicted_masks = []
|
120 |
+
gt_masks = []
|
121 |
+
for i in range(b):
|
122 |
+
scores = predictions[i]["scores"]
|
123 |
+
masks = predictions[i]["masks"]
|
124 |
+
good_masks = masks[scores > 0.5]
|
125 |
+
# Combined into a single mask
|
126 |
+
good_masks = torch.sum(good_masks, dim=0)
|
127 |
+
predicted_masks.append(good_masks)
|
128 |
+
gt_masks.append(targets[i]["masks"].sum(dim=0))
|
129 |
+
|
130 |
+
gt_masks = torch.stack(gt_masks).squeeze(1) > 0
|
131 |
+
predicted_masks = torch.stack(predicted_masks).squeeze(1) > 0
|
132 |
+
self.metric.update(predicted_masks, gt_masks)
|
133 |
+
return predictions
|
134 |
+
|
135 |
+
def configure_optimizers(self):
|
136 |
+
optimizer = AdamW(
|
137 |
+
self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay
|
138 |
+
)
|
139 |
+
scheduler = CosineAnnealingLR(
|
140 |
+
optimizer,
|
141 |
+
T_max=self.trainer.max_epochs, # type: ignore
|
142 |
+
eta_min=self.learning_rate / 25,
|
143 |
+
)
|
144 |
+
scheduler = {
|
145 |
+
"scheduler": scheduler,
|
146 |
+
"interval": "epoch",
|
147 |
+
}
|
148 |
+
return [optimizer], [scheduler]
|
dnafiber/ui/Welcome.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
def main():
|
6 |
+
st.set_page_config(
|
7 |
+
page_title="Hello",
|
8 |
+
page_icon="🧬",
|
9 |
+
layout="wide",
|
10 |
+
)
|
11 |
+
st.write("# Welcome to DN-AI! 👋")
|
12 |
+
|
13 |
+
st.write(
|
14 |
+
"This is a web application for the DN-AI project, which aims to provide an easy-to-use interface for analyzing and processing fiber images."
|
15 |
+
)
|
16 |
+
st.write("## Features")
|
17 |
+
st.write(
|
18 |
+
"- **Image loading**: The application accepts CZI file, jpeg and PNG file. \n"
|
19 |
+
"- **Image segmentation**: The application provides a set of tools to segment the DNA fiber and measure the ratio between analogs. \n"
|
20 |
+
)
|
21 |
+
st.write("## Technical details")
|
22 |
+
cols = st.columns(2)
|
23 |
+
with cols[0]:
|
24 |
+
st.write("### Source")
|
25 |
+
st.write("The source code for this application is available on GitHub.")
|
26 |
+
"""
|
27 |
+
[](https://github.com/ClementPla/DeepFiberQ/tree/relabelled)
|
28 |
+
|
29 |
+
"""
|
30 |
+
st.markdown("<br>", unsafe_allow_html=True)
|
31 |
+
|
32 |
+
with cols[1]:
|
33 |
+
st.write("### Device ")
|
34 |
+
st.write("If available, the application will try to use a GPU for processing.")
|
35 |
+
device = "GPU" if torch.cuda.is_available() else "CPU"
|
36 |
+
cols = st.columns(3)
|
37 |
+
with cols[0]:
|
38 |
+
st.write("Running on:")
|
39 |
+
with cols[1]:
|
40 |
+
st.button(device, icon="⚙️", disabled=True)
|
41 |
+
if not torch.cuda.is_available():
|
42 |
+
with cols[2]:
|
43 |
+
st.warning("The application will run on CPU, which may be slower.")
|
44 |
+
|
45 |
+
|
46 |
+
if __name__ == "__main__":
|
47 |
+
main()
|
dnafiber/ui/__init__.py
ADDED
File without changes
|
dnafiber/ui/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (165 Bytes). View file
|
|
dnafiber/ui/__pycache__/inference.cpython-312.pyc
ADDED
Binary file (2.32 kB). View file
|
|
dnafiber/ui/__pycache__/utils.cpython-312.pyc
ADDED
Binary file (7.33 kB). View file
|
|
dnafiber/ui/inference.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from dnafiber.inference import infer
|
3 |
+
from dnafiber.postprocess.core import refine_segmentation
|
4 |
+
import numpy as np
|
5 |
+
from dnafiber.deployment import _get_model
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
@st.cache_data
|
10 |
+
def ui_inference(_model, _image, _device, postprocess=True, id=None):
|
11 |
+
return ui_inference_cacheless(
|
12 |
+
_model, _image, _device, postprocess=postprocess, id=id
|
13 |
+
)
|
14 |
+
|
15 |
+
|
16 |
+
@st.cache_resource
|
17 |
+
def get_model(model_name):
|
18 |
+
model = _get_model(
|
19 |
+
device="cuda" if torch.cuda.is_available() else "cpu",
|
20 |
+
revision=model_name,
|
21 |
+
)
|
22 |
+
return model
|
23 |
+
|
24 |
+
|
25 |
+
def ui_inference_cacheless(_model, _image, _device, postprocess=True, id=None):
|
26 |
+
"""
|
27 |
+
A cacheless version of the ui_inference function.
|
28 |
+
This function does not use caching and is intended for use in scenarios where caching is not desired.
|
29 |
+
"""
|
30 |
+
h, w = _image.shape[:2]
|
31 |
+
with st.spinner("Sliding window segmentation in progress..."):
|
32 |
+
if isinstance(_model, list):
|
33 |
+
output = None
|
34 |
+
for model in _model:
|
35 |
+
if isinstance(model, str):
|
36 |
+
model = get_model(model)
|
37 |
+
with st.spinner(text="Segmenting with model: {}".format(model)):
|
38 |
+
if output is None:
|
39 |
+
output = infer(
|
40 |
+
model,
|
41 |
+
image=_image,
|
42 |
+
device=_device,
|
43 |
+
scale=st.session_state.get("pixel_size", 0.13),
|
44 |
+
only_probabilities=True,
|
45 |
+
).cpu()
|
46 |
+
else:
|
47 |
+
output = (
|
48 |
+
output
|
49 |
+
+ infer(
|
50 |
+
model,
|
51 |
+
image=_image,
|
52 |
+
device=_device,
|
53 |
+
scale=st.session_state.get("pixel_size", 0.13),
|
54 |
+
only_probabilities=True,
|
55 |
+
).cpu()
|
56 |
+
)
|
57 |
+
output = (output / len(_model)).argmax(1).squeeze().numpy()
|
58 |
+
else:
|
59 |
+
output = infer(
|
60 |
+
_model,
|
61 |
+
image=_image,
|
62 |
+
device=_device,
|
63 |
+
scale=st.session_state.get("pixel_size", 0.13),
|
64 |
+
)
|
65 |
+
output = output.astype(np.uint8)
|
66 |
+
if postprocess:
|
67 |
+
with st.spinner("Post-processing segmentation..."):
|
68 |
+
output = refine_segmentation(output, fix_junctions=postprocess)
|
69 |
+
return output
|
dnafiber/ui/pages/1_Load.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
|
4 |
+
st.set_page_config(
|
5 |
+
page_title="DN-AI",
|
6 |
+
page_icon="🔬",
|
7 |
+
layout="wide",
|
8 |
+
)
|
9 |
+
|
10 |
+
def build_multichannel_loader():
|
11 |
+
|
12 |
+
if (
|
13 |
+
st.session_state.get("files_uploaded", None) is None
|
14 |
+
or len(st.session_state.files_uploaded) == 0
|
15 |
+
):
|
16 |
+
st.session_state["files_uploaded"] = st.file_uploader(
|
17 |
+
label="Upload files",
|
18 |
+
accept_multiple_files=True,
|
19 |
+
type=["czi", "jpeg", "jpg", "png", "tiff", "tif"],
|
20 |
+
)
|
21 |
+
else:
|
22 |
+
st.session_state["files_uploaded"] += st.file_uploader(
|
23 |
+
label="Upload files",
|
24 |
+
accept_multiple_files=True,
|
25 |
+
type=["czi", "jpeg", "jpg", "png", "tiff", "tif"],
|
26 |
+
)
|
27 |
+
st.write("### Channel interpretation")
|
28 |
+
st.markdown("The goal is to obtain an RGB image in the order of <span style='color: red;'>First analog</span>, <span style='color: green;'>Second analog</span>, <span style='color: blue;'>Empty</span>.", unsafe_allow_html=True)
|
29 |
+
st.markdown("By default, we assume that the first channel in CZI/TIFF file is <span style='color: green;'>the second analog</span>, (which happens to be the case in Zeiss microscope) " \
|
30 |
+
"which means that we swap the order of the two channels for processing.", unsafe_allow_html=True)
|
31 |
+
st.write("If this not the intented behavior, please tick the box below:")
|
32 |
+
st.session_state["reverse_channels"] = st.checkbox(
|
33 |
+
"Reverse the channels interpretation",
|
34 |
+
value=False,
|
35 |
+
)
|
36 |
+
st.warning("Please note that we only swap the channels for raw (CZI, TIFF) files. JPEG and PNG files "\
|
37 |
+
"are assumed to be already in the correct order (First analog in red and second analog in green). " \
|
38 |
+
)
|
39 |
+
|
40 |
+
st.info("" \
|
41 |
+
"The channels order in CZI files does not necessarily match the order in which they are displayed in ImageJ or equivalent. " \
|
42 |
+
"Indeed, such viewers will usually look at the metadata of the file to determine the order of the channels, which we don't. " \
|
43 |
+
"In doubt, we recommend visualizing the image in ImageJ and compare with our viewer. If the channels appear reversed, tick the option above.")
|
44 |
+
|
45 |
+
def build_individual_loader():
|
46 |
+
|
47 |
+
cols = st.columns(2)
|
48 |
+
with cols[1]:
|
49 |
+
st.markdown(f"<h3 style='color: {st.session_state['color2']};'>Second analog</h3>", unsafe_allow_html=True)
|
50 |
+
|
51 |
+
if (
|
52 |
+
st.session_state.get("analog_2_files", None) is None
|
53 |
+
or len(st.session_state.analog_2_files) == 0
|
54 |
+
):
|
55 |
+
st.session_state["analog_2_files"] = st.file_uploader(
|
56 |
+
label="Upload second analog file(s)",
|
57 |
+
accept_multiple_files=True,
|
58 |
+
type=["czi", "jpeg", "jpg", "png", "tiff", "tif"],
|
59 |
+
)
|
60 |
+
else:
|
61 |
+
st.session_state["analog_2_files"] += st.file_uploader(
|
62 |
+
label="Upload second analog file(s)",
|
63 |
+
accept_multiple_files=True,
|
64 |
+
type=["czi", "jpeg", "jpg", "png", "tiff", "tif"],
|
65 |
+
)
|
66 |
+
|
67 |
+
|
68 |
+
with cols[0]:
|
69 |
+
st.markdown(f"<h3 style='color: {st.session_state['color1']};'>First analog</h3>", unsafe_allow_html=True)
|
70 |
+
if (
|
71 |
+
st.session_state.get("analog_1_files", None) is None
|
72 |
+
or len(st.session_state.analog_1_files) == 0
|
73 |
+
):
|
74 |
+
st.session_state["analog_1_files"] = st.file_uploader(
|
75 |
+
label="Upload first analog file(s)",
|
76 |
+
accept_multiple_files=True,
|
77 |
+
type=["czi", "jpeg", "jpg", "png", "tiff", "tif"],
|
78 |
+
)
|
79 |
+
else:
|
80 |
+
st.session_state["analog_1_files"] += st.file_uploader(
|
81 |
+
label="Upload first analog file(s)",
|
82 |
+
accept_multiple_files=True,
|
83 |
+
type=["czi", "jpeg", "jpg", "png", "tiff", "tif"],)
|
84 |
+
|
85 |
+
analog_1_files=st.session_state.get("analog_1_files", None)
|
86 |
+
analog_2_files=st.session_state.get("analog_2_files", None)
|
87 |
+
|
88 |
+
# Remove duplicates from the list of files. We loop through the files and keep only the first occurrence of each file_id.
|
89 |
+
def remove_duplicates(files):
|
90 |
+
seen_ids = set()
|
91 |
+
unique_files = []
|
92 |
+
for file in files:
|
93 |
+
if file and file.name not in seen_ids:
|
94 |
+
unique_files.append(file)
|
95 |
+
seen_ids.add(file.name)
|
96 |
+
return unique_files
|
97 |
+
|
98 |
+
analog_1_files = remove_duplicates(analog_1_files or [])
|
99 |
+
analog_2_files = remove_duplicates(analog_2_files or [])
|
100 |
+
|
101 |
+
|
102 |
+
if analog_1_files is None and analog_2_files is None:
|
103 |
+
return
|
104 |
+
else:
|
105 |
+
if len(analog_1_files)>0 and len(analog_2_files)>0 and len(analog_1_files) != len(analog_2_files):
|
106 |
+
st.error("Please upload the same number of analogs files.")
|
107 |
+
return
|
108 |
+
|
109 |
+
# Always make sure we don't have duplicates in the list of files
|
110 |
+
|
111 |
+
analog_1_files = sorted(analog_1_files, key=lambda x: x.name)
|
112 |
+
analog_2_files = sorted(analog_2_files, key=lambda x: x.name)
|
113 |
+
max_size = max(len(analog_1_files), len(analog_2_files))
|
114 |
+
# Pad the shorter list with None
|
115 |
+
if len(analog_1_files) < max_size:
|
116 |
+
analog_1_files += [None] * (max_size - len(analog_1_files))
|
117 |
+
if len(analog_2_files) < max_size:
|
118 |
+
analog_2_files += [None] * (max_size - len(analog_2_files))
|
119 |
+
|
120 |
+
combined_files = list(zip(analog_1_files, analog_2_files))
|
121 |
+
|
122 |
+
|
123 |
+
|
124 |
+
if (
|
125 |
+
st.session_state.get("files_uploaded", None) is None
|
126 |
+
or len(st.session_state.files_uploaded) == 0
|
127 |
+
):
|
128 |
+
st.session_state["files_uploaded"] = combined_files
|
129 |
+
else:
|
130 |
+
st.session_state["files_uploaded"] += combined_files
|
131 |
+
|
132 |
+
|
133 |
+
|
134 |
+
# If any of the files (analog_1_files or analog_2_files) was included previously in the files_uploaded,
|
135 |
+
# We remove the previous occurence from the files_uploaded list.
|
136 |
+
current_ids = set()
|
137 |
+
for f in analog_1_files + analog_2_files:
|
138 |
+
if f:
|
139 |
+
current_ids.add(f.name)
|
140 |
+
|
141 |
+
# Safely filter the list to exclude any files with matching file_ids
|
142 |
+
def is_not_duplicate(file):
|
143 |
+
if isinstance(file, tuple):
|
144 |
+
f1, f2 = file
|
145 |
+
if f1 and f2:
|
146 |
+
return True
|
147 |
+
|
148 |
+
return (f1 is None or f1.name not in current_ids) and (f2 is None or f2.name not in current_ids)
|
149 |
+
else:
|
150 |
+
return True
|
151 |
+
|
152 |
+
st.session_state.files_uploaded = [f for f in st.session_state.files_uploaded if is_not_duplicate(f)]
|
153 |
+
|
154 |
+
|
155 |
+
|
156 |
+
cols = st.columns(2)
|
157 |
+
with cols[1]:
|
158 |
+
|
159 |
+
|
160 |
+
st.write("### Pixel size")
|
161 |
+
st.session_state["pixel_size"] = st.number_input(
|
162 |
+
"Please indicate the pixel size of the image in µm (default: 0.13 µm).",
|
163 |
+
value=st.session_state.get("pixel_size", 0.13),
|
164 |
+
)
|
165 |
+
# In small, lets precise the tehnical details
|
166 |
+
st.write(
|
167 |
+
"The pixel size is used to convert the pixel coordinates to µm. " \
|
168 |
+
"The model is trained on images with a pixel size of 0.26 µm, and the application automatically " \
|
169 |
+
"resizes the images to match this pixel size using your provided choice."
|
170 |
+
)
|
171 |
+
|
172 |
+
st.write("### Labels color")
|
173 |
+
color_choices = st.columns(2)
|
174 |
+
with color_choices[0]:
|
175 |
+
st.session_state["color1"] = st.color_picker(
|
176 |
+
"Select the color for first analog",
|
177 |
+
value=st.session_state.get("color1", "#FF0000"),
|
178 |
+
help="This color will be used to display the first analog segments.")
|
179 |
+
with color_choices[1]:
|
180 |
+
st.session_state["color2"] = st.color_picker(
|
181 |
+
"Select the color for second analog",
|
182 |
+
value=st.session_state.get("color2", "#00FF00"),
|
183 |
+
help="This color will be used to display the second analog segments.")
|
184 |
+
|
185 |
+
with cols[0]:
|
186 |
+
choice = st.segmented_control(
|
187 |
+
"Please select the type of images you want to upload:",
|
188 |
+
options=["Multichannel", "Individual channel"],
|
189 |
+
default="Multichannel",
|
190 |
+
)
|
191 |
+
if choice == "Individual channel":
|
192 |
+
build_individual_loader()
|
193 |
+
else:
|
194 |
+
build_multichannel_loader()
|
195 |
+
|
196 |
+
|