Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[k8s] Share SSH jump pod #2826

Merged
merged 49 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
ee8f113
[k8s] use constant jump pod name for namespacing
kbrgl Nov 23, 2023
56f8e99
wip: use single secret object to store jump keys
kbrgl Nov 27, 2023
9afb1ca
wip: reload SSH keys on interval in jump lifecycle manager
kbrgl Nov 27, 2023
f3004e7
fix import
kbrgl Nov 27, 2023
eb870c6
fix: ensure public key ends with newline
kbrgl Nov 27, 2023
c05ce9e
fix comment
kbrgl Nov 27, 2023
4e2b661
use threading to allow differing poll intervals in LCM script
kbrgl Dec 20, 2023
108ff89
Use timeout in first kubeapi call
kbrgl Dec 24, 2023
5053296
Use `sky-ssh-keys` instead of `sky-ssh-key`
kbrgl Dec 24, 2023
4e502a9
Update LCM manager docstring
kbrgl Dec 24, 2023
1f85115
Clean up error messages in LCM
kbrgl Dec 24, 2023
b96bc50
add parent=skypilot label to secret object
kbrgl Jan 5, 2024
4f18307
diff keys and only reload if changed
kbrgl Jan 6, 2024
190994a
more accurate error message
kbrgl Jan 6, 2024
8eff1c6
backwards compatibility fix
kbrgl Jan 6, 2024
804bf98
Run yapf and pylint
kbrgl Jan 6, 2024
ae7683d
fix format
kbrgl Jan 6, 2024
1be67fc
Merge branch 'master' into kabir/issue-2598-jump
kbrgl Jan 13, 2024
ae4a3ac
[k8s] use constant jump pod name for namespacing
kbrgl Nov 23, 2023
c80173f
wip: use single secret object to store jump keys
kbrgl Nov 27, 2023
aff92d7
wip: reload SSH keys on interval in jump lifecycle manager
kbrgl Nov 27, 2023
9f384eb
fix import
kbrgl Nov 27, 2023
4f74d3b
fix: ensure public key ends with newline
kbrgl Nov 27, 2023
f9efdf8
fix comment
kbrgl Nov 27, 2023
d945915
use threading to allow differing poll intervals in LCM script
kbrgl Dec 20, 2023
3d3a2fc
Use timeout in first kubeapi call
kbrgl Dec 24, 2023
da50079
Use `sky-ssh-keys` instead of `sky-ssh-key`
kbrgl Dec 24, 2023
e169fb4
Update LCM manager docstring
kbrgl Dec 24, 2023
f297ac4
Clean up error messages in LCM
kbrgl Dec 24, 2023
850f7d2
add parent=skypilot label to secret object
kbrgl Jan 5, 2024
0290a6b
diff keys and only reload if changed
kbrgl Jan 6, 2024
fa1bcc1
more accurate error message
kbrgl Jan 6, 2024
8e3e8e7
backwards compatibility fix
kbrgl Jan 6, 2024
e0474e8
Run yapf and pylint
kbrgl Jan 6, 2024
e125c05
fix format
kbrgl Jan 6, 2024
afcdd84
fix rebase mistakes
kbrgl Feb 7, 2024
77258da
fix ssh provision issue by reordering perm changes
kbrgl Feb 8, 2024
0acf179
run lint and format
kbrgl Feb 8, 2024
8d0e94b
Merge branch 'kabir/issue-2598-jump' of https://github.com/kbrgl/skyp…
romilbhardwaj Feb 20, 2024
59e4c96
Merge branch 'master' of https://github.com/skypilot-org/skypilot int…
romilbhardwaj Feb 20, 2024
4bc30f2
move auth key cat up
romilbhardwaj Feb 20, 2024
76e62fd
move auth key cat down
romilbhardwaj Feb 21, 2024
ffb057e
remove logger
romilbhardwaj Feb 21, 2024
13aed4f
lint
romilbhardwaj Feb 21, 2024
1837937
Merge branch 'master' of https://github.com/skypilot-org/skypilot int…
romilbhardwaj Feb 21, 2024
74e51da
Add missed commit from #200
romilbhardwaj Feb 21, 2024
2a65bac
Merge branch 'master' of https://github.com/skypilot-org/skypilot int…
romilbhardwaj Feb 21, 2024
25b5db4
update build_image to use latest or date tag
romilbhardwaj Feb 21, 2024
a0dea6c
update docker images to install socat and netcat and update boto3 on …
romilbhardwaj Feb 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 20 additions & 23 deletions sky/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from sky.adaptors import gcp
from sky.adaptors import ibm
from sky.clouds.utils import lambda_utils
from sky.adaptors import kubernetes
from sky.utils import common_utils
from sky.utils import kubernetes_enums
from sky.utils import kubernetes_utils
Expand Down Expand Up @@ -398,29 +399,26 @@ def setup_kubernetes_authentication(config: Dict[str, Any]) -> Dict[str, Any]:
from None
get_or_generate_keys()

