Implementing server hooks

This page explains how to implement the inference-server hooks to deploy a model to Amazon SageMaker.

Hook definitions

inference-server defines 4 hooks that can be implemented to deploy a model as a SageMaker Docker container:

model_fn(model_dir: str) Any[source]

A function which loads a model in memory from a given filesystem directory.

This function will be called when the server starts up. Here, ModelType can be any Python class corresponding to the model, for example sklearn.tree.DecisionTreeClassifier.

Parameters:

model_dir – Local filesystem directory containing the model files

input_fn(input_data: bytes, content_type: str) Any[source]

A function which converts data sent over an HTTP connection to the input data format the model is expecting.

Typically, the HTTP transport uses JSON bytes, but in theory, this could be any serialization format.

The input_fn() is called as the first hook for each inference invocation/request.

Parameters:
  • input_data – Raw HTTP body data

  • content_type – The content type (MIME) corresponding with the body data, e.g. application/json

predict_fn(data: Any, model: Any) Any[source]

A function which invokes the model and returns a prediction.

Argument data will be populated with the output from input_fn() and model will be the output from model_fn(). Data types should therefore correspond between these functions.

Apply inference to input features and return a prediction

The predict_fn function is responsible for getting predictions from the model. It takes the model and the data returned from input_fn as parameters, and returns the prediction.

Parameters:
  • data – Deserialized input features (the output from input_fn())

  • model – Model object (the output from model_fn())

output_fn(prediction: Any, accept: MIMEAccept) Tuple[bytes, str][source]

A function which seriazizes and returns the prediction as bytes along with the corresponding MIME type.

The returned data would typically be JSON bytes (MIME type aplication/json), but in theory this could be any serialization format as long as the application which invokes the prediction accepts this type. The output_fn() implementation should therefore compare the accept argument value with the implemented serialization format(s).

Parameters:
  • prediction – The output from the model as return by predict_fn()

  • accept – MIME type(s) requested/accepted by the client, e.g. application/json

Note

These hooks follow an API very similar to the Amazon SageMaker Inference Toolkit. See https://docs.aws.amazon.com/sagemaker/latest/dg/adapt-inference-container.html

A 5th hook is available to implement custom health-checks if required:

ping_fn(model: Any) bool[source]

A functions wich indicates whether the web application is up and running.

In most cases, there is no need to implement this function as the default implementation simply returns True. Such an implementation simply confirms that the model is initialized correctly using the model_fn() and that the webserver (e.g. Gunicorn) is able to respond to HTTP requests.

For more advanced scenarios, any logic could be implemented provided it’s reasonably fast.

Parameters:

model – Model object

Implementing model hooks

To implement the server hooks for a model, we need to create a Python package. In this example, we will be deploying a shipping (weather) forecast model.

We will setup the following files:

shipping_forecast/
├── pyproject.toml
├── src/
│   └── shipping_forecast/
│       └── __init__.py
└── tests/

Then we define the server hook functions inside __init__.py.

First of all, the actual model:

import inference_server

DataType = str
PredictionType = Dict[str, str]
ModelType = Callable[[DataType], PredictionType]

@inference_server.plugin_hook
def model_fn(model_dir: str) -> ModelType:
    """Return a function that returns the weather forecast for a given location at sea"""
    return _predict_weather

def _predict_weather(location: DataType) -> PredictionType:
    """It's stormy everywhere"""
    return {
        "wind": "Southwesterly gale force 8 continuing",
        "sea_state": "Rough or very rough, occasionally moderate in southeast.",
        "weather": "Thundery showers.",
        "visibility": "Good, occasionally poor.",
    }

(Type hint imports have been omitted for brevity.)

To make predictions, we implement the following function:

@inference_server.plugin_hook
def predict_fn(data: DataType, model: ModelType) -> PredictionType:
    """Invoke a prediction for given input data"""
    return model(data)

The above implementation is a common pattern: inference-server ships with a default implementation for this hook which does exactly that. If we are happy with that default implementation, there is no need to define our own predict_fn()!

Implementing deserialization/serialization hooks

To integrate the model with the HTTP server, we need to wire up the deserialization and serialization functions.

To deserialize the input data, let’s assume the following JSON payload should be sent for a single invocation:

{"location": "Fair Isle"}

With the shipping forecast model that would require the following input_fn():

import orjson

@inference_server.plugin_hook
def input_fn(input_data: bytes, content_type: Literal["application/json"]) -> DataType:
    """Deserialize JSON bytes and return ``location`` attribute"""
    return orjson.loads(input_data)["location"]

Bear in mind that this a fairly naive implementation of course as it does not apply any input validation on the payload or the content type. Here we use a fast JSON serializer orjson which natively serializes to and from bytes instead of string objects.

In this example, the predictions should be returned using the following JSON structure:

{
    "wind": "Southwesterly gale force 8 continuing",
    "sea_state": "Rough or very rough, occasionally moderate in southeast.",
    "weather": "Thundery showers.",
    "visibility": "Good, occasionally poor."
}

That requires a simple output_fn() like this:

@inference_server.plugin_hook
def output_fn(prediction: PredictionType, accept: inference_server.MIMEAccept) -> Tuple[bytes, str]:
    """Serialize predictions as JSON"""
    assert accept.accept_json
    return orjson.dumps(prediction), "application/json"

This function validates that a JSON serialization is acceptable to the application invoking the prediction. However, error handling should be improved here.

Tip

The 4 plugin hooks may be implemented and installed using different Python modules or packages. For example, this could be used to develop a package with JSON serialization/deserialization hooks which could be shared between many different model packages. In that case, the model packages would need to define the model_fn() and predict_fn() hooks only.

Registering the hooks

To register the hooks with inference-server, we need to add some metadata to the shipping_forecast package. In this example, we use setuptools with entry point metadata defined in pyproject.toml. Other build backends may support entry point definitions too.

Add the following content to pyproject.toml:

[project.entry-points.inference_server]
shipping_forecast = "shipping_forecast"

This configuration states that we are a supplying an entry point under the group inference_server. The entry point is named shipping_forecast and it refers to the package shipping_forecast since we defined the server hooks in shipping_forecast/__init__.py.

Note

Additional package metadata should be recorded as with any other Python package. See https://setuptools.pypa.io/ for further details.