diff --git a/README.md b/README.md index c0e7f1250da9c5b13a344240b6e708a520576d36..c8357a2347bd0b97667b8176fa02d6aac884cf1a 100644 --- a/README.md +++ b/README.md @@ -1,20 +1,59 @@ ---- -title: DNAI -emoji: 🚀 -colorFrom: red -colorTo: red -sdk: docker -app_port: 8501 -tags: -- streamlit -pinned: false -short_description: DNA Fiber semantic segmentation for replication assessment -license: mit ---- - -# Welcome to Streamlit! - -Edit `/src/streamlit_app.py` to customize this app to your heart's desire. :heart: - -If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community -forums](https://discuss.streamlit.io). +# DN-AI + +This is the official repository for DN-AI, an automated tool for measurement of differentiated DNA replication in fluorescence microscopy images. + +DN-AI offers different solutions for biologists to measure DNA replication in fluorescence microscopy images, without requiring programming skills. See the [Installation](#installation) section for instructions on how to install DN-AI. + +## Features + +- **Automated DNA replication measurement**: DN-AI can automatically measure the amount of DNA replication in fluorescence microscopy images. We use a deep learning model to segment the images and measure the amount of DNA replication. +- **User-friendly interface**: DN-AI provides a web-based user-friendly interface that allows users to easily upload images and view the results. Both jpeg and tiff images are supported. +- **Batch processing**: DN-AI can process multiple images at once, making it easy to analyze large datasets. It also supports comparing ratios between different batches of images. + + +## Installation + +DN-AI relies on Python. We recommend installing its latest version (3.10 or higher) and using a virtual environment to avoid conflicts with other packages. + +### Prerequisites +Before installing DN-AI, make sure you have the following prerequisites installed: +- [Python 3.10 or higher](https://www.python.org/downloads/) +- [pip](https://pip.pypa.io/en/stable/installation/) (Python package installer) + +### Python Package +To install DN-AI as a Python package, you can use pip: + +```bash +pip install git+https://github.com/ClementPla/DeepFiberQ.git +``` + + +### Graphical User Interface (GUI) + +To run the DN-AI graphical user interface, you can use the following command: + +```bash +DNAI +``` + +Make sure you are running this command in the terminal where you have installed DN-AI. This will start a local web server and you will see output similar to: + + +Then open your web browser and go to `http://localhost:8501` to access the DN-AI interface. + +Screenshots of the GUI: + +![DN-AI GUI](imgs/screenshot.png) + + + +### Docker +A Docker image is available for DN-AI. You can pull the image from Docker Hub: + +```bash +docker pull clementpla/dnafiber +``` + +### Google Colab +We also provide a Google Colab notebook for DN-AI. You can access it [here](https://colab.research.google.com/github/ClementPla/DeepFiberQ/blob/main/Colab/DNA_Fiber_Q.ipynb). + diff --git a/dnafiber/__init__.py b/dnafiber/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0af7c0bdc39e9496dd9ce13b58670c432dc7ed69 --- /dev/null +++ b/dnafiber/__init__.py @@ -0,0 +1 @@ +from dnafiber.deployment import _get_model diff --git a/dnafiber/__pycache__/__init__.cpython-312.pyc b/dnafiber/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dbd6e37a07fcfa7a2643ec7943b810897fa211ed Binary files /dev/null and b/dnafiber/__pycache__/__init__.cpython-312.pyc differ diff --git a/dnafiber/__pycache__/deployment.cpython-312.pyc b/dnafiber/__pycache__/deployment.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..770cfcbe5b6ddd3b0a8e57b2c67c824592159a13 Binary files /dev/null and b/dnafiber/__pycache__/deployment.cpython-312.pyc differ diff --git a/dnafiber/__pycache__/inference.cpython-312.pyc b/dnafiber/__pycache__/inference.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed38d561a723dbdd96b13e5365f1c844c6ffe8db Binary files /dev/null and b/dnafiber/__pycache__/inference.cpython-312.pyc differ diff --git a/dnafiber/__pycache__/metric.cpython-312.pyc b/dnafiber/__pycache__/metric.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b7d15adb0604ccf6dde5ba60d05491250742f34 Binary files /dev/null and b/dnafiber/__pycache__/metric.cpython-312.pyc differ diff --git a/dnafiber/__pycache__/post_process.cpython-312.pyc b/dnafiber/__pycache__/post_process.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6ca6169409036efc2298fa2307b81341de9a9d2 Binary files /dev/null and b/dnafiber/__pycache__/post_process.cpython-312.pyc differ diff --git a/dnafiber/__pycache__/trainee.cpython-312.pyc b/dnafiber/__pycache__/trainee.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba7d271e535bd6475e0baa6bd77b37c483730320 Binary files /dev/null and b/dnafiber/__pycache__/trainee.cpython-312.pyc differ diff --git a/dnafiber/analysis/__init__.py b/dnafiber/analysis/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/dnafiber/analysis/chart.py b/dnafiber/analysis/chart.py new file mode 100644 index 0000000000000000000000000000000000000000..c60087e4b1bb890a7697cb5adbea545f57df380e --- /dev/null +++ b/dnafiber/analysis/chart.py @@ -0,0 +1,61 @@ +import pandas as pd +from dnafiber.analysis.const import palette +import plotly.express as px + + +def get_color_association(df): + """ + Get the color association for each image in the dataframe. + """ + unique_name = df["image_name"].unique() + color_association = {i: p for (i, p) in zip(unique_name, palette)} + return color_association + + +def plot_ratio(df, color_association=None, only_bilateral=True): + df = df[["ratio", "image_name", "fiber_type"]].copy() + + df["Image"] = df["image_name"] + df["Fiber Type"] = df["fiber_type"] + df["Ratio"] = df["ratio"] + if only_bilateral: + df = df[df["Fiber Type"] == "double"] + + df = df.sort_values( + by=["Image", "Fiber Type"], + ascending=[True, True], + ) + + # Order the dataframe by the average ratio of each image + image_order = ( + df.groupby("Image")["Ratio"].median().sort_values(ascending=True).index + ) + df["Image"] = pd.Categorical(df["Image"], categories=image_order, ordered=True) + df.sort_values("Image", inplace=True) + if color_association is None: + color_association = get_color_association(df) + unique_name = df["image_name"].unique() + color_association = {i: p for (i, p) in zip(unique_name, palette)} + + this_palette = [color_association[i] for i in unique_name] + fig = px.violin( + df, + y="Ratio", + x="Image", + color="Image", + color_discrete_sequence=this_palette, + box=True, # draw box plot inside the violin + points="all", # can be 'outliers', or False + ) + + # Make the fig taller + + fig.update_layout( + height=500, + width=1000, + title="Ratio of green to red", + yaxis_title="Ratio", + xaxis_title="Image", + legend_title="Image", + ) + return fig diff --git a/dnafiber/analysis/const.py b/dnafiber/analysis/const.py new file mode 100644 index 0000000000000000000000000000000000000000..21121587325a77401fc3cd26ac9164b0a92a0ae3 --- /dev/null +++ b/dnafiber/analysis/const.py @@ -0,0 +1,3 @@ +from catppuccin.palette import PALETTE + +palette = [c.hex for c in PALETTE.latte.colors] diff --git a/dnafiber/analysis/utils.py b/dnafiber/analysis/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8a158f17ee2eedb2f461d14042ec2e15a1cbd6a5 --- /dev/null +++ b/dnafiber/analysis/utils.py @@ -0,0 +1,21 @@ +from tqdm.auto import tqdm +from dnafiber.data.utils import read_colormask +import numpy as np + + +def build_consensus_map(intergraders, root_img, list_img): + all_masks = [] + for img_path in tqdm(list_img): + path_from_root = img_path.relative_to(root_img) + masks = [] + for intergrader in intergraders: + intergrader_path = (intergrader / path_from_root).with_suffix(".png") + if not intergrader_path.exists(): + print(f"Missing {intergrader_path}") + continue + mask = read_colormask(intergrader_path) + masks.append(mask) + masks = np.array(masks) + + all_masks.append(masks) + return np.array(all_masks) diff --git a/dnafiber/callbacks.py b/dnafiber/callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..d58d978a24c50b07c45e729faccfef8d45e709b3 --- /dev/null +++ b/dnafiber/callbacks.py @@ -0,0 +1,50 @@ +from lightning.pytorch.callbacks import Callback +from pytorch_lightning.utilities import rank_zero_only +import wandb + + +class LogPredictionSamplesCallback(Callback): + def __init__(self, wandb_logger, n_images=8): + self.n_images = n_images + self.wandb_logger = wandb_logger + super().__init__() + + @rank_zero_only + def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + if batch_idx < 1 and trainer.is_global_zero: + n = self.n_images + x = batch["image"][:n].float() + h, w = x.shape[-2:] + y = batch["mask"][:n] + pred = outputs[:n] + pred = pred.argmax(dim=1) + + if len(y.shape) == 4: + y = y.squeeze(1) + if len(pred.shape) == 4: + pred = pred.squeeze(1) + y = y.clamp(0, 2) + columns = ["image"] + class_labels = {0: "Background", 1: "Red", 2: "Green"} + + data = [ + [ + wandb.Image( + x_i, + masks={ + "Prediction": { + "mask_data": p_i.cpu().numpy(), + "class_labels": class_labels, + }, + "Groundtruth": { + "mask_data": y_i.cpu().numpy(), + "class_labels": class_labels, + }, + }, + ) + ] + for x_i, y_i, p_i in list(zip(x, y, pred)) + ] + self.wandb_logger.log_table( + data=data, key=f"Validation Batch {batch_idx}", columns=columns + ) diff --git a/dnafiber/data/__init__.py b/dnafiber/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/dnafiber/data/__pycache__/__init__.cpython-312.pyc b/dnafiber/data/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5c1bc70ad0f1c2d99af47542e2f3477e3c601268 Binary files /dev/null and b/dnafiber/data/__pycache__/__init__.cpython-312.pyc differ diff --git a/dnafiber/data/__pycache__/utils.cpython-312.pyc b/dnafiber/data/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0987d53b295fe7c72074cae27c9651ea3dcb170d Binary files /dev/null and b/dnafiber/data/__pycache__/utils.cpython-312.pyc differ diff --git a/dnafiber/data/dataset.py b/dnafiber/data/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..550497f0489d51da06bfa610af692bf86896f4b9 --- /dev/null +++ b/dnafiber/data/dataset.py @@ -0,0 +1,271 @@ +import albumentations as A +import nntools.dataset as D +import numpy as np +from albumentations.pytorch import ToTensorV2 +from lightning import LightningDataModule +from sklearn.model_selection import train_test_split +from torch.utils.data import DataLoader +from skimage.measure import label, regionprops +from skimage.morphology import skeletonize, dilation +from skimage.segmentation import expand_labels +import torch +from nntools.dataset.composer import CacheBullet + + +@D.nntools_wrapper +def convert_mask(mask): + output = np.zeros(mask.shape[:2], dtype=np.uint8) + output[mask[:, :, 0] > 200] = 1 + output[mask[:, :, 1] > 200] = 2 + binary_mask = output > 0 + skeleton = skeletonize(binary_mask) * output + output = expand_labels(skeleton, 3) + output = np.clip(output, 0, 2) + return {"mask": output} + + +@D.nntools_wrapper +def extract_bbox(mask): + binary_mask = mask > 0 + labelled = label(binary_mask) + props = regionprops(labelled, intensity_image=mask) + skeleton = skeletonize(binary_mask) * mask + mask = dilation(skeleton, np.ones((3, 3))) + bboxes = [] + masks = [] + # We want the XYXY format + for prop in props: + minr, minc, maxr, maxc = prop.bbox + bboxes.append([minc, minr, maxc, maxr]) + masks.append((labelled == prop.label).astype(np.uint8)) + if not masks: + masks = np.zeros_like(mask)[np.newaxis, :, :] + masks = np.array(masks) + masks = np.moveaxis(masks, 0, -1) + + return { + "bboxes": np.array(bboxes), + "mask": masks, + "fiber_ids": np.array([p.label for p in props]), + } + + +class FiberDatamodule(LightningDataModule): + def __init__( + self, + root_img, + crop_size=(256, 256), + shape=1024, + batch_size=32, + num_workers=8, + use_bbox=False, + **kwargs, + ): + self.shape = shape + self.root_img = str(root_img) + self.crop_size = crop_size + self.batch_size = batch_size + self.num_workers = num_workers + self.kwargs = kwargs + self.use_bbox = use_bbox + + super().__init__() + + def setup(self, *args, **kwargs): + def _get_dataset(version): + dataset = D.MultiImageDataset( + { + "image": f"{self.root_img}/{version}/images/", + "mask": f"{self.root_img}/{version}/annotations/", + }, + shape=(self.shape, self.shape), + use_cache=self.kwargs.get("use_cache", False), + cache_option=self.kwargs.get("cache_option", None), + ) # type: ignore + dataset.img_filepath["image"] = np.asarray( # type: ignore + sorted( + list(dataset.img_filepath["image"]), + key=lambda x: (x.parent.stem, x.stem), + ) + ) + dataset.img_filepath["mask"] = np.asarray( # type: ignore + sorted( + list(dataset.img_filepath["mask"]), + key=lambda x: (x.parent.stem, x.stem), + ) + ) + dataset.composer = D.Composition() + dataset.composer << convert_mask # type: ignore + if self.use_bbox: + dataset.composer << extract_bbox + + return dataset + + self.train = _get_dataset("train") + self.val = _get_dataset("train") + self.test = _get_dataset("test") + self.train.composer << CacheBullet() + self.val.use_cache = False + self.test.use_cache = False + + stratify = [] + for f in self.train.img_filepath["image"]: + if "tile" in f.stem: + stratify.append(int(f.parent.stem)) + else: + stratify.append(25) + train_idx, val_idx = train_test_split( + np.arange(len(self.train)), # type: ignore + stratify=stratify, + test_size=0.2, + random_state=42, + ) + self.train.subset(train_idx) + self.val.subset(val_idx) + + self.train.composer.add(*self.get_train_composer()) + self.val.composer.add(*self.cast_operators()) + self.test.composer.add(*self.cast_operators()) + + def get_train_composer(self): + transforms = [] + if self.crop_size is not None: + transforms.append( + A.CropNonEmptyMaskIfExists( + width=self.crop_size[0], height=self.crop_size[1] + ), + ) + return [ + A.Compose( + transforms + + [ + A.HorizontalFlip(), + A.VerticalFlip(), + A.Affine(), + A.ElasticTransform(), + A.RandomRotate90(), + A.OneOf( + [ + A.RandomBrightnessContrast( + brightness_limit=(-0.2, 0.1), + contrast_limit=(-0.2, 0.1), + p=0.5, + ), + A.HueSaturationValue( + hue_shift_limit=(-5, 5), + sat_shift_limit=(-20, 20), + val_shift_limit=(-20, 20), + p=0.5, + ), + ] + ), + A.GaussNoise(std_range=(0.0, 0.1), p=0.5), + ], + bbox_params=A.BboxParams( + format="pascal_voc", label_fields=["fiber_ids"], min_visibility=0.95 + ) + if self.use_bbox + else None, + ), + *self.cast_operators(), + ] + + def cast_operators(self): + return [ + A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) + if not self.use_bbox + else A.Normalize( + mean=( + 0.0, + 0.0, + 0.0, + ), + std=(1.0, 1.0, 1.0), + max_pixel_value=255, + ), + ToTensorV2(), + ] + + def train_dataloader(self): + if self.use_bbox: + return DataLoader( + self.train, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, + pin_memory=True, + persistent_workers=True, + collate_fn=bbox_collate_fn, + ) + + else: + return DataLoader( + self.train, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, + pin_memory=True, + persistent_workers=True, + ) + + def val_dataloader(self): + if self.use_bbox: + return DataLoader( + self.val, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + pin_memory=True, + persistent_workers=True, + collate_fn=bbox_collate_fn, + ) + return DataLoader( + self.val, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + + def test_dataloader(self): + if self.use_bbox: + return DataLoader( + self.test, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + pin_memory=True, + persistent_workers=True, + collate_fn=bbox_collate_fn, + ) + return DataLoader( + self.test, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + + +def bbox_collate_fn(batch): + images = [] + targets = [] + + for b in batch: + target = dict() + + target["boxes"] = torch.from_numpy(b["bboxes"]) + if target["boxes"].shape[0] == 0: + target["boxes"] = torch.zeros((0, 4), dtype=torch.float32) + images.append(b["image"]) + target["boxes"] = torch.from_numpy(b["bboxes"]) + target["masks"] = b["mask"].permute(2, 0, 1) + if target["boxes"].shape[0] == 0: + target["labels"] = torch.zeros(1, dtype=torch.int64) + else: + target["labels"] = torch.ones_like(target["boxes"][:, 0], dtype=torch.int64) + + targets.append(target) + + return { + "image": torch.stack(images), + "targets": targets, + } diff --git a/dnafiber/data/intergrader/__init__.py b/dnafiber/data/intergrader/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..983f876cace65cecee3b1587f6dd49e086d83a48 --- /dev/null +++ b/dnafiber/data/intergrader/__init__.py @@ -0,0 +1 @@ +from .const import * \ No newline at end of file diff --git a/dnafiber/data/intergrader/__pycache__/__init__.cpython-312.pyc b/dnafiber/data/intergrader/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d74f18a56e8ec761c0df6edbdf0181283aa721e Binary files /dev/null and b/dnafiber/data/intergrader/__pycache__/__init__.cpython-312.pyc differ diff --git a/dnafiber/data/intergrader/__pycache__/analysis.cpython-312.pyc b/dnafiber/data/intergrader/__pycache__/analysis.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4e41ac0c51b247d8b9fcbcb5c205c743d70bb3d Binary files /dev/null and b/dnafiber/data/intergrader/__pycache__/analysis.cpython-312.pyc differ diff --git a/dnafiber/data/intergrader/__pycache__/const.cpython-312.pyc b/dnafiber/data/intergrader/__pycache__/const.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53bed8195ecccacbd65f01f998d2e8d2ee5b3ae3 Binary files /dev/null and b/dnafiber/data/intergrader/__pycache__/const.cpython-312.pyc differ diff --git a/dnafiber/data/intergrader/__pycache__/io.cpython-312.pyc b/dnafiber/data/intergrader/__pycache__/io.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bba9c87fd011f943fa8c46bff5e35e2410dafed2 Binary files /dev/null and b/dnafiber/data/intergrader/__pycache__/io.cpython-312.pyc differ diff --git a/dnafiber/data/intergrader/__pycache__/plot.cpython-312.pyc b/dnafiber/data/intergrader/__pycache__/plot.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e48b3022a399c90bd5334dcbe9b854f75719080 Binary files /dev/null and b/dnafiber/data/intergrader/__pycache__/plot.cpython-312.pyc differ diff --git a/dnafiber/data/intergrader/analysis.py b/dnafiber/data/intergrader/analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..d9b7fa9e3db75332c82f5998ee3dcc0bb417874c --- /dev/null +++ b/dnafiber/data/intergrader/analysis.py @@ -0,0 +1,120 @@ +from skimage.morphology import skeletonize +import numpy as np +from skimage.measure import label +from tqdm.contrib.concurrent import thread_map # or thread_map +def extract_fiber_properties(mask): + + binary_mask = mask > 0 + skeleton = skeletonize(binary_mask) + r = mask == 1 + g = mask == 2 + labeled_skeleton = label(skeleton, connectivity=2) + properties = {"R": [], "G": [], "ratio": []} + for i in range(1, labeled_skeleton.max() + 1): + fiber_mask = labeled_skeleton == i + sum_r = np.sum(r & fiber_mask) + sum_g = np.sum(g & fiber_mask) + if sum_r == 0 or sum_g == 0: + continue + properties["R"].append(np.sum(r & fiber_mask)) + properties["G"].append(np.sum(g & fiber_mask)) + + properties["R"] = np.array(properties["R"]) + properties["G"] = np.array(properties["G"]) + properties["ratio"] = properties["R"] / (properties["G"]) + properties["label"] = labeled_skeleton + return properties + + +def filter_non_commons_fibers(properties): + # Properties is a a list of dicts. For each dict, we have a labelmap and a list of reds, greens and ratios + # We want to filter out the fibers that are not common in all images + + binary_labels = [p['label'] > 0 for p in properties] + common_labels = np.logical_and.reduce(binary_labels) + filtered_properties = {k:[] for k in properties.keys()} + for i, p in enumerate(properties): + # We want to keep the labels that are common in all images + good_labels = common_labels * p['label'] + indices = np.unique(good_labels[good_labels > 0]) + + filtered_properties.append({ + "R": p["R"][common_labels], + "G": p["G"][common_labels], + "ratio": p["ratio"][common_labels], + "label": p["label"][common_labels] + }) + +def skeletonize_mask(mask): + # Skeletonize the mask and return the skeleton + binary_mask = mask > 0 + skeleton = skeletonize(binary_mask) * mask + return skeleton + + +def skeletonize_data_dict(data_dict): + skeletons = dict() + for annotator, images in data_dict.items(): + skeletons[annotator] = dict() + for image_type, masks in images.items(): + skeletons[annotator][image_type] = thread_map(skeletonize_mask, masks, max_workers=8) + + return skeletons + + +def extract_properties_from_datadict(data_dict, with_common_analysis=True): + """ + Extract the properties of the fibers from the data dictionary. + The data dictionary is a dict of annotators. Each value is a dict of images. Each image is a list of masks. + """ + properties = dict(annotator=[], image_type=[], red=[], green=[], ratio=[], fiber_type=[]) + all_annotators = list(data_dict.keys()) + + found_by = {a: [] for a in all_annotators} + properties.update(found_by) + for annotator, images in data_dict.items(): + for image_type, masks in images.items(): + for i, mask in enumerate(masks): + if with_common_analysis: + others_masks = [] + other_annotators = [] + for other in all_annotators: + if other == annotator: + continue + other_annotators.append(other) + others_masks.append(data_dict[other][image_type][i] > 0) + + labels, num = label(mask>0, connectivity=2, return_num=True) + for l in range(1, num + 1): + fiber = labels == l + if np.sum(fiber) < 10: + continue + + properties["annotator"].append(annotator) + properties["image_type"].append(image_type) + + # Check for common fibers + properties[annotator].append(True) + if with_common_analysis: + for i, (other_mask, other_annotator) in enumerate(zip(others_masks, other_annotators)): + properties[other_annotator].append(np.any(fiber & other_mask)) + + red_length = np.sum(fiber & (mask == 1)) + green_length = np.sum(fiber & (mask == 2)) + if red_length == 0 or green_length == 0: + continue + properties["ratio"].append(green_length / (red_length + 1e-7)) # Avoid division by zero + properties["red"].append(red_length) + properties["green"].append(green_length) + + segments, count = label(mask[fiber], connectivity=1, return_num=True) + if count == 1: + properties["fiber_type"].append("single") + elif count == 2: + properties["fiber_type"].append("double") + elif count > 2: + properties["fiber_type"].append("multiple") + else: + properties["fiber_type"].append("unknown") + + return properties \ No newline at end of file diff --git a/dnafiber/data/intergrader/auto.py b/dnafiber/data/intergrader/auto.py new file mode 100644 index 0000000000000000000000000000000000000000..ed06ec6053d499748f83324a4df42547c112fd66 --- /dev/null +++ b/dnafiber/data/intergrader/auto.py @@ -0,0 +1,3 @@ +def inference_model(model, path, use_cuda=False): + pass + \ No newline at end of file diff --git a/dnafiber/data/intergrader/const.py b/dnafiber/data/intergrader/const.py new file mode 100644 index 0000000000000000000000000000000000000000..d51eaa079e3d4fb8782b3e1c149aaf78b900a5e0 --- /dev/null +++ b/dnafiber/data/intergrader/const.py @@ -0,0 +1,21 @@ +BLIND_MAPPING = { + "siB+M-01": "0", + "siB+M-04": "1", + "siBRCA2-02": "5", + "siBRCA2-03": "15", + "siTONSL-03": "11", + "siTONSL-04": "14", + "HLTF ko+si MMS22L-01": "8", + "HLTF ko+si MMS22L-02": "13", + "siBRCA2+SMARCAL KO-01": "2", + "siBRCA2+SMARCAL KO-03": "9", + "siBRCA2+SMARCAL KO-04": "16", + "siBRCA2-01": "4", + "59_siBRCA2-02": "7", + "siNT-01": "10", + "siNT-02": "12", + "siMMS22L_+dox-01": "3", + "siMMS22L_+dox-02": "6", +} + +REVERSE_BLIND_MAPPING = {v: k for k, v in BLIND_MAPPING.items()} \ No newline at end of file diff --git a/dnafiber/data/intergrader/io.py b/dnafiber/data/intergrader/io.py new file mode 100644 index 0000000000000000000000000000000000000000..bdb3aa39f9e8c4f0d59e4e4130a25c80e6a080b6 --- /dev/null +++ b/dnafiber/data/intergrader/io.py @@ -0,0 +1,27 @@ +import cv2 +import numpy as np +from skimage.segmentation import expand_labels + +def read_to_mask(f): + img = cv2.imread(str(f), cv2.IMREAD_UNCHANGED)[:,:,::-1] + mask = np.zeros(img.shape[:2], dtype=np.uint8) + mask[img[:, :, 0] > 200] = 1 + mask[img[:, :, 1] > 200] = 2 + + return mask + + +def read_mask_from_path_gens(dict_gens, mapping=None): + output = {k: dict() for k in dict_gens.keys()} + for k, files in dict_gens.items(): + for file in files: + name = file.parent.stem + if mapping is not None: + name = mapping.get(name, name) + mask = read_to_mask(file) + mask = expand_labels(mask, 1) + if output[k].get(name) is None: + output[k][name] = [] + output[k][name].append(mask) + return output + diff --git a/dnafiber/data/intergrader/plot.py b/dnafiber/data/intergrader/plot.py new file mode 100644 index 0000000000000000000000000000000000000000..c8d185fcddee4f037dc873d03e093fdb5d1278ea --- /dev/null +++ b/dnafiber/data/intergrader/plot.py @@ -0,0 +1,172 @@ +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.colors import ListedColormap +from skimage.measure import label, regionprops +import base64 +from typing import Callable + +def imshow_compare(data_dict, ax_size=4, draw_bbox=False, max_images=None): + """ + Display the images in a grid format for comparison. + Each key is an annotator, each value is another dict, where the key is the image type and the value the list of corresponding images. + """ + # 0 is black, 1 is red, 2 is green + cmap = ListedColormap(['black', 'red', 'green']) + + # Convert the data dictionary to a dict of annotators: list of images + data = dict() + for annotator, images in data_dict.items(): + if annotator not in data: + data[annotator] = [] + for image_type, masks in images.items(): + for mask in masks: + data[annotator].append(mask) + + annotators = list(data.keys()) + num_images = len(data[annotators[0]]) + if max_images is not None and num_images > max_images: + num_images = max_images + num_annotators = len(annotators) + + fig_size = (ax_size * num_annotators, ax_size * num_images) + fig, axes = plt.subplots(num_images, num_annotators, figsize=fig_size, squeeze=False) + + for i, annotator in enumerate(annotators): + for j in range(num_images): + if max_images is not None and j > max_images: + break + ax = axes[j, i] + mask = data[annotator][j] + ax.imshow(mask, cmap=cmap, interpolation='nearest') + ax.axis('off') + ax.set_xticks([]) + ax.set_yticks([]) + if draw_bbox: + mask = mask > 0 + labeled_mask = label(mask, connectivity=2) + regions = regionprops(labeled_mask) + for region in regions: + minr, minc, maxr, maxc = region.bbox + rect = plt.Rectangle((minc, minr), maxc - minc, maxr - minr, + fill=False, edgecolor='yellow', linewidth=0.5) + ax.add_patch(rect) + + + + if j == 0: + ax.set_title(annotator) + + + fig.tight_layout() + return fig, axes + + +def add_p_value_annotation(fig, array_columns, stats_test, subplot=None, _format=dict(interline=0.07, text_height=1.07, color='black')): + ''' Adds notations giving the p-value between two box plot data (t-test two-sided comparison) + + Parameters: + ---------- + fig: figure + plotly boxplot figure + array_columns: np.array + array of which columns to compare + e.g.: [[0,1], [1,2]] compares column 0 with 1 and 1 with 2 + subplot: None or int + specifies if the figures has subplots and what subplot to add the notation to + _format: dict + format characteristics for the lines + + Returns: + ------- + fig: figure + figure with the added notation + ''' + # Specify in what y_range to plot for each pair of columns + y_range = np.zeros([len(array_columns), 2]) + for i in range(len(array_columns)): + y_range[i] = [1.01+i*_format['interline'], 1.02+i*_format['interline']] + + # Get values from figure + fig_dict = fig.to_dict() + # Get indices if working with subplots + if subplot: + if subplot == 1: + subplot_str = '' + else: + subplot_str =str(subplot) + indices = [] #Change the box index to the indices of the data for that subplot + for index, data in enumerate(fig_dict['data']): + #print(index, data['xaxis'], 'x' + subplot_str) + if data['xaxis'] == 'x' + subplot_str: + indices = np.append(indices, index) + indices = [int(i) for i in indices] + print((indices)) + else: + subplot_str = '' + + # Print the p-values + for index, column_pair in enumerate(array_columns): + if subplot: + data_pair = [indices[column_pair[0]], indices[column_pair[1]]] + else: + data_pair = column_pair + + # Mare sure it is selecting the data and subplot you want + #print('0:', fig_dict['data'][data_pair[0]]['name'], fig_dict['data'][data_pair[0]]['xaxis']) + #print('1:', fig_dict['data'][data_pair[1]]['name'], fig_dict['data'][data_pair[1]]['xaxis']) + + if isinstance(stats_test, Callable): + # Get the p-value + d1 = fig_dict['data'][data_pair[0]]['y'] + d2 = fig_dict['data'][data_pair[1]]['y'] + d1 = base64.b64decode(d1['bdata']) + d2 = base64.b64decode(d2['bdata']) + d1 = np.frombuffer(d1, dtype=np.float64) + d2 = np.frombuffer(d2, dtype=np.float64) + pvalue = stats_test( + d1, + d2, + )[1] + else: + pvalue = stats_test[index] + if pvalue >= 0.05: + symbol = 'ns' + elif pvalue >= 0.01: + symbol = '*' + elif pvalue >= 0.001: + symbol = '**' + else: + symbol = '***' + # Vertical line + fig.add_shape(type="line", + xref="x"+subplot_str, yref="y"+subplot_str+" domain", + x0=column_pair[0], y0=y_range[index][0], + x1=column_pair[0], y1=y_range[index][1], + line=dict(color=_format['color'], width=2,) + ) + # Horizontal line + fig.add_shape(type="line", + xref="x"+subplot_str, yref="y"+subplot_str+" domain", + x0=column_pair[0], y0=y_range[index][1], + x1=column_pair[1], y1=y_range[index][1], + line=dict(color=_format['color'], width=2,) + ) + # Vertical line + fig.add_shape(type="line", + xref="x"+subplot_str, yref="y"+subplot_str+" domain", + x0=column_pair[1], y0=y_range[index][0], + x1=column_pair[1], y1=y_range[index][1], + line=dict(color=_format['color'], width=2,) + ) + ## add text at the correct x, y coordinates + ## for bars, there is a direct mapping from the bar number to 0, 1, 2... + fig.add_annotation(dict(font=dict(color=_format['color'],size=14), + x=(column_pair[0] + column_pair[1])/2, + y=y_range[index][1]*_format['text_height'], + showarrow=False, + text=symbol, + textangle=0, + xref="x"+subplot_str, + yref="y"+subplot_str+" domain" + )) + return fig \ No newline at end of file diff --git a/dnafiber/data/utils.py b/dnafiber/data/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..48b48cb11d7124bcc7890d1a24d99537d1d56696 --- /dev/null +++ b/dnafiber/data/utils.py @@ -0,0 +1,80 @@ +import base64 + +from xml.dom import minidom +import cv2 +import numpy as np +from czifile import CziFile +from tifffile import imread + + +def read_svg(svg_path): + doc = minidom.parse(str(svg_path)) + img_strings = { + path.getAttribute("id"): path.getAttribute("href") + for path in doc.getElementsByTagName("image") + } + doc.unlink() + + red = img_strings["Red"] + green = img_strings["Green"] + red = base64.b64decode(red.split(",")[1]) + green = base64.b64decode(green.split(",")[1]) + red = cv2.imdecode(np.frombuffer(red, dtype=np.uint8), cv2.IMREAD_UNCHANGED) + green = cv2.imdecode(np.frombuffer(green, dtype=np.uint8), cv2.IMREAD_UNCHANGED) + + red = cv2.cvtColor(red, cv2.COLOR_BGRA2GRAY) + green = cv2.cvtColor(green, cv2.COLOR_BGRA2GRAY) + mask = np.zeros_like(red) + mask[red > 0] = 1 + mask[green > 0] = 2 + return mask + + +def extract_bboxes(mask): + mask = np.array(mask) + mask = mask.astype(np.uint8) + + # Find connected components + num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats( + mask, connectivity=8 + ) + bboxes = [] + for i in range(1, num_labels): + x, y, w, h, area = stats[i] + bboxes.append([x, y, x + w, y + h]) + return bboxes + + +def preprocess(raw_data, reverse_channels=False): + MAX_VALUE = 2**16 - 1 + if raw_data.ndim == 2: + raw_data = raw_data[np.newaxis, :, :] + h, w = raw_data.shape[1:3] + orders = np.arange(raw_data.shape[0])[::-1] # Reverse channel order + result = np.zeros((h, w, 3), dtype=np.uint8) + + for i, chan in enumerate(raw_data): + hist, bins = np.histogram(chan.ravel(), MAX_VALUE + 1, (0, MAX_VALUE + 1)) + cdf = hist.cumsum() + cdf_normalized = cdf / cdf[-1] + bmax = np.searchsorted(cdf_normalized, 0.99, side="left") + clip = np.clip(chan, 0, bmax).astype(np.float32) + clip = (clip - clip.min()) / (bmax - clip.min()) * 255 + result[:, :, orders[i]] = clip + if reverse_channels: + # Reverse channels 0 and 1 + result = result[:, :, [1, 0, 2]] + return result + + +def read_czi(filepath): + data = CziFile(filepath) + + return data.asarray().squeeze() + + +def read_tiff(filepath): + + data = imread(filepath).squeeze() + + return data \ No newline at end of file diff --git a/dnafiber/deployment.py b/dnafiber/deployment.py new file mode 100644 index 0000000000000000000000000000000000000000..0e8c2f89ea09cd9e28757e88a2c56c1daa07e404 --- /dev/null +++ b/dnafiber/deployment.py @@ -0,0 +1,44 @@ +from dnafiber.trainee import Trainee +from dnafiber.postprocess.fiber import FiberProps +import pandas as pd + +def _get_model(revision, device="cuda"): + if revision is None: + model = Trainee.from_pretrained( + "ClementP/DeepFiberQ", arch="unet", encoder_name="mit_b0" + ) + else: + model = Trainee.from_pretrained( + "ClementP/DeepFiberQ", + revision=revision, + ) + return model.eval().to(device) + + +def format_results(results: list[FiberProps], pixel_size: float) -> pd.DataFrame: + """ + Format the results for display in the UI. + """ + results = [fiber for fiber in results if fiber.is_valid] + all_results = dict( + FirstAnalog=[], SecondAnalog=[], length=[], ratio=[], fiber_type=[] + ) + all_results["FirstAnalog"].extend([fiber.red * pixel_size for fiber in results]) + all_results["SecondAnalog"].extend([fiber.green * pixel_size for fiber in results]) + all_results["length"].extend( + [fiber.red * pixel_size + fiber.green * pixel_size for fiber in results] + ) + all_results["ratio"].extend([fiber.ratio for fiber in results]) + all_results["fiber_type"].extend([fiber.fiber_type for fiber in results]) + + return pd.DataFrame.from_dict(all_results) + + + + +MODELS_ZOO = { + "Ensemble": "ensemble", + "SegFormer MiT-B4": "segformer_mit_b4", + "SegFormer MiT-B2": "segformer_mit_b2", + "U-Net SE-ResNet50": "unet_se_resnet50", +} \ No newline at end of file diff --git a/dnafiber/inference.py b/dnafiber/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..b8bcbff46cc8dcc1d111158546fbdbfa0a254763 --- /dev/null +++ b/dnafiber/inference.py @@ -0,0 +1,105 @@ +import torch.nn.functional as F +import numpy as np +import torch +from torchvision.transforms._functional_tensor import normalize +import pandas as pd +from skimage.segmentation import expand_labels +from skimage.measure import label +import albumentations as A +from monai.inferers import SlidingWindowInferer +from dnafiber.deployment import _get_model +from dnafiber.postprocess import refine_segmentation + +transform = A.Compose( + [ + A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + A.ToTensorV2(), + ] +) + + +def preprocess_image(image): + image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0) + image = normalize(image, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + return image + + +def convert_to_dataset(counts): + data = {"index": [], "red": [], "green": [], "ratio": []} + for k, v in counts.items(): + data["index"].append(k) + data["green"].append(v["green"]) + data["red"].append(v["red"]) + if v["red"] == 0: + data["ratio"].append(np.nan) + else: + data["ratio"].append(v["green"] / (v["red"])) + df = pd.DataFrame(data) + return df + + +def convert_mask_to_image(mask, expand=False): + if expand: + mask = expand_labels(mask, distance=expand) + h, w = mask.shape + image = np.zeros((h, w, 3), dtype=np.uint8) + GREEN = np.array([0, 255, 0]) + RED = np.array([255, 0, 0]) + + image[mask == 1] = RED + image[mask == 2] = GREEN + + return image + + +@torch.inference_mode() +def infer(model, image, device, scale=0.13, to_numpy=True, only_probabilities=False): + if isinstance(model, str): + model = _get_model(device=device, revision=model) + model_pixel_size = 0.26 + + scale = scale / model_pixel_size + tensor = transform(image=image)["image"].unsqueeze(0).to(device) + h, w = tensor.shape[2], tensor.shape[3] + device = torch.device(device) + with torch.autocast(device_type=device.type): + tensor = F.interpolate( + tensor, + size=(int(h * scale), int(w * scale)), + mode="bilinear", + ) + if tensor.shape[2] > 1024 or tensor.shape[3] > 1024: + inferer = SlidingWindowInferer( + roi_size=(1024, 1024), + sw_batch_size=4, + overlap=0.25, + mode="gaussian", + device=device, + progress=True, + ) + output = inferer(tensor, model) + else: + output = model(tensor) + + probabilities = F.softmax(output, dim=1) + if only_probabilities: + probabilities = probabilities.cpu() + + probabilities = F.interpolate( + probabilities, + size=(h, w), + mode="bilinear", + ) + return probabilities + + output = F.interpolate( + probabilities.argmax(dim=1, keepdim=True).float(), + size=(h, w), + mode="nearest", + ) + + output = output.squeeze().byte() + if to_numpy: + output = output.cpu().numpy() + + return output diff --git a/dnafiber/metric.py b/dnafiber/metric.py new file mode 100644 index 0000000000000000000000000000000000000000..c94bd2af55b8672431c878aa3a44f17c3a2f3b5e --- /dev/null +++ b/dnafiber/metric.py @@ -0,0 +1,150 @@ +import kornia as K +import torch +import torchmetrics.functional as F +from skimage.measure import label +from torchmetrics import Metric + + +class DNAFIBERMetric(Metric): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.add_state( + "detection_tp", + default=torch.tensor(0, dtype=torch.int64), + dist_reduce_fx="sum", + ) + self.add_state( + "fiber_red_dice", + default=torch.tensor(0, dtype=torch.float32), + dist_reduce_fx="sum", + ) + self.add_state( + "fiber_green_dice", + default=torch.tensor(0, dtype=torch.float32), + dist_reduce_fx="sum", + ) + self.add_state( + "fiber_red_recall", + default=torch.tensor(0, dtype=torch.float32), + dist_reduce_fx="sum", + ) + self.add_state( + "fiber_green_recall", + default=torch.tensor(0, dtype=torch.float32), + dist_reduce_fx="sum", + ) + # Specificity + self.add_state( + "fiber_red_precision", + default=torch.tensor(0, dtype=torch.float32), + dist_reduce_fx="sum", + ) + self.add_state( + "fiber_green_precision", + default=torch.tensor(0, dtype=torch.float32), + dist_reduce_fx="sum", + ) + + self.add_state( + "detection_fp", + default=torch.tensor(0, dtype=torch.int64), + dist_reduce_fx="sum", + ) + self.add_state( + "N", + default=torch.tensor(0, dtype=torch.int64), + dist_reduce_fx="sum", + ) + + def update(self, preds, target): + if preds.ndim == 4: + preds = preds.argmax(dim=1) + if target.ndim == 4: + target = target.squeeze(1) + B, H, W = preds.shape + preds_labels = [] + target_labels = [] + binary_preds = preds > 0 + binary_target = target > 0 + N_true_labels = 0 + for i in range(B): + pred = binary_preds[i].detach().cpu().numpy() + target_np = binary_target[i].detach().cpu().numpy() + pred_labels = label(pred, connectivity=2) + target_labels_np = label(target_np, connectivity=2) + preds_labels.append(torch.from_numpy(pred_labels).to(preds.device)) + target_labels.append(torch.from_numpy(target_labels_np).to(preds.device)) + N_true_labels += target_labels_np.max() + + preds_labels = torch.stack(preds_labels) + target_labels = torch.stack(target_labels) + + for i, plab in enumerate(preds_labels): + labels = torch.unique(plab) + for blob in labels: + if blob == 0: + continue + pred_mask = plab == blob + pixels_in_common = torch.any(pred_mask & binary_target[i]) + if pixels_in_common: + self.detection_tp += 1 + gt_label = target_labels[i][pred_mask].unique()[-1] + gt_mask = target_labels[i] == gt_label + common_mask = pred_mask | gt_mask + pred_fiber = preds[i][common_mask] + gt_fiber = target[i][common_mask] + dices = F.dice( + pred_fiber, + gt_fiber, + num_classes=3, + ignore_index=0, + average=None, + ) + dices = torch.nan_to_num(dices, nan=0.0) + self.fiber_red_dice += dices[1] + self.fiber_green_dice += dices[2] + recalls = F.recall( + pred_fiber, + gt_fiber, + num_classes=3, + ignore_index=0, + task="multiclass", + average=None, + ) + recalls = torch.nan_to_num(recalls, nan=0.0) + self.fiber_red_recall += recalls[1] + self.fiber_green_recall += recalls[2] + + # Specificity + specificity = F.precision( + pred_fiber, + gt_fiber, + num_classes=3, + ignore_index=0, + task="multiclass", + average=None, + ) + specificity = torch.nan_to_num(specificity, nan=0.0) + self.fiber_red_precision += specificity[1] + self.fiber_green_precision += specificity[2] + + else: + self.detection_fp += 1 + + self.N += N_true_labels + + def compute(self): + return { + "detection_precision": self.detection_tp + / (self.detection_tp + self.detection_fp + 1e-7), + "detection_recall": self.detection_tp / (self.N + 1e-7), + "fiber_red_dice": self.fiber_red_dice / (self.detection_tp + 1e-7), + "fiber_green_dice": self.fiber_green_dice / (self.detection_tp + 1e-7), + "fiber_red_recall": self.fiber_red_recall / (self.detection_tp + 1e-7), + "fiber_green_recall": self.fiber_green_recall / (self.detection_tp + 1e-7), + "fiber_red_precision": self.fiber_red_precision + / (self.detection_tp + 1e-7), + "fiber_green_precision": self.fiber_green_precision + / (self.detection_tp + 1e-7), + } diff --git a/dnafiber/model/maskrcnn.py b/dnafiber/model/maskrcnn.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/dnafiber/postprocess/__init__.py b/dnafiber/postprocess/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a8d092ba103ac2115498368327064f129764562e --- /dev/null +++ b/dnafiber/postprocess/__init__.py @@ -0,0 +1 @@ +from .core import refine_segmentation diff --git a/dnafiber/postprocess/__pycache__/__init__.cpython-312.pyc b/dnafiber/postprocess/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a52a0e1a96834225fb98bc0f3f46246eeed38f70 Binary files /dev/null and b/dnafiber/postprocess/__pycache__/__init__.cpython-312.pyc differ diff --git a/dnafiber/postprocess/__pycache__/core.cpython-312.pyc b/dnafiber/postprocess/__pycache__/core.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20bbd2765e3117d0d48c63f30216260ad4540c5c Binary files /dev/null and b/dnafiber/postprocess/__pycache__/core.cpython-312.pyc differ diff --git a/dnafiber/postprocess/__pycache__/fiber.cpython-312.pyc b/dnafiber/postprocess/__pycache__/fiber.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0053e4430a919d5740f314ace842619e9444b886 Binary files /dev/null and b/dnafiber/postprocess/__pycache__/fiber.cpython-312.pyc differ diff --git a/dnafiber/postprocess/__pycache__/skan.cpython-312.pyc b/dnafiber/postprocess/__pycache__/skan.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e495fffb403d131e38f162499ecf10ea33023b8a Binary files /dev/null and b/dnafiber/postprocess/__pycache__/skan.cpython-312.pyc differ diff --git a/dnafiber/postprocess/core.py b/dnafiber/postprocess/core.py new file mode 100644 index 0000000000000000000000000000000000000000..90fa1bdfd892d16154bbd878676ba693b3ae2edf --- /dev/null +++ b/dnafiber/postprocess/core.py @@ -0,0 +1,274 @@ +import numpy as np +import cv2 +from typing import List, Tuple +from dnafiber.postprocess.skan import find_endpoints, compute_points_angle +from scipy.spatial.distance import cdist + +from scipy.sparse.csgraph import connected_components +from scipy.sparse import csr_array +from skimage.morphology import skeletonize +from dnafiber.postprocess.skan import find_line_intersection +from dnafiber.postprocess.fiber import Fiber, FiberProps, Bbox +from itertools import compress +import matplotlib.pyplot as plt +from matplotlib.colors import ListedColormap + +cmlabel = ListedColormap(["black", "red", "green"]) + +MIN_ANGLE = 20 +MIN_BRANCH_LENGTH = 10 +MIN_BRANCH_DISTANCE = 30 + + +def handle_multiple_fiber_in_cc(fiber, junctions_fiber, coordinates): + for y, x in junctions_fiber: + fiber[y - 1 : y + 2, x - 1 : x + 2] = 0 + + endpoints = find_endpoints(fiber > 0) + endpoints = np.asarray(endpoints) + # We only keep the endpoints that are close to the junction + # We compute the distance between the endpoints and the junctions + distances = np.linalg.norm( + np.expand_dims(endpoints, axis=1) - np.expand_dims(junctions_fiber, axis=0), + axis=2, + ) + # We only keep the endpoints that are close to the junctions + distances = distances < 5 + endpoints = endpoints[distances.any(axis=1)] + + retval, branches, branches_stats, _ = cv2.connectedComponentsWithStatsWithAlgorithm( + fiber, connectivity=8, ccltype=cv2.CCL_BOLELLI, ltype=cv2.CV_16U + ) + branches_bboxes = branches_stats[ + :, + [cv2.CC_STAT_LEFT, cv2.CC_STAT_TOP, cv2.CC_STAT_WIDTH, cv2.CC_STAT_HEIGHT], + ] + + num_branches = branches_bboxes.shape[0] - 1 + # We associate the endpoints to the branches + endpoints_ids = np.zeros((len(endpoints),), dtype=np.uint16) + endpoints_color = np.zeros((len(endpoints),), dtype=np.uint8) + for i, endpoint in enumerate(endpoints): + # Get the branch id + branch_id = branches[endpoint[0], endpoint[1]] + # Check if the branch id is not 0 + if branch_id != 0: + endpoints_ids[i] = branch_id + endpoints_color[i] = fiber[endpoint[0], endpoint[1]] + + # We remove the small branches + kept_branches = set() + for i in range(1, num_branches + 1): + # Get the branch + branch = branches == i + # Compute the area of the branch + area = np.sum(branch.astype(np.uint8)) + # If the area is less than 10 pixels, remove the branch + if area < MIN_BRANCH_LENGTH: + branches[branch] = 0 + else: + kept_branches.add(i) + + # We remove the endpoints that are in the filtered branches + remaining_idxs = np.isin(endpoints_ids, np.asarray(list(kept_branches))) + if remaining_idxs.sum() == 0: + return [] + endpoints = endpoints[remaining_idxs] + + endpoints_color = endpoints_color[remaining_idxs] + endpoints_ids = endpoints_ids[remaining_idxs] + + # We compute the angles of the endpoints + angles = compute_points_angle(fiber, endpoints, steps=15) + angles = np.rad2deg(angles) + # We compute the difference of angles between all the endpoints + endpoints_angles_diff = cdist(angles[:, None], angles[:, None], metric="cityblock") + + # Put inf to the diagonal + endpoints_angles_diff[range(len(endpoints)), range(len(endpoints))] = np.inf + endpoints_distances = cdist(endpoints, endpoints, metric="euclidean") + + endpoints_distances[range(len(endpoints)), range(len(endpoints))] = np.inf + + # We sort by the distance + endpoints_distances[endpoints_distances > MIN_BRANCH_DISTANCE] = np.inf + endpoints_distances[endpoints_angles_diff > MIN_ANGLE] = np.inf + + matchB = np.argmin(endpoints_distances, axis=1) + values = np.take_along_axis(endpoints_distances, matchB[:, None], axis=1) + + added_edges = dict() + N = len(endpoints) + A = np.eye(N, dtype=np.uint8) + for i in range(N): + for j in range(N): + if i == j: + continue + if endpoints_ids[i] == endpoints_ids[j]: + A[i, j] = 1 + A[j, i] = 1 + + if matchB[i] == j and values[i, 0] < np.inf: + added_edges[i] = j + A[i, j] = 1 + A[j, i] = 1 + + A = csr_array(A) + n, ccs = connected_components(A, directed=False, return_labels=True) + unique_clusters = np.unique(ccs) + results = [] + for c in unique_clusters: + idx = np.where(ccs == c)[0] + branches_ids = np.unique(endpoints_ids[idx]) + + unique_branches = np.logical_or.reduce( + [branches == i for i in branches_ids], axis=0 + ) + + commons_bboxes = branches_bboxes[branches_ids] + # Compute the union of the bboxes + min_x = np.min(commons_bboxes[:, 0]) + min_y = np.min(commons_bboxes[:, 1]) + max_x = np.max(commons_bboxes[:, 0] + commons_bboxes[:, 2]) + max_y = np.max(commons_bboxes[:, 1] + commons_bboxes[:, 3]) + + new_fiber = fiber[min_y:max_y, min_x:max_x] + new_fiber = unique_branches[min_y:max_y, min_x:max_x] * new_fiber + for cidx in idx: + if cidx not in added_edges: + continue + pointA = endpoints[cidx] + pointB = endpoints[added_edges[cidx]] + pointA = ( + pointA[1] - min_x, + pointA[0] - min_y, + ) + pointB = ( + pointB[1] - min_x, + pointB[0] - min_y, + ) + colA = endpoints_color[cidx] + colB = endpoints_color[added_edges[cidx]] + new_fiber = cv2.line( + new_fiber, + pointA, + pointB, + color=2 if colA != colB else int(colA), + thickness=1, + ) + # We express the bbox in the original image + bbox = ( + coordinates[0] + min_x, + coordinates[1] + min_y, + max_x - min_x, + max_y - min_y, + ) + bbox = Bbox(x=bbox[0], y=bbox[1], width=bbox[2], height=bbox[3]) + result = Fiber(bbox=bbox, data=new_fiber) + results.append(result) + return results + + +def handle_ccs_with_junctions( + ccs: List[np.ndarray], + junctions: List[List[Tuple[int, int]]], + coordinates: List[Tuple[int, int]], +): + """ + Handle the connected components with junctions. + The function takes a list of connected components, a list of list of junctions and a list of coordinates. + The junctions + The coordinates corresponds to the top left corner of the connected component. + """ + jncts_fibers = [] + for fiber, junction, coordinate in zip(ccs, junctions, coordinates): + jncts_fibers += handle_multiple_fiber_in_cc(fiber, junction, coordinate) + + return jncts_fibers + + +def refine_segmentation(segmentation, fix_junctions=True, show=False): + skeleton = skeletonize(segmentation > 0, method="lee").astype(np.uint8) + skeleton_gt = skeleton * segmentation + retval, labels, stats, centroids = cv2.connectedComponentsWithStatsWithAlgorithm( + skeleton, connectivity=8, ccltype=cv2.CCL_BOLELLI, ltype=cv2.CV_16U + ) + + bboxes = stats[ + :, + [ + cv2.CC_STAT_LEFT, + cv2.CC_STAT_TOP, + cv2.CC_STAT_WIDTH, + cv2.CC_STAT_HEIGHT, + ], + ] + + local_fibers = [] + coordinates = [] + junctions = [] + for i in range(1, retval): + bbox = bboxes[i] + x1, y1, w, h = bbox + local_gt = skeleton_gt[y1 : y1 + h, x1 : x1 + w] + local_label = (labels[y1 : y1 + h, x1 : x1 + w] == i).astype(np.uint8) + local_fiber = local_gt * local_label + local_fibers.append(local_fiber) + coordinates.append(np.asarray([x1, y1, w, h])) + local_junctions = find_line_intersection(local_fiber > 0) + local_junctions = np.where(local_junctions) + local_junctions = np.array(local_junctions).transpose() + junctions.append(local_junctions) + if show: + for bbox, junction in zip(coordinates, junctions): + x, y, w, h = bbox + junction_to_global = np.array(junction) + np.array([y, x]) + + plt.scatter( + junction_to_global[:, 1], + junction_to_global[:, 0], + color="white", + s=30, + alpha=0.35, + ) + + plt.imshow(skeleton_gt, cmap=cmlabel, interpolation="nearest") + plt.axis("off") + plt.xticks([]) + plt.yticks([]) + plt.subplots_adjust(left=0, right=1, top=1, bottom=0) + plt.show() + + fibers = [] + if fix_junctions: + has_junctions = [len(j) > 0 for j in junctions] + for fiber, coordinate in zip( + compress(local_fibers, np.logical_not(has_junctions)), + compress(coordinates, np.logical_not(has_junctions)), + ): + bbox = Bbox( + x=coordinate[0], + y=coordinate[1], + width=coordinate[2], + height=coordinate[3], + ) + fibers.append(Fiber(bbox=bbox, data=fiber)) + + fibers += handle_ccs_with_junctions( + compress(local_fibers, has_junctions), + compress(junctions, has_junctions), + compress(coordinates, has_junctions), + ) + else: + for fiber, coordinate in zip(local_fibers, coordinates): + bbox = Bbox( + x=coordinate[0], + y=coordinate[1], + width=coordinate[2], + height=coordinate[3], + ) + fibers.append(Fiber(bbox=bbox, data=fiber)) + + fiberprops = [FiberProps(fiber=f, fiber_id=i) for i, f in enumerate(fibers)] + + return fiberprops diff --git a/dnafiber/postprocess/fiber.py b/dnafiber/postprocess/fiber.py new file mode 100644 index 0000000000000000000000000000000000000000..dff9ef480dc25fb580e3e6eebf7ba7b9f5e88272 --- /dev/null +++ b/dnafiber/postprocess/fiber.py @@ -0,0 +1,129 @@ +import attrs +import numpy as np +from typing import Tuple +from dnafiber.postprocess.skan import trace_skeleton + +@attrs.define +class Bbox: + x: int + y: int + width: int + height: int + + @property + def bbox(self) -> Tuple[int, int, int, int]: + return (self.x, self.y, self.width, self.height) + + @bbox.setter + def bbox(self, value: Tuple[int, int, int, int]): + self.x, self.y, self.width, self.height = value + + +@attrs.define +class Fiber: + bbox: Bbox + data: np.ndarray + + +@attrs.define +class FiberProps: + fiber: Fiber + fiber_id: int + red_pixels: int = None + green_pixels: int = None + category: str = None + + @property + def bbox(self): + return self.fiber.bbox.bbox + + @bbox.setter + def bbox(self, value): + self.fiber.bbox = value + + @property + def data(self): + return self.fiber.data + + @data.setter + def data(self, value): + self.fiber.data = value + + @property + def red(self): + if self.red_pixels is None: + self.red_pixels, self.green_pixels = self.counts + return self.red_pixels + + @property + def green(self): + if self.green_pixels is None: + self.red_pixels, self.green_pixels = self.counts + return self.green_pixels + + @property + def length(self): + return sum(self.counts) + + @property + def counts(self): + if self.red_pixels is None or self.green_pixels is None: + self.red_pixels = np.sum(self.data == 1) + self.green_pixels = np.sum(self.data == 2) + return self.red_pixels, self.green_pixels + + @property + def fiber_type(self): + if self.category is not None: + return self.category + red_pixels, green_pixels = self.counts + if red_pixels == 0 or green_pixels == 0: + self.category = "single" + else: + self.category = estimate_fiber_category(self.data) + return self.category + + @property + def ratio(self): + return self.green / self.red + + @property + def is_valid(self): + return ( + self.fiber_type == "double" + or self.fiber_type == "one-two-one" + or self.fiber_type == "two-one-two" + ) + + def scaled_coordinates(self, scale: float) -> Tuple[int, int]: + """ + Scale down the coordinates of the fiber's bounding box. + """ + x, y, width, height = self.bbox + return ( + int(x * scale), + int(y * scale), + int(width * scale), + int(height * scale), + ) + + +def estimate_fiber_category(fiber: np.ndarray) -> str: + """ + Estimate the fiber category based on the number of red and green pixels. + """ + coordinates = trace_skeleton(fiber > 0) + coordinates = np.asarray(coordinates) + values = fiber[coordinates[:, 0], coordinates[:, 1]] + diff = np.diff(values) + jump = np.sum(diff != 0) + n_ccs = jump + 1 + if n_ccs == 2: + return "double" + elif n_ccs == 3: + if values[0] == 1: + return "one-two-one" + else: + return "two-one-two" + else: + return "multiple" diff --git a/dnafiber/postprocess/skan.py b/dnafiber/postprocess/skan.py new file mode 100644 index 0000000000000000000000000000000000000000..c86255f6fcbc2e41e922e1d5736b237bfc13d50c --- /dev/null +++ b/dnafiber/postprocess/skan.py @@ -0,0 +1,211 @@ +# Functions to generate kernels of curve intersection +import numpy as np +import cv2 +import itertools +from numba import njit, int64 +from numba.typed import List +from numba.types import Tuple + +# Define the element type: a tuple of two int64 +tuple_type = Tuple((int64, int64)) + + +def find_neighbours(fibers_map, point): + """ + Find the next point in the fiber starting from the given point. + The function returns None if the point is not in the fiber. + """ + # Get the fiber id + neighbors = [] + h, w = fibers_map.shape + for i in range(-1, 2): + for j in range(-1, 2): + # Skip the center point + if i == 0 and j == 0: + continue + # Get the next point + nextpoint = (point[0] + i, point[1] + j) + # Check if the next point is in the image + if ( + nextpoint[0] < 0 + or nextpoint[0] >= h + or nextpoint[1] < 0 + or nextpoint[1] >= w + ): + continue + + # Check if the next point is in the fiber + if fibers_map[nextpoint]: + neighbors.append(nextpoint) + return neighbors + + +def compute_points_angle(fibers_map, points, steps=25): + """ + For each endpoint, follow the fiber for a given number of steps and estimate the tangent line by + fitting a line to the visited points. The angle of the line is returned. + """ + points_angle = np.zeros((len(points),), dtype=np.float32) + for i, point in enumerate(points): + # Find the fiber it belongs to + # Lets navigate along the fiber starting from the point during steps pixels. + # We compute the angles at each step and return the mean angle. + visited = trace_from_point( + fibers_map > 0, (point[0], point[1]), max_length=steps + ) + visited = np.array(visited) + vx, vy, x, y = cv2.fitLine(visited[:, ::-1], cv2.DIST_L2, 0, 0.01, 0.01) + # Compute the angle of the line + points_angle[i] = np.arctan(vy / vx) + + return points_angle + + +def generate_nonadjacent_combination(input_list, take_n): + """ + It generates combinations of m taken n at a time where there is no adjacent n. + INPUT: + input_list = (iterable) List of elements you want to extract the combination + take_n = (integer) Number of elements that you are going to take at a time in + each combination + OUTPUT: + all_comb = (np.array) with all the combinations + """ + all_comb = [] + for comb in itertools.combinations(input_list, take_n): + comb = np.array(comb) + d = np.diff(comb) + if len(d[d == 1]) == 0 and comb[-1] - comb[0] != 7: + all_comb.append(comb) + return all_comb + + +def populate_intersection_kernel(combinations): + """ + Maps the numbers from 0-7 into the 8 pixels surrounding the center pixel in + a 9 x 9 matrix clockwisely i.e. up_pixel = 0, right_pixel = 2, etc. And + generates a kernel that represents a line intersection, where the center + pixel is occupied and 3 or 4 pixels of the border are ocuppied too. + INPUT: + combinations = (np.array) matrix where every row is a vector of combinations + OUTPUT: + kernels = (List) list of 9 x 9 kernels/masks. each element is a mask. + """ + n = len(combinations[0]) + template = np.array(([-1, -1, -1], [-1, 1, -1], [-1, -1, -1]), dtype="int") + match = [(0, 1), (0, 2), (1, 2), (2, 2), (2, 1), (2, 0), (1, 0), (0, 0)] + kernels = [] + for n in combinations: + tmp = np.copy(template) + for m in n: + tmp[match[m][0], match[m][1]] = 1 + kernels.append(tmp) + return kernels + + +def give_intersection_kernels(): + """ + Generates all the intersection kernels in a 9x9 matrix. + INPUT: + None + OUTPUT: + kernels = (List) list of 9 x 9 kernels/masks. each element is a mask. + """ + input_list = np.arange(8) + taken_n = [4, 3] + kernels = [] + for taken in taken_n: + comb = generate_nonadjacent_combination(input_list, taken) + tmp_ker = populate_intersection_kernel(comb) + kernels.extend(tmp_ker) + return kernels + + +def find_line_intersection(input_image, show=0): + """ + Applies morphologyEx with parameter HitsMiss to look for all the curve + intersection kernels generated with give_intersection_kernels() function. + INPUT: + input_image = (np.array dtype=np.uint8) binarized m x n image matrix + OUTPUT: + output_image = (np.array dtype=np.uint8) image where the nonzero pixels + are the line intersection. + """ + input_image = input_image.astype(np.uint8) + kernel = np.array(give_intersection_kernels()) + output_image = np.zeros(input_image.shape) + for i in np.arange(len(kernel)): + out = cv2.morphologyEx( + input_image, + cv2.MORPH_HITMISS, + kernel[i, :, :], + borderValue=0, + borderType=cv2.BORDER_CONSTANT, + ) + output_image = output_image + out + + return output_image + + +@njit +def get_neighbors_8(y, x, shape): + neighbors = List.empty_list(tuple_type) + for dy in range(-1, 2): + for dx in range(-1, 2): + if dy == 0 and dx == 0: + continue + ny, nx = y + dy, x + dx + if 0 <= ny < shape[0] and 0 <= nx < shape[1]: + neighbors.append((ny, nx)) + return neighbors + + +@njit +def find_endpoints(skel): + endpoints = List.empty_list(tuple_type) + for y in range(skel.shape[0]): + for x in range(skel.shape[1]): + if skel[y, x] == 1: + count = 0 + neighbors = get_neighbors_8(y, x, skel.shape) + for ny, nx in neighbors: + if skel[ny, nx] == 1: + count += 1 + if count == 1: + endpoints.append((y, x)) + return endpoints + + +@njit +def trace_skeleton(skel): + endpoints = find_endpoints(skel) + if len(endpoints) < 1: + return List.empty_list(tuple_type) # Return empty list with proper type + + return trace_from_point(skel, endpoints[0], max_length=skel.sum()) + + +@njit +def trace_from_point(skel, point, max_length=25): + visited = np.zeros_like(skel, dtype=np.uint8) + path = List.empty_list(tuple_type) + + # Check if the starting point is on the skeleton + y, x = point + if y < 0 or y >= skel.shape[0] or x < 0 or x >= skel.shape[1] or skel[y, x] != 1: + return path + + stack = List.empty_list(tuple_type) + stack.append(point) + + while len(stack) > 0 and len(path) < max_length: + y, x = stack.pop() + if visited[y, x]: + continue + visited[y, x] = 1 + path.append((y, x)) + neighbors = get_neighbors_8(y, x, skel.shape) + for ny, nx in neighbors: + if skel[ny, nx] == 1 and not visited[ny, nx]: + stack.append((ny, nx)) + return path diff --git a/dnafiber/start.py b/dnafiber/start.py new file mode 100644 index 0000000000000000000000000000000000000000..4e408cbc2f046893527334ba12b755594ac9b1b7 --- /dev/null +++ b/dnafiber/start.py @@ -0,0 +1,22 @@ +import subprocess +import os + + +def main(): + # Start the Streamlit application + print("Starting Streamlit application...") + local_dir = os.path.dirname(os.path.abspath(__file__)) + subprocess.run( + [ + "streamlit", + "run", + os.path.join(local_dir, "ui", "Welcome.py"), + "--server.maxUploadSize", + "1024", + ], + ) + + +if __name__ == "__main__": + main() + print("Streamlit application started successfully.") diff --git a/dnafiber/trainee.py b/dnafiber/trainee.py new file mode 100644 index 0000000000000000000000000000000000000000..41fe1f7fd76520aa5fadc9e7592b9555a09f22e7 --- /dev/null +++ b/dnafiber/trainee.py @@ -0,0 +1,148 @@ +from lightning import LightningModule +import segmentation_models_pytorch as smp +from monai.losses.dice import GeneralizedDiceLoss +from monai.losses.cldice import SoftDiceclDiceLoss +from torchmetrics.classification import Dice, JaccardIndex +from torch.optim import AdamW +from torch.optim.lr_scheduler import CosineAnnealingLR +from torchmetrics import MetricCollection +import torch.nn.functional as F +from huggingface_hub import PyTorchModelHubMixin +import torch +import torchvision +from dnafiber.metric import DNAFIBERMetric + + +class Trainee(LightningModule, PyTorchModelHubMixin): + def __init__( + self, learning_rate=0.001, weight_decay=0.0002, num_classes=3, **model_config + ): + super().__init__() + self.model_config = model_config + if ( + self.model_config.get("arch", None) is None + or self.model_config["arch"] == "maskrcnn" + ): + self.model = None + else: + self.model = smp.create_model(classes=3, **self.model_config, dropout=0.2) + self.loss = GeneralizedDiceLoss(to_onehot_y=False, softmax=False) + self.metric = MetricCollection( + { + "dice": Dice(num_classes=num_classes, ignore_index=0), + "jaccard": JaccardIndex( + num_classes=num_classes, + task="multiclass" if num_classes > 2 else "binary", + ignore_index=0, + ), + "detection": DNAFIBERMetric(), + } + ) + self.weight_decay = weight_decay + self.learning_rate = learning_rate + self.save_hyperparameters() + + def forward(self, x): + yhat = self.model(x) + return yhat + + def training_step(self, batch, batch_idx): + x, y = batch["image"], batch["mask"] + y = y.clamp(0, 2) + y_hat = self(x) + loss = self.get_loss(y_hat, y) + + self.log("train_loss", loss) + + return loss + + def get_loss(self, y_hat, y): + y_hat = F.softmax(y_hat, dim=1) + y = F.one_hot(y.long(), num_classes=3) + y = y.permute(0, 3, 1, 2).float() + loss = self.loss(y_hat, y) + return loss + + def validation_step(self, batch, batch_idx): + x, y = batch["image"], batch["mask"] + y = y.clamp(0, 2) + y_hat = self(x) + loss = self.get_loss(y_hat, y) + self.log("val_loss", loss, on_step=False, on_epoch=True, sync_dist=True) + self.metric.update(y_hat, y) + return y_hat + + def on_validation_epoch_end(self): + scores = self.metric.compute() + self.log_dict(scores, sync_dist=True) + self.metric.reset() + + def configure_optimizers(self): + optimizer = AdamW( + self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay + ) + scheduler = CosineAnnealingLR( + optimizer, + T_max=self.trainer.max_epochs, # type: ignore + eta_min=self.learning_rate / 25, + ) + scheduler = { + "scheduler": scheduler, + "interval": "epoch", + } + return [optimizer], [scheduler] + + +class TraineeMaskRCNN(Trainee): + def __init__(self, learning_rate=0.001, weight_decay=0.0002, **model_config): + super().__init__(learning_rate, weight_decay, **model_config) + self.model = torchvision.models.get_model("maskrcnn_resnet50_fpn_v2") + + def forward(self, x): + yhat = self.model(x) + return yhat + + def training_step(self, batch, batch_idx): + image = batch["image"] + targets = batch["targets"] + loss_dict = self.model(image, targets) + losses = sum(loss for loss in loss_dict.values()) + self.log("train_loss", losses, on_step=True, on_epoch=False, sync_dist=True) + return losses + + def validation_step(self, batch, batch_idx): + image = batch["image"] + targets = batch["targets"] + + predictions = self.model(image) + b = len(predictions) + predicted_masks = [] + gt_masks = [] + for i in range(b): + scores = predictions[i]["scores"] + masks = predictions[i]["masks"] + good_masks = masks[scores > 0.5] + # Combined into a single mask + good_masks = torch.sum(good_masks, dim=0) + predicted_masks.append(good_masks) + gt_masks.append(targets[i]["masks"].sum(dim=0)) + + gt_masks = torch.stack(gt_masks).squeeze(1) > 0 + predicted_masks = torch.stack(predicted_masks).squeeze(1) > 0 + self.metric.update(predicted_masks, gt_masks) + return predictions + + def configure_optimizers(self): + optimizer = AdamW( + self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay + ) + scheduler = CosineAnnealingLR( + optimizer, + T_max=self.trainer.max_epochs, # type: ignore + eta_min=self.learning_rate / 25, + ) + scheduler = { + "scheduler": scheduler, + "interval": "epoch", + } + return [optimizer], [scheduler] diff --git a/dnafiber/ui/Welcome.py b/dnafiber/ui/Welcome.py new file mode 100644 index 0000000000000000000000000000000000000000..e1657f273987bf13811815ccd85e442bab27d837 --- /dev/null +++ b/dnafiber/ui/Welcome.py @@ -0,0 +1,47 @@ +import streamlit as st +import torch + + +def main(): + st.set_page_config( + page_title="Hello", + page_icon="🧬", + layout="wide", + ) + st.write("# Welcome to DN-AI! 👋") + + st.write( + "This is a web application for the DN-AI project, which aims to provide an easy-to-use interface for analyzing and processing fiber images." + ) + st.write("## Features") + st.write( + "- **Image loading**: The application accepts CZI file, jpeg and PNG file. \n" + "- **Image segmentation**: The application provides a set of tools to segment the DNA fiber and measure the ratio between analogs. \n" + ) + st.write("## Technical details") + cols = st.columns(2) + with cols[0]: + st.write("### Source") + st.write("The source code for this application is available on GitHub.") + """ + [![Repo](https://badgen.net/badge/icon/GitHub?icon=github&label)](https://github.com/ClementPla/DeepFiberQ/tree/relabelled) + + """ + st.markdown("
", unsafe_allow_html=True) + + with cols[1]: + st.write("### Device ") + st.write("If available, the application will try to use a GPU for processing.") + device = "GPU" if torch.cuda.is_available() else "CPU" + cols = st.columns(3) + with cols[0]: + st.write("Running on:") + with cols[1]: + st.button(device, icon="⚙️", disabled=True) + if not torch.cuda.is_available(): + with cols[2]: + st.warning("The application will run on CPU, which may be slower.") + + +if __name__ == "__main__": + main() diff --git a/dnafiber/ui/__init__.py b/dnafiber/ui/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/dnafiber/ui/__pycache__/__init__.cpython-312.pyc b/dnafiber/ui/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9f8e6978691e9c21072f3ed6bced49f13b641b8 Binary files /dev/null and b/dnafiber/ui/__pycache__/__init__.cpython-312.pyc differ diff --git a/dnafiber/ui/__pycache__/inference.cpython-312.pyc b/dnafiber/ui/__pycache__/inference.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..793cfaed8e36d0a276a380e829826315c68b9f6b Binary files /dev/null and b/dnafiber/ui/__pycache__/inference.cpython-312.pyc differ diff --git a/dnafiber/ui/__pycache__/utils.cpython-312.pyc b/dnafiber/ui/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c480d1c3f4b0b327fec16c4ae9859f1ced007272 Binary files /dev/null and b/dnafiber/ui/__pycache__/utils.cpython-312.pyc differ diff --git a/dnafiber/ui/inference.py b/dnafiber/ui/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..1bf55ed6869f692c7bfa8877a67007d078de34b7 --- /dev/null +++ b/dnafiber/ui/inference.py @@ -0,0 +1,69 @@ +import streamlit as st +from dnafiber.inference import infer +from dnafiber.postprocess.core import refine_segmentation +import numpy as np +from dnafiber.deployment import _get_model +import torch + + +@st.cache_data +def ui_inference(_model, _image, _device, postprocess=True, id=None): + return ui_inference_cacheless( + _model, _image, _device, postprocess=postprocess, id=id + ) + + +@st.cache_resource +def get_model(model_name): + model = _get_model( + device="cuda" if torch.cuda.is_available() else "cpu", + revision=model_name, + ) + return model + + +def ui_inference_cacheless(_model, _image, _device, postprocess=True, id=None): + """ + A cacheless version of the ui_inference function. + This function does not use caching and is intended for use in scenarios where caching is not desired. + """ + h, w = _image.shape[:2] + with st.spinner("Sliding window segmentation in progress..."): + if isinstance(_model, list): + output = None + for model in _model: + if isinstance(model, str): + model = get_model(model) + with st.spinner(text="Segmenting with model: {}".format(model)): + if output is None: + output = infer( + model, + image=_image, + device=_device, + scale=st.session_state.get("pixel_size", 0.13), + only_probabilities=True, + ).cpu() + else: + output = ( + output + + infer( + model, + image=_image, + device=_device, + scale=st.session_state.get("pixel_size", 0.13), + only_probabilities=True, + ).cpu() + ) + output = (output / len(_model)).argmax(1).squeeze().numpy() + else: + output = infer( + _model, + image=_image, + device=_device, + scale=st.session_state.get("pixel_size", 0.13), + ) + output = output.astype(np.uint8) + if postprocess: + with st.spinner("Post-processing segmentation..."): + output = refine_segmentation(output, fix_junctions=postprocess) + return output diff --git a/dnafiber/ui/pages/1_Load.py b/dnafiber/ui/pages/1_Load.py new file mode 100644 index 0000000000000000000000000000000000000000..1c1948696ee7a7ba8ce30ff9b7955ac33a7e75c6 --- /dev/null +++ b/dnafiber/ui/pages/1_Load.py @@ -0,0 +1,196 @@ +import streamlit as st + + +st.set_page_config( + page_title="DN-AI", + page_icon="🔬", + layout="wide", +) + +def build_multichannel_loader(): + + if ( + st.session_state.get("files_uploaded", None) is None + or len(st.session_state.files_uploaded) == 0 + ): + st.session_state["files_uploaded"] = st.file_uploader( + label="Upload files", + accept_multiple_files=True, + type=["czi", "jpeg", "jpg", "png", "tiff", "tif"], + ) + else: + st.session_state["files_uploaded"] += st.file_uploader( + label="Upload files", + accept_multiple_files=True, + type=["czi", "jpeg", "jpg", "png", "tiff", "tif"], + ) + st.write("### Channel interpretation") + st.markdown("The goal is to obtain an RGB image in the order of First analog, Second analog, Empty.", unsafe_allow_html=True) + st.markdown("By default, we assume that the first channel in CZI/TIFF file is the second analog, (which happens to be the case in Zeiss microscope) " \ + "which means that we swap the order of the two channels for processing.", unsafe_allow_html=True) + st.write("If this not the intented behavior, please tick the box below:") + st.session_state["reverse_channels"] = st.checkbox( + "Reverse the channels interpretation", + value=False, + ) + st.warning("Please note that we only swap the channels for raw (CZI, TIFF) files. JPEG and PNG files "\ + "are assumed to be already in the correct order (First analog in red and second analog in green). " \ + ) + + st.info("" \ + "The channels order in CZI files does not necessarily match the order in which they are displayed in ImageJ or equivalent. " \ + "Indeed, such viewers will usually look at the metadata of the file to determine the order of the channels, which we don't. " \ + "In doubt, we recommend visualizing the image in ImageJ and compare with our viewer. If the channels appear reversed, tick the option above.") + +def build_individual_loader(): + + cols = st.columns(2) + with cols[1]: + st.markdown(f"