# Run kubectl command to add the public key to the cluster.
# Add the user's public key to the SkyPilot cluster.
public_key_path = os.path.expanduser(PUBLIC_SSH_KEY_PATH)
key_label = clouds.Kubernetes.SKY_SSH_KEY_SECRET_NAME
cmd = f'kubectl create secret generic {key_label} ' \
f'--from-file=ssh-publickey={public_key_path}'
try:
subprocess.check_output(cmd, stderr=subprocess.STDOUT, shell=True)
except subprocess.CalledProcessError as e:
output = e.output.decode('utf-8')
suffix = f'\nError message: {output}'
if 'already exists' in output:
logger.debug(
f'Key {key_label} already exists in the cluster, using it...')
elif any(err in output for err in ['connection refused', 'timeout']):
with ux_utils.print_exception_no_traceback():
raise ConnectionError(
'Failed to connect to the cluster. Check if your '
'cluster is running, your kubeconfig is correct '
'and you can connect to it using: '
f'kubectl get namespaces.{suffix}') from e
else:
logger.error(suffix)
raise
secret_name = clouds.Kubernetes.SKY_SSH_KEY_SECRET_NAME
secret_field_name = clouds.Kubernetes.SKY_SSH_KEY_SECRET_FIELD_NAME
namespace = kubernetes_utils.get_current_kube_config_context_namespace()
k8s = kubernetes.get_kubernetes()
with open(public_key_path, 'r') as f:
public_key = f.read()
if not public_key.endswith('\n'):
public_key += '\n'
secret = k8s.client.V1Secret(metadata=k8s.client.V1ObjectMeta(name=secret_name), string_data={
kbrgl marked this conversation as resolved.
Show resolved Hide resolved
secret_field_name: public_key
})
if kubernetes_utils.check_secret_exists(secret_name, namespace):
logger.debug(f'Key {secret_name} exists in the cluster, patching it...')
kubernetes.core_api().patch_namespaced_secret(secret_name, namespace, secret)
else:
logger.debug(
f'Key {secret_name} does not exist in the cluster, creating it...')
kubernetes.core_api().create_namespaced_secret(namespace, secret)

ssh_jump_name = clouds.Kubernetes.SKY_SSH_JUMP_NAME
if network_mode == nodeport_mode:
Expand All @@ -438,7 +436,6 @@ def setup_kubernetes_authentication(config: Dict[str, Any]) -> Dict[str, Any]:
# Setup service for SSH jump pod. We create the SSH jump service here
# because we need to know the service IP address and port to set the
# ssh_proxy_command in the autoscaler config.
namespace = kubernetes_utils.get_current_kube_config_context_namespace()
kubernetes_utils.setup_ssh_jump_svc(ssh_jump_name, namespace, service_type)

