|
import streamlit as st
|
|
from bokeh.plotting import figure
|
|
from bokeh.layouts import gridplot
|
|
from streamlit_bokeh import streamlit_bokeh
|
|
from dnafiber.ui.utils import (
|
|
get_image,
|
|
get_multifile_image,
|
|
get_resized_image,
|
|
bokeh_imshow,
|
|
pad_image_to_croppable,
|
|
numpy_to_base64_png,
|
|
)
|
|
from dnafiber.deployment import MODELS_ZOO
|
|
from dnafiber.ui.inference import ui_inference, get_model
|
|
from skimage.util import view_as_blocks
|
|
import cv2
|
|
import math
|
|
from bokeh.models import (
|
|
Range1d,
|
|
HoverTool,
|
|
)
|
|
import streamlit_image_coordinates
|
|
from catppuccin import PALETTE
|
|
import numpy as np
|
|
import torch
|
|
from skimage.segmentation import expand_labels
|
|
import pandas as pd
|
|
|
|
st.set_page_config(
|
|
layout="wide",
|
|
page_icon=":microscope:",
|
|
)
|
|
st.title("Viewer")
|
|
|
|
|
|
@st.cache_resource
|
|
def display_prediction(_prediction, _image, image_id=None):
|
|
max_width = 2048
|
|
image = _image
|
|
if image.max() > 25:
|
|
image = cv2.normalize(image, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U)
|
|
scale = 1
|
|
|
|
if image.shape[1] > max_width:
|
|
scale = max_width / image.shape[1]
|
|
image = cv2.resize(
|
|
image,
|
|
None,
|
|
fx=scale,
|
|
fy=scale,
|
|
interpolation=cv2.INTER_LINEAR,
|
|
)
|
|
|
|
h, w = image.shape[:2]
|
|
labels_maps = np.zeros((h, w), dtype=np.uint8)
|
|
for i, region in enumerate(_prediction):
|
|
x, y, w, h = region.scaled_coordinates(scale)
|
|
data = cv2.resize(
|
|
expand_labels(region.data, 1),
|
|
None,
|
|
fx=scale,
|
|
fy=scale,
|
|
interpolation=cv2.INTER_NEAREST,
|
|
)
|
|
labels_maps[
|
|
y : y + data.shape[0],
|
|
x : x + data.shape[1],
|
|
] = data
|
|
p1 = figure(
|
|
width=600,
|
|
x_range=Range1d(-image.shape[1] / 8, image.shape[1] * 1.125, bounds="auto"),
|
|
y_range=Range1d(image.shape[0] * 1.125, -image.shape[0] / 8, bounds="auto"),
|
|
title=f"Detected fibers: {len(_prediction)}",
|
|
tools="pan,wheel_zoom,box_zoom,reset",
|
|
active_scroll="wheel_zoom",
|
|
)
|
|
|
|
p1.image(
|
|
image=[labels_maps],
|
|
x=0,
|
|
y=0,
|
|
dw=labels_maps.shape[1],
|
|
dh=labels_maps.shape[0],
|
|
palette=["black", st.session_state["color1"], st.session_state["color2"]]
|
|
if np.max(labels_maps) > 0
|
|
else ["black"],
|
|
)
|
|
p2 = figure(
|
|
x_range=p1.x_range,
|
|
y_range=p1.y_range,
|
|
width=600,
|
|
tools="pan,wheel_zoom,box_zoom,reset",
|
|
active_scroll="wheel_zoom",
|
|
)
|
|
bokeh_imshow(p2, image)
|
|
colors = [c.hex for c in PALETTE.latte.colors][:14]
|
|
data_source = dict(
|
|
x=[],
|
|
y=[],
|
|
width=[],
|
|
height=[],
|
|
color=[],
|
|
firstAnalog=[],
|
|
secondAnalog=[],
|
|
ratio=[],
|
|
fiber_id=[],
|
|
)
|
|
np.random.shuffle(colors)
|
|
for i, region in enumerate(_prediction):
|
|
color = colors[i % len(colors)]
|
|
x, y, w, h = region.scaled_coordinates(scale)
|
|
|
|
fiberId = region.fiber_id
|
|
data_source["x"].append((x + w / 2))
|
|
data_source["y"].append((y + h / 2))
|
|
data_source["width"].append(w)
|
|
data_source["height"].append(h)
|
|
data_source["color"].append(color)
|
|
r, g = region.counts
|
|
red_length = st.session_state["pixel_size"] * r / scale
|
|
green_length = st.session_state["pixel_size"] * g / scale
|
|
data_source["firstAnalog"].append(f"{red_length:.2f} µm")
|
|
data_source["secondAnalog"].append(f"{green_length:.2f} µm")
|
|
data_source["ratio"].append(f"{green_length / red_length:.2f}")
|
|
data_source["fiber_id"].append(fiberId)
|
|
|
|
rect1 = p1.rect(
|
|
x="x",
|
|
y="y",
|
|
width="width",
|
|
height="height",
|
|
source=data_source,
|
|
fill_color=None,
|
|
line_color="color",
|
|
)
|
|
rect2 = p2.rect(
|
|
x="x",
|
|
y="y",
|
|
width="width",
|
|
height="height",
|
|
source=data_source,
|
|
fill_color=None,
|
|
line_color="color",
|
|
)
|
|
|
|
hover = HoverTool(
|
|
tooltips=f'<b>Fiber ID: @fiber_id</b><br><p style="color:{st.session_state["color1"]};">@firstAnalog</p> <p style="color:{st.session_state["color2"]};">@secondAnalog</p><b> Ratio: @ratio</b>',
|
|
)
|
|
hover.renderers = [rect1, rect2]
|
|
hover.point_policy = "follow_mouse"
|
|
hover.attachment = "vertical"
|
|
p1.add_tools(hover)
|
|
p2.add_tools(hover)
|
|
|
|
p1.axis.visible = False
|
|
p2.axis.visible = False
|
|
fig = gridplot(
|
|
[[p2, p1]],
|
|
merge_tools=True,
|
|
sizing_mode="stretch_width",
|
|
toolbar_options=dict(logo=None, help=None),
|
|
)
|
|
return fig
|
|
|
|
|
|
@st.cache_data
|
|
def show_fibers(_prediction, _image, image_id=None):
|
|
data = dict(
|
|
fiber_id=[],
|
|
firstAnalog=[],
|
|
secondAnalog=[],
|
|
ratio=[],
|
|
fiber_type=[],
|
|
visualization=[],
|
|
)
|
|
|
|
for fiber in _prediction:
|
|
data["fiber_id"].append(fiber.fiber_id)
|
|
r, g = fiber.counts
|
|
red_length = st.session_state["pixel_size"] * r
|
|
green_length = st.session_state["pixel_size"] * g
|
|
data["firstAnalog"].append(f"{red_length:.3f} ")
|
|
data["secondAnalog"].append(f"{green_length:.3f} ")
|
|
data["ratio"].append(f"{green_length / red_length:.3f}")
|
|
data["fiber_type"].append(fiber.fiber_type)
|
|
|
|
x, y, w, h = fiber.bbox
|
|
|
|
visu = _image[y : y + h, x : x + w, :]
|
|
visu = cv2.normalize(visu, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U)
|
|
data["visualization"].append(visu)
|
|
|
|
df = pd.DataFrame(data)
|
|
df = df.rename(
|
|
columns={
|
|
"firstAnalog": "First analog (µm)",
|
|
"secondAnalog": "Second analog (µm)",
|
|
"ratio": "Ratio",
|
|
"fiber_type": "Fiber type",
|
|
"fiber_id": "Fiber ID",
|
|
"visualization": "Visualization",
|
|
}
|
|
)
|
|
df["Visualization"] = df["Visualization"].apply(lambda x: numpy_to_base64_png(x))
|
|
return df
|
|
|
|
|
|
def start_inference():
|
|
image = st.session_state.image_inference
|
|
image = cv2.normalize(image, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U)
|
|
|
|
if "ensemble" in st.session_state.model:
|
|
model = [
|
|
_ + "_finetuned" if "finetuned" in st.session_state.model else ""
|
|
for _ in MODELS_ZOO.values()
|
|
if _ != "ensemble"
|
|
]
|
|
else:
|
|
model = get_model(st.session_state.model)
|
|
prediction = ui_inference(
|
|
model,
|
|
image,
|
|
"cuda" if torch.cuda.is_available() else "cpu",
|
|
st.session_state.post_process,
|
|
st.session_state.image_id,
|
|
)
|
|
prediction = [
|
|
p
|
|
for p in prediction
|
|
if (p.fiber_type != "single") and p.fiber_type != "multiple"
|
|
]
|
|
tab_viewer, tab_fibers = st.tabs(["Viewer", "Fibers"])
|
|
|
|
with tab_fibers:
|
|
df = show_fibers(prediction, image, st.session_state.image_id)
|
|
|
|
event = st.dataframe(
|
|
df,
|
|
on_select="rerun",
|
|
selection_mode="multi-row",
|
|
use_container_width=True,
|
|
column_config={
|
|
"Visualization": st.column_config.ImageColumn(
|
|
"Visualization",
|
|
help="Visualization of the fiber",
|
|
)
|
|
},
|
|
)
|
|
|
|
rows = event["selection"]["rows"]
|
|
columns = df.columns[:-2]
|
|
df = df.iloc[rows][columns]
|
|
|
|
cols = st.columns(3)
|
|
with cols[0]:
|
|
copy_to_clipboard = st.button(
|
|
"Copy selected fibers to clipboard",
|
|
help="Copy the selected fibers to clipboard in CSV format.",
|
|
)
|
|
if copy_to_clipboard:
|
|
df.to_clipboard(index=False)
|
|
with cols[2]:
|
|
st.download_button(
|
|
"Download selected fibers",
|
|
data=df.to_csv(index=False).encode("utf-8"),
|
|
file_name=f"fibers_{st.session_state.image_id}.csv",
|
|
mime="text/csv",
|
|
)
|
|
|
|
with tab_viewer:
|
|
max_width = 2048
|
|
if image.shape[1] > max_width:
|
|
st.toast("Images are displayed at a lower resolution of 2048 pixel wide")
|
|
|
|
fig = display_prediction(prediction, image, st.session_state.image_id)
|
|
streamlit_bokeh(fig, use_container_width=True)
|
|
|
|
|
|
def on_session_start():
|
|
can_start = (
|
|
st.session_state.get("files_uploaded", None) is not None
|
|
and len(st.session_state.files_uploaded) > 0
|
|
)
|
|
|
|
if can_start:
|
|
return can_start
|
|
|
|
cldu_exists = (
|
|
st.session_state.get("files_uploaded_cldu", None) is not None
|
|
and len(st.session_state.files_uploaded_cldu) > 0
|
|
)
|
|
idu_exists = (
|
|
st.session_state.get("files_uploaded_idu", None) is not None
|
|
and len(st.session_state.files_uploaded_idu) > 0
|
|
)
|
|
|
|
if cldu_exists and idu_exists:
|
|
if len(st.session_state.get("files_uploaded_cldu")) != len(
|
|
st.session_state.get("files_uploaded_idu")
|
|
):
|
|
st.error("Please upload the same number of CldU and IdU files.")
|
|
return False
|
|
|
|
|
|
def create_display_files(files):
|
|
if files is None or len(files) == 0:
|
|
return "No files uploaded"
|
|
display_files = []
|
|
for file in files:
|
|
if isinstance(file, tuple):
|
|
if file[0] is None:
|
|
name = f"Second analog only {file[1].name}"
|
|
elif file[1] is None:
|
|
name = f"First analog only {file[0].name}"
|
|
else:
|
|
name = f"{file[0].name} and {file[1].name}"
|
|
display_files.append(name)
|
|
else:
|
|
display_files.append(file.name)
|
|
return display_files
|
|
|
|
|
|
if on_session_start():
|
|
files = st.session_state.files_uploaded
|
|
displayed_names = create_display_files(files)
|
|
selected_file = st.selectbox(
|
|
"Pick an image",
|
|
displayed_names,
|
|
index=0,
|
|
help="Select an image to view and analyze.",
|
|
)
|
|
|
|
|
|
index = displayed_names.index(selected_file)
|
|
file = files[index]
|
|
if isinstance(file, tuple):
|
|
file_id = file[0].file_id if file[0] is not None else file[1].file_id
|
|
if file[0] is None or file[1] is None:
|
|
missing = "First analog" if file[0] is None else "Second analog"
|
|
st.warning(
|
|
f"In this image, {missing} channel is missing. We assume the intended goal is to segment the DNA fibers without differentiation. \
|
|
Note the model may still predict two classes and try to compute a ratio; these informations can be ignored."
|
|
)
|
|
image = get_multifile_image(file)
|
|
else:
|
|
file_id = file.file_id
|
|
image = get_image(
|
|
file,
|
|
reverse_channel=st.session_state.get("reverse_channels", False),
|
|
id=file_id,
|
|
)
|
|
h, w = image.shape[:2]
|
|
with st.sidebar:
|
|
st.metric(
|
|
"Pixel size (µm)",
|
|
st.session_state.get("pixel_size", 0.13),
|
|
)
|
|
|
|
block_size = st.slider(
|
|
"Block size",
|
|
min_value=256,
|
|
max_value=min(4096, max(h, w)),
|
|
value=min(2048, max(h, w)),
|
|
step=256,
|
|
)
|
|
if h < block_size:
|
|
block_size = h
|
|
if w < block_size:
|
|
block_size = w
|
|
|
|
bx = by = block_size
|
|
image = pad_image_to_croppable(image, bx, by, file_id + str(bx) + str(by))
|
|
thumbnail = get_resized_image(image, file_id)
|
|
|
|
blocks = view_as_blocks(image, (bx, by, 3))
|
|
x_blocks, y_blocks = blocks.shape[0], blocks.shape[1]
|
|
with st.sidebar:
|
|
with st.expander("Model", expanded=True):
|
|
model_name = st.selectbox(
|
|
"Select a model",
|
|
list(MODELS_ZOO.keys()),
|
|
index=0,
|
|
help="Select a model to use for inference",
|
|
)
|
|
finetuned = st.checkbox(
|
|
"Use finetuned model",
|
|
value=True,
|
|
help="Use a finetuned model for inference",
|
|
)
|
|
|
|
col1, col2 = st.columns(2)
|
|
with col1:
|
|
st.write("Running on:")
|
|
with col2:
|
|
st.button(
|
|
"GPU" if torch.cuda.is_available() else "CPU",
|
|
disabled=True,
|
|
)
|
|
|
|
st.session_state.post_process = st.checkbox(
|
|
"Post-process",
|
|
value=True,
|
|
help="Apply post-processing to the prediction",
|
|
)
|
|
|
|
st.session_state.model = (
|
|
(MODELS_ZOO[model_name] + "_finetuned")
|
|
if finetuned
|
|
else MODELS_ZOO[model_name]
|
|
)
|
|
|
|
which_y = st.session_state.get("which_y", 0)
|
|
which_x = st.session_state.get("which_x", 0)
|
|
|
|
|
|
|
|
h, w = image.shape[:2]
|
|
small_h, small_w = thumbnail.shape[:2]
|
|
scale_h = h / small_h
|
|
scale_w = w / small_w
|
|
|
|
y1 = math.floor(which_y * bx / scale_h)
|
|
y2 = math.floor((which_y + 1) * bx / scale_h)
|
|
x1 = math.floor(which_x * by / scale_w)
|
|
x2 = math.floor((which_x + 1) * by / scale_w)
|
|
|
|
|
|
|
|
while y2 > small_h:
|
|
which_y -= 1
|
|
y1 = math.floor(which_y * bx / scale_h)
|
|
y2 = math.floor((which_y + 1) * bx / scale_h)
|
|
while x2 > small_w:
|
|
which_x -= 1
|
|
x1 = math.floor(which_x * by / scale_w)
|
|
x2 = math.floor((which_x + 1) * by / scale_w)
|
|
|
|
st.session_state["which_x"] = which_x
|
|
st.session_state["which_y"] = which_y
|
|
|
|
|
|
for i in range(0, small_h, int(bx // scale_h)):
|
|
cv2.line(thumbnail, (0, i), (small_w, i), (255, 255, 255), 1)
|
|
for i in range(0, small_w, int(by // scale_w)):
|
|
cv2.line(thumbnail, (i, 0), (i, small_h), (255, 255, 255), 1)
|
|
|
|
cv2.rectangle(
|
|
thumbnail,
|
|
(x1, y1),
|
|
(x2, y2),
|
|
(0, 0, 255),
|
|
5,
|
|
)
|
|
|
|
st.write("### Select a block")
|
|
|
|
coordinates = streamlit_image_coordinates.streamlit_image_coordinates(
|
|
thumbnail, use_column_width=True
|
|
)
|
|
|
|
if coordinates:
|
|
which_x = math.floor((w * coordinates["x"] / coordinates["width"]) / bx)
|
|
which_y = math.floor((h * coordinates["y"] / coordinates["height"]) / by)
|
|
if which_x != st.session_state.get("which_x", 0):
|
|
st.session_state["which_x"] = which_x
|
|
if which_y != st.session_state.get("which_y", 0):
|
|
st.session_state["which_y"] = which_y
|
|
|
|
st.rerun()
|
|
|
|
image = blocks[which_y, which_x, 0]
|
|
with st.sidebar:
|
|
st.image(image, caption="Selected block", use_container_width=True)
|
|
|
|
st.session_state.image_inference = image
|
|
st.session_state.image_id = (
|
|
file_id
|
|
+ str(which_x)
|
|
+ str(which_y)
|
|
+ str(bx)
|
|
+ str(by)
|
|
+ str(model_name)
|
|
+ ("_finetuned" if finetuned else "")
|
|
)
|
|
col1, col2, col3 = st.columns([1, 1, 1])
|
|
start_inference()
|
|
else:
|
|
st.switch_page("pages/1_Load.py")
|
|
|
|
|
|
|