import streamlit as st import torch from dnafiber.ui.utils import get_image, get_multifile_image from dnafiber.deployment import MODELS_ZOO import pandas as pd import plotly.express as px from dnafiber.postprocess import refine_segmentation import torch.nn.functional as F from joblib import Parallel, delayed import time from catppuccin import PALETTE from dnafiber.deployment import _get_model from dnafiber.ui.inference import ui_inference_cacheless def plot_result(seleted_category=None): if st.session_state.get("results", None) is None or selected_category is None: return only_bilateral = st.checkbox( "Show only bicolor fibers", value=False, ) remove_outliers = st.checkbox( "Remove outliers", value=True, help="Remove outliers from the data", ) reorder = st.checkbox( "Reorder groups by median ratio", value=True, ) if remove_outliers: min_ratio, max_ratio = st.slider( "Ratio range", min_value=0.0, max_value=10.0, value=(0.0, 5.0), step=0.1, help="Select the ratio range to display", ) df = st.session_state.results.copy() clean_df = df[["ratio", "image_name", "fiber_type"]].copy() clean_df["Image"] = clean_df["image_name"] clean_df["Fiber Type"] = clean_df["fiber_type"] clean_df["Ratio"] = clean_df["ratio"] if only_bilateral: clean_df = clean_df[clean_df["Fiber Type"] == "double"] if remove_outliers: clean_df = clean_df[ (clean_df["Ratio"] >= min_ratio) & (clean_df["Ratio"] <= max_ratio) ] if selected_category: clean_df = clean_df[clean_df["Image"].isin(selected_category)] if not reorder: clean_df["Image"] = pd.Categorical( clean_df["Image"], categories=selected_category, ordered=True ) clean_df.sort_values("Image", inplace=True) if reorder: image_order = ( clean_df.groupby("Image")["Ratio"] .median() .sort_values(ascending=True) .index ) clean_df["Image"] = pd.Categorical( clean_df["Image"], categories=image_order, ordered=True ) clean_df.sort_values("Image", inplace=True) palette = [c.hex for c in PALETTE.latte.colors] fig = px.violin( clean_df, y="Ratio", x="Image", color="Image", box=True, # draw box plot inside the violin points="all", # can be 'outliers', or False color_discrete_sequence=palette, ) # Set y-axis to log scale st.plotly_chart( fig, use_container_width=True, ) def run_inference(model_name, pixel_size): is_cuda_available = torch.cuda.is_available() if "ensemble" in model_name: model = [ _ + "_finetuned" if "finetuned" in model_name else "" for _ in MODELS_ZOO.values() if _ != "ensemble" ] else: model = _get_model( revision=model_name, device="cuda" if is_cuda_available else "cpu", ) my_bar = st.progress(0, text="Running segmentation...") all_files = st.session_state.files_uploaded all_results = dict( FirstAnalog=[], SecondAnalog=[], length=[], ratio=[], image_name=[], fiber_type=[], ) for i, file in enumerate(all_files): if isinstance(file, tuple): if file[0] is None: filename = file[1].name if file[1] is None: filename = file[0].name image = get_multifile_image(file) else: filename = file.name image = get_image( file, st.session_state.get("reverse_channels", False), file.file_id ) start = time.time() prediction = ui_inference_cacheless( _model=model, _image=image, _device="cuda" if is_cuda_available else "cpu", postprocess=False, ) print(f"Prediction time: {time.time() - start:.2f} seconds for {file.name}") h, w = prediction.shape start = time.time() if h > 2048 or w > 2048: # Extract blocks from the prediction blocks = F.unfold( torch.from_numpy(prediction).unsqueeze(0).float(), kernel_size=(4096, 4096), stride=(4096, 4096), ) blocks = blocks.view(4096, 4096, -1).permute(2, 0, 1).byte().numpy() results = Parallel(n_jobs=4)( delayed(refine_segmentation)(block) for block in blocks ) results = [x for xs in results for x in xs] else: results = refine_segmentation(prediction, fix_junctions=True) print(f"Refinement time: {time.time() - start:.2f} seconds for {filename}") results = [fiber for fiber in results if fiber.is_valid] all_results["FirstAnalog"].extend([fiber.red * pixel_size for fiber in results]) all_results["SecondAnalog"].extend( [fiber.green * pixel_size for fiber in results] ) all_results["length"].extend( [fiber.red * pixel_size + fiber.green * pixel_size for fiber in results] ) all_results["ratio"].extend([fiber.ratio for fiber in results]) all_results["image_name"].extend([filename.split("-")[0] for fiber in results]) all_results["fiber_type"].extend([fiber.fiber_type for fiber in results]) my_bar.progress(i / len(all_files), text=f"{filename} done") st.session_state.results = pd.DataFrame.from_dict(all_results) my_bar.empty() if st.session_state.get("files_uploaded", None): run_segmentation = st.button("Run Segmentation", use_container_width=True) with st.sidebar: st.metric( "Pixel size (µm)", st.session_state.get("pixel_size", 0.13), ) with st.expander("Model", expanded=True): model_name = st.selectbox( "Select a model", list(MODELS_ZOO.keys()), index=0, help="Select a model to use for inference", ) finetuned = st.checkbox( "Use finetuned model", value=True, help="Use a finetuned model for inference", ) col1, col2 = st.columns(2) with col1: st.write("Running on:") with col2: st.button( "GPU" if torch.cuda.is_available() else "CPU", disabled=True, ) tab_segmentation, tab_charts = st.tabs(["Segmentation", "Charts"]) with tab_segmentation: st.subheader("Segmentation") if run_segmentation: run_inference( model_name=MODELS_ZOO[model_name] + "_finetuned" if finetuned else MODELS_ZOO[model_name], pixel_size=st.session_state.get("pixel_size", 0.13), ) st.balloons() if st.session_state.get("results", None) is not None: st.write( st.session_state.results, ) st.download_button( label="Download results", data=st.session_state.results.to_csv(index=False).encode("utf-8"), file_name="results.csv", mime="text/csv", use_container_width=True, ) with tab_charts: if st.session_state.get("results", None) is not None: results = st.session_state.results categories = results["image_name"].unique() selected_category = st.multiselect( "Select a category", categories, default=categories ) plot_result(selected_category) else: st.switch_page("pages/1_Load.py")