ClementP commited on
Commit
69591a9
·
verified ·
1 Parent(s): 06bdd49

Upload 55 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +59 -20
  2. dnafiber/__init__.py +1 -0
  3. dnafiber/__pycache__/__init__.cpython-312.pyc +0 -0
  4. dnafiber/__pycache__/deployment.cpython-312.pyc +0 -0
  5. dnafiber/__pycache__/inference.cpython-312.pyc +0 -0
  6. dnafiber/__pycache__/metric.cpython-312.pyc +0 -0
  7. dnafiber/__pycache__/post_process.cpython-312.pyc +0 -0
  8. dnafiber/__pycache__/trainee.cpython-312.pyc +0 -0
  9. dnafiber/analysis/__init__.py +0 -0
  10. dnafiber/analysis/chart.py +61 -0
  11. dnafiber/analysis/const.py +3 -0
  12. dnafiber/analysis/utils.py +21 -0
  13. dnafiber/callbacks.py +50 -0
  14. dnafiber/data/__init__.py +0 -0
  15. dnafiber/data/__pycache__/__init__.cpython-312.pyc +0 -0
  16. dnafiber/data/__pycache__/utils.cpython-312.pyc +0 -0
  17. dnafiber/data/dataset.py +271 -0
  18. dnafiber/data/intergrader/__init__.py +1 -0
  19. dnafiber/data/intergrader/__pycache__/__init__.cpython-312.pyc +0 -0
  20. dnafiber/data/intergrader/__pycache__/analysis.cpython-312.pyc +0 -0
  21. dnafiber/data/intergrader/__pycache__/const.cpython-312.pyc +0 -0
  22. dnafiber/data/intergrader/__pycache__/io.cpython-312.pyc +0 -0
  23. dnafiber/data/intergrader/__pycache__/plot.cpython-312.pyc +0 -0
  24. dnafiber/data/intergrader/analysis.py +120 -0
  25. dnafiber/data/intergrader/auto.py +3 -0
  26. dnafiber/data/intergrader/const.py +21 -0
  27. dnafiber/data/intergrader/io.py +27 -0
  28. dnafiber/data/intergrader/plot.py +172 -0
  29. dnafiber/data/utils.py +80 -0
  30. dnafiber/deployment.py +44 -0
  31. dnafiber/inference.py +105 -0
  32. dnafiber/metric.py +150 -0
  33. dnafiber/model/maskrcnn.py +0 -0
  34. dnafiber/postprocess/__init__.py +1 -0
  35. dnafiber/postprocess/__pycache__/__init__.cpython-312.pyc +0 -0
  36. dnafiber/postprocess/__pycache__/core.cpython-312.pyc +0 -0
  37. dnafiber/postprocess/__pycache__/fiber.cpython-312.pyc +0 -0
  38. dnafiber/postprocess/__pycache__/skan.cpython-312.pyc +0 -0
  39. dnafiber/postprocess/core.py +274 -0
  40. dnafiber/postprocess/fiber.py +129 -0
  41. dnafiber/postprocess/skan.py +211 -0
  42. dnafiber/start.py +22 -0
  43. dnafiber/trainee.py +148 -0
  44. dnafiber/ui/Welcome.py +47 -0
  45. dnafiber/ui/__init__.py +0 -0
  46. dnafiber/ui/__pycache__/__init__.cpython-312.pyc +0 -0
  47. dnafiber/ui/__pycache__/inference.cpython-312.pyc +0 -0
  48. dnafiber/ui/__pycache__/utils.cpython-312.pyc +0 -0
  49. dnafiber/ui/inference.py +69 -0
  50. dnafiber/ui/pages/1_Load.py +196 -0
