File size: 8,273 Bytes
69591a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
import streamlit as st
import torch
from dnafiber.ui.utils import get_image, get_multifile_image
from dnafiber.deployment import MODELS_ZOO
import pandas as pd
import plotly.express as px
from dnafiber.postprocess import refine_segmentation
import torch.nn.functional as F
from joblib import Parallel, delayed
import time
from catppuccin import PALETTE
from dnafiber.deployment import _get_model
from dnafiber.ui.inference import ui_inference_cacheless


def plot_result(seleted_category=None):
    if st.session_state.get("results", None) is None or selected_category is None:
        return
    only_bilateral = st.checkbox(
        "Show only bicolor fibers",
        value=False,
    )
    remove_outliers = st.checkbox(
        "Remove outliers",
        value=True,
        help="Remove outliers from the data",
    )
    reorder = st.checkbox(
        "Reorder groups by median ratio",
        value=True,
    )
    if remove_outliers:
        min_ratio, max_ratio = st.slider(
            "Ratio range",
            min_value=0.0,
            max_value=10.0,
            value=(0.0, 5.0),
            step=0.1,
            help="Select the ratio range to display",
        )
    df = st.session_state.results.copy()

    clean_df = df[["ratio", "image_name", "fiber_type"]].copy()
    clean_df["Image"] = clean_df["image_name"]
    clean_df["Fiber Type"] = clean_df["fiber_type"]
    clean_df["Ratio"] = clean_df["ratio"]

    if only_bilateral:
        clean_df = clean_df[clean_df["Fiber Type"] == "double"]
    if remove_outliers:
        clean_df = clean_df[
            (clean_df["Ratio"] >= min_ratio) & (clean_df["Ratio"] <= max_ratio)
        ]

    if selected_category:
        clean_df = clean_df[clean_df["Image"].isin(selected_category)]

        if not reorder:
            clean_df["Image"] = pd.Categorical(
                clean_df["Image"], categories=selected_category, ordered=True
            )
            clean_df.sort_values("Image", inplace=True)

    if reorder:
        image_order = (
            clean_df.groupby("Image")["Ratio"]
            .median()
            .sort_values(ascending=True)
            .index
        )
        clean_df["Image"] = pd.Categorical(
            clean_df["Image"], categories=image_order, ordered=True
        )
        clean_df.sort_values("Image", inplace=True)

    palette = [c.hex for c in PALETTE.latte.colors]

    fig = px.violin(
        clean_df,
        y="Ratio",
        x="Image",
        color="Image",
        box=True,  # draw box plot inside the violin
        points="all",  # can be 'outliers', or False
        color_discrete_sequence=palette,
    )
    # Set y-axis to log scale
    st.plotly_chart(
        fig,
        use_container_width=True,
    )


def run_inference(model_name, pixel_size):
    is_cuda_available = torch.cuda.is_available()
    if "ensemble" in model_name:
        model = [
            _ + "_finetuned" if "finetuned" in model_name else ""
            for _ in MODELS_ZOO.values()
            if _ != "ensemble"
        ]
    else:
        model = _get_model(
            revision=model_name,
            device="cuda" if is_cuda_available else "cpu",
        )

    my_bar = st.progress(0, text="Running segmentation...")
    all_files = st.session_state.files_uploaded
    all_results = dict(
        FirstAnalog=[],
        SecondAnalog=[],
        length=[],
        ratio=[],
        image_name=[],
        fiber_type=[],
    )
    for i, file in enumerate(all_files):
        if isinstance(file, tuple):
            if file[0] is None:
                filename = file[1].name
            if file[1] is None:
                filename = file[0].name
            image = get_multifile_image(file)
        else:
            filename = file.name
            image = get_image(
                file, st.session_state.get("reverse_channels", False), file.file_id
            )
        start = time.time()
        prediction = ui_inference_cacheless(
            _model=model,
            _image=image,
            _device="cuda" if is_cuda_available else "cpu",
            postprocess=False,
        )
        print(f"Prediction time: {time.time() - start:.2f} seconds for {file.name}")
        h, w = prediction.shape
        start = time.time()
        if h > 2048 or w > 2048:
            # Extract blocks from the prediction
            blocks = F.unfold(
                torch.from_numpy(prediction).unsqueeze(0).float(),
                kernel_size=(4096, 4096),
                stride=(4096, 4096),
            )
            blocks = blocks.view(4096, 4096, -1).permute(2, 0, 1).byte().numpy()
            results = Parallel(n_jobs=4)(
                delayed(refine_segmentation)(block) for block in blocks
            )
            results = [x for xs in results for x in xs]

        else:
            results = refine_segmentation(prediction, fix_junctions=True)

        print(f"Refinement time: {time.time() - start:.2f} seconds for {filename}")
        results = [fiber for fiber in results if fiber.is_valid]
        all_results["FirstAnalog"].extend([fiber.red * pixel_size for fiber in results])
        all_results["SecondAnalog"].extend(
            [fiber.green * pixel_size for fiber in results]
        )
        all_results["length"].extend(
            [fiber.red * pixel_size + fiber.green * pixel_size for fiber in results]
        )
        all_results["ratio"].extend([fiber.ratio for fiber in results])
        all_results["image_name"].extend([filename.split("-")[0] for fiber in results])
        all_results["fiber_type"].extend([fiber.fiber_type for fiber in results])

        my_bar.progress(i / len(all_files), text=f"{filename} done")

    st.session_state.results = pd.DataFrame.from_dict(all_results)

    my_bar.empty()


if st.session_state.get("files_uploaded", None):
    run_segmentation = st.button("Run Segmentation", use_container_width=True)

    with st.sidebar:
        st.metric(
            "Pixel size (µm)",
            st.session_state.get("pixel_size", 0.13),
        )

        with st.expander("Model", expanded=True):
            model_name = st.selectbox(
                "Select a model",
                list(MODELS_ZOO.keys()),
                index=0,
                help="Select a model to use for inference",
            )
            finetuned = st.checkbox(
                "Use finetuned model",
                value=True,
                help="Use a finetuned model for inference",
            )
            col1, col2 = st.columns(2)
            with col1:
                st.write("Running on:")
            with col2:
                st.button(
                    "GPU" if torch.cuda.is_available() else "CPU",
                    disabled=True,
                )

    tab_segmentation, tab_charts = st.tabs(["Segmentation", "Charts"])

    with tab_segmentation:
        st.subheader("Segmentation")
        if run_segmentation:
            run_inference(
                model_name=MODELS_ZOO[model_name] + "_finetuned"
                if finetuned
                else MODELS_ZOO[model_name],
                pixel_size=st.session_state.get("pixel_size", 0.13),
            )
            st.balloons()
        if st.session_state.get("results", None) is not None:
            st.write(
                st.session_state.results,
            )

            st.download_button(
                label="Download results",
                data=st.session_state.results.to_csv(index=False).encode("utf-8"),
                file_name="results.csv",
                mime="text/csv",
                use_container_width=True,
            )
    with tab_charts:
        if st.session_state.get("results", None) is not None:
            results = st.session_state.results

            categories = results["image_name"].unique()
            selected_category = st.multiselect(
                "Select a category", categories, default=categories
            )
            plot_result(selected_category)

else:
    st.switch_page("pages/1_Load.py")