Implementing server hooks
This page explains how to implement the inference-server hooks to deploy a model to Amazon SageMaker.
Hook definitions
inference-server defines 4 hooks that can be implemented to deploy a model as a SageMaker Docker container:
- model_fn(model_dir: str) Any [source]
A function which loads a model in memory from a given filesystem directory.
This function will be called when the server starts up. Here,
ModelType
can be any Python class corresponding to the model, for examplesklearn.tree.DecisionTreeClassifier
.- Parameters:
model_dir – Local filesystem directory containing the model files
- input_fn(input_data: bytes, content_type: str) Any [source]
A function which converts data sent over an HTTP connection to the input data format the model is expecting.
Typically, the HTTP transport uses JSON bytes, but in theory, this could be any serialization format.
The
input_fn()
is called as the first hook for each inference invocation/request.- Parameters:
input_data – Raw HTTP body data
content_type – The content type (MIME) corresponding with the body data, e.g.
application/json
- predict_fn(data: Any, model: Any) Any [source]
A function which invokes the model and returns a prediction.
Argument
data
will be populated with the output frominput_fn()
andmodel
will be the output frommodel_fn()
. Data types should therefore correspond between these functions.Apply inference to input features and return a prediction
The predict_fn function is responsible for getting predictions from the model. It takes the model and the data returned from input_fn as parameters, and returns the prediction.
- Parameters:
data – Deserialized input features (the output from
input_fn()
)model – Model object (the output from
model_fn()
)
- output_fn(prediction: Any, accept: MIMEAccept) Tuple[bytes, str] [source]
A function which seriazizes and returns the prediction as bytes along with the corresponding MIME type.
The returned data would typically be JSON bytes (MIME type
aplication/json
), but in theory this could be any serialization format as long as the application which invokes the prediction accepts this type. Theoutput_fn()
implementation should therefore compare theaccept
argument value with the implemented serialization format(s).- Parameters:
prediction – The output from the model as return by
predict_fn()
accept – MIME type(s) requested/accepted by the client, e.g.
application/json
Note
These hooks follow an API very similar to the Amazon SageMaker Inference Toolkit. See https://docs.aws.amazon.com/sagemaker/latest/dg/adapt-inference-container.html
A 5th hook is available to implement custom health-checks if required:
- ping_fn(model: Any) bool [source]
A functions wich indicates whether the web application is up and running.
In most cases, there is no need to implement this function as the default implementation simply returns
True
. Such an implementation simply confirms that the model is initialized correctly using themodel_fn()
and that the webserver (e.g. Gunicorn) is able to respond to HTTP requests.For more advanced scenarios, any logic could be implemented provided it’s reasonably fast.
- Parameters:
model – Model object
Implementing model hooks
To implement the server hooks for a model, we need to create a Python package. In this example, we will be deploying a shipping (weather) forecast model.
We will setup the following files:
shipping_forecast/
├── pyproject.toml
├── src/
│ └── shipping_forecast/
│ └── __init__.py
└── tests/
Then we define the server hook functions inside __init__.py
.
First of all, the actual model:
import inference_server
DataType = str
PredictionType = Dict[str, str]
ModelType = Callable[[DataType], PredictionType]
@inference_server.plugin_hook
def model_fn(model_dir: str) -> ModelType:
"""Return a function that returns the weather forecast for a given location at sea"""
return _predict_weather
def _predict_weather(location: DataType) -> PredictionType:
"""It's stormy everywhere"""
return {
"wind": "Southwesterly gale force 8 continuing",
"sea_state": "Rough or very rough, occasionally moderate in southeast.",
"weather": "Thundery showers.",
"visibility": "Good, occasionally poor.",
}
(Type hint imports have been omitted for brevity.)
To make predictions, we implement the following function:
@inference_server.plugin_hook
def predict_fn(data: DataType, model: ModelType) -> PredictionType:
"""Invoke a prediction for given input data"""
return model(data)
The above implementation is a common pattern: inference-server ships with a default implementation for
this hook which does exactly that. If we are happy with that default implementation, there is no need to define our own
predict_fn()
!
Implementing deserialization/serialization hooks
To integrate the model with the HTTP server, we need to wire up the deserialization and serialization functions.
To deserialize the input data, let’s assume the following JSON payload should be sent for a single invocation:
{"location": "Fair Isle"}
With the shipping forecast model that would require the following input_fn()
:
import orjson
@inference_server.plugin_hook
def input_fn(input_data: bytes, content_type: Literal["application/json"]) -> DataType:
"""Deserialize JSON bytes and return ``location`` attribute"""
return orjson.loads(input_data)["location"]
Bear in mind that this a fairly naive implementation of course as it does not apply any input validation on the payload
or the content type. Here we use a fast JSON serializer orjson
which natively serializes to and from bytes
instead of string objects.
In this example, the predictions should be returned using the following JSON structure:
{
"wind": "Southwesterly gale force 8 continuing",
"sea_state": "Rough or very rough, occasionally moderate in southeast.",
"weather": "Thundery showers.",
"visibility": "Good, occasionally poor."
}
That requires a simple output_fn()
like this:
@inference_server.plugin_hook
def output_fn(prediction: PredictionType, accept: inference_server.MIMEAccept) -> Tuple[bytes, str]:
"""Serialize predictions as JSON"""
assert accept.accept_json
return orjson.dumps(prediction), "application/json"
This function validates that a JSON serialization is acceptable to the application invoking the prediction. However, error handling should be improved here.
Tip
The 4 plugin hooks may be implemented and installed using different Python modules or packages. For example, this
could be used to develop a package with JSON serialization/deserialization hooks which could be shared between many
different model packages. In that case, the model packages would need to define the model_fn()
and
predict_fn()
hooks only.
Registering the hooks
To register the hooks with inference-server, we need to add some metadata to the shipping_forecast
package. In this example, we use setuptools
with entry point metadata defined in pyproject.toml
.
Other build backends may support entry point definitions too.
See also
- Setuptools entry point documentation
https://setuptools.pypa.io/en/latest/userguide/entry_point.html#entry-points-for-plugins
Add the following content to pyproject.toml
:
[project.entry-points.inference_server]
shipping_forecast = "shipping_forecast"
This configuration states that we are a supplying an entry point under the group inference_server
. The entry
point is named shipping_forecast
and it refers to the package shipping_forecast
since we defined the
server hooks in shipping_forecast/__init__.py
.
Note
Additional package metadata should be recorded as with any other Python package. See https://setuptools.pypa.io/ for further details.