Skip to content

Commit a7fdd5b

Browse files
committed
manual deploy with tag + nbox fail fast
1 parent fc8c4fe commit a7fdd5b

File tree

12 files changed

+116
-83
lines changed

12 files changed

+116
-83
lines changed

nbox/assets/exe.jinja

+5-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ os.environ["PYTHONUNBUFFERED"] = "true" # so print comes when it should come
1010
import fire
1111
import inspect
1212
from functools import lru_cache, partial
13-
from {{ file_name }} import {{ fn_name }}
1413

1514
from nbox import Operator, logger
1615
import nbox.utils as U
@@ -20,6 +19,11 @@ from nbox.lib.dist import NBXLet
2019
def get_op(cloud = False) -> Operator:
2120
# The beauty of this function is that it ensures that the operator class is loaded only once
2221
try:
22+
# import user code, add this to the try/except because if the code does not exit and there
23+
# is an infinite loop, there can be a whole bunch of side effects, ex: 100s of LMAO live trackers
24+
from {{ file_name }} import {{ fn_name }}
25+
26+
# load the operator
2327
obj = {{ init_code }}
2428
if not type(obj) == Operator and {{ load_operator }}:
2529
# there is an initial level of precaution that we use during deployment, but we are adding simple

nbox/auth.py

+3
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,9 @@ def init_secret():
213213
token_present = len(secret.access_token) > 0,
214214
nbx_url = secret.nbx_url,
215215
))
216+
217+
if not secret.workspace_id:
218+
raise Exception("Workspace ID not found. Please run `nbox login` to login to NimbleBox.")
216219
return secret
217220
else:
218221
logger.info(f"Skipping authentication as NBOX_NO_AUTH is set to True")

nbox/init.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323

