File size: 3,177 Bytes
69591a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import torch.nn.functional as F
import numpy as np
import torch
from torchvision.transforms._functional_tensor import normalize
import pandas as pd
from skimage.segmentation import expand_labels
from skimage.measure import label
import albumentations as A
from monai.inferers import SlidingWindowInferer
from dnafiber.deployment import _get_model
from dnafiber.postprocess import refine_segmentation

transform = A.Compose(
    [
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        A.ToTensorV2(),
    ]
)


def preprocess_image(image):
    image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0)
    image = normalize(image, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    return image


def convert_to_dataset(counts):
    data = {"index": [], "red": [], "green": [], "ratio": []}
    for k, v in counts.items():
        data["index"].append(k)
        data["green"].append(v["green"])
        data["red"].append(v["red"])
        if v["red"] == 0:
            data["ratio"].append(np.nan)
        else:
            data["ratio"].append(v["green"] / (v["red"]))
    df = pd.DataFrame(data)
    return df


def convert_mask_to_image(mask, expand=False):
    if expand:
        mask = expand_labels(mask, distance=expand)
    h, w = mask.shape
    image = np.zeros((h, w, 3), dtype=np.uint8)
    GREEN = np.array([0, 255, 0])
    RED = np.array([255, 0, 0])

    image[mask == 1] = RED
    image[mask == 2] = GREEN

    return image


@torch.inference_mode()
def infer(model, image, device, scale=0.13, to_numpy=True, only_probabilities=False):
    if isinstance(model, str):
        model = _get_model(device=device, revision=model)
    model_pixel_size = 0.26

    scale = scale / model_pixel_size
    tensor = transform(image=image)["image"].unsqueeze(0).to(device)
    h, w = tensor.shape[2], tensor.shape[3]
    device = torch.device(device)
    with torch.autocast(device_type=device.type):
        tensor = F.interpolate(
            tensor,
            size=(int(h * scale), int(w * scale)),
            mode="bilinear",
        )
        if tensor.shape[2] > 1024 or tensor.shape[3] > 1024:
            inferer = SlidingWindowInferer(
                roi_size=(1024, 1024),
                sw_batch_size=4,
                overlap=0.25,
                mode="gaussian",
                device=device,
                progress=True,
            )
            output = inferer(tensor, model)
        else:
            output = model(tensor)

        probabilities = F.softmax(output, dim=1)
        if only_probabilities:
            probabilities = probabilities.cpu()

            probabilities = F.interpolate(
                probabilities,
                size=(h, w),
                mode="bilinear",
            )
            return probabilities

        output = F.interpolate(
            probabilities.argmax(dim=1, keepdim=True).float(),
            size=(h, w),
            mode="nearest",
        )

    output = output.squeeze().byte()
    if to_numpy:
        output = output.cpu().numpy()

    return output