# Copyright 2023 J.P. Morgan Chase & Co.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.
"""
Functions for testing **inference-server** plugins
"""
import io
import pathlib
from types import ModuleType
from typing import Any, Callable, Optional, Protocol, Tuple, Type, Union
import botocore.response # type: ignore[import-untyped]
import pluggy
import pytest
import werkzeug.test
import inference_server
import inference_server._plugin
[docs]
class ImplementsSerialize(Protocol):
"""Interface compatible with :class:`sagemaker.serializers.BaseSerializer`"""
@property
def CONTENT_TYPE(self) -> str:
"""The MIME type for the serialized data"""
[docs]
def serialize(self, data: Any) -> bytes:
"""Return the serialized data"""
[docs]
class ImplementsDeserialize(Protocol):
"""Interface compatible with :class:`sagemaker.deserializers.BaseDeserializer`"""
@property
def ACCEPT(self) -> Tuple[str]:
"""The content types that are supported by this deserializer"""
[docs]
def deserialize(self, stream: botocore.response.StreamingBody, content_type: str) -> Any:
"""Return the deserialized data"""
class _PassThroughSerializer:
"""Serialize bytes as bytes"""
@property
def CONTENT_TYPE(self) -> str:
"""The MIME type for the serialized data"""
return "application/octet-stream"
def serialize(self, data: bytes) -> bytes:
"""Return the serialized data"""
assert isinstance(data, bytes)
return data
class _PassThroughDeserializer:
"""Deserialize bytes as bytes"""
@property
def ACCEPT(self) -> Tuple[str]:
"""The content types that are supported by this deserializer"""
return ("application/octet-stream",)
def deserialize(self, stream: "botocore.response.StreamingBody", content_type: str) -> Any:
"""Return the deserialized data"""
assert content_type in self.ACCEPT
try:
return stream.read()
finally:
stream.close()
[docs]
def predict(
data: Any,
*,
model_dir: Optional[pathlib.Path] = None,
serializer: Optional[ImplementsSerialize] = None,
deserializer: Optional[ImplementsDeserialize] = None,
) -> Any:
"""
Invoke the model and return a prediction
:param data: Model input data
:param model_dir: Optional pass a custom model directory to load the model from. Default is
:file:`/opt/ml/model/`.
:param serializer: Optional. A serializer for sending the data as bytes to the model server. Should be compatible
with :class:`sagemaker.serializers.BaseSerializer`. Default: bytes pass-through.
:param deserializer: Optional. A deserializer for processing the prediction as sent by the model server. Should be
compatible with :class:`sagemaker.deserializers.BaseDeserializer`. Default: bytes pass-through.
"""
serializer = serializer or _PassThroughSerializer()
deserializer = deserializer or _PassThroughDeserializer()
serialized_data = serializer.serialize(data)
http_headers = {
"Content-Type": serializer.CONTENT_TYPE, # The serializer declares the content-type of the input data
"Accept": ", ".join(deserializer.ACCEPT), # The deserializer dictates the content-type of the prediction
}
prediction_response = post_invocations(model_dir=model_dir, data=serialized_data, headers=http_headers)
prediction_stream = botocore.response.StreamingBody(
raw_stream=io.BytesIO(prediction_response.data),
content_length=prediction_response.content_length,
)
prediction_deserialized = deserializer.deserialize(prediction_stream, content_type=prediction_response.content_type)
return prediction_deserialized
[docs]
def client() -> werkzeug.test.Client:
"""
Return an HTTP test client for :mod:`inference_server`
The test client is simply a :class:`werkzeug.test.Client` instance which loads the **inference-server** WSGI app.
Consult the :mod:`werkzeug` documentation for details how to use the test client.
"""
return werkzeug.test.Client(inference_server.create_app())
[docs]
def post_invocations(*, model_dir: Optional[pathlib.Path] = None, **kwargs) -> werkzeug.test.TestResponse:
"""
Send an HTTP POST request to ``/invocations`` using a test HTTP client and return the response
This function should be used to verify an inference request using the full **inference-server** logic.
:param model_dir: Optional pass a custom model directory to load the model from. Default is :file:`/opt/ml/model/`.
:param kwargs: Keyword arguments passed to :meth:`werkzeug.test.Client.post`
"""
# pytest should be available when we are using inference_server.testing
with pytest.MonkeyPatch.context() as monkeypatch:
if model_dir:
monkeypatch.setattr(inference_server, "_MODEL_DIR", str(model_dir))
response = client().post("/invocations", **kwargs)
assert response.status_code == 200
return response
[docs]
def plugin_manager() -> pluggy.PluginManager:
"""Return the plugin manager used by **inference-server**"""
return inference_server._plugin.manager()
[docs]
def plugin_is_registered(plugin: Union[Type, ModuleType]) -> bool:
"""
Return whether the given plugin is registered with :mod:`inference_server`
This validates whether a plugin entrypoint is defined in :file:`pyproject.toml` like this:
.. code-block:: toml
[project.entry-points.inference_server]
my_plugin_name = "my_module_name"
:param plugin: The plugin, typically a module containg the hook functions.
"""
return plugin_manager().is_registered(plugin)
[docs]
def hookimpl_is_valid(function: Callable) -> bool:
"""
Return whether the given function is a valid implementation of an :mod:`inference_server` hook
:param function: The hook function to validate
"""
try:
hook = getattr(plugin_manager().hook, function.__name__)
except AttributeError:
return False
return function in (impl.function for impl in hook.get_hookimpls())