ssh_proxy_cmd = kubernetes_utils.get_ssh_proxy_command(
Expand Down
5 changes: 3 additions & 2 deletions sky/clouds/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@
class Kubernetes(clouds.Cloud):
"""Kubernetes."""

SKY_SSH_KEY_SECRET_NAME = f'sky-ssh-{common_utils.get_user_hash()}'
SKY_SSH_JUMP_NAME = f'sky-ssh-jump-{common_utils.get_user_hash()}'
SKY_SSH_KEY_SECRET_NAME = 'sky-ssh-key'
kbrgl marked this conversation as resolved.
Show resolved Hide resolved
SKY_SSH_KEY_SECRET_FIELD_NAME = f'ssh-key-{common_utils.get_user_hash()}'
SKY_SSH_JUMP_NAME = 'sky-ssh-jump-pod'
PORT_FORWARD_PROXY_CMD_TEMPLATE = \
'kubernetes-port-forward-proxy-command.sh.j2'
PORT_FORWARD_PROXY_CMD_PATH = '~/.sky/port-forward-proxy-cmd.sh'
Expand Down
5 changes: 2 additions & 3 deletions sky/templates/kubernetes-ray.yml.j2
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ available_node_types:
lifecycle:
postStart:
exec:
command: ["/bin/bash", "-c", "mkdir -p ~/.ssh && cp /etc/secret-volume/ssh-publickey ~/.ssh/authorized_keys && sudo service ssh restart"]
command: ["/bin/bash", "-c", "mkdir -p ~/.ssh && cat /etc/secret-volume/ssh-key-* > ~/.ssh/authorized_keys && sudo service ssh restart"]
resources:
requests:
cpu: {{cpus}}
Expand Down Expand Up @@ -251,7 +251,7 @@ available_node_types:
lifecycle:
postStart:
exec:
command: ["/bin/bash", "-c", "mkdir -p ~/.ssh && cp /etc/secret-volume/ssh-publickey ~/.ssh/authorized_keys && sudo service ssh restart"]
command: ["/bin/bash", "-c", "mkdir -p ~/.ssh && cat /etc/secret-volume/ssh-key-* > ~/.ssh/authorized_keys && sudo service ssh restart"]
ports:
- containerPort: 22 # Used for SSH
# This volume allocates shared memory for Ray to use for its plasma
Expand Down Expand Up @@ -346,4 +346,3 @@ cluster_synced_files: []
file_mounts_sync_continuously: False
initialization_commands: []
rsync_exclude: []

4 changes: 2 additions & 2 deletions sky/templates/kubernetes-ssh-jump.yml.j2
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ pod_spec:
lifecycle:
postStart:
exec:
command: ["/bin/bash", "-c", "mkdir -p ~/.ssh && cp /etc/secret-volume/ssh-publickey ~/.ssh/authorized_keys && sudo service ssh restart"]
command: ["/bin/bash", "-c", "mkdir -p ~/.ssh && cat /etc/secret-volume/ssh-key-* > ~/.ssh/authorized_keys && sudo service ssh restart"]
env:
- name: MY_POD_NAME
valueFrom:
Expand Down Expand Up @@ -74,7 +74,7 @@ role:
rules:
- apiGroups: [""]
resources: ["pods", "pods/status", "pods/exec", "services"]
verbs: ["get", "list", "create", "delete"]
verbs: ["get", "list", "create", "delete"]
role_binding:
apiVersion: rbac.authorization.k8s.io/v1
kind: RoleBinding
Expand Down
170 changes: 114 additions & 56 deletions sky/utils/kubernetes/ssh_jump_lifecycle_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,20 @@
This script runs inside ssh jump pod as the main process (PID 1).

It terminates itself (by removing ssh jump service and pod via a call to
kubeapi), if it does not see ray pods in the duration of 10 minutes. If the
kubeapi) if it does not see ray pods in the duration of 10 minutes. If the
user re-launches a task before the duration is over, then ssh jump pod is being
reused and will terminate itself when it sees that no ray cluster exist in that
duration.
reused and will terminate itself when it sees that no ray clusters exist in
that duration.

This script also reloads SSH keys from the mounted secret volume on an
interval.
kbrgl marked this conversation as resolved.
Show resolved Hide resolved
"""
import datetime
import os
import sys
import time
import subprocess
import threading

from kubernetes import client
from kubernetes import config
Expand All @@ -29,82 +34,135 @@
alert_threshold = int(os.getenv('ALERT_THRESHOLD', '600'))
# The amount of time in seconds to wait between Ray pods existence checks
retry_interval = int(os.getenv('RETRY_INTERVAL', '60'))
# The amount of time in seconds to wait between SSH key reloads
reload_interval = int(os.getenv('RELOAD_INTERVAL', '5'))

# Ray pods are labeled with this value i.e., ssh jump name which is unique per
# user (based on user hash)
label_selector = f'skypilot-ssh-jump={current_name}'


def poll():
sys.stdout.write('Starting polling.\n')

alert_delta = datetime.timedelta(seconds=alert_threshold)

# Set delay for each retry
retry_interval_delta = datetime.timedelta(seconds=retry_interval)

# Accumulated time of where no SkyPilot cluster exists. Used to compare
# against alert_threshold
nocluster_delta = datetime.timedelta()

while True:
sys.stdout.write(f'Sleeping {retry_interval} seconds..\n')
time.sleep(retry_interval)

# List the pods in the current namespace
def poll(interval, leading=True):
"""Decorator factory for polling function. To stop polling, return True.

