Spaces:
Running
Running
# Copyright 2024 Google LLC | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Gunicorn application for passing requests through to the executor command. | |
Provides a thin, subject-agnostic request server for Vertex endpoints which | |
handles requests by piping their JSON bodies to the given executor command | |
and returning the json output. | |
""" | |
from collections.abc import Mapping | |
import http | |
import os | |
import sys | |
from typing import Any, Optional, Sequence | |
import json | |
from flask_cors import CORS | |
import diskcache | |
import gzip | |
from io import BytesIO | |
import shutil | |
import tempfile | |
import requests | |
from absl import app | |
from absl import logging | |
import auth | |
import flask | |
from flask import render_template, send_from_directory, Response, send_file, request, current_app, abort | |
from gunicorn.app import base as gunicorn_base | |
from flask_caching import Cache | |
# import pete_predictor_v2 | |
# Define a persistent cache directory | |
CACHE_DIR = "/home/user/app/path-cache" | |
# Configure the cache to use the persistent directory | |
cache_disk = diskcache.Cache(CACHE_DIR, size_limit=45e9) # Limit cache to 45GB | |
print(f"Cache stats: {cache_disk.stats()}") | |
DICOM_SERVER_URL = os.environ.get("DICOM_SERVER_URL") | |
PREDICT_SERVER_URL = os.environ.get("PREDICT_ENDPOINT_URL") | |
def validate_allowed_predict_request(data): | |
for item in data['instances']: | |
if 'dicom_path' not in item: | |
raise ValueError("Missing 'dicom_path' key in request data.") | |
if 'patch_coordinates' not in item: | |
raise ValueError("Missing 'patch_coordinates' key in request data.") | |
if 'raw_image_bytes' in item: | |
raise ValueError("'raw_image_bytes' key found in request data, but it is not expected") | |
if 'image_file_uri' in item: | |
raise ValueError("'image_file_uri' key found in request data, but it is not expected") | |
def test_series_path_prefix(data, server_url): | |
for item in data['instances']: | |
series_path = item['dicom_path']['series_path'] | |
if not series_path.startswith(server_url): | |
logging.error(f"series_path '{series_path}' does not start with '{server_url}'") | |
return False | |
return True | |
def replace_series_path_prefix(data, prefix, server_url): | |
for item in data['instances']: | |
item['dicom_path']['series_path'] = item['dicom_path']['series_path'].replace(prefix, server_url) | |
return data | |
def provide_dicom_server_token(data, token): | |
for item in data['instances']: | |
item['bearer_token'] = token | |
return data | |
def compress_response(json_data): | |
"""Compresses JSON data using gzip.""" | |
compressed_data = BytesIO() | |
with gzip.GzipFile(fileobj=compressed_data, mode='w') as gz: | |
gz.write(json_data.encode('utf-8')) | |
return compressed_data.getvalue() | |
def create_gzipped_response(data, status=http.HTTPStatus.OK.value, content_type='application/json'): | |
"""Creates a gzipped Flask response.""" | |
json_data = json.dumps(data) | |
compressed_data = compress_response(json_data) | |
response = Response(compressed_data, status=status, content_type=content_type) | |
response.headers['Content-Encoding'] = 'gzip' | |
return response | |
def get_cached_and_uncached_patches(instance, dicom_path): | |
"""Separates cached and uncached patches.""" | |
cached_patch_embeddings = [] | |
uncached_patches = [] | |
uncached_patch_indices = [] | |
for i, patch in enumerate(instance['patch_coordinates']): | |
cache_key = json.dumps({"dicom_path": dicom_path, "patch": patch}, sort_keys=True) | |
cached_result = cache_disk.get(cache_key) | |
if cached_result is not None: | |
cached_patch_embeddings.append({"patch_coordinate": patch, "embedding_vector": cached_result}) | |
else: | |
uncached_patches.append(patch) | |
uncached_patch_indices.append(i) | |
return cached_patch_embeddings, uncached_patches, uncached_patch_indices | |
def process_new_results(response_json, dicom_path): | |
"""Processes new results from the prediction server.""" | |
new_patch_embeddings = [] | |
if "predictions" in response_json: | |
for prediction in response_json["predictions"]: | |
if "result" in prediction and "patch_embeddings" in prediction["result"]: | |
for patch_embedding in prediction["result"]["patch_embeddings"]: | |
patch = patch_embedding["patch_coordinate"] | |
embedding_vector = patch_embedding["embedding_vector"] | |
cache_key = json.dumps({"dicom_path": dicom_path, "patch": patch}, sort_keys=True) | |
cache_disk.set(cache_key, embedding_vector) | |
new_patch_embeddings.append({"patch_coordinate": patch, "embedding_vector": embedding_vector}) | |
else: | |
logging.error("Unexpected response format: missing 'result' or 'patch_embeddings'") | |
return None | |
else: | |
logging.error("Unexpected response format: missing 'predictions'") | |
return None | |
return new_patch_embeddings | |
def combine_results(instance, cached_patch_embeddings, new_patch_embeddings, uncached_patch_indices): | |
"""Combines cached and new results.""" | |
final_patch_embeddings = [None] * len(instance['patch_coordinates']) | |
cached_index = 0 | |
new_index = 0 | |
for i in range(len(instance['patch_coordinates'])): | |
if i in uncached_patch_indices: | |
final_patch_embeddings[i] = new_patch_embeddings[new_index] | |
new_index += 1 | |
else: | |
final_patch_embeddings[i] = cached_patch_embeddings[cached_index] | |
cached_index += 1 | |
return final_patch_embeddings | |
def _create_app() -> flask.Flask: | |
"""Creates a Flask app with the given executor.""" | |
# Create credentials and get access token on startup | |
try: | |
global credentials | |
credentials = auth.create_credentials() | |
auth.refresh_credentials(credentials) | |
except ValueError as e: | |
logging.exception(f"Failed to create credentials: {e}") | |
# Handle credential creation failure appropriately, e.g., exit the application. | |
sys.exit(1) | |
# predictor = pete_predictor_v2.PetePredictor() | |
flask_app = flask.Flask(__name__, static_folder='web', static_url_path='') | |
CORS(flask_app, origins='http://localhost:5432') | |
flask_app.config.from_mapping({"CACHE_TYPE": "simple"}) | |
cache = Cache(flask_app) | |
def display_html(): | |
index_path = 'web/index.html' | |
try: | |
with open(index_path, 'r') as f: | |
content = f.read() | |
return Response(content, mimetype='text/html') | |
except FileNotFoundError: | |
abort(404, f"Error: index.html not found at {index_path}") | |
def dicom(url_path): | |
access_token = auth.get_access_token_refresh_if_needed(credentials) | |
if not DICOM_SERVER_URL: | |
abort(http.HTTPStatus.INTERNAL_SERVER_ERROR.value, "DICOM server URL not configured.") | |
full_url = f"{DICOM_SERVER_URL}/{url_path}" | |
headers = dict() # flask.request.headers | |
headers['Authorization'] = f"Bearer {access_token}" | |
try: | |
response = requests.get(full_url, params=flask.request.args, data=flask.request.get_data(), headers=headers) | |
response.raise_for_status() | |
return Response(response.content, status=response.status_code, content_type=response.headers['Content-Type']) | |
except requests.RequestException as e: | |
logging.exception("Error proxying request to DICOM server. %s", e) | |
headers['Authorization'] = "hidden" | |
censored_content = response.content.replace("Bearer " + access_token, "hidden") | |
logging.error("Interal request headers: %s", json.dumps(headers, indent=2)) | |
logging.error("Internal request data: %s", censored_content) | |
abort(http.HTTPStatus.BAD_GATEWAY.value, f"Error proxying request to DICOM server: {e}") | |
def predict(): | |
access_token = auth.get_access_token_refresh_if_needed(credentials) | |
if not PREDICT_SERVER_URL: | |
abort(http.HTTPStatus.INTERNAL_SERVER_ERROR.value, "PREDICT server URL not configured.") | |
headers = { | |
'Authorization': f"Bearer {access_token}", | |
'Content-Type': 'application/json', | |
} | |
try: | |
body = json.loads(flask.request.get_data()) | |
validate_allowed_predict_request(body) | |
except ValueError as e: | |
abort(http.HTTPStatus.BAD_REQUEST.value, f"disallowed {str(e)}") | |
try: | |
body = replace_series_path_prefix(body, "http://localhost:8080/dicom/", "/dicom/") | |
if not test_series_path_prefix(body, '/dicom/'): | |
abort(http.HTTPStatus.BAD_REQUEST.value, "series_path does not start with dicom server url.") | |
body = replace_series_path_prefix(body, "/dicom/", f"{DICOM_SERVER_URL}/") | |
instance = body['instances'][0] # assume single instance | |
dicom_path = instance['dicom_path'] | |
cached_patch_embeddings, uncached_patches, uncached_patch_indices = get_cached_and_uncached_patches(instance, dicom_path) | |
# If all patches are cached, return the cached results | |
if not uncached_patches: | |
return create_gzipped_response({"predictions": [{"result": {"patch_embeddings": cached_patch_embeddings}} ]}) | |
# Prepare the request for uncached patches | |
request_body = {"instances": [{"dicom_path": dicom_path, "patch_coordinates": uncached_patches}]} | |
request_body = provide_dicom_server_token(request_body, access_token) | |
response = requests.post(PREDICT_SERVER_URL, json=request_body, headers=headers) | |
response.raise_for_status() | |
response_json = response.json() | |
new_patch_embeddings = process_new_results(response_json, dicom_path) | |
if new_patch_embeddings is None: | |
abort(http.HTTPStatus.INTERNAL_SERVER_ERROR, "Unexpected response format from predict server") | |
final_patch_embeddings = combine_results(instance, cached_patch_embeddings, new_patch_embeddings, uncached_patch_indices) | |
return create_gzipped_response({"predictions": [{"result": {"patch_embeddings": final_patch_embeddings}} ]}, status=response.status_code) | |
except requests.RequestException as e: | |
headers['Authorization'] = "hidden" | |
censored_content = request_body.content.replace("Bearer " + access_token, "hidden") | |
logging.exception("Error proxying request to predict server: %s", e) | |
print("Internal request headers:", json.dumps(headers, indent=2)) | |
print("Internal request body:", json.dumps(censored_content, indent=2)) | |
abort(http.HTTPStatus.BAD_GATEWAY.value, "Error proxying request to predict server.") | |
except json.JSONDecodeError as e: | |
headers['Authorization'] = "hidden" | |
censored_content = request_body.content.replace("Bearer " + access_token, "hidden") | |
logging.exception("Error decoding JSON response from predict server: %s", e) | |
print("Internal request headers:", json.dumps(headers, indent=2)) | |
print("Internal request body:", json.dumps(censored_content, indent=2)) | |
abort(http.HTTPStatus.BAD_GATEWAY.value, "Error decoding JSON response from predict server.") | |
def download_cache(): | |
""" | |
Downloads the entire cache directory as a zip file. | |
""" | |
print("Downloading cache") | |
print(f"Cache stats: {cache_disk.stats()}") | |
zip_filename = "path-cache.zip" | |
# Use tempfile to create a temporary directory | |
with tempfile.TemporaryDirectory() as temp_dir: | |
zip_filepath = os.path.join(temp_dir, zip_filename) | |
try: | |
shutil.make_archive( | |
os.path.splitext(zip_filepath)[0], | |
"zip", | |
CACHE_DIR, | |
) | |
# Send the file and delete it afterwards | |
return send_file( | |
zip_filepath, | |
mimetype="application/zip", | |
as_attachment=True, | |
download_name=zip_filename, | |
) | |
except Exception as e: | |
current_app.logger.error(f"Error creating zip archive: {e}") | |
abort(500, f"Error creating zip archive: {e}") | |
return flask_app | |
class PredictionApplication(gunicorn_base.BaseApplication): | |
"""Application to serve predictors on Vertex endpoints using gunicorn.""" | |
def __init__( | |
self, | |
*, | |
options: Optional[Mapping[str, Any]] = None, | |
): | |
self.options = options or {} | |
self.options = dict(self.options) | |
self.options["preload_app"] = False | |
self.application = _create_app() | |
super().__init__() | |
def load_config(self): | |
config = { | |
key: value | |
for key, value in self.options.items() | |
if key in self.cfg.settings and value is not None | |
} | |
for key, value in config.items(): | |
self.cfg.set(key.lower(), value) | |
def load(self) -> flask.Flask: | |
return self.application | |
def main(argv: Sequence[str]) -> None: | |
options = {'bind': f'0.0.0.0:8080', | |
'workers': 6, | |
'timeout': 600 | |
} | |
PredictionApplication(options=options).run() | |
if __name__ == '__main__': | |
app.run(main) | |