Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[k8s] Share SSH jump pod #2826

Merged
merged 49 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from 46 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
ee8f113
[k8s] use constant jump pod name for namespacing
kbrgl Nov 23, 2023
56f8e99
wip: use single secret object to store jump keys
kbrgl Nov 27, 2023
9afb1ca
wip: reload SSH keys on interval in jump lifecycle manager
kbrgl Nov 27, 2023
f3004e7
fix import
kbrgl Nov 27, 2023
eb870c6
fix: ensure public key ends with newline
kbrgl Nov 27, 2023
c05ce9e
fix comment
kbrgl Nov 27, 2023
4e2b661
use threading to allow differing poll intervals in LCM script
kbrgl Dec 20, 2023
108ff89
Use timeout in first kubeapi call
kbrgl Dec 24, 2023
5053296
Use `sky-ssh-keys` instead of `sky-ssh-key`
kbrgl Dec 24, 2023
4e502a9
Update LCM manager docstring
kbrgl Dec 24, 2023
1f85115
Clean up error messages in LCM
kbrgl Dec 24, 2023
b96bc50
add parent=skypilot label to secret object
kbrgl Jan 5, 2024
4f18307
diff keys and only reload if changed
kbrgl Jan 6, 2024
190994a
more accurate error message
kbrgl Jan 6, 2024
8eff1c6
backwards compatibility fix
kbrgl Jan 6, 2024
804bf98
Run yapf and pylint
kbrgl Jan 6, 2024
ae7683d
fix format
kbrgl Jan 6, 2024
1be67fc
Merge branch 'master' into kabir/issue-2598-jump
kbrgl Jan 13, 2024
ae4a3ac
[k8s] use constant jump pod name for namespacing
kbrgl Nov 23, 2023
c80173f
wip: use single secret object to store jump keys
kbrgl Nov 27, 2023
aff92d7
wip: reload SSH keys on interval in jump lifecycle manager
kbrgl Nov 27, 2023
9f384eb
fix import
kbrgl Nov 27, 2023
4f74d3b
fix: ensure public key ends with newline
kbrgl Nov 27, 2023
f9efdf8
fix comment
kbrgl Nov 27, 2023
d945915
use threading to allow differing poll intervals in LCM script
kbrgl Dec 20, 2023
3d3a2fc
Use timeout in first kubeapi call
kbrgl Dec 24, 2023
da50079
Use `sky-ssh-keys` instead of `sky-ssh-key`
kbrgl Dec 24, 2023
e169fb4
Update LCM manager docstring
kbrgl Dec 24, 2023
f297ac4
Clean up error messages in LCM
kbrgl Dec 24, 2023
850f7d2
add parent=skypilot label to secret object
kbrgl Jan 5, 2024
0290a6b
diff keys and only reload if changed
kbrgl Jan 6, 2024
fa1bcc1
more accurate error message
kbrgl Jan 6, 2024
8e3e8e7
backwards compatibility fix
kbrgl Jan 6, 2024
e0474e8
Run yapf and pylint
kbrgl Jan 6, 2024
e125c05
fix format
kbrgl Jan 6, 2024
afcdd84
fix rebase mistakes
kbrgl Feb 7, 2024
77258da
fix ssh provision issue by reordering perm changes
kbrgl Feb 8, 2024
0acf179
run lint and format
kbrgl Feb 8, 2024
8d0e94b
Merge branch 'kabir/issue-2598-jump' of https://github.com/kbrgl/skyp…
romilbhardwaj Feb 20, 2024
59e4c96
Merge branch 'master' of https://github.com/skypilot-org/skypilot int…
romilbhardwaj Feb 20, 2024
4bc30f2
move auth key cat up
romilbhardwaj Feb 20, 2024
76e62fd
move auth key cat down
romilbhardwaj Feb 21, 2024
ffb057e
remove logger
romilbhardwaj Feb 21, 2024
13aed4f
lint
romilbhardwaj Feb 21, 2024
1837937
Merge branch 'master' of https://github.com/skypilot-org/skypilot int…
romilbhardwaj Feb 21, 2024
74e51da
Add missed commit from #200
romilbhardwaj Feb 21, 2024
2a65bac
Merge branch 'master' of https://github.com/skypilot-org/skypilot int…
romilbhardwaj Feb 21, 2024
25b5db4
update build_image to use latest or date tag
romilbhardwaj Feb 21, 2024
a0dea6c
update docker images to install socat and netcat and update boto3 on …
romilbhardwaj Feb 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/reference/kubernetes/kubernetes-setup.rst
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ To use this mode:
.. code-block:: bash

