import streamlit as st from dnafiber.inference import infer from dnafiber.postprocess.core import refine_segmentation import numpy as np from dnafiber.deployment import _get_model import torch @st.cache_data def ui_inference(_model, _image, _device, postprocess=True, id=None): return ui_inference_cacheless( _model, _image, _device, postprocess=postprocess, id=id ) @st.cache_resource def get_model(model_name): model = _get_model( device="cuda" if torch.cuda.is_available() else "cpu", revision=model_name, ) return model def ui_inference_cacheless(_model, _image, _device, postprocess=True, id=None): """ A cacheless version of the ui_inference function. This function does not use caching and is intended for use in scenarios where caching is not desired. """ h, w = _image.shape[:2] with st.spinner("Sliding window segmentation in progress..."): if isinstance(_model, list): output = None for model in _model: if isinstance(model, str): model = get_model(model) with st.spinner(text="Segmenting with model: {}".format(model)): if output is None: output = infer( model, image=_image, device=_device, scale=st.session_state.get("pixel_size", 0.13), only_probabilities=True, ).cpu() else: output = ( output + infer( model, image=_image, device=_device, scale=st.session_state.get("pixel_size", 0.13), only_probabilities=True, ).cpu() ) output = (output / len(_model)).argmax(1).squeeze().numpy() else: output = infer( _model, image=_image, device=_device, scale=st.session_state.get("pixel_size", 0.13), ) output = output.astype(np.uint8) if postprocess: with st.spinner("Post-processing segmentation..."): output = refine_segmentation(output, fix_junctions=postprocess) return output