README.md CHANGED
@@ -1,20 +1,59 @@
1
- ---
2
- title: DNAI
3
- emoji: 🚀
4
- colorFrom: red
5
- colorTo: red
6
- sdk: docker
7
- app_port: 8501
8
- tags:
9
- - streamlit
10
- pinned: false
11
- short_description: DNA Fiber semantic segmentation for replication assessment
12
- license: mit
13
- ---
14
-
15
- # Welcome to Streamlit!
16
-
17
- Edit `/src/streamlit_app.py` to customize this app to your heart's desire. :heart:
18
-
19
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
20
- forums](https://discuss.streamlit.io).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DN-AI
2
+
3
+ This is the official repository for DN-AI, an automated tool for measurement of differentiated DNA replication in fluorescence microscopy images.
4
+
5
+ DN-AI offers different solutions for biologists to measure DNA replication in fluorescence microscopy images, without requiring programming skills. See the [Installation](#installation) section for instructions on how to install DN-AI.
6
+
7
+ ## Features
8
+
9
+ - **Automated DNA replication measurement**: DN-AI can automatically measure the amount of DNA replication in fluorescence microscopy images. We use a deep learning model to segment the images and measure the amount of DNA replication.
10
+ - **User-friendly interface**: DN-AI provides a web-based user-friendly interface that allows users to easily upload images and view the results. Both jpeg and tiff images are supported.
11
+ - **Batch processing**: DN-AI can process multiple images at once, making it easy to analyze large datasets. It also supports comparing ratios between different batches of images.
12
+
13
+
14
+ ## Installation
15
+
16
+ DN-AI relies on Python. We recommend installing its latest version (3.10 or higher) and using a virtual environment to avoid conflicts with other packages.
17
+
18
+ ### Prerequisites
19
+ Before installing DN-AI, make sure you have the following prerequisites installed:
20
+ - [Python 3.10 or higher](https://www.python.org/downloads/)
21
+ - [pip](https://pip.pypa.io/en/stable/installation/) (Python package installer)
22
+
23
+ ### Python Package
24
+ To install DN-AI as a Python package, you can use pip:
25
+
26
+ ```bash
27
+ pip install git+https://github.com/ClementPla/DeepFiberQ.git
28
+ ```
29
+
30
+
31
+ ### Graphical User Interface (GUI)
32
+
33
+ To run the DN-AI graphical user interface, you can use the following command:
34
+
35
+ ```bash
36
+ DNAI
37
+ ```
38
+
39
+ Make sure you are running this command in the terminal where you have installed DN-AI. This will start a local web server and you will see output similar to:
40
+
41
+
42
+ Then open your web browser and go to `http://localhost:8501` to access the DN-AI interface.
43
+
44
+ Screenshots of the GUI:
45
+
46
+ ![DN-AI GUI](imgs/screenshot.png)
47
+
48
+
49
+
50
+ ### Docker
51
+ A Docker image is available for DN-AI. You can pull the image from Docker Hub:
52
+
53
+ ```bash
54
+ docker pull clementpla/dnafiber
55
+ ```
56
+
57
+ ### Google Colab
58
+ We also provide a Google Colab notebook for DN-AI. You can access it [here](https://colab.research.google.com/github/ClementPla/DeepFiberQ/blob/main/Colab/DNA_Fiber_Q.ipynb).
59
+
dnafiber/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from dnafiber.deployment import _get_model
dnafiber/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (219 Bytes). View file
 
dnafiber/__pycache__/deployment.cpython-312.pyc ADDED
Binary file (2.46 kB). View file
 
dnafiber/__pycache__/inference.cpython-312.pyc ADDED
Binary file (4.95 kB). View file
 
dnafiber/__pycache__/metric.cpython-312.pyc ADDED
Binary file (6.79 kB). View file
 
dnafiber/__pycache__/post_process.cpython-312.pyc ADDED
Binary file (5.36 kB). View file
 
dnafiber/__pycache__/trainee.cpython-312.pyc ADDED
Binary file (7.73 kB). View file
 
dnafiber/analysis/__init__.py ADDED
File without changes
dnafiber/analysis/chart.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from dnafiber.analysis.const import palette
3
+ import plotly.express as px
4
+
5
+
6
+ def get_color_association(df):
7
+ """
8
+ Get the color association for each image in the dataframe.
9
+ """
10
+ unique_name = df["image_name"].unique()
11
+ color_association = {i: p for (i, p) in zip(unique_name, palette)}
12
+ return color_association
13
+
14
+
15
+ def plot_ratio(df, color_association=None, only_bilateral=True):
16
+ df = df[["ratio", "image_name", "fiber_type"]].copy()
17
+
18
+ df["Image"] = df["image_name"]
19
+ df["Fiber Type"] = df["fiber_type"]
20
+ df["Ratio"] = df["ratio"]
21
+ if only_bilateral:
22
+ df = df[df["Fiber Type"] == "double"]
23
+
24
+ df = df.sort_values(
25
+ by=["Image", "Fiber Type"],
26
+ ascending=[True, True],
27
+ )
28
+
29
+ # Order the dataframe by the average ratio of each image
30
+ image_order = (
31
+ df.groupby("Image")["Ratio"].median().sort_values(ascending=True).index
32
+ )
33
+ df["Image"] = pd.Categorical(df["Image"], categories=image_order, ordered=True)
34
+ df.sort_values("Image", inplace=True)
35
+ if color_association is None:
36
+ color_association = get_color_association(df)
37
+ unique_name = df["image_name"].unique()
38
+ color_association = {i: p for (i, p) in zip(unique_name, palette)}
39
+
40
+ this_palette = [color_association[i] for i in unique_name]
41
+ fig = px.violin(
42
+ df,
43
+ y="Ratio",
44
+ x="Image",
45
+ color="Image",
46
+ color_discrete_sequence=this_palette,
47
+ box=True, # draw box plot inside the violin
48
+ points="all", # can be 'outliers', or False
49
+ )
50
+
51
+ # Make the fig taller
52
+
53
+ fig.update_layout(
54
+ height=500,
55
+ width=1000,
56
+ title="Ratio of green to red",
57
+ yaxis_title="Ratio",
58
+ xaxis_title="Image",
59
+ legend_title="Image",
60
+ )
61
+ return fig
dnafiber/analysis/const.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from catppuccin.palette import PALETTE
2
+
3
+ palette = [c.hex for c in PALETTE.latte.colors]
dnafiber/analysis/utils.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm.auto import tqdm
2
+ from dnafiber.data.utils import read_colormask
3
+ import numpy as np
4
+
5
+
6
+ def build_consensus_map(intergraders, root_img, list_img):
7
+ all_masks = []
8
+ for img_path in tqdm(list_img):
9
+ path_from_root = img_path.relative_to(root_img)
10
+ masks = []
11
+ for intergrader in intergraders:
12
+ intergrader_path = (intergrader / path_from_root).with_suffix(".png")
13
+ if not intergrader_path.exists():
14
+ print(f"Missing {intergrader_path}")
15
+ continue
16
+ mask = read_colormask(intergrader_path)
17
+ masks.append(mask)
18
+ masks = np.array(masks)
19
+
20
+ all_masks.append(masks)
21
+ return np.array(all_masks)
dnafiber/callbacks.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lightning.pytorch.callbacks import Callback
2
+ from pytorch_lightning.utilities import rank_zero_only
3
+ import wandb
4
+
5
+
6
+ class LogPredictionSamplesCallback(Callback):
7
+ def __init__(self, wandb_logger, n_images=8):
8
+ self.n_images = n_images
9
+ self.wandb_logger = wandb_logger
10
+ super().__init__()
11
+
12
+ @rank_zero_only
13
+ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
14
+ if batch_idx < 1 and trainer.is_global_zero:
15
+ n = self.n_images
16
+ x = batch["image"][:n].float()
17
+ h, w = x.shape[-2:]
18
+ y = batch["mask"][:n]
19
+ pred = outputs[:n]
20
+ pred = pred.argmax(dim=1)
21
+
22
+ if len(y.shape) == 4:
23
+ y = y.squeeze(1)
24
+ if len(pred.shape) == 4:
25
+ pred = pred.squeeze(1)
26
+ y = y.clamp(0, 2)
27
+ columns = ["image"]
28
+ class_labels = {0: "Background", 1: "Red", 2: "Green"}
29
+
30
+ data = [
31
+ [
32
+ wandb.Image(
33
+ x_i,
34
+ masks={
35
+ "Prediction": {
36
+ "mask_data": p_i.cpu().numpy(),
37
+ "class_labels": class_labels,
38
+ },
39
+ "Groundtruth": {
40
+ "mask_data": y_i.cpu().numpy(),
41
+ "class_labels": class_labels,
42
+ },
43
+ },
44
+ )
45
+ ]
46
+ for x_i, y_i, p_i in list(zip(x, y, pred))
47
+ ]
48
+ self.wandb_logger.log_table(
49
+ data=data, key=f"Validation Batch {batch_idx}", columns=columns
50
+ )
dnafiber/data/__init__.py ADDED
File without changes
dnafiber/data/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (167 Bytes). View file
 
dnafiber/data/__pycache__/utils.cpython-312.pyc ADDED
Binary file (4.66 kB). View file
 
dnafiber/data/dataset.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import albumentations as A
2
+ import nntools.dataset as D
3
+ import numpy as np
4
+ from albumentations.pytorch import ToTensorV2
5
+ from lightning import LightningDataModule
6
+ from sklearn.model_selection import train_test_split
7
+ from torch.utils.data import DataLoader
8
+ from skimage.measure import label, regionprops
9
+ from skimage.morphology import skeletonize, dilation
10
+ from skimage.segmentation import expand_labels
11
+ import torch
12
+ from nntools.dataset.composer import CacheBullet
13
+
14
+
15
+ @D.nntools_wrapper
16
+ def convert_mask(mask):
17
+ output = np.zeros(mask.shape[:2], dtype=np.uint8)
18
+ output[mask[:, :, 0] > 200] = 1
19
+ output[mask[:, :, 1] > 200] = 2
20
+ binary_mask = output > 0
21
+ skeleton = skeletonize(binary_mask) * output
22
+ output = expand_labels(skeleton, 3)
23
+ output = np.clip(output, 0, 2)
24
+ return {"mask": output}
25
+
26
+
27
+ @D.nntools_wrapper
28
+ def extract_bbox(mask):
29
+ binary_mask = mask > 0
30
+ labelled = label(binary_mask)
31
+ props = regionprops(labelled, intensity_image=mask)
32
+ skeleton = skeletonize(binary_mask) * mask
33
+ mask = dilation(skeleton, np.ones((3, 3)))
34
+ bboxes = []
35
+ masks = []
36
+ # We want the XYXY format
37
+ for prop in props:
38
+ minr, minc, maxr, maxc = prop.bbox
39
+ bboxes.append([minc, minr, maxc, maxr])
40
+ masks.append((labelled == prop.label).astype(np.uint8))
41
+ if not masks:
42
+ masks = np.zeros_like(mask)[np.newaxis, :, :]
43
+ masks = np.array(masks)
44
+ masks = np.moveaxis(masks, 0, -1)
45
+
46
+ return {
47
+ "bboxes": np.array(bboxes),
48
+ "mask": masks,
49
+ "fiber_ids": np.array([p.label for p in props]),
50
+ }
51
+
52
+
53
+ class FiberDatamodule(LightningDataModule):
54
+ def __init__(
55
+ self,
56
+ root_img,
57
+ crop_size=(256, 256),
58
+ shape=1024,
59
+ batch_size=32,
60
+ num_workers=8,
61
+ use_bbox=False,
62
+ **kwargs,
63
+ ):
64
+ self.shape = shape
65
+ self.root_img = str(root_img)
66
+ self.crop_size = crop_size
67
+ self.batch_size = batch_size
68
+ self.num_workers = num_workers
69
+ self.kwargs = kwargs
70
+ self.use_bbox = use_bbox
71
+
72
+ super().__init__()
73
+
74
+ def setup(self, *args, **kwargs):
75
+ def _get_dataset(version):
76
+ dataset = D.MultiImageDataset(
77
+ {
78
+ "image": f"{self.root_img}/{version}/images/",
79
+ "mask": f"{self.root_img}/{version}/annotations/",
80
+ },
81
+ shape=(self.shape, self.shape),
82
+ use_cache=self.kwargs.get("use_cache", False),
83
+ cache_option=self.kwargs.get("cache_option", None),
84
+ ) # type: ignore
85
+ dataset.img_filepath["image"] = np.asarray( # type: ignore
86
+ sorted(
87
+ list(dataset.img_filepath["image"]),
88
+ key=lambda x: (x.parent.stem, x.stem),
89
+ )
90
+ )
91
+ dataset.img_filepath["mask"] = np.asarray( # type: ignore
92
+ sorted(
93
+ list(dataset.img_filepath["mask"]),
94
+ key=lambda x: (x.parent.stem, x.stem),
95
+ )
96
+ )
97
+ dataset.composer = D.Composition()
98
+ dataset.composer << convert_mask # type: ignore
99
+ if self.use_bbox:
100
+ dataset.composer << extract_bbox
101
+
102
+ return dataset
103
+
104
+ self.train = _get_dataset("train")
105
+ self.val = _get_dataset("train")
106
+ self.test = _get_dataset("test")
107
+ self.train.composer << CacheBullet()
108
+ self.val.use_cache = False
109
+ self.test.use_cache = False
110
+
111
+ stratify = []
112
+ for f in self.train.img_filepath["image"]:
113
+ if "tile" in f.stem:
114
+ stratify.append(int(f.parent.stem))
115
+ else:
116
+ stratify.append(25)
117
+ train_idx, val_idx = train_test_split(
118
+ np.arange(len(self.train)), # type: ignore
119
+ stratify=stratify,
120
+ test_size=0.2,
121
+ random_state=42,
122
+ )
123
+ self.train.subset(train_idx)
124
+ self.val.subset(val_idx)
125
+
126
+ self.train.composer.add(*self.get_train_composer())
127
+ self.val.composer.add(*self.cast_operators())
128
+ self.test.composer.add(*self.cast_operators())
129
+
130
+ def get_train_composer(self):
131
+ transforms = []
132
+ if self.crop_size is not None:
133
+ transforms.append(
134
+ A.CropNonEmptyMaskIfExists(
135
+ width=self.crop_size[0], height=self.crop_size[1]
136
+ ),
137
+ )
138
+ return [
139
+ A.Compose(
140
+ transforms
141
+ + [
142
+ A.HorizontalFlip(),
143
+ A.VerticalFlip(),
144
+ A.Affine(),
145
+ A.ElasticTransform(),
146
+ A.RandomRotate90(),
147
+ A.OneOf(
148
+ [
149
+ A.RandomBrightnessContrast(
150
+ brightness_limit=(-0.2, 0.1),
151
+ contrast_limit=(-0.2, 0.1),
152
+ p=0.5,
153
+ ),
154
+ A.HueSaturationValue(
155
+ hue_shift_limit=(-5, 5),
156
+ sat_shift_limit=(-20, 20),
157
+ val_shift_limit=(-20, 20),
158
+ p=0.5,
159
+ ),
160
+ ]
161
+ ),
162
+ A.GaussNoise(std_range=(0.0, 0.1), p=0.5),
163
+ ],
164
+ bbox_params=A.BboxParams(
165
+ format="pascal_voc", label_fields=["fiber_ids"], min_visibility=0.95
166
+ )
167
+ if self.use_bbox
168
+ else None,
169
+ ),
170
+ *self.cast_operators(),
171
+ ]
172
+
173
+ def cast_operators(self):
174
+ return [
175
+ A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
176
+ if not self.use_bbox
177
+ else A.Normalize(
178
+ mean=(
179
+ 0.0,
180
+ 0.0,
181
+ 0.0,
182
+ ),
183
+ std=(1.0, 1.0, 1.0),
184
+ max_pixel_value=255,
185
+ ),
186
+ ToTensorV2(),
187
+ ]
188
+
189
+ def train_dataloader(self):
190
+ if self.use_bbox:
191
+ return DataLoader(
192
+ self.train,
193
+ batch_size=self.batch_size,
194
+ shuffle=True,
195
+ num_workers=self.num_workers,
196
+ pin_memory=True,
197
+ persistent_workers=True,
198
+ collate_fn=bbox_collate_fn,
199
+ )
200
+
201
+ else:
202
+ return DataLoader(
203
+ self.train,
204
+ batch_size=self.batch_size,
205
+ shuffle=True,
206
+ num_workers=self.num_workers,
207
+ pin_memory=True,
208
+ persistent_workers=True,
209
+ )
210
+
211
+ def val_dataloader(self):
212
+ if self.use_bbox:
213
+ return DataLoader(
214
+ self.val,
215
+ batch_size=self.batch_size,
216
+ shuffle=False,
217
+ num_workers=self.num_workers,
218
+ pin_memory=True,
219
+ persistent_workers=True,
220
+ collate_fn=bbox_collate_fn,
221
+ )
222
+ return DataLoader(
223
+ self.val,
224
+ batch_size=self.batch_size,
225
+ shuffle=False,
226
+ num_workers=self.num_workers,
227
+ )
228
+
229
+ def test_dataloader(self):
230
+ if self.use_bbox:
231
+ return DataLoader(
232
+ self.test,
233
+ batch_size=self.batch_size,
234
+ shuffle=False,
235
+ num_workers=self.num_workers,
236
+ pin_memory=True,
237
+ persistent_workers=True,
238
+ collate_fn=bbox_collate_fn,
239
+ )
240
+ return DataLoader(
241
+ self.test,
242
+ batch_size=self.batch_size,
243
+ shuffle=False,
244
+ num_workers=self.num_workers,
245
+ )
246
+
247
+
248
+ def bbox_collate_fn(batch):
249
+ images = []
250
+ targets = []
251
+
252
+ for b in batch:
253
+ target = dict()
254
+
255
+ target["boxes"] = torch.from_numpy(b["bboxes"])
256
+ if target["boxes"].shape[0] == 0:
257
+ target["boxes"] = torch.zeros((0, 4), dtype=torch.float32)
258
+ images.append(b["image"])
259
+ target["boxes"] = torch.from_numpy(b["bboxes"])
260
+ target["masks"] = b["mask"].permute(2, 0, 1)
261
+ if target["boxes"].shape[0] == 0:
262
+ target["labels"] = torch.zeros(1, dtype=torch.int64)
263
+ else:
264
+ target["labels"] = torch.ones_like(target["boxes"][:, 0], dtype=torch.int64)
265
+
266
+ targets.append(target)
267
+
268
+ return {
269
+ "image": torch.stack(images),
270
+ "targets": targets,
271
+ }
dnafiber/data/intergrader/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .const import *
dnafiber/data/intergrader/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (206 Bytes). View file
 
dnafiber/data/intergrader/__pycache__/analysis.cpython-312.pyc ADDED
Binary file (6.06 kB). View file
 
dnafiber/data/intergrader/__pycache__/const.cpython-312.pyc ADDED
Binary file (994 Bytes). View file
 
dnafiber/data/intergrader/__pycache__/io.cpython-312.pyc ADDED
Binary file (1.72 kB). View file
 
dnafiber/data/intergrader/__pycache__/plot.cpython-312.pyc ADDED
Binary file (6.6 kB). View file
 
dnafiber/data/intergrader/analysis.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from skimage.morphology import skeletonize
2
+ import numpy as np
3
+ from skimage.measure import label
4
+ from tqdm.contrib.concurrent import thread_map # or thread_map
5
+ def extract_fiber_properties(mask):
6
+
7
+ binary_mask = mask > 0
8
+ skeleton = skeletonize(binary_mask)
9
+ r = mask == 1
10
+ g = mask == 2
11
+ labeled_skeleton = label(skeleton, connectivity=2)
12
+ properties = {"R": [], "G": [], "ratio": []}
13
+ for i in range(1, labeled_skeleton.max() + 1):
14
+ fiber_mask = labeled_skeleton == i
15
+ sum_r = np.sum(r & fiber_mask)
16
+ sum_g = np.sum(g & fiber_mask)
17
+ if sum_r == 0 or sum_g == 0:
18
+ continue
19
+ properties["R"].append(np.sum(r & fiber_mask))
20
+ properties["G"].append(np.sum(g & fiber_mask))
21
+
22
+ properties["R"] = np.array(properties["R"])
23
+ properties["G"] = np.array(properties["G"])
24
+ properties["ratio"] = properties["R"] / (properties["G"])
25
+ properties["label"] = labeled_skeleton
26
+ return properties
27
+
28
+
29
+ def filter_non_commons_fibers(properties):
30
+ # Properties is a a list of dicts. For each dict, we have a labelmap and a list of reds, greens and ratios
31
+ # We want to filter out the fibers that are not common in all images
32
+
33
+ binary_labels = [p['label'] > 0 for p in properties]
34
+ common_labels = np.logical_and.reduce(binary_labels)
35
+ filtered_properties = {k:[] for k in properties.keys()}
36
+ for i, p in enumerate(properties):
37
+ # We want to keep the labels that are common in all images
38
+ good_labels = common_labels * p['label']
39
+ indices = np.unique(good_labels[good_labels > 0])
40
+
41
+ filtered_properties.append({
42
+ "R": p["R"][common_labels],
43
+ "G": p["G"][common_labels],
44
+ "ratio": p["ratio"][common_labels],
45
+ "label": p["label"][common_labels]
46
+ })
47
+
48
+ def skeletonize_mask(mask):
49
+ # Skeletonize the mask and return the skeleton
50
+ binary_mask = mask > 0
51
+ skeleton = skeletonize(binary_mask) * mask
52
+ return skeleton
53
+
54
+
55
+ def skeletonize_data_dict(data_dict):
56
+ skeletons = dict()
57
+ for annotator, images in data_dict.items():
58
+ skeletons[annotator] = dict()
59
+ for image_type, masks in images.items():
60
+ skeletons[annotator][image_type] = thread_map(skeletonize_mask, masks, max_workers=8)
61
+
62
+ return skeletons
63
+
64
+
65
+ def extract_properties_from_datadict(data_dict, with_common_analysis=True):
66
+ """
67
+ Extract the properties of the fibers from the data dictionary.
68
+ The data dictionary is a dict of annotators. Each value is a dict of images. Each image is a list of masks.
69
+ """
70
+ properties = dict(annotator=[], image_type=[], red=[], green=[], ratio=[], fiber_type=[])
71
+ all_annotators = list(data_dict.keys())
72
+
73
+ found_by = {a: [] for a in all_annotators}
74
+ properties.update(found_by)
75
+ for annotator, images in data_dict.items():
76
+ for image_type, masks in images.items():
77
+ for i, mask in enumerate(masks):
78
+ if with_common_analysis:
79
+ others_masks = []
80
+ other_annotators = []
81
+ for other in all_annotators:
82
+ if other == annotator:
83
+ continue
84
+ other_annotators.append(other)
85
+ others_masks.append(data_dict[other][image_type][i] > 0)
86
+
87
+ labels, num = label(mask>0, connectivity=2, return_num=True)
88
+ for l in range(1, num + 1):
89
+ fiber = labels == l
90
+ if np.sum(fiber) < 10:
91
+ continue
92
+
93
+ properties["annotator"].append(annotator)
94
+ properties["image_type"].append(image_type)
95
+
96
+ # Check for common fibers
97
+ properties[annotator].append(True)
98
+ if with_common_analysis:
99
+ for i, (other_mask, other_annotator) in enumerate(zip(others_masks, other_annotators)):
100
+ properties[other_annotator].append(np.any(fiber & other_mask))
101
+
102
+ red_length = np.sum(fiber & (mask == 1))
103
+ green_length = np.sum(fiber & (mask == 2))
104
+ if red_length == 0 or green_length == 0:
105
+ continue
106
+ properties["ratio"].append(green_length / (red_length + 1e-7)) # Avoid division by zero
107
+ properties["red"].append(red_length)
108
+ properties["green"].append(green_length)
109
+
110
+ segments, count = label(mask[fiber], connectivity=1, return_num=True)
111
+ if count == 1:
112
+ properties["fiber_type"].append("single")
113
+ elif count == 2:
114
+ properties["fiber_type"].append("double")
115
+ elif count > 2:
116
+ properties["fiber_type"].append("multiple")
117
+ else:
118
+ properties["fiber_type"].append("unknown")
119
+
120
+ return properties
dnafiber/data/intergrader/auto.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ def inference_model(model, path, use_cuda=False):
2
+ pass
3
+
dnafiber/data/intergrader/const.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BLIND_MAPPING = {
2
+ "siB+M-01": "0",
3
+ "siB+M-04": "1",
4
+ "siBRCA2-02": "5",
5
+ "siBRCA2-03": "15",
6
+ "siTONSL-03": "11",
7
+ "siTONSL-04": "14",
8
+ "HLTF ko+si MMS22L-01": "8",
9
+ "HLTF ko+si MMS22L-02": "13",
10
+ "siBRCA2+SMARCAL KO-01": "2",
11
+ "siBRCA2+SMARCAL KO-03": "9",
12
+ "siBRCA2+SMARCAL KO-04": "16",
13
+ "siBRCA2-01": "4",
14
+ "59_siBRCA2-02": "7",
15
+ "siNT-01": "10",
16
+ "siNT-02": "12",
17
+ "siMMS22L_+dox-01": "3",
18
+ "siMMS22L_+dox-02": "6",
19
+ }
20
+
21
+ REVERSE_BLIND_MAPPING = {v: k for k, v in BLIND_MAPPING.items()}
dnafiber/data/intergrader/io.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from skimage.segmentation import expand_labels
4
+
5
+ def read_to_mask(f):
6
+ img = cv2.imread(str(f), cv2.IMREAD_UNCHANGED)[:,:,::-1]
7
+ mask = np.zeros(img.shape[:2], dtype=np.uint8)
8
+ mask[img[:, :, 0] > 200] = 1
9
+ mask[img[:, :, 1] > 200] = 2
10
+
11
+ return mask
12
+
13
+
14
+ def read_mask_from_path_gens(dict_gens, mapping=None):
15
+ output = {k: dict() for k in dict_gens.keys()}
16
+ for k, files in dict_gens.items():
17
+ for file in files:
18
+ name = file.parent.stem
19
+ if mapping is not None:
20
+ name = mapping.get(name, name)
21
+ mask = read_to_mask(file)
22
+ mask = expand_labels(mask, 1)
23
+ if output[k].get(name) is None:
24
+ output[k][name] = []
25
+ output[k][name].append(mask)
26
+ return output
27
+
dnafiber/data/intergrader/plot.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ from matplotlib.colors import ListedColormap
4
+ from skimage.measure import label, regionprops
5
+ import base64
6
+ from typing import Callable
7
+
8
+ def imshow_compare(data_dict, ax_size=4, draw_bbox=False, max_images=None):
9
+ """
10
+ Display the images in a grid format for comparison.
11
+ Each key is an annotator, each value is another dict, where the key is the image type and the value the list of corresponding images.
12
+ """
13
+ # 0 is black, 1 is red, 2 is green
14
+ cmap = ListedColormap(['black', 'red', 'green'])
15
+
16
+ # Convert the data dictionary to a dict of annotators: list of images
17
+ data = dict()
18
+ for annotator, images in data_dict.items():
19
+ if annotator not in data:
20
+ data[annotator] = []
21
+ for image_type, masks in images.items():
22
+ for mask in masks:
23
+ data[annotator].append(mask)
24
+
25
+ annotators = list(data.keys())
26
+ num_images = len(data[annotators[0]])
27
+ if max_images is not None and num_images > max_images:
28
+ num_images = max_images
29
+ num_annotators = len(annotators)
30
+
31
+ fig_size = (ax_size * num_annotators, ax_size * num_images)
32
+ fig, axes = plt.subplots(num_images, num_annotators, figsize=fig_size, squeeze=False)
33
+
34
+ for i, annotator in enumerate(annotators):
35
+ for j in range(num_images):
36
+ if max_images is not None and j > max_images:
37
+ break
38
+ ax = axes[j, i]
39
+ mask = data[annotator][j]
40
+ ax.imshow(mask, cmap=cmap, interpolation='nearest')
41
+ ax.axis('off')
42
+ ax.set_xticks([])
43
+ ax.set_yticks([])
44
+ if draw_bbox:
45
+ mask = mask > 0
46
+ labeled_mask = label(mask, connectivity=2)
47
+ regions = regionprops(labeled_mask)
48
+ for region in regions:
49
+ minr, minc, maxr, maxc = region.bbox
50
+ rect = plt.Rectangle((minc, minr), maxc - minc, maxr - minr,
51
+ fill=False, edgecolor='yellow', linewidth=0.5)
52
+ ax.add_patch(rect)
53
+
54
+
55
+
56
+ if j == 0:
57
+ ax.set_title(annotator)
58
+
59
+
60
+ fig.tight_layout()
61
+ return fig, axes
62
+
63
+
64
+ def add_p_value_annotation(fig, array_columns, stats_test, subplot=None, _format=dict(interline=0.07, text_height=1.07, color='black')):
65
+ ''' Adds notations giving the p-value between two box plot data (t-test two-sided comparison)
66
+
67
+ Parameters:
68
+ ----------
69
+ fig: figure
70
+ plotly boxplot figure
71
+ array_columns: np.array
72
+ array of which columns to compare
73
+ e.g.: [[0,1], [1,2]] compares column 0 with 1 and 1 with 2
74
+ subplot: None or int
75
+ specifies if the figures has subplots and what subplot to add the notation to
76
+ _format: dict
77
+ format characteristics for the lines
78
+
79
+ Returns:
80
+ -------
81
+ fig: figure
82
+ figure with the added notation
83
+ '''
84
+ # Specify in what y_range to plot for each pair of columns
85
+ y_range = np.zeros([len(array_columns), 2])
86
+ for i in range(len(array_columns)):
87
+ y_range[i] = [1.01+i*_format['interline'], 1.02+i*_format['interline']]
88
+
89
+ # Get values from figure
90
+ fig_dict = fig.to_dict()
91
+ # Get indices if working with subplots
92
+ if subplot:
93
+ if subplot == 1:
94
+ subplot_str = ''
95
+ else:
96
+ subplot_str =str(subplot)
97
+ indices = [] #Change the box index to the indices of the data for that subplot
98
+ for index, data in enumerate(fig_dict['data']):
99
+ #print(index, data['xaxis'], 'x' + subplot_str)
100
+ if data['xaxis'] == 'x' + subplot_str:
101
+ indices = np.append(indices, index)
102
+ indices = [int(i) for i in indices]
103
+ print((indices))
104
+ else:
105
+ subplot_str = ''
106
+
107
+ # Print the p-values
108
+ for index, column_pair in enumerate(array_columns):
109
+ if subplot:
110
+ data_pair = [indices[column_pair[0]], indices[column_pair[1]]]
111
+ else:
112
+ data_pair = column_pair
113
+
114
+ # Mare sure it is selecting the data and subplot you want
115
+ #print('0:', fig_dict['data'][data_pair[0]]['name'], fig_dict['data'][data_pair[0]]['xaxis'])
116
+ #print('1:', fig_dict['data'][data_pair[1]]['name'], fig_dict['data'][data_pair[1]]['xaxis'])
117
+
118
+ if isinstance(stats_test, Callable):
119
+ # Get the p-value
120
+ d1 = fig_dict['data'][data_pair[0]]['y']
121
+ d2 = fig_dict['data'][data_pair[1]]['y']
122
+ d1 = base64.b64decode(d1['bdata'])
123
+ d2 = base64.b64decode(d2['bdata'])
124
+ d1 = np.frombuffer(d1, dtype=np.float64)
125
+ d2 = np.frombuffer(d2, dtype=np.float64)
126
+ pvalue = stats_test(
127
+ d1,
128
+ d2,
129
+ )[1]
130
+ else:
131
+ pvalue = stats_test[index]
132
+ if pvalue >= 0.05:
133
+ symbol = 'ns'
134
+ elif pvalue >= 0.01:
135
+ symbol = '*'
136
+ elif pvalue >= 0.001:
137
+ symbol = '**'
138
+ else:
139
+ symbol = '***'
140
+ # Vertical line
141
+ fig.add_shape(type="line",
142
+ xref="x"+subplot_str, yref="y"+subplot_str+" domain",
143
+ x0=column_pair[0], y0=y_range[index][0],
144
+ x1=column_pair[0], y1=y_range[index][1],
145
+ line=dict(color=_format['color'], width=2,)
146
+ )
147
+ # Horizontal line
148
+ fig.add_shape(type="line",
149
+ xref="x"+subplot_str, yref="y"+subplot_str+" domain",
150
+ x0=column_pair[0], y0=y_range[index][1],
151
+ x1=column_pair[1], y1=y_range[index][1],
152
+ line=dict(color=_format['color'], width=2,)
153
+ )
154
+ # Vertical line
155
+ fig.add_shape(type="line",
156
+ xref="x"+subplot_str, yref="y"+subplot_str+" domain",
157
+ x0=column_pair[1], y0=y_range[index][0],
158
+ x1=column_pair[1], y1=y_range[index][1],
159
+ line=dict(color=_format['color'], width=2,)
160
+ )
161
+ ## add text at the correct x, y coordinates
162
+ ## for bars, there is a direct mapping from the bar number to 0, 1, 2...
163
+ fig.add_annotation(dict(font=dict(color=_format['color'],size=14),
164
+ x=(column_pair[0] + column_pair[1])/2,
165
+ y=y_range[index][1]*_format['text_height'],
166
+ showarrow=False,
167
+ text=symbol,
168
+ textangle=0,
169
+ xref="x"+subplot_str,
170
+ yref="y"+subplot_str+" domain"
171
+ ))
172
+ return fig
dnafiber/data/utils.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+
3
+ from xml.dom import minidom
4
+ import cv2
5
+ import numpy as np
6
+ from czifile import CziFile
7
+ from tifffile import imread
8
+
9
+
10
+ def read_svg(svg_path):
11
+ doc = minidom.parse(str(svg_path))
12
+ img_strings = {
13
+ path.getAttribute("id"): path.getAttribute("href")
14
+ for path in doc.getElementsByTagName("image")
15
+ }
16
+ doc.unlink()
17
+
18
+ red = img_strings["Red"]
19
+ green = img_strings["Green"]
20
+ red = base64.b64decode(red.split(",")[1])
21
+ green = base64.b64decode(green.split(",")[1])
22
+ red = cv2.imdecode(np.frombuffer(red, dtype=np.uint8), cv2.IMREAD_UNCHANGED)
23
+ green = cv2.imdecode(np.frombuffer(green, dtype=np.uint8), cv2.IMREAD_UNCHANGED)
24
+
25
+ red = cv2.cvtColor(red, cv2.COLOR_BGRA2GRAY)
26
+ green = cv2.cvtColor(green, cv2.COLOR_BGRA2GRAY)
27
+ mask = np.zeros_like(red)
28
+ mask[red > 0] = 1
29
+ mask[green > 0] = 2
30
+ return mask
31
+
32
+
33
+ def extract_bboxes(mask):
34
+ mask = np.array(mask)
35
+ mask = mask.astype(np.uint8)
36
+
37
+ # Find connected components
38
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(
39
+ mask, connectivity=8
40
+ )
41
+ bboxes = []
42
+ for i in range(1, num_labels):
43
+ x, y, w, h, area = stats[i]
44
+ bboxes.append([x, y, x + w, y + h])
45
+ return bboxes
46
+
47
+
48
+ def preprocess(raw_data, reverse_channels=False):
49
+ MAX_VALUE = 2**16 - 1
50
+ if raw_data.ndim == 2:
51
+ raw_data = raw_data[np.newaxis, :, :]
52
+ h, w = raw_data.shape[1:3]
53
+ orders = np.arange(raw_data.shape[0])[::-1] # Reverse channel order
54
+ result = np.zeros((h, w, 3), dtype=np.uint8)
55
+
56
+ for i, chan in enumerate(raw_data):
57
+ hist, bins = np.histogram(chan.ravel(), MAX_VALUE + 1, (0, MAX_VALUE + 1))
58
+ cdf = hist.cumsum()
59
+ cdf_normalized = cdf / cdf[-1]
60
+ bmax = np.searchsorted(cdf_normalized, 0.99, side="left")
61
+ clip = np.clip(chan, 0, bmax).astype(np.float32)
62
+ clip = (clip - clip.min()) / (bmax - clip.min()) * 255
63
+ result[:, :, orders[i]] = clip
64
+ if reverse_channels:
65
+ # Reverse channels 0 and 1
66
+ result = result[:, :, [1, 0, 2]]
67
+ return result
68
+
69
+
70
+ def read_czi(filepath):
71
+ data = CziFile(filepath)
72
+
73
+ return data.asarray().squeeze()
74
+
75
+
76
+ def read_tiff(filepath):
77
+
78
+ data = imread(filepath).squeeze()
79
+
80
+ return data
dnafiber/deployment.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dnafiber.trainee import Trainee
2
+ from dnafiber.postprocess.fiber import FiberProps
3
+ import pandas as pd
4
+
5
+ def _get_model(revision, device="cuda"):
6
+ if revision is None:
7
+ model = Trainee.from_pretrained(
8
+ "ClementP/DeepFiberQ", arch="unet", encoder_name="mit_b0"
9
+ )
10
+ else:
11
+ model = Trainee.from_pretrained(
12
+ "ClementP/DeepFiberQ",
13
+ revision=revision,
14
+ )
15
+ return model.eval().to(device)
16
+
17
+
18
+ def format_results(results: list[FiberProps], pixel_size: float) -> pd.DataFrame:
19
+ """
20
+ Format the results for display in the UI.
21
+ """
22
+ results = [fiber for fiber in results if fiber.is_valid]
23
+ all_results = dict(
24
+ FirstAnalog=[], SecondAnalog=[], length=[], ratio=[], fiber_type=[]
25
+ )
26
+ all_results["FirstAnalog"].extend([fiber.red * pixel_size for fiber in results])
27
+ all_results["SecondAnalog"].extend([fiber.green * pixel_size for fiber in results])
28
+ all_results["length"].extend(
29
+ [fiber.red * pixel_size + fiber.green * pixel_size for fiber in results]
30
+ )
31
+ all_results["ratio"].extend([fiber.ratio for fiber in results])
32
+ all_results["fiber_type"].extend([fiber.fiber_type for fiber in results])
33
+
34
+ return pd.DataFrame.from_dict(all_results)
35
+
36
+
37
+
38
+
39
+ MODELS_ZOO = {
40
+ "Ensemble": "ensemble",
41
+ "SegFormer MiT-B4": "segformer_mit_b4",
42
+ "SegFormer MiT-B2": "segformer_mit_b2",
43
+ "U-Net SE-ResNet50": "unet_se_resnet50",
44
+ }
dnafiber/inference.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn.functional as F
2
+ import numpy as np
3
+ import torch
4
+ from torchvision.transforms._functional_tensor import normalize
5
+ import pandas as pd
6
+ from skimage.segmentation import expand_labels
7
+ from skimage.measure import label
8
+ import albumentations as A
9
+ from monai.inferers import SlidingWindowInferer
10
+ from dnafiber.deployment import _get_model
11
+ from dnafiber.postprocess import refine_segmentation
12
+
13
+ transform = A.Compose(
14
+ [
15
+ A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
16
+ A.ToTensorV2(),
17
+ ]
18
+ )
19
+
20
+
21
+ def preprocess_image(image):
22
+ image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0)
23
+ image = normalize(image, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
24
+ return image
25
+
26
+
27
+ def convert_to_dataset(counts):
28
+ data = {"index": [], "red": [], "green": [], "ratio": []}
29
+ for k, v in counts.items():
30
+ data["index"].append(k)
31
+ data["green"].append(v["green"])
32
+ data["red"].append(v["red"])
33
+ if v["red"] == 0:
34
+ data["ratio"].append(np.nan)
35
+ else:
36
+ data["ratio"].append(v["green"] / (v["red"]))
37
+ df = pd.DataFrame(data)
38
+ return df
39
+
40
+
41
+ def convert_mask_to_image(mask, expand=False):
42
+ if expand:
43
+ mask = expand_labels(mask, distance=expand)
44
+ h, w = mask.shape
45
+ image = np.zeros((h, w, 3), dtype=np.uint8)
46
+ GREEN = np.array([0, 255, 0])
47
+ RED = np.array([255, 0, 0])
48
+
49
+ image[mask == 1] = RED
50
+ image[mask == 2] = GREEN
51
+
52
+ return image
53
+
54
+
55
+ @torch.inference_mode()
56
+ def infer(model, image, device, scale=0.13, to_numpy=True, only_probabilities=False):
57
+ if isinstance(model, str):
58
+ model = _get_model(device=device, revision=model)
59
+ model_pixel_size = 0.26
60
+
61
+ scale = scale / model_pixel_size
62
+ tensor = transform(image=image)["image"].unsqueeze(0).to(device)
63
+ h, w = tensor.shape[2], tensor.shape[3]
64
+ device = torch.device(device)
65
+ with torch.autocast(device_type=device.type):
66
+ tensor = F.interpolate(
67
+ tensor,
68
+ size=(int(h * scale), int(w * scale)),
69
+ mode="bilinear",
70
+ )
71
+ if tensor.shape[2] > 1024 or tensor.shape[3] > 1024:
72
+ inferer = SlidingWindowInferer(
73
+ roi_size=(1024, 1024),
74
+ sw_batch_size=4,
75
+ overlap=0.25,
76
+ mode="gaussian",
77
+ device=device,
78
+ progress=True,
79
+ )
80
+ output = inferer(tensor, model)
81
+ else:
82
+ output = model(tensor)
83
+
84
+ probabilities = F.softmax(output, dim=1)
85
+ if only_probabilities:
86
+ probabilities = probabilities.cpu()
87
+
88
+ probabilities = F.interpolate(
89
+ probabilities,
90
+ size=(h, w),
91
+ mode="bilinear",
92
+ )
93
+ return probabilities
94
+
95
+ output = F.interpolate(
96
+ probabilities.argmax(dim=1, keepdim=True).float(),
97
+ size=(h, w),
98
+ mode="nearest",
99
+ )
100
+
101
+ output = output.squeeze().byte()
102
+ if to_numpy:
103
+ output = output.cpu().numpy()
104
+
105
+ return output
dnafiber/metric.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import kornia as K
2
+ import torch
3
+ import torchmetrics.functional as F
4
+ from skimage.measure import label
5
+ from torchmetrics import Metric
6
+
7
+
8
+ class DNAFIBERMetric(Metric):
9
+ def __init__(self, **kwargs):
10
+ super().__init__(**kwargs)
11
+
12
+ self.add_state(
13
+ "detection_tp",
14
+ default=torch.tensor(0, dtype=torch.int64),
15
+ dist_reduce_fx="sum",
16
+ )
17
+ self.add_state(
18
+ "fiber_red_dice",
19
+ default=torch.tensor(0, dtype=torch.float32),
20
+ dist_reduce_fx="sum",
21
+ )
22
+ self.add_state(
23
+ "fiber_green_dice",
24
+ default=torch.tensor(0, dtype=torch.float32),
25
+ dist_reduce_fx="sum",
26
+ )
27
+ self.add_state(
28
+ "fiber_red_recall",
29
+ default=torch.tensor(0, dtype=torch.float32),
30
+ dist_reduce_fx="sum",
31
+ )
32
+ self.add_state(
33
+ "fiber_green_recall",
34
+ default=torch.tensor(0, dtype=torch.float32),
35
+ dist_reduce_fx="sum",
36
+ )
37
+ # Specificity
38
+ self.add_state(
39
+ "fiber_red_precision",
40
+ default=torch.tensor(0, dtype=torch.float32),
41
+ dist_reduce_fx="sum",
42
+ )
43
+ self.add_state(
44
+ "fiber_green_precision",
45
+ default=torch.tensor(0, dtype=torch.float32),
46
+ dist_reduce_fx="sum",
47
+ )
48
+
49
+ self.add_state(
50
+ "detection_fp",
51
+ default=torch.tensor(0, dtype=torch.int64),
52
+ dist_reduce_fx="sum",
53
+ )
54
+ self.add_state(
55
+ "N",
56
+ default=torch.tensor(0, dtype=torch.int64),
57
+ dist_reduce_fx="sum",
58
+ )
59
+
60
+ def update(self, preds, target):
61
+ if preds.ndim == 4:
62
+ preds = preds.argmax(dim=1)
63
+ if target.ndim == 4:
64
+ target = target.squeeze(1)
65
+ B, H, W = preds.shape
66
+ preds_labels = []
67
+ target_labels = []
68
+ binary_preds = preds > 0
69
+ binary_target = target > 0
70
+ N_true_labels = 0
71
+ for i in range(B):
72
+ pred = binary_preds[i].detach().cpu().numpy()
73
+ target_np = binary_target[i].detach().cpu().numpy()
74
+ pred_labels = label(pred, connectivity=2)
75
+ target_labels_np = label(target_np, connectivity=2)
76
+ preds_labels.append(torch.from_numpy(pred_labels).to(preds.device))
77
+ target_labels.append(torch.from_numpy(target_labels_np).to(preds.device))
78
+ N_true_labels += target_labels_np.max()
79
+
80
+ preds_labels = torch.stack(preds_labels)
81
+ target_labels = torch.stack(target_labels)
82
+
83
+ for i, plab in enumerate(preds_labels):
84
+ labels = torch.unique(plab)
85
+ for blob in labels:
86
+ if blob == 0:
87
+ continue
88
+ pred_mask = plab == blob
89
+ pixels_in_common = torch.any(pred_mask & binary_target[i])
90
+ if pixels_in_common:
91
+ self.detection_tp += 1
92
+ gt_label = target_labels[i][pred_mask].unique()[-1]
93
+ gt_mask = target_labels[i] == gt_label
94
+ common_mask = pred_mask | gt_mask
95
+ pred_fiber = preds[i][common_mask]
96
+ gt_fiber = target[i][common_mask]
97
+ dices = F.dice(
98
+ pred_fiber,
99
+ gt_fiber,
100
+ num_classes=3,
101
+ ignore_index=0,
102
+ average=None,
103
+ )
104
+ dices = torch.nan_to_num(dices, nan=0.0)
105
+ self.fiber_red_dice += dices[1]
106
+ self.fiber_green_dice += dices[2]
107
+ recalls = F.recall(
108
+ pred_fiber,
109
+ gt_fiber,
110
+ num_classes=3,
111
+ ignore_index=0,
112
+ task="multiclass",
113
+ average=None,
114
+ )
115
+ recalls = torch.nan_to_num(recalls, nan=0.0)
116
+ self.fiber_red_recall += recalls[1]
117
+ self.fiber_green_recall += recalls[2]
118
+
119
+ # Specificity
120
+ specificity = F.precision(
121
+ pred_fiber,
122
+ gt_fiber,
123
+ num_classes=3,
124
+ ignore_index=0,
125
+ task="multiclass",
126
+ average=None,
127
+ )
128
+ specificity = torch.nan_to_num(specificity, nan=0.0)
129
+ self.fiber_red_precision += specificity[1]
130
+ self.fiber_green_precision += specificity[2]
131
+
132
+ else:
133
+ self.detection_fp += 1
134
+
135
+ self.N += N_true_labels
136
+
137
+ def compute(self):
138
+ return {
139
+ "detection_precision": self.detection_tp
140
+ / (self.detection_tp + self.detection_fp + 1e-7),
141
+ "detection_recall": self.detection_tp / (self.N + 1e-7),
142
+ "fiber_red_dice": self.fiber_red_dice / (self.detection_tp + 1e-7),
143
+ "fiber_green_dice": self.fiber_green_dice / (self.detection_tp + 1e-7),
144
+ "fiber_red_recall": self.fiber_red_recall / (self.detection_tp + 1e-7),
145
+ "fiber_green_recall": self.fiber_green_recall / (self.detection_tp + 1e-7),
146
+ "fiber_red_precision": self.fiber_red_precision
147
+ / (self.detection_tp + 1e-7),
148
+ "fiber_green_precision": self.fiber_green_precision
149
+ / (self.detection_tp + 1e-7),
150
+ }
dnafiber/model/maskrcnn.py ADDED
File without changes
dnafiber/postprocess/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .core import refine_segmentation
dnafiber/postprocess/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (225 Bytes). View file
 
dnafiber/postprocess/__pycache__/core.cpython-312.pyc ADDED
Binary file (11.9 kB). View file
 
dnafiber/postprocess/__pycache__/fiber.cpython-312.pyc ADDED
Binary file (6.39 kB). View file
 
dnafiber/postprocess/__pycache__/skan.cpython-312.pyc ADDED
Binary file (9.43 kB). View file
 
dnafiber/postprocess/core.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ from typing import List, Tuple
4
+ from dnafiber.postprocess.skan import find_endpoints, compute_points_angle
5
+ from scipy.spatial.distance import cdist
6
+
7
+ from scipy.sparse.csgraph import connected_components
8
+ from scipy.sparse import csr_array
9
+ from skimage.morphology import skeletonize
10
+ from dnafiber.postprocess.skan import find_line_intersection
11
+ from dnafiber.postprocess.fiber import Fiber, FiberProps, Bbox
12
+ from itertools import compress
13
+ import matplotlib.pyplot as plt
14
+ from matplotlib.colors import ListedColormap
15
+
16
+ cmlabel = ListedColormap(["black", "red", "green"])
17
+
18
+ MIN_ANGLE = 20
19
+ MIN_BRANCH_LENGTH = 10
20
+ MIN_BRANCH_DISTANCE = 30
21
+
22
+
23
+ def handle_multiple_fiber_in_cc(fiber, junctions_fiber, coordinates):
24
+ for y, x in junctions_fiber:
25
+ fiber[y - 1 : y + 2, x - 1 : x + 2] = 0
26
+
27
+ endpoints = find_endpoints(fiber > 0)
28
+ endpoints = np.asarray(endpoints)
29
+ # We only keep the endpoints that are close to the junction
30
+ # We compute the distance between the endpoints and the junctions
31
+ distances = np.linalg.norm(
32
+ np.expand_dims(endpoints, axis=1) - np.expand_dims(junctions_fiber, axis=0),
33
+ axis=2,
34
+ )
35
+ # We only keep the endpoints that are close to the junctions
36
+ distances = distances < 5
37
+ endpoints = endpoints[distances.any(axis=1)]
38
+
39
+ retval, branches, branches_stats, _ = cv2.connectedComponentsWithStatsWithAlgorithm(
40
+ fiber, connectivity=8, ccltype=cv2.CCL_BOLELLI, ltype=cv2.CV_16U
41
+ )
42
+ branches_bboxes = branches_stats[
43
+ :,
44
+ [cv2.CC_STAT_LEFT, cv2.CC_STAT_TOP, cv2.CC_STAT_WIDTH, cv2.CC_STAT_HEIGHT],
45
+ ]
46
+
47
+ num_branches = branches_bboxes.shape[0] - 1
48
+ # We associate the endpoints to the branches
49
+ endpoints_ids = np.zeros((len(endpoints),), dtype=np.uint16)
50
+ endpoints_color = np.zeros((len(endpoints),), dtype=np.uint8)
51
+ for i, endpoint in enumerate(endpoints):
52
+ # Get the branch id
53
+ branch_id = branches[endpoint[0], endpoint[1]]
54
+ # Check if the branch id is not 0
55
+ if branch_id != 0:
56
+ endpoints_ids[i] = branch_id
57
+ endpoints_color[i] = fiber[endpoint[0], endpoint[1]]
58
+
59
+ # We remove the small branches
60
+ kept_branches = set()
61
+ for i in range(1, num_branches + 1):
62
+ # Get the branch
63
+ branch = branches == i
64
+ # Compute the area of the branch
65
+ area = np.sum(branch.astype(np.uint8))
66
+ # If the area is less than 10 pixels, remove the branch
67
+ if area < MIN_BRANCH_LENGTH:
68
+ branches[branch] = 0
69
+ else:
70
+ kept_branches.add(i)
71
+
72
+ # We remove the endpoints that are in the filtered branches
73
+ remaining_idxs = np.isin(endpoints_ids, np.asarray(list(kept_branches)))
74
+ if remaining_idxs.sum() == 0:
75
+ return []
76
+ endpoints = endpoints[remaining_idxs]
77
+
78
+ endpoints_color = endpoints_color[remaining_idxs]
79
+ endpoints_ids = endpoints_ids[remaining_idxs]
80
+
81
+ # We compute the angles of the endpoints
82
+ angles = compute_points_angle(fiber, endpoints, steps=15)
83
+ angles = np.rad2deg(angles)
84
+ # We compute the difference of angles between all the endpoints
85
+ endpoints_angles_diff = cdist(angles[:, None], angles[:, None], metric="cityblock")
86
+
87
+ # Put inf to the diagonal
88
+ endpoints_angles_diff[range(len(endpoints)), range(len(endpoints))] = np.inf
89
+ endpoints_distances = cdist(endpoints, endpoints, metric="euclidean")
90
+
91
+ endpoints_distances[range(len(endpoints)), range(len(endpoints))] = np.inf
92
+
93
+ # We sort by the distance
94
+ endpoints_distances[endpoints_distances > MIN_BRANCH_DISTANCE] = np.inf
95
+ endpoints_distances[endpoints_angles_diff > MIN_ANGLE] = np.inf
96
+
97
+ matchB = np.argmin(endpoints_distances, axis=1)
98
+ values = np.take_along_axis(endpoints_distances, matchB[:, None], axis=1)
99
+
100
+ added_edges = dict()
101
+ N = len(endpoints)
102
+ A = np.eye(N, dtype=np.uint8)
103
+ for i in range(N):
104
+ for j in range(N):
105
+ if i == j:
106
+ continue
107
+ if endpoints_ids[i] == endpoints_ids[j]:
108
+ A[i, j] = 1
109
+ A[j, i] = 1
110
+
111
+ if matchB[i] == j and values[i, 0] < np.inf:
112
+ added_edges[i] = j
113
+ A[i, j] = 1
114
+ A[j, i] = 1
115
+
116
+ A = csr_array(A)
117
+ n, ccs = connected_components(A, directed=False, return_labels=True)
118
+ unique_clusters = np.unique(ccs)
119
+ results = []
120
+ for c in unique_clusters:
121
+ idx = np.where(ccs == c)[0]
122
+ branches_ids = np.unique(endpoints_ids[idx])
123
+
124
+ unique_branches = np.logical_or.reduce(
125
+ [branches == i for i in branches_ids], axis=0
126
+ )
127
+
128
+ commons_bboxes = branches_bboxes[branches_ids]
129
+ # Compute the union of the bboxes
130
+ min_x = np.min(commons_bboxes[:, 0])
131
+ min_y = np.min(commons_bboxes[:, 1])
132
+ max_x = np.max(commons_bboxes[:, 0] + commons_bboxes[:, 2])
133
+ max_y = np.max(commons_bboxes[:, 1] + commons_bboxes[:, 3])
134
+
135
+ new_fiber = fiber[min_y:max_y, min_x:max_x]
136
+ new_fiber = unique_branches[min_y:max_y, min_x:max_x] * new_fiber
137
+ for cidx in idx:
138
+ if cidx not in added_edges:
139
+ continue
140
+ pointA = endpoints[cidx]
141
+ pointB = endpoints[added_edges[cidx]]
142
+ pointA = (
143
+ pointA[1] - min_x,
144
+ pointA[0] - min_y,
145
+ )
146
+ pointB = (
147
+ pointB[1] - min_x,
148
+ pointB[0] - min_y,
149
+ )
150
+ colA = endpoints_color[cidx]
151
+ colB = endpoints_color[added_edges[cidx]]
152
+ new_fiber = cv2.line(
153
+ new_fiber,
154
+ pointA,
155
+ pointB,
156
+ color=2 if colA != colB else int(colA),
157
+ thickness=1,
158
+ )
159
+ # We express the bbox in the original image
160
+ bbox = (
161
+ coordinates[0] + min_x,
162
+ coordinates[1] + min_y,
163
+ max_x - min_x,
164
+ max_y - min_y,
165
+ )
166
+ bbox = Bbox(x=bbox[0], y=bbox[1], width=bbox[2], height=bbox[3])
167
+ result = Fiber(bbox=bbox, data=new_fiber)
168
+ results.append(result)
169
+ return results
170
+
171
+
172
+ def handle_ccs_with_junctions(
173
+ ccs: List[np.ndarray],
174
+ junctions: List[List[Tuple[int, int]]],
175
+ coordinates: List[Tuple[int, int]],
176
+ ):
177
+ """
178
+ Handle the connected components with junctions.
179
+ The function takes a list of connected components, a list of list of junctions and a list of coordinates.
180
+ The junctions
181
+ The coordinates corresponds to the top left corner of the connected component.
182
+ """
183
+ jncts_fibers = []
184
+ for fiber, junction, coordinate in zip(ccs, junctions, coordinates):
185
+ jncts_fibers += handle_multiple_fiber_in_cc(fiber, junction, coordinate)
186
+
187
+ return jncts_fibers
188
+
189
+
190
+ def refine_segmentation(segmentation, fix_junctions=True, show=False):
191
+ skeleton = skeletonize(segmentation > 0, method="lee").astype(np.uint8)
192
+ skeleton_gt = skeleton * segmentation
193
+ retval, labels, stats, centroids = cv2.connectedComponentsWithStatsWithAlgorithm(
194
+ skeleton, connectivity=8, ccltype=cv2.CCL_BOLELLI, ltype=cv2.CV_16U
195
+ )
196
+
197
+ bboxes = stats[
198
+ :,
199
+ [
200
+ cv2.CC_STAT_LEFT,
201
+ cv2.CC_STAT_TOP,
202
+ cv2.CC_STAT_WIDTH,
203
+ cv2.CC_STAT_HEIGHT,
204
+ ],
205
+ ]
206
+
207
+ local_fibers = []
208
+ coordinates = []
209
+ junctions = []
210
+ for i in range(1, retval):
211
+ bbox = bboxes[i]
212
+ x1, y1, w, h = bbox
213
+ local_gt = skeleton_gt[y1 : y1 + h, x1 : x1 + w]
214
+ local_label = (labels[y1 : y1 + h, x1 : x1 + w] == i).astype(np.uint8)
215
+ local_fiber = local_gt * local_label
216
+ local_fibers.append(local_fiber)
217
+ coordinates.append(np.asarray([x1, y1, w, h]))
218
+ local_junctions = find_line_intersection(local_fiber > 0)
219
+ local_junctions = np.where(local_junctions)
220
+ local_junctions = np.array(local_junctions).transpose()
221
+ junctions.append(local_junctions)
222
+ if show:
223
+ for bbox, junction in zip(coordinates, junctions):
224
+ x, y, w, h = bbox
225
+ junction_to_global = np.array(junction) + np.array([y, x])
226
+
227
+ plt.scatter(
228
+ junction_to_global[:, 1],
229
+ junction_to_global[:, 0],
230
+ color="white",
231
+ s=30,
232
+ alpha=0.35,
233
+ )
234
+
235
+ plt.imshow(skeleton_gt, cmap=cmlabel, interpolation="nearest")
236
+ plt.axis("off")
237
+ plt.xticks([])
238
+ plt.yticks([])
239
+ plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
240
+ plt.show()
241
+
242
+ fibers = []
243
+ if fix_junctions:
244
+ has_junctions = [len(j) > 0 for j in junctions]
245
+ for fiber, coordinate in zip(
246
+ compress(local_fibers, np.logical_not(has_junctions)),
247
+ compress(coordinates, np.logical_not(has_junctions)),
248
+ ):
249
+ bbox = Bbox(
250
+ x=coordinate[0],
251
+ y=coordinate[1],
252
+ width=coordinate[2],
253
+ height=coordinate[3],
254
+ )
255
+ fibers.append(Fiber(bbox=bbox, data=fiber))
256
+
257
+ fibers += handle_ccs_with_junctions(
258
+ compress(local_fibers, has_junctions),
259
+ compress(junctions, has_junctions),
260
+ compress(coordinates, has_junctions),
261
+ )
262
+ else:
263
+ for fiber, coordinate in zip(local_fibers, coordinates):
264
+ bbox = Bbox(
265
+ x=coordinate[0],
266
+ y=coordinate[1],
267
+ width=coordinate[2],
268
+ height=coordinate[3],
269
+ )
270
+ fibers.append(Fiber(bbox=bbox, data=fiber))
271
+
272
+ fiberprops = [FiberProps(fiber=f, fiber_id=i) for i, f in enumerate(fibers)]
273
+
274
+ return fiberprops
dnafiber/postprocess/fiber.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import attrs
2
+ import numpy as np
3
+ from typing import Tuple
4
+ from dnafiber.postprocess.skan import trace_skeleton
5
+
6
+ @attrs.define
7
+ class Bbox:
8
+ x: int
9
+ y: int
10
+ width: int
11
+ height: int
12
+
13
+ @property
14
+ def bbox(self) -> Tuple[int, int, int, int]:
15
+ return (self.x, self.y, self.width, self.height)
16
+
17
+ @bbox.setter
18
+ def bbox(self, value: Tuple[int, int, int, int]):
19
+ self.x, self.y, self.width, self.height = value
20
+
21
+
22
+ @attrs.define
23
+ class Fiber:
24
+ bbox: Bbox
25
+ data: np.ndarray
26
+
27
+
28
+ @attrs.define
29
+ class FiberProps:
30
+ fiber: Fiber
31
+ fiber_id: int
32
+ red_pixels: int = None
33
+ green_pixels: int = None
34
+ category: str = None
35
+
36
+ @property
37
+ def bbox(self):
38
+ return self.fiber.bbox.bbox
39
+
40
+ @bbox.setter
41
+ def bbox(self, value):
42
+ self.fiber.bbox = value
43
+
44
+ @property
45
+ def data(self):
46
+ return self.fiber.data
47
+
48
+ @data.setter
49
+ def data(self, value):
50
+ self.fiber.data = value
51
+
52
+ @property
53
+ def red(self):
54
+ if self.red_pixels is None:
55
+ self.red_pixels, self.green_pixels = self.counts
56
+ return self.red_pixels
57
+
58
+ @property
59
+ def green(self):
60
+ if self.green_pixels is None:
61
+ self.red_pixels, self.green_pixels = self.counts
62
+ return self.green_pixels
63
+
64
+ @property
65
+ def length(self):
66
+ return sum(self.counts)
67
+
68
+ @property
69
+ def counts(self):
70
+ if self.red_pixels is None or self.green_pixels is None:
71
+ self.red_pixels = np.sum(self.data == 1)
72
+ self.green_pixels = np.sum(self.data == 2)
73
+ return self.red_pixels, self.green_pixels
74
+
75
+ @property
76
+ def fiber_type(self):
77
+ if self.category is not None:
78
+ return self.category
79
+ red_pixels, green_pixels = self.counts
80
+ if red_pixels == 0 or green_pixels == 0:
81
+ self.category = "single"
82
+ else:
83
+ self.category = estimate_fiber_category(self.data)
84
+ return self.category
85
+
86
+ @property
87
+ def ratio(self):
88
+ return self.green / self.red
89
+
90
+ @property
91
+ def is_valid(self):
92
+ return (
93
+ self.fiber_type == "double"
94
+ or self.fiber_type == "one-two-one"
95
+ or self.fiber_type == "two-one-two"
96
+ )
97
+
98
+ def scaled_coordinates(self, scale: float) -> Tuple[int, int]:
99
+ """
100
+ Scale down the coordinates of the fiber's bounding box.
101
+ """
102
+ x, y, width, height = self.bbox
103
+ return (
104
+ int(x * scale),
105
+ int(y * scale),
106
+ int(width * scale),
107
+ int(height * scale),
108
+ )
109
+
110
+
111
+ def estimate_fiber_category(fiber: np.ndarray) -> str:
112
+ """
113
+ Estimate the fiber category based on the number of red and green pixels.
114
+ """
115
+ coordinates = trace_skeleton(fiber > 0)
116
+ coordinates = np.asarray(coordinates)
117
+ values = fiber[coordinates[:, 0], coordinates[:, 1]]
118
+ diff = np.diff(values)
119
+ jump = np.sum(diff != 0)
120
+ n_ccs = jump + 1
121
+ if n_ccs == 2:
122
+ return "double"
123
+ elif n_ccs == 3:
124
+ if values[0] == 1:
125
+ return "one-two-one"
126
+ else:
127
+ return "two-one-two"
128
+ else:
129
+ return "multiple"
dnafiber/postprocess/skan.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Functions to generate kernels of curve intersection
2
+ import numpy as np
3
+ import cv2
4
+ import itertools
5
+ from numba import njit, int64
6
+ from numba.typed import List
7
+ from numba.types import Tuple
8
+
9
+ # Define the element type: a tuple of two int64
10
+ tuple_type = Tuple((int64, int64))
11
+
12
+
13
+ def find_neighbours(fibers_map, point):
14
+ """
15
+ Find the next point in the fiber starting from the given point.
16
+ The function returns None if the point is not in the fiber.
17
+ """
18
+ # Get the fiber id
19
+ neighbors = []
20
+ h, w = fibers_map.shape
21
+ for i in range(-1, 2):
22
+ for j in range(-1, 2):
23
+ # Skip the center point
24
+ if i == 0 and j == 0:
25
+ continue
26
+ # Get the next point
27
+ nextpoint = (point[0] + i, point[1] + j)
28
+ # Check if the next point is in the image
29
+ if (
30
+ nextpoint[0] < 0
31
+ or nextpoint[0] >= h
32
+ or nextpoint[1] < 0
33
+ or nextpoint[1] >= w
34
+ ):
35
+ continue
36
+
37
+ # Check if the next point is in the fiber
38
+ if fibers_map[nextpoint]:
39
+ neighbors.append(nextpoint)
40
+ return neighbors
41
+
42
+
43
+ def compute_points_angle(fibers_map, points, steps=25):
44
+ """
45
+ For each endpoint, follow the fiber for a given number of steps and estimate the tangent line by
46
+ fitting a line to the visited points. The angle of the line is returned.
47
+ """
48
+ points_angle = np.zeros((len(points),), dtype=np.float32)
49
+ for i, point in enumerate(points):
50
+ # Find the fiber it belongs to
51
+ # Lets navigate along the fiber starting from the point during steps pixels.
52
+ # We compute the angles at each step and return the mean angle.
53
+ visited = trace_from_point(
54
+ fibers_map > 0, (point[0], point[1]), max_length=steps
55
+ )
56
+ visited = np.array(visited)
57
+ vx, vy, x, y = cv2.fitLine(visited[:, ::-1], cv2.DIST_L2, 0, 0.01, 0.01)
58
+ # Compute the angle of the line
59
+ points_angle[i] = np.arctan(vy / vx)
60
+
61
+ return points_angle
62
+
63
+
64
+ def generate_nonadjacent_combination(input_list, take_n):
65
+ """
66
+ It generates combinations of m taken n at a time where there is no adjacent n.
67
+ INPUT:
68
+ input_list = (iterable) List of elements you want to extract the combination
69
+ take_n = (integer) Number of elements that you are going to take at a time in
70
+ each combination
71
+ OUTPUT:
72
+ all_comb = (np.array) with all the combinations
73
+ """
74
+ all_comb = []
75
+ for comb in itertools.combinations(input_list, take_n):
76
+ comb = np.array(comb)
77
+ d = np.diff(comb)
78
+ if len(d[d == 1]) == 0 and comb[-1] - comb[0] != 7:
79
+ all_comb.append(comb)
80
+ return all_comb
81
+
82
+
83
+ def populate_intersection_kernel(combinations):
84
+ """
85
+ Maps the numbers from 0-7 into the 8 pixels surrounding the center pixel in
86
+ a 9 x 9 matrix clockwisely i.e. up_pixel = 0, right_pixel = 2, etc. And
87
+ generates a kernel that represents a line intersection, where the center
88
+ pixel is occupied and 3 or 4 pixels of the border are ocuppied too.
89
+ INPUT:
90
+ combinations = (np.array) matrix where every row is a vector of combinations
91
+ OUTPUT:
92
+ kernels = (List) list of 9 x 9 kernels/masks. each element is a mask.
93
+ """
94
+ n = len(combinations[0])
95
+ template = np.array(([-1, -1, -1], [-1, 1, -1], [-1, -1, -1]), dtype="int")
96
+ match = [(0, 1), (0, 2), (1, 2), (2, 2), (2, 1), (2, 0), (1, 0), (0, 0)]
97
+ kernels = []
98
+ for n in combinations:
99
+ tmp = np.copy(template)
100
+ for m in n:
101
+ tmp[match[m][0], match[m][1]] = 1
102
+ kernels.append(tmp)
103
+ return kernels
104
+
105
+
106
+ def give_intersection_kernels():
107
+ """
108
+ Generates all the intersection kernels in a 9x9 matrix.
109
+ INPUT:
110
+ None
111
+ OUTPUT:
112
+ kernels = (List) list of 9 x 9 kernels/masks. each element is a mask.
113
+ """
114
+ input_list = np.arange(8)
115
+ taken_n = [4, 3]
116
+ kernels = []
117
+ for taken in taken_n:
118
+ comb = generate_nonadjacent_combination(input_list, taken)
119
+ tmp_ker = populate_intersection_kernel(comb)
120
+ kernels.extend(tmp_ker)
121
+ return kernels
122
+
123
+
124
+ def find_line_intersection(input_image, show=0):
125
+ """
126
+ Applies morphologyEx with parameter HitsMiss to look for all the curve
127
+ intersection kernels generated with give_intersection_kernels() function.
128
+ INPUT:
129
+ input_image = (np.array dtype=np.uint8) binarized m x n image matrix
130
+ OUTPUT:
131
+ output_image = (np.array dtype=np.uint8) image where the nonzero pixels
132
+ are the line intersection.
133
+ """
134
+ input_image = input_image.astype(np.uint8)
135
+ kernel = np.array(give_intersection_kernels())
136
+ output_image = np.zeros(input_image.shape)
137
+ for i in np.arange(len(kernel)):
138
+ out = cv2.morphologyEx(
139
+ input_image,
140
+ cv2.MORPH_HITMISS,
141
+ kernel[i, :, :],
142
+ borderValue=0,
143
+ borderType=cv2.BORDER_CONSTANT,
144
+ )
145
+ output_image = output_image + out
146
+
147
+ return output_image
148
+
149
+
150
+ @njit
151
+ def get_neighbors_8(y, x, shape):
152
+ neighbors = List.empty_list(tuple_type)
153
+ for dy in range(-1, 2):
154
+ for dx in range(-1, 2):
155
+ if dy == 0 and dx == 0:
156
+ continue
157
+ ny, nx = y + dy, x + dx
158
+ if 0 <= ny < shape[0] and 0 <= nx < shape[1]:
159
+ neighbors.append((ny, nx))
160
+ return neighbors
161
+
162
+
163
+ @njit
164
+ def find_endpoints(skel):
165
+ endpoints = List.empty_list(tuple_type)
166
+ for y in range(skel.shape[0]):
167
+ for x in range(skel.shape[1]):
168
+ if skel[y, x] == 1:
169
+ count = 0
170
+ neighbors = get_neighbors_8(y, x, skel.shape)
171
+ for ny, nx in neighbors:
172
+ if skel[ny, nx] == 1:
173
+ count += 1
174
+ if count == 1:
175
+ endpoints.append((y, x))
176
+ return endpoints
177
+
178
+
179
+ @njit
180
+ def trace_skeleton(skel):
181
+ endpoints = find_endpoints(skel)
182
+ if len(endpoints) < 1:
183
+ return List.empty_list(tuple_type) # Return empty list with proper type
184
+
185
+ return trace_from_point(skel, endpoints[0], max_length=skel.sum())
186
+
187
+
188
+ @njit
189
+ def trace_from_point(skel, point, max_length=25):
190
+ visited = np.zeros_like(skel, dtype=np.uint8)
191
+ path = List.empty_list(tuple_type)
192
+
193
+ # Check if the starting point is on the skeleton
194
+ y, x = point
195
+ if y < 0 or y >= skel.shape[0] or x < 0 or x >= skel.shape[1] or skel[y, x] != 1:
196
+ return path
197
+
198
+ stack = List.empty_list(tuple_type)
199
+ stack.append(point)
200
+
201
+ while len(stack) > 0 and len(path) < max_length:
202
+ y, x = stack.pop()
203
+ if visited[y, x]:
204
+ continue
205
+ visited[y, x] = 1
206
+ path.append((y, x))
207
+ neighbors = get_neighbors_8(y, x, skel.shape)
208
+ for ny, nx in neighbors:
209
+ if skel[ny, nx] == 1 and not visited[ny, nx]:
210
+ stack.append((ny, nx))
211
+ return path
dnafiber/start.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import os
3
+
4
+
5
+ def main():
6
+ # Start the Streamlit application
7
+ print("Starting Streamlit application...")
8
+ local_dir = os.path.dirname(os.path.abspath(__file__))
9
+ subprocess.run(
10
+ [
11
+ "streamlit",
12
+ "run",
13
+ os.path.join(local_dir, "ui", "Welcome.py"),
14
+ "--server.maxUploadSize",
15
+ "1024",
16
+ ],
17
+ )
18
+
19
+
20
+ if __name__ == "__main__":
21
+ main()
22
+ print("Streamlit application started successfully.")
dnafiber/trainee.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lightning import LightningModule
2
+ import segmentation_models_pytorch as smp
3
+ from monai.losses.dice import GeneralizedDiceLoss
4
+ from monai.losses.cldice import SoftDiceclDiceLoss
5
+ from torchmetrics.classification import Dice, JaccardIndex
6
+ from torch.optim import AdamW
7
+ from torch.optim.lr_scheduler import CosineAnnealingLR
8
+ from torchmetrics import MetricCollection
9
+ import torch.nn.functional as F
10
+ from huggingface_hub import PyTorchModelHubMixin
11
+ import torch
12
+ import torchvision
13
+ from dnafiber.metric import DNAFIBERMetric
14
+
15
+
16
+ class Trainee(LightningModule, PyTorchModelHubMixin):
17
+ def __init__(
18
+ self, learning_rate=0.001, weight_decay=0.0002, num_classes=3, **model_config
19
+ ):
20
+ super().__init__()
21
+ self.model_config = model_config
22
+ if (
23
+ self.model_config.get("arch", None) is None
24
+ or self.model_config["arch"] == "maskrcnn"
25
+ ):
26
+ self.model = None
27
+ else:
28
+ self.model = smp.create_model(classes=3, **self.model_config, dropout=0.2)
29
+ self.loss = GeneralizedDiceLoss(to_onehot_y=False, softmax=False)
30
+ self.metric = MetricCollection(
31
+ {
32
+ "dice": Dice(num_classes=num_classes, ignore_index=0),
33
+ "jaccard": JaccardIndex(
34
+ num_classes=num_classes,
35
+ task="multiclass" if num_classes > 2 else "binary",
36
+ ignore_index=0,
37
+ ),
38
+ "detection": DNAFIBERMetric(),
39
+ }
40
+ )
41
+ self.weight_decay = weight_decay
42
+ self.learning_rate = learning_rate
43
+ self.save_hyperparameters()
44
+
45
+ def forward(self, x):
46
+ yhat = self.model(x)
47
+ return yhat
48
+
49
+ def training_step(self, batch, batch_idx):
50
+ x, y = batch["image"], batch["mask"]
51
+ y = y.clamp(0, 2)
52
+ y_hat = self(x)
53
+ loss = self.get_loss(y_hat, y)
54
+
55
+ self.log("train_loss", loss)
56
+
57
+ return loss
58
+
59
+ def get_loss(self, y_hat, y):
60
+ y_hat = F.softmax(y_hat, dim=1)
61
+ y = F.one_hot(y.long(), num_classes=3)
62
+ y = y.permute(0, 3, 1, 2).float()
63
+ loss = self.loss(y_hat, y)
64
+ return loss
65
+
66
+ def validation_step(self, batch, batch_idx):
67
+ x, y = batch["image"], batch["mask"]
68
+ y = y.clamp(0, 2)
69
+ y_hat = self(x)
70
+ loss = self.get_loss(y_hat, y)
71
+ self.log("val_loss", loss, on_step=False, on_epoch=True, sync_dist=True)
72
+ self.metric.update(y_hat, y)
73
+ return y_hat
74
+
75
+ def on_validation_epoch_end(self):
76
+ scores = self.metric.compute()
77
+ self.log_dict(scores, sync_dist=True)
78
+ self.metric.reset()
79
+
80
+ def configure_optimizers(self):
81
+ optimizer = AdamW(
82
+ self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay
83
+ )
84
+ scheduler = CosineAnnealingLR(
85
+ optimizer,
86
+ T_max=self.trainer.max_epochs, # type: ignore
87
+ eta_min=self.learning_rate / 25,
88
+ )
89
+ scheduler = {
90
+ "scheduler": scheduler,
91
+ "interval": "epoch",
92
+ }
93
+ return [optimizer], [scheduler]
94
+
95
+
96
+ class TraineeMaskRCNN(Trainee):
97
+ def __init__(self, learning_rate=0.001, weight_decay=0.0002, **model_config):
98
+ super().__init__(learning_rate, weight_decay, **model_config)
99
+ self.model = torchvision.models.get_model("maskrcnn_resnet50_fpn_v2")
100
+
101
+ def forward(self, x):
102
+ yhat = self.model(x)
103
+ return yhat
104
+
105
+ def training_step(self, batch, batch_idx):
106
+ image = batch["image"]
107
+ targets = batch["targets"]
108
+ loss_dict = self.model(image, targets)
109
+ losses = sum(loss for loss in loss_dict.values())
110
+ self.log("train_loss", losses, on_step=True, on_epoch=False, sync_dist=True)
111
+ return losses
112
+
113
+ def validation_step(self, batch, batch_idx):
114
+ image = batch["image"]
115
+ targets = batch["targets"]
116
+
117
+ predictions = self.model(image)
118
+ b = len(predictions)
119
+ predicted_masks = []
120
+ gt_masks = []
121
+ for i in range(b):
122
+ scores = predictions[i]["scores"]
123
+ masks = predictions[i]["masks"]
124
+ good_masks = masks[scores > 0.5]
125
+ # Combined into a single mask
126
+ good_masks = torch.sum(good_masks, dim=0)
127
+ predicted_masks.append(good_masks)
128
+ gt_masks.append(targets[i]["masks"].sum(dim=0))
129
+
130
+ gt_masks = torch.stack(gt_masks).squeeze(1) > 0
131
+ predicted_masks = torch.stack(predicted_masks).squeeze(1) > 0
132
+ self.metric.update(predicted_masks, gt_masks)
133
+ return predictions
134
+
135
+ def configure_optimizers(self):
136
+ optimizer = AdamW(
137
+ self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay
138
+ )
139
+ scheduler = CosineAnnealingLR(
140
+ optimizer,
141
+ T_max=self.trainer.max_epochs, # type: ignore
142
+ eta_min=self.learning_rate / 25,
143
+ )
144
+ scheduler = {
145
+ "scheduler": scheduler,
146
+ "interval": "epoch",
147
+ }
148
+ return [optimizer], [scheduler]
dnafiber/ui/Welcome.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+
4
+
5
+ def main():
6
+ st.set_page_config(
7
+ page_title="Hello",
8
+ page_icon="🧬",
9
+ layout="wide",
10
+ )
11
+ st.write("# Welcome to DN-AI! 👋")
12
+
13
+ st.write(
14
+ "This is a web application for the DN-AI project, which aims to provide an easy-to-use interface for analyzing and processing fiber images."
15
+ )
16
+ st.write("## Features")
17
+ st.write(
18
+ "- **Image loading**: The application accepts CZI file, jpeg and PNG file. \n"
19
+ "- **Image segmentation**: The application provides a set of tools to segment the DNA fiber and measure the ratio between analogs. \n"
20
+ )
21
+ st.write("## Technical details")
22
+ cols = st.columns(2)
23
+ with cols[0]:
24
+ st.write("### Source")
25
+ st.write("The source code for this application is available on GitHub.")
26
+ """
27
+ [![Repo](https://badgen.net/badge/icon/GitHub?icon=github&label)](https://github.com/ClementPla/DeepFiberQ/tree/relabelled)
28
+
29
+ """
30
+ st.markdown("<br>", unsafe_allow_html=True)
31
+
32
+ with cols[1]:
33
+ st.write("### Device ")
34
+ st.write("If available, the application will try to use a GPU for processing.")
35
+ device = "GPU" if torch.cuda.is_available() else "CPU"
36
+ cols = st.columns(3)
37
+ with cols[0]:
38
+ st.write("Running on:")
39
+ with cols[1]:
40
+ st.button(device, icon="⚙️", disabled=True)
41
+ if not torch.cuda.is_available():
42
+ with cols[2]:
43
+ st.warning("The application will run on CPU, which may be slower.")
44
+
45
+
46
+ if __name__ == "__main__":
47
+ main()
dnafiber/ui/__init__.py ADDED
File without changes
dnafiber/ui/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (165 Bytes). View file
 
dnafiber/ui/__pycache__/inference.cpython-312.pyc ADDED
Binary file (2.32 kB). View file
 
dnafiber/ui/__pycache__/utils.cpython-312.pyc ADDED
Binary file (7.33 kB). View file
 
dnafiber/ui/inference.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from dnafiber.inference import infer
3
+ from dnafiber.postprocess.core import refine_segmentation
4
+ import numpy as np
5
+ from dnafiber.deployment import _get_model
6
+ import torch
7
+
8
+
9
+ @st.cache_data
10
+ def ui_inference(_model, _image, _device, postprocess=True, id=None):
11
+ return ui_inference_cacheless(
12
+ _model, _image, _device, postprocess=postprocess, id=id
13
+ )
14
+
15
+
16
+ @st.cache_resource
17
+ def get_model(model_name):
18
+ model = _get_model(
19
+ device="cuda" if torch.cuda.is_available() else "cpu",
20
+ revision=model_name,
21
+ )
22
+ return model
23
+
24
+
25
+ def ui_inference_cacheless(_model, _image, _device, postprocess=True, id=None):
26
+ """
27
+ A cacheless version of the ui_inference function.
28
+ This function does not use caching and is intended for use in scenarios where caching is not desired.
29
+ """
30
+ h, w = _image.shape[:2]
31
+ with st.spinner("Sliding window segmentation in progress..."):
32
+ if isinstance(_model, list):
33
+ output = None
34
+ for model in _model:
35
+ if isinstance(model, str):
36
+ model = get_model(model)
37
+ with st.spinner(text="Segmenting with model: {}".format(model)):
38
+ if output is None:
39
+ output = infer(
40
+ model,
41
+ image=_image,
42
+ device=_device,
43
+ scale=st.session_state.get("pixel_size", 0.13),
44
+ only_probabilities=True,
45
+ ).cpu()
46
+ else:
47
+ output = (
48
+ output
49
+ + infer(
50
+ model,
51
+ image=_image,
52
+ device=_device,
53
+ scale=st.session_state.get("pixel_size", 0.13),
54
+ only_probabilities=True,
55
+ ).cpu()
56
+ )
57
+ output = (output / len(_model)).argmax(1).squeeze().numpy()
58
+ else:
59
+ output = infer(
60
+ _model,
61
+ image=_image,
62
+ device=_device,
63
+ scale=st.session_state.get("pixel_size", 0.13),
64
+ )
65
+ output = output.astype(np.uint8)
66
+ if postprocess:
67
+ with st.spinner("Post-processing segmentation..."):
68
+ output = refine_segmentation(output, fix_junctions=postprocess)
69
+ return output
dnafiber/ui/pages/1_Load.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+
4
+ st.set_page_config(
5
+ page_title="DN-AI",
6
+ page_icon="🔬",
7
+ layout="wide",
8
+ )
9
+
10
+ def build_multichannel_loader():
11
+
12
+ if (
13
+ st.session_state.get("files_uploaded", None) is None
14
+ or len(st.session_state.files_uploaded) == 0
15
+ ):
16
+ st.session_state["files_uploaded"] = st.file_uploader(
17
+ label="Upload files",
18
+ accept_multiple_files=True,
19
+ type=["czi", "jpeg", "jpg", "png", "tiff", "tif"],
20
+ )
21
+ else:
22
+ st.session_state["files_uploaded"] += st.file_uploader(
23
+ label="Upload files",
24
+ accept_multiple_files=True,
25
+ type=["czi", "jpeg", "jpg", "png", "tiff", "tif"],
26
+ )
27
+ st.write("### Channel interpretation")
28
+ st.markdown("The goal is to obtain an RGB image in the order of <span style='color: red;'>First analog</span>, <span style='color: green;'>Second analog</span>, <span style='color: blue;'>Empty</span>.", unsafe_allow_html=True)
29
+ st.markdown("By default, we assume that the first channel in CZI/TIFF file is <span style='color: green;'>the second analog</span>, (which happens to be the case in Zeiss microscope) " \
30
+ "which means that we swap the order of the two channels for processing.", unsafe_allow_html=True)
31
+ st.write("If this not the intented behavior, please tick the box below:")
32
+ st.session_state["reverse_channels"] = st.checkbox(
33
+ "Reverse the channels interpretation",
34
+ value=False,
35
+ )
36
+ st.warning("Please note that we only swap the channels for raw (CZI, TIFF) files. JPEG and PNG files "\
37
+ "are assumed to be already in the correct order (First analog in red and second analog in green). " \
38
+ )
39
+
40
+ st.info("" \
41
+ "The channels order in CZI files does not necessarily match the order in which they are displayed in ImageJ or equivalent. " \
42
+ "Indeed, such viewers will usually look at the metadata of the file to determine the order of the channels, which we don't. " \
43
+ "In doubt, we recommend visualizing the image in ImageJ and compare with our viewer. If the channels appear reversed, tick the option above.")
44
+
45
+ def build_individual_loader():
46
+
47
+ cols = st.columns(2)
48
+ with cols[1]:
49
+ st.markdown(f"<h3 style='color: {st.session_state['color2']};'>Second analog</h3>", unsafe_allow_html=True)
50
+
51
+ if (
52
+ st.session_state.get("analog_2_files", None) is None
53
+ or len(st.session_state.analog_2_files) == 0
54
+ ):
55
+ st.session_state["analog_2_files"] = st.file_uploader(
56
+ label="Upload second analog file(s)",
57
+ accept_multiple_files=True,
58
+ type=["czi", "jpeg", "jpg", "png", "tiff", "tif"],
59
+ )
60
+ else:
61
+ st.session_state["analog_2_files"] += st.file_uploader(
62
+ label="Upload second analog file(s)",
63
+ accept_multiple_files=True,
64
+ type=["czi", "jpeg", "jpg", "png", "tiff", "tif"],
65
+ )
66
+
67
+
68
+ with cols[0]:
69
+ st.markdown(f"<h3 style='color: {st.session_state['color1']};'>First analog</h3>", unsafe_allow_html=True)
70
+ if (
71
+ st.session_state.get("analog_1_files", None) is None
72
+ or len(st.session_state.analog_1_files) == 0
73
+ ):
74
+ st.session_state["analog_1_files"] = st.file_uploader(
75
+ label="Upload first analog file(s)",
76
+ accept_multiple_files=True,
77
+ type=["czi", "jpeg", "jpg", "png", "tiff", "tif"],
78
+ )
79
+ else:
80
+ st.session_state["analog_1_files"] += st.file_uploader(
81
+ label="Upload first analog file(s)",
82
+ accept_multiple_files=True,
83
+ type=["czi", "jpeg", "jpg", "png", "tiff", "tif"],)
84
+
85
+ analog_1_files=st.session_state.get("analog_1_files", None)
86
+ analog_2_files=st.session_state.get("analog_2_files", None)
87
+
88
+ # Remove duplicates from the list of files. We loop through the files and keep only the first occurrence of each file_id.
89
+ def remove_duplicates(files):
90
+ seen_ids = set()
91
+ unique_files = []
92
+ for file in files:
93
+ if file and file.name not in seen_ids:
94
+ unique_files.append(file)
95
+ seen_ids.add(file.name)
96
+ return unique_files
97
+
98
+ analog_1_files = remove_duplicates(analog_1_files or [])
99
+ analog_2_files = remove_duplicates(analog_2_files or [])
100
+
101
+
102
+ if analog_1_files is None and analog_2_files is None:
103
+ return
104
+ else:
105
+ if len(analog_1_files)>0 and len(analog_2_files)>0 and len(analog_1_files) != len(analog_2_files):
106
+ st.error("Please upload the same number of analogs files.")
107
+ return
108
+
109
+ # Always make sure we don't have duplicates in the list of files
110
+
111
+ analog_1_files = sorted(analog_1_files, key=lambda x: x.name)
112
+ analog_2_files = sorted(analog_2_files, key=lambda x: x.name)
113
+ max_size = max(len(analog_1_files), len(analog_2_files))
114
+ # Pad the shorter list with None
115
+ if len(analog_1_files) < max_size:
116
+ analog_1_files += [None] * (max_size - len(analog_1_files))
117
+ if len(analog_2_files) < max_size:
118
+ analog_2_files += [None] * (max_size - len(analog_2_files))
119
+
120
+ combined_files = list(zip(analog_1_files, analog_2_files))
121
+
122
+
123
+
124
+ if (
125
+ st.session_state.get("files_uploaded", None) is None
126
+ or len(st.session_state.files_uploaded) == 0
127
+ ):
128
+ st.session_state["files_uploaded"] = combined_files
129
+ else:
130
+ st.session_state["files_uploaded"] += combined_files
131
+
132
+
133
+
134
+ # If any of the files (analog_1_files or analog_2_files) was included previously in the files_uploaded,
135
+ # We remove the previous occurence from the files_uploaded list.
136
+ current_ids = set()
137
+ for f in analog_1_files + analog_2_files:
138
+ if f:
139
+ current_ids.add(f.name)
140
+
141
+ # Safely filter the list to exclude any files with matching file_ids
142
+ def is_not_duplicate(file):
143
+ if isinstance(file, tuple):
144
+ f1, f2 = file
145
+ if f1 and f2:
146
+ return True
147
+
148
+ return (f1 is None or f1.name not in current_ids) and (f2 is None or f2.name not in current_ids)
149
+ else:
150
+ return True
151
+
152
+ st.session_state.files_uploaded = [f for f in st.session_state.files_uploaded if is_not_duplicate(f)]
153
+
154
+
155
+
156
+ cols = st.columns(2)
157
+ with cols[1]:
158
+
159
+
160
+ st.write("### Pixel size")
161
+ st.session_state["pixel_size"] = st.number_input(
162
+ "Please indicate the pixel size of the image in µm (default: 0.13 µm).",
163
+ value=st.session_state.get("pixel_size", 0.13),
164
+ )
165
+ # In small, lets precise the tehnical details
166
+ st.write(
167
+ "The pixel size is used to convert the pixel coordinates to µm. " \
168
+ "The model is trained on images with a pixel size of 0.26 µm, and the application automatically " \
169
+ "resizes the images to match this pixel size using your provided choice."
170
+ )
171
+
172
+ st.write("### Labels color")
173
+ color_choices = st.columns(2)
174
+ with color_choices[0]:
175
+ st.session_state["color1"] = st.color_picker(
176
+ "Select the color for first analog",
177
+ value=st.session_state.get("color1", "#FF0000"),
178
+ help="This color will be used to display the first analog segments.")
179
+ with color_choices[1]:
180
+ st.session_state["color2"] = st.color_picker(
181
+ "Select the color for second analog",
182
+ value=st.session_state.get("color2", "#00FF00"),
183
+ help="This color will be used to display the second analog segments.")
184
+
185
+ with cols[0]:
186
+ choice = st.segmented_control(
187
+ "Please select the type of images you want to upload:",
188
+ options=["Multichannel", "Individual channel"],
189
+ default="Multichannel",
190
+ )
191
+ if choice == "Individual channel":
192
+ build_individual_loader()
193
+ else:
194
+ build_multichannel_loader()
195
+
196
+