# Patch the nginx ingress service with an external IP. Can be any node's IP if using NodePort service.
# Replace <IP> in the following command with the IP you select.
$ kubectl patch svc ingress-nginx-controller -n ingress-nginx -p '{"spec": {"externalIPs": ["<IP>"]}}'

If the ``EXTERNAL-IP`` field is left as ``<none>``, SkyPilot will use ``localhost`` as the external IP for the Ingress,
Expand Down
46 changes: 23 additions & 23 deletions sky/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from sky import skypilot_config
from sky.adaptors import gcp
from sky.adaptors import ibm
from sky.adaptors import kubernetes
from sky.adaptors import runpod
from sky.clouds.utils import lambda_utils
from sky.provision.fluidstack import fluidstack_utils
Expand Down Expand Up @@ -398,29 +399,29 @@ def setup_kubernetes_authentication(config: Dict[str, Any]) -> Dict[str, Any]:
from None
get_or_generate_keys()

# Run kubectl command to add the public key to the cluster.
# Add the user's public key to the SkyPilot cluster.
public_key_path = os.path.expanduser(PUBLIC_SSH_KEY_PATH)
key_label = clouds.Kubernetes.SKY_SSH_KEY_SECRET_NAME
cmd = f'kubectl create secret generic {key_label} ' \
f'--from-file=ssh-publickey={public_key_path}'
try:
subprocess.check_output(cmd, stderr=subprocess.STDOUT, shell=True)
except subprocess.CalledProcessError as e:
output = e.output.decode('utf-8')
suffix = f'\nError message: {output}'
if 'already exists' in output:
logger.debug(
f'Key {key_label} already exists in the cluster, using it...')
elif any(err in output for err in ['connection refused', 'timeout']):
with ux_utils.print_exception_no_traceback():
raise ConnectionError(
'Failed to connect to the cluster. Check if your '
'cluster is running, your kubeconfig is correct '
'and you can connect to it using: '
f'kubectl get namespaces.{suffix}') from e
else:
logger.error(suffix)
raise
secret_name = clouds.Kubernetes.SKY_SSH_KEY_SECRET_NAME
secret_field_name = clouds.Kubernetes.SKY_SSH_KEY_SECRET_FIELD_NAME
namespace = kubernetes_utils.get_current_kube_config_context_namespace()
k8s = kubernetes.get_kubernetes()
with open(public_key_path, 'r', encoding='utf-8') as f:
public_key = f.read()
if not public_key.endswith('\n'):
public_key += '\n'
secret_metadata = k8s.client.V1ObjectMeta(name=secret_name,
labels={'parent': 'skypilot'})
secret = k8s.client.V1Secret(
metadata=secret_metadata,
string_data={secret_field_name: public_key})
if kubernetes_utils.check_secret_exists(secret_name, namespace):
logger.debug(f'Key {secret_name} exists in the cluster, patching it...')
kubernetes.core_api().patch_namespaced_secret(secret_name, namespace,
secret)
else:
logger.debug(
f'Key {secret_name} does not exist in the cluster, creating it...')
kubernetes.core_api().create_namespaced_secret(namespace, secret)

ssh_jump_name = clouds.Kubernetes.SKY_SSH_JUMP_NAME
if network_mode == nodeport_mode:
Expand All @@ -438,7 +439,6 @@ def setup_kubernetes_authentication(config: Dict[str, Any]) -> Dict[str, Any]:
# Setup service for SSH jump pod. We create the SSH jump service here
# because we need to know the service IP address and port to set the
# ssh_proxy_command in the autoscaler config.
namespace = kubernetes_utils.get_current_kube_config_context_namespace()
kubernetes_utils.setup_ssh_jump_svc(ssh_jump_name, namespace, service_type)

