Skip to content

Commit

Permalink
Add KeepAliveClientRequest class for k8s async client (#15220)
Browse files Browse the repository at this point in the history
  • Loading branch information
kevingrismore authored Sep 4, 2024
1 parent 522c254 commit 308f461
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 52 deletions.
45 changes: 24 additions & 21 deletions src/integrations/prefect-kubernetes/prefect_kubernetes/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import sys
from typing import Optional, TypeVar

from kubernetes_asyncio.client import ApiClient
from aiohttp import ClientResponse
from aiohttp.client_reqrep import ClientRequest
from aiohttp.connector import Connection
from slugify import slugify

# Note: `dict(str, str)` is the Kubernetes API convention for
Expand All @@ -14,34 +16,35 @@
V1KubernetesModel = TypeVar("V1KubernetesModel")


def enable_socket_keep_alive(client: ApiClient) -> None:
class KeepAliveClientRequest(ClientRequest):
"""
Setting the keep-alive flags on the kubernetes client object.
Unfortunately neither the kubernetes library nor the urllib3 library which
kubernetes is using internally offer the functionality to enable keep-alive
messages. Thus the flags are added to be used on the underlying sockets.
aiohttp only directly implements socket keepalive for incoming connections
in its RequestHandler. For client connections, we need to set the keepalive
ourselves.
Refer to https://github.com/aio-libs/aiohttp/issues/3904#issuecomment-759205696
"""

socket_options = [(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)]
async def send(self, conn: Connection) -> ClientResponse:
sock = conn.protocol.transport.get_extra_info("socket")
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)

if hasattr(socket, "TCP_KEEPINTVL"):
socket_options.append((socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 30))
if hasattr(socket, "TCP_KEEPIDLE"):
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30)

if hasattr(socket, "TCP_KEEPCNT"):
socket_options.append((socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 6))
if hasattr(socket, "TCP_KEEPINTVL"):
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 6)

if hasattr(socket, "TCP_KEEPIDLE"):
socket_options.append((socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 6))
if hasattr(socket, "TCP_KEEPCNT"):
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 6)

if sys.platform == "darwin":
# TCP_KEEP_ALIVE not available on socket module in macOS, but defined in
# https://github.com/apple/darwin-xnu/blob/2ff845c2e033bd0ff64b5b6aa6063a1f8f65aa32/bsd/netinet/tcp.h#L215
TCP_KEEP_ALIVE = 0x10
socket_options.append((socket.IPPROTO_TCP, TCP_KEEP_ALIVE, 30))
if sys.platform == "darwin":
# TCP_KEEP_ALIVE not available on socket module in macOS, but defined in
# https://github.com/apple/darwin-xnu/blob/2ff845c2e033bd0ff64b5b6aa6063a1f8f65aa32/bsd/netinet/tcp.h#L215
TCP_KEEP_ALIVE = 0x10
sock.setsockopt(socket.IPPROTO_TCP, TCP_KEEP_ALIVE, 30)

client.rest_client.pool_manager.connection_pool_kw[
"socket_options"
] = socket_options
return await super().send(conn)


def _slugify_name(name: str, max_length: int = 45) -> Optional[str]:
Expand Down
12 changes: 7 additions & 5 deletions src/integrations/prefect-kubernetes/prefect_kubernetes/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,10 @@
from prefect_kubernetes.credentials import KubernetesClusterConfig
from prefect_kubernetes.events import KubernetesEventsReplicator
from prefect_kubernetes.utilities import (
KeepAliveClientRequest,
_slugify_label_key,
_slugify_label_value,
_slugify_name,
enable_socket_keep_alive,
)

