Skip to content

Commit abdeb60

Browse files
added support for missing features for python client library for inference and compute deployment (#48)
1 parent b88c95c commit abdeb60

File tree

5 files changed

+314
-53
lines changed

5 files changed

+314
-53
lines changed

centml/cli/cluster.py

Lines changed: 191 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,25 @@
1+
import sys
12
import click
23
from tabulate import tabulate
34
import platform_api_client
45
from platform_api_client.models.endpoint_ready_state import EndpointReadyState
56
from platform_api_client.models.deployment_status import DeploymentStatus
6-
77
from ..sdk import api
88

99

10+
# Custom class to parse key-value pairs for env variables for inference deployment
11+
class InferenceEnvType(click.ParamType):
12+
name = "key_value"
13+
14+
def convert(self, value, param, ctx):
15+
try:
16+
key, val = value.split('=', 1)
17+
return key, val
18+
except ValueError:
19+
self.fail(f"{value} is not a valid key=value pair", param, ctx)
20+
return None # to avoid warning from lint for inconsistent return statements
21+
22+
1023
hw_to_id_map = {"small": 1000, "medium": 1001, "large": 1002}
1124
id_to_hw_map = {v: k for k, v in hw_to_id_map.items()}
1225

@@ -17,19 +30,43 @@
1730
}
1831

1932

33+
def format_ssh_key(ssh_key):
34+
if not ssh_key:
35+
return "No SSH Key Found"
36+
return ssh_key[:10] + '...'
37+
38+
2039
def get_ready_status(api_status, service_status):
21-
if api_status == DeploymentStatus.PAUSED:
22-
return click.style("paused", fg="yellow")
23-
elif api_status == DeploymentStatus.DELETED:
24-
return click.style("deleted", fg="white")
25-
elif api_status == DeploymentStatus.FAILED:
26-
return click.style("failed", fg="red")
27-
elif api_status == DeploymentStatus.ACTIVE and service_status == EndpointReadyState.NUMBER_1:
28-
return click.style("ready", fg="green")
29-
elif api_status == DeploymentStatus.ACTIVE and service_status == EndpointReadyState.NUMBER_2:
30-
return click.style("starting", fg="cyan")
31-
else:
32-
return click.style("unknown", fg="black", bg="white")
40+
status_styles = {
41+
(DeploymentStatus.PAUSED, None): ("paused", "yellow", "black"),
42+
(DeploymentStatus.DELETED, None): ("deleted", "white", "black"),
43+
(DeploymentStatus.ACTIVE, EndpointReadyState.READY): ("ready", "green", "black"),
44+
(DeploymentStatus.ACTIVE, EndpointReadyState.NOT_READY): ("starting", "black", "white"),
45+
(DeploymentStatus.ACTIVE, EndpointReadyState.NOT_FOUND): ("not found", "cyan"),
46+
(DeploymentStatus.ACTIVE, EndpointReadyState.FOUND_MULTIPLE): ("found multiple", "black", "white"),
47+
(DeploymentStatus.ACTIVE, EndpointReadyState.INGRESS_RULE_NOT_FOUND): (
48+
"ingress rule not found",
49+
"black",
50+
"white",
51+
),
52+
(DeploymentStatus.ACTIVE, EndpointReadyState.CONDITION_NOT_FOUND): ("condition not found", "black", "white"),
53+
(DeploymentStatus.ACTIVE, EndpointReadyState.INGRESS_NOT_CONFIGURED): (
54+
"ingress not configured",
55+
"black",
56+
"white",
57+
),
58+
(DeploymentStatus.ACTIVE, EndpointReadyState.CONTAINER_MISSING): ("container missing", "black", "white"),
59+
(DeploymentStatus.ACTIVE, EndpointReadyState.PROGRESS_DEADLINE_EXCEEDED): (
60+
"progress deadline exceeded",
61+
"black",
62+
"white",
63+
),
64+
(DeploymentStatus.ACTIVE, EndpointReadyState.REVISION_MISSING): ("revision missing", "black", "white"),
65+
}
66+
67+
style = status_styles.get((api_status, service_status), ("unknown", "black", "white"))
68+
# Handle foreground and background colors
69+
return click.style(style[0], fg=style[1], bg=style[2])
3370

