diff --git a/README.md b/README.md
index c0e7f1250da9c5b13a344240b6e708a520576d36..c8357a2347bd0b97667b8176fa02d6aac884cf1a 100644
--- a/README.md
+++ b/README.md
@@ -1,20 +1,59 @@
----
-title: DNAI
-emoji: 🚀
-colorFrom: red
-colorTo: red
-sdk: docker
-app_port: 8501
-tags:
-- streamlit
-pinned: false
-short_description: DNA Fiber semantic segmentation for replication assessment
-license: mit
----
-
-# Welcome to Streamlit!
-
-Edit `/src/streamlit_app.py` to customize this app to your heart's desire. :heart:
-
-If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
-forums](https://discuss.streamlit.io).
+# DN-AI
+
+This is the official repository for DN-AI, an automated tool for measurement of differentiated DNA replication in fluorescence microscopy images.
+
+DN-AI offers different solutions for biologists to measure DNA replication in fluorescence microscopy images, without requiring programming skills. See the [Installation](#installation) section for instructions on how to install DN-AI.
+
+## Features
+
+- **Automated DNA replication measurement**: DN-AI can automatically measure the amount of DNA replication in fluorescence microscopy images. We use a deep learning model to segment the images and measure the amount of DNA replication.
+- **User-friendly interface**: DN-AI provides a web-based user-friendly interface that allows users to easily upload images and view the results. Both jpeg and tiff images are supported.
+- **Batch processing**: DN-AI can process multiple images at once, making it easy to analyze large datasets. It also supports comparing ratios between different batches of images.
+
+
+## Installation
+
+DN-AI relies on Python. We recommend installing its latest version (3.10 or higher) and using a virtual environment to avoid conflicts with other packages.
+
+### Prerequisites
+Before installing DN-AI, make sure you have the following prerequisites installed:
+- [Python 3.10 or higher](https://www.python.org/downloads/)
+- [pip](https://pip.pypa.io/en/stable/installation/) (Python package installer)
+
+### Python Package
+To install DN-AI as a Python package, you can use pip:
+
+```bash
+pip install git+https://github.com/ClementPla/DeepFiberQ.git
+```
+
+
+### Graphical User Interface (GUI)
+
+To run the DN-AI graphical user interface, you can use the following command:
+
+```bash
+DNAI
+```
+
+Make sure you are running this command in the terminal where you have installed DN-AI. This will start a local web server and you will see output similar to:
+
+
+Then open your web browser and go to `http://localhost:8501` to access the DN-AI interface.
+
+Screenshots of the GUI:
+
+
+
+
+
+### Docker
+A Docker image is available for DN-AI. You can pull the image from Docker Hub:
+
+```bash
+docker pull clementpla/dnafiber
+```
+
+### Google Colab
+We also provide a Google Colab notebook for DN-AI. You can access it [here](https://colab.research.google.com/github/ClementPla/DeepFiberQ/blob/main/Colab/DNA_Fiber_Q.ipynb).
+
diff --git a/dnafiber/__init__.py b/dnafiber/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0af7c0bdc39e9496dd9ce13b58670c432dc7ed69
--- /dev/null
+++ b/dnafiber/__init__.py
@@ -0,0 +1 @@
+from dnafiber.deployment import _get_model
diff --git a/dnafiber/__pycache__/__init__.cpython-312.pyc b/dnafiber/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dbd6e37a07fcfa7a2643ec7943b810897fa211ed
Binary files /dev/null and b/dnafiber/__pycache__/__init__.cpython-312.pyc differ
diff --git a/dnafiber/__pycache__/deployment.cpython-312.pyc b/dnafiber/__pycache__/deployment.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..770cfcbe5b6ddd3b0a8e57b2c67c824592159a13
Binary files /dev/null and b/dnafiber/__pycache__/deployment.cpython-312.pyc differ
diff --git a/dnafiber/__pycache__/inference.cpython-312.pyc b/dnafiber/__pycache__/inference.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ed38d561a723dbdd96b13e5365f1c844c6ffe8db
Binary files /dev/null and b/dnafiber/__pycache__/inference.cpython-312.pyc differ
diff --git a/dnafiber/__pycache__/metric.cpython-312.pyc b/dnafiber/__pycache__/metric.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4b7d15adb0604ccf6dde5ba60d05491250742f34
Binary files /dev/null and b/dnafiber/__pycache__/metric.cpython-312.pyc differ
diff --git a/dnafiber/__pycache__/post_process.cpython-312.pyc b/dnafiber/__pycache__/post_process.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f6ca6169409036efc2298fa2307b81341de9a9d2
Binary files /dev/null and b/dnafiber/__pycache__/post_process.cpython-312.pyc differ
diff --git a/dnafiber/__pycache__/trainee.cpython-312.pyc b/dnafiber/__pycache__/trainee.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ba7d271e535bd6475e0baa6bd77b37c483730320
Binary files /dev/null and b/dnafiber/__pycache__/trainee.cpython-312.pyc differ
diff --git a/dnafiber/analysis/__init__.py b/dnafiber/analysis/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/dnafiber/analysis/chart.py b/dnafiber/analysis/chart.py
new file mode 100644
index 0000000000000000000000000000000000000000..c60087e4b1bb890a7697cb5adbea545f57df380e
--- /dev/null
+++ b/dnafiber/analysis/chart.py
@@ -0,0 +1,61 @@
+import pandas as pd
+from dnafiber.analysis.const import palette
+import plotly.express as px
+
+
+def get_color_association(df):
+ """
+ Get the color association for each image in the dataframe.
+ """
+ unique_name = df["image_name"].unique()
+ color_association = {i: p for (i, p) in zip(unique_name, palette)}
+ return color_association
+
+
+def plot_ratio(df, color_association=None, only_bilateral=True):
+ df = df[["ratio", "image_name", "fiber_type"]].copy()
+
+ df["Image"] = df["image_name"]
+ df["Fiber Type"] = df["fiber_type"]
+ df["Ratio"] = df["ratio"]
+ if only_bilateral:
+ df = df[df["Fiber Type"] == "double"]
+
+ df = df.sort_values(
+ by=["Image", "Fiber Type"],
+ ascending=[True, True],
+ )
+
+ # Order the dataframe by the average ratio of each image
+ image_order = (
+ df.groupby("Image")["Ratio"].median().sort_values(ascending=True).index
+ )
+ df["Image"] = pd.Categorical(df["Image"], categories=image_order, ordered=True)
+ df.sort_values("Image", inplace=True)
+ if color_association is None:
+ color_association = get_color_association(df)
+ unique_name = df["image_name"].unique()
+ color_association = {i: p for (i, p) in zip(unique_name, palette)}
+
+ this_palette = [color_association[i] for i in unique_name]
+ fig = px.violin(
+ df,
+ y="Ratio",
+ x="Image",
+ color="Image",
+ color_discrete_sequence=this_palette,
+ box=True, # draw box plot inside the violin
+ points="all", # can be 'outliers', or False
+ )
+
+ # Make the fig taller
+
+ fig.update_layout(
+ height=500,
+ width=1000,
+ title="Ratio of green to red",
+ yaxis_title="Ratio",
+ xaxis_title="Image",
+ legend_title="Image",
+ )
+ return fig
diff --git a/dnafiber/analysis/const.py b/dnafiber/analysis/const.py
new file mode 100644
index 0000000000000000000000000000000000000000..21121587325a77401fc3cd26ac9164b0a92a0ae3
--- /dev/null
+++ b/dnafiber/analysis/const.py
@@ -0,0 +1,3 @@
+from catppuccin.palette import PALETTE
+
+palette = [c.hex for c in PALETTE.latte.colors]
diff --git a/dnafiber/analysis/utils.py b/dnafiber/analysis/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a158f17ee2eedb2f461d14042ec2e15a1cbd6a5
--- /dev/null
+++ b/dnafiber/analysis/utils.py
@@ -0,0 +1,21 @@
+from tqdm.auto import tqdm
+from dnafiber.data.utils import read_colormask
+import numpy as np
+
+
+def build_consensus_map(intergraders, root_img, list_img):
+ all_masks = []
+ for img_path in tqdm(list_img):
+ path_from_root = img_path.relative_to(root_img)
+ masks = []
+ for intergrader in intergraders:
+ intergrader_path = (intergrader / path_from_root).with_suffix(".png")
+ if not intergrader_path.exists():
+ print(f"Missing {intergrader_path}")
+ continue
+ mask = read_colormask(intergrader_path)
+ masks.append(mask)
+ masks = np.array(masks)
+
+ all_masks.append(masks)
+ return np.array(all_masks)
diff --git a/dnafiber/callbacks.py b/dnafiber/callbacks.py
new file mode 100644
index 0000000000000000000000000000000000000000..d58d978a24c50b07c45e729faccfef8d45e709b3
--- /dev/null
+++ b/dnafiber/callbacks.py
@@ -0,0 +1,50 @@
+from lightning.pytorch.callbacks import Callback
+from pytorch_lightning.utilities import rank_zero_only
+import wandb
+
+
+class LogPredictionSamplesCallback(Callback):
+ def __init__(self, wandb_logger, n_images=8):
+ self.n_images = n_images
+ self.wandb_logger = wandb_logger
+ super().__init__()
+
+ @rank_zero_only
+ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
+ if batch_idx < 1 and trainer.is_global_zero:
+ n = self.n_images
+ x = batch["image"][:n].float()
+ h, w = x.shape[-2:]
+ y = batch["mask"][:n]
+ pred = outputs[:n]
+ pred = pred.argmax(dim=1)
+
+ if len(y.shape) == 4:
+ y = y.squeeze(1)
+ if len(pred.shape) == 4:
+ pred = pred.squeeze(1)
+ y = y.clamp(0, 2)
+ columns = ["image"]
+ class_labels = {0: "Background", 1: "Red", 2: "Green"}
+
+ data = [
+ [
+ wandb.Image(
+ x_i,
+ masks={
+ "Prediction": {
+ "mask_data": p_i.cpu().numpy(),
+ "class_labels": class_labels,
+ },
+ "Groundtruth": {
+ "mask_data": y_i.cpu().numpy(),
+ "class_labels": class_labels,
+ },
+ },
+ )
+ ]
+ for x_i, y_i, p_i in list(zip(x, y, pred))
+ ]
+ self.wandb_logger.log_table(
+ data=data, key=f"Validation Batch {batch_idx}", columns=columns
+ )
diff --git a/dnafiber/data/__init__.py b/dnafiber/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/dnafiber/data/__pycache__/__init__.cpython-312.pyc b/dnafiber/data/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5c1bc70ad0f1c2d99af47542e2f3477e3c601268
Binary files /dev/null and b/dnafiber/data/__pycache__/__init__.cpython-312.pyc differ
diff --git a/dnafiber/data/__pycache__/utils.cpython-312.pyc b/dnafiber/data/__pycache__/utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0987d53b295fe7c72074cae27c9651ea3dcb170d
Binary files /dev/null and b/dnafiber/data/__pycache__/utils.cpython-312.pyc differ
diff --git a/dnafiber/data/dataset.py b/dnafiber/data/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..550497f0489d51da06bfa610af692bf86896f4b9
--- /dev/null
+++ b/dnafiber/data/dataset.py
@@ -0,0 +1,271 @@
+import albumentations as A
+import nntools.dataset as D
+import numpy as np
+from albumentations.pytorch import ToTensorV2
+from lightning import LightningDataModule
+from sklearn.model_selection import train_test_split
+from torch.utils.data import DataLoader
+from skimage.measure import label, regionprops
+from skimage.morphology import skeletonize, dilation
+from skimage.segmentation import expand_labels
+import torch
+from nntools.dataset.composer import CacheBullet
+
+
+@D.nntools_wrapper
+def convert_mask(mask):
+ output = np.zeros(mask.shape[:2], dtype=np.uint8)
+ output[mask[:, :, 0] > 200] = 1
+ output[mask[:, :, 1] > 200] = 2
+ binary_mask = output > 0
+ skeleton = skeletonize(binary_mask) * output
+ output = expand_labels(skeleton, 3)
+ output = np.clip(output, 0, 2)
+ return {"mask": output}
+
+
+@D.nntools_wrapper
+def extract_bbox(mask):
+ binary_mask = mask > 0
+ labelled = label(binary_mask)
+ props = regionprops(labelled, intensity_image=mask)
+ skeleton = skeletonize(binary_mask) * mask
+ mask = dilation(skeleton, np.ones((3, 3)))
+ bboxes = []
+ masks = []
+ # We want the XYXY format
+ for prop in props:
+ minr, minc, maxr, maxc = prop.bbox
+ bboxes.append([minc, minr, maxc, maxr])
+ masks.append((labelled == prop.label).astype(np.uint8))
+ if not masks:
+ masks = np.zeros_like(mask)[np.newaxis, :, :]
+ masks = np.array(masks)
+ masks = np.moveaxis(masks, 0, -1)
+
+ return {
+ "bboxes": np.array(bboxes),
+ "mask": masks,
+ "fiber_ids": np.array([p.label for p in props]),
+ }
+
+
+class FiberDatamodule(LightningDataModule):
+ def __init__(
+ self,
+ root_img,
+ crop_size=(256, 256),
+ shape=1024,
+ batch_size=32,
+ num_workers=8,
+ use_bbox=False,
+ **kwargs,
+ ):
+ self.shape = shape
+ self.root_img = str(root_img)
+ self.crop_size = crop_size
+ self.batch_size = batch_size
+ self.num_workers = num_workers
+ self.kwargs = kwargs
+ self.use_bbox = use_bbox
+
+ super().__init__()
+
+ def setup(self, *args, **kwargs):
+ def _get_dataset(version):
+ dataset = D.MultiImageDataset(
+ {
+ "image": f"{self.root_img}/{version}/images/",
+ "mask": f"{self.root_img}/{version}/annotations/",
+ },
+ shape=(self.shape, self.shape),
+ use_cache=self.kwargs.get("use_cache", False),
+ cache_option=self.kwargs.get("cache_option", None),
+ ) # type: ignore
+ dataset.img_filepath["image"] = np.asarray( # type: ignore
+ sorted(
+ list(dataset.img_filepath["image"]),
+ key=lambda x: (x.parent.stem, x.stem),
+ )
+ )
+ dataset.img_filepath["mask"] = np.asarray( # type: ignore
+ sorted(
+ list(dataset.img_filepath["mask"]),
+ key=lambda x: (x.parent.stem, x.stem),
+ )
+ )
+ dataset.composer = D.Composition()
+ dataset.composer << convert_mask # type: ignore
+ if self.use_bbox:
+ dataset.composer << extract_bbox
+
+ return dataset
+
+ self.train = _get_dataset("train")
+ self.val = _get_dataset("train")
+ self.test = _get_dataset("test")
+ self.train.composer << CacheBullet()
+ self.val.use_cache = False
+ self.test.use_cache = False
+
+ stratify = []
+ for f in self.train.img_filepath["image"]:
+ if "tile" in f.stem:
+ stratify.append(int(f.parent.stem))
+ else:
+ stratify.append(25)
+ train_idx, val_idx = train_test_split(
+ np.arange(len(self.train)), # type: ignore
+ stratify=stratify,
+ test_size=0.2,
+ random_state=42,
+ )
+ self.train.subset(train_idx)
+ self.val.subset(val_idx)
+
+ self.train.composer.add(*self.get_train_composer())
+ self.val.composer.add(*self.cast_operators())
+ self.test.composer.add(*self.cast_operators())
+
+ def get_train_composer(self):
+ transforms = []
+ if self.crop_size is not None:
+ transforms.append(
+ A.CropNonEmptyMaskIfExists(
+ width=self.crop_size[0], height=self.crop_size[1]
+ ),
+ )
+ return [
+ A.Compose(
+ transforms
+ + [
+ A.HorizontalFlip(),
+ A.VerticalFlip(),
+ A.Affine(),
+ A.ElasticTransform(),
+ A.RandomRotate90(),
+ A.OneOf(
+ [
+ A.RandomBrightnessContrast(
+ brightness_limit=(-0.2, 0.1),
+ contrast_limit=(-0.2, 0.1),
+ p=0.5,
+ ),
+ A.HueSaturationValue(
+ hue_shift_limit=(-5, 5),
+ sat_shift_limit=(-20, 20),
+ val_shift_limit=(-20, 20),
+ p=0.5,
+ ),
+ ]
+ ),
+ A.GaussNoise(std_range=(0.0, 0.1), p=0.5),
+ ],
+ bbox_params=A.BboxParams(
+ format="pascal_voc", label_fields=["fiber_ids"], min_visibility=0.95
+ )
+ if self.use_bbox
+ else None,
+ ),
+ *self.cast_operators(),
+ ]
+
+ def cast_operators(self):
+ return [
+ A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
+ if not self.use_bbox
+ else A.Normalize(
+ mean=(
+ 0.0,
+ 0.0,
+ 0.0,
+ ),
+ std=(1.0, 1.0, 1.0),
+ max_pixel_value=255,
+ ),
+ ToTensorV2(),
+ ]
+
+ def train_dataloader(self):
+ if self.use_bbox:
+ return DataLoader(
+ self.train,
+ batch_size=self.batch_size,
+ shuffle=True,
+ num_workers=self.num_workers,
+ pin_memory=True,
+ persistent_workers=True,
+ collate_fn=bbox_collate_fn,
+ )
+
+ else:
+ return DataLoader(
+ self.train,
+ batch_size=self.batch_size,
+ shuffle=True,
+ num_workers=self.num_workers,
+ pin_memory=True,
+ persistent_workers=True,
+ )
+
+ def val_dataloader(self):
+ if self.use_bbox:
+ return DataLoader(
+ self.val,
+ batch_size=self.batch_size,
+ shuffle=False,
+ num_workers=self.num_workers,
+ pin_memory=True,
+ persistent_workers=True,
+ collate_fn=bbox_collate_fn,
+ )
+ return DataLoader(
+ self.val,
+ batch_size=self.batch_size,
+ shuffle=False,
+ num_workers=self.num_workers,
+ )
+
+ def test_dataloader(self):
+ if self.use_bbox:
+ return DataLoader(
+ self.test,
+ batch_size=self.batch_size,
+ shuffle=False,
+ num_workers=self.num_workers,
+ pin_memory=True,
+ persistent_workers=True,
+ collate_fn=bbox_collate_fn,
+ )
+ return DataLoader(
+ self.test,
+ batch_size=self.batch_size,
+ shuffle=False,
+ num_workers=self.num_workers,
+ )
+
+
+def bbox_collate_fn(batch):
+ images = []
+ targets = []
+
+ for b in batch:
+ target = dict()
+
+ target["boxes"] = torch.from_numpy(b["bboxes"])
+ if target["boxes"].shape[0] == 0:
+ target["boxes"] = torch.zeros((0, 4), dtype=torch.float32)
+ images.append(b["image"])
+ target["boxes"] = torch.from_numpy(b["bboxes"])
+ target["masks"] = b["mask"].permute(2, 0, 1)
+ if target["boxes"].shape[0] == 0:
+ target["labels"] = torch.zeros(1, dtype=torch.int64)
+ else:
+ target["labels"] = torch.ones_like(target["boxes"][:, 0], dtype=torch.int64)
+
+ targets.append(target)
+
+ return {
+ "image": torch.stack(images),
+ "targets": targets,
+ }
diff --git a/dnafiber/data/intergrader/__init__.py b/dnafiber/data/intergrader/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..983f876cace65cecee3b1587f6dd49e086d83a48
--- /dev/null
+++ b/dnafiber/data/intergrader/__init__.py
@@ -0,0 +1 @@
+from .const import *
\ No newline at end of file
diff --git a/dnafiber/data/intergrader/__pycache__/__init__.cpython-312.pyc b/dnafiber/data/intergrader/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7d74f18a56e8ec761c0df6edbdf0181283aa721e
Binary files /dev/null and b/dnafiber/data/intergrader/__pycache__/__init__.cpython-312.pyc differ
diff --git a/dnafiber/data/intergrader/__pycache__/analysis.cpython-312.pyc b/dnafiber/data/intergrader/__pycache__/analysis.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b4e41ac0c51b247d8b9fcbcb5c205c743d70bb3d
Binary files /dev/null and b/dnafiber/data/intergrader/__pycache__/analysis.cpython-312.pyc differ
diff --git a/dnafiber/data/intergrader/__pycache__/const.cpython-312.pyc b/dnafiber/data/intergrader/__pycache__/const.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..53bed8195ecccacbd65f01f998d2e8d2ee5b3ae3
Binary files /dev/null and b/dnafiber/data/intergrader/__pycache__/const.cpython-312.pyc differ
diff --git a/dnafiber/data/intergrader/__pycache__/io.cpython-312.pyc b/dnafiber/data/intergrader/__pycache__/io.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bba9c87fd011f943fa8c46bff5e35e2410dafed2
Binary files /dev/null and b/dnafiber/data/intergrader/__pycache__/io.cpython-312.pyc differ
diff --git a/dnafiber/data/intergrader/__pycache__/plot.cpython-312.pyc b/dnafiber/data/intergrader/__pycache__/plot.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0e48b3022a399c90bd5334dcbe9b854f75719080
Binary files /dev/null and b/dnafiber/data/intergrader/__pycache__/plot.cpython-312.pyc differ
diff --git a/dnafiber/data/intergrader/analysis.py b/dnafiber/data/intergrader/analysis.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9b7fa9e3db75332c82f5998ee3dcc0bb417874c
--- /dev/null
+++ b/dnafiber/data/intergrader/analysis.py
@@ -0,0 +1,120 @@
+from skimage.morphology import skeletonize
+import numpy as np
+from skimage.measure import label
+from tqdm.contrib.concurrent import thread_map # or thread_map
+def extract_fiber_properties(mask):
+
+ binary_mask = mask > 0
+ skeleton = skeletonize(binary_mask)
+ r = mask == 1
+ g = mask == 2
+ labeled_skeleton = label(skeleton, connectivity=2)
+ properties = {"R": [], "G": [], "ratio": []}
+ for i in range(1, labeled_skeleton.max() + 1):
+ fiber_mask = labeled_skeleton == i
+ sum_r = np.sum(r & fiber_mask)
+ sum_g = np.sum(g & fiber_mask)
+ if sum_r == 0 or sum_g == 0:
+ continue
+ properties["R"].append(np.sum(r & fiber_mask))
+ properties["G"].append(np.sum(g & fiber_mask))
+
+ properties["R"] = np.array(properties["R"])
+ properties["G"] = np.array(properties["G"])
+ properties["ratio"] = properties["R"] / (properties["G"])
+ properties["label"] = labeled_skeleton
+ return properties
+
+
+def filter_non_commons_fibers(properties):
+ # Properties is a a list of dicts. For each dict, we have a labelmap and a list of reds, greens and ratios
+ # We want to filter out the fibers that are not common in all images
+
+ binary_labels = [p['label'] > 0 for p in properties]
+ common_labels = np.logical_and.reduce(binary_labels)
+ filtered_properties = {k:[] for k in properties.keys()}
+ for i, p in enumerate(properties):
+ # We want to keep the labels that are common in all images
+ good_labels = common_labels * p['label']
+ indices = np.unique(good_labels[good_labels > 0])
+
+ filtered_properties.append({
+ "R": p["R"][common_labels],
+ "G": p["G"][common_labels],
+ "ratio": p["ratio"][common_labels],
+ "label": p["label"][common_labels]
+ })
+
+def skeletonize_mask(mask):
+ # Skeletonize the mask and return the skeleton
+ binary_mask = mask > 0
+ skeleton = skeletonize(binary_mask) * mask
+ return skeleton
+
+
+def skeletonize_data_dict(data_dict):
+ skeletons = dict()
+ for annotator, images in data_dict.items():
+ skeletons[annotator] = dict()
+ for image_type, masks in images.items():
+ skeletons[annotator][image_type] = thread_map(skeletonize_mask, masks, max_workers=8)
+
+ return skeletons
+
+
+def extract_properties_from_datadict(data_dict, with_common_analysis=True):
+ """
+ Extract the properties of the fibers from the data dictionary.
+ The data dictionary is a dict of annotators. Each value is a dict of images. Each image is a list of masks.
+ """
+ properties = dict(annotator=[], image_type=[], red=[], green=[], ratio=[], fiber_type=[])
+ all_annotators = list(data_dict.keys())
+
+ found_by = {a: [] for a in all_annotators}
+ properties.update(found_by)
+ for annotator, images in data_dict.items():
+ for image_type, masks in images.items():
+ for i, mask in enumerate(masks):
+ if with_common_analysis:
+ others_masks = []
+ other_annotators = []
+ for other in all_annotators:
+ if other == annotator:
+ continue
+ other_annotators.append(other)
+ others_masks.append(data_dict[other][image_type][i] > 0)
+
+ labels, num = label(mask>0, connectivity=2, return_num=True)
+ for l in range(1, num + 1):
+ fiber = labels == l
+ if np.sum(fiber) < 10:
+ continue
+
+ properties["annotator"].append(annotator)
+ properties["image_type"].append(image_type)
+
+ # Check for common fibers
+ properties[annotator].append(True)
+ if with_common_analysis:
+ for i, (other_mask, other_annotator) in enumerate(zip(others_masks, other_annotators)):
+ properties[other_annotator].append(np.any(fiber & other_mask))
+
+ red_length = np.sum(fiber & (mask == 1))
+ green_length = np.sum(fiber & (mask == 2))
+ if red_length == 0 or green_length == 0:
+ continue
+ properties["ratio"].append(green_length / (red_length + 1e-7)) # Avoid division by zero
+ properties["red"].append(red_length)
+ properties["green"].append(green_length)
+
+ segments, count = label(mask[fiber], connectivity=1, return_num=True)
+ if count == 1:
+ properties["fiber_type"].append("single")
+ elif count == 2:
+ properties["fiber_type"].append("double")
+ elif count > 2:
+ properties["fiber_type"].append("multiple")
+ else:
+ properties["fiber_type"].append("unknown")
+
+ return properties
\ No newline at end of file
diff --git a/dnafiber/data/intergrader/auto.py b/dnafiber/data/intergrader/auto.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed06ec6053d499748f83324a4df42547c112fd66
--- /dev/null
+++ b/dnafiber/data/intergrader/auto.py
@@ -0,0 +1,3 @@
+def inference_model(model, path, use_cuda=False):
+ pass
+
\ No newline at end of file
diff --git a/dnafiber/data/intergrader/const.py b/dnafiber/data/intergrader/const.py
new file mode 100644
index 0000000000000000000000000000000000000000..d51eaa079e3d4fb8782b3e1c149aaf78b900a5e0
--- /dev/null
+++ b/dnafiber/data/intergrader/const.py
@@ -0,0 +1,21 @@
+BLIND_MAPPING = {
+ "siB+M-01": "0",
+ "siB+M-04": "1",
+ "siBRCA2-02": "5",
+ "siBRCA2-03": "15",
+ "siTONSL-03": "11",
+ "siTONSL-04": "14",
+ "HLTF ko+si MMS22L-01": "8",
+ "HLTF ko+si MMS22L-02": "13",
+ "siBRCA2+SMARCAL KO-01": "2",
+ "siBRCA2+SMARCAL KO-03": "9",
+ "siBRCA2+SMARCAL KO-04": "16",
+ "siBRCA2-01": "4",
+ "59_siBRCA2-02": "7",
+ "siNT-01": "10",
+ "siNT-02": "12",
+ "siMMS22L_+dox-01": "3",
+ "siMMS22L_+dox-02": "6",
+}
+
+REVERSE_BLIND_MAPPING = {v: k for k, v in BLIND_MAPPING.items()}
\ No newline at end of file
diff --git a/dnafiber/data/intergrader/io.py b/dnafiber/data/intergrader/io.py
new file mode 100644
index 0000000000000000000000000000000000000000..bdb3aa39f9e8c4f0d59e4e4130a25c80e6a080b6
--- /dev/null
+++ b/dnafiber/data/intergrader/io.py
@@ -0,0 +1,27 @@
+import cv2
+import numpy as np
+from skimage.segmentation import expand_labels
+
+def read_to_mask(f):
+ img = cv2.imread(str(f), cv2.IMREAD_UNCHANGED)[:,:,::-1]
+ mask = np.zeros(img.shape[:2], dtype=np.uint8)
+ mask[img[:, :, 0] > 200] = 1
+ mask[img[:, :, 1] > 200] = 2
+
+ return mask
+
+
+def read_mask_from_path_gens(dict_gens, mapping=None):
+ output = {k: dict() for k in dict_gens.keys()}
+ for k, files in dict_gens.items():
+ for file in files:
+ name = file.parent.stem
+ if mapping is not None:
+ name = mapping.get(name, name)
+ mask = read_to_mask(file)
+ mask = expand_labels(mask, 1)
+ if output[k].get(name) is None:
+ output[k][name] = []
+ output[k][name].append(mask)
+ return output
+
diff --git a/dnafiber/data/intergrader/plot.py b/dnafiber/data/intergrader/plot.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8d185fcddee4f037dc873d03e093fdb5d1278ea
--- /dev/null
+++ b/dnafiber/data/intergrader/plot.py
@@ -0,0 +1,172 @@
+import matplotlib.pyplot as plt
+import numpy as np
+from matplotlib.colors import ListedColormap
+from skimage.measure import label, regionprops
+import base64
+from typing import Callable
+
+def imshow_compare(data_dict, ax_size=4, draw_bbox=False, max_images=None):
+ """
+ Display the images in a grid format for comparison.
+ Each key is an annotator, each value is another dict, where the key is the image type and the value the list of corresponding images.
+ """
+ # 0 is black, 1 is red, 2 is green
+ cmap = ListedColormap(['black', 'red', 'green'])
+
+ # Convert the data dictionary to a dict of annotators: list of images
+ data = dict()
+ for annotator, images in data_dict.items():
+ if annotator not in data:
+ data[annotator] = []
+ for image_type, masks in images.items():
+ for mask in masks:
+ data[annotator].append(mask)
+
+ annotators = list(data.keys())
+ num_images = len(data[annotators[0]])
+ if max_images is not None and num_images > max_images:
+ num_images = max_images
+ num_annotators = len(annotators)
+
+ fig_size = (ax_size * num_annotators, ax_size * num_images)
+ fig, axes = plt.subplots(num_images, num_annotators, figsize=fig_size, squeeze=False)
+
+ for i, annotator in enumerate(annotators):
+ for j in range(num_images):
+ if max_images is not None and j > max_images:
+ break
+ ax = axes[j, i]
+ mask = data[annotator][j]
+ ax.imshow(mask, cmap=cmap, interpolation='nearest')
+ ax.axis('off')
+ ax.set_xticks([])
+ ax.set_yticks([])
+ if draw_bbox:
+ mask = mask > 0
+ labeled_mask = label(mask, connectivity=2)
+ regions = regionprops(labeled_mask)
+ for region in regions:
+ minr, minc, maxr, maxc = region.bbox
+ rect = plt.Rectangle((minc, minr), maxc - minc, maxr - minr,
+ fill=False, edgecolor='yellow', linewidth=0.5)
+ ax.add_patch(rect)
+
+
+
+ if j == 0:
+ ax.set_title(annotator)
+
+
+ fig.tight_layout()
+ return fig, axes
+
+
+def add_p_value_annotation(fig, array_columns, stats_test, subplot=None, _format=dict(interline=0.07, text_height=1.07, color='black')):
+ ''' Adds notations giving the p-value between two box plot data (t-test two-sided comparison)
+
+ Parameters:
+ ----------
+ fig: figure
+ plotly boxplot figure
+ array_columns: np.array
+ array of which columns to compare
+ e.g.: [[0,1], [1,2]] compares column 0 with 1 and 1 with 2
+ subplot: None or int
+ specifies if the figures has subplots and what subplot to add the notation to
+ _format: dict
+ format characteristics for the lines
+
+ Returns:
+ -------
+ fig: figure
+ figure with the added notation
+ '''
+ # Specify in what y_range to plot for each pair of columns
+ y_range = np.zeros([len(array_columns), 2])
+ for i in range(len(array_columns)):
+ y_range[i] = [1.01+i*_format['interline'], 1.02+i*_format['interline']]
+
+ # Get values from figure
+ fig_dict = fig.to_dict()
+ # Get indices if working with subplots
+ if subplot:
+ if subplot == 1:
+ subplot_str = ''
+ else:
+ subplot_str =str(subplot)
+ indices = [] #Change the box index to the indices of the data for that subplot
+ for index, data in enumerate(fig_dict['data']):
+ #print(index, data['xaxis'], 'x' + subplot_str)
+ if data['xaxis'] == 'x' + subplot_str:
+ indices = np.append(indices, index)
+ indices = [int(i) for i in indices]
+ print((indices))
+ else:
+ subplot_str = ''
+
+ # Print the p-values
+ for index, column_pair in enumerate(array_columns):
+ if subplot:
+ data_pair = [indices[column_pair[0]], indices[column_pair[1]]]
+ else:
+ data_pair = column_pair
+
+ # Mare sure it is selecting the data and subplot you want
+ #print('0:', fig_dict['data'][data_pair[0]]['name'], fig_dict['data'][data_pair[0]]['xaxis'])
+ #print('1:', fig_dict['data'][data_pair[1]]['name'], fig_dict['data'][data_pair[1]]['xaxis'])
+
+ if isinstance(stats_test, Callable):
+ # Get the p-value
+ d1 = fig_dict['data'][data_pair[0]]['y']
+ d2 = fig_dict['data'][data_pair[1]]['y']
+ d1 = base64.b64decode(d1['bdata'])
+ d2 = base64.b64decode(d2['bdata'])
+ d1 = np.frombuffer(d1, dtype=np.float64)
+ d2 = np.frombuffer(d2, dtype=np.float64)
+ pvalue = stats_test(
+ d1,
+ d2,
+ )[1]
+ else:
+ pvalue = stats_test[index]
+ if pvalue >= 0.05:
+ symbol = 'ns'
+ elif pvalue >= 0.01:
+ symbol = '*'
+ elif pvalue >= 0.001:
+ symbol = '**'
+ else:
+ symbol = '***'
+ # Vertical line
+ fig.add_shape(type="line",
+ xref="x"+subplot_str, yref="y"+subplot_str+" domain",
+ x0=column_pair[0], y0=y_range[index][0],
+ x1=column_pair[0], y1=y_range[index][1],
+ line=dict(color=_format['color'], width=2,)
+ )
+ # Horizontal line
+ fig.add_shape(type="line",
+ xref="x"+subplot_str, yref="y"+subplot_str+" domain",
+ x0=column_pair[0], y0=y_range[index][1],
+ x1=column_pair[1], y1=y_range[index][1],
+ line=dict(color=_format['color'], width=2,)
+ )
+ # Vertical line
+ fig.add_shape(type="line",
+ xref="x"+subplot_str, yref="y"+subplot_str+" domain",
+ x0=column_pair[1], y0=y_range[index][0],
+ x1=column_pair[1], y1=y_range[index][1],
+ line=dict(color=_format['color'], width=2,)
+ )
+ ## add text at the correct x, y coordinates
+ ## for bars, there is a direct mapping from the bar number to 0, 1, 2...
+ fig.add_annotation(dict(font=dict(color=_format['color'],size=14),
+ x=(column_pair[0] + column_pair[1])/2,
+ y=y_range[index][1]*_format['text_height'],
+ showarrow=False,
+ text=symbol,
+ textangle=0,
+ xref="x"+subplot_str,
+ yref="y"+subplot_str+" domain"
+ ))
+ return fig
\ No newline at end of file
diff --git a/dnafiber/data/utils.py b/dnafiber/data/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..48b48cb11d7124bcc7890d1a24d99537d1d56696
--- /dev/null
+++ b/dnafiber/data/utils.py
@@ -0,0 +1,80 @@
+import base64
+
+from xml.dom import minidom
+import cv2
+import numpy as np
+from czifile import CziFile
+from tifffile import imread
+
+
+def read_svg(svg_path):
+ doc = minidom.parse(str(svg_path))
+ img_strings = {
+ path.getAttribute("id"): path.getAttribute("href")
+ for path in doc.getElementsByTagName("image")
+ }
+ doc.unlink()
+
+ red = img_strings["Red"]
+ green = img_strings["Green"]
+ red = base64.b64decode(red.split(",")[1])
+ green = base64.b64decode(green.split(",")[1])
+ red = cv2.imdecode(np.frombuffer(red, dtype=np.uint8), cv2.IMREAD_UNCHANGED)
+ green = cv2.imdecode(np.frombuffer(green, dtype=np.uint8), cv2.IMREAD_UNCHANGED)
+
+ red = cv2.cvtColor(red, cv2.COLOR_BGRA2GRAY)
+ green = cv2.cvtColor(green, cv2.COLOR_BGRA2GRAY)
+ mask = np.zeros_like(red)
+ mask[red > 0] = 1
+ mask[green > 0] = 2
+ return mask
+
+
+def extract_bboxes(mask):
+ mask = np.array(mask)
+ mask = mask.astype(np.uint8)
+
+ # Find connected components
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(
+ mask, connectivity=8
+ )
+ bboxes = []
+ for i in range(1, num_labels):
+ x, y, w, h, area = stats[i]
+ bboxes.append([x, y, x + w, y + h])
+ return bboxes
+
+
+def preprocess(raw_data, reverse_channels=False):
+ MAX_VALUE = 2**16 - 1
+ if raw_data.ndim == 2:
+ raw_data = raw_data[np.newaxis, :, :]
+ h, w = raw_data.shape[1:3]
+ orders = np.arange(raw_data.shape[0])[::-1] # Reverse channel order
+ result = np.zeros((h, w, 3), dtype=np.uint8)
+
+ for i, chan in enumerate(raw_data):
+ hist, bins = np.histogram(chan.ravel(), MAX_VALUE + 1, (0, MAX_VALUE + 1))
+ cdf = hist.cumsum()
+ cdf_normalized = cdf / cdf[-1]
+ bmax = np.searchsorted(cdf_normalized, 0.99, side="left")
+ clip = np.clip(chan, 0, bmax).astype(np.float32)
+ clip = (clip - clip.min()) / (bmax - clip.min()) * 255
+ result[:, :, orders[i]] = clip
+ if reverse_channels:
+ # Reverse channels 0 and 1
+ result = result[:, :, [1, 0, 2]]
+ return result
+
+
+def read_czi(filepath):
+ data = CziFile(filepath)
+
+ return data.asarray().squeeze()
+
+
+def read_tiff(filepath):
+
+ data = imread(filepath).squeeze()
+
+ return data
\ No newline at end of file
diff --git a/dnafiber/deployment.py b/dnafiber/deployment.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e8c2f89ea09cd9e28757e88a2c56c1daa07e404
--- /dev/null
+++ b/dnafiber/deployment.py
@@ -0,0 +1,44 @@
+from dnafiber.trainee import Trainee
+from dnafiber.postprocess.fiber import FiberProps
+import pandas as pd
+
+def _get_model(revision, device="cuda"):
+ if revision is None:
+ model = Trainee.from_pretrained(
+ "ClementP/DeepFiberQ", arch="unet", encoder_name="mit_b0"
+ )
+ else:
+ model = Trainee.from_pretrained(
+ "ClementP/DeepFiberQ",
+ revision=revision,
+ )
+ return model.eval().to(device)
+
+
+def format_results(results: list[FiberProps], pixel_size: float) -> pd.DataFrame:
+ """
+ Format the results for display in the UI.
+ """
+ results = [fiber for fiber in results if fiber.is_valid]
+ all_results = dict(
+ FirstAnalog=[], SecondAnalog=[], length=[], ratio=[], fiber_type=[]
+ )
+ all_results["FirstAnalog"].extend([fiber.red * pixel_size for fiber in results])
+ all_results["SecondAnalog"].extend([fiber.green * pixel_size for fiber in results])
+ all_results["length"].extend(
+ [fiber.red * pixel_size + fiber.green * pixel_size for fiber in results]
+ )
+ all_results["ratio"].extend([fiber.ratio for fiber in results])
+ all_results["fiber_type"].extend([fiber.fiber_type for fiber in results])
+
+ return pd.DataFrame.from_dict(all_results)
+
+
+
+
+MODELS_ZOO = {
+ "Ensemble": "ensemble",
+ "SegFormer MiT-B4": "segformer_mit_b4",
+ "SegFormer MiT-B2": "segformer_mit_b2",
+ "U-Net SE-ResNet50": "unet_se_resnet50",
+}
\ No newline at end of file
diff --git a/dnafiber/inference.py b/dnafiber/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8bcbff46cc8dcc1d111158546fbdbfa0a254763
--- /dev/null
+++ b/dnafiber/inference.py
@@ -0,0 +1,105 @@
+import torch.nn.functional as F
+import numpy as np
+import torch
+from torchvision.transforms._functional_tensor import normalize
+import pandas as pd
+from skimage.segmentation import expand_labels
+from skimage.measure import label
+import albumentations as A
+from monai.inferers import SlidingWindowInferer
+from dnafiber.deployment import _get_model
+from dnafiber.postprocess import refine_segmentation
+
+transform = A.Compose(
+ [
+ A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
+ A.ToTensorV2(),
+ ]
+)
+
+
+def preprocess_image(image):
+ image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0)
+ image = normalize(image, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
+ return image
+
+
+def convert_to_dataset(counts):
+ data = {"index": [], "red": [], "green": [], "ratio": []}
+ for k, v in counts.items():
+ data["index"].append(k)
+ data["green"].append(v["green"])
+ data["red"].append(v["red"])
+ if v["red"] == 0:
+ data["ratio"].append(np.nan)
+ else:
+ data["ratio"].append(v["green"] / (v["red"]))
+ df = pd.DataFrame(data)
+ return df
+
+
+def convert_mask_to_image(mask, expand=False):
+ if expand:
+ mask = expand_labels(mask, distance=expand)
+ h, w = mask.shape
+ image = np.zeros((h, w, 3), dtype=np.uint8)
+ GREEN = np.array([0, 255, 0])
+ RED = np.array([255, 0, 0])
+
+ image[mask == 1] = RED
+ image[mask == 2] = GREEN
+
+ return image
+
+
+@torch.inference_mode()
+def infer(model, image, device, scale=0.13, to_numpy=True, only_probabilities=False):
+ if isinstance(model, str):
+ model = _get_model(device=device, revision=model)
+ model_pixel_size = 0.26
+
+ scale = scale / model_pixel_size
+ tensor = transform(image=image)["image"].unsqueeze(0).to(device)
+ h, w = tensor.shape[2], tensor.shape[3]
+ device = torch.device(device)
+ with torch.autocast(device_type=device.type):
+ tensor = F.interpolate(
+ tensor,
+ size=(int(h * scale), int(w * scale)),
+ mode="bilinear",
+ )
+ if tensor.shape[2] > 1024 or tensor.shape[3] > 1024:
+ inferer = SlidingWindowInferer(
+ roi_size=(1024, 1024),
+ sw_batch_size=4,
+ overlap=0.25,
+ mode="gaussian",
+ device=device,
+ progress=True,
+ )
+ output = inferer(tensor, model)
+ else:
+ output = model(tensor)
+
+ probabilities = F.softmax(output, dim=1)
+ if only_probabilities:
+ probabilities = probabilities.cpu()
+
+ probabilities = F.interpolate(
+ probabilities,
+ size=(h, w),
+ mode="bilinear",
+ )
+ return probabilities
+
+ output = F.interpolate(
+ probabilities.argmax(dim=1, keepdim=True).float(),
+ size=(h, w),
+ mode="nearest",
+ )
+
+ output = output.squeeze().byte()
+ if to_numpy:
+ output = output.cpu().numpy()
+
+ return output
diff --git a/dnafiber/metric.py b/dnafiber/metric.py
new file mode 100644
index 0000000000000000000000000000000000000000..c94bd2af55b8672431c878aa3a44f17c3a2f3b5e
--- /dev/null
+++ b/dnafiber/metric.py
@@ -0,0 +1,150 @@
+import kornia as K
+import torch
+import torchmetrics.functional as F
+from skimage.measure import label
+from torchmetrics import Metric
+
+
+class DNAFIBERMetric(Metric):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ self.add_state(
+ "detection_tp",
+ default=torch.tensor(0, dtype=torch.int64),
+ dist_reduce_fx="sum",
+ )
+ self.add_state(
+ "fiber_red_dice",
+ default=torch.tensor(0, dtype=torch.float32),
+ dist_reduce_fx="sum",
+ )
+ self.add_state(
+ "fiber_green_dice",
+ default=torch.tensor(0, dtype=torch.float32),
+ dist_reduce_fx="sum",
+ )
+ self.add_state(
+ "fiber_red_recall",
+ default=torch.tensor(0, dtype=torch.float32),
+ dist_reduce_fx="sum",
+ )
+ self.add_state(
+ "fiber_green_recall",
+ default=torch.tensor(0, dtype=torch.float32),
+ dist_reduce_fx="sum",
+ )
+ # Specificity
+ self.add_state(
+ "fiber_red_precision",
+ default=torch.tensor(0, dtype=torch.float32),
+ dist_reduce_fx="sum",
+ )
+ self.add_state(
+ "fiber_green_precision",
+ default=torch.tensor(0, dtype=torch.float32),
+ dist_reduce_fx="sum",
+ )
+
+ self.add_state(
+ "detection_fp",
+ default=torch.tensor(0, dtype=torch.int64),
+ dist_reduce_fx="sum",
+ )
+ self.add_state(
+ "N",
+ default=torch.tensor(0, dtype=torch.int64),
+ dist_reduce_fx="sum",
+ )
+
+ def update(self, preds, target):
+ if preds.ndim == 4:
+ preds = preds.argmax(dim=1)
+ if target.ndim == 4:
+ target = target.squeeze(1)
+ B, H, W = preds.shape
+ preds_labels = []
+ target_labels = []
+ binary_preds = preds > 0
+ binary_target = target > 0
+ N_true_labels = 0
+ for i in range(B):
+ pred = binary_preds[i].detach().cpu().numpy()
+ target_np = binary_target[i].detach().cpu().numpy()
+ pred_labels = label(pred, connectivity=2)
+ target_labels_np = label(target_np, connectivity=2)
+ preds_labels.append(torch.from_numpy(pred_labels).to(preds.device))
+ target_labels.append(torch.from_numpy(target_labels_np).to(preds.device))
+ N_true_labels += target_labels_np.max()
+
+ preds_labels = torch.stack(preds_labels)
+ target_labels = torch.stack(target_labels)
+
+ for i, plab in enumerate(preds_labels):
+ labels = torch.unique(plab)
+ for blob in labels:
+ if blob == 0:
+ continue
+ pred_mask = plab == blob
+ pixels_in_common = torch.any(pred_mask & binary_target[i])
+ if pixels_in_common:
+ self.detection_tp += 1
+ gt_label = target_labels[i][pred_mask].unique()[-1]
+ gt_mask = target_labels[i] == gt_label
+ common_mask = pred_mask | gt_mask
+ pred_fiber = preds[i][common_mask]
+ gt_fiber = target[i][common_mask]
+ dices = F.dice(
+ pred_fiber,
+ gt_fiber,
+ num_classes=3,
+ ignore_index=0,
+ average=None,
+ )
+ dices = torch.nan_to_num(dices, nan=0.0)
+ self.fiber_red_dice += dices[1]
+ self.fiber_green_dice += dices[2]
+ recalls = F.recall(
+ pred_fiber,
+ gt_fiber,
+ num_classes=3,
+ ignore_index=0,
+ task="multiclass",
+ average=None,
+ )
+ recalls = torch.nan_to_num(recalls, nan=0.0)
+ self.fiber_red_recall += recalls[1]
+ self.fiber_green_recall += recalls[2]
+
+ # Specificity
+ specificity = F.precision(
+ pred_fiber,
+ gt_fiber,
+ num_classes=3,
+ ignore_index=0,
+ task="multiclass",
+ average=None,
+ )
+ specificity = torch.nan_to_num(specificity, nan=0.0)
+ self.fiber_red_precision += specificity[1]
+ self.fiber_green_precision += specificity[2]
+
+ else:
+ self.detection_fp += 1
+
+ self.N += N_true_labels
+
+ def compute(self):
+ return {
+ "detection_precision": self.detection_tp
+ / (self.detection_tp + self.detection_fp + 1e-7),
+ "detection_recall": self.detection_tp / (self.N + 1e-7),
+ "fiber_red_dice": self.fiber_red_dice / (self.detection_tp + 1e-7),
+ "fiber_green_dice": self.fiber_green_dice / (self.detection_tp + 1e-7),
+ "fiber_red_recall": self.fiber_red_recall / (self.detection_tp + 1e-7),
+ "fiber_green_recall": self.fiber_green_recall / (self.detection_tp + 1e-7),
+ "fiber_red_precision": self.fiber_red_precision
+ / (self.detection_tp + 1e-7),
+ "fiber_green_precision": self.fiber_green_precision
+ / (self.detection_tp + 1e-7),
+ }
diff --git a/dnafiber/model/maskrcnn.py b/dnafiber/model/maskrcnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/dnafiber/postprocess/__init__.py b/dnafiber/postprocess/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8d092ba103ac2115498368327064f129764562e
--- /dev/null
+++ b/dnafiber/postprocess/__init__.py
@@ -0,0 +1 @@
+from .core import refine_segmentation
diff --git a/dnafiber/postprocess/__pycache__/__init__.cpython-312.pyc b/dnafiber/postprocess/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a52a0e1a96834225fb98bc0f3f46246eeed38f70
Binary files /dev/null and b/dnafiber/postprocess/__pycache__/__init__.cpython-312.pyc differ
diff --git a/dnafiber/postprocess/__pycache__/core.cpython-312.pyc b/dnafiber/postprocess/__pycache__/core.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..20bbd2765e3117d0d48c63f30216260ad4540c5c
Binary files /dev/null and b/dnafiber/postprocess/__pycache__/core.cpython-312.pyc differ
diff --git a/dnafiber/postprocess/__pycache__/fiber.cpython-312.pyc b/dnafiber/postprocess/__pycache__/fiber.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0053e4430a919d5740f314ace842619e9444b886
Binary files /dev/null and b/dnafiber/postprocess/__pycache__/fiber.cpython-312.pyc differ
diff --git a/dnafiber/postprocess/__pycache__/skan.cpython-312.pyc b/dnafiber/postprocess/__pycache__/skan.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e495fffb403d131e38f162499ecf10ea33023b8a
Binary files /dev/null and b/dnafiber/postprocess/__pycache__/skan.cpython-312.pyc differ
diff --git a/dnafiber/postprocess/core.py b/dnafiber/postprocess/core.py
new file mode 100644
index 0000000000000000000000000000000000000000..90fa1bdfd892d16154bbd878676ba693b3ae2edf
--- /dev/null
+++ b/dnafiber/postprocess/core.py
@@ -0,0 +1,274 @@
+import numpy as np
+import cv2
+from typing import List, Tuple
+from dnafiber.postprocess.skan import find_endpoints, compute_points_angle
+from scipy.spatial.distance import cdist
+
+from scipy.sparse.csgraph import connected_components
+from scipy.sparse import csr_array
+from skimage.morphology import skeletonize
+from dnafiber.postprocess.skan import find_line_intersection
+from dnafiber.postprocess.fiber import Fiber, FiberProps, Bbox
+from itertools import compress
+import matplotlib.pyplot as plt
+from matplotlib.colors import ListedColormap
+
+cmlabel = ListedColormap(["black", "red", "green"])
+
+MIN_ANGLE = 20
+MIN_BRANCH_LENGTH = 10
+MIN_BRANCH_DISTANCE = 30
+
+
+def handle_multiple_fiber_in_cc(fiber, junctions_fiber, coordinates):
+ for y, x in junctions_fiber:
+ fiber[y - 1 : y + 2, x - 1 : x + 2] = 0
+
+ endpoints = find_endpoints(fiber > 0)
+ endpoints = np.asarray(endpoints)
+ # We only keep the endpoints that are close to the junction
+ # We compute the distance between the endpoints and the junctions
+ distances = np.linalg.norm(
+ np.expand_dims(endpoints, axis=1) - np.expand_dims(junctions_fiber, axis=0),
+ axis=2,
+ )
+ # We only keep the endpoints that are close to the junctions
+ distances = distances < 5
+ endpoints = endpoints[distances.any(axis=1)]
+
+ retval, branches, branches_stats, _ = cv2.connectedComponentsWithStatsWithAlgorithm(
+ fiber, connectivity=8, ccltype=cv2.CCL_BOLELLI, ltype=cv2.CV_16U
+ )
+ branches_bboxes = branches_stats[
+ :,
+ [cv2.CC_STAT_LEFT, cv2.CC_STAT_TOP, cv2.CC_STAT_WIDTH, cv2.CC_STAT_HEIGHT],
+ ]
+
+ num_branches = branches_bboxes.shape[0] - 1
+ # We associate the endpoints to the branches
+ endpoints_ids = np.zeros((len(endpoints),), dtype=np.uint16)
+ endpoints_color = np.zeros((len(endpoints),), dtype=np.uint8)
+ for i, endpoint in enumerate(endpoints):
+ # Get the branch id
+ branch_id = branches[endpoint[0], endpoint[1]]
+ # Check if the branch id is not 0
+ if branch_id != 0:
+ endpoints_ids[i] = branch_id
+ endpoints_color[i] = fiber[endpoint[0], endpoint[1]]
+
+ # We remove the small branches
+ kept_branches = set()
+ for i in range(1, num_branches + 1):
+ # Get the branch
+ branch = branches == i
+ # Compute the area of the branch
+ area = np.sum(branch.astype(np.uint8))
+ # If the area is less than 10 pixels, remove the branch
+ if area < MIN_BRANCH_LENGTH:
+ branches[branch] = 0
+ else:
+ kept_branches.add(i)
+
+ # We remove the endpoints that are in the filtered branches
+ remaining_idxs = np.isin(endpoints_ids, np.asarray(list(kept_branches)))
+ if remaining_idxs.sum() == 0:
+ return []
+ endpoints = endpoints[remaining_idxs]
+
+ endpoints_color = endpoints_color[remaining_idxs]
+ endpoints_ids = endpoints_ids[remaining_idxs]
+
+ # We compute the angles of the endpoints
+ angles = compute_points_angle(fiber, endpoints, steps=15)
+ angles = np.rad2deg(angles)
+ # We compute the difference of angles between all the endpoints
+ endpoints_angles_diff = cdist(angles[:, None], angles[:, None], metric="cityblock")
+
+ # Put inf to the diagonal
+ endpoints_angles_diff[range(len(endpoints)), range(len(endpoints))] = np.inf
+ endpoints_distances = cdist(endpoints, endpoints, metric="euclidean")
+
+ endpoints_distances[range(len(endpoints)), range(len(endpoints))] = np.inf
+
+ # We sort by the distance
+ endpoints_distances[endpoints_distances > MIN_BRANCH_DISTANCE] = np.inf
+ endpoints_distances[endpoints_angles_diff > MIN_ANGLE] = np.inf
+
+ matchB = np.argmin(endpoints_distances, axis=1)
+ values = np.take_along_axis(endpoints_distances, matchB[:, None], axis=1)
+
+ added_edges = dict()
+ N = len(endpoints)
+ A = np.eye(N, dtype=np.uint8)
+ for i in range(N):
+ for j in range(N):
+ if i == j:
+ continue
+ if endpoints_ids[i] == endpoints_ids[j]:
+ A[i, j] = 1
+ A[j, i] = 1
+
+ if matchB[i] == j and values[i, 0] < np.inf:
+ added_edges[i] = j
+ A[i, j] = 1
+ A[j, i] = 1
+
+ A = csr_array(A)
+ n, ccs = connected_components(A, directed=False, return_labels=True)
+ unique_clusters = np.unique(ccs)
+ results = []
+ for c in unique_clusters:
+ idx = np.where(ccs == c)[0]
+ branches_ids = np.unique(endpoints_ids[idx])
+
+ unique_branches = np.logical_or.reduce(
+ [branches == i for i in branches_ids], axis=0
+ )
+
+ commons_bboxes = branches_bboxes[branches_ids]
+ # Compute the union of the bboxes
+ min_x = np.min(commons_bboxes[:, 0])
+ min_y = np.min(commons_bboxes[:, 1])
+ max_x = np.max(commons_bboxes[:, 0] + commons_bboxes[:, 2])
+ max_y = np.max(commons_bboxes[:, 1] + commons_bboxes[:, 3])
+
+ new_fiber = fiber[min_y:max_y, min_x:max_x]
+ new_fiber = unique_branches[min_y:max_y, min_x:max_x] * new_fiber
+ for cidx in idx:
+ if cidx not in added_edges:
+ continue
+ pointA = endpoints[cidx]
+ pointB = endpoints[added_edges[cidx]]
+ pointA = (
+ pointA[1] - min_x,
+ pointA[0] - min_y,
+ )
+ pointB = (
+ pointB[1] - min_x,
+ pointB[0] - min_y,
+ )
+ colA = endpoints_color[cidx]
+ colB = endpoints_color[added_edges[cidx]]
+ new_fiber = cv2.line(
+ new_fiber,
+ pointA,
+ pointB,
+ color=2 if colA != colB else int(colA),
+ thickness=1,
+ )
+ # We express the bbox in the original image
+ bbox = (
+ coordinates[0] + min_x,
+ coordinates[1] + min_y,
+ max_x - min_x,
+ max_y - min_y,
+ )
+ bbox = Bbox(x=bbox[0], y=bbox[1], width=bbox[2], height=bbox[3])
+ result = Fiber(bbox=bbox, data=new_fiber)
+ results.append(result)
+ return results
+
+
+def handle_ccs_with_junctions(
+ ccs: List[np.ndarray],
+ junctions: List[List[Tuple[int, int]]],
+ coordinates: List[Tuple[int, int]],
+):
+ """
+ Handle the connected components with junctions.
+ The function takes a list of connected components, a list of list of junctions and a list of coordinates.
+ The junctions
+ The coordinates corresponds to the top left corner of the connected component.
+ """
+ jncts_fibers = []
+ for fiber, junction, coordinate in zip(ccs, junctions, coordinates):
+ jncts_fibers += handle_multiple_fiber_in_cc(fiber, junction, coordinate)
+
+ return jncts_fibers
+
+
+def refine_segmentation(segmentation, fix_junctions=True, show=False):
+ skeleton = skeletonize(segmentation > 0, method="lee").astype(np.uint8)
+ skeleton_gt = skeleton * segmentation
+ retval, labels, stats, centroids = cv2.connectedComponentsWithStatsWithAlgorithm(
+ skeleton, connectivity=8, ccltype=cv2.CCL_BOLELLI, ltype=cv2.CV_16U
+ )
+
+ bboxes = stats[
+ :,
+ [
+ cv2.CC_STAT_LEFT,
+ cv2.CC_STAT_TOP,
+ cv2.CC_STAT_WIDTH,
+ cv2.CC_STAT_HEIGHT,
+ ],
+ ]
+
+ local_fibers = []
+ coordinates = []
+ junctions = []
+ for i in range(1, retval):
+ bbox = bboxes[i]
+ x1, y1, w, h = bbox
+ local_gt = skeleton_gt[y1 : y1 + h, x1 : x1 + w]
+ local_label = (labels[y1 : y1 + h, x1 : x1 + w] == i).astype(np.uint8)
+ local_fiber = local_gt * local_label
+ local_fibers.append(local_fiber)
+ coordinates.append(np.asarray([x1, y1, w, h]))
+ local_junctions = find_line_intersection(local_fiber > 0)
+ local_junctions = np.where(local_junctions)
+ local_junctions = np.array(local_junctions).transpose()
+ junctions.append(local_junctions)
+ if show:
+ for bbox, junction in zip(coordinates, junctions):
+ x, y, w, h = bbox
+ junction_to_global = np.array(junction) + np.array([y, x])
+
+ plt.scatter(
+ junction_to_global[:, 1],
+ junction_to_global[:, 0],
+ color="white",
+ s=30,
+ alpha=0.35,
+ )
+
+ plt.imshow(skeleton_gt, cmap=cmlabel, interpolation="nearest")
+ plt.axis("off")
+ plt.xticks([])
+ plt.yticks([])
+ plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
+ plt.show()
+
+ fibers = []
+ if fix_junctions:
+ has_junctions = [len(j) > 0 for j in junctions]
+ for fiber, coordinate in zip(
+ compress(local_fibers, np.logical_not(has_junctions)),
+ compress(coordinates, np.logical_not(has_junctions)),
+ ):
+ bbox = Bbox(
+ x=coordinate[0],
+ y=coordinate[1],
+ width=coordinate[2],
+ height=coordinate[3],
+ )
+ fibers.append(Fiber(bbox=bbox, data=fiber))
+
+ fibers += handle_ccs_with_junctions(
+ compress(local_fibers, has_junctions),
+ compress(junctions, has_junctions),
+ compress(coordinates, has_junctions),
+ )
+ else:
+ for fiber, coordinate in zip(local_fibers, coordinates):
+ bbox = Bbox(
+ x=coordinate[0],
+ y=coordinate[1],
+ width=coordinate[2],
+ height=coordinate[3],
+ )
+ fibers.append(Fiber(bbox=bbox, data=fiber))
+
+ fiberprops = [FiberProps(fiber=f, fiber_id=i) for i, f in enumerate(fibers)]
+
+ return fiberprops
diff --git a/dnafiber/postprocess/fiber.py b/dnafiber/postprocess/fiber.py
new file mode 100644
index 0000000000000000000000000000000000000000..dff9ef480dc25fb580e3e6eebf7ba7b9f5e88272
--- /dev/null
+++ b/dnafiber/postprocess/fiber.py
@@ -0,0 +1,129 @@
+import attrs
+import numpy as np
+from typing import Tuple
+from dnafiber.postprocess.skan import trace_skeleton
+
+@attrs.define
+class Bbox:
+ x: int
+ y: int
+ width: int
+ height: int
+
+ @property
+ def bbox(self) -> Tuple[int, int, int, int]:
+ return (self.x, self.y, self.width, self.height)
+
+ @bbox.setter
+ def bbox(self, value: Tuple[int, int, int, int]):
+ self.x, self.y, self.width, self.height = value
+
+
+@attrs.define
+class Fiber:
+ bbox: Bbox
+ data: np.ndarray
+
+
+@attrs.define
+class FiberProps:
+ fiber: Fiber
+ fiber_id: int
+ red_pixels: int = None
+ green_pixels: int = None
+ category: str = None
+
+ @property
+ def bbox(self):
+ return self.fiber.bbox.bbox
+
+ @bbox.setter
+ def bbox(self, value):
+ self.fiber.bbox = value
+
+ @property
+ def data(self):
+ return self.fiber.data
+
+ @data.setter
+ def data(self, value):
+ self.fiber.data = value
+
+ @property
+ def red(self):
+ if self.red_pixels is None:
+ self.red_pixels, self.green_pixels = self.counts
+ return self.red_pixels
+
+ @property
+ def green(self):
+ if self.green_pixels is None:
+ self.red_pixels, self.green_pixels = self.counts
+ return self.green_pixels
+
+ @property
+ def length(self):
+ return sum(self.counts)
+
+ @property
+ def counts(self):
+ if self.red_pixels is None or self.green_pixels is None:
+ self.red_pixels = np.sum(self.data == 1)
+ self.green_pixels = np.sum(self.data == 2)
+ return self.red_pixels, self.green_pixels
+
+ @property
+ def fiber_type(self):
+ if self.category is not None:
+ return self.category
+ red_pixels, green_pixels = self.counts
+ if red_pixels == 0 or green_pixels == 0:
+ self.category = "single"
+ else:
+ self.category = estimate_fiber_category(self.data)
+ return self.category
+
+ @property
+ def ratio(self):
+ return self.green / self.red
+
+ @property
+ def is_valid(self):
+ return (
+ self.fiber_type == "double"
+ or self.fiber_type == "one-two-one"
+ or self.fiber_type == "two-one-two"
+ )
+
+ def scaled_coordinates(self, scale: float) -> Tuple[int, int]:
+ """
+ Scale down the coordinates of the fiber's bounding box.
+ """
+ x, y, width, height = self.bbox
+ return (
+ int(x * scale),
+ int(y * scale),
+ int(width * scale),
+ int(height * scale),
+ )
+
+
+def estimate_fiber_category(fiber: np.ndarray) -> str:
+ """
+ Estimate the fiber category based on the number of red and green pixels.
+ """
+ coordinates = trace_skeleton(fiber > 0)
+ coordinates = np.asarray(coordinates)
+ values = fiber[coordinates[:, 0], coordinates[:, 1]]
+ diff = np.diff(values)
+ jump = np.sum(diff != 0)
+ n_ccs = jump + 1
+ if n_ccs == 2:
+ return "double"
+ elif n_ccs == 3:
+ if values[0] == 1:
+ return "one-two-one"
+ else:
+ return "two-one-two"
+ else:
+ return "multiple"
diff --git a/dnafiber/postprocess/skan.py b/dnafiber/postprocess/skan.py
new file mode 100644
index 0000000000000000000000000000000000000000..c86255f6fcbc2e41e922e1d5736b237bfc13d50c
--- /dev/null
+++ b/dnafiber/postprocess/skan.py
@@ -0,0 +1,211 @@
+# Functions to generate kernels of curve intersection
+import numpy as np
+import cv2
+import itertools
+from numba import njit, int64
+from numba.typed import List
+from numba.types import Tuple
+
+# Define the element type: a tuple of two int64
+tuple_type = Tuple((int64, int64))
+
+
+def find_neighbours(fibers_map, point):
+ """
+ Find the next point in the fiber starting from the given point.
+ The function returns None if the point is not in the fiber.
+ """
+ # Get the fiber id
+ neighbors = []
+ h, w = fibers_map.shape
+ for i in range(-1, 2):
+ for j in range(-1, 2):
+ # Skip the center point
+ if i == 0 and j == 0:
+ continue
+ # Get the next point
+ nextpoint = (point[0] + i, point[1] + j)
+ # Check if the next point is in the image
+ if (
+ nextpoint[0] < 0
+ or nextpoint[0] >= h
+ or nextpoint[1] < 0
+ or nextpoint[1] >= w
+ ):
+ continue
+
+ # Check if the next point is in the fiber
+ if fibers_map[nextpoint]:
+ neighbors.append(nextpoint)
+ return neighbors
+
+
+def compute_points_angle(fibers_map, points, steps=25):
+ """
+ For each endpoint, follow the fiber for a given number of steps and estimate the tangent line by
+ fitting a line to the visited points. The angle of the line is returned.
+ """
+ points_angle = np.zeros((len(points),), dtype=np.float32)
+ for i, point in enumerate(points):
+ # Find the fiber it belongs to
+ # Lets navigate along the fiber starting from the point during steps pixels.
+ # We compute the angles at each step and return the mean angle.
+ visited = trace_from_point(
+ fibers_map > 0, (point[0], point[1]), max_length=steps
+ )
+ visited = np.array(visited)
+ vx, vy, x, y = cv2.fitLine(visited[:, ::-1], cv2.DIST_L2, 0, 0.01, 0.01)
+ # Compute the angle of the line
+ points_angle[i] = np.arctan(vy / vx)
+
+ return points_angle
+
+
+def generate_nonadjacent_combination(input_list, take_n):
+ """
+ It generates combinations of m taken n at a time where there is no adjacent n.
+ INPUT:
+ input_list = (iterable) List of elements you want to extract the combination
+ take_n = (integer) Number of elements that you are going to take at a time in
+ each combination
+ OUTPUT:
+ all_comb = (np.array) with all the combinations
+ """
+ all_comb = []
+ for comb in itertools.combinations(input_list, take_n):
+ comb = np.array(comb)
+ d = np.diff(comb)
+ if len(d[d == 1]) == 0 and comb[-1] - comb[0] != 7:
+ all_comb.append(comb)
+ return all_comb
+
+
+def populate_intersection_kernel(combinations):
+ """
+ Maps the numbers from 0-7 into the 8 pixels surrounding the center pixel in
+ a 9 x 9 matrix clockwisely i.e. up_pixel = 0, right_pixel = 2, etc. And
+ generates a kernel that represents a line intersection, where the center
+ pixel is occupied and 3 or 4 pixels of the border are ocuppied too.
+ INPUT:
+ combinations = (np.array) matrix where every row is a vector of combinations
+ OUTPUT:
+ kernels = (List) list of 9 x 9 kernels/masks. each element is a mask.
+ """
+ n = len(combinations[0])
+ template = np.array(([-1, -1, -1], [-1, 1, -1], [-1, -1, -1]), dtype="int")
+ match = [(0, 1), (0, 2), (1, 2), (2, 2), (2, 1), (2, 0), (1, 0), (0, 0)]
+ kernels = []
+ for n in combinations:
+ tmp = np.copy(template)
+ for m in n:
+ tmp[match[m][0], match[m][1]] = 1
+ kernels.append(tmp)
+ return kernels
+
+
+def give_intersection_kernels():
+ """
+ Generates all the intersection kernels in a 9x9 matrix.
+ INPUT:
+ None
+ OUTPUT:
+ kernels = (List) list of 9 x 9 kernels/masks. each element is a mask.
+ """
+ input_list = np.arange(8)
+ taken_n = [4, 3]
+ kernels = []
+ for taken in taken_n:
+ comb = generate_nonadjacent_combination(input_list, taken)
+ tmp_ker = populate_intersection_kernel(comb)
+ kernels.extend(tmp_ker)
+ return kernels
+
+
+def find_line_intersection(input_image, show=0):
+ """
+ Applies morphologyEx with parameter HitsMiss to look for all the curve
+ intersection kernels generated with give_intersection_kernels() function.
+ INPUT:
+ input_image = (np.array dtype=np.uint8) binarized m x n image matrix
+ OUTPUT:
+ output_image = (np.array dtype=np.uint8) image where the nonzero pixels
+ are the line intersection.
+ """
+ input_image = input_image.astype(np.uint8)
+ kernel = np.array(give_intersection_kernels())
+ output_image = np.zeros(input_image.shape)
+ for i in np.arange(len(kernel)):
+ out = cv2.morphologyEx(
+ input_image,
+ cv2.MORPH_HITMISS,
+ kernel[i, :, :],
+ borderValue=0,
+ borderType=cv2.BORDER_CONSTANT,
+ )
+ output_image = output_image + out
+
+ return output_image
+
+
+@njit
+def get_neighbors_8(y, x, shape):
+ neighbors = List.empty_list(tuple_type)
+ for dy in range(-1, 2):
+ for dx in range(-1, 2):
+ if dy == 0 and dx == 0:
+ continue
+ ny, nx = y + dy, x + dx
+ if 0 <= ny < shape[0] and 0 <= nx < shape[1]:
+ neighbors.append((ny, nx))
+ return neighbors
+
+
+@njit
+def find_endpoints(skel):
+ endpoints = List.empty_list(tuple_type)
+ for y in range(skel.shape[0]):
+ for x in range(skel.shape[1]):
+ if skel[y, x] == 1:
+ count = 0
+ neighbors = get_neighbors_8(y, x, skel.shape)
+ for ny, nx in neighbors:
+ if skel[ny, nx] == 1:
+ count += 1
+ if count == 1:
+ endpoints.append((y, x))
+ return endpoints
+
+
+@njit
+def trace_skeleton(skel):
+ endpoints = find_endpoints(skel)
+ if len(endpoints) < 1:
+ return List.empty_list(tuple_type) # Return empty list with proper type
+
+ return trace_from_point(skel, endpoints[0], max_length=skel.sum())
+
+
+@njit
+def trace_from_point(skel, point, max_length=25):
+ visited = np.zeros_like(skel, dtype=np.uint8)
+ path = List.empty_list(tuple_type)
+
+ # Check if the starting point is on the skeleton
+ y, x = point
+ if y < 0 or y >= skel.shape[0] or x < 0 or x >= skel.shape[1] or skel[y, x] != 1:
+ return path
+
+ stack = List.empty_list(tuple_type)
+ stack.append(point)
+
+ while len(stack) > 0 and len(path) < max_length:
+ y, x = stack.pop()
+ if visited[y, x]:
+ continue
+ visited[y, x] = 1
+ path.append((y, x))
+ neighbors = get_neighbors_8(y, x, skel.shape)
+ for ny, nx in neighbors:
+ if skel[ny, nx] == 1 and not visited[ny, nx]:
+ stack.append((ny, nx))
+ return path
diff --git a/dnafiber/start.py b/dnafiber/start.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e408cbc2f046893527334ba12b755594ac9b1b7
--- /dev/null
+++ b/dnafiber/start.py
@@ -0,0 +1,22 @@
+import subprocess
+import os
+
+
+def main():
+ # Start the Streamlit application
+ print("Starting Streamlit application...")
+ local_dir = os.path.dirname(os.path.abspath(__file__))
+ subprocess.run(
+ [
+ "streamlit",
+ "run",
+ os.path.join(local_dir, "ui", "Welcome.py"),
+ "--server.maxUploadSize",
+ "1024",
+ ],
+ )
+
+
+if __name__ == "__main__":
+ main()
+ print("Streamlit application started successfully.")
diff --git a/dnafiber/trainee.py b/dnafiber/trainee.py
new file mode 100644
index 0000000000000000000000000000000000000000..41fe1f7fd76520aa5fadc9e7592b9555a09f22e7
--- /dev/null
+++ b/dnafiber/trainee.py
@@ -0,0 +1,148 @@
+from lightning import LightningModule
+import segmentation_models_pytorch as smp
+from monai.losses.dice import GeneralizedDiceLoss
+from monai.losses.cldice import SoftDiceclDiceLoss
+from torchmetrics.classification import Dice, JaccardIndex
+from torch.optim import AdamW
+from torch.optim.lr_scheduler import CosineAnnealingLR
+from torchmetrics import MetricCollection
+import torch.nn.functional as F
+from huggingface_hub import PyTorchModelHubMixin
+import torch
+import torchvision
+from dnafiber.metric import DNAFIBERMetric
+
+
+class Trainee(LightningModule, PyTorchModelHubMixin):
+ def __init__(
+ self, learning_rate=0.001, weight_decay=0.0002, num_classes=3, **model_config
+ ):
+ super().__init__()
+ self.model_config = model_config
+ if (
+ self.model_config.get("arch", None) is None
+ or self.model_config["arch"] == "maskrcnn"
+ ):
+ self.model = None
+ else:
+ self.model = smp.create_model(classes=3, **self.model_config, dropout=0.2)
+ self.loss = GeneralizedDiceLoss(to_onehot_y=False, softmax=False)
+ self.metric = MetricCollection(
+ {
+ "dice": Dice(num_classes=num_classes, ignore_index=0),
+ "jaccard": JaccardIndex(
+ num_classes=num_classes,
+ task="multiclass" if num_classes > 2 else "binary",
+ ignore_index=0,
+ ),
+ "detection": DNAFIBERMetric(),
+ }
+ )
+ self.weight_decay = weight_decay
+ self.learning_rate = learning_rate
+ self.save_hyperparameters()
+
+ def forward(self, x):
+ yhat = self.model(x)
+ return yhat
+
+ def training_step(self, batch, batch_idx):
+ x, y = batch["image"], batch["mask"]
+ y = y.clamp(0, 2)
+ y_hat = self(x)
+ loss = self.get_loss(y_hat, y)
+
+ self.log("train_loss", loss)
+
+ return loss
+
+ def get_loss(self, y_hat, y):
+ y_hat = F.softmax(y_hat, dim=1)
+ y = F.one_hot(y.long(), num_classes=3)
+ y = y.permute(0, 3, 1, 2).float()
+ loss = self.loss(y_hat, y)
+ return loss
+
+ def validation_step(self, batch, batch_idx):
+ x, y = batch["image"], batch["mask"]
+ y = y.clamp(0, 2)
+ y_hat = self(x)
+ loss = self.get_loss(y_hat, y)
+ self.log("val_loss", loss, on_step=False, on_epoch=True, sync_dist=True)
+ self.metric.update(y_hat, y)
+ return y_hat
+
+ def on_validation_epoch_end(self):
+ scores = self.metric.compute()
+ self.log_dict(scores, sync_dist=True)
+ self.metric.reset()
+
+ def configure_optimizers(self):
+ optimizer = AdamW(
+ self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay
+ )
+ scheduler = CosineAnnealingLR(
+ optimizer,
+ T_max=self.trainer.max_epochs, # type: ignore
+ eta_min=self.learning_rate / 25,
+ )
+ scheduler = {
+ "scheduler": scheduler,
+ "interval": "epoch",
+ }
+ return [optimizer], [scheduler]
+
+
+class TraineeMaskRCNN(Trainee):
+ def __init__(self, learning_rate=0.001, weight_decay=0.0002, **model_config):
+ super().__init__(learning_rate, weight_decay, **model_config)
+ self.model = torchvision.models.get_model("maskrcnn_resnet50_fpn_v2")
+
+ def forward(self, x):
+ yhat = self.model(x)
+ return yhat
+
+ def training_step(self, batch, batch_idx):
+ image = batch["image"]
+ targets = batch["targets"]
+ loss_dict = self.model(image, targets)
+ losses = sum(loss for loss in loss_dict.values())
+ self.log("train_loss", losses, on_step=True, on_epoch=False, sync_dist=True)
+ return losses
+
+ def validation_step(self, batch, batch_idx):
+ image = batch["image"]
+ targets = batch["targets"]
+
+ predictions = self.model(image)
+ b = len(predictions)
+ predicted_masks = []
+ gt_masks = []
+ for i in range(b):
+ scores = predictions[i]["scores"]
+ masks = predictions[i]["masks"]
+ good_masks = masks[scores > 0.5]
+ # Combined into a single mask
+ good_masks = torch.sum(good_masks, dim=0)
+ predicted_masks.append(good_masks)
+ gt_masks.append(targets[i]["masks"].sum(dim=0))
+
+ gt_masks = torch.stack(gt_masks).squeeze(1) > 0
+ predicted_masks = torch.stack(predicted_masks).squeeze(1) > 0
+ self.metric.update(predicted_masks, gt_masks)
+ return predictions
+
+ def configure_optimizers(self):
+ optimizer = AdamW(
+ self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay
+ )
+ scheduler = CosineAnnealingLR(
+ optimizer,
+ T_max=self.trainer.max_epochs, # type: ignore
+ eta_min=self.learning_rate / 25,
+ )
+ scheduler = {
+ "scheduler": scheduler,
+ "interval": "epoch",
+ }
+ return [optimizer], [scheduler]
diff --git a/dnafiber/ui/Welcome.py b/dnafiber/ui/Welcome.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1657f273987bf13811815ccd85e442bab27d837
--- /dev/null
+++ b/dnafiber/ui/Welcome.py
@@ -0,0 +1,47 @@
+import streamlit as st
+import torch
+
+
+def main():
+ st.set_page_config(
+ page_title="Hello",
+ page_icon="🧬",
+ layout="wide",
+ )
+ st.write("# Welcome to DN-AI! 👋")
+
+ st.write(
+ "This is a web application for the DN-AI project, which aims to provide an easy-to-use interface for analyzing and processing fiber images."
+ )
+ st.write("## Features")
+ st.write(
+ "- **Image loading**: The application accepts CZI file, jpeg and PNG file. \n"
+ "- **Image segmentation**: The application provides a set of tools to segment the DNA fiber and measure the ratio between analogs. \n"
+ )
+ st.write("## Technical details")
+ cols = st.columns(2)
+ with cols[0]:
+ st.write("### Source")
+ st.write("The source code for this application is available on GitHub.")
+ """
+ [](https://github.com/ClementPla/DeepFiberQ/tree/relabelled)
+
+ """
+ st.markdown("
", unsafe_allow_html=True)
+
+ with cols[1]:
+ st.write("### Device ")
+ st.write("If available, the application will try to use a GPU for processing.")
+ device = "GPU" if torch.cuda.is_available() else "CPU"
+ cols = st.columns(3)
+ with cols[0]:
+ st.write("Running on:")
+ with cols[1]:
+ st.button(device, icon="⚙️", disabled=True)
+ if not torch.cuda.is_available():
+ with cols[2]:
+ st.warning("The application will run on CPU, which may be slower.")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/dnafiber/ui/__init__.py b/dnafiber/ui/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/dnafiber/ui/__pycache__/__init__.cpython-312.pyc b/dnafiber/ui/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d9f8e6978691e9c21072f3ed6bced49f13b641b8
Binary files /dev/null and b/dnafiber/ui/__pycache__/__init__.cpython-312.pyc differ
diff --git a/dnafiber/ui/__pycache__/inference.cpython-312.pyc b/dnafiber/ui/__pycache__/inference.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..793cfaed8e36d0a276a380e829826315c68b9f6b
Binary files /dev/null and b/dnafiber/ui/__pycache__/inference.cpython-312.pyc differ
diff --git a/dnafiber/ui/__pycache__/utils.cpython-312.pyc b/dnafiber/ui/__pycache__/utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c480d1c3f4b0b327fec16c4ae9859f1ced007272
Binary files /dev/null and b/dnafiber/ui/__pycache__/utils.cpython-312.pyc differ
diff --git a/dnafiber/ui/inference.py b/dnafiber/ui/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..1bf55ed6869f692c7bfa8877a67007d078de34b7
--- /dev/null
+++ b/dnafiber/ui/inference.py
@@ -0,0 +1,69 @@
+import streamlit as st
+from dnafiber.inference import infer
+from dnafiber.postprocess.core import refine_segmentation
+import numpy as np
+from dnafiber.deployment import _get_model
+import torch
+
+
+@st.cache_data
+def ui_inference(_model, _image, _device, postprocess=True, id=None):
+ return ui_inference_cacheless(
+ _model, _image, _device, postprocess=postprocess, id=id
+ )
+
+
+@st.cache_resource
+def get_model(model_name):
+ model = _get_model(
+ device="cuda" if torch.cuda.is_available() else "cpu",
+ revision=model_name,
+ )
+ return model
+
+
+def ui_inference_cacheless(_model, _image, _device, postprocess=True, id=None):
+ """
+ A cacheless version of the ui_inference function.
+ This function does not use caching and is intended for use in scenarios where caching is not desired.
+ """
+ h, w = _image.shape[:2]
+ with st.spinner("Sliding window segmentation in progress..."):
+ if isinstance(_model, list):
+ output = None
+ for model in _model:
+ if isinstance(model, str):
+ model = get_model(model)
+ with st.spinner(text="Segmenting with model: {}".format(model)):
+ if output is None:
+ output = infer(
+ model,
+ image=_image,
+ device=_device,
+ scale=st.session_state.get("pixel_size", 0.13),
+ only_probabilities=True,
+ ).cpu()
+ else:
+ output = (
+ output
+ + infer(
+ model,
+ image=_image,
+ device=_device,
+ scale=st.session_state.get("pixel_size", 0.13),
+ only_probabilities=True,
+ ).cpu()
+ )
+ output = (output / len(_model)).argmax(1).squeeze().numpy()
+ else:
+ output = infer(
+ _model,
+ image=_image,
+ device=_device,
+ scale=st.session_state.get("pixel_size", 0.13),
+ )
+ output = output.astype(np.uint8)
+ if postprocess:
+ with st.spinner("Post-processing segmentation..."):
+ output = refine_segmentation(output, fix_junctions=postprocess)
+ return output
diff --git a/dnafiber/ui/pages/1_Load.py b/dnafiber/ui/pages/1_Load.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c1948696ee7a7ba8ce30ff9b7955ac33a7e75c6
--- /dev/null
+++ b/dnafiber/ui/pages/1_Load.py
@@ -0,0 +1,196 @@
+import streamlit as st
+
+
+st.set_page_config(
+ page_title="DN-AI",
+ page_icon="🔬",
+ layout="wide",
+)
+
+def build_multichannel_loader():
+
+ if (
+ st.session_state.get("files_uploaded", None) is None
+ or len(st.session_state.files_uploaded) == 0
+ ):
+ st.session_state["files_uploaded"] = st.file_uploader(
+ label="Upload files",
+ accept_multiple_files=True,
+ type=["czi", "jpeg", "jpg", "png", "tiff", "tif"],
+ )
+ else:
+ st.session_state["files_uploaded"] += st.file_uploader(
+ label="Upload files",
+ accept_multiple_files=True,
+ type=["czi", "jpeg", "jpg", "png", "tiff", "tif"],
+ )
+ st.write("### Channel interpretation")
+ st.markdown("The goal is to obtain an RGB image in the order of First analog, Second analog, Empty.", unsafe_allow_html=True)
+ st.markdown("By default, we assume that the first channel in CZI/TIFF file is the second analog, (which happens to be the case in Zeiss microscope) " \
+ "which means that we swap the order of the two channels for processing.", unsafe_allow_html=True)
+ st.write("If this not the intented behavior, please tick the box below:")
+ st.session_state["reverse_channels"] = st.checkbox(
+ "Reverse the channels interpretation",
+ value=False,
+ )
+ st.warning("Please note that we only swap the channels for raw (CZI, TIFF) files. JPEG and PNG files "\
+ "are assumed to be already in the correct order (First analog in red and second analog in green). " \
+ )
+
+ st.info("" \
+ "The channels order in CZI files does not necessarily match the order in which they are displayed in ImageJ or equivalent. " \
+ "Indeed, such viewers will usually look at the metadata of the file to determine the order of the channels, which we don't. " \
+ "In doubt, we recommend visualizing the image in ImageJ and compare with our viewer. If the channels appear reversed, tick the option above.")
+
+def build_individual_loader():
+
+ cols = st.columns(2)
+ with cols[1]:
+ st.markdown(f"
@firstAnalog
@secondAnalog
Ratio: @ratio', + ) + hover.renderers = [rect1, rect2] + hover.point_policy = "follow_mouse" + hover.attachment = "vertical" + p1.add_tools(hover) + p2.add_tools(hover) + + p1.axis.visible = False + p2.axis.visible = False + fig = gridplot( + [[p2, p1]], + merge_tools=True, + sizing_mode="stretch_width", + toolbar_options=dict(logo=None, help=None), + ) + return fig + + +@st.cache_data +def show_fibers(_prediction, _image, image_id=None): + data = dict( + fiber_id=[], + firstAnalog=[], + secondAnalog=[], + ratio=[], + fiber_type=[], + visualization=[], + ) + + for fiber in _prediction: + data["fiber_id"].append(fiber.fiber_id) + r, g = fiber.counts + red_length = st.session_state["pixel_size"] * r + green_length = st.session_state["pixel_size"] * g + data["firstAnalog"].append(f"{red_length:.3f} ") + data["secondAnalog"].append(f"{green_length:.3f} ") + data["ratio"].append(f"{green_length / red_length:.3f}") + data["fiber_type"].append(fiber.fiber_type) + + x, y, w, h = fiber.bbox + + visu = _image[y : y + h, x : x + w, :] + visu = cv2.normalize(visu, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U) + data["visualization"].append(visu) + + df = pd.DataFrame(data) + df = df.rename( + columns={ + "firstAnalog": "First analog (µm)", + "secondAnalog": "Second analog (µm)", + "ratio": "Ratio", + "fiber_type": "Fiber type", + "fiber_id": "Fiber ID", + "visualization": "Visualization", + } + ) + df["Visualization"] = df["Visualization"].apply(lambda x: numpy_to_base64_png(x)) + return df + + +def start_inference(): + image = st.session_state.image_inference + image = cv2.normalize(image, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U) + + if "ensemble" in st.session_state.model: + model = [ + _ + "_finetuned" if "finetuned" in st.session_state.model else "" + for _ in MODELS_ZOO.values() + if _ != "ensemble" + ] + else: + model = get_model(st.session_state.model) + prediction = ui_inference( + model, + image, + "cuda" if torch.cuda.is_available() else "cpu", + st.session_state.post_process, + st.session_state.image_id, + ) + prediction = [ + p + for p in prediction + if (p.fiber_type != "single") and p.fiber_type != "multiple" + ] + tab_viewer, tab_fibers = st.tabs(["Viewer", "Fibers"]) + + with tab_fibers: + df = show_fibers(prediction, image, st.session_state.image_id) + + event = st.dataframe( + df, + on_select="rerun", + selection_mode="multi-row", + use_container_width=True, + column_config={ + "Visualization": st.column_config.ImageColumn( + "Visualization", + help="Visualization of the fiber", + ) + }, + ) + + rows = event["selection"]["rows"] + columns = df.columns[:-2] + df = df.iloc[rows][columns] + + cols = st.columns(3) + with cols[0]: + copy_to_clipboard = st.button( + "Copy selected fibers to clipboard", + help="Copy the selected fibers to clipboard in CSV format.", + ) + if copy_to_clipboard: + df.to_clipboard(index=False) + with cols[2]: + st.download_button( + "Download selected fibers", + data=df.to_csv(index=False).encode("utf-8"), + file_name=f"fibers_{st.session_state.image_id}.csv", + mime="text/csv", + ) + + with tab_viewer: + max_width = 2048 + if image.shape[1] > max_width: + st.toast("Images are displayed at a lower resolution of 2048 pixel wide") + + fig = display_prediction(prediction, image, st.session_state.image_id) + streamlit_bokeh(fig, use_container_width=True) + + +def on_session_start(): + can_start = ( + st.session_state.get("files_uploaded", None) is not None + and len(st.session_state.files_uploaded) > 0 + ) + + if can_start: + return can_start + + cldu_exists = ( + st.session_state.get("files_uploaded_cldu", None) is not None + and len(st.session_state.files_uploaded_cldu) > 0 + ) + idu_exists = ( + st.session_state.get("files_uploaded_idu", None) is not None + and len(st.session_state.files_uploaded_idu) > 0 + ) + + if cldu_exists and idu_exists: + if len(st.session_state.get("files_uploaded_cldu")) != len( + st.session_state.get("files_uploaded_idu") + ): + st.error("Please upload the same number of CldU and IdU files.") + return False + + +def create_display_files(files): + if files is None or len(files) == 0: + return "No files uploaded" + display_files = [] + for file in files: + if isinstance(file, tuple): + if file[0] is None: + name = f"Second analog only {file[1].name}" + elif file[1] is None: + name = f"First analog only {file[0].name}" + else: + name = f"{file[0].name} and {file[1].name}" + display_files.append(name) + else: + display_files.append(file.name) + return display_files + + +if on_session_start(): + files = st.session_state.files_uploaded + displayed_names = create_display_files(files) + selected_file = st.selectbox( + "Pick an image", + displayed_names, + index=0, + help="Select an image to view and analyze.", + ) + + # Find index of the selected file + index = displayed_names.index(selected_file) + file = files[index] + if isinstance(file, tuple): + file_id = file[0].file_id if file[0] is not None else file[1].file_id + if file[0] is None or file[1] is None: + missing = "First analog" if file[0] is None else "Second analog" + st.warning( + f"In this image, {missing} channel is missing. We assume the intended goal is to segment the DNA fibers without differentiation. \ + Note the model may still predict two classes and try to compute a ratio; these informations can be ignored." + ) + image = get_multifile_image(file) + else: + file_id = file.file_id + image = get_image( + file, + reverse_channel=st.session_state.get("reverse_channels", False), + id=file_id, + ) + h, w = image.shape[:2] + with st.sidebar: + st.metric( + "Pixel size (µm)", + st.session_state.get("pixel_size", 0.13), + ) + + block_size = st.slider( + "Block size", + min_value=256, + max_value=min(4096, max(h, w)), + value=min(2048, max(h, w)), + step=256, + ) + if h < block_size: + block_size = h + if w < block_size: + block_size = w + + bx = by = block_size + image = pad_image_to_croppable(image, bx, by, file_id + str(bx) + str(by)) + thumbnail = get_resized_image(image, file_id) + + blocks = view_as_blocks(image, (bx, by, 3)) + x_blocks, y_blocks = blocks.shape[0], blocks.shape[1] + with st.sidebar: + with st.expander("Model", expanded=True): + model_name = st.selectbox( + "Select a model", + list(MODELS_ZOO.keys()), + index=0, + help="Select a model to use for inference", + ) + finetuned = st.checkbox( + "Use finetuned model", + value=True, + help="Use a finetuned model for inference", + ) + + col1, col2 = st.columns(2) + with col1: + st.write("Running on:") + with col2: + st.button( + "GPU" if torch.cuda.is_available() else "CPU", + disabled=True, + ) + + st.session_state.post_process = st.checkbox( + "Post-process", + value=True, + help="Apply post-processing to the prediction", + ) + + st.session_state.model = ( + (MODELS_ZOO[model_name] + "_finetuned") + if finetuned + else MODELS_ZOO[model_name] + ) + + which_y = st.session_state.get("which_y", 0) + which_x = st.session_state.get("which_x", 0) + + # Display the selected block + # Scale factor + h, w = image.shape[:2] + small_h, small_w = thumbnail.shape[:2] + scale_h = h / small_h + scale_w = w / small_w + # Calculate the coordinates of the block + y1 = math.floor(which_y * bx / scale_h) + y2 = math.floor((which_y + 1) * bx / scale_h) + x1 = math.floor(which_x * by / scale_w) + x2 = math.floor((which_x + 1) * by / scale_w) + # Draw a rectangle around the selected block + + # Check if the coordinates are within the bounds of the image + while y2 > small_h: + which_y -= 1 + y1 = math.floor(which_y * bx / scale_h) + y2 = math.floor((which_y + 1) * bx / scale_h) + while x2 > small_w: + which_x -= 1 + x1 = math.floor(which_x * by / scale_w) + x2 = math.floor((which_x + 1) * by / scale_w) + + st.session_state["which_x"] = which_x + st.session_state["which_y"] = which_y + + # Draw a grid on the thumbnail + for i in range(0, small_h, int(bx // scale_h)): + cv2.line(thumbnail, (0, i), (small_w, i), (255, 255, 255), 1) + for i in range(0, small_w, int(by // scale_w)): + cv2.line(thumbnail, (i, 0), (i, small_h), (255, 255, 255), 1) + + cv2.rectangle( + thumbnail, + (x1, y1), + (x2, y2), + (0, 0, 255), + 5, + ) + + st.write("### Select a block") + + coordinates = streamlit_image_coordinates.streamlit_image_coordinates( + thumbnail, use_column_width=True + ) + + if coordinates: + which_x = math.floor((w * coordinates["x"] / coordinates["width"]) / bx) + which_y = math.floor((h * coordinates["y"] / coordinates["height"]) / by) + if which_x != st.session_state.get("which_x", 0): + st.session_state["which_x"] = which_x + if which_y != st.session_state.get("which_y", 0): + st.session_state["which_y"] = which_y + + st.rerun() + + image = blocks[which_y, which_x, 0] + with st.sidebar: + st.image(image, caption="Selected block", use_container_width=True) + + st.session_state.image_inference = image + st.session_state.image_id = ( + file_id + + str(which_x) + + str(which_y) + + str(bx) + + str(by) + + str(model_name) + + ("_finetuned" if finetuned else "") + ) + col1, col2, col3 = st.columns([1, 1, 1]) + start_inference() +else: + st.switch_page("pages/1_Load.py") + +# Add a callback to mouse move event diff --git a/dnafiber/ui/pages/3_Analysis.py b/dnafiber/ui/pages/3_Analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..1ecb08cc73ac4f574b3838453495aa7eaa17e7a5 --- /dev/null +++ b/dnafiber/ui/pages/3_Analysis.py @@ -0,0 +1,241 @@ +import streamlit as st +import torch +from dnafiber.ui.utils import get_image, get_multifile_image +from dnafiber.deployment import MODELS_ZOO +import pandas as pd +import plotly.express as px +from dnafiber.postprocess import refine_segmentation +import torch.nn.functional as F +from joblib import Parallel, delayed +import time +from catppuccin import PALETTE +from dnafiber.deployment import _get_model +from dnafiber.ui.inference import ui_inference_cacheless + + +def plot_result(seleted_category=None): + if st.session_state.get("results", None) is None or selected_category is None: + return + only_bilateral = st.checkbox( + "Show only bicolor fibers", + value=False, + ) + remove_outliers = st.checkbox( + "Remove outliers", + value=True, + help="Remove outliers from the data", + ) + reorder = st.checkbox( + "Reorder groups by median ratio", + value=True, + ) + if remove_outliers: + min_ratio, max_ratio = st.slider( + "Ratio range", + min_value=0.0, + max_value=10.0, + value=(0.0, 5.0), + step=0.1, + help="Select the ratio range to display", + ) + df = st.session_state.results.copy() + + clean_df = df[["ratio", "image_name", "fiber_type"]].copy() + clean_df["Image"] = clean_df["image_name"] + clean_df["Fiber Type"] = clean_df["fiber_type"] + clean_df["Ratio"] = clean_df["ratio"] + + if only_bilateral: + clean_df = clean_df[clean_df["Fiber Type"] == "double"] + if remove_outliers: + clean_df = clean_df[ + (clean_df["Ratio"] >= min_ratio) & (clean_df["Ratio"] <= max_ratio) + ] + + if selected_category: + clean_df = clean_df[clean_df["Image"].isin(selected_category)] + + if not reorder: + clean_df["Image"] = pd.Categorical( + clean_df["Image"], categories=selected_category, ordered=True + ) + clean_df.sort_values("Image", inplace=True) + + if reorder: + image_order = ( + clean_df.groupby("Image")["Ratio"] + .median() + .sort_values(ascending=True) + .index + ) + clean_df["Image"] = pd.Categorical( + clean_df["Image"], categories=image_order, ordered=True + ) + clean_df.sort_values("Image", inplace=True) + + palette = [c.hex for c in PALETTE.latte.colors] + + fig = px.violin( + clean_df, + y="Ratio", + x="Image", + color="Image", + box=True, # draw box plot inside the violin + points="all", # can be 'outliers', or False + color_discrete_sequence=palette, + ) + # Set y-axis to log scale + st.plotly_chart( + fig, + use_container_width=True, + ) + + +def run_inference(model_name, pixel_size): + is_cuda_available = torch.cuda.is_available() + if "ensemble" in model_name: + model = [ + _ + "_finetuned" if "finetuned" in model_name else "" + for _ in MODELS_ZOO.values() + if _ != "ensemble" + ] + else: + model = _get_model( + revision=model_name, + device="cuda" if is_cuda_available else "cpu", + ) + + my_bar = st.progress(0, text="Running segmentation...") + all_files = st.session_state.files_uploaded + all_results = dict( + FirstAnalog=[], + SecondAnalog=[], + length=[], + ratio=[], + image_name=[], + fiber_type=[], + ) + for i, file in enumerate(all_files): + if isinstance(file, tuple): + if file[0] is None: + filename = file[1].name + if file[1] is None: + filename = file[0].name + image = get_multifile_image(file) + else: + filename = file.name + image = get_image( + file, st.session_state.get("reverse_channels", False), file.file_id + ) + start = time.time() + prediction = ui_inference_cacheless( + _model=model, + _image=image, + _device="cuda" if is_cuda_available else "cpu", + postprocess=False, + ) + print(f"Prediction time: {time.time() - start:.2f} seconds for {file.name}") + h, w = prediction.shape + start = time.time() + if h > 2048 or w > 2048: + # Extract blocks from the prediction + blocks = F.unfold( + torch.from_numpy(prediction).unsqueeze(0).float(), + kernel_size=(4096, 4096), + stride=(4096, 4096), + ) + blocks = blocks.view(4096, 4096, -1).permute(2, 0, 1).byte().numpy() + results = Parallel(n_jobs=4)( + delayed(refine_segmentation)(block) for block in blocks + ) + results = [x for xs in results for x in xs] + + else: + results = refine_segmentation(prediction, fix_junctions=True) + + print(f"Refinement time: {time.time() - start:.2f} seconds for {filename}") + results = [fiber for fiber in results if fiber.is_valid] + all_results["FirstAnalog"].extend([fiber.red * pixel_size for fiber in results]) + all_results["SecondAnalog"].extend( + [fiber.green * pixel_size for fiber in results] + ) + all_results["length"].extend( + [fiber.red * pixel_size + fiber.green * pixel_size for fiber in results] + ) + all_results["ratio"].extend([fiber.ratio for fiber in results]) + all_results["image_name"].extend([filename.split("-")[0] for fiber in results]) + all_results["fiber_type"].extend([fiber.fiber_type for fiber in results]) + + my_bar.progress(i / len(all_files), text=f"{filename} done") + + st.session_state.results = pd.DataFrame.from_dict(all_results) + + my_bar.empty() + + +if st.session_state.get("files_uploaded", None): + run_segmentation = st.button("Run Segmentation", use_container_width=True) + + with st.sidebar: + st.metric( + "Pixel size (µm)", + st.session_state.get("pixel_size", 0.13), + ) + + with st.expander("Model", expanded=True): + model_name = st.selectbox( + "Select a model", + list(MODELS_ZOO.keys()), + index=0, + help="Select a model to use for inference", + ) + finetuned = st.checkbox( + "Use finetuned model", + value=True, + help="Use a finetuned model for inference", + ) + col1, col2 = st.columns(2) + with col1: + st.write("Running on:") + with col2: + st.button( + "GPU" if torch.cuda.is_available() else "CPU", + disabled=True, + ) + + tab_segmentation, tab_charts = st.tabs(["Segmentation", "Charts"]) + + with tab_segmentation: + st.subheader("Segmentation") + if run_segmentation: + run_inference( + model_name=MODELS_ZOO[model_name] + "_finetuned" + if finetuned + else MODELS_ZOO[model_name], + pixel_size=st.session_state.get("pixel_size", 0.13), + ) + st.balloons() + if st.session_state.get("results", None) is not None: + st.write( + st.session_state.results, + ) + + st.download_button( + label="Download results", + data=st.session_state.results.to_csv(index=False).encode("utf-8"), + file_name="results.csv", + mime="text/csv", + use_container_width=True, + ) + with tab_charts: + if st.session_state.get("results", None) is not None: + results = st.session_state.results + + categories = results["image_name"].unique() + selected_category = st.multiselect( + "Select a category", categories, default=categories + ) + plot_result(selected_category) + +else: + st.switch_page("pages/1_Load.py") diff --git a/dnafiber/ui/streamlit/.config.toml b/dnafiber/ui/streamlit/.config.toml new file mode 100644 index 0000000000000000000000000000000000000000..c1f975db792882b84e3cade9e34039f0de8b29f2 --- /dev/null +++ b/dnafiber/ui/streamlit/.config.toml @@ -0,0 +1,3 @@ +[server] +headless = true +port = 8503 \ No newline at end of file diff --git a/dnafiber/ui/utils.py b/dnafiber/ui/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..558a25eeecbdc7feb6dcd5fb0633f9ab807613bb --- /dev/null +++ b/dnafiber/ui/utils.py @@ -0,0 +1,179 @@ +import PIL.Image +import streamlit as st +from dnafiber.data.utils import read_czi, read_tiff, preprocess +import cv2 +import numpy as np +import math +from dnafiber.deployment import _get_model +import PIL +from PIL import Image +import io +import base64 + +MAX_WIDTH = 512 +MAX_HEIGHT = 512 + + + +TYPE_MAPPING = { + 0: "BG", + 1: "SINGLE", + 2: "BILATERAL", + 3: "TRICOLOR", + 4: "MULTICOLOR", +} + +@st.cache_data +def load_image(_filepath, id=None): + filename = str(_filepath.name) + if filename.endswith(".czi"): + return read_czi(_filepath) + elif filename.endswith(".tif") or filename.endswith(".tiff"): + return read_tiff(_filepath) + elif ( + filename.endswith(".png") + or filename.endswith(".jpg") + or filename.endswith(".jpeg") + ): + image = PIL.Image.open(_filepath) + image = np.array(image) + return image + else: + raise NotImplementedError(f"File type {filename} is not supported yet") + + + +@st.cache_data +def get_image(_filepath, reverse_channel, id): + filename = str(_filepath.name) + image = load_image(_filepath, id) + if filename.endswith(".czi") or filename.endswith(".tif") or filename.endswith(".tiff"): + image = preprocess(image, reverse_channel) + image = cv2.normalize( + image, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U + ) + return image + + +def get_multifile_image(_filepaths): + result = None + + if _filepaths[0] is not None: + chan1 = get_image(_filepaths[0], False, _filepaths[0].file_id) + chan1 = cv2.cvtColor(chan1, cv2.COLOR_RGB2GRAY) + h, w = chan1.shape[:2] + else: + chan1 = None + if _filepaths[1] is not None: + chan2 = get_image(_filepaths[1], False, _filepaths[1].file_id) + chan2 = cv2.cvtColor(chan2, cv2.COLOR_RGB2GRAY) + h, w = chan2.shape[:2] + else: + chan2 = None + + result = np.zeros((h, w, 3), dtype=np.uint8) + + if chan1 is not None: + result[:, :, 0] = chan1 + else: + result[:, :, 0] = chan2 + + if chan2 is not None: + result[:, :, 1] = chan2 + else: + result[:, :, 1] = chan1 + + return result + + + +def numpy_to_base64_png(image_array): + """ + Encodes a NumPy image array to a base64 string (PNG format). + + Args: + image_array: A NumPy array representing the image. + + Returns: + A base64 string representing the PNG image. + """ + # Convert NumPy array to PIL Image + image = Image.fromarray(image_array) + + # Create an in-memory binary stream + buffer = io.BytesIO() + + # Save the image to the buffer in PNG format + image.save(buffer, format="jpeg") + + # Get the byte data from the buffer + png_data = buffer.getvalue() + + # Encode the byte data to base64 + base64_encoded = base64.b64encode(png_data).decode() + + return f"data:image/jpeg;base64,{base64_encoded}" + + +@st.cache_data +def get_resized_image(_image, id): + h, w = _image.shape[:2] + if w > MAX_WIDTH: + scale = MAX_WIDTH / w + new_size = (int(w * scale), int(h * scale)) + resized_image = cv2.resize(_image, new_size, interpolation=cv2.INTER_NEAREST) + else: + resized_image = _image + if h > MAX_HEIGHT: + scale = MAX_HEIGHT / h + new_size = (int(w * scale), int(h * scale)) + resized_image = cv2.resize( + resized_image, new_size, interpolation=cv2.INTER_NEAREST + ) + else: + resized_image = resized_image + return resized_image + + +def bokeh_imshow(fig, image): + # image is a numpy array of shape (h, w, 3) or (h, w) of type uint8 + if len(image.shape) == 2: + # grayscale image + image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) + + # Convert to h*w with uint32 + img = np.empty((image.shape[0], image.shape[1]), dtype=np.uint32) + view = img.view(dtype=np.uint8).reshape((image.shape[0], image.shape[1], 4)) # RGBA + view[:, :, 0] = image[:, :, 0] + view[:, :, 1] = image[:, :, 1] + view[:, :, 2] = image[:, :, 2] + view[:, :, 3] = 255 # Alpha channel + fig.image_rgba(image=[img], x=0, y=0, dw=image.shape[1], dh=image.shape[0]) + + +@st.cache_resource +def get_model(device, revision=None): + return _get_model(revision=revision, device=device) + + +def pad_image_to_croppable(_image, bx, by, uid=None): + # Pad the image to be divisible by bx and by + h, w = _image.shape[:2] + if h % bx != 0: + pad_h = bx - (h % bx) + else: + pad_h = 0 + if w % by != 0: + pad_w = by - (w % by) + else: + pad_w = 0 + _image = cv2.copyMakeBorder( + _image, + math.ceil(pad_h / 2), + math.floor(pad_h / 2), + math.ceil(pad_w / 2), + math.floor(pad_w / 2), + cv2.BORDER_CONSTANT, + value=(0, 0, 0), + ) + return _image diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..21695529aa181a6efca6c8bc7006769727bcd1fb --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,39 @@ +[tool.poetry] +name = "dnafiber" +version = "0.2.0" +description = "" +authors = ["ClementPla