Spaces:
Running
Running
# Copyright 2024 Google LLC | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Gunicorn application for passing requests through to the executor command. | |
Provides a thin, subject-agnostic request server for Vertex endpoints which | |
handles requests by piping their JSON bodies to the given executor command | |
and returning the json output. | |
""" | |
from collections.abc import Mapping | |
import http | |
import os | |
from typing import Any, Optional, Sequence | |
from absl import app | |
from absl import logging | |
import flask | |
from gunicorn.app import base as gunicorn_base | |
import pete_predictor_v2 | |
def _create_app() -> flask.Flask: | |
"""Creates a Flask app with the given executor.""" | |
predictor = pete_predictor_v2.PetePredictor() | |
flask_app = flask.Flask(__name__) | |
def predict() -> tuple[dict[str, Any], int]: | |
logging.info("predict route hit") | |
if flask.request.get_json(silent=True) is None: | |
return {"error": "No JSON body."}, http.HTTPStatus.BAD_REQUEST.value | |
logging.debug("Dispatching request to executor.") | |
try: | |
exec_result = predictor.predict(flask.request.get_json()) | |
logging.debug("Executor returned results.") | |
return (exec_result, http.HTTPStatus.OK.value) | |
except RuntimeError: | |
logging.exception("Internal error handling request: Executor failed.") | |
return { | |
"error": "Internal server error." | |
}, http.HTTPStatus.INTERNAL_SERVER_ERROR.value | |
predict_route = os.environ.get("AIP_PREDICT_ROUTE", "/predict") | |
logging.info("predict route: %s", predict_route) | |
flask_app.add_url_rule(predict_route, view_func=predict, methods=["POST"]) | |
flask_app.config["TRAP_BAD_REQUEST_ERRORS"] = True | |
return flask_app | |
class PredictionApplication(gunicorn_base.BaseApplication): | |
"""Application to serve predictors on Vertex endpoints using gunicorn.""" | |
def __init__( | |
self, | |
*, | |
options: Optional[Mapping[str, Any]] = None, | |
): | |
self.options = options or {} | |
self.options = dict(self.options) | |
self.options["preload_app"] = False | |
self.application = _create_app() | |
super().__init__() | |
def load_config(self): | |
config = { | |
key: value | |
for key, value in self.options.items() | |
if key in self.cfg.settings and value is not None | |
} | |
for key, value in config.items(): | |
self.cfg.set(key.lower(), value) | |
def load(self) -> flask.Flask: | |
return self.application | |
def main(argv: Sequence[str]) -> None: | |
options = {'bind': f'127.0.0.1:80', | |
'workers': 3, | |
'timeout':600 | |
} | |
PredictionApplication(options=options).run() | |
if __name__ == '__main__': | |
app.run(main) | |