MAX_ATTEMPTS = 3
Expand Down Expand Up @@ -637,10 +637,6 @@ async def _get_configured_kubernetes_client(
Returns a configured Kubernetes client.
"""
client = None
if os.environ.get(
"PREFECT_KUBERNETES_WORKER_ADD_TCP_KEEPALIVE", "TRUE"
).strip().lower() in ("true", "1"):
enable_socket_keep_alive(client)

if configuration.cluster_config:
config_dict = configuration.cluster_config.config
Expand All @@ -657,6 +653,12 @@ async def _get_configured_kubernetes_client(
except config.ConfigException:
# If in-cluster config fails, load the local kubeconfig
client = await config.new_client_from_config()

if os.environ.get(
"PREFECT_KUBERNETES_WORKER_ADD_TCP_KEEPALIVE", "TRUE"
).strip().lower() in ("true", "1"):
client.rest_client.pool_manager._request_class = KeepAliveClientRequest

try:
yield client
finally:
Expand Down
14 changes: 0 additions & 14 deletions src/integrations/prefect-kubernetes/tests/test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@

import pytest
from kubernetes_asyncio.config import ConfigException
from prefect_kubernetes.utilities import (
enable_socket_keep_alive,
)

FAKE_CLUSTER = "fake-cluster"

Expand All @@ -26,14 +23,3 @@ def mock_cluster_config(monkeypatch):
@pytest.fixture
def mock_api_client(mock_cluster_config):
return MagicMock()


def test_keep_alive_updates_socket_options(mock_api_client):
enable_socket_keep_alive(mock_api_client)

assert (
mock_api_client.rest_client.pool_manager.connection_pool_kw[
"socket_options"
]._mock_set_call
is not None
)
36 changes: 24 additions & 12 deletions src/integrations/prefect-kubernetes/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@
)
from kubernetes_asyncio.config import ConfigException
from prefect_kubernetes import KubernetesWorker
from prefect_kubernetes.utilities import _slugify_label_value, _slugify_name
from prefect_kubernetes.utilities import (
KeepAliveClientRequest,
_slugify_label_value,
_slugify_name,
)
from prefect_kubernetes.worker import KubernetesWorkerJobConfiguration
from pydantic import ValidationError

Expand Down Expand Up @@ -1504,7 +1508,7 @@ async def test_can_store_api_key_in_secret(
)

# Make sure secret gets deleted
assert mock_core_client.return_value.delete_namespaced_secret(
assert await mock_core_client.return_value.delete_namespaced_secret(
name=f"prefect-{_slugify_name(k8s_worker.name)}-api-key",
namespace=configuration.namespace,
)
Expand Down Expand Up @@ -2107,6 +2111,24 @@ async def test_uses_specified_image_pull_policy(
)
assert call_image_pull_policy == "IfNotPresent"

@pytest.mark.usefixtures("mock_core_client_lean", "mock_cluster_config")
async def test_keepalive_enabled(
self,
):
configuration = await KubernetesWorkerJobConfiguration.from_template_and_values(
KubernetesWorker.get_default_base_job_template(),
{"image": "foo"},
)

async with KubernetesWorker(work_pool_name="test") as k8s_worker:
async with k8s_worker._get_configured_kubernetes_client(
configuration
) as client:
assert (
client.rest_client.pool_manager._request_class
is KeepAliveClientRequest
)

async def test_defaults_to_incluster_config(
self,
flow_run,
Expand All @@ -2117,12 +2139,7 @@ async def test_defaults_to_incluster_config(
mock_batch_client,
mock_job,
mock_pod,
monkeypatch,
):
monkeypatch.setattr(
"prefect_kubernetes.worker.enable_socket_keep_alive", MagicMock()
)

async def mock_stream(*args, **kwargs):
if kwargs["func"] == mock_core_client_lean.return_value.list_namespaced_pod:
yield {"object": mock_pod, "type": "MODIFIED"}
Expand All @@ -2148,12 +2165,7 @@ async def test_uses_cluster_config_if_not_in_cluster(
mock_core_client_lean,
mock_job,
mock_pod,
monkeypatch,
):
monkeypatch.setattr(
"prefect_kubernetes.worker.enable_socket_keep_alive", MagicMock()
)

async def mock_stream(*args, **kwargs):
if kwargs["func"] == mock_core_client_lean.return_value.list_namespaced_pod:
yield {"object": mock_pod, "type": "MODIFIED"}
Expand Down

0 comments on commit 308f461

Please sign in to comment.