@@ -43,17 +43,26 @@ def __init__(self):
43
43
else :
44
44
self ._engine = hopsworks_engine .HopsworksEngine ()
45
45
46
- def _poll_model_available (self , model_instance , await_registration ):
46
+ def _poll_model_available (self , _client , model_instance , await_registration ):
47
47
if await_registration > 0 :
48
+ endpoint_id = _client ._project_id
49
+ if model_instance .shared_registry_project is not None :
50
+ project_info = _client ._get_project_info (
51
+ model_instance .shared_registry_project
52
+ )
53
+ endpoint_id = str (project_info ["projectId" ])
48
54
sleep_seconds = 5
49
55
for i in range (int (await_registration / sleep_seconds )):
50
56
try :
51
57
time .sleep (sleep_seconds )
52
- model = self ._model_api .get (
53
- name = model_instance .name , version = model_instance .version
58
+ model_meta = self ._model_api .get (
59
+ model_instance .name ,
60
+ model_instance .version ,
61
+ endpoint_id ,
62
+ model_instance .shared_registry_project ,
54
63
)
55
- if model is not None :
56
- return model
64
+ if model_meta is not None :
65
+ return model_meta
57
66
except RestAPIError as e :
58
67
if e .response .status_code != 404 :
59
68
raise e
@@ -88,7 +97,7 @@ def _upload_additional_resources(self, model_instance, dataset_model_version_pat
88
97
model_instance .signature = dataset_model_version_path + "/signature.json"
89
98
return model_instance
90
99
91
- def _copy_hopsfs_model (self , model_path , dataset_model_version_path , client ):
100
+ def _copy_hopsfs_model (self , model_path , dataset_model_version_path ):
92
101
# Strip hdfs prefix
93
102
if model_path .startswith ("hdfs:/" ):
94
103
projects_index = model_path .find ("/Projects" , 0 )
@@ -168,18 +177,9 @@ def _set_model_version(
168
177
)
169
178
return model_instance
170
179
171
- def _build_registry_path (self , model_instance , artifact_path ):
172
- models_path = None
173
- if model_instance .shared_registry_project is not None :
174
- models_path = "{}::{}" .format (
175
- model_instance .shared_registry_project ,
176
- constants .MODEL_SERVING .MODELS_DATASET ,
177
- )
178
- else :
179
- models_path = constants .MODEL_SERVING .MODELS_DATASET
180
- return artifact_path .replace (
181
- constants .MODEL_SERVING .MODELS_DATASET , models_path
182
- )
180
+ def _build_artifact_path (self , model_instance , artifact ):
181
+ artifact_path = model_instance .path + "/" + artifact
182
+ return artifact_path
183
183
184
184
def save (self , model_instance , model_path , await_registration = 480 ):
185
185
@@ -287,9 +287,7 @@ def save(self, model_instance, model_path, await_registration=480):
287
287
elif self ._dataset_api .path_exists (
288
288
model_path
289
289
): # check hdfs relative and absolute
290
- self ._copy_hopsfs_model (
291
- model_path , dataset_model_version_path , _client
292
- )
290
+ self ._copy_hopsfs_model (model_path , dataset_model_version_path )
293
291
else :
294
292
raise IOError (
295
293
"Could not find path {} in the local filesystem or in HopsFS" .format (
@@ -299,11 +297,9 @@ def save(self, model_instance, model_path, await_registration=480):
299
297
if step ["id" ] == 3 :
300
298
self ._model_api .put (model_instance , model_query_params )
301
299
if step ["id" ] == 4 :
302
- # We do not necessarily have access to the Models REST API for the shared model registry, so we do not know if it is registered or not
303
- if not is_shared_registry :
304
- model_instance = self ._poll_model_available (
305
- model_instance , await_registration
306
- )
300
+ model_instance = self ._poll_model_available (
301
+ _client , model_instance , await_registration
302
+ )
307
303
if step ["id" ] == 5 :
308
304
pass
309
305
except BaseException as be :
@@ -326,12 +322,7 @@ def download(self, model_instance):
326
322
zip_path = model_version_path + ".zip"
327
323
os .makedirs (model_name_path )
328
324
329
- dataset_model_name_path = (
330
- constants .MODEL_SERVING .MODELS_DATASET + "/" + model_instance ._name
331
- )
332
- dataset_model_version_path = (
333
- dataset_model_name_path + "/" + str (model_instance ._version )
334
- )
325
+ dataset_model_version_path = model_instance .path
335
326
336
327
temp_download_dir = "/Resources" + "/" + str (uuid .uuid4 ())
337
328
try :
@@ -340,7 +331,7 @@ def download(self, model_instance):
340
331
dataset_model_version_path ,
341
332
destination_path = temp_download_dir ,
342
333
block = True ,
343
- timeout = 480 ,
334
+ timeout = 600 ,
344
335
)
345
336
self ._dataset_api .download (
346
337
temp_download_dir + "/" + str (model_instance ._version ) + ".zip" ,
@@ -358,44 +349,30 @@ def download(self, model_instance):
358
349
359
350
return model_version_path
360
351
361
- def read_input_example (self , model_instance ):
362
- try :
363
- tmp_dir = tempfile .TemporaryDirectory (dir = os .getcwd ())
364
- self ._dataset_api .download (
365
- self ._build_registry_path (
366
- model_instance , model_instance ._input_example
367
- ),
368
- tmp_dir .name + "/inputs.json" ,
369
- )
370
- with open (tmp_dir .name + "/inputs.json" , "rb" ) as f :
371
- return json .loads (f .read ())
372
- finally :
373
- if tmp_dir is not None and os .path .exists (tmp_dir .name ):
374
- tmp_dir .cleanup ()
375
-
376
- def read_environment (self , model_instance ):
352
+ def read_file (self , model_instance , artifact ):
377
353
try :
354
+ artifact = os .path .basename (artifact )
378
355
tmp_dir = tempfile .TemporaryDirectory (dir = os .getcwd ())
356
+ local_artifact_path = os .path .join (tmp_dir .name , artifact )
379
357
self ._dataset_api .download (
380
- self ._build_registry_path (
381
- model_instance , model_instance ._environment [0 ]
382
- ),
383
- tmp_dir .name + "/environment.yml" ,
358
+ self ._build_artifact_path (model_instance , artifact ),
359
+ local_artifact_path ,
384
360
)
385
- with open (tmp_dir . name + "/environment.yml" , "r" ) as f :
361
+ with open (local_artifact_path , "r" ) as f :
386
362
return f .read ()
387
363
finally :
388
364
if tmp_dir is not None and os .path .exists (tmp_dir .name ):
389
365
tmp_dir .cleanup ()
390
366
391
- def read_signature (self , model_instance ):
367
+ def read_json (self , model_instance , artifact ):
392
368
try :
393
369
tmp_dir = tempfile .TemporaryDirectory (dir = os .getcwd ())
370
+ local_artifact_path = os .path .join (tmp_dir .name , artifact )
394
371
self ._dataset_api .download (
395
- self ._build_registry_path (model_instance , model_instance . _signature ),
396
- tmp_dir .name + "/signature.json" ,
372
+ self ._build_artifact_path (model_instance , artifact ),
373
+ os . path . join ( tmp_dir .name , artifact ) ,
397
374
)
398
- with open (tmp_dir . name + "/signature.json" , "rb" ) as f :
375
+ with open (local_artifact_path , "rb" ) as f :
399
376
return json .loads (f .read ())
400
377
finally :
401
378
if tmp_dir is not None and os .path .exists (tmp_dir .name ):
0 commit comments