Spaces:
Running
Running
initial
Browse files- Dockerfile +27 -0
- data_models/__init__.py +14 -0
- data_models/embedding_converter.py +445 -0
- data_models/embedding_request.py +103 -0
- data_models/embedding_response.py +143 -0
- data_models/patch_coordinate.py +54 -0
- pete_errors.py +119 -0
- pete_predictor_v2.py +101 -0
- requirements.txt +7 -0
- server_gunicorn.py +100 -0
- test.py +18 -0
Dockerfile
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Google LLC
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# This is used to build a Docker image that includes the necessary dependencies
|
16 |
+
# for running the Path Foundation as a microservice.
|
17 |
+
|
18 |
+
FROM python:3.12-slim-bullseye
|
19 |
+
|
20 |
+
RUN apt-get update && apt-get install -y nano tmux
|
21 |
+
|
22 |
+
COPY ./requirements.txt /
|
23 |
+
RUN pip3 install -r requirements.txt
|
24 |
+
|
25 |
+
COPY ./ /
|
26 |
+
|
27 |
+
ENTRYPOINT ["/bin/bash"]
|
data_models/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright 2024 Google LLC
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
data_models/embedding_converter.py
ADDED
@@ -0,0 +1,445 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Google LLC
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Converts Embedding Requests and Responses to json and vice versa."""
|
16 |
+
|
17 |
+
import dataclasses
|
18 |
+
import json
|
19 |
+
from typing import Any, List, Mapping, Sequence
|
20 |
+
|
21 |
+
from ez_wsi_dicomweb import patch_embedding_endpoints
|
22 |
+
|
23 |
+
import pete_errors
|
24 |
+
from data_models import embedding_request
|
25 |
+
from data_models import embedding_response
|
26 |
+
from data_models import patch_coordinate as patch_coordinate_module
|
27 |
+
|
28 |
+
_EndpointJsonKeys = patch_embedding_endpoints.EndpointJsonKeys
|
29 |
+
|
30 |
+
|
31 |
+
class ValidationError(Exception):
|
32 |
+
pass
|
33 |
+
|
34 |
+
|
35 |
+
class _InvalidCoordinateError(Exception):
|
36 |
+
pass
|
37 |
+
|
38 |
+
|
39 |
+
class _InstanceUIDMetadataError(Exception):
|
40 |
+
pass
|
41 |
+
|
42 |
+
|
43 |
+
def validate_int(val: Any) -> int:
|
44 |
+
if isinstance(val, float):
|
45 |
+
cast_val = int(val)
|
46 |
+
if cast_val != val:
|
47 |
+
raise ValidationError('coordinate value is not int')
|
48 |
+
val = cast_val
|
49 |
+
elif not isinstance(val, int):
|
50 |
+
raise ValidationError('coordinate value is not int')
|
51 |
+
return val
|
52 |
+
|
53 |
+
|
54 |
+
def _get_patch_coord(patch_coordinates: Sequence[Mapping[str, Any]]):
|
55 |
+
"""Returns patch coodianates."""
|
56 |
+
result = []
|
57 |
+
if not isinstance(patch_coordinates, list):
|
58 |
+
raise _InvalidCoordinateError('patch_coordinates is not list')
|
59 |
+
for patch_coordinate in patch_coordinates:
|
60 |
+
try:
|
61 |
+
pc = patch_coordinate_module.create_patch_coordinate(**patch_coordinate)
|
62 |
+
except TypeError as exp:
|
63 |
+
if not isinstance(patch_coordinate, dict):
|
64 |
+
raise _InvalidCoordinateError('Patch coordinate is not dict.') from exp
|
65 |
+
keys = ', '.join(
|
66 |
+
list(
|
67 |
+
dataclasses.asdict(
|
68 |
+
patch_coordinate_module.create_patch_coordinate(0, 0)
|
69 |
+
)
|
70 |
+
)
|
71 |
+
)
|
72 |
+
raise _InvalidCoordinateError(
|
73 |
+
f'Patch coordinate dict has invalid keys; expecting: {keys}'
|
74 |
+
) from exp
|
75 |
+
try:
|
76 |
+
validate_int(pc.x_origin)
|
77 |
+
validate_int(pc.y_origin)
|
78 |
+
validate_int(pc.width)
|
79 |
+
validate_int(pc.height)
|
80 |
+
except ValidationError as exp:
|
81 |
+
raise _InvalidCoordinateError(
|
82 |
+
f'Invalid patch coordinate; x_origin: {pc.x_origin}, y_origin:'
|
83 |
+
f' {pc.y_origin}, width: {pc.width}, height: {pc.height}'
|
84 |
+
) from exp
|
85 |
+
result.append(pc)
|
86 |
+
if not result:
|
87 |
+
raise _InvalidCoordinateError('empty patch_coordinates')
|
88 |
+
return result
|
89 |
+
|
90 |
+
|
91 |
+
def embedding_response_v1_to_json(
|
92 |
+
response: embedding_response.EmbeddingResponseV1,
|
93 |
+
) -> Mapping[str, Any]:
|
94 |
+
"""Loads the model artifact.
|
95 |
+
|
96 |
+
Args:
|
97 |
+
response: Structed EmbeddingResponse object.
|
98 |
+
|
99 |
+
Returns:
|
100 |
+
The value of the JSON payload to return in the API.
|
101 |
+
"""
|
102 |
+
json_response = dataclasses.asdict(response)
|
103 |
+
if response.error_response:
|
104 |
+
json_response['error_response'][
|
105 |
+
'error_code'
|
106 |
+
] = response.error_response.error_code.value
|
107 |
+
return {_EndpointJsonKeys.PREDICTIONS: json_response}
|
108 |
+
|
109 |
+
|
110 |
+
def embedding_response_v2_to_json(
|
111 |
+
json_response: Sequence[Mapping[str, Any]],
|
112 |
+
) -> Mapping[str, Any]:
|
113 |
+
return {_EndpointJsonKeys.PREDICTIONS: json_response}
|
114 |
+
|
115 |
+
|
116 |
+
def validate_str_list(val: Any) -> List[str]:
|
117 |
+
if not isinstance(val, List):
|
118 |
+
raise ValidationError('not list')
|
119 |
+
for v in val:
|
120 |
+
if not isinstance(v, str) or not v:
|
121 |
+
raise ValidationError('list contains invalid value')
|
122 |
+
return val
|
123 |
+
|
124 |
+
|
125 |
+
def _validate_instance_uids_not_empty_str_list(val: Any) -> List[str]:
|
126 |
+
try:
|
127 |
+
val = validate_str_list(val)
|
128 |
+
except ValidationError as exp:
|
129 |
+
raise _InstanceUIDMetadataError() from exp
|
130 |
+
if not val:
|
131 |
+
raise _InstanceUIDMetadataError('list is empty')
|
132 |
+
return val
|
133 |
+
|
134 |
+
|
135 |
+
def validate_str_key_dict(val: Any) -> Mapping[str, Any]:
|
136 |
+
if not isinstance(val, dict):
|
137 |
+
raise ValidationError('not a dict')
|
138 |
+
if val:
|
139 |
+
for k in val:
|
140 |
+
if not isinstance(k, str) or not k:
|
141 |
+
raise ValidationError('dict contains invalid value')
|
142 |
+
return val
|
143 |
+
|
144 |
+
|
145 |
+
def validate_str(val: Any) -> str:
|
146 |
+
if not isinstance(val, str):
|
147 |
+
raise ValidationError('not string')
|
148 |
+
return val
|
149 |
+
|
150 |
+
|
151 |
+
def _validate_not_empty_str(val: Any) -> str:
|
152 |
+
if not isinstance(val, str) or not val:
|
153 |
+
raise ValidationError('not string or empty')
|
154 |
+
return val
|
155 |
+
|
156 |
+
|
157 |
+
def _generate_instance_metadata_error_string(
|
158 |
+
metadata: Mapping[str, Any], *keys: str
|
159 |
+
) -> str:
|
160 |
+
"""returns instance metadata as a error string."""
|
161 |
+
result = {}
|
162 |
+
for key in keys:
|
163 |
+
if key not in metadata:
|
164 |
+
continue
|
165 |
+
if key == _EndpointJsonKeys.EXTENSIONS:
|
166 |
+
value = metadata[key]
|
167 |
+
if isinstance(value, Mapping):
|
168 |
+
value = dict(value)
|
169 |
+
# Strip ez_wsi_state from output.
|
170 |
+
# Not contributing to validation errors here and may be very large.
|
171 |
+
if _EndpointJsonKeys.EZ_WSI_STATE in value:
|
172 |
+
del value[_EndpointJsonKeys.EZ_WSI_STATE]
|
173 |
+
result[key] = value
|
174 |
+
continue
|
175 |
+
elif key == _EndpointJsonKeys.BEARER_TOKEN:
|
176 |
+
value = metadata[key]
|
177 |
+
# If bearer token is present, and defined strip
|
178 |
+
if isinstance(value, str) and value:
|
179 |
+
result[key] = 'PRESENT'
|
180 |
+
continue
|
181 |
+
# otherwise just associate key and value.
|
182 |
+
result[key] = metadata[key]
|
183 |
+
return json.dumps(result, sort_keys=True)
|
184 |
+
|
185 |
+
|
186 |
+
def _validate_instance_list(json_metadata: Mapping[str, Any]) -> List[Any]:
|
187 |
+
val = json_metadata.get(_EndpointJsonKeys.INSTANCES)
|
188 |
+
if isinstance(val, list):
|
189 |
+
return val
|
190 |
+
raise pete_errors.InvalidRequestFieldError(
|
191 |
+
'Invalid input, missing expected'
|
192 |
+
f' key: {_EndpointJsonKeys.INSTANCES} and associated list of values.'
|
193 |
+
)
|
194 |
+
|
195 |
+
|
196 |
+
class EmbeddingConverterV1:
|
197 |
+
"""Class containing methods for transforming embedding request and responses."""
|
198 |
+
|
199 |
+
def json_to_embedding_request(
|
200 |
+
self, json_metadata: Mapping[str, Any]
|
201 |
+
) -> embedding_request.EmbeddingRequestV1:
|
202 |
+
"""Converts json to embedding request.
|
203 |
+
|
204 |
+
Args:
|
205 |
+
json_metadata: The value of the JSON payload provided to the API.
|
206 |
+
|
207 |
+
Returns:
|
208 |
+
Structured EmbeddingRequest object.
|
209 |
+
|
210 |
+
Raises:
|
211 |
+
InvalidRequestFieldError: If the provided fields are invalid.
|
212 |
+
"""
|
213 |
+
instances = []
|
214 |
+
try:
|
215 |
+
model_params = json_metadata[_EndpointJsonKeys.PARAMETERS]
|
216 |
+
try:
|
217 |
+
parameters = embedding_request.EmbeddingParameters(
|
218 |
+
model_size=_validate_not_empty_str(
|
219 |
+
model_params.get(_EndpointJsonKeys.MODEL_SIZE)
|
220 |
+
),
|
221 |
+
model_kind=_validate_not_empty_str(
|
222 |
+
model_params.get(_EndpointJsonKeys.MODEL_KIND)
|
223 |
+
),
|
224 |
+
)
|
225 |
+
except ValidationError as exp:
|
226 |
+
raise pete_errors.InvalidRequestFieldError(
|
227 |
+
'Invalid model size and/or kind parameters.'
|
228 |
+
) from exp
|
229 |
+
for instance in _validate_instance_list(json_metadata):
|
230 |
+
ez_wsi_state = instance.get(_EndpointJsonKeys.EZ_WSI_STATE, {})
|
231 |
+
try:
|
232 |
+
ez_wsi_state = validate_str_key_dict(ez_wsi_state)
|
233 |
+
except ValidationError:
|
234 |
+
try:
|
235 |
+
ez_wsi_state = validate_str(ez_wsi_state)
|
236 |
+
except ValidationError as exp:
|
237 |
+
raise pete_errors.InvalidRequestFieldError(
|
238 |
+
'Invalid EZ-WSI state metadata.'
|
239 |
+
) from exp
|
240 |
+
try:
|
241 |
+
instances.append(
|
242 |
+
embedding_request.EmbeddingInstanceV1(
|
243 |
+
dicom_web_store_url=_validate_not_empty_str(
|
244 |
+
instance.get(_EndpointJsonKeys.DICOM_WEB_STORE_URL)
|
245 |
+
),
|
246 |
+
dicom_study_uid=_validate_not_empty_str(
|
247 |
+
instance.get(_EndpointJsonKeys.DICOM_STUDY_UID)
|
248 |
+
),
|
249 |
+
dicom_series_uid=_validate_not_empty_str(
|
250 |
+
instance.get(_EndpointJsonKeys.DICOM_SERIES_UID)
|
251 |
+
),
|
252 |
+
bearer_token=_validate_not_empty_str(
|
253 |
+
instance.get(_EndpointJsonKeys.BEARER_TOKEN)
|
254 |
+
),
|
255 |
+
ez_wsi_state=ez_wsi_state,
|
256 |
+
instance_uids=_validate_instance_uids_not_empty_str_list(
|
257 |
+
instance.get(_EndpointJsonKeys.INSTANCE_UIDS)
|
258 |
+
),
|
259 |
+
patch_coordinates=_get_patch_coord(
|
260 |
+
instance.get(_EndpointJsonKeys.PATCH_COORDINATES)
|
261 |
+
),
|
262 |
+
)
|
263 |
+
)
|
264 |
+
except ValidationError as exp:
|
265 |
+
instance_error_msg = _generate_instance_metadata_error_string(
|
266 |
+
instance,
|
267 |
+
_EndpointJsonKeys.DICOM_WEB_STORE_URL,
|
268 |
+
_EndpointJsonKeys.DICOM_STUDY_UID,
|
269 |
+
_EndpointJsonKeys.DICOM_SERIES_UID,
|
270 |
+
_EndpointJsonKeys.BEARER_TOKEN,
|
271 |
+
_EndpointJsonKeys.INSTANCE_UIDS,
|
272 |
+
)
|
273 |
+
raise pete_errors.InvalidRequestFieldError(
|
274 |
+
f'Invalid instance; {instance_error_msg}'
|
275 |
+
) from exp
|
276 |
+
except _InstanceUIDMetadataError as exp:
|
277 |
+
instance_error_msg = _generate_instance_metadata_error_string(
|
278 |
+
instance,
|
279 |
+
_EndpointJsonKeys.PATCH_COORDINATES,
|
280 |
+
)
|
281 |
+
raise pete_errors.InvalidRequestFieldError(
|
282 |
+
f'Invalid DICOM SOP Instance UID metadata; {instance_error_msg}'
|
283 |
+
) from exp
|
284 |
+
except _InvalidCoordinateError as exp:
|
285 |
+
raise pete_errors.InvalidRequestFieldError(
|
286 |
+
f'Invalid patch coordinate; {exp}'
|
287 |
+
) from exp
|
288 |
+
except (TypeError, ValueError, KeyError) as exp:
|
289 |
+
raise pete_errors.InvalidRequestFieldError(
|
290 |
+
f'Invalid input: {json.dumps(json_metadata)}'
|
291 |
+
) from exp
|
292 |
+
return embedding_request.EmbeddingRequestV1(
|
293 |
+
parameters=parameters, instances=instances
|
294 |
+
)
|
295 |
+
|
296 |
+
|
297 |
+
class EmbeddingConverterV2:
|
298 |
+
"""Class containing methods for transforming embedding request and responses."""
|
299 |
+
|
300 |
+
def json_to_embedding_request(
|
301 |
+
self, json_metadata: Mapping[str, Any]
|
302 |
+
) -> embedding_request.EmbeddingRequestV2:
|
303 |
+
"""Converts json to embedding request.
|
304 |
+
|
305 |
+
Args:
|
306 |
+
json_metadata: The value of the JSON payload provided to the API.
|
307 |
+
|
308 |
+
Returns:
|
309 |
+
Structured EmbeddingRequest object.
|
310 |
+
|
311 |
+
Raises:
|
312 |
+
InvalidRequestFieldError: If the provided fields are invalid.
|
313 |
+
"""
|
314 |
+
instances = []
|
315 |
+
for instance in _validate_instance_list(json_metadata):
|
316 |
+
try:
|
317 |
+
patch_coordinates = _get_patch_coord(
|
318 |
+
instance.get(_EndpointJsonKeys.PATCH_COORDINATES)
|
319 |
+
)
|
320 |
+
except _InvalidCoordinateError as exp:
|
321 |
+
instance_error_msg = _generate_instance_metadata_error_string(
|
322 |
+
instance,
|
323 |
+
_EndpointJsonKeys.PATCH_COORDINATES,
|
324 |
+
)
|
325 |
+
raise pete_errors.InvalidRequestFieldError(
|
326 |
+
f'Invalid patch coordinate; {exp}; {instance_error_msg}'
|
327 |
+
) from exp
|
328 |
+
if _EndpointJsonKeys.DICOM_PATH in instance:
|
329 |
+
try:
|
330 |
+
dicom_path = validate_str_key_dict(
|
331 |
+
instance.get(_EndpointJsonKeys.DICOM_PATH)
|
332 |
+
)
|
333 |
+
except ValidationError as exp:
|
334 |
+
raise pete_errors.InvalidRequestFieldError(
|
335 |
+
'Invalid DICOM path.'
|
336 |
+
) from exp
|
337 |
+
try:
|
338 |
+
instances.append(
|
339 |
+
embedding_request.DicomImageV2(
|
340 |
+
series_path=_validate_not_empty_str(
|
341 |
+
dicom_path.get(_EndpointJsonKeys.SERIES_PATH)
|
342 |
+
),
|
343 |
+
bearer_token=validate_str(
|
344 |
+
instance.get(
|
345 |
+
_EndpointJsonKeys.BEARER_TOKEN,
|
346 |
+
'',
|
347 |
+
)
|
348 |
+
),
|
349 |
+
extensions=validate_str_key_dict(
|
350 |
+
instance.get(
|
351 |
+
_EndpointJsonKeys.EXTENSIONS,
|
352 |
+
{},
|
353 |
+
)
|
354 |
+
),
|
355 |
+
instance_uids=_validate_instance_uids_not_empty_str_list(
|
356 |
+
dicom_path.get(_EndpointJsonKeys.INSTANCE_UIDS)
|
357 |
+
),
|
358 |
+
patch_coordinates=patch_coordinates,
|
359 |
+
)
|
360 |
+
)
|
361 |
+
except _InstanceUIDMetadataError as exp:
|
362 |
+
error_msg = _generate_instance_metadata_error_string(
|
363 |
+
instance,
|
364 |
+
_EndpointJsonKeys.SERIES_PATH,
|
365 |
+
_EndpointJsonKeys.BEARER_TOKEN,
|
366 |
+
_EndpointJsonKeys.EXTENSIONS,
|
367 |
+
_EndpointJsonKeys.INSTANCE_UIDS,
|
368 |
+
)
|
369 |
+
raise pete_errors.InvalidRequestFieldError(
|
370 |
+
f'Invalid DICOM SOP Instance UID metadata; {error_msg}'
|
371 |
+
) from exp
|
372 |
+
except ValidationError as exp:
|
373 |
+
error_msg = _generate_instance_metadata_error_string(
|
374 |
+
instance,
|
375 |
+
_EndpointJsonKeys.SERIES_PATH,
|
376 |
+
_EndpointJsonKeys.BEARER_TOKEN,
|
377 |
+
_EndpointJsonKeys.EXTENSIONS,
|
378 |
+
_EndpointJsonKeys.INSTANCE_UIDS,
|
379 |
+
)
|
380 |
+
raise pete_errors.InvalidRequestFieldError(
|
381 |
+
f'DICOM instance JSON formatting is invalid; {error_msg}'
|
382 |
+
) from exp
|
383 |
+
elif _EndpointJsonKeys.IMAGE_FILE_URI in instance:
|
384 |
+
try:
|
385 |
+
instances.append(
|
386 |
+
embedding_request.GcsImageV2(
|
387 |
+
image_file_uri=_validate_not_empty_str(
|
388 |
+
instance.get(_EndpointJsonKeys.IMAGE_FILE_URI)
|
389 |
+
),
|
390 |
+
bearer_token=validate_str(
|
391 |
+
instance.get(
|
392 |
+
_EndpointJsonKeys.BEARER_TOKEN,
|
393 |
+
'',
|
394 |
+
)
|
395 |
+
),
|
396 |
+
extensions=validate_str_key_dict(
|
397 |
+
instance.get(
|
398 |
+
_EndpointJsonKeys.EXTENSIONS,
|
399 |
+
{},
|
400 |
+
)
|
401 |
+
),
|
402 |
+
patch_coordinates=patch_coordinates,
|
403 |
+
)
|
404 |
+
)
|
405 |
+
except ValidationError as exp:
|
406 |
+
error_msg = _generate_instance_metadata_error_string(
|
407 |
+
instance,
|
408 |
+
_EndpointJsonKeys.IMAGE_FILE_URI,
|
409 |
+
_EndpointJsonKeys.BEARER_TOKEN,
|
410 |
+
_EndpointJsonKeys.EXTENSIONS,
|
411 |
+
)
|
412 |
+
raise pete_errors.InvalidRequestFieldError(
|
413 |
+
'Google Cloud Storage instance JSON formatting is invalid;'
|
414 |
+
f' {error_msg}'
|
415 |
+
) from exp
|
416 |
+
elif _EndpointJsonKeys.RAW_IMAGE_BYTES in instance:
|
417 |
+
try:
|
418 |
+
instances.append(
|
419 |
+
embedding_request.EmbeddedImageV2(
|
420 |
+
image_bytes=_validate_not_empty_str(
|
421 |
+
instance.get(_EndpointJsonKeys.RAW_IMAGE_BYTES)
|
422 |
+
),
|
423 |
+
extensions=validate_str_key_dict(
|
424 |
+
instance.get(
|
425 |
+
_EndpointJsonKeys.EXTENSIONS,
|
426 |
+
{},
|
427 |
+
)
|
428 |
+
),
|
429 |
+
patch_coordinates=patch_coordinates,
|
430 |
+
)
|
431 |
+
)
|
432 |
+
except ValidationError as exp:
|
433 |
+
error_msg = _generate_instance_metadata_error_string(
|
434 |
+
instance,
|
435 |
+
_EndpointJsonKeys.IMAGE_FILE_URI,
|
436 |
+
_EndpointJsonKeys.BEARER_TOKEN,
|
437 |
+
_EndpointJsonKeys.EXTENSIONS,
|
438 |
+
)
|
439 |
+
raise pete_errors.InvalidRequestFieldError(
|
440 |
+
'Embedded image instance JSON formatting is invalid; '
|
441 |
+
f' {error_msg}'
|
442 |
+
) from exp
|
443 |
+
else:
|
444 |
+
raise pete_errors.InvalidRequestFieldError('unidentified type')
|
445 |
+
return embedding_request.EmbeddingRequestV2(instances)
|
data_models/embedding_request.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Google LLC
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Request dataclasses for Pete."""
|
16 |
+
|
17 |
+
import dataclasses
|
18 |
+
import enum
|
19 |
+
from typing import Any, List, Mapping, Union
|
20 |
+
from data_models import patch_coordinate
|
21 |
+
|
22 |
+
|
23 |
+
class ModelSize(enum.Enum):
|
24 |
+
UNDEFINED = 0
|
25 |
+
SMALL = 1 # ~1M parameters
|
26 |
+
MEDIUM = 2 # ~20M parameters.
|
27 |
+
LARGE = 3 # ~100M parameters.
|
28 |
+
|
29 |
+
|
30 |
+
class ModelKind(enum.Enum):
|
31 |
+
UNDEFINED = 0
|
32 |
+
# Best suited for high magnification images.
|
33 |
+
# Pixel spacings of .002mm, .001mm, .0005mm or 5x, 10x, 20x.
|
34 |
+
LOW_PIXEL_SPACING = 1
|
35 |
+
# Best suited for low magnification images.
|
36 |
+
# Pixel spacings of .004mm, .008mm, .016mm, 5x_div_2, 5x_div4, 5x_div8.
|
37 |
+
HIGH_PIXEL_SPACING = 2
|
38 |
+
|
39 |
+
|
40 |
+
@dataclasses.dataclass(frozen=True)
|
41 |
+
class EmbeddingInstanceV1:
|
42 |
+
"""An instance in a DICOM Embedding Request as described in the schema file."""
|
43 |
+
|
44 |
+
dicom_web_store_url: str
|
45 |
+
dicom_study_uid: str
|
46 |
+
dicom_series_uid: str
|
47 |
+
bearer_token: str
|
48 |
+
ez_wsi_state: Union[str, Mapping[str, Any]]
|
49 |
+
instance_uids: List[str]
|
50 |
+
patch_coordinates: List[patch_coordinate.PatchCoordinate]
|
51 |
+
|
52 |
+
|
53 |
+
@dataclasses.dataclass(frozen=True)
|
54 |
+
class DicomImageV2:
|
55 |
+
"""An instance in a DICOM Embedding Request as described in the schema file."""
|
56 |
+
series_path: str
|
57 |
+
bearer_token: str
|
58 |
+
extensions: Mapping[str, Any]
|
59 |
+
instance_uids: List[str]
|
60 |
+
patch_coordinates: List[patch_coordinate.PatchCoordinate]
|
61 |
+
|
62 |
+
|
63 |
+
@dataclasses.dataclass(frozen=True)
|
64 |
+
class GcsImageV2:
|
65 |
+
"""An instance in a DICOM Embedding Request as described in the schema file."""
|
66 |
+
|
67 |
+
image_file_uri: str
|
68 |
+
bearer_token: str
|
69 |
+
extensions: Mapping[str, Any]
|
70 |
+
patch_coordinates: List[patch_coordinate.PatchCoordinate]
|
71 |
+
|
72 |
+
|
73 |
+
@dataclasses.dataclass(frozen=True)
|
74 |
+
class EmbeddedImageV2:
|
75 |
+
"""An instance in a DICOM Embedding Request as described in the schema file."""
|
76 |
+
image_bytes: str
|
77 |
+
extensions: Mapping[str, Any]
|
78 |
+
patch_coordinates: List[patch_coordinate.PatchCoordinate]
|
79 |
+
|
80 |
+
|
81 |
+
EmbeddingInstanceV2 = Union[DicomImageV2, GcsImageV2, EmbeddedImageV2]
|
82 |
+
|
83 |
+
|
84 |
+
@dataclasses.dataclass(frozen=True)
|
85 |
+
class EmbeddingParameters:
|
86 |
+
"""A prediction in a DICOM Embedding Request as described in the schema file."""
|
87 |
+
|
88 |
+
model_size: str
|
89 |
+
model_kind: str
|
90 |
+
|
91 |
+
|
92 |
+
@dataclasses.dataclass(frozen=True)
|
93 |
+
class EmbeddingRequestV1:
|
94 |
+
"""A DICOM Embedding Request is a single parameter and list of instances."""
|
95 |
+
parameters: EmbeddingParameters
|
96 |
+
instances: List[EmbeddingInstanceV1]
|
97 |
+
|
98 |
+
|
99 |
+
@dataclasses.dataclass(frozen=True)
|
100 |
+
class EmbeddingRequestV2:
|
101 |
+
"""A DICOM Embedding Request is a list of instances."""
|
102 |
+
|
103 |
+
instances: List[EmbeddingInstanceV2]
|
data_models/embedding_response.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Google LLC
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Response dataclasses for Pete."""
|
16 |
+
|
17 |
+
import dataclasses
|
18 |
+
import enum
|
19 |
+
from typing import Any, List, Mapping, Optional, Sequence
|
20 |
+
|
21 |
+
from ez_wsi_dicomweb import patch_embedding_endpoints
|
22 |
+
|
23 |
+
import pete_errors
|
24 |
+
from data_models import patch_coordinate
|
25 |
+
|
26 |
+
_MAX_ERROR_DESCRIPTION_LENGTH = 1024
|
27 |
+
|
28 |
+
|
29 |
+
class ErrorCode(enum.Enum):
|
30 |
+
"""The error codes for PeteErrorResponse mapped from PeteErrors."""
|
31 |
+
|
32 |
+
TOO_MANY_PATCHES_ERROR = 'TOO_MANY_PATCHES_ERROR'
|
33 |
+
INVALID_CREDENTIALS_ERROR = (
|
34 |
+
patch_embedding_endpoints.EndpointJsonKeys.INVALID_CREDENTIALS
|
35 |
+
)
|
36 |
+
PATCH_DIMENSIONS_DO_NOT_MATCH_ENDPOINT_INPUT_DIMENSIONS_ERROR = (
|
37 |
+
'PATCH_DIMENSIONS_DO_NOT_MATCH_ENDPOINT_INPUT_DIMENSIONS_ERROR'
|
38 |
+
)
|
39 |
+
INSTANCES_NOT_CONCATENATED_ERROR = 'INSTANCES_NOT_CONCATENATED_ERROR'
|
40 |
+
INVALID_REQUEST_FIELD_ERROR = 'INVALID_REQUEST_FIELD_ERROR'
|
41 |
+
INVALID_RESPONSE_ERROR = 'INVALID_RESPONSE_ERROR'
|
42 |
+
LEVEL_NOT_FOUND_ERROR = 'LEVEL_NOT_FOUND_ERROR'
|
43 |
+
EZ_WSI_STATE_ERROR = 'EZ_WSI_STATE_ERROR'
|
44 |
+
IMAGE_ERROR = 'IMAGE_ERROR'
|
45 |
+
HTTP_ERROR = 'HTTP_ERROR'
|
46 |
+
INVALID_ICC_PROFILE_TRANSFORM_ERROR = 'INVALID_ICC_PROFILE_TRANSFORM_ERROR'
|
47 |
+
IMAGE_DIMENSION_ERROR = 'IMAGE_DIMENSION_ERROR'
|
48 |
+
DICOM_TILED_FULL_ERROR = 'DICOM_TILED_FULL_ERROR'
|
49 |
+
DICOM_ERROR = 'DICOM_ERROR'
|
50 |
+
DICOM_IMAGE_DOWNSAMPLING_TOO_LARGE_ERROR = (
|
51 |
+
'DICOM_IMAGE_DOWNSAMPLING_TOO_LARGE_ERROR'
|
52 |
+
)
|
53 |
+
PATCH_OUTSIDE_OF_IMAGE_DIMENSIONS_ERROR = (
|
54 |
+
'PATCH_OUTSIDE_OF_IMAGE_DIMENSIONS_ERROR'
|
55 |
+
)
|
56 |
+
DICOM_PATH_ERROR = 'DICOM_PATH_ERROR'
|
57 |
+
GCS_IMAGE_PATH_FORMAT_ERROR = 'GCS_IMAGE_PATH_FORMAT_ERROR'
|
58 |
+
UNAPPROVED_DICOM_STORE_ERROR = 'UNAPPROVED_DICOM_STORE_ERROR'
|
59 |
+
UNAPPROVED_GCS_BUCKET_ERROR = 'UNAPPROVED_GCS_BUCKET_ERROR'
|
60 |
+
|
61 |
+
|
62 |
+
@dataclasses.dataclass(frozen=True)
|
63 |
+
class PeteErrorResponse:
|
64 |
+
"""The response when Pete is unable to successfully complete a request."""
|
65 |
+
|
66 |
+
error_code: ErrorCode
|
67 |
+
|
68 |
+
|
69 |
+
@dataclasses.dataclass(frozen=True)
|
70 |
+
class PatchEmbeddingV1:
|
71 |
+
"""A List of embeddings, instance uids, and patch coordinate."""
|
72 |
+
|
73 |
+
embeddings: List[float]
|
74 |
+
patch_coordinate: patch_coordinate.PatchCoordinate
|
75 |
+
|
76 |
+
|
77 |
+
@dataclasses.dataclass(frozen=True)
|
78 |
+
class PatchEmbeddingV2:
|
79 |
+
"""A List of embeddings, instance uids, and patch coordinate."""
|
80 |
+
|
81 |
+
embedding_vector: List[float]
|
82 |
+
patch_coordinate: patch_coordinate.PatchCoordinate
|
83 |
+
|
84 |
+
|
85 |
+
@dataclasses.dataclass(frozen=True)
|
86 |
+
class EmbeddingResultV1:
|
87 |
+
"""The response when Pete is able to successfully complete a request."""
|
88 |
+
|
89 |
+
dicom_study_uid: str
|
90 |
+
dicom_series_uid: str
|
91 |
+
instance_uids: List[str]
|
92 |
+
patch_embeddings: List[PatchEmbeddingV1]
|
93 |
+
|
94 |
+
|
95 |
+
@dataclasses.dataclass(frozen=True)
|
96 |
+
class EmbeddingResponseV1:
|
97 |
+
"""An instance in a Embedding Response as described in the schema file."""
|
98 |
+
|
99 |
+
model_version: str
|
100 |
+
error_response: Optional[PeteErrorResponse]
|
101 |
+
embedding_result: List[EmbeddingResultV1]
|
102 |
+
|
103 |
+
def __post_init__(self):
|
104 |
+
if self.error_response is None and self.embedding_result is None:
|
105 |
+
raise pete_errors.InvalidResponseError(
|
106 |
+
'At least one of error_response or embedding_result must be set.'
|
107 |
+
)
|
108 |
+
|
109 |
+
|
110 |
+
def embedding_instance_response_v2(
|
111 |
+
results: Sequence[PatchEmbeddingV2],
|
112 |
+
) -> Mapping[str, Any]:
|
113 |
+
"""Returns a JSON-serializable embedding instance responses."""
|
114 |
+
return {
|
115 |
+
patch_embedding_endpoints.EndpointJsonKeys.RESULT: {
|
116 |
+
patch_embedding_endpoints.EndpointJsonKeys.PATCH_EMBEDDINGS: [
|
117 |
+
dataclasses.asdict(patch_embedding) for patch_embedding in results
|
118 |
+
]
|
119 |
+
},
|
120 |
+
}
|
121 |
+
|
122 |
+
|
123 |
+
def instance_error_response_v2(
|
124 |
+
error_code: ErrorCode, description: str = ''
|
125 |
+
) -> Mapping[str, Any]:
|
126 |
+
error = {
|
127 |
+
patch_embedding_endpoints.EndpointJsonKeys.ERROR_CODE: error_code.value
|
128 |
+
}
|
129 |
+
if description:
|
130 |
+
error[patch_embedding_endpoints.EndpointJsonKeys.ERROR_CODE_DESCRIPTION] = (
|
131 |
+
description[:_MAX_ERROR_DESCRIPTION_LENGTH]
|
132 |
+
)
|
133 |
+
return {
|
134 |
+
patch_embedding_endpoints.EndpointJsonKeys.ERROR: error,
|
135 |
+
}
|
136 |
+
|
137 |
+
|
138 |
+
def prediction_error_response_v2(error_code: ErrorCode) -> Mapping[str, Any]:
|
139 |
+
return {
|
140 |
+
patch_embedding_endpoints.EndpointJsonKeys.VERTEXAI_ERROR: (
|
141 |
+
error_code.value
|
142 |
+
)
|
143 |
+
}
|
data_models/patch_coordinate.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Google LLC
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Shared dataclasses across requests and responses for Pete."""
|
16 |
+
|
17 |
+
import dataclasses
|
18 |
+
|
19 |
+
import pete_errors
|
20 |
+
|
21 |
+
|
22 |
+
@dataclasses.dataclass(frozen=True)
|
23 |
+
class PatchCoordinate:
|
24 |
+
"""A coordinate of a patch."""
|
25 |
+
|
26 |
+
x_origin: int
|
27 |
+
y_origin: int
|
28 |
+
height: int
|
29 |
+
width: int
|
30 |
+
|
31 |
+
def __post_init__(self):
|
32 |
+
if (self.width != 224 or self.height != 224):
|
33 |
+
raise pete_errors.PatchDimensionsDoNotMatchEndpointInputDimensionsError(
|
34 |
+
'Patch coordinate width and height must be', f' 224x224.'
|
35 |
+
)
|
36 |
+
|
37 |
+
|
38 |
+
def create_patch_coordinate(
|
39 |
+
x_origin: int,
|
40 |
+
y_origin: int,
|
41 |
+
width: int = -1,
|
42 |
+
height: int = -1,
|
43 |
+
) -> PatchCoordinate:
|
44 |
+
"""Creates a patch coordinate."""
|
45 |
+
if width == -1:
|
46 |
+
width = 224
|
47 |
+
if height == -1:
|
48 |
+
height = 224
|
49 |
+
return PatchCoordinate(
|
50 |
+
x_origin=x_origin,
|
51 |
+
y_origin=y_origin,
|
52 |
+
width=width,
|
53 |
+
height=height,
|
54 |
+
)
|
pete_errors.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Google LLC
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Error classes for Pete."""
|
16 |
+
|
17 |
+
|
18 |
+
class InternalBugError(Exception):
|
19 |
+
"""Internal error capture exceptions which should never happen.
|
20 |
+
|
21 |
+
The exception is purposefully not a child of PeteError to prevent it from
|
22 |
+
being caught by pete exception handling logic. If InternalBugError are
|
23 |
+
raised they should be investigated as bugs. Most internal errors check for
|
24 |
+
expected conditions between the EZ-WSI pete interface.
|
25 |
+
"""
|
26 |
+
|
27 |
+
|
28 |
+
class PeteError(Exception):
|
29 |
+
"""Base error class for Pete Errors."""
|
30 |
+
|
31 |
+
def __init__(self, message: str = '', api_description: str = ''):
|
32 |
+
"""Errors with optional alternative descriptions for API echoing."""
|
33 |
+
super().__init__(message if message else api_description)
|
34 |
+
self._api_description = api_description
|
35 |
+
|
36 |
+
@property
|
37 |
+
def api_description(self) -> str:
|
38 |
+
"""Returns the API description of the error."""
|
39 |
+
return self._api_description if self._api_description else str(self)
|
40 |
+
|
41 |
+
|
42 |
+
class InstancesNotConcatenatedError(PeteError):
|
43 |
+
pass
|
44 |
+
|
45 |
+
|
46 |
+
class InvalidRequestFieldError(PeteError):
|
47 |
+
pass
|
48 |
+
|
49 |
+
|
50 |
+
class InvalidResponseError(PeteError):
|
51 |
+
pass
|
52 |
+
|
53 |
+
|
54 |
+
class InvalidCredentialsError(PeteError):
|
55 |
+
pass
|
56 |
+
|
57 |
+
|
58 |
+
class LevelNotFoundError(PeteError):
|
59 |
+
pass
|
60 |
+
|
61 |
+
|
62 |
+
class TooManyPatchesError(PeteError):
|
63 |
+
pass
|
64 |
+
|
65 |
+
|
66 |
+
class EzWsiStateError(PeteError):
|
67 |
+
pass
|
68 |
+
|
69 |
+
|
70 |
+
class GcsImagePathFormatError(PeteError):
|
71 |
+
pass
|
72 |
+
|
73 |
+
|
74 |
+
class ImageError(PeteError):
|
75 |
+
pass
|
76 |
+
|
77 |
+
|
78 |
+
class PatchOutsideOfImageDimensionsError(PeteError):
|
79 |
+
pass
|
80 |
+
|
81 |
+
|
82 |
+
class HttpError(PeteError):
|
83 |
+
pass
|
84 |
+
|
85 |
+
|
86 |
+
class InvalidIccProfileTransformError(PeteError):
|
87 |
+
pass
|
88 |
+
|
89 |
+
|
90 |
+
class ImageDimensionError(PeteError):
|
91 |
+
pass
|
92 |
+
|
93 |
+
|
94 |
+
class DicomTiledFullError(PeteError):
|
95 |
+
pass
|
96 |
+
|
97 |
+
|
98 |
+
class DicomPathError(PeteError):
|
99 |
+
pass
|
100 |
+
|
101 |
+
|
102 |
+
class DicomError(PeteError):
|
103 |
+
pass
|
104 |
+
|
105 |
+
|
106 |
+
class DicomImageDownsamplingTooLargeError(PeteError):
|
107 |
+
pass
|
108 |
+
|
109 |
+
|
110 |
+
class UnapprovedDicomStoreError(PeteError):
|
111 |
+
pass
|
112 |
+
|
113 |
+
|
114 |
+
class UnapprovedGcsBucketError(PeteError):
|
115 |
+
pass
|
116 |
+
|
117 |
+
|
118 |
+
class PatchDimensionsDoNotMatchEndpointInputDimensionsError(PeteError):
|
119 |
+
pass
|
pete_predictor_v2.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Google LLC
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Callable responsible for running Inference on provided patches."""
|
16 |
+
|
17 |
+
import functools
|
18 |
+
from typing import Any, Mapping
|
19 |
+
|
20 |
+
from ez_wsi_dicomweb import credential_factory
|
21 |
+
from ez_wsi_dicomweb import dicom_slide
|
22 |
+
from ez_wsi_dicomweb import patch_embedding
|
23 |
+
from ez_wsi_dicomweb import dicom_web_interface
|
24 |
+
from ez_wsi_dicomweb import patch_embedding_endpoints
|
25 |
+
from ez_wsi_dicomweb.ml_toolkit import dicom_path
|
26 |
+
import numpy as np
|
27 |
+
import tensorflow as tf
|
28 |
+
|
29 |
+
from data_models import embedding_response
|
30 |
+
from data_models import embedding_request
|
31 |
+
from data_models import embedding_converter
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
def _load_huggingface_model() -> tf.keras.Model:
|
36 |
+
raise ValueError('Liron implement model loader')
|
37 |
+
|
38 |
+
|
39 |
+
def _endpoint_model(ml_model: tf.keras.Model, image: np.ndarray) -> np.ndarray:
|
40 |
+
"""Function ez-wsi will use to run local ML model."""
|
41 |
+
result = ml_model.signatures['serving_default'](
|
42 |
+
tf.cast(tf.constant(image), tf.float32)
|
43 |
+
)
|
44 |
+
return result['output_0'].numpy()
|
45 |
+
|
46 |
+
|
47 |
+
_ENDPOINT_MODEL = functools.partial(_endpoint_model, _load_huggingface_model())
|
48 |
+
|
49 |
+
|
50 |
+
class PetePredictor:
|
51 |
+
"""Callable responsible for generating embeddings."""
|
52 |
+
|
53 |
+
def predict(
|
54 |
+
self,
|
55 |
+
prediction_input: Mapping[str, Any],
|
56 |
+
) -> Mapping[str, Any]:
|
57 |
+
"""Runs inference on provided patches.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
prediction_input: JSON formatted input for embedding prediction.
|
61 |
+
model: ModelRunner to handle model step.
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
JSON formatted output.
|
65 |
+
|
66 |
+
Raises:
|
67 |
+
ERROR_LOADING_DICOM: If the provided patches are not concated.
|
68 |
+
"""
|
69 |
+
embedding_json_converter = embedding_converter.EmbeddingConverterV2()
|
70 |
+
request = embedding_json_converter.json_to_embedding_request(prediction_input)
|
71 |
+
endpoint = patch_embedding_endpoints.LocalEndpoint(_ENDPOINT_MODEL)
|
72 |
+
|
73 |
+
embedding_results = []
|
74 |
+
for instance in request.instances:
|
75 |
+
patches = []
|
76 |
+
if not isinstance(instance, embedding_request.DicomImageV2):
|
77 |
+
raise ValueError('unsupported')
|
78 |
+
token = instance.bearer_token
|
79 |
+
if token:
|
80 |
+
cf = credential_factory.TokenPassthroughCredentialFactory(token)
|
81 |
+
else:
|
82 |
+
cf = credential_factory.NoAuthCredentialsFactory()
|
83 |
+
dwi = dicom_web_interface.DicomWebInterface(cf)
|
84 |
+
path = dicom_path.FromString(instance.series_path)
|
85 |
+
ds = dicom_slide.DicomSlide(dwi=dwi, path=path)
|
86 |
+
level = ds.get_instance_level(instance.instance_uids[0])
|
87 |
+
for coor in instance.patch_coordinates:
|
88 |
+
patches.append(ds.get_patch(level, coor.x_origin, coor.y_origin, coor.width, coor.height))
|
89 |
+
|
90 |
+
patch_embeddings = []
|
91 |
+
for index, result in enumerate(patch_embedding.generate_patch_embeddings(endpoint, patches)):
|
92 |
+
embedding = np.array(result.embedding)
|
93 |
+
patch_embeddings.append(
|
94 |
+
embedding_response.PatchEmbeddingV2(
|
95 |
+
embedding_vector=embedding.tolist(),
|
96 |
+
patch_coordinate=instance.patch_coordinates[index],
|
97 |
+
))
|
98 |
+
embedding_results.append(
|
99 |
+
embedding_response.embedding_instance_response_v2(patch_embeddings)
|
100 |
+
)
|
101 |
+
return embedding_converter.embedding_response_v2_to_json(embedding_results)
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py~=2.1.0
|
2 |
+
ez-wsi-dicomweb~=6.0.9
|
3 |
+
flask~=3.0.3
|
4 |
+
gunicorn~=23.0.0
|
5 |
+
numpy~=1.26.4
|
6 |
+
tensorflow~=2.17.0
|
7 |
+
typing-extensions~=4.12.2
|
server_gunicorn.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Google LLC
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Gunicorn application for passing requests through to the executor command.
|
16 |
+
|
17 |
+
Provides a thin, subject-agnostic request server for Vertex endpoints which
|
18 |
+
handles requests by piping their JSON bodies to the given executor command
|
19 |
+
and returning the json output.
|
20 |
+
"""
|
21 |
+
|
22 |
+
from collections.abc import Mapping
|
23 |
+
import http
|
24 |
+
import os
|
25 |
+
from typing import Any, Optional, Sequence
|
26 |
+
|
27 |
+
from absl import app
|
28 |
+
from absl import logging
|
29 |
+
import flask
|
30 |
+
from gunicorn.app import base as gunicorn_base
|
31 |
+
|
32 |
+
import pete_predictor_v2
|
33 |
+
|
34 |
+
|
35 |
+
def _create_app() -> flask.Flask:
|
36 |
+
"""Creates a Flask app with the given executor."""
|
37 |
+
predictor = pete_predictor_v2.PetePredictor()
|
38 |
+
flask_app = flask.Flask(__name__)
|
39 |
+
|
40 |
+
def predict() -> tuple[dict[str, Any], int]:
|
41 |
+
logging.info("predict route hit")
|
42 |
+
if flask.request.get_json(silent=True) is None:
|
43 |
+
return {"error": "No JSON body."}, http.HTTPStatus.BAD_REQUEST.value
|
44 |
+
|
45 |
+
logging.debug("Dispatching request to executor.")
|
46 |
+
try:
|
47 |
+
exec_result = predictor.predict(flask.request.get_json())
|
48 |
+
logging.debug("Executor returned results.")
|
49 |
+
return (exec_result, http.HTTPStatus.OK.value)
|
50 |
+
except RuntimeError:
|
51 |
+
logging.exception("Internal error handling request: Executor failed.")
|
52 |
+
return {
|
53 |
+
"error": "Internal server error."
|
54 |
+
}, http.HTTPStatus.INTERNAL_SERVER_ERROR.value
|
55 |
+
|
56 |
+
predict_route = os.environ.get("AIP_PREDICT_ROUTE", "/predict")
|
57 |
+
logging.info("predict route: %s", predict_route)
|
58 |
+
flask_app.add_url_rule(predict_route, view_func=predict, methods=["POST"])
|
59 |
+
|
60 |
+
flask_app.config["TRAP_BAD_REQUEST_ERRORS"] = True
|
61 |
+
|
62 |
+
return flask_app
|
63 |
+
|
64 |
+
|
65 |
+
class PredictionApplication(gunicorn_base.BaseApplication):
|
66 |
+
"""Application to serve predictors on Vertex endpoints using gunicorn."""
|
67 |
+
|
68 |
+
def __init__(
|
69 |
+
self,
|
70 |
+
*,
|
71 |
+
options: Optional[Mapping[str, Any]] = None,
|
72 |
+
):
|
73 |
+
self.options = options or {}
|
74 |
+
self.options = dict(self.options)
|
75 |
+
self.options["preload_app"] = False
|
76 |
+
self.application = _create_app()
|
77 |
+
super().__init__()
|
78 |
+
|
79 |
+
def load_config(self):
|
80 |
+
config = {
|
81 |
+
key: value
|
82 |
+
for key, value in self.options.items()
|
83 |
+
if key in self.cfg.settings and value is not None
|
84 |
+
}
|
85 |
+
for key, value in config.items():
|
86 |
+
self.cfg.set(key.lower(), value)
|
87 |
+
|
88 |
+
def load(self) -> flask.Flask:
|
89 |
+
return self.application
|
90 |
+
|
91 |
+
def main(argv: Sequence[str]) -> None:
|
92 |
+
options = {'bind': f'127.0.0.1:80',
|
93 |
+
'workers': 3,
|
94 |
+
'timeout':600
|
95 |
+
}
|
96 |
+
PredictionApplication(options=options).run()
|
97 |
+
|
98 |
+
|
99 |
+
if __name__ == '__main__':
|
100 |
+
app.run(main)
|
test.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ez_wsi_dicomweb import credential_factory
|
2 |
+
from ez_wsi_dicomweb import dicom_slide
|
3 |
+
from ez_wsi_dicomweb import dicom_web_interface
|
4 |
+
from ez_wsi_dicomweb import patch_embedding
|
5 |
+
from ez_wsi_dicomweb import patch_embedding_endpoints
|
6 |
+
from ez_wsi_dicomweb.ml_toolkit import dicom_path
|
7 |
+
|
8 |
+
|
9 |
+
if __name__ == '__main__':
|
10 |
+
endpoint = patch_embedding_endpoints.V2PatchEmbeddingEndpoint(credential_factory=credential_factory.NoAuthCredentialsFactory())
|
11 |
+
endpoint._end_point_url = 'http://127.0.0.1/predict'
|
12 |
+
|
13 |
+
dwi = dicom_web_interface.DicomWebInterface(credential_factory.NoAuthCredentialsFactory())
|
14 |
+
path = dicom_path.FromString("https://proxy.imaging.datacommons.cancer.gov/current/viewer-only-no-downloads-see-tinyurl-dot-com-slash-3j3d9jyp/dicomWeb/studies/2.25.247578737460869511622147617375340640521/series/1.3.6.1.4.1.5962.99.1.1334257398.450227235.1637716829942.2.0")
|
15 |
+
slide = dicom_slide.DicomSlide(dwi, path)
|
16 |
+
patch = slide.get_patch(slide.native_level, 0,0, 224,224)
|
17 |
+
embedding = patch_embedding.get_patch_embedding(endpoint, patch)
|
18 |
+
print(embedding)
|