Skip to content

interoperability with asyncio (part 2): integration with aiohttp #175

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ readme = "README.md"
dynamic = ["version"]
requires-python = ">= 3.8"
dependencies = [
"aiohttp >= 3.9.4",
"grpcio >= 1.60.0",
"protobuf >= 4.24.0",
"types-protobuf >= 4.24.0.20240129",
Expand Down
131 changes: 131 additions & 0 deletions src/dispatch/aiohttp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from typing import Optional, Union

from aiohttp import web

from dispatch.function import Registry
from dispatch.http import (
FunctionServiceError,
function_service_run,
make_error_response_body,
)
from dispatch.signature import Ed25519PublicKey, parse_verification_key


class Dispatch(web.Application):
"""A Dispatch instance servicing as a http server."""

registry: Registry
verification_key: Optional[Ed25519PublicKey]

def __init__(
self,
registry: Registry,
verification_key: Optional[Union[Ed25519PublicKey, str, bytes]] = None,
):
"""Initialize a Dispatch application.

Args:
registry: The registry of functions to be serviced.

verification_key: The verification key to use for requests.
"""
super().__init__()
self.registry = registry
self.verification_key = parse_verification_key(verification_key)
self.add_routes(
[
web.post(
"/dispatch.sdk.v1.FunctionService/Run", self.handle_run_request
),
]
)

async def handle_run_request(self, request: web.Request) -> web.Response:
return await function_service_run_handler(
request, self.registry, self.verification_key
)


class Server:
host: str
port: int
app: Dispatch

_runner: web.AppRunner
_site: web.TCPSite

def __init__(self, host: str, port: int, app: Dispatch):
self.host = host
self.port = port
self.app = app

async def __aenter__(self):
await self.start()
return self

async def __aexit__(self, exc_type, exc_value, traceback):
await self.stop()

async def start(self):
self._runner = web.AppRunner(self.app)
await self._runner.setup()

self._site = web.TCPSite(self._runner, self.host, self.port)
await self._site.start()

async def stop(self):
await self._site.stop()
await self._runner.cleanup()


def make_error_response(status: int, code: str, message: str) -> web.Response:
body = make_error_response_body(code, message)
return web.Response(status=status, content_type="application/json", body=body)


def make_error_response_invalid_argument(message: str) -> web.Response:
return make_error_response(400, "invalid_argument", message)


def make_error_response_not_found(message: str) -> web.Response:
return make_error_response(404, "not_found", message)


def make_error_response_unauthenticated(message: str) -> web.Response:
return make_error_response(401, "unauthenticated", message)


def make_error_response_permission_denied(message: str) -> web.Response:
return make_error_response(403, "permission_denied", message)


def make_error_response_internal(message: str) -> web.Response:
return make_error_response(500, "internal", message)


async def function_service_run_handler(
request: web.Request,
function_registry: Registry,
verification_key: Optional[Ed25519PublicKey],
) -> web.Response:
content_length = request.content_length
if content_length is None or content_length == 0:
return make_error_response_invalid_argument("content length is required")
if content_length < 0:
return make_error_response_invalid_argument("content length is negative")
if content_length > 16_000_000:
return make_error_response_invalid_argument("content length is too large")

data: bytes = await request.read()
try:
content = await function_service_run(
str(request.url),
request.method,
dict(request.headers),
data,
function_registry,
verification_key,
)
except FunctionServiceError as e:
return make_error_response(e.status, e.code, e.message)
return web.Response(status=200, content_type="application/proto", body=content)
15 changes: 11 additions & 4 deletions src/dispatch/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@


class Dispatch:
"""A Dispatch instance to be serviced by a http server. The Dispatch class
acts as a factory for DispatchHandler objects, by capturing the variables
that would be shared between all DispatchHandler instances it created."""
"""A Dispatch instance servicing as a http server."""

registry: Registry
verification_key: Optional[Ed25519PublicKey]

def __init__(
self,
Expand All @@ -38,6 +39,8 @@ def __init__(

Args:
registry: The registry of functions to be serviced.

verification_key: The verification key to use for requests.
"""
self.registry = registry
self.verification_key = parse_verification_key(verification_key)
Expand Down Expand Up @@ -92,7 +95,7 @@ def send_error_response_internal(self, message: str):
self.send_error_response(500, "internal", message)

def send_error_response(self, status: int, code: str, message: str):
body = f'{{"code":"{code}","message":"{message}"}}'.encode()
body = make_error_response_body(code, message)
self.send_response(status)
self.send_header("Content-Type", self.error_content_type)
self.send_header("Content-Length", str(len(body)))
Expand Down Expand Up @@ -234,3 +237,7 @@ async def function_service_run(

logger.debug("finished handling run request with status %s", status.name)
return response.SerializeToString()


def make_error_response_body(code: str, message: str) -> bytes:
return f'{{"code":"{code}","message":"{message}"}}'.encode()
130 changes: 130 additions & 0 deletions tests/test_aiohttp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import asyncio
import base64
import os
import pickle
import struct
import threading
import unittest
from typing import Any, Tuple
from unittest import mock

import fastapi
import google.protobuf.any_pb2
import google.protobuf.wrappers_pb2
import httpx
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PublicKey

import dispatch.test.httpx
from dispatch.aiohttp import Dispatch, Server
from dispatch.asyncio import Runner
from dispatch.experimental.durable.registry import clear_functions
from dispatch.function import Arguments, Error, Function, Input, Output, Registry
from dispatch.proto import _any_unpickle as any_unpickle
from dispatch.sdk.v1 import call_pb2 as call_pb
from dispatch.sdk.v1 import function_pb2 as function_pb
from dispatch.signature import parse_verification_key, public_key_from_pem
from dispatch.status import Status
from dispatch.test import EndpointClient

public_key_pem = "-----BEGIN PUBLIC KEY-----\nMCowBQYDK2VwAyEAJrQLj5P/89iXES9+vFgrIy29clF9CC/oPPsw3c5D0bs=\n-----END PUBLIC KEY-----"
public_key_pem2 = "-----BEGIN PUBLIC KEY-----\\nMCowBQYDK2VwAyEAJrQLj5P/89iXES9+vFgrIy29clF9CC/oPPsw3c5D0bs=\\n-----END PUBLIC KEY-----"
public_key = public_key_from_pem(public_key_pem)
public_key_bytes = public_key.public_bytes_raw()
public_key_b64 = base64.b64encode(public_key_bytes)

from datetime import datetime


def run(runner: Runner, server: Server, ready: threading.Event):
try:
with runner:
runner.run(serve(server, ready))
except RuntimeError as e:
pass # silence errors triggered by stopping the loop after tests are done


async def serve(server: Server, ready: threading.Event):
async with server:
ready.set() # allow the test to continue after the server started
await asyncio.Event().wait()


class TestAIOHTTP(unittest.TestCase):
def setUp(self):
ready = threading.Event()
self.runner = Runner()

host = "127.0.0.1"
port = 9997

self.endpoint = f"http://{host}:{port}"
self.dispatch = Dispatch(
Registry(
endpoint=self.endpoint,
api_key="0000000000000000",
api_url="http://127.0.0.1:10000",
),
)

self.client = httpx.Client(timeout=1.0)
self.server = Server(host, port, self.dispatch)
self.thread = threading.Thread(
target=lambda: run(self.runner, self.server, ready)
)
self.thread.start()
ready.wait()

def tearDown(self):
loop = self.runner.get_loop()
loop.call_soon_threadsafe(loop.stop)
self.thread.join(timeout=1.0)
self.client.close()

def test_content_length_missing(self):
resp = self.client.post(f"{self.endpoint}/dispatch.sdk.v1.FunctionService/Run")
body = resp.read()
self.assertEqual(resp.status_code, 400)
self.assertEqual(
body, b'{"code":"invalid_argument","message":"content length is required"}'
)

def test_content_length_too_large(self):
resp = self.client.post(
f"{self.endpoint}/dispatch.sdk.v1.FunctionService/Run",
data={"msg": "a" * 16_000_001},
)
body = resp.read()
self.assertEqual(resp.status_code, 400)
self.assertEqual(
body, b'{"code":"invalid_argument","message":"content length is too large"}'
)

def test_simple_request(self):
@self.dispatch.registry.primitive_function
async def my_function(input: Input) -> Output:
return Output.value(
f"You told me: '{input.input}' ({len(input.input)} characters)"
)

http_client = dispatch.test.httpx.Client(httpx.Client(base_url=self.endpoint))
client = EndpointClient(http_client)

pickled = pickle.dumps("Hello World!")
input_any = google.protobuf.any_pb2.Any()
input_any.Pack(google.protobuf.wrappers_pb2.BytesValue(value=pickled))

req = function_pb.RunRequest(
function=my_function.name,
input=input_any,
)

resp = client.run(req)

self.assertIsInstance(resp, function_pb.RunResponse)

resp.exit.result.output.Unpack(
output_bytes := google.protobuf.wrappers_pb2.BytesValue()
)
output = pickle.loads(output_bytes.value)

self.assertEqual(output, "You told me: 'Hello World!' (12 characters)")
28 changes: 14 additions & 14 deletions tests/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import threading
import unittest
from http.server import HTTPServer
from typing import Any
from typing import Any, Tuple
from unittest import mock

import fastapi
Expand Down Expand Up @@ -34,21 +34,21 @@
from datetime import datetime


def create_dispatch_instance(endpoint: str):
return Dispatch(
Registry(
endpoint=endpoint,
api_key="0000000000000000",
api_url="http://127.0.0.1:10000",
),
)


class TestHTTP(unittest.TestCase):
def setUp(self):
self.server_address = ("127.0.0.1", 9999)
self.endpoint = f"http://{self.server_address[0]}:{self.server_address[1]}"
self.dispatch = create_dispatch_instance(self.endpoint)
host = "127.0.0.1"
port = 9999

self.server_address = (host, port)
self.endpoint = f"http://{host}:{port}"
self.dispatch = Dispatch(
Registry(
endpoint=self.endpoint,
api_key="0000000000000000",
api_url="http://127.0.0.1:10000",
),
)

self.client = httpx.Client(timeout=1.0)
self.server = HTTPServer(self.server_address, self.dispatch)
self.thread = threading.Thread(
Expand Down