ssh_proxy_cmd = kubernetes_utils.get_ssh_proxy_command(
Expand Down
6 changes: 4 additions & 2 deletions sky/clouds/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@
class Kubernetes(clouds.Cloud):
"""Kubernetes."""

SKY_SSH_KEY_SECRET_NAME = f'sky-ssh-{common_utils.get_user_hash()}'
SKY_SSH_JUMP_NAME = f'sky-ssh-jump-{common_utils.get_user_hash()}'
SKY_SSH_KEY_SECRET_NAME = 'sky-ssh-keys'
SKY_SSH_KEY_SECRET_FIELD_NAME = \
f'ssh-publickey-{common_utils.get_user_hash()}'
SKY_SSH_JUMP_NAME = 'sky-ssh-jump-pod'
PORT_FORWARD_PROXY_CMD_TEMPLATE = \
'kubernetes-port-forward-proxy-command.sh.j2'
PORT_FORWARD_PROXY_CMD_PATH = '~/.sky/port-forward-proxy-cmd.sh'
Expand Down
4 changes: 2 additions & 2 deletions sky/provision/kubernetes/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,11 +338,11 @@ def _setup_ssh_in_pods(namespace: str, new_nodes: List) -> None:
'pam_loginuid.so@g" -i /etc/pam.d/sshd; '
'cd /etc/ssh/ && $(prefix_cmd) ssh-keygen -A; '
'$(prefix_cmd) mkdir -p ~/.ssh; '
'$(prefix_cmd) cp /etc/secret-volume/ssh-publickey '
'~/.ssh/authorized_keys; '
'$(prefix_cmd) chown -R $(whoami) ~/.ssh;'
'$(prefix_cmd) chmod 700 ~/.ssh; '
'$(prefix_cmd) chmod 644 ~/.ssh/authorized_keys; '
'$(prefix_cmd) cat /etc/secret-volume/ssh-publickey* > '
'~/.ssh/authorized_keys; '
'$(prefix_cmd) service ssh restart; '
# Eliminate the error
# `mesg: ttyname failed: inappropriate ioctl for device`.
Expand Down
19 changes: 19 additions & 0 deletions sky/provision/kubernetes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1150,3 +1150,22 @@ def check_nvidia_runtime_class() -> bool:
nvidia_exists = any(
rc.metadata.name == 'nvidia' for rc in runtime_classes.items)
return nvidia_exists


def check_secret_exists(secret_name: str, namespace: str) -> bool:
"""Checks if a secret exists in a namespace

Args:
secret_name: Name of secret to check
namespace: Namespace to check
"""

try:
kubernetes.core_api().read_namespaced_secret(
secret_name, namespace, _request_timeout=kubernetes.API_TIMEOUT)
except kubernetes.api_exception() as e:
if e.status == 404:
return False
raise
else:
return True
2 changes: 1 addition & 1 deletion sky/skylet/providers/kubernetes/node_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ def _setup_ssh_in_pods(self, new_nodes):
'$(prefix_cmd) sed "s@session\\s*required\\s*pam_loginuid.so@session optional pam_loginuid.so@g" -i /etc/pam.d/sshd; '
'cd /etc/ssh/ && $(prefix_cmd) ssh-keygen -A; '
'$(prefix_cmd) mkdir -p ~/.ssh; '
'$(prefix_cmd) cp /etc/secret-volume/ssh-publickey ~/.ssh/authorized_keys; '
'$(prefix_cmd) cat /etc/secret-volume/ssh-publickey* > ~/.ssh/authorized_keys; '
'$(prefix_cmd) service ssh restart')
]

Expand Down
4 changes: 2 additions & 2 deletions sky/templates/kubernetes-ssh-jump.yml.j2
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ pod_spec:
lifecycle:
postStart:
exec:
command: ["/bin/bash", "-c", "mkdir -p ~/.ssh && cp /etc/secret-volume/ssh-publickey ~/.ssh/authorized_keys && sudo service ssh restart"]
command: ["/bin/bash", "-c", "mkdir -p ~/.ssh && cat /etc/secret-volume/ssh-publickey* > ~/.ssh/authorized_keys && sudo service ssh restart"]
env:
- name: MY_POD_NAME
valueFrom:
Expand Down Expand Up @@ -74,7 +74,7 @@ role:
rules:
- apiGroups: [""]
resources: ["pods", "pods/status", "pods/exec", "services"]
verbs: ["get", "list", "create", "delete"]
verbs: ["get", "list", "create", "delete"]
role_binding:
apiVersion: rbac.authorization.k8s.io/v1
kind: RoleBinding
Expand Down
178 changes: 129 additions & 49 deletions sky/utils/kubernetes/ssh_jump_lifecycle_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,20 @@
This script runs inside ssh jump pod as the main process (PID 1).

