Skip to content

Commit b6b129a

Browse files
authored
add http2-server to aws-replicator (#64)
1 parent 0cab1aa commit b6b129a

File tree

3 files changed

+325
-3
lines changed

3 files changed

+325
-3
lines changed

aws-replicator/aws_replicator/client/auth_proxy.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from localstack.utils.files import new_tmp_file, save_file
2929
from localstack.utils.functions import run_safe
3030
from localstack.utils.net import get_docker_host_from_container, get_free_tcp_port
31-
from localstack.utils.server.http2_server import run_server
3231
from localstack.utils.serving import Server
3332
from localstack.utils.strings import short_uid, to_bytes, to_str, truncate
3433
from localstack_ext.bootstrap.licensingv2 import ENV_LOCALSTACK_API_KEY, ENV_LOCALSTACK_AUTH_TOKEN
@@ -39,6 +38,8 @@
3938
from aws_replicator.config import HANDLER_PATH_PROXIES
4039
from aws_replicator.shared.models import AddProxyRequest, ProxyConfig
4140

41+
from .http2_server import run_server
42+
4243
LOG = logging.getLogger(__name__)
4344
LOG.setLevel(logging.INFO)
4445
if config.DEBUG:
Lines changed: 319 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,319 @@
1+
# TODO: currently this is only used for the auth_proxy. replace at some point with the more modern gateway
2+
# server
3+
import asyncio
4+
import collections.abc
5+
import logging
6+
import os
7+
import ssl
8+
import threading
9+
import traceback
10+
from typing import Callable, List, Tuple
11+
12+
import h11
13+
from hypercorn import utils as hypercorn_utils
14+
from hypercorn.asyncio import serve, tcp_server
15+
from hypercorn.config import Config
16+
from hypercorn.events import Closed
17+
from hypercorn.protocol import http_stream
18+
from localstack import config
19+
from localstack.utils.asyncio import ensure_event_loop, run_coroutine, run_sync
20+
from localstack.utils.files import load_file
21+
from localstack.utils.http import uses_chunked_encoding
22+
from localstack.utils.run import FuncThread
23+
from localstack.utils.sync import retry
24+
from localstack.utils.threads import TMP_THREADS
25+
from quart import Quart
26+
from quart import app as quart_app
27+
from quart import asgi as quart_asgi
28+
from quart import make_response, request
29+
from quart import utils as quart_utils
30+
from quart.app import _cancel_all_tasks
31+
32+
LOG = logging.getLogger(__name__)
33+
34+
HTTP_METHODS = ["GET", "POST", "PUT", "DELETE", "HEAD", "OPTIONS", "PATCH"]
35+
36+
# flag to avoid lowercasing all header names (e.g., some AWS S3 SDKs depend on "ETag" response header)
37+
RETURN_CASE_SENSITIVE_HEADERS = True
38+
39+
# default max content length for HTTP server requests (256 MB)
40+
DEFAULT_MAX_CONTENT_LENGTH = 256 * 1024 * 1024
41+
42+
# cache of SSL contexts (indexed by cert file names)
43+
SSL_CONTEXTS = {}
44+
SSL_LOCK = threading.RLock()
45+
46+
47+
def setup_quart_logging():
48+
# set up loggers to avoid duplicate log lines in quart
49+
for name in ["quart.app", "quart.serving"]:
50+
log = logging.getLogger(name)
51+
log.setLevel(logging.INFO if config.DEBUG else logging.WARNING)
52+
for hdl in list(log.handlers):
53+
log.removeHandler(hdl)
54+
55+
56+
def apply_patches():
57+
def InformationalResponse_init(self, *args, **kwargs):
58+
if kwargs.get("status_code") == 100 and not kwargs.get("reason"):
59+
# add missing "100 Continue" keyword which makes boto3 HTTP clients fail/hang
60+
kwargs["reason"] = "Continue"
61+
InformationalResponse_init_orig(self, *args, **kwargs)
62+
63+
InformationalResponse_init_orig = h11.InformationalResponse.__init__
64+
h11.InformationalResponse.__init__ = InformationalResponse_init
65+
66+
# skip error logging for ssl.SSLError in hypercorn tcp_server.py _read_data()
67+
68+
async def _read_data(self) -> None:
69+
try:
70+
return await _read_data_orig(self)
71+
except Exception:
72+
await self.protocol.handle(Closed())
73+
74+
_read_data_orig = tcp_server.TCPServer._read_data
75+
tcp_server.TCPServer._read_data = _read_data
76+
77+
# skip error logging for ssl.SSLError in hypercorn tcp_server.py _close()
78+
79+
async def _close(self) -> None:
80+
try:
81+
return await _close_orig(self)
82+
except ssl.SSLError:
83+
return
84+
85+
_close_orig = tcp_server.TCPServer._close
86+
tcp_server.TCPServer._close = _close
87+
88+
# avoid SSL context initialization errors when running multiple server threads in parallel
89+
90+
def create_ssl_context(self, *args, **kwargs):
91+
with SSL_LOCK:
92+
key = "%s%s" % (self.certfile, self.keyfile)
93+
if key not in SSL_CONTEXTS:
94+
# perform retries to circumvent "ssl.SSLError: [SSL] PEM lib (_ssl.c:4012)"
95+
def _do_create():
96+
SSL_CONTEXTS[key] = create_ssl_context_orig(self, *args, **kwargs)
97+
98+
retry(_do_create, retries=3, sleep=0.5)
99+
return SSL_CONTEXTS[key]
100+
101+
create_ssl_context_orig = Config.create_ssl_context
102+
Config.create_ssl_context = create_ssl_context
103+
104+
# apply patch for case-sensitive header names (e.g., some AWS S3 SDKs depend on "ETag" case-sensitive header)
105+
106+
def _encode_headers(headers):
107+
if RETURN_CASE_SENSITIVE_HEADERS:
108+
return [(key.encode(), value.encode()) for key, value in headers.items()]
109+
return [(key.lower().encode(), value.encode()) for key, value in headers.items()]
110+
111+
quart_asgi._encode_headers = quart_asgi.encode_headers = _encode_headers
112+
quart_app.encode_headers = quart_utils.encode_headers = _encode_headers
113+
114+
def build_and_validate_headers(headers):
115+
validated_headers = []
116+
for name, value in headers:
117+
if name[0] == b":"[0]:
118+
raise ValueError("Pseudo headers are not valid")
119+
header_name = bytes(name) if RETURN_CASE_SENSITIVE_HEADERS else bytes(name).lower()
120+
validated_headers.append((header_name.strip(), bytes(value).strip()))
121+
return validated_headers
122+
123+
hypercorn_utils.build_and_validate_headers = build_and_validate_headers
124+
http_stream.build_and_validate_headers = build_and_validate_headers
125+
126+
# avoid "h11._util.LocalProtocolError: Too little data for declared Content-Length" for certain status codes
127+
128+
def suppress_body(method, status_code):
129+
if status_code == 412:
130+
return False
131+
return suppress_body_orig(method, status_code)
132+
133+
suppress_body_orig = hypercorn_utils.suppress_body
134+
hypercorn_utils.suppress_body = suppress_body
135+
http_stream.suppress_body = suppress_body
136+
137+
138+
class HTTPErrorResponse(Exception):
139+
def __init__(self, *args, code=None, **kwargs):
140+
super(HTTPErrorResponse, self).__init__(*args, **kwargs)
141+
self.code = code
142+
143+
144+
def get_async_generator_result(result):
145+
gen, headers = result, {}
146+
if isinstance(result, tuple) and len(result) >= 2:
147+
gen, headers = result[:2]
148+
if not isinstance(gen, (collections.abc.Generator, collections.abc.AsyncGenerator)):
149+
return
150+
return gen, headers
151+
152+
153+
def run_server(
154+
port: int,
155+
bind_addresses: List[str],
156+
handler: Callable = None,
157+
asynchronous: bool = True,
158+
ssl_creds: Tuple[str, str] = None,
159+
max_content_length: int = None,
160+
send_timeout: int = None,
161+
):
162+
"""
163+
Run an HTTP2-capable Web server on the given port, processing incoming requests via a `handler` function.
164+
:param port: port to bind to
165+
:param bind_addresses: addresses to bind to
166+
:param handler: callable that receives the request and returns a response
167+
:param asynchronous: whether to start the server asynchronously in the background
168+
:param ssl_creds: optional tuple with SSL cert file names (cert file, key file)
169+
:param max_content_length: maximum content length of uploaded payload
170+
:param send_timeout: timeout (in seconds) for sending the request payload over the wire
171+
"""
172+
173+
ensure_event_loop()
174+
app = Quart(__name__, static_folder=None)
175+
app.config["MAX_CONTENT_LENGTH"] = max_content_length or DEFAULT_MAX_CONTENT_LENGTH
176+
if send_timeout:
177+
app.config["BODY_TIMEOUT"] = send_timeout
178+
179+
@app.route("/", methods=HTTP_METHODS, defaults={"path": ""})
180+
@app.route("/<path:path>", methods=HTTP_METHODS)
181+
async def index(path=None):
182+
response = await make_response("{}")
183+
if handler:
184+
data = await request.get_data()
185+
try:
186+
result = await run_sync(handler, request, data)
187+
if isinstance(result, Exception):
188+
raise result
189+
except Exception as e:
190+
LOG.warning(
191+
"Error in proxy handler for request %s %s: %s %s",
192+
request.method,
193+
request.url,
194+
e,
195+
traceback.format_exc(),
196+
)
197+
response.status_code = 500
198+
if isinstance(e, HTTPErrorResponse):
199+
response.status_code = e.code or response.status_code
200+
return response
201+
if result is not None:
202+
# check if this is an async generator (for HTTP2 push event responses)
203+
async_gen = get_async_generator_result(result)
204+
if async_gen:
205+
return async_gen
206+
# prepare and return regular response
207+
is_chunked = uses_chunked_encoding(result)
208+
result_content = result.content or ""
209+
response = await make_response(result_content)
210+
response.status_code = result.status_code
211+
if is_chunked:
212+
response.headers.pop("Content-Length", None)
213+
result.headers.pop("Server", None)
214+
result.headers.pop("Date", None)
215+
headers = {k: str(v).replace("\n", r"\n") for k, v in result.headers.items()}
216+
response.headers.update(headers)
217+
# set multi-value headers
218+
multi_value_headers = getattr(result, "multi_value_headers", {})
219+
for key, values in multi_value_headers.items():
220+
for value in values:
221+
response.headers.add_header(key, value)
222+
# set default headers, if required
223+
if not is_chunked and request.method not in ["OPTIONS", "HEAD"]:
224+
response_data = await response.get_data()
225+
response.headers["Content-Length"] = str(len(response_data or ""))
226+
if "Connection" not in response.headers:
227+
response.headers["Connection"] = "close"
228+
# fix headers for OPTIONS requests (possible fix for Firefox requests)
229+
if request.method == "OPTIONS":
230+
response.headers.pop("Content-Type", None)
231+
if not response.headers.get("Cache-Control"):
232+
response.headers["Cache-Control"] = "no-cache"
233+
return response
234+
235+
def run_app_sync(*args, loop=None, shutdown_event=None):
236+
kwargs = {}
237+
config = Config()
238+
cert_file_name, key_file_name = ssl_creds or (None, None)
239+
if cert_file_name:
240+
kwargs["certfile"] = cert_file_name
241+
config.certfile = cert_file_name
242+
if key_file_name:
243+
kwargs["keyfile"] = key_file_name
244+
config.keyfile = key_file_name
245+
setup_quart_logging()
246+
config.h11_pass_raw_headers = True
247+
config.bind = [f"{bind_address}:{port}" for bind_address in bind_addresses]
248+
config.workers = len(bind_addresses)
249+
loop = loop or ensure_event_loop()
250+
run_kwargs = {}
251+
if shutdown_event:
252+
run_kwargs["shutdown_trigger"] = shutdown_event.wait
253+
try:
254+
try:
255+
return loop.run_until_complete(serve(app, config, **run_kwargs))
256+
except Exception as e:
257+
LOG.info(
258+
"Error running server event loop on port %s: %s %s",
259+
port,
260+
e,
261+
traceback.format_exc(),
262+
)
263+
if "SSL" in str(e):
264+
c_exists = os.path.exists(cert_file_name)
265+
k_exists = os.path.exists(key_file_name)
266+
c_size = len(load_file(cert_file_name)) if c_exists else 0
267+
k_size = len(load_file(key_file_name)) if k_exists else 0
268+
LOG.warning(
269+
"Unable to create SSL context. Cert files exist: %s %s (%sB), %s %s (%sB)",
270+
cert_file_name,
271+
c_exists,
272+
c_size,
273+
key_file_name,
274+
k_exists,
275+
k_size,
276+
)
277+
raise
278+
finally:
279+
try:
280+
_cancel_all_tasks(loop)
281+
loop.run_until_complete(loop.shutdown_asyncgens())
282+
finally:
283+
asyncio.set_event_loop(None)
284+
loop.close()
285+
286+
class ProxyThread(FuncThread):
287+
def __init__(self):
288+
FuncThread.__init__(self, self.run_proxy, None, name="proxy-thread")
289+
self.shutdown_event = None
290+
self.loop = None
291+
292+
def run_proxy(self, *args):
293+
self.loop = ensure_event_loop()
294+
self.shutdown_event = asyncio.Event()
295+
run_app_sync(loop=self.loop, shutdown_event=self.shutdown_event)
296+
297+
def stop(self, quiet=None):
298+
event = self.shutdown_event
299+
300+
async def set_event():
301+
event.set()
302+
303+
run_coroutine(set_event(), self.loop)
304+
super().stop(quiet)
305+
306+
def run_in_thread():
307+
thread = ProxyThread()
308+
thread.start()
309+
TMP_THREADS.append(thread)
310+
return thread
311+
312+
if asynchronous:
313+
return run_in_thread()
314+
315+
return run_app_sync()
316+
317+
318+
# apply patches on startup
319+
apply_patches()

aws-replicator/setup.cfg

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,16 @@ install_requires =
2222
localstack-client
2323
localstack-ext
2424
xmltodict
25+
# TODO: refactor the use of http2_server
26+
hypercorn
27+
h11
28+
quart
2529
# TODO: runtime dependencies below should be removed over time (required for some LS imports)
2630
boto
2731
cbor2
2832
flask-cors
29-
h11
3033
jsonpatch
3134
moto
32-
quart
3335
werkzeug
3436

3537
[options.extras_require]

0 commit comments

Comments
 (0)