Spaces:
Running
Running
# Copyright 2024 Google LLC | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Converts Embedding Requests and Responses to json and vice versa.""" | |
import dataclasses | |
import json | |
from typing import Any, List, Mapping, Sequence | |
from ez_wsi_dicomweb import patch_embedding_endpoints | |
import pete_errors | |
from data_models import embedding_request | |
from data_models import embedding_response | |
from data_models import patch_coordinate as patch_coordinate_module | |
_EndpointJsonKeys = patch_embedding_endpoints.EndpointJsonKeys | |
class ValidationError(Exception): | |
pass | |
class _InvalidCoordinateError(Exception): | |
pass | |
class _InstanceUIDMetadataError(Exception): | |
pass | |
def validate_int(val: Any) -> int: | |
if isinstance(val, float): | |
cast_val = int(val) | |
if cast_val != val: | |
raise ValidationError('coordinate value is not int') | |
val = cast_val | |
elif not isinstance(val, int): | |
raise ValidationError('coordinate value is not int') | |
return val | |
def _get_patch_coord(patch_coordinates: Sequence[Mapping[str, Any]]): | |
"""Returns patch coodianates.""" | |
result = [] | |
if not isinstance(patch_coordinates, list): | |
raise _InvalidCoordinateError('patch_coordinates is not list') | |
for patch_coordinate in patch_coordinates: | |
try: | |
pc = patch_coordinate_module.create_patch_coordinate(**patch_coordinate) | |
except TypeError as exp: | |
if not isinstance(patch_coordinate, dict): | |
raise _InvalidCoordinateError('Patch coordinate is not dict.') from exp | |
keys = ', '.join( | |
list( | |
dataclasses.asdict( | |
patch_coordinate_module.create_patch_coordinate(0, 0) | |
) | |
) | |
) | |
raise _InvalidCoordinateError( | |
f'Patch coordinate dict has invalid keys; expecting: {keys}' | |
) from exp | |
try: | |
validate_int(pc.x_origin) | |
validate_int(pc.y_origin) | |
validate_int(pc.width) | |
validate_int(pc.height) | |
except ValidationError as exp: | |
raise _InvalidCoordinateError( | |
f'Invalid patch coordinate; x_origin: {pc.x_origin}, y_origin:' | |
f' {pc.y_origin}, width: {pc.width}, height: {pc.height}' | |
) from exp | |
result.append(pc) | |
if not result: | |
raise _InvalidCoordinateError('empty patch_coordinates') | |
return result | |
def embedding_response_v1_to_json( | |
response: embedding_response.EmbeddingResponseV1, | |
) -> Mapping[str, Any]: | |
"""Loads the model artifact. | |
Args: | |
response: Structed EmbeddingResponse object. | |
Returns: | |
The value of the JSON payload to return in the API. | |
""" | |
json_response = dataclasses.asdict(response) | |
if response.error_response: | |
json_response['error_response'][ | |
'error_code' | |
] = response.error_response.error_code.value | |
return {_EndpointJsonKeys.PREDICTIONS: json_response} | |
def embedding_response_v2_to_json( | |
json_response: Sequence[Mapping[str, Any]], | |
) -> Mapping[str, Any]: | |
return {_EndpointJsonKeys.PREDICTIONS: json_response} | |
def validate_str_list(val: Any) -> List[str]: | |
if not isinstance(val, List): | |
raise ValidationError('not list') | |
for v in val: | |
if not isinstance(v, str) or not v: | |
raise ValidationError('list contains invalid value') | |
return val | |
def _validate_instance_uids_not_empty_str_list(val: Any) -> List[str]: | |
try: | |
val = validate_str_list(val) | |
except ValidationError as exp: | |
raise _InstanceUIDMetadataError() from exp | |
if not val: | |
raise _InstanceUIDMetadataError('list is empty') | |
return val | |
def validate_str_key_dict(val: Any) -> Mapping[str, Any]: | |
if not isinstance(val, dict): | |
raise ValidationError('not a dict') | |
if val: | |
for k in val: | |
if not isinstance(k, str) or not k: | |
raise ValidationError('dict contains invalid value') | |
return val | |
def validate_str(val: Any) -> str: | |
if not isinstance(val, str): | |
raise ValidationError('not string') | |
return val | |
def _validate_not_empty_str(val: Any) -> str: | |
if not isinstance(val, str) or not val: | |
raise ValidationError('not string or empty') | |
return val | |
def _generate_instance_metadata_error_string( | |
metadata: Mapping[str, Any], *keys: str | |
) -> str: | |
"""returns instance metadata as a error string.""" | |
result = {} | |
for key in keys: | |
if key not in metadata: | |
continue | |
if key == _EndpointJsonKeys.EXTENSIONS: | |
value = metadata[key] | |
if isinstance(value, Mapping): | |
value = dict(value) | |
# Strip ez_wsi_state from output. | |
# Not contributing to validation errors here and may be very large. | |
if _EndpointJsonKeys.EZ_WSI_STATE in value: | |
del value[_EndpointJsonKeys.EZ_WSI_STATE] | |
result[key] = value | |
continue | |
elif key == _EndpointJsonKeys.BEARER_TOKEN: | |
value = metadata[key] | |
# If bearer token is present, and defined strip | |
if isinstance(value, str) and value: | |
result[key] = 'PRESENT' | |
continue | |
# otherwise just associate key and value. | |
result[key] = metadata[key] | |
return json.dumps(result, sort_keys=True) | |
def _validate_instance_list(json_metadata: Mapping[str, Any]) -> List[Any]: | |
val = json_metadata.get(_EndpointJsonKeys.INSTANCES) | |
if isinstance(val, list): | |
return val | |
raise pete_errors.InvalidRequestFieldError( | |
'Invalid input, missing expected' | |
f' key: {_EndpointJsonKeys.INSTANCES} and associated list of values.' | |
) | |
class EmbeddingConverterV1: | |
"""Class containing methods for transforming embedding request and responses.""" | |
def json_to_embedding_request( | |
self, json_metadata: Mapping[str, Any] | |
) -> embedding_request.EmbeddingRequestV1: | |
"""Converts json to embedding request. | |
Args: | |
json_metadata: The value of the JSON payload provided to the API. | |
Returns: | |
Structured EmbeddingRequest object. | |
Raises: | |
InvalidRequestFieldError: If the provided fields are invalid. | |
""" | |
instances = [] | |
try: | |
model_params = json_metadata[_EndpointJsonKeys.PARAMETERS] | |
try: | |
parameters = embedding_request.EmbeddingParameters( | |
model_size=_validate_not_empty_str( | |
model_params.get(_EndpointJsonKeys.MODEL_SIZE) | |
), | |
model_kind=_validate_not_empty_str( | |
model_params.get(_EndpointJsonKeys.MODEL_KIND) | |
), | |
) | |
except ValidationError as exp: | |
raise pete_errors.InvalidRequestFieldError( | |
'Invalid model size and/or kind parameters.' | |
) from exp | |
for instance in _validate_instance_list(json_metadata): | |
ez_wsi_state = instance.get(_EndpointJsonKeys.EZ_WSI_STATE, {}) | |
try: | |
ez_wsi_state = validate_str_key_dict(ez_wsi_state) | |
except ValidationError: | |
try: | |
ez_wsi_state = validate_str(ez_wsi_state) | |
except ValidationError as exp: | |
raise pete_errors.InvalidRequestFieldError( | |
'Invalid EZ-WSI state metadata.' | |
) from exp | |
try: | |
instances.append( | |
embedding_request.EmbeddingInstanceV1( | |
dicom_web_store_url=_validate_not_empty_str( | |
instance.get(_EndpointJsonKeys.DICOM_WEB_STORE_URL) | |
), | |
dicom_study_uid=_validate_not_empty_str( | |
instance.get(_EndpointJsonKeys.DICOM_STUDY_UID) | |
), | |
dicom_series_uid=_validate_not_empty_str( | |
instance.get(_EndpointJsonKeys.DICOM_SERIES_UID) | |
), | |
bearer_token=_validate_not_empty_str( | |
instance.get(_EndpointJsonKeys.BEARER_TOKEN) | |
), | |
ez_wsi_state=ez_wsi_state, | |
instance_uids=_validate_instance_uids_not_empty_str_list( | |
instance.get(_EndpointJsonKeys.INSTANCE_UIDS) | |
), | |
patch_coordinates=_get_patch_coord( | |
instance.get(_EndpointJsonKeys.PATCH_COORDINATES) | |
), | |
) | |
) | |
except ValidationError as exp: | |
instance_error_msg = _generate_instance_metadata_error_string( | |
instance, | |
_EndpointJsonKeys.DICOM_WEB_STORE_URL, | |
_EndpointJsonKeys.DICOM_STUDY_UID, | |
_EndpointJsonKeys.DICOM_SERIES_UID, | |
_EndpointJsonKeys.BEARER_TOKEN, | |
_EndpointJsonKeys.INSTANCE_UIDS, | |
) | |
raise pete_errors.InvalidRequestFieldError( | |
f'Invalid instance; {instance_error_msg}' | |
) from exp | |
except _InstanceUIDMetadataError as exp: | |
instance_error_msg = _generate_instance_metadata_error_string( | |
instance, | |
_EndpointJsonKeys.PATCH_COORDINATES, | |
) | |
raise pete_errors.InvalidRequestFieldError( | |
f'Invalid DICOM SOP Instance UID metadata; {instance_error_msg}' | |
) from exp | |
except _InvalidCoordinateError as exp: | |
raise pete_errors.InvalidRequestFieldError( | |
f'Invalid patch coordinate; {exp}' | |
) from exp | |
except (TypeError, ValueError, KeyError) as exp: | |
raise pete_errors.InvalidRequestFieldError( | |
f'Invalid input: {json.dumps(json_metadata)}' | |
) from exp | |
return embedding_request.EmbeddingRequestV1( | |
parameters=parameters, instances=instances | |
) | |
class EmbeddingConverterV2: | |
"""Class containing methods for transforming embedding request and responses.""" | |
def json_to_embedding_request( | |
self, json_metadata: Mapping[str, Any] | |
) -> embedding_request.EmbeddingRequestV2: | |
"""Converts json to embedding request. | |
Args: | |
json_metadata: The value of the JSON payload provided to the API. | |
Returns: | |
Structured EmbeddingRequest object. | |
Raises: | |
InvalidRequestFieldError: If the provided fields are invalid. | |
""" | |
instances = [] | |
for instance in _validate_instance_list(json_metadata): | |
try: | |
patch_coordinates = _get_patch_coord( | |
instance.get(_EndpointJsonKeys.PATCH_COORDINATES) | |
) | |
except _InvalidCoordinateError as exp: | |
instance_error_msg = _generate_instance_metadata_error_string( | |
instance, | |
_EndpointJsonKeys.PATCH_COORDINATES, | |
) | |
raise pete_errors.InvalidRequestFieldError( | |
f'Invalid patch coordinate; {exp}; {instance_error_msg}' | |
) from exp | |
if _EndpointJsonKeys.DICOM_PATH in instance: | |
try: | |
dicom_path = validate_str_key_dict( | |
instance.get(_EndpointJsonKeys.DICOM_PATH) | |
) | |
except ValidationError as exp: | |
raise pete_errors.InvalidRequestFieldError( | |
'Invalid DICOM path.' | |
) from exp | |
try: | |
instances.append( | |
embedding_request.DicomImageV2( | |
series_path=_validate_not_empty_str( | |
dicom_path.get(_EndpointJsonKeys.SERIES_PATH) | |
), | |
bearer_token=validate_str( | |
instance.get( | |
_EndpointJsonKeys.BEARER_TOKEN, | |
'', | |
) | |
), | |
extensions=validate_str_key_dict( | |
instance.get( | |
_EndpointJsonKeys.EXTENSIONS, | |
{}, | |
) | |
), | |
instance_uids=_validate_instance_uids_not_empty_str_list( | |
dicom_path.get(_EndpointJsonKeys.INSTANCE_UIDS) | |
), | |
patch_coordinates=patch_coordinates, | |
) | |
) | |
except _InstanceUIDMetadataError as exp: | |
error_msg = _generate_instance_metadata_error_string( | |
instance, | |
_EndpointJsonKeys.SERIES_PATH, | |
_EndpointJsonKeys.BEARER_TOKEN, | |
_EndpointJsonKeys.EXTENSIONS, | |
_EndpointJsonKeys.INSTANCE_UIDS, | |
) | |
raise pete_errors.InvalidRequestFieldError( | |
f'Invalid DICOM SOP Instance UID metadata; {error_msg}' | |
) from exp | |
except ValidationError as exp: | |
error_msg = _generate_instance_metadata_error_string( | |
instance, | |
_EndpointJsonKeys.SERIES_PATH, | |
_EndpointJsonKeys.BEARER_TOKEN, | |
_EndpointJsonKeys.EXTENSIONS, | |
_EndpointJsonKeys.INSTANCE_UIDS, | |
) | |
raise pete_errors.InvalidRequestFieldError( | |
f'DICOM instance JSON formatting is invalid; {error_msg}' | |
) from exp | |
elif _EndpointJsonKeys.IMAGE_FILE_URI in instance: | |
try: | |
instances.append( | |
embedding_request.GcsImageV2( | |
image_file_uri=_validate_not_empty_str( | |
instance.get(_EndpointJsonKeys.IMAGE_FILE_URI) | |
), | |
bearer_token=validate_str( | |
instance.get( | |
_EndpointJsonKeys.BEARER_TOKEN, | |
'', | |
) | |
), | |
extensions=validate_str_key_dict( | |
instance.get( | |
_EndpointJsonKeys.EXTENSIONS, | |
{}, | |
) | |
), | |
patch_coordinates=patch_coordinates, | |
) | |
) | |
except ValidationError as exp: | |
error_msg = _generate_instance_metadata_error_string( | |
instance, | |
_EndpointJsonKeys.IMAGE_FILE_URI, | |
_EndpointJsonKeys.BEARER_TOKEN, | |
_EndpointJsonKeys.EXTENSIONS, | |
) | |
raise pete_errors.InvalidRequestFieldError( | |
'Google Cloud Storage instance JSON formatting is invalid;' | |
f' {error_msg}' | |
) from exp | |
elif _EndpointJsonKeys.RAW_IMAGE_BYTES in instance: | |
try: | |
instances.append( | |
embedding_request.EmbeddedImageV2( | |
image_bytes=_validate_not_empty_str( | |
instance.get(_EndpointJsonKeys.RAW_IMAGE_BYTES) | |
), | |
extensions=validate_str_key_dict( | |
instance.get( | |
_EndpointJsonKeys.EXTENSIONS, | |
{}, | |
) | |
), | |
patch_coordinates=patch_coordinates, | |
) | |
) | |
except ValidationError as exp: | |
error_msg = _generate_instance_metadata_error_string( | |
instance, | |
_EndpointJsonKeys.IMAGE_FILE_URI, | |
_EndpointJsonKeys.BEARER_TOKEN, | |
_EndpointJsonKeys.EXTENSIONS, | |
) | |
raise pete_errors.InvalidRequestFieldError( | |
'Embedded image instance JSON formatting is invalid; ' | |
f' {error_msg}' | |
) from exp | |
else: | |
raise pete_errors.InvalidRequestFieldError('unidentified type') | |
return embedding_request.EmbeddingRequestV2(instances) | |