2424
from nbox.auth import secret, AuthConfig
25-
from nbox.utils import logger, env
25+
from nbox.utils import logger, env, hard_exit_program
2626
from nbox.subway import Sub30
2727
from nbox.hyperloop.jobs.nbox_ws_pb2_grpc import WSJobServiceStub
2828
from nbox.hyperloop.deploy.serve_pb2_grpc import ServingServiceStub, ModelServiceStub
@@ -110,6 +110,7 @@ def create_webserver_subway(version: str = "v1", session: requests.Session = Non
110110
except Exception as e:
111111
logger.error(f"Could not connect to webserver at {secret('nbx_url')}")
112112
logger.error(e)
113+
113114
return None
114115

115116
spec = r.json()

nbox/jobs.py

+52-27
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
from nbox.auth import secret, AuthConfig, auth_info_pb
1818
from nbox.utils import logger
1919
from nbox.version import __version__
20-
from nbox.messages import rpc, streaming_rpc
20+
from nbox import messages as mpb
21+
# from nbox.messages import rpc, streaming_rpc
2122
from nbox.init import nbox_grpc_stub, nbox_ws_v1, nbox_serving_service_stub, nbox_model_service_stub
2223
from nbox.nbxlib.astea import Astea, IndexTypes as IT
2324

@@ -165,7 +166,11 @@ def upload_job_folder(
165166
init_folder: str,
166167
id: str = "",
167168
project_id: str = "",
169+
170+
# job / deploy rpc things
168171
trigger: bool = False,
172+
deploy: bool = True,
173+
pin: bool = False,
169174

170175
# all the things for resources
171176
resource_cpu: str = "",
@@ -228,10 +233,10 @@ def upload_job_folder(
228233

229234
if method not in OT._valid_deployment_types():
230235
raise ValueError(f"Invalid method: {method}, should be either {OT._valid_deployment_types()}")
231-
# if (not name and not id) or (name and id):
232-
# raise ValueError("Either --name or --id must be present")
233-
if trigger and method not in [OT.JOB, OT.SERVING]:
234-
raise ValueError(f"Trigger can only be used with '{OT.JOB}' or '{OT.SERVING}'")
236+
if trigger and method != OT.JOB:
237+
raise ValueError(f"Trigger can only be used with method='{OT.JOB}'")
238+
if pin and method != OT.SERVING:
239+
raise ValueError(f"Deploy and Pin can only be used with method='{OT.SERVING}'")
235240
if model_name and method != OT.SERVING:
236241
raise ValueError(f"model_name can only be used with '{OT.SERVING}'")
237242

@@ -410,6 +415,8 @@ def __common_resource(db: Resource) -> Resource:
410415
},
411416
exe_jinja_kwargs = exe_jinja_kwargs,
412417
)
418+
if deploy:
419+
out.deploy()
413420
if trigger:
414421
out.pin()
415422
else:
@@ -450,7 +457,7 @@ def _get_deployment_data(name: str = "", id: str = "", *, workspace_id: str = ""
450457
workspace_id = workspace_id or secret(AuthConfig.workspace_id)
451458

452459
# get the deployment
453-
serving: Serving = rpc(
460+
serving: Serving = mpb.rpc(
454461
nbox_serving_service_stub.GetServing,
455462
ServingRequest(
456463
serving=Serving(name=name, id=id),
@@ -473,7 +480,7 @@ def _get_time(t):
473480
return datetime.fromtimestamp(int(float(t))).strftime("%Y-%m-%d %H:%M:%S")
474481

475482
workspace_id = workspace_id or secret(AuthConfig.workspace_id)
476-
all_deployments: ServingListResponse = rpc(
483+
all_deployments: ServingListResponse = mpb.rpc(
477484
nbox_serving_service_stub.ListServings,
478485
ServingListRequest(
479486
auth_info=auth_info_pb(),
@@ -522,6 +529,13 @@ def __init__(self, serving_id: str = "", model_id: str = "", *, workspace_id: st
522529
self.serving_name = serving_name
523530
self.ws_stub = nbox_ws_v1.deployments
524531

532+
def __repr__(self) -> str:
533+
x = f"nbox.Serve('{self.id}', '{self.workspace_id}'"
534+
if self.model_id is not None:
535+
x += f", model_id = '{self.model_id}'"
536+
x += ")"
537+
return x
538+
525539
def pin(self) -> bool:
526540
"""Pin a model to the deployment
527541
@@ -530,7 +544,7 @@ def pin(self) -> bool:
530544
workspace_id (str, optional): Workspace ID. Defaults to "".
531545
"""
532546
logger.info(f"Pin model {self.model_id} to deployment {self.serving_id}")
533-
rpc(
547+
mpb.rpc(
534548
nbox_model_service_stub.SetModelPin,
535549
ModelRequest(
536550
model = ModelProto(
@@ -552,7 +566,7 @@ def unpin(self) -> bool:
552566
workspace_id (str, optional): Workspace ID. Defaults to "".
553567
"""
554568
logger.info(f"Unpin model {self.model_id} to deployment {self.serving_id}")
555-
rpc(
569+
mpb.rpc(
556570
nbox_model_service_stub.SetModelPin,
557571
ModelRequest(
558572
model = ModelProto(
@@ -578,7 +592,7 @@ def scale(self, replicas: int) -> bool:
578592
raise ValueError("Replicas must be greater than or equal to 0")
579593

580594
logger.info(f"Scale model deployment {self.model_id} to {replicas} replicas")
581-
rpc(
595+
mpb.rpc(
582596
nbox_model_service_stub.UpdateModel,
583597
UpdateModelRequest(
584598
model=ModelProto(
@@ -600,7 +614,7 @@ def logs(self, f = sys.stdout):
600614
f (file, optional): File to write the logs to. Defaults to sys.stdout.
601615
"""
602616
logger.debug(f"Streaming logs of model '{self.model_id}'")
603-
for model_log in streaming_rpc(
617+
for model_log in mpb.streaming_rpc(
604618
nbox_model_service_stub.ModelLogs,
605619
ModelRequest(
606620
model = ModelProto(
@@ -616,13 +630,24 @@ def logs(self, f = sys.stdout):
616630
f.write(log)
617631
f.flush()
618632

619-
def __repr__(self) -> str:
620-
x = f"nbox.Serve('{self.id}', '{self.workspace_id}'"
621-
if self.model_id is not None:
622-
x += f", model_id = '{self.model_id}'"
623-
x += ")"
624-
return x
625-
633+
def deploy(self, tag: str = ""):
634+
model = ModelProto(
635+
id = self.model_id,
636+
serving_group_id = self.serving_id,
637+
)
638+
if tag:
639+
model.feature_gates.update({
640+
"SetModelMetadata": tag
641+
})
642+
response: ModelProto = mpb.rpc(
643+
nbox_model_service_stub.Deploy,
644+
ModelRequest(
645+
model = model,
646+
auth_info = auth_info_pb(),
647+
),
648+
"Could not deploy model",
649+
raise_on_error=True
650+
)
626651

627652

628653
################################################################################
@@ -655,7 +680,7 @@ def _get_job_data(name: str = "", id: str = "", remove_archived: bool = True, *,
655680
if workspace_id == None:
656681
workspace_id = "personal"
657682

658-
job: JobProto = rpc(
683+
job: JobProto = mpb.rpc(
659684
nbox_grpc_stub.GetJob,
660685
JobRequest(
661686
auth_info = auth_info_pb(),
@@ -683,7 +708,7 @@ def get_job_list(sort: str = "name", *, workspace_id: str = ""):
683708
def _get_time(t):
684709
return datetime.fromtimestamp(int(float(t))).strftime("%Y-%m-%d %H:%M:%S")
685710

686-
out: ListJobsResponse = rpc(
711+
out: ListJobsResponse = mpb.rpc(
687712
nbox_grpc_stub.ListJobs,
688713
ListJobsRequest(auth_info = auth_info_pb()),
689714
"Could not get job list",
@@ -794,7 +819,7 @@ def change_schedule(self, new_schedule: Schedule):
794819
"""
795820
logger.debug(f"Updating job '{self.job_proto.id}'")
796821
self.job_proto.schedule.MergeFrom(new_schedule.get_message())
797-
rpc(
822+
mpb.rpc(
798823
nbox_grpc_stub.UpdateJob,
799824
UpdateJobRequest(auth_info=self.auth_info, job=self.job_proto, update_mask=FieldMask(paths=["schedule"])),
800825
"Could not update job schedule",
@@ -814,7 +839,7 @@ def __repr__(self) -> str:
814839
def logs(self, f = sys.stdout):
815840
"""Stream logs of the job, `f` can be anything has a `.write/.flush` methods"""
816841
logger.debug(f"Streaming logs of job '{self.job_proto.id}'")
817-
for job_log in streaming_rpc(
842+
for job_log in mpb.streaming_rpc(
818843
nbox_grpc_stub.GetJobLogs,
819844
JobRequest(auth_info=self.auth_info ,job = self.job_proto),
820845
f"Could not get logs of job {self.job_proto.id}, is your job complete?",
@@ -827,7 +852,7 @@ def logs(self, f = sys.stdout):
827852
def delete(self):
828853
"""Delete this job"""
829854
logger.info(f"Deleting job '{self.job_proto.id}'")
830-
rpc(nbox_grpc_stub.DeleteJob, JobRequest(auth_info=self.auth_info, job = self.job_proto,), "Could not delete job")
855+
mpb.rpc(nbox_grpc_stub.DeleteJob, JobRequest(auth_info=self.auth_info, job = self.job_proto,), "Could not delete job")
831856
logger.info(f"Deleted job '{self.job_proto.id}'")
832857
self.refresh()
833858

@@ -839,7 +864,7 @@ def refresh(self):
839864
if self.id == None:
840865
return
841866

842-
self.job_proto: JobProto = rpc(
867+
self.job_proto: JobProto = mpb.rpc(
843868
nbox_grpc_stub.GetJob,
844869
JobRequest(auth_info=self.auth_info, job = self.job_proto),
845870
f"Could not get job {self.job_proto.id}"
@@ -858,7 +883,7 @@ def trigger(self, tag: str = ""):
858883
logger.debug(f"Triggering job '{self.job_proto.id}'")
859884
if tag:
860885
self.job_proto.feature_gates.update({"SetRunMetadata": tag})
861-
rpc(nbox_grpc_stub.TriggerJob, JobRequest(auth_info=self.auth_info, job = self.job_proto), f"Could not trigger job '{self.job_proto.id}'")
886+
mpb.rpc(nbox_grpc_stub.TriggerJob, JobRequest(auth_info=self.auth_info, job = self.job_proto), f"Could not trigger job '{self.job_proto.id}'")
862887
logger.info(f"Triggered job '{self.job_proto.id}'")
863888
self.refresh()
864889

@@ -870,7 +895,7 @@ def pause(self):
870895
logger.info(f"Pausing job '{self.job_proto.id}'")
871896
job: JobProto = self.job_proto
872897
job.status = JobProto.Status.PAUSED
873-
rpc(nbox_grpc_stub.UpdateJob, UpdateJobRequest(auth_info=self.auth_info, job=job, update_mask=FieldMask(paths=["status", "paused"])), f"Could not pause job {self.job_proto.id}", True)
898+
mpb.rpc(nbox_grpc_stub.UpdateJob, UpdateJobRequest(auth_info=self.auth_info, job=job, update_mask=FieldMask(paths=["status", "paused"])), f"Could not pause job {self.job_proto.id}", True)
874899
logger.debug(f"Paused job '{self.job_proto.id}'")
875900
self.refresh()
876901

@@ -879,7 +904,7 @@ def resume(self):
879904
logger.info(f"Resuming job '{self.job_proto.id}'")
880905
job: JobProto = self.job_proto
881906
job.status = JobProto.Status.SCHEDULED
882-
rpc(nbox_grpc_stub.UpdateJob, UpdateJobRequest(auth_info=self.auth_info, job=job, update_mask=FieldMask(paths=["status", "paused"])), f"Could not resume job {self.job_proto.id}", True)
907+
mpb.rpc(nbox_grpc_stub.UpdateJob, UpdateJobRequest(auth_info=self.auth_info, job=job, update_mask=FieldMask(paths=["status", "paused"])), f"Could not resume job {self.job_proto.id}", True)
883908
logger.debug(f"Resumed job '{self.job_proto.id}'")
884909
self.refresh()
885910

nbox/lib/dist.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,8 @@ def serve(self, **serve_kwargs):
146146
# create the central project class and get the experiment tracker
147147
proj = Project()
148148
logger.info(lo("Project data:", **proj.data))
149-
live_tracker = proj.get_live_tracker()
150-
tracker_config = LiveConfig.from_json(live_tracker.serving.config)
149+
# live_tracker = proj.get_live_tracker()
150+
# tracker_config = LiveConfig.from_json(live_tracker.serving.config)
151151

152152
# now start serving
153153
serve_operator(op_or_app = self.op, **serve_kwargs)

nbox/lmao/common.py

+5-15
Original file line numberDiff line numberDiff line change
@@ -182,21 +182,20 @@ def __init__(
182182
resource: Resource,
183183
cli_comm: str,
184184
enable_system_monitoring: bool = False,
185+
extra_kwargs: Dict[str, Any] = {},
185186
):
186187
self.resource = resource
187188
self.cli_comm = cli_comm
188189
self.enable_system_monitoring = enable_system_monitoring
189-
self.keys = set()
190+
self.extra_kwargs = extra_kwargs
190191

191192
def to_dict(self):
192-
out = {
193+
return {
193194
"resource": mpb.message_to_dict(self.resource),
194195
"cli_comm": self.cli_comm,
195196
"enable_system_monitoring": self.enable_system_monitoring,
197+
"extra_kwargs": self.extra_kwargs,
196198
}
197-
for k in self.keys:
198-
out[k] = getattr(self, k)
199-
return out
200199

201200
def to_json(self):
202201
return json.dumps(self.to_dict())
@@ -206,16 +205,7 @@ def from_json(cls, json_str) -> 'LiveConfig':
206205
d = json.loads(json_str)
207206
d["resource"] = resource_from_dict(d["resource"])
208207
_cls = cls(**d)
209-
for k in d:
210-
if k not in ["resource", "cli_comm", "enable_system_monitoring"]:
211-
_cls.add(k, d[k])
212-
213-
def add(self, key, value):
214-
setattr(self, key, value)
215-
self.keys.add(key)
216-
217-
def get(self, key):
218-
return getattr(self, key)
208+
return _cls
219209

220210

221211
"""

nbox/lmao/exp.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from nbox.lmao.lmao_rpc_client import (
1717
AgentDetails,
1818
RunLog,
19-
Run,
19+
Run as RunProto,
2020
InitRunRequest
2121
)
2222
from nbox.observability.system import SystemMetricsLogger
@@ -111,7 +111,7 @@ def __init__(
111111
raise Exception(f"Project with id {self.project_id} does not exist")
112112

113113
self.config = self._get_config(metadata = metadata)
114-
self.run = self._init_experiment(
114+
self.run: RunProto = self._init_experiment(
115115
project_id = project_id,
116116
config = self.config,
117117
experiment_id = experiment_id
@@ -148,7 +148,7 @@ def _get_config(self, metadata: Dict[str, Any]):
148148
log_config["git"] = get_git_details("./")
149149
return log_config
150150

151-
def _init_experiment(self, project_id, config: Dict[str, Any] = {}, experiment_id: str = ""):
151+
def _init_experiment(self, project_id, config: Dict[str, Any] = {}, experiment_id: str = "") -> RunProto:
152152
# update the server or create new experiment
153153
agent_details = AgentDetails(
154154
workspace_id = self.workspace_id,
@@ -163,7 +163,7 @@ def _init_experiment(self, project_id, config: Dict[str, Any] = {}, experiment_i
163163

164164
if experiment_id:
165165
action = "Updated"
166-
run_details = self.lmao.get_run_details(Run(
166+
run_details = self.lmao.get_run_details(RunProto(
167167
workspace_id = self.workspace_id,
168168
project_id = project_id,
169169
experiment_id = experiment_id,
@@ -173,7 +173,7 @@ def _init_experiment(self, project_id, config: Dict[str, Any] = {}, experiment_i
173173
raise Exception("Server Side exception has occurred, Check the log for details")
174174
if run_details.experiment_id:
175175
# means that this run already exists so we need to make an update call
176-
ack = self.lmao.update_run_status(Run(
176+
ack = self.lmao.update_run_status(RunProto(
177177
workspace_id = self.workspace_id,
178178
project_id = project_id,
179179
experiment_id = run_details.experiment_id,
@@ -204,6 +204,10 @@ def _init_experiment(self, project_id, config: Dict[str, Any] = {}, experiment_i
204204

205205
"""The functions below are the ones supposed to be used."""
206206

207+
@property
208+
def run_config(self) -> Dict[str, Any]:
209+
return loads(self.run.config)
210+
207211
@lru_cache(maxsize=1)
208212
def get_relic(self):
209213
"""Get the underlying Relic for more advanced usage patterns."""
@@ -285,3 +289,4 @@ def add_files(self, *files: List[str]):
285289
relic = self.get_relic()
286290
for f in all_files:
287291
relic.put(f)
292+
return all_files

0 commit comments

Comments
 (0)