import streamlit as st
from bokeh.plotting import figure
from bokeh.layouts import gridplot
from streamlit_bokeh import streamlit_bokeh
from dnafiber.ui.utils import (
get_image,
get_multifile_image,
get_resized_image,
bokeh_imshow,
pad_image_to_croppable,
numpy_to_base64_png,
)
from dnafiber.deployment import MODELS_ZOO
from dnafiber.ui.inference import ui_inference, get_model
from skimage.util import view_as_blocks
import cv2
import math
from bokeh.models import (
Range1d,
HoverTool,
)
import streamlit_image_coordinates
from catppuccin import PALETTE
import numpy as np
import torch
from skimage.segmentation import expand_labels
import pandas as pd
st.set_page_config(
layout="wide",
page_icon=":microscope:",
)
st.title("Viewer")
@st.cache_resource
def display_prediction(_prediction, _image, image_id=None):
max_width = 2048
image = _image
if image.max() > 25:
image = cv2.normalize(image, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U)
scale = 1
# Resize the image to max_width
if image.shape[1] > max_width:
scale = max_width / image.shape[1]
image = cv2.resize(
image,
None,
fx=scale,
fy=scale,
interpolation=cv2.INTER_LINEAR,
)
h, w = image.shape[:2]
labels_maps = np.zeros((h, w), dtype=np.uint8)
for i, region in enumerate(_prediction):
x, y, w, h = region.scaled_coordinates(scale)
data = cv2.resize(
expand_labels(region.data, 1),
None,
fx=scale,
fy=scale,
interpolation=cv2.INTER_NEAREST,
)
labels_maps[
y : y + data.shape[0],
x : x + data.shape[1],
] = data
p1 = figure(
width=600,
x_range=Range1d(-image.shape[1] / 8, image.shape[1] * 1.125, bounds="auto"),
y_range=Range1d(image.shape[0] * 1.125, -image.shape[0] / 8, bounds="auto"),
title=f"Detected fibers: {len(_prediction)}",
tools="pan,wheel_zoom,box_zoom,reset",
active_scroll="wheel_zoom",
)
p1.image(
image=[labels_maps],
x=0,
y=0,
dw=labels_maps.shape[1],
dh=labels_maps.shape[0],
palette=["black", st.session_state["color1"], st.session_state["color2"]]
if np.max(labels_maps) > 0
else ["black"],
)
p2 = figure(
x_range=p1.x_range,
y_range=p1.y_range,
width=600,
tools="pan,wheel_zoom,box_zoom,reset",
active_scroll="wheel_zoom",
)
bokeh_imshow(p2, image)
colors = [c.hex for c in PALETTE.latte.colors][:14]
data_source = dict(
x=[],
y=[],
width=[],
height=[],
color=[],
firstAnalog=[],
secondAnalog=[],
ratio=[],
fiber_id=[],
)
np.random.shuffle(colors)
for i, region in enumerate(_prediction):
color = colors[i % len(colors)]
x, y, w, h = region.scaled_coordinates(scale)
fiberId = region.fiber_id
data_source["x"].append((x + w / 2))
data_source["y"].append((y + h / 2))
data_source["width"].append(w)
data_source["height"].append(h)
data_source["color"].append(color)
r, g = region.counts
red_length = st.session_state["pixel_size"] * r / scale
green_length = st.session_state["pixel_size"] * g / scale
data_source["firstAnalog"].append(f"{red_length:.2f} µm")
data_source["secondAnalog"].append(f"{green_length:.2f} µm")
data_source["ratio"].append(f"{green_length / red_length:.2f}")
data_source["fiber_id"].append(fiberId)
rect1 = p1.rect(
x="x",
y="y",
width="width",
height="height",
source=data_source,
fill_color=None,
line_color="color",
)
rect2 = p2.rect(
x="x",
y="y",
width="width",
height="height",
source=data_source,
fill_color=None,
line_color="color",
)
hover = HoverTool(
tooltips=f'Fiber ID: @fiber_id
@firstAnalog
@secondAnalog
Ratio: @ratio', ) hover.renderers = [rect1, rect2] hover.point_policy = "follow_mouse" hover.attachment = "vertical" p1.add_tools(hover) p2.add_tools(hover) p1.axis.visible = False p2.axis.visible = False fig = gridplot( [[p2, p1]], merge_tools=True, sizing_mode="stretch_width", toolbar_options=dict(logo=None, help=None), ) return fig @st.cache_data def show_fibers(_prediction, _image, image_id=None): data = dict( fiber_id=[], firstAnalog=[], secondAnalog=[], ratio=[], fiber_type=[], visualization=[], ) for fiber in _prediction: data["fiber_id"].append(fiber.fiber_id) r, g = fiber.counts red_length = st.session_state["pixel_size"] * r green_length = st.session_state["pixel_size"] * g data["firstAnalog"].append(f"{red_length:.3f} ") data["secondAnalog"].append(f"{green_length:.3f} ") data["ratio"].append(f"{green_length / red_length:.3f}") data["fiber_type"].append(fiber.fiber_type) x, y, w, h = fiber.bbox visu = _image[y : y + h, x : x + w, :] visu = cv2.normalize(visu, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U) data["visualization"].append(visu) df = pd.DataFrame(data) df = df.rename( columns={ "firstAnalog": "First analog (µm)", "secondAnalog": "Second analog (µm)", "ratio": "Ratio", "fiber_type": "Fiber type", "fiber_id": "Fiber ID", "visualization": "Visualization", } ) df["Visualization"] = df["Visualization"].apply(lambda x: numpy_to_base64_png(x)) return df def start_inference(): image = st.session_state.image_inference image = cv2.normalize(image, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U) if "ensemble" in st.session_state.model: model = [ _ + "_finetuned" if "finetuned" in st.session_state.model else "" for _ in MODELS_ZOO.values() if _ != "ensemble" ] else: model = get_model(st.session_state.model) prediction = ui_inference( model, image, "cuda" if torch.cuda.is_available() else "cpu", st.session_state.post_process, st.session_state.image_id, ) prediction = [ p for p in prediction if (p.fiber_type != "single") and p.fiber_type != "multiple" ] tab_viewer, tab_fibers = st.tabs(["Viewer", "Fibers"]) with tab_fibers: df = show_fibers(prediction, image, st.session_state.image_id) event = st.dataframe( df, on_select="rerun", selection_mode="multi-row", use_container_width=True, column_config={ "Visualization": st.column_config.ImageColumn( "Visualization", help="Visualization of the fiber", ) }, ) rows = event["selection"]["rows"] columns = df.columns[:-2] df = df.iloc[rows][columns] cols = st.columns(3) with cols[0]: copy_to_clipboard = st.button( "Copy selected fibers to clipboard", help="Copy the selected fibers to clipboard in CSV format.", ) if copy_to_clipboard: df.to_clipboard(index=False) with cols[2]: st.download_button( "Download selected fibers", data=df.to_csv(index=False).encode("utf-8"), file_name=f"fibers_{st.session_state.image_id}.csv", mime="text/csv", ) with tab_viewer: max_width = 2048 if image.shape[1] > max_width: st.toast("Images are displayed at a lower resolution of 2048 pixel wide") fig = display_prediction(prediction, image, st.session_state.image_id) streamlit_bokeh(fig, use_container_width=True) def on_session_start(): can_start = ( st.session_state.get("files_uploaded", None) is not None and len(st.session_state.files_uploaded) > 0 ) if can_start: return can_start cldu_exists = ( st.session_state.get("files_uploaded_cldu", None) is not None and len(st.session_state.files_uploaded_cldu) > 0 ) idu_exists = ( st.session_state.get("files_uploaded_idu", None) is not None and len(st.session_state.files_uploaded_idu) > 0 ) if cldu_exists and idu_exists: if len(st.session_state.get("files_uploaded_cldu")) != len( st.session_state.get("files_uploaded_idu") ): st.error("Please upload the same number of CldU and IdU files.") return False def create_display_files(files): if files is None or len(files) == 0: return "No files uploaded" display_files = [] for file in files: if isinstance(file, tuple): if file[0] is None: name = f"Second analog only {file[1].name}" elif file[1] is None: name = f"First analog only {file[0].name}" else: name = f"{file[0].name} and {file[1].name}" display_files.append(name) else: display_files.append(file.name) return display_files if on_session_start(): files = st.session_state.files_uploaded displayed_names = create_display_files(files) selected_file = st.selectbox( "Pick an image", displayed_names, index=0, help="Select an image to view and analyze.", ) # Find index of the selected file index = displayed_names.index(selected_file) file = files[index] if isinstance(file, tuple): file_id = file[0].file_id if file[0] is not None else file[1].file_id if file[0] is None or file[1] is None: missing = "First analog" if file[0] is None else "Second analog" st.warning( f"In this image, {missing} channel is missing. We assume the intended goal is to segment the DNA fibers without differentiation. \ Note the model may still predict two classes and try to compute a ratio; these informations can be ignored." ) image = get_multifile_image(file) else: file_id = file.file_id image = get_image( file, reverse_channel=st.session_state.get("reverse_channels", False), id=file_id, ) h, w = image.shape[:2] with st.sidebar: st.metric( "Pixel size (µm)", st.session_state.get("pixel_size", 0.13), ) block_size = st.slider( "Block size", min_value=256, max_value=min(4096, max(h, w)), value=min(2048, max(h, w)), step=256, ) if h < block_size: block_size = h if w < block_size: block_size = w bx = by = block_size image = pad_image_to_croppable(image, bx, by, file_id + str(bx) + str(by)) thumbnail = get_resized_image(image, file_id) blocks = view_as_blocks(image, (bx, by, 3)) x_blocks, y_blocks = blocks.shape[0], blocks.shape[1] with st.sidebar: with st.expander("Model", expanded=True): model_name = st.selectbox( "Select a model", list(MODELS_ZOO.keys()), index=0, help="Select a model to use for inference", ) finetuned = st.checkbox( "Use finetuned model", value=True, help="Use a finetuned model for inference", ) col1, col2 = st.columns(2) with col1: st.write("Running on:") with col2: st.button( "GPU" if torch.cuda.is_available() else "CPU", disabled=True, ) st.session_state.post_process = st.checkbox( "Post-process", value=True, help="Apply post-processing to the prediction", ) st.session_state.model = ( (MODELS_ZOO[model_name] + "_finetuned") if finetuned else MODELS_ZOO[model_name] ) which_y = st.session_state.get("which_y", 0) which_x = st.session_state.get("which_x", 0) # Display the selected block # Scale factor h, w = image.shape[:2] small_h, small_w = thumbnail.shape[:2] scale_h = h / small_h scale_w = w / small_w # Calculate the coordinates of the block y1 = math.floor(which_y * bx / scale_h) y2 = math.floor((which_y + 1) * bx / scale_h) x1 = math.floor(which_x * by / scale_w) x2 = math.floor((which_x + 1) * by / scale_w) # Draw a rectangle around the selected block # Check if the coordinates are within the bounds of the image while y2 > small_h: which_y -= 1 y1 = math.floor(which_y * bx / scale_h) y2 = math.floor((which_y + 1) * bx / scale_h) while x2 > small_w: which_x -= 1 x1 = math.floor(which_x * by / scale_w) x2 = math.floor((which_x + 1) * by / scale_w) st.session_state["which_x"] = which_x st.session_state["which_y"] = which_y # Draw a grid on the thumbnail for i in range(0, small_h, int(bx // scale_h)): cv2.line(thumbnail, (0, i), (small_w, i), (255, 255, 255), 1) for i in range(0, small_w, int(by // scale_w)): cv2.line(thumbnail, (i, 0), (i, small_h), (255, 255, 255), 1) cv2.rectangle( thumbnail, (x1, y1), (x2, y2), (0, 0, 255), 5, ) st.write("### Select a block") coordinates = streamlit_image_coordinates.streamlit_image_coordinates( thumbnail, use_column_width=True ) if coordinates: which_x = math.floor((w * coordinates["x"] / coordinates["width"]) / bx) which_y = math.floor((h * coordinates["y"] / coordinates["height"]) / by) if which_x != st.session_state.get("which_x", 0): st.session_state["which_x"] = which_x if which_y != st.session_state.get("which_y", 0): st.session_state["which_y"] = which_y st.rerun() image = blocks[which_y, which_x, 0] with st.sidebar: st.image(image, caption="Selected block", use_container_width=True) st.session_state.image_inference = image st.session_state.image_id = ( file_id + str(which_x) + str(which_y) + str(bx) + str(by) + str(model_name) + ("_finetuned" if finetuned else "") ) col1, col2, col3 = st.columns([1, 1, 1]) start_inference() else: st.switch_page("pages/1_Load.py") # Add a callback to mouse move event