3471

3572
@click.command(help="List all deployments")
@@ -49,19 +86,25 @@ def ls(type):
4986

5087

5188
@click.command(help="Get deployment details")
89+
@click.argument("type", type=click.Choice(list(depl_type_map.keys())))
5290
@click.argument("id", type=int)
53-
def get(id):
54-
deployment = api.get_inference(id)
91+
def get(type, id):
92+
if type == platform_api_client.DeploymentType.INFERENCE:
93+
deployment = api.get_inference(id)
94+
elif type == platform_api_client.DeploymentType.COMPUTE:
95+
deployment = api.get_compute(id)
96+
else:
97+
sys.exit("Please enter correct deployment type")
5598
state = api.get_status(id)
5699
ready_status = get_ready_status(deployment.status, state.service_status)
57100

58-
click.echo(f"Inference deployment #{id} is {ready_status}")
101+
click.echo(f"The current status of Deployment #{id} is: {ready_status}.")
59102
click.echo(
60103
tabulate(
61104
[
62105
("Name", deployment.name),
63106
("Image", deployment.image_url),
64-
("Endpoint", f"https://{deployment.endpoint_url}/"),
107+
("Endpoint", deployment.endpoint_url),
65108
("Created at", deployment.created_at.strftime("%Y-%m-%d %H:%M:%S")),
66109
("Hardware", id_to_hw_map[deployment.hardware_instance_id]),
67110
],
@@ -71,43 +114,146 @@ def get(id):
71114
)
72115

73116
click.echo("Additional deployment configurations:")
74-
click.echo(
75-
tabulate(
76-
[
77-
("Is private?", deployment.secrets is not None),
78-
("Hardware", id_to_hw_map[deployment.hardware_instance_id]),
79-
("Port", deployment.port),
80-
("Healthcheck", deployment.healthcheck or "/"),
81-
("Replicas", {"min": deployment.min_replicas, "max": deployment.max_replicas}),
82-
("Environment variables", deployment.env_vars or "None"),
83-
],
84-
tablefmt="rounded_outline",
85-
disable_numparse=True,
117+
if type == platform_api_client.DeploymentType.INFERENCE:
118+
click.echo(
119+
tabulate(
120+
[
121+
("Port", deployment.port),
122+
("Healthcheck", deployment.healthcheck or "/"),
123+
("Replicas", {"min": deployment.min_replicas, "max": deployment.max_replicas}),
124+
("Environment variables", deployment.env_vars or "None"),
125+
("Max concurrency", deployment.timeout or "None"),
126+
],
127+
tablefmt="rounded_outline",
128+
disable_numparse=True,
129+
)
130+
)
131+
elif type == platform_api_client.DeploymentType.COMPUTE:
132+
click.echo(
133+
tabulate(
134+
[
135+
("Port", deployment.port),
136+
("Username", deployment.username or "None"),
137+
("SSH key", format_ssh_key(deployment.ssh_key)),
138+
],
139+
tablefmt="rounded_outline",
140+
disable_numparse=True,
141+
)
86142
)
143+
144+
145+
# Define common deployment
146+
def common_options(func):
147+
func = click.option("--name", "-n", prompt="Name", help="Name of the deployment")(func)
148+
func = click.option("--image", "-i", prompt="Image", help="Container image")(func)
149+
func = click.option(
150+
"--hardware",
151+
"-h",
152+
prompt="Hardware",
153+
type=click.Choice(list(hw_to_id_map.keys())),
154+
help="Hardware instance type",
155+
)(func)
156+
return func
157+
158+
159+
# Define inference specific options
160+
def inference_options(func):
161+
func = click.option("--port", "-p", prompt="Port", type=int, help="Port to expose")(func)
162+
func = click.option(
163+
"--env", type=InferenceEnvType(), help="Environment variables in the format KEY=VALUE", multiple=True
164+
)(func)
165+
func = click.option("--min_replicas", default="1", prompt="Min replicas", type=click.IntRange(1, 10))(func)
166+
func = click.option("--max_replicas", default="1", prompt="Max replicas", type=click.IntRange(1, 10))(func)
167+
func = click.option("--health", default="/", prompt="Health check", help="Health check endpoint")(func)
168+
func = click.option("--is_private", default=False, type=bool, prompt="Is private?", help="Is private endpoint?")(
169+
func
170+
)
171+
func = click.option("--timeout", prompt="Max concurrency", default=0, type=int)(func)
172+
func = click.option("--command", type=str, required=False, default=None, help="Define a command for a container")(
173+
func
87174
)
175+
func = click.option("--command_args", multiple=True, type=str, default=None, help="List of command arguments")(func)
176+
return func
88177

89178

90-
@click.command(help="Create a new deployment")
91-
@click.argument("type", type=click.Choice(list(depl_type_map.keys())))
92-
@click.option("--name", "-n", prompt="Name", help="Name of the deployment")
93-
@click.option("--image", "-i", prompt="Image", help="Container image")
94-
@click.option("--port", "-p", prompt="Port", type=int, help="Port to expose")
95-
@click.option(
96-
"--hardware", "-h", prompt="Hardware", type=click.Choice(list(hw_to_id_map.keys())), help="Hardware instance type"
97-
)
98-
@click.option("--health", default="/", prompt="Health check", help="Health check endpoint")
99-
@click.option("--min_replicas", default="1", prompt="Min replicas", type=click.IntRange(1, 10))
100-
@click.option("--max_replicas", default="1", prompt="Max replicas", type=click.IntRange(1, 10))
101-
@click.option("--username", prompt=True, default="", help="Username for HTTP authentication")
102-
@click.option("--password", prompt=True, default="", hide_input=True, help="Password for HTTP authentication")
103-
@click.option("--env", "-e", required=False, type=str, multiple=True, help="Environment variables (KEY=VALUE)")
104-
def create(type, name, image, port, hardware, health, min_replicas, max_replicas, username, password, env):
179+
# Define compute specific options
180+
def compute_options(func):
181+
func = click.option("--username", prompt="Username", type=str, help="Username")(func)
182+
func = click.option("--password", prompt="Password", hide_input=True, type=str, help="password")(func)
183+
func = click.option(
184+
"--ssh_key", prompt="Add ssh key", default="", type=str, help="Would you like to add an SSH key?"
185+
)(func)
186+
return func
187+
188+
189+
# Main command group
190+
@click.group(help="Create a new deployment")
191+
@click.pass_context
192+
def create(ctx):
193+
pass
194+
195+
196+
# Define the inference subcommand
197+
@create.command(name="inference", help="Create an inference deployment")
198+
@common_options
199+
@inference_options
200+
@click.pass_context
201+
def create_inference(ctx, **kwargs):
202+
click.echo("Creating inference deployment with the following options:")
203+
204+
name = kwargs.get("name")
205+
image = kwargs.get("image")
206+
port = kwargs.get("port")
207+
is_private = kwargs.get("is_private")
208+
hardware = kwargs.get("hardware")
209+
health = kwargs.get("health")
210+
min_replicas = kwargs.get("min_replicas")
211+
max_replicas = kwargs.get("max_replicas")
212+
env = kwargs.get("env")
213+
command = kwargs.get("command")
214+
command_args = kwargs.get("command_args")
215+
timeout = kwargs.get("timeout")
216+
217+
# Call the API function for creating infrence deployment
105218
resp = api.create_inference(
106-
name, image, port, hw_to_id_map[hardware], health, min_replicas, max_replicas, username, password, env
219+
name,
220+
image,
221+
port,
222+
is_private,
223+
hw_to_id_map[hardware],
224+
health,
225+
min_replicas,
226+
max_replicas,
227+
env,
228+
command,
229+
command_args,
230+
timeout,
107231
)
232+
108233
click.echo(f"Inference deployment #{resp.id} created at https://{resp.endpoint_url}/")
109234

110235

236+
# Define the compute subcommand
237+
@create.command(name="compute", help="Create a compute deployment")
238+
@common_options
239+
@compute_options
240+
@click.pass_context
241+
def create_compute(ctx, **kwargs):
242+
click.echo("Creating compute deployment with the following options:")
243+
244+
name = kwargs.get("name")
245+
image = kwargs.get("image")
246+
username = kwargs.get("username")
247+
password = kwargs.get("password")
248+
ssh_key = kwargs.get("ssh_key")
249+
hardware = kwargs.get("hardware")
250+
251+
# Call the API function for creating infrence deployment
252+
resp = api.create_compute(name, image, username, password, ssh_key, hw_to_id_map[hardware])
253+
254+
click.echo(f"Compute deployment #{resp.id} created at https://{resp.endpoint_url}/")
255+
256+
111257
@click.command(help="Delete a deployment")
112258
@click.argument("id", type=int)
113259
def delete(id):

centml/sdk/api.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from . import auth
66
from .config import Config
7+
from .utils import client_certs
78

89

910
@contextlib.contextmanager
@@ -44,25 +45,44 @@ def get_compute(id):
4445
return api.get_compute_deployment_deployments_compute_deployment_id_get(id)
4546

4647

47-
def create_inference(name, image, port, hw_id, health, min_replicas, max_replicas, username, password, env):
48+
def create_inference(
49+
name, image, port, is_private, hw_to_id_map, health, min_replicas, max_replicas, env, command, command_args, timeout
50+
):
51+
if is_private:
52+
triplet = client_certs.generate_ca_client_triplet(name)
53+
# Handle automatic download of client private secrets
54+
client_certs.save_pem_file(name, triplet.client_private_key, triplet.client_certificate)
4855
with get_api() as api:
4956
req = platform_api_client.CreateInferenceDeploymentRequest(
5057
name=name,
5158
image_url=image,
52-
hardware_instance_id=hw_id,
53-
env_vars={k: v for (k, v) in env},
54-
secrets=(
55-
platform_api_client.AuthSecret(username=username, password=password) if username and password else None
56-
),
5759
port=port,
60+
hardware_instance_id=hw_to_id_map,
61+
healthcheck=health,
5862
min_replicas=min_replicas,
5963
max_replicas=max_replicas,
60-
timeout=0,
61-
healthcheck=health,
64+
env_vars=dict(env) if dict(env) else None,
65+
command=[command] if command else None,
66+
command_args=(list(command_args) if command and len(list(command_args)) > 0 else None),
67+
timeout=timeout,
68+
endpoint_certificate_authority=triplet.certificate_authority if triplet else None,
6269
)
6370
return api.create_inference_deployment_deployments_inference_post(req)
6471

6572

73+
def create_compute(name, image, username, password, ssh_key, hw_to_id_map):
74+
with get_api() as api:
75+
req = platform_api_client.CreateComputeDeploymentRequest(
76+
name=name,
77+
image_url=image,
78+
hardware_instance_id=hw_to_id_map,
79+
username=username,
80+
password=password,
81+
ssh_key=ssh_key if ssh_key else None,
82+
)
83+
return api.create_compute_deployment_deployments_compute_post(req)
84+
85+
6686
def update_status(id, new_status):
6787
with get_api() as api:
6888
status_req = platform_api_client.DeploymentStatusRequest(status=new_status)

0 commit comments

Comments
 (0)