Second analog

", unsafe_allow_html=True) + + if ( + st.session_state.get("analog_2_files", None) is None + or len(st.session_state.analog_2_files) == 0 + ): + st.session_state["analog_2_files"] = st.file_uploader( + label="Upload second analog file(s)", + accept_multiple_files=True, + type=["czi", "jpeg", "jpg", "png", "tiff", "tif"], + ) + else: + st.session_state["analog_2_files"] += st.file_uploader( + label="Upload second analog file(s)", + accept_multiple_files=True, + type=["czi", "jpeg", "jpg", "png", "tiff", "tif"], + ) + + + with cols[0]: + st.markdown(f"

First analog

", unsafe_allow_html=True) + if ( + st.session_state.get("analog_1_files", None) is None + or len(st.session_state.analog_1_files) == 0 + ): + st.session_state["analog_1_files"] = st.file_uploader( + label="Upload first analog file(s)", + accept_multiple_files=True, + type=["czi", "jpeg", "jpg", "png", "tiff", "tif"], + ) + else: + st.session_state["analog_1_files"] += st.file_uploader( + label="Upload first analog file(s)", + accept_multiple_files=True, + type=["czi", "jpeg", "jpg", "png", "tiff", "tif"],) + + analog_1_files=st.session_state.get("analog_1_files", None) + analog_2_files=st.session_state.get("analog_2_files", None) + + # Remove duplicates from the list of files. We loop through the files and keep only the first occurrence of each file_id. + def remove_duplicates(files): + seen_ids = set() + unique_files = [] + for file in files: + if file and file.name not in seen_ids: + unique_files.append(file) + seen_ids.add(file.name) + return unique_files + + analog_1_files = remove_duplicates(analog_1_files or []) + analog_2_files = remove_duplicates(analog_2_files or []) + + + if analog_1_files is None and analog_2_files is None: + return + else: + if len(analog_1_files)>0 and len(analog_2_files)>0 and len(analog_1_files) != len(analog_2_files): + st.error("Please upload the same number of analogs files.") + return + + # Always make sure we don't have duplicates in the list of files + + analog_1_files = sorted(analog_1_files, key=lambda x: x.name) + analog_2_files = sorted(analog_2_files, key=lambda x: x.name) + max_size = max(len(analog_1_files), len(analog_2_files)) + # Pad the shorter list with None + if len(analog_1_files) < max_size: + analog_1_files += [None] * (max_size - len(analog_1_files)) + if len(analog_2_files) < max_size: + analog_2_files += [None] * (max_size - len(analog_2_files)) + + combined_files = list(zip(analog_1_files, analog_2_files)) + + + + if ( + st.session_state.get("files_uploaded", None) is None + or len(st.session_state.files_uploaded) == 0 + ): + st.session_state["files_uploaded"] = combined_files + else: + st.session_state["files_uploaded"] += combined_files + + + + # If any of the files (analog_1_files or analog_2_files) was included previously in the files_uploaded, + # We remove the previous occurence from the files_uploaded list. + current_ids = set() + for f in analog_1_files + analog_2_files: + if f: + current_ids.add(f.name) + + # Safely filter the list to exclude any files with matching file_ids + def is_not_duplicate(file): + if isinstance(file, tuple): + f1, f2 = file + if f1 and f2: + return True + + return (f1 is None or f1.name not in current_ids) and (f2 is None or f2.name not in current_ids) + else: + return True + + st.session_state.files_uploaded = [f for f in st.session_state.files_uploaded if is_not_duplicate(f)] + + + +cols = st.columns(2) +with cols[1]: + + + st.write("### Pixel size") + st.session_state["pixel_size"] = st.number_input( + "Please indicate the pixel size of the image in µm (default: 0.13 µm).", + value=st.session_state.get("pixel_size", 0.13), + ) + # In small, lets precise the tehnical details + st.write( + "The pixel size is used to convert the pixel coordinates to µm. " \ + "The model is trained on images with a pixel size of 0.26 µm, and the application automatically " \ + "resizes the images to match this pixel size using your provided choice." + ) + + st.write("### Labels color") + color_choices = st.columns(2) + with color_choices[0]: + st.session_state["color1"] = st.color_picker( + "Select the color for first analog", + value=st.session_state.get("color1", "#FF0000"), + help="This color will be used to display the first analog segments.") + with color_choices[1]: + st.session_state["color2"] = st.color_picker( + "Select the color for second analog", + value=st.session_state.get("color2", "#00FF00"), + help="This color will be used to display the second analog segments.") + +with cols[0]: + choice = st.segmented_control( + "Please select the type of images you want to upload:", + options=["Multichannel", "Individual channel"], + default="Multichannel", + ) + if choice == "Individual channel": + build_individual_loader() + else: + build_multichannel_loader() + + diff --git a/dnafiber/ui/pages/2_Viewer.py b/dnafiber/ui/pages/2_Viewer.py new file mode 100644 index 0000000000000000000000000000000000000000..8f2b42faf867043b471ca0a0d7df716df5612815 --- /dev/null +++ b/dnafiber/ui/pages/2_Viewer.py @@ -0,0 +1,490 @@ +import streamlit as st +from bokeh.plotting import figure +from bokeh.layouts import gridplot +from streamlit_bokeh import streamlit_bokeh +from dnafiber.ui.utils import ( + get_image, + get_multifile_image, + get_resized_image, + bokeh_imshow, + pad_image_to_croppable, + numpy_to_base64_png, +) +from dnafiber.deployment import MODELS_ZOO +from dnafiber.ui.inference import ui_inference, get_model +from skimage.util import view_as_blocks +import cv2 +import math +from bokeh.models import ( + Range1d, + HoverTool, +) +import streamlit_image_coordinates +from catppuccin import PALETTE +import numpy as np +import torch +from skimage.segmentation import expand_labels +import pandas as pd + +st.set_page_config( + layout="wide", + page_icon=":microscope:", +) +st.title("Viewer") + + +@st.cache_resource +def display_prediction(_prediction, _image, image_id=None): + max_width = 2048 + image = _image + if image.max() > 25: + image = cv2.normalize(image, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U) + scale = 1 + # Resize the image to max_width + if image.shape[1] > max_width: + scale = max_width / image.shape[1] + image = cv2.resize( + image, + None, + fx=scale, + fy=scale, + interpolation=cv2.INTER_LINEAR, + ) + + h, w = image.shape[:2] + labels_maps = np.zeros((h, w), dtype=np.uint8) + for i, region in enumerate(_prediction): + x, y, w, h = region.scaled_coordinates(scale) + data = cv2.resize( + expand_labels(region.data, 1), + None, + fx=scale, + fy=scale, + interpolation=cv2.INTER_NEAREST, + ) + labels_maps[ + y : y + data.shape[0], + x : x + data.shape[1], + ] = data + p1 = figure( + width=600, + x_range=Range1d(-image.shape[1] / 8, image.shape[1] * 1.125, bounds="auto"), + y_range=Range1d(image.shape[0] * 1.125, -image.shape[0] / 8, bounds="auto"), + title=f"Detected fibers: {len(_prediction)}", + tools="pan,wheel_zoom,box_zoom,reset", + active_scroll="wheel_zoom", + ) + + p1.image( + image=[labels_maps], + x=0, + y=0, + dw=labels_maps.shape[1], + dh=labels_maps.shape[0], + palette=["black", st.session_state["color1"], st.session_state["color2"]] + if np.max(labels_maps) > 0 + else ["black"], + ) + p2 = figure( + x_range=p1.x_range, + y_range=p1.y_range, + width=600, + tools="pan,wheel_zoom,box_zoom,reset", + active_scroll="wheel_zoom", + ) + bokeh_imshow(p2, image) + colors = [c.hex for c in PALETTE.latte.colors][:14] + data_source = dict( + x=[], + y=[], + width=[], + height=[], + color=[], + firstAnalog=[], + secondAnalog=[], + ratio=[], + fiber_id=[], + ) + np.random.shuffle(colors) + for i, region in enumerate(_prediction): + color = colors[i % len(colors)] + x, y, w, h = region.scaled_coordinates(scale) + + fiberId = region.fiber_id + data_source["x"].append((x + w / 2)) + data_source["y"].append((y + h / 2)) + data_source["width"].append(w) + data_source["height"].append(h) + data_source["color"].append(color) + r, g = region.counts + red_length = st.session_state["pixel_size"] * r / scale + green_length = st.session_state["pixel_size"] * g / scale + data_source["firstAnalog"].append(f"{red_length:.2f} µm") + data_source["secondAnalog"].append(f"{green_length:.2f} µm") + data_source["ratio"].append(f"{green_length / red_length:.2f}") + data_source["fiber_id"].append(fiberId) + + rect1 = p1.rect( + x="x", + y="y", + width="width", + height="height", + source=data_source, + fill_color=None, + line_color="color", + ) + rect2 = p2.rect( + x="x", + y="y", + width="width", + height="height", + source=data_source, + fill_color=None, + line_color="color", + ) + + hover = HoverTool( + tooltips=f'Fiber ID: @fiber_id

