Spaces:
Running
Running
# Copyright 2024 Google LLC | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Callable responsible for running Inference on provided patches.""" | |
import functools | |
from typing import Any, Mapping | |
from ez_wsi_dicomweb import credential_factory | |
from ez_wsi_dicomweb import dicom_slide | |
from ez_wsi_dicomweb import patch_embedding | |
from ez_wsi_dicomweb import dicom_web_interface | |
from ez_wsi_dicomweb import patch_embedding_endpoints | |
from ez_wsi_dicomweb.ml_toolkit import dicom_path | |
import numpy as np | |
import tensorflow as tf | |
from data_models import embedding_response | |
from data_models import embedding_request | |
from data_models import embedding_converter | |
def _load_huggingface_model() -> tf.keras.Model: | |
raise ValueError('Liron implement model loader') | |
def _endpoint_model(ml_model: tf.keras.Model, image: np.ndarray) -> np.ndarray: | |
"""Function ez-wsi will use to run local ML model.""" | |
result = ml_model.signatures['serving_default']( | |
tf.cast(tf.constant(image), tf.float32) | |
) | |
return result['output_0'].numpy() | |
_ENDPOINT_MODEL = functools.partial(_endpoint_model, _load_huggingface_model()) | |
class PetePredictor: | |
"""Callable responsible for generating embeddings.""" | |
def predict( | |
self, | |
prediction_input: Mapping[str, Any], | |
) -> Mapping[str, Any]: | |
"""Runs inference on provided patches. | |
Args: | |
prediction_input: JSON formatted input for embedding prediction. | |
model: ModelRunner to handle model step. | |
Returns: | |
JSON formatted output. | |
Raises: | |
ERROR_LOADING_DICOM: If the provided patches are not concated. | |
""" | |
embedding_json_converter = embedding_converter.EmbeddingConverterV2() | |
request = embedding_json_converter.json_to_embedding_request(prediction_input) | |
endpoint = patch_embedding_endpoints.LocalEndpoint(_ENDPOINT_MODEL) | |
embedding_results = [] | |
for instance in request.instances: | |
patches = [] | |
if not isinstance(instance, embedding_request.DicomImageV2): | |
raise ValueError('unsupported') | |
token = instance.bearer_token | |
if token: | |
cf = credential_factory.TokenPassthroughCredentialFactory(token) | |
else: | |
cf = credential_factory.NoAuthCredentialsFactory() | |
dwi = dicom_web_interface.DicomWebInterface(cf) | |
path = dicom_path.FromString(instance.series_path) | |
ds = dicom_slide.DicomSlide(dwi=dwi, path=path) | |
level = ds.get_instance_level(instance.instance_uids[0]) | |
for coor in instance.patch_coordinates: | |
patches.append(ds.get_patch(level, coor.x_origin, coor.y_origin, coor.width, coor.height)) | |
patch_embeddings = [] | |
for index, result in enumerate(patch_embedding.generate_patch_embeddings(endpoint, patches)): | |
embedding = np.array(result.embedding) | |
patch_embeddings.append( | |
embedding_response.PatchEmbeddingV2( | |
embedding_vector=embedding.tolist(), | |
patch_coordinate=instance.patch_coordinates[index], | |
)) | |
embedding_results.append( | |
embedding_response.embedding_instance_response_v2(patch_embeddings) | |
) | |
return embedding_converter.embedding_response_v2_to_json(embedding_results) | |