It terminates itself (by removing ssh jump service and pod via a call to
kubeapi), if it does not see ray pods in the duration of 10 minutes. If the
kubeapi) if it does not see ray pods in the duration of 10 minutes. If the
user re-launches a task before the duration is over, then ssh jump pod is being
reused and will terminate itself when it sees that no ray cluster exist in that
duration.
reused and will terminate itself when it sees that no ray clusters exist in
that duration.

To allow multiple users to the share the same SSH jump pod,
this script also reloads SSH keys from the mounted secret volume on an
interval and updates `~/.ssh/authorized_keys`.
"""
import datetime
import os
import subprocess
import sys
import threading
import time

from kubernetes import client
Expand All @@ -29,67 +35,130 @@
alert_threshold = int(os.getenv('ALERT_THRESHOLD', '600'))
# The amount of time in seconds to wait between Ray pods existence checks
retry_interval = int(os.getenv('RETRY_INTERVAL', '60'))
# The amount of time in seconds to wait between SSH key reloads
reload_interval = int(os.getenv('RELOAD_INTERVAL', '5'))

# Ray pods are labeled with this value i.e., ssh jump name which is unique per
# user (based on user hash)
label_selector = f'skypilot-ssh-jump={current_name}'


def poll():
sys.stdout.write('Starting polling.\n')
def poll(interval, leading=True):
"""Decorator factory for polling function. To stop polling, return True.