Args:
interval (int): The amount of time to wait between function calls.
leading (bool): Whether to wait before (rather than after) calls.
"""
kbrgl marked this conversation as resolved.
Show resolved Hide resolved

def decorator(func):
def wrapper(*args, **kwargs):
while True:
if leading:
time.sleep(interval)
done = func(*args, **kwargs)
if done:
return
if not leading:
time.sleep(interval)
return wrapper
return decorator


# Flag to terminate the reload keys thread when the lifecycle thread
# terminates.
terminated = False

@poll(interval=reload_interval, leading=False)
def reload_keys():
"""Reloads SSH keys from mounted secret volume."""

if terminated:
sys.stdout.write('[SSH Key Reloader] Terminated.\n')
return True

# Reload SSH keys from mounted secret volume
sys.stdout.write('[SSH Key Reloader] Reloading SSH keys.\n')
kbrgl marked this conversation as resolved.
Show resolved Hide resolved
cmd = 'cat /etc/secret-volume/ssh-key-* > ~/.ssh/authorized_keys'
kbrgl marked this conversation as resolved.
Show resolved Hide resolved
try:
subprocess.check_output(cmd, shell=True)
except Exception as e:
sys.stdout.write(f'[SSH Key Reloader] [ERROR] failed to reload SSH keys: {e}\n')
kbrgl marked this conversation as resolved.
Show resolved Hide resolved
raise


alert_delta = datetime.timedelta(seconds=alert_threshold)
retry_interval_delta = datetime.timedelta(seconds=retry_interval)
# Accumulated time of where no SkyPilot cluster exists. Compared
# against alert_threshold.
nocluster_delta = datetime.timedelta()

@poll(interval=retry_interval)
def manage_lifecycle():
"""Manages lifecycle of ssh jump pod."""

global terminated, nocluster_delta

try:
ret = v1.list_namespaced_pod(current_namespace,
label_selector=label_selector)
except Exception as e:
sys.stdout.write('[Lifecycle] [ERROR] listing pods failed with '
f'error: {e}\n')
raise

if len(ret.items) == 0:
sys.stdout.write(f'[Lifecycle] Did not find pods with label '
f'"{label_selector}" in namespace {current_namespace}\n')
nocluster_delta = nocluster_delta + retry_interval_delta
sys.stdout.write(
f'[Lifecycle] Time since no pods found: {nocluster_delta}, alert '
f'threshold: {alert_delta}\n')
else:
sys.stdout.write(
f'[Lifecycle] Found pods with label "{label_selector}" in '
f'namespace {current_namespace}\n')
# reset ..
nocluster_delta = datetime.timedelta()
sys.stdout.write(f'[Lifecycle] nocluster_delta is reset: {nocluster_delta}\n')

if nocluster_delta >= alert_delta:
sys.stdout.write(
f'[Lifecycle] nocluster_delta: {nocluster_delta} crossed alert '
f'threshold: {alert_delta}. Time to terminate myself and my '
'service.\n')
try:
ret = v1.list_namespaced_pod(current_namespace,
label_selector=label_selector)
# ssh jump resources created under same name
v1.delete_namespaced_service(current_name, current_namespace)
v1.delete_namespaced_pod(current_name, current_namespace)
except Exception as e:
sys.stdout.write(f'Error: listing pods failed with error: {e}\n')
sys.stdout.write('[Lifecycle] [ERROR] Deletion failed. Exiting '
kbrgl marked this conversation as resolved.
Show resolved Hide resolved
f'poll() with error: {e}\n')
raise

if len(ret.items) == 0:
sys.stdout.write(f'Did not find pods with label "{label_selector}" '
f'in namespace {current_namespace}\n')
nocluster_delta = nocluster_delta + retry_interval_delta
sys.stdout.write(
f'Time since no pods found: {nocluster_delta}, alert '
f'threshold: {alert_delta}\n')
else:
sys.stdout.write(
f'Found pods with label "{label_selector}" in namespace '
f'{current_namespace}\n')
# reset ..
nocluster_delta = datetime.timedelta()
sys.stdout.write(f'noray_delta is reset: {nocluster_delta}\n')

if nocluster_delta >= alert_delta:
sys.stdout.write(
f'nocluster_delta: {nocluster_delta} crossed alert threshold: '
f'{alert_delta}. Time to terminate myself and my service.\n')
try:
# ssh jump resources created under same name
v1.delete_namespaced_service(current_name, current_namespace)
v1.delete_namespaced_pod(current_name, current_namespace)
except Exception as e:
sys.stdout.write('[ERROR] Deletion failed. Exiting '
f'poll() with error: {e}\n')
raise

break

sys.stdout.write('Done polling.\n')
terminated = True
return True


def main():
sys.stdout.write('SkyPilot SSH Jump Pod Lifecycle Manager\n')
sys.stdout.write(f'current_name: {current_name}\n')
sys.stdout.write(f'current_name: {current_name}\n')
sys.stdout.write(f'current_namespace: {current_namespace}\n')
sys.stdout.write(f'alert_threshold time: {alert_threshold}\n')
sys.stdout.write(f'retry_interval time: {retry_interval}\n')
sys.stdout.write(f'reload_interval time: {reload_interval}\n')
sys.stdout.write(f'label_selector: {label_selector}\n')

if not current_name or not current_namespace:
# Raise Exception with message to terminate pod
raise Exception('Missing environment variables MY_POD_NAME or '
'MY_POD_NAMESPACE')
poll()

threads = [
threading.Thread(target=manage_lifecycle),
threading.Thread(target=reload_keys)
]
sys.stdout.write(f'Polling with {len(threads)} threads.\n')
for t in threads:
t.start()
for t in threads:
t.join()
sys.stdout.write('Done.\n')


if __name__ == '__main__':
Expand Down
17 changes: 17 additions & 0 deletions sky/utils/kubernetes_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -991,3 +991,20 @@ def check_port_forward_mode_dependencies() -> None:
f' $ sudo apt install {install_cmd}\n'
f'On MacOS, install it with: \n'
f' $ brew install {install_cmd}') from None

def check_secret_exists(secret_name: str, namespace: str) -> bool:
"""Checks if a secret exists in a namespace

