File size: 3,678 Bytes
a3e8b4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Callable responsible for running Inference on provided patches."""

import functools
from typing import Any, Mapping

from ez_wsi_dicomweb import credential_factory
from ez_wsi_dicomweb import dicom_slide
from ez_wsi_dicomweb import patch_embedding
from ez_wsi_dicomweb import dicom_web_interface
from ez_wsi_dicomweb import patch_embedding_endpoints
from ez_wsi_dicomweb.ml_toolkit import dicom_path
import numpy as np
import tensorflow as tf

from data_models import embedding_response
from data_models import embedding_request
from data_models import embedding_converter



def _load_huggingface_model() -> tf.keras.Model:
  raise ValueError('Liron implement model loader') 


def _endpoint_model(ml_model: tf.keras.Model, image: np.ndarray) -> np.ndarray:
  """Function ez-wsi will use to run local ML model."""
  result = ml_model.signatures['serving_default'](
      tf.cast(tf.constant(image), tf.float32)
  )
  return result['output_0'].numpy()


_ENDPOINT_MODEL = functools.partial(_endpoint_model, _load_huggingface_model())


class PetePredictor:
  """Callable responsible for generating embeddings."""

  def predict(
      self,
      prediction_input: Mapping[str, Any],
  ) -> Mapping[str, Any]:
    """Runs inference on provided patches.

    Args:
      prediction_input: JSON formatted input for embedding prediction.
      model: ModelRunner to handle model step.

    Returns:
      JSON formatted output.

    Raises:
      ERROR_LOADING_DICOM: If the provided patches are not concated.
    """
    embedding_json_converter = embedding_converter.EmbeddingConverterV2()
    request = embedding_json_converter.json_to_embedding_request(prediction_input)
    endpoint =  patch_embedding_endpoints.LocalEndpoint(_ENDPOINT_MODEL)

    embedding_results = []
    for instance in request.instances:
      patches = []
      if not isinstance(instance, embedding_request.DicomImageV2):
        raise ValueError('unsupported')
      token = instance.bearer_token
      if token:
        cf = credential_factory.TokenPassthroughCredentialFactory(token)
      else:
        cf = credential_factory.NoAuthCredentialsFactory()
      dwi = dicom_web_interface.DicomWebInterface(cf)
      path = dicom_path.FromString(instance.series_path)
      ds = dicom_slide.DicomSlide(dwi=dwi, path=path)
      level = ds.get_instance_level(instance.instance_uids[0])
      for coor in instance.patch_coordinates:
        patches.append(ds.get_patch(level, coor.x_origin, coor.y_origin, coor.width, coor.height))

      patch_embeddings = []
      for index, result in enumerate(patch_embedding.generate_patch_embeddings(endpoint, patches)):
        embedding = np.array(result.embedding)
        patch_embeddings.append(
            embedding_response.PatchEmbeddingV2(
                embedding_vector=embedding.tolist(),
                patch_coordinate=instance.patch_coordinates[index],
            ))
      embedding_results.append(
          embedding_response.embedding_instance_response_v2(patch_embeddings)
      )
    return embedding_converter.embedding_response_v2_to_json(embedding_results)