Source code for inference_server

# Copyright 2023 J.P. Morgan Chase & Co.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

"""
Pluggable Python HTTP web service (WSGI) for real-time AI/ML model inference compatible with Amazon SageMaker
"""

import enum
import functools
import http
import logging
from typing import TYPE_CHECKING

import codetiming
import orjson
import werkzeug
import werkzeug.exceptions
from werkzeug.datastructures import MIMEAccept

import inference_server._plugin
from inference_server._plugin import hook as plugin_hook

if TYPE_CHECKING:
    from _typeshed.wsgi import WSGIApplication

try:
    from importlib import metadata
except ImportError:  # pragma: no cover
    # Python < 3.8
    import importlib_metadata as metadata  # type: ignore

__all__ = (
    "BatchStrategy",
    "MIMEAccept",  # Exporting for plugin developers' convenience
    "create_app",
    "plugin_hook",
    "warmup",
)

#: Library version, e.g. 1.0.0, taken from Git tags
__version__ = metadata.version("inference-server")

#: Well known location for model artifacts
_MODEL_DIR = "/opt/ml/model"

logger = logging.getLogger(__package__)


[docs] class BatchStrategy(enum.Enum): """ Enumeration of Batch Transform invocation strategies Specifies the number of records to include in a mini-batch for an HTTP inference request. A record is a single unit of input data that inference can be made on. For example, a single line in a CSV file is a record. See: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateTransformJob.html#sagemaker-CreateTransformJob-request-BatchStrategy """ # noqa: E501 #: Batch Transform job to invoke the model with a single record per request SINGLE_RECORD = "SingleRecord" #: Batch Transform job to invoke the model with multiple records per request MULTI_RECORD = "MultiRecord"
[docs] def create_app() -> "WSGIApplication": """ Initialize and return the WSGI application This is the WSGI application factory function that needs to be passed to a WSGI-compatible web server. """ return _app
[docs] def warmup() -> None: """ Initialize any additional resources upfront This will call the ``model_fn`` plugin hook. """ _model()
@werkzeug.Request.application def _app(request: werkzeug.Request) -> werkzeug.Response: """Return the WSGI application""" try: route_handler = _ROUTES[(request.method, request.path)] except KeyError: raise werkzeug.exceptions.NotFound() response = route_handler(request) return response def _handle_invocations(request: werkzeug.Request) -> werkzeug.Response: """ Handle an incoming inference POST request :param request: HTTP request data """ with codetiming.Timer(text="Invocation took {:.3f} seconds", logger=logger.debug): pm = inference_server._plugin.manager() # Deserialize HTTP body payload (bytes) into input features data = pm.hook.input_fn(input_data=request.data, content_type=request.content_type) # Then use the model to make a prediction prediction = pm.hook.predict_fn(data=data, model=_model()) # Then serialize the data as bytes. This is often (but not necessarily) JSON bytes. prediction_bytes, content_type = pm.hook.output_fn(prediction=prediction, accept=request.accept_mimetypes) return werkzeug.Response(prediction_bytes, mimetype=content_type) def _handle_ping(request: werkzeug.Request) -> werkzeug.Response: """ Handle an incoming ping HEAD request :param request: HTTP request data """ pm = inference_server._plugin.manager() if pm.hook.ping_fn(model=_model()): status = http.HTTPStatus.OK else: status = http.HTTPStatus.SERVICE_UNAVAILABLE return werkzeug.Response(status=status) def _handle_execution_parameters(request: werkzeug.Request): """ Handle an incoming execution-parameters GET request This will enable BatchTransform job to choose the optimal tuning parameters during runtime. :param request: HTTP request data """ pm = inference_server._plugin.manager() response_data = { "BatchStrategy": pm.hook.batch_strategy(), "MaxConcurrentTransforms": pm.hook.max_concurrent_transforms(), "MaxPayloadInMB": pm.hook.max_payload_in_mb(), } return werkzeug.Response(orjson.dumps(response_data), mimetype="application/json") # Stupidly simple request routing _ROUTES = { ("GET", "/execution-parameters"): _handle_execution_parameters, ("POST", "/invocations"): _handle_invocations, ("GET", "/ping"): _handle_ping, } @functools.lru_cache(maxsize=None) def _model() -> inference_server._plugin.ModelType: """ Load a previously serialized ML model from a given filesystem directory """ pm = inference_server._plugin.manager() logger.info("Loading model using 'model_fn' hook...") model = pm.hook.model_fn(model_dir=_MODEL_DIR) logger.info("Finished loading model %s", model) return model