# Copyright 2023 J.P. Morgan Chase & Co.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.
"""
Plugin logic
Hook definitions adapted from https://docs.aws.amazon.com/sagemaker/latest/dg/adapt-inference-container.html
"""
import functools
import logging
import sys
from typing import TYPE_CHECKING, Any, Tuple
import pluggy
import werkzeug.datastructures
if TYPE_CHECKING:
import inference_server
ModelType = Any
DataType = Any
PredictionType = Any
logger = logging.getLogger(__package__)
#: Decorator for plugin hook functions
hook = pluggy.HookimplMarker(__package__)
#: Decorator for plugin hook function specifications/signatures
hookspec = pluggy.HookspecMarker(__package__)
[docs]
@hookspec(firstresult=True)
def model_fn(model_dir: str) -> ModelType:
"""
A function which loads a model in memory from a given filesystem directory.
This function will be called when the server starts up. Here, ``ModelType`` can be any Python class corresponding to
the model, for example :class:`sklearn.tree.DecisionTreeClassifier`.
:param model_dir: Local filesystem directory containing the model files. This is always :file:`/opt/ml/model` when
invoked by **inference-server**.
"""
raise NotImplementedError
[docs]
@hookspec(firstresult=True)
def ping_fn(model: ModelType) -> bool:
"""
A functions wich indicates whether the web application is up and running.
In most cases, there is no need to implement this function as the default implementation simply returns ``True``.
Such an implementation simply confirms that the model is initialized correctly using the :func:`model_fn` and that
the webserver (e.g. Gunicorn) is able to respond to HTTP requests.
For more advanced scenarios, any logic could be implemented provided it's reasonably fast.
:param model: Model object
"""
raise NotImplementedError
[docs]
@hookspec(firstresult=True)
def predict_fn(data: DataType, model: ModelType) -> PredictionType:
"""
A function which invokes the model and returns a prediction.
Argument ``data`` will be populated with the output from :func:`input_fn` and ``model`` will be the output from
:func:`model_fn`. Data types should therefore correspond between these functions.
Apply inference to input features and return a prediction
The predict_fn function is responsible for getting predictions from the model. It takes the model and the data
returned from input_fn as parameters, and returns the prediction.
:param data: Deserialized input features (the output from :func:`input_fn`)
:param model: Model object (the output from :func:`model_fn`)
"""
raise NotImplementedError
[docs]
@hookspec(firstresult=True)
def output_fn(prediction: PredictionType, accept: werkzeug.datastructures.MIMEAccept) -> Tuple[bytes, str]:
"""
A function which seriazizes and returns the prediction as bytes along with the corresponding MIME type.
The returned data would typically be JSON bytes (MIME type ``aplication/json``), but in theory this could be any
serialization format as long as the application which invokes the prediction *accepts* this type. The
:func:`output_fn` implementation should therefore compare the ``accept`` argument value with the implemented
serialization format(s).
:param prediction: The output from the model as return by :func:`predict_fn`
:param accept: MIME type(s) requested/accepted by the client, e.g. ``application/json``
"""
raise NotImplementedError
[docs]
@hookspec(firstresult=True)
def batch_strategy() -> "inference_server.BatchStrategy":
"""
Return the default Batch Transform invocation strategy for this model
Default: :attr:`inference_server.BatchStrategy.MULTI_RECORD`
If users do not specify a strategy when creating a Batch Transform job, the strategy returned by this hook will be
used.
A model may support one or multiple invocation strategies depending on its implementation of the server hooks.
"""
raise NotImplementedError
[docs]
@hookspec(firstresult=True)
def max_payload_in_mb() -> int:
"""
Return the maximum allowed size in MB of a single record submitted by a Batch Transform job to the model
Default: ``6`` (MB)
The value of :func:`max_payload_in_mb` × :func:`max_concurrent_transforms` should be ≤ 100 MB.
"""
raise NotImplementedError
@functools.lru_cache(maxsize=None)
def manager() -> pluggy.PluginManager:
"""
Return a manager to discover and load plugins for providing hooks
Plugins are automatically loaded through (setuptools) entrypoints, group ``inference_server``.
"""
from inference_server import default_plugin
logger.debug("Initializing plugin manager for '%s'", __package__)
manager_ = pluggy.PluginManager(__package__)
manager_.add_hookspecs(sys.modules[__name__])
logger.debug("Loading default plugin '%s'", default_plugin.__name__)
manager_.register(default_plugin)
logger.debug("Discovering plugins using entrypoint group '%s'", __package__)
manager_.load_setuptools_entrypoints(group=__package__)
logger.debug("Loaded plugins: %s", manager_.get_plugins())
return manager_