17
17
from nbox .auth import secret , AuthConfig , auth_info_pb
18
18
from nbox .utils import logger
19
19
from nbox .version import __version__
20
- from nbox .messages import rpc , streaming_rpc
20
+ from nbox import messages as mpb
21
+ # from nbox.messages import rpc, streaming_rpc
21
22
from nbox .init import nbox_grpc_stub , nbox_ws_v1 , nbox_serving_service_stub , nbox_model_service_stub
22
23
from nbox .nbxlib .astea import Astea , IndexTypes as IT
23
24
@@ -165,7 +166,11 @@ def upload_job_folder(
165
166
init_folder : str ,
166
167
id : str = "" ,
167
168
project_id : str = "" ,
169
+
170
+ # job / deploy rpc things
168
171
trigger : bool = False ,
172
+ deploy : bool = True ,
173
+ pin : bool = False ,
169
174
170
175
# all the things for resources
171
176
resource_cpu : str = "" ,
@@ -228,10 +233,10 @@ def upload_job_folder(
228
233
229
234
if method not in OT ._valid_deployment_types ():
230
235
raise ValueError (f"Invalid method: { method } , should be either { OT ._valid_deployment_types ()} " )
231
- # if (not name and not id) or (name and id) :
232
- # raise ValueError("Either --name or --id must be present ")
233
- if trigger and method not in [ OT .JOB , OT . SERVING ] :
234
- raise ValueError (f"Trigger can only be used with ' { OT . JOB } ' or '{ OT .SERVING } '" )
236
+ if trigger and method != OT . JOB :
237
+ raise ValueError (f"Trigger can only be used with method=' { OT . JOB } ' " )
238
+ if pin and method != OT .SERVING :
239
+ raise ValueError (f"Deploy and Pin can only be used with method= '{ OT .SERVING } '" )
235
240
if model_name and method != OT .SERVING :
236
241
raise ValueError (f"model_name can only be used with '{ OT .SERVING } '" )
237
242
@@ -410,6 +415,8 @@ def __common_resource(db: Resource) -> Resource:
410
415
},
411
416
exe_jinja_kwargs = exe_jinja_kwargs ,
412
417
)
418
+ if deploy :
419
+ out .deploy ()
413
420
if trigger :
414
421
out .pin ()
415
422
else :
@@ -450,7 +457,7 @@ def _get_deployment_data(name: str = "", id: str = "", *, workspace_id: str = ""
450
457
workspace_id = workspace_id or secret (AuthConfig .workspace_id )
451
458
452
459
# get the deployment
453
- serving : Serving = rpc (
460
+ serving : Serving = mpb . rpc (
454
461
nbox_serving_service_stub .GetServing ,
455
462
ServingRequest (
456
463
serving = Serving (name = name , id = id ),
@@ -473,7 +480,7 @@ def _get_time(t):
473
480
return datetime .fromtimestamp (int (float (t ))).strftime ("%Y-%m-%d %H:%M:%S" )
474
481
475
482
workspace_id = workspace_id or secret (AuthConfig .workspace_id )
476
- all_deployments : ServingListResponse = rpc (
483
+ all_deployments : ServingListResponse = mpb . rpc (
477
484
nbox_serving_service_stub .ListServings ,
478
485
ServingListRequest (
479
486
auth_info = auth_info_pb (),
@@ -522,6 +529,13 @@ def __init__(self, serving_id: str = "", model_id: str = "", *, workspace_id: st
522
529
self .serving_name = serving_name
523
530
self .ws_stub = nbox_ws_v1 .deployments
524
531
532
+ def __repr__ (self ) -> str :
533
+ x = f"nbox.Serve('{ self .id } ', '{ self .workspace_id } '"
534
+ if self .model_id is not None :
535
+ x += f", model_id = '{ self .model_id } '"
536
+ x += ")"
537
+ return x
538
+
525
539
def pin (self ) -> bool :
526
540
"""Pin a model to the deployment
527
541
@@ -530,7 +544,7 @@ def pin(self) -> bool:
530
544
workspace_id (str, optional): Workspace ID. Defaults to "".
531
545
"""
532
546
logger .info (f"Pin model { self .model_id } to deployment { self .serving_id } " )
533
- rpc (
547
+ mpb . rpc (
534
548
nbox_model_service_stub .SetModelPin ,
535
549
ModelRequest (
536
550
model = ModelProto (
@@ -552,7 +566,7 @@ def unpin(self) -> bool:
552
566
workspace_id (str, optional): Workspace ID. Defaults to "".
553
567
"""
554
568
logger .info (f"Unpin model { self .model_id } to deployment { self .serving_id } " )
555
- rpc (
569
+ mpb . rpc (
556
570
nbox_model_service_stub .SetModelPin ,
557
571
ModelRequest (
558
572
model = ModelProto (
@@ -578,7 +592,7 @@ def scale(self, replicas: int) -> bool:
578
592
raise ValueError ("Replicas must be greater than or equal to 0" )
579
593
580
594
logger .info (f"Scale model deployment { self .model_id } to { replicas } replicas" )
581
- rpc (
595
+ mpb . rpc (
582
596
nbox_model_service_stub .UpdateModel ,
583
597
UpdateModelRequest (
584
598
model = ModelProto (
@@ -600,7 +614,7 @@ def logs(self, f = sys.stdout):
600
614
f (file, optional): File to write the logs to. Defaults to sys.stdout.
601
615
"""
602
616
logger .debug (f"Streaming logs of model '{ self .model_id } '" )
603
- for model_log in streaming_rpc (
617
+ for model_log in mpb . streaming_rpc (
604
618
nbox_model_service_stub .ModelLogs ,
605
619
ModelRequest (
606
620
model = ModelProto (
@@ -616,13 +630,24 @@ def logs(self, f = sys.stdout):
616
630
f .write (log )
617
631
f .flush ()
618
632
619
- def __repr__ (self ) -> str :
620
- x = f"nbox.Serve('{ self .id } ', '{ self .workspace_id } '"
621
- if self .model_id is not None :
622
- x += f", model_id = '{ self .model_id } '"
623
- x += ")"
624
- return x
625
-
633
+ def deploy (self , tag : str = "" ):
634
+ model = ModelProto (
635
+ id = self .model_id ,
636
+ serving_group_id = self .serving_id ,
637
+ )
638
+ if tag :
639
+ model .feature_gates .update ({
640
+ "SetModelMetadata" : tag
641
+ })
642
+ response : ModelProto = mpb .rpc (
643
+ nbox_model_service_stub .Deploy ,
644
+ ModelRequest (
645
+ model = model ,
646
+ auth_info = auth_info_pb (),
647
+ ),
648
+ "Could not deploy model" ,
649
+ raise_on_error = True
650
+ )
626
651
627
652
628
653
################################################################################
@@ -655,7 +680,7 @@ def _get_job_data(name: str = "", id: str = "", remove_archived: bool = True, *,
655
680
if workspace_id == None :
656
681
workspace_id = "personal"
657
682
658
- job : JobProto = rpc (
683
+ job : JobProto = mpb . rpc (
659
684
nbox_grpc_stub .GetJob ,
660
685
JobRequest (
661
686
auth_info = auth_info_pb (),
@@ -683,7 +708,7 @@ def get_job_list(sort: str = "name", *, workspace_id: str = ""):
683
708
def _get_time (t ):
684
709
return datetime .fromtimestamp (int (float (t ))).strftime ("%Y-%m-%d %H:%M:%S" )
685
710
686
- out : ListJobsResponse = rpc (
711
+ out : ListJobsResponse = mpb . rpc (
687
712
nbox_grpc_stub .ListJobs ,
688
713
ListJobsRequest (auth_info = auth_info_pb ()),
689
714
"Could not get job list" ,
@@ -794,7 +819,7 @@ def change_schedule(self, new_schedule: Schedule):
794
819
"""
795
820
logger .debug (f"Updating job '{ self .job_proto .id } '" )
796
821
self .job_proto .schedule .MergeFrom (new_schedule .get_message ())
797
- rpc (
822
+ mpb . rpc (
798
823
nbox_grpc_stub .UpdateJob ,
799
824
UpdateJobRequest (auth_info = self .auth_info , job = self .job_proto , update_mask = FieldMask (paths = ["schedule" ])),
800
825
"Could not update job schedule" ,
@@ -814,7 +839,7 @@ def __repr__(self) -> str:
814
839
def logs (self , f = sys .stdout ):
815
840
"""Stream logs of the job, `f` can be anything has a `.write/.flush` methods"""
816
841
logger .debug (f"Streaming logs of job '{ self .job_proto .id } '" )
817
- for job_log in streaming_rpc (
842
+ for job_log in mpb . streaming_rpc (
818
843
nbox_grpc_stub .GetJobLogs ,
819
844
JobRequest (auth_info = self .auth_info ,job = self .job_proto ),
820
845
f"Could not get logs of job { self .job_proto .id } , is your job complete?" ,
@@ -827,7 +852,7 @@ def logs(self, f = sys.stdout):
827
852
def delete (self ):
828
853
"""Delete this job"""
829
854
logger .info (f"Deleting job '{ self .job_proto .id } '" )
830
- rpc (nbox_grpc_stub .DeleteJob , JobRequest (auth_info = self .auth_info , job = self .job_proto ,), "Could not delete job" )
855
+ mpb . rpc (nbox_grpc_stub .DeleteJob , JobRequest (auth_info = self .auth_info , job = self .job_proto ,), "Could not delete job" )
831
856
logger .info (f"Deleted job '{ self .job_proto .id } '" )
832
857
self .refresh ()
833
858
@@ -839,7 +864,7 @@ def refresh(self):
839
864
if self .id == None :
840
865
return
841
866
842
- self .job_proto : JobProto = rpc (
867
+ self .job_proto : JobProto = mpb . rpc (
843
868
nbox_grpc_stub .GetJob ,
844
869
JobRequest (auth_info = self .auth_info , job = self .job_proto ),
845
870
f"Could not get job { self .job_proto .id } "
@@ -858,7 +883,7 @@ def trigger(self, tag: str = ""):
858
883
logger .debug (f"Triggering job '{ self .job_proto .id } '" )
859
884
if tag :
860
885
self .job_proto .feature_gates .update ({"SetRunMetadata" : tag })
861
- rpc (nbox_grpc_stub .TriggerJob , JobRequest (auth_info = self .auth_info , job = self .job_proto ), f"Could not trigger job '{ self .job_proto .id } '" )
886
+ mpb . rpc (nbox_grpc_stub .TriggerJob , JobRequest (auth_info = self .auth_info , job = self .job_proto ), f"Could not trigger job '{ self .job_proto .id } '" )
862
887
logger .info (f"Triggered job '{ self .job_proto .id } '" )
863
888
self .refresh ()
864
889
@@ -870,7 +895,7 @@ def pause(self):
870
895
logger .info (f"Pausing job '{ self .job_proto .id } '" )
871
896
job : JobProto = self .job_proto
872
897
job .status = JobProto .Status .PAUSED
873
- rpc (nbox_grpc_stub .UpdateJob , UpdateJobRequest (auth_info = self .auth_info , job = job , update_mask = FieldMask (paths = ["status" , "paused" ])), f"Could not pause job { self .job_proto .id } " , True )
898
+ mpb . rpc (nbox_grpc_stub .UpdateJob , UpdateJobRequest (auth_info = self .auth_info , job = job , update_mask = FieldMask (paths = ["status" , "paused" ])), f"Could not pause job { self .job_proto .id } " , True )
874
899
logger .debug (f"Paused job '{ self .job_proto .id } '" )
875
900
self .refresh ()
876
901
@@ -879,7 +904,7 @@ def resume(self):
879
904
logger .info (f"Resuming job '{ self .job_proto .id } '" )
880
905
job : JobProto = self .job_proto
881
906
job .status = JobProto .Status .SCHEDULED
882
- rpc (nbox_grpc_stub .UpdateJob , UpdateJobRequest (auth_info = self .auth_info , job = job , update_mask = FieldMask (paths = ["status" , "paused" ])), f"Could not resume job { self .job_proto .id } " , True )
907
+ mpb . rpc (nbox_grpc_stub .UpdateJob , UpdateJobRequest (auth_info = self .auth_info , job = job , update_mask = FieldMask (paths = ["status" , "paused" ])), f"Could not resume job { self .job_proto .id } " , True )
883
908
logger .debug (f"Resumed job '{ self .job_proto .id } '" )
884
909
self .refresh ()
885
910
0 commit comments