path-foundation-demo / data_models /embedding_converter.py
lirony's picture
initial
a3e8b4b
raw
history blame
15.5 kB
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Converts Embedding Requests and Responses to json and vice versa."""
import dataclasses
import json
from typing import Any, List, Mapping, Sequence
from ez_wsi_dicomweb import patch_embedding_endpoints
import pete_errors
from data_models import embedding_request
from data_models import embedding_response
from data_models import patch_coordinate as patch_coordinate_module
_EndpointJsonKeys = patch_embedding_endpoints.EndpointJsonKeys
class ValidationError(Exception):
pass
class _InvalidCoordinateError(Exception):
pass
class _InstanceUIDMetadataError(Exception):
pass
def validate_int(val: Any) -> int:
if isinstance(val, float):
cast_val = int(val)
if cast_val != val:
raise ValidationError('coordinate value is not int')
val = cast_val
elif not isinstance(val, int):
raise ValidationError('coordinate value is not int')
return val
def _get_patch_coord(patch_coordinates: Sequence[Mapping[str, Any]]):
"""Returns patch coodianates."""
result = []
if not isinstance(patch_coordinates, list):
raise _InvalidCoordinateError('patch_coordinates is not list')
for patch_coordinate in patch_coordinates:
try:
pc = patch_coordinate_module.create_patch_coordinate(**patch_coordinate)
except TypeError as exp:
if not isinstance(patch_coordinate, dict):
raise _InvalidCoordinateError('Patch coordinate is not dict.') from exp
keys = ', '.join(
list(
dataclasses.asdict(
patch_coordinate_module.create_patch_coordinate(0, 0)
)
)
)
raise _InvalidCoordinateError(
f'Patch coordinate dict has invalid keys; expecting: {keys}'
) from exp
try:
validate_int(pc.x_origin)
validate_int(pc.y_origin)
validate_int(pc.width)
validate_int(pc.height)
except ValidationError as exp:
raise _InvalidCoordinateError(
f'Invalid patch coordinate; x_origin: {pc.x_origin}, y_origin:'
f' {pc.y_origin}, width: {pc.width}, height: {pc.height}'
) from exp
result.append(pc)
if not result:
raise _InvalidCoordinateError('empty patch_coordinates')
return result
def embedding_response_v1_to_json(
response: embedding_response.EmbeddingResponseV1,
) -> Mapping[str, Any]:
"""Loads the model artifact.
Args:
response: Structed EmbeddingResponse object.
Returns:
The value of the JSON payload to return in the API.
"""
json_response = dataclasses.asdict(response)
if response.error_response:
json_response['error_response'][
'error_code'
] = response.error_response.error_code.value
return {_EndpointJsonKeys.PREDICTIONS: json_response}
def embedding_response_v2_to_json(
json_response: Sequence[Mapping[str, Any]],
) -> Mapping[str, Any]:
return {_EndpointJsonKeys.PREDICTIONS: json_response}
def validate_str_list(val: Any) -> List[str]:
if not isinstance(val, List):
raise ValidationError('not list')
for v in val:
if not isinstance(v, str) or not v:
raise ValidationError('list contains invalid value')
return val
def _validate_instance_uids_not_empty_str_list(val: Any) -> List[str]:
try:
val = validate_str_list(val)
except ValidationError as exp:
raise _InstanceUIDMetadataError() from exp
if not val:
raise _InstanceUIDMetadataError('list is empty')
return val
def validate_str_key_dict(val: Any) -> Mapping[str, Any]:
if not isinstance(val, dict):
raise ValidationError('not a dict')
if val:
for k in val:
if not isinstance(k, str) or not k:
raise ValidationError('dict contains invalid value')
return val
def validate_str(val: Any) -> str:
if not isinstance(val, str):
raise ValidationError('not string')
return val
def _validate_not_empty_str(val: Any) -> str:
if not isinstance(val, str) or not val:
raise ValidationError('not string or empty')
return val
def _generate_instance_metadata_error_string(
metadata: Mapping[str, Any], *keys: str
) -> str:
"""returns instance metadata as a error string."""
result = {}
for key in keys:
if key not in metadata:
continue
if key == _EndpointJsonKeys.EXTENSIONS:
value = metadata[key]
if isinstance(value, Mapping):
value = dict(value)
# Strip ez_wsi_state from output.
# Not contributing to validation errors here and may be very large.
if _EndpointJsonKeys.EZ_WSI_STATE in value:
del value[_EndpointJsonKeys.EZ_WSI_STATE]
result[key] = value
continue
elif key == _EndpointJsonKeys.BEARER_TOKEN:
value = metadata[key]
# If bearer token is present, and defined strip
if isinstance(value, str) and value:
result[key] = 'PRESENT'
continue
# otherwise just associate key and value.
result[key] = metadata[key]
return json.dumps(result, sort_keys=True)
def _validate_instance_list(json_metadata: Mapping[str, Any]) -> List[Any]:
val = json_metadata.get(_EndpointJsonKeys.INSTANCES)
if isinstance(val, list):
return val
raise pete_errors.InvalidRequestFieldError(
'Invalid input, missing expected'
f' key: {_EndpointJsonKeys.INSTANCES} and associated list of values.'
)
class EmbeddingConverterV1:
"""Class containing methods for transforming embedding request and responses."""
def json_to_embedding_request(
self, json_metadata: Mapping[str, Any]
) -> embedding_request.EmbeddingRequestV1:
"""Converts json to embedding request.
Args:
json_metadata: The value of the JSON payload provided to the API.
Returns:
Structured EmbeddingRequest object.
Raises:
InvalidRequestFieldError: If the provided fields are invalid.
"""
instances = []
try:
model_params = json_metadata[_EndpointJsonKeys.PARAMETERS]
try:
parameters = embedding_request.EmbeddingParameters(
model_size=_validate_not_empty_str(
model_params.get(_EndpointJsonKeys.MODEL_SIZE)
),
model_kind=_validate_not_empty_str(
model_params.get(_EndpointJsonKeys.MODEL_KIND)
),
)
except ValidationError as exp:
raise pete_errors.InvalidRequestFieldError(
'Invalid model size and/or kind parameters.'
) from exp
for instance in _validate_instance_list(json_metadata):
ez_wsi_state = instance.get(_EndpointJsonKeys.EZ_WSI_STATE, {})
try:
ez_wsi_state = validate_str_key_dict(ez_wsi_state)
except ValidationError:
try:
ez_wsi_state = validate_str(ez_wsi_state)
except ValidationError as exp:
raise pete_errors.InvalidRequestFieldError(
'Invalid EZ-WSI state metadata.'
) from exp
try:
instances.append(
embedding_request.EmbeddingInstanceV1(
dicom_web_store_url=_validate_not_empty_str(
instance.get(_EndpointJsonKeys.DICOM_WEB_STORE_URL)
),
dicom_study_uid=_validate_not_empty_str(
instance.get(_EndpointJsonKeys.DICOM_STUDY_UID)
),
dicom_series_uid=_validate_not_empty_str(
instance.get(_EndpointJsonKeys.DICOM_SERIES_UID)
),
bearer_token=_validate_not_empty_str(
instance.get(_EndpointJsonKeys.BEARER_TOKEN)
),
ez_wsi_state=ez_wsi_state,
instance_uids=_validate_instance_uids_not_empty_str_list(
instance.get(_EndpointJsonKeys.INSTANCE_UIDS)
),
patch_coordinates=_get_patch_coord(
instance.get(_EndpointJsonKeys.PATCH_COORDINATES)
),
)
)
except ValidationError as exp:
instance_error_msg = _generate_instance_metadata_error_string(
instance,
_EndpointJsonKeys.DICOM_WEB_STORE_URL,
_EndpointJsonKeys.DICOM_STUDY_UID,
_EndpointJsonKeys.DICOM_SERIES_UID,
_EndpointJsonKeys.BEARER_TOKEN,
_EndpointJsonKeys.INSTANCE_UIDS,
)
raise pete_errors.InvalidRequestFieldError(
f'Invalid instance; {instance_error_msg}'
) from exp
except _InstanceUIDMetadataError as exp:
instance_error_msg = _generate_instance_metadata_error_string(
instance,
_EndpointJsonKeys.PATCH_COORDINATES,
)
raise pete_errors.InvalidRequestFieldError(
f'Invalid DICOM SOP Instance UID metadata; {instance_error_msg}'
) from exp
except _InvalidCoordinateError as exp:
raise pete_errors.InvalidRequestFieldError(
f'Invalid patch coordinate; {exp}'
) from exp
except (TypeError, ValueError, KeyError) as exp:
raise pete_errors.InvalidRequestFieldError(
f'Invalid input: {json.dumps(json_metadata)}'
) from exp
return embedding_request.EmbeddingRequestV1(
parameters=parameters, instances=instances
)
class EmbeddingConverterV2:
"""Class containing methods for transforming embedding request and responses."""
def json_to_embedding_request(
self, json_metadata: Mapping[str, Any]
) -> embedding_request.EmbeddingRequestV2:
"""Converts json to embedding request.
Args:
json_metadata: The value of the JSON payload provided to the API.
Returns:
Structured EmbeddingRequest object.
Raises:
InvalidRequestFieldError: If the provided fields are invalid.
"""
instances = []
for instance in _validate_instance_list(json_metadata):
try:
patch_coordinates = _get_patch_coord(
instance.get(_EndpointJsonKeys.PATCH_COORDINATES)
)
except _InvalidCoordinateError as exp:
instance_error_msg = _generate_instance_metadata_error_string(
instance,
_EndpointJsonKeys.PATCH_COORDINATES,
)
raise pete_errors.InvalidRequestFieldError(
f'Invalid patch coordinate; {exp}; {instance_error_msg}'
) from exp
if _EndpointJsonKeys.DICOM_PATH in instance:
try:
dicom_path = validate_str_key_dict(
instance.get(_EndpointJsonKeys.DICOM_PATH)
)
except ValidationError as exp:
raise pete_errors.InvalidRequestFieldError(
'Invalid DICOM path.'
) from exp
try:
instances.append(
embedding_request.DicomImageV2(
series_path=_validate_not_empty_str(
dicom_path.get(_EndpointJsonKeys.SERIES_PATH)
),
bearer_token=validate_str(
instance.get(
_EndpointJsonKeys.BEARER_TOKEN,
'',
)
),
extensions=validate_str_key_dict(
instance.get(
_EndpointJsonKeys.EXTENSIONS,
{},
)
),
instance_uids=_validate_instance_uids_not_empty_str_list(
dicom_path.get(_EndpointJsonKeys.INSTANCE_UIDS)
),
patch_coordinates=patch_coordinates,
)
)
except _InstanceUIDMetadataError as exp:
error_msg = _generate_instance_metadata_error_string(
instance,
_EndpointJsonKeys.SERIES_PATH,
_EndpointJsonKeys.BEARER_TOKEN,
_EndpointJsonKeys.EXTENSIONS,
_EndpointJsonKeys.INSTANCE_UIDS,
)
raise pete_errors.InvalidRequestFieldError(
f'Invalid DICOM SOP Instance UID metadata; {error_msg}'
) from exp
except ValidationError as exp:
error_msg = _generate_instance_metadata_error_string(
instance,
_EndpointJsonKeys.SERIES_PATH,
_EndpointJsonKeys.BEARER_TOKEN,
_EndpointJsonKeys.EXTENSIONS,
_EndpointJsonKeys.INSTANCE_UIDS,
)
raise pete_errors.InvalidRequestFieldError(
f'DICOM instance JSON formatting is invalid; {error_msg}'
) from exp
elif _EndpointJsonKeys.IMAGE_FILE_URI in instance:
try:
instances.append(
embedding_request.GcsImageV2(
image_file_uri=_validate_not_empty_str(
instance.get(_EndpointJsonKeys.IMAGE_FILE_URI)
),
bearer_token=validate_str(
instance.get(
_EndpointJsonKeys.BEARER_TOKEN,
'',
)
),
extensions=validate_str_key_dict(
instance.get(
_EndpointJsonKeys.EXTENSIONS,
{},
)
),
patch_coordinates=patch_coordinates,
)
)
except ValidationError as exp:
error_msg = _generate_instance_metadata_error_string(
instance,
_EndpointJsonKeys.IMAGE_FILE_URI,
_EndpointJsonKeys.BEARER_TOKEN,
_EndpointJsonKeys.EXTENSIONS,
)
raise pete_errors.InvalidRequestFieldError(
'Google Cloud Storage instance JSON formatting is invalid;'
f' {error_msg}'
) from exp
elif _EndpointJsonKeys.RAW_IMAGE_BYTES in instance:
try:
instances.append(
embedding_request.EmbeddedImageV2(
image_bytes=_validate_not_empty_str(
instance.get(_EndpointJsonKeys.RAW_IMAGE_BYTES)
),
extensions=validate_str_key_dict(
instance.get(
_EndpointJsonKeys.EXTENSIONS,
{},
)
),
patch_coordinates=patch_coordinates,
)
)
except ValidationError as exp:
error_msg = _generate_instance_metadata_error_string(
instance,
_EndpointJsonKeys.IMAGE_FILE_URI,
_EndpointJsonKeys.BEARER_TOKEN,
_EndpointJsonKeys.EXTENSIONS,
)
raise pete_errors.InvalidRequestFieldError(
'Embedded image instance JSON formatting is invalid; '
f' {error_msg}'
) from exp
else:
raise pete_errors.InvalidRequestFieldError('unidentified type')
return embedding_request.EmbeddingRequestV2(instances)