@firstAnalog

@secondAnalog

Ratio: @ratio', + ) + hover.renderers = [rect1, rect2] + hover.point_policy = "follow_mouse" + hover.attachment = "vertical" + p1.add_tools(hover) + p2.add_tools(hover) + + p1.axis.visible = False + p2.axis.visible = False + fig = gridplot( + [[p2, p1]], + merge_tools=True, + sizing_mode="stretch_width", + toolbar_options=dict(logo=None, help=None), + ) + return fig + + +@st.cache_data +def show_fibers(_prediction, _image, image_id=None): + data = dict( + fiber_id=[], + firstAnalog=[], + secondAnalog=[], + ratio=[], + fiber_type=[], + visualization=[], + ) + + for fiber in _prediction: + data["fiber_id"].append(fiber.fiber_id) + r, g = fiber.counts + red_length = st.session_state["pixel_size"] * r + green_length = st.session_state["pixel_size"] * g + data["firstAnalog"].append(f"{red_length:.3f} ") + data["secondAnalog"].append(f"{green_length:.3f} ") + data["ratio"].append(f"{green_length / red_length:.3f}") + data["fiber_type"].append(fiber.fiber_type) + + x, y, w, h = fiber.bbox + + visu = _image[y : y + h, x : x + w, :] + visu = cv2.normalize(visu, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U) + data["visualization"].append(visu) + + df = pd.DataFrame(data) + df = df.rename( + columns={ + "firstAnalog": "First analog (µm)", + "secondAnalog": "Second analog (µm)", + "ratio": "Ratio", + "fiber_type": "Fiber type", + "fiber_id": "Fiber ID", + "visualization": "Visualization", + } + ) + df["Visualization"] = df["Visualization"].apply(lambda x: numpy_to_base64_png(x)) + return df + + +def start_inference(): + image = st.session_state.image_inference + image = cv2.normalize(image, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U) + + if "ensemble" in st.session_state.model: + model = [ + _ + "_finetuned" if "finetuned" in st.session_state.model else "" + for _ in MODELS_ZOO.values() + if _ != "ensemble" + ] + else: + model = get_model(st.session_state.model) + prediction = ui_inference( + model, + image, + "cuda" if torch.cuda.is_available() else "cpu", + st.session_state.post_process, + st.session_state.image_id, + ) + prediction = [ + p + for p in prediction + if (p.fiber_type != "single") and p.fiber_type != "multiple" + ] + tab_viewer, tab_fibers = st.tabs(["Viewer", "Fibers"]) + + with tab_fibers: + df = show_fibers(prediction, image, st.session_state.image_id) + + event = st.dataframe( + df, + on_select="rerun", + selection_mode="multi-row", + use_container_width=True, + column_config={ + "Visualization": st.column_config.ImageColumn( + "Visualization", + help="Visualization of the fiber", + ) + }, + ) + + rows = event["selection"]["rows"] + columns = df.columns[:-2] + df = df.iloc[rows][columns] + + cols = st.columns(3) + with cols[0]: + copy_to_clipboard = st.button( + "Copy selected fibers to clipboard", + help="Copy the selected fibers to clipboard in CSV format.", + ) + if copy_to_clipboard: + df.to_clipboard(index=False) + with cols[2]: + st.download_button( + "Download selected fibers", + data=df.to_csv(index=False).encode("utf-8"), + file_name=f"fibers_{st.session_state.image_id}.csv", + mime="text/csv", + ) + + with tab_viewer: + max_width = 2048 + if image.shape[1] > max_width: + st.toast("Images are displayed at a lower resolution of 2048 pixel wide") + + fig = display_prediction(prediction, image, st.session_state.image_id) + streamlit_bokeh(fig, use_container_width=True) + + +def on_session_start(): + can_start = ( + st.session_state.get("files_uploaded", None) is not None + and len(st.session_state.files_uploaded) > 0 + ) + + if can_start: + return can_start + + cldu_exists = ( + st.session_state.get("files_uploaded_cldu", None) is not None + and len(st.session_state.files_uploaded_cldu) > 0 + ) + idu_exists = ( + st.session_state.get("files_uploaded_idu", None) is not None + and len(st.session_state.files_uploaded_idu) > 0 + ) + + if cldu_exists and idu_exists: + if len(st.session_state.get("files_uploaded_cldu")) != len( + st.session_state.get("files_uploaded_idu") + ): + st.error("Please upload the same number of CldU and IdU files.") + return False + + +def create_display_files(files): + if files is None or len(files) == 0: + return "No files uploaded" + display_files = [] + for file in files: + if isinstance(file, tuple): + if file[0] is None: + name = f"Second analog only {file[1].name}" + elif file[1] is None: + name = f"First analog only {file[0].name}" + else: + name = f"{file[0].name} and {file[1].name}" + display_files.append(name) + else: + display_files.append(file.name) + return display_files + + +if on_session_start(): + files = st.session_state.files_uploaded + displayed_names = create_display_files(files) + selected_file = st.selectbox( + "Pick an image", + displayed_names, + index=0, + help="Select an image to view and analyze.", + ) + + # Find index of the selected file + index = displayed_names.index(selected_file) + file = files[index] + if isinstance(file, tuple): + file_id = file[0].file_id if file[0] is not None else file[1].file_id + if file[0] is None or file[1] is None: + missing = "First analog" if file[0] is None else "Second analog" + st.warning( + f"In this image, {missing} channel is missing. We assume the intended goal is to segment the DNA fibers without differentiation. \ + Note the model may still predict two classes and try to compute a ratio; these informations can be ignored." + ) + image = get_multifile_image(file) + else: + file_id = file.file_id + image = get_image( + file, + reverse_channel=st.session_state.get("reverse_channels", False), + id=file_id, + ) + h, w = image.shape[:2] + with st.sidebar: + st.metric( + "Pixel size (µm)", + st.session_state.get("pixel_size", 0.13), + ) + + block_size = st.slider( + "Block size", + min_value=256, + max_value=min(4096, max(h, w)), + value=min(2048, max(h, w)), + step=256, + ) + if h < block_size: + block_size = h + if w < block_size: + block_size = w + + bx = by = block_size + image = pad_image_to_croppable(image, bx, by, file_id + str(bx) + str(by)) + thumbnail = get_resized_image(image, file_id) + + blocks = view_as_blocks(image, (bx, by, 3)) + x_blocks, y_blocks = blocks.shape[0], blocks.shape[1] + with st.sidebar: + with st.expander("Model", expanded=True): + model_name = st.selectbox( + "Select a model", + list(MODELS_ZOO.keys()), + index=0, + help="Select a model to use for inference", + ) + finetuned = st.checkbox( + "Use finetuned model", + value=True, + help="Use a finetuned model for inference", + ) + + col1, col2 = st.columns(2) + with col1: + st.write("Running on:") + with col2: + st.button( + "GPU" if torch.cuda.is_available() else "CPU", + disabled=True, + ) + + st.session_state.post_process = st.checkbox( + "Post-process", + value=True, + help="Apply post-processing to the prediction", + ) + + st.session_state.model = ( + (MODELS_ZOO[model_name] + "_finetuned") + if finetuned + else MODELS_ZOO[model_name] + ) + + which_y = st.session_state.get("which_y", 0) + which_x = st.session_state.get("which_x", 0) + + # Display the selected block + # Scale factor + h, w = image.shape[:2] + small_h, small_w = thumbnail.shape[:2] + scale_h = h / small_h + scale_w = w / small_w + # Calculate the coordinates of the block + y1 = math.floor(which_y * bx / scale_h) + y2 = math.floor((which_y + 1) * bx / scale_h) + x1 = math.floor(which_x * by / scale_w) + x2 = math.floor((which_x + 1) * by / scale_w) + # Draw a rectangle around the selected block + + # Check if the coordinates are within the bounds of the image + while y2 > small_h: + which_y -= 1 + y1 = math.floor(which_y * bx / scale_h) + y2 = math.floor((which_y + 1) * bx / scale_h) + while x2 > small_w: + which_x -= 1 + x1 = math.floor(which_x * by / scale_w) + x2 = math.floor((which_x + 1) * by / scale_w) + + st.session_state["which_x"] = which_x + st.session_state["which_y"] = which_y + + # Draw a grid on the thumbnail + for i in range(0, small_h, int(bx // scale_h)): + cv2.line(thumbnail, (0, i), (small_w, i), (255, 255, 255), 1) + for i in range(0, small_w, int(by // scale_w)): + cv2.line(thumbnail, (i, 0), (i, small_h), (255, 255, 255), 1) + + cv2.rectangle( + thumbnail, + (x1, y1), + (x2, y2), + (0, 0, 255), + 5, + ) + + st.write("### Select a block") + + coordinates = streamlit_image_coordinates.streamlit_image_coordinates( + thumbnail, use_column_width=True + ) + + if coordinates: + which_x = math.floor((w * coordinates["x"] / coordinates["width"]) / bx) + which_y = math.floor((h * coordinates["y"] / coordinates["height"]) / by) + if which_x != st.session_state.get("which_x", 0): + st.session_state["which_x"] = which_x + if which_y != st.session_state.get("which_y", 0): + st.session_state["which_y"] = which_y + + st.rerun() + + image = blocks[which_y, which_x, 0] + with st.sidebar: + st.image(image, caption="Selected block", use_container_width=True) + + st.session_state.image_inference = image + st.session_state.image_id = ( + file_id + + str(which_x) + + str(which_y) + + str(bx) + + str(by) + + str(model_name) + + ("_finetuned" if finetuned else "") + ) + col1, col2, col3 = st.columns([1, 1, 1]) + start_inference() +else: + st.switch_page("pages/1_Load.py") + +# Add a callback to mouse move event diff --git a/dnafiber/ui/pages/3_Analysis.py b/dnafiber/ui/pages/3_Analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..1ecb08cc73ac4f574b3838453495aa7eaa17e7a5 --- /dev/null +++ b/dnafiber/ui/pages/3_Analysis.py @@ -0,0 +1,241 @@ +import streamlit as st +import torch +from dnafiber.ui.utils import get_image, get_multifile_image +from dnafiber.deployment import MODELS_ZOO +import pandas as pd +import plotly.express as px +from dnafiber.postprocess import refine_segmentation +import torch.nn.functional as F +from joblib import Parallel, delayed +import time +from catppuccin import PALETTE +from dnafiber.deployment import _get_model +from dnafiber.ui.inference import ui_inference_cacheless + + +def plot_result(seleted_category=None): + if st.session_state.get("results", None) is None or selected_category is None: + return + only_bilateral = st.checkbox( + "Show only bicolor fibers", + value=False, + ) + remove_outliers = st.checkbox( + "Remove outliers", + value=True, + help="Remove outliers from the data", + ) + reorder = st.checkbox( + "Reorder groups by median ratio", + value=True, + ) + if remove_outliers: + min_ratio, max_ratio = st.slider( + "Ratio range", + min_value=0.0, + max_value=10.0, + value=(0.0, 5.0), + step=0.1, + help="Select the ratio range to display", + ) + df = st.session_state.results.copy() + + clean_df = df[["ratio", "image_name", "fiber_type"]].copy() + clean_df["Image"] = clean_df["image_name"] + clean_df["Fiber Type"] = clean_df["fiber_type"] + clean_df["Ratio"] = clean_df["ratio"] + + if only_bilateral: + clean_df = clean_df[clean_df["Fiber Type"] == "double"] + if remove_outliers: + clean_df = clean_df[ + (clean_df["Ratio"] >= min_ratio) & (clean_df["Ratio"] <= max_ratio) + ] + + if selected_category: + clean_df = clean_df[clean_df["Image"].isin(selected_category)] + + if not reorder: + clean_df["Image"] = pd.Categorical( + clean_df["Image"], categories=selected_category, ordered=True + ) + clean_df.sort_values("Image", inplace=True) + + if reorder: + image_order = ( + clean_df.groupby("Image")["Ratio"] + .median() + .sort_values(ascending=True) + .index + ) + clean_df["Image"] = pd.Categorical( + clean_df["Image"], categories=image_order, ordered=True + ) + clean_df.sort_values("Image", inplace=True) + + palette = [c.hex for c in PALETTE.latte.colors] + + fig = px.violin( + clean_df, + y="Ratio", + x="Image", + color="Image", + box=True, # draw box plot inside the violin + points="all", # can be 'outliers', or False + color_discrete_sequence=palette, + ) + # Set y-axis to log scale + st.plotly_chart( + fig, + use_container_width=True, + ) + + +def run_inference(model_name, pixel_size): + is_cuda_available = torch.cuda.is_available() + if "ensemble" in model_name: + model = [ + _ + "_finetuned" if "finetuned" in model_name else "" + for _ in MODELS_ZOO.values() + if _ != "ensemble" + ] + else: + model = _get_model( + revision=model_name, + device="cuda" if is_cuda_available else "cpu", + ) + + my_bar = st.progress(0, text="Running segmentation...") + all_files = st.session_state.files_uploaded + all_results = dict( + FirstAnalog=[], + SecondAnalog=[], + length=[], + ratio=[], + image_name=[], + fiber_type=[], + ) + for i, file in enumerate(all_files): + if isinstance(file, tuple): + if file[0] is None: + filename = file[1].name + if file[1] is None: + filename = file[0].name + image = get_multifile_image(file) + else: + filename = file.name + image = get_image( + file, st.session_state.get("reverse_channels", False), file.file_id + ) + start = time.time() + prediction = ui_inference_cacheless( + _model=model, + _image=image, + _device="cuda" if is_cuda_available else "cpu", + postprocess=False, + ) + print(f"Prediction time: {time.time() - start:.2f} seconds for {file.name}") + h, w = prediction.shape + start = time.time() + if h > 2048 or w > 2048: + # Extract blocks from the prediction + blocks = F.unfold( + torch.from_numpy(prediction).unsqueeze(0).float(), + kernel_size=(4096, 4096), + stride=(4096, 4096), + ) + blocks = blocks.view(4096, 4096, -1).permute(2, 0, 1).byte().numpy() + results = Parallel(n_jobs=4)( + delayed(refine_segmentation)(block) for block in blocks + ) + results = [x for xs in results for x in xs] + + else: + results = refine_segmentation(prediction, fix_junctions=True) + + print(f"Refinement time: {time.time() - start:.2f} seconds for {filename}") + results = [fiber for fiber in results if fiber.is_valid] + all_results["FirstAnalog"].extend([fiber.red * pixel_size for fiber in results]) + all_results["SecondAnalog"].extend( + [fiber.green * pixel_size for fiber in results] + ) + all_results["length"].extend( + [fiber.red * pixel_size + fiber.green * pixel_size for fiber in results] + ) + all_results["ratio"].extend([fiber.ratio for fiber in results]) + all_results["image_name"].extend([filename.split("-")[0] for fiber in results]) + all_results["fiber_type"].extend([fiber.fiber_type for fiber in results]) + + my_bar.progress(i / len(all_files), text=f"{filename} done") + + st.session_state.results = pd.DataFrame.from_dict(all_results) + + my_bar.empty() + + +if st.session_state.get("files_uploaded", None): + run_segmentation = st.button("Run Segmentation", use_container_width=True) + + with st.sidebar: + st.metric( + "Pixel size (µm)", + st.session_state.get("pixel_size", 0.13), + ) + + with st.expander("Model", expanded=True): + model_name = st.selectbox( + "Select a model", + list(MODELS_ZOO.keys()), + index=0, + help="Select a model to use for inference", + ) + finetuned = st.checkbox( + "Use finetuned model", + value=True, + help="Use a finetuned model for inference", + ) + col1, col2 = st.columns(2) + with col1: + st.write("Running on:") + with col2: + st.button( + "GPU" if torch.cuda.is_available() else "CPU", + disabled=True, + ) + + tab_segmentation, tab_charts = st.tabs(["Segmentation", "Charts"]) + + with tab_segmentation: + st.subheader("Segmentation") + if run_segmentation: + run_inference( + model_name=MODELS_ZOO[model_name] + "_finetuned" + if finetuned + else MODELS_ZOO[model_name], + pixel_size=st.session_state.get("pixel_size", 0.13), + ) + st.balloons() + if st.session_state.get("results", None) is not None: + st.write( + st.session_state.results, + ) + + st.download_button( + label="Download results", + data=st.session_state.results.to_csv(index=False).encode("utf-8"), + file_name="results.csv", + mime="text/csv", + use_container_width=True, + ) + with tab_charts: + if st.session_state.get("results", None) is not None: + results = st.session_state.results + + categories = results["image_name"].unique() + selected_category = st.multiselect( + "Select a category", categories, default=categories + ) + plot_result(selected_category) + +else: + st.switch_page("pages/1_Load.py") diff --git a/dnafiber/ui/streamlit/.config.toml b/dnafiber/ui/streamlit/.config.toml new file mode 100644 index 0000000000000000000000000000000000000000..c1f975db792882b84e3cade9e34039f0de8b29f2 --- /dev/null +++ b/dnafiber/ui/streamlit/.config.toml @@ -0,0 +1,3 @@ +[server] +headless = true +port = 8503 \ No newline at end of file diff --git a/dnafiber/ui/utils.py b/dnafiber/ui/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..558a25eeecbdc7feb6dcd5fb0633f9ab807613bb --- /dev/null +++ b/dnafiber/ui/utils.py @@ -0,0 +1,179 @@ +import PIL.Image +import streamlit as st +from dnafiber.data.utils import read_czi, read_tiff, preprocess +import cv2 +import numpy as np +import math +from dnafiber.deployment import _get_model +import PIL +from PIL import Image +import io +import base64 + +MAX_WIDTH = 512 +MAX_HEIGHT = 512 + + + +TYPE_MAPPING = { + 0: "BG", + 1: "SINGLE", + 2: "BILATERAL", + 3: "TRICOLOR", + 4: "MULTICOLOR", +} + +@st.cache_data +def load_image(_filepath, id=None): + filename = str(_filepath.name) + if filename.endswith(".czi"): + return read_czi(_filepath) + elif filename.endswith(".tif") or filename.endswith(".tiff"): + return read_tiff(_filepath) + elif ( + filename.endswith(".png") + or filename.endswith(".jpg") + or filename.endswith(".jpeg") + ): + image = PIL.Image.open(_filepath) + image = np.array(image) + return image + else: + raise NotImplementedError(f"File type {filename} is not supported yet") + + + +@st.cache_data +def get_image(_filepath, reverse_channel, id): + filename = str(_filepath.name) + image = load_image(_filepath, id) + if filename.endswith(".czi") or filename.endswith(".tif") or filename.endswith(".tiff"): + image = preprocess(image, reverse_channel) + image = cv2.normalize( + image, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U + ) + return image + + +def get_multifile_image(_filepaths): + result = None + + if _filepaths[0] is not None: + chan1 = get_image(_filepaths[0], False, _filepaths[0].file_id) + chan1 = cv2.cvtColor(chan1, cv2.COLOR_RGB2GRAY) + h, w = chan1.shape[:2] + else: + chan1 = None + if _filepaths[1] is not None: + chan2 = get_image(_filepaths[1], False, _filepaths[1].file_id) + chan2 = cv2.cvtColor(chan2, cv2.COLOR_RGB2GRAY) + h, w = chan2.shape[:2] + else: + chan2 = None + + result = np.zeros((h, w, 3), dtype=np.uint8) + + if chan1 is not None: + result[:, :, 0] = chan1 + else: + result[:, :, 0] = chan2 + + if chan2 is not None: + result[:, :, 1] = chan2 + else: + result[:, :, 1] = chan1 + + return result + + + +def numpy_to_base64_png(image_array): + """ + Encodes a NumPy image array to a base64 string (PNG format). + + Args: + image_array: A NumPy array representing the image. + + Returns: + A base64 string representing the PNG image. + """ + # Convert NumPy array to PIL Image + image = Image.fromarray(image_array) + + # Create an in-memory binary stream + buffer = io.BytesIO() + + # Save the image to the buffer in PNG format + image.save(buffer, format="jpeg") + + # Get the byte data from the buffer + png_data = buffer.getvalue() + + # Encode the byte data to base64 + base64_encoded = base64.b64encode(png_data).decode() + + return f"data:image/jpeg;base64,{base64_encoded}" + + +@st.cache_data +def get_resized_image(_image, id): + h, w = _image.shape[:2] + if w > MAX_WIDTH: + scale = MAX_WIDTH / w + new_size = (int(w * scale), int(h * scale)) + resized_image = cv2.resize(_image, new_size, interpolation=cv2.INTER_NEAREST) + else: + resized_image = _image + if h > MAX_HEIGHT: + scale = MAX_HEIGHT / h + new_size = (int(w * scale), int(h * scale)) + resized_image = cv2.resize( + resized_image, new_size, interpolation=cv2.INTER_NEAREST + ) + else: + resized_image = resized_image + return resized_image + + +def bokeh_imshow(fig, image): + # image is a numpy array of shape (h, w, 3) or (h, w) of type uint8 + if len(image.shape) == 2: + # grayscale image + image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) + + # Convert to h*w with uint32 + img = np.empty((image.shape[0], image.shape[1]), dtype=np.uint32) + view = img.view(dtype=np.uint8).reshape((image.shape[0], image.shape[1], 4)) # RGBA + view[:, :, 0] = image[:, :, 0] + view[:, :, 1] = image[:, :, 1] + view[:, :, 2] = image[:, :, 2] + view[:, :, 3] = 255 # Alpha channel + fig.image_rgba(image=[img], x=0, y=0, dw=image.shape[1], dh=image.shape[0]) + + +@st.cache_resource +def get_model(device, revision=None): + return _get_model(revision=revision, device=device) + + +def pad_image_to_croppable(_image, bx, by, uid=None): + # Pad the image to be divisible by bx and by + h, w = _image.shape[:2] + if h % bx != 0: + pad_h = bx - (h % bx) + else: + pad_h = 0 + if w % by != 0: + pad_w = by - (w % by) + else: + pad_w = 0 + _image = cv2.copyMakeBorder( + _image, + math.ceil(pad_h / 2), + math.floor(pad_h / 2), + math.ceil(pad_w / 2), + math.floor(pad_w / 2), + cv2.BORDER_CONSTANT, + value=(0, 0, 0), + ) + return _image diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..21695529aa181a6efca6c8bc7006769727bcd1fb --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,39 @@ +[tool.poetry] +name = "dnafiber" +version = "0.2.0" +description = "" +authors = ["ClementPla "] +readme = "README.md" + +[tool.poetry.dependencies] +python = "^3.10" +lightning= "^2.5.0.post0" +monai= "^1.4.0" +segmentation_models_pytorch="^0.4.0" +huggingface_hub="^0.25.2" +czifile="~2019.7.2" +tifffile="~2024.8.10" +streamlit = "1.43.2" +plotly = "6.0.1" +pandas = "2.2.3" +scipy = "1.15.2" +numpy = "^1.22.4" +matplotlib = "^3.10.3" +bokeh = "^3.0.0" +streamlit_bokeh = "^3.0.0" +torchmetrics = "1.6.3" +albumentations = "2.0.3" +streamlit_image_coordinates = "0.2.1" +catppuccin = "2.4.1" +scikit-image = "0.25.2" +kornia = "0.7.3" +numba = "0.60.0" +joblib = "1.4.2" + + +[build-system] +requires = ["poetry-core"] +build-backend = "poetry.core.masonry.api" + +[tool.poetry.scripts] +DNAI = "dnafiber.start:main" \ No newline at end of file