Args:
secret_name: Name of secret to check
namespace: Namespace to check
"""

try:
kubernetes.core_api().read_namespaced_secret(secret_name, namespace)
kbrgl marked this conversation as resolved.
Show resolved Hide resolved
except kubernetes.api_exception() as e:
if e.status == 404:
return False
raise
else:
return True
4 changes: 2 additions & 2 deletions tests/kubernetes/scripts/ray_k8s_sky.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ available_node_types:
lifecycle:
postStart:
exec:
command: ["/bin/bash", "-c", "mkdir -p ~/.ssh && cp /etc/secret-volume/ssh-publickey ~/.ssh/authorized_keys && sudo service ssh restart"]
command: ["/bin/bash", "-c", "mkdir -p ~/.ssh && cat /etc/secret-volume/ssh-key-* > ~/.ssh/authorized_keys && sudo service ssh restart"]
ports:
- containerPort: 22 # Used for SSH
# This volume allocates shared memory for Ray to use for its plasma
Expand Down Expand Up @@ -224,7 +224,7 @@ available_node_types:
lifecycle:
postStart:
exec:
command: ["/bin/bash", "-c", "mkdir -p ~/.ssh && cp /etc/secret-volume/ssh-publickey ~/.ssh/authorized_keys && sudo service ssh restart"]
command: ["/bin/bash", "-c", "mkdir -p ~/.ssh && cat /etc/secret-volume/ssh-key-* > ~/.ssh/authorized_keys && sudo service ssh restart"]
resources:
requests:
cpu: 1000m
Expand Down
3 changes: 2 additions & 1 deletion tests/kubernetes/scripts/run.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# TODO(kbrgl): Fix secret creation since all SSH keys are now stored in one secret object.
kubectl create secret generic ssh-key-secret --from-file=ssh-publickey=/Users/romilb/.ssh/sky-key.pub
kubectl apply -f skypilot_ssh_k8s_deployment.yaml
# Use kubectl describe service skypilot-service to get the port of the service
kubectl describe service skypilot-service | grep NodePort
echo Run the following command to ssh into the container:
echo ssh sky@127.0.0.1 -p port -i ~/.ssh/sky-key
echo ssh sky@127.0.0.1 -p port -i ~/.ssh/sky-key
2 changes: 1 addition & 1 deletion tests/kubernetes/scripts/skypilot_ssh_k8s_deployment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ spec:
lifecycle:
postStart:
exec:
command: ["/bin/bash", "-c", "mkdir -p ~/.ssh && cp /etc/secret-volume/ssh-publickey ~/.ssh/authorized_keys && sudo service ssh restart"]
command: ["/bin/bash", "-c", "mkdir -p ~/.ssh && cat /etc/secret-volume/ssh-key-* > ~/.ssh/authorized_keys && sudo service ssh restart"]
---
apiVersion: v1
kind: Service
Expand Down
Loading