alert_delta = datetime.timedelta(seconds=alert_threshold)
Args:
interval (int): The amount of time to wait between function calls.
leading (bool): Whether to wait before (rather than after) calls.
"""

# Set delay for each retry
retry_interval_delta = datetime.timedelta(seconds=retry_interval)
def decorator(func):

# Accumulated time of where no SkyPilot cluster exists. Used to compare
# against alert_threshold
nocluster_delta = datetime.timedelta()
def wrapper(*args, **kwargs):
while True:
if leading:
time.sleep(interval)
done = func(*args, **kwargs)
if done:
return
if not leading:
time.sleep(interval)

while True:
sys.stdout.write(f'Sleeping {retry_interval} seconds..\n')
time.sleep(retry_interval)
return wrapper

# List the pods in the current namespace
try:
ret = v1.list_namespaced_pod(current_namespace,
label_selector=label_selector)
except Exception as e:
sys.stdout.write(f'Error: listing pods failed with error: {e}\n')
raise
return decorator


# Flag to terminate the reload keys thread when the lifecycle thread
# terminates.
terminated = False

if len(ret.items) == 0:
sys.stdout.write(f'Did not find pods with label "{label_selector}" '
f'in namespace {current_namespace}\n')
nocluster_delta = nocluster_delta + retry_interval_delta
sys.stdout.write(
f'Time since no pods found: {nocluster_delta}, alert '
f'threshold: {alert_delta}\n')
else:
sys.stdout.write(
f'Found pods with label "{label_selector}" in namespace '
f'{current_namespace}\n')
# reset ..
nocluster_delta = datetime.timedelta()
sys.stdout.write(f'noray_delta is reset: {nocluster_delta}\n')

if nocluster_delta >= alert_delta:
@poll(interval=reload_interval, leading=False)
def reload_keys():
"""Reloads SSH keys from mounted secret volume."""

if terminated:
sys.stdout.write('[SSH Key Reloader] Terminated.\n')
return True

# Reload SSH keys from mounted secret volume if changed.
tmpfile = '/tmp/sky-ssh-keys'
try:
subprocess.check_output(
f'cat /etc/secret-volume/ssh-publickey* > {tmpfile}', shell=True)
try:
subprocess.check_output(f'diff {tmpfile} ~/.ssh/authorized_keys',
shell=True)
sys.stdout.write(
f'nocluster_delta: {nocluster_delta} crossed alert threshold: '
f'{alert_delta}. Time to terminate myself and my service.\n')
try:
# ssh jump resources created under same name
v1.delete_namespaced_service(current_name, current_namespace)
v1.delete_namespaced_pod(current_name, current_namespace)
except Exception as e:
sys.stdout.write('[ERROR] Deletion failed. Exiting '
f'poll() with error: {e}\n')
'[SSH Key Reloader] No keys changed, continuing.\n')
except subprocess.CalledProcessError as e:
if e.returncode == 1:
sys.stdout.write(
'[SSH Key Reloader] Changes detected, reloading.\n')
subprocess.check_output(f'mv {tmpfile} ~/.ssh/authorized_keys',
shell=True)
else:
raise
except Exception as e:
sys.stdout.write(
f'[SSH Key Reloader][ERROR] Failed to reload SSH keys: {e}\n')
raise


alert_delta = datetime.timedelta(seconds=alert_threshold)
retry_interval_delta = datetime.timedelta(seconds=retry_interval)
# Accumulated time of where no SkyPilot cluster exists. Compared
# against alert_threshold.
nocluster_delta = datetime.timedelta()


@poll(interval=retry_interval)
def manage_lifecycle():
"""Manages lifecycle of ssh jump pod."""

global terminated, nocluster_delta

try:
ret = v1.list_namespaced_pod(current_namespace,
label_selector=label_selector)
except Exception as e:
sys.stdout.write('[Lifecycle] [ERROR] listing pods failed with '
f'error: {e}\n')
raise

if len(ret.items) == 0:
sys.stdout.write(
f'[Lifecycle] Did not find pods with label '
f'"{label_selector}" in namespace {current_namespace}\n')
nocluster_delta = nocluster_delta + retry_interval_delta
sys.stdout.write(
f'[Lifecycle] Time since no pods found: {nocluster_delta}, alert '
f'threshold: {alert_delta}\n')
else:
sys.stdout.write(
f'[Lifecycle] Found pods with label "{label_selector}" in '
f'namespace {current_namespace}\n')
# reset ..
nocluster_delta = datetime.timedelta()
sys.stdout.write(
f'[Lifecycle] nocluster_delta is reset: {nocluster_delta}\n')

if nocluster_delta >= alert_delta:
sys.stdout.write(
f'[Lifecycle] nocluster_delta: {nocluster_delta} crossed alert '
f'threshold: {alert_delta}. Time to terminate myself and my '
'service.\n')
try:
# ssh jump resources created under same name
v1.delete_namespaced_service(current_name, current_namespace)
v1.delete_namespaced_pod(current_name, current_namespace)
except Exception as e:
sys.stdout.write('[Lifecycle][ERROR] Deletion failed. Exiting '
f'poll() with error: {e}\n')
raise

break

sys.stdout.write('Done polling.\n')
terminated = True
return True


def main():
Expand All @@ -98,13 +167,24 @@ def main():
sys.stdout.write(f'current_namespace: {current_namespace}\n')
sys.stdout.write(f'alert_threshold time: {alert_threshold}\n')
sys.stdout.write(f'retry_interval time: {retry_interval}\n')
sys.stdout.write(f'reload_interval time: {reload_interval}\n')
sys.stdout.write(f'label_selector: {label_selector}\n')

if not current_name or not current_namespace:
# Raise Exception with message to terminate pod
raise Exception('Missing environment variables MY_POD_NAME or '
'MY_POD_NAMESPACE')
poll()

threads = [
threading.Thread(target=manage_lifecycle),
threading.Thread(target=reload_keys)
]
sys.stdout.write(f'Polling with {len(threads)} threads.\n')
for t in threads:
t.start()
for t in threads:
t.join()
sys.stdout.write('Done.\n')


if __name__ == '__main__':
Expand Down
4 changes: 2 additions & 2 deletions tests/kubernetes/scripts/ray_k8s_sky.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ available_node_types:
lifecycle:
postStart:
exec:
command: ["/bin/bash", "-c", "mkdir -p ~/.ssh && cp /etc/secret-volume/ssh-publickey ~/.ssh/authorized_keys && sudo service ssh restart"]
command: ["/bin/bash", "-c", "mkdir -p ~/.ssh && cat /etc/secret-volume/ssh-publickey* > ~/.ssh/authorized_keys && sudo service ssh restart"]
ports:
- containerPort: 22 # Used for SSH
# This volume allocates shared memory for Ray to use for its plasma
Expand Down Expand Up @@ -224,7 +224,7 @@ available_node_types:
lifecycle:
postStart:
exec:
command: ["/bin/bash", "-c", "mkdir -p ~/.ssh && cp /etc/secret-volume/ssh-publickey ~/.ssh/authorized_keys && sudo service ssh restart"]
command: ["/bin/bash", "-c", "mkdir -p ~/.ssh && cat /etc/secret-volume/ssh-publickey* > ~/.ssh/authorized_keys && sudo service ssh restart"]
resources:
requests:
cpu: 1000m
Expand Down
Loading
Loading