Skip to content

Commit daffe0d

Browse files
committed
Full support for shared model registry
fix fix fix fix fix fix fix fix fix move endpoint id to backend fix fix fix refactor fix fix remove self fix fix fix fix fix input example name fix fix put newline back fix formatting test add read program refactor fixes
1 parent d3c80b3 commit daffe0d

File tree

10 files changed

+132
-81
lines changed

10 files changed

+132
-81
lines changed

python/hsml/core/model_api.py

+28-7
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def put(self, model_instance, query_params):
4747
)
4848
)
4949

50-
def get(self, name, version):
50+
def get(self, name, version, endpoint_id, shared_registry_project=None):
5151
"""Get the metadata of a model with a certain name and version.
5252
5353
:param name: name of the model
@@ -64,11 +64,23 @@ def get(self, name, version):
6464
"models",
6565
name + "_" + str(version),
6666
]
67-
query_params = {"expand": "trainingdatasets"}
67+
query_params = {"expand": "trainingdatasets", "endpointId": str(endpoint_id)}
68+
6869
model_json = _client._send_request("GET", path_params, query_params)
69-
return model.Model.from_response_json(model_json)
70+
model_meta = model.Model.from_response_json(model_json)
71+
72+
model_meta.shared_registry_project = shared_registry_project
73+
74+
return model_meta
7075

71-
def get_models(self, name, metric=None, direction=None):
76+
def get_models(
77+
self,
78+
name,
79+
endpoint_id,
80+
shared_registry_project=None,
81+
metric=None,
82+
direction=None,
83+
):
7284
"""Get the metadata of models based on the name or optionally the best model given a metric and direction.
7385
7486
:param name: name of the model
@@ -83,8 +95,11 @@ def get_models(self, name, metric=None, direction=None):
8395

8496
_client = client.get_instance()
8597
path_params = ["project", _client._project_id, "models"]
98+
query_params = {
99+
"expand": "trainingdatasets",
100+
"filter_by": ["name_eq:" + name, "endpoint_id:" + str(endpoint_id)],
101+
}
86102

87-
query_params = {"expand": "trainingdatasets", "filter_by": "name_eq:" + name}
88103
if metric is not None and direction is not None:
89104
if direction.lower() == "max":
90105
direction = "desc"
@@ -95,7 +110,12 @@ def get_models(self, name, metric=None, direction=None):
95110
query_params["limit"] = "1"
96111

97112
model_json = _client._send_request("GET", path_params, query_params)
98-
return model.Model.from_response_json(model_json)
113+
models_meta = model.Model.from_response_json(model_json)
114+
115+
for model_meta in models_meta:
116+
model_meta.shared_registry_project = shared_registry_project
117+
118+
return models_meta
99119

100120
def delete(self, model_instance):
101121
"""Delete the model and metadata.
@@ -104,5 +124,6 @@ def delete(self, model_instance):
104124
:type model_instance: Model
105125
"""
106126
_client = client.get_instance()
127+
query_params = {"endpointId": str(model_instance.endpoint_id)}
107128
path_params = ["project", _client._project_id, "models", model_instance.id]
108-
_client._send_request("DELETE", path_params)
129+
_client._send_request("DELETE", path_params, query_params)

python/hsml/core/model_registry_api.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,14 @@ def get(self, project=None):
5050
)
5151
)
5252

53+
endpoint_id = _client._project_id
54+
if project is not None:
55+
project_info = _client._get_project_info(project)
56+
endpoint_id = str(project_info["projectId"])
57+
5358
return ModelRegistry(
54-
_client._project_name, _client._project_id, shared_registry_project=project
59+
_client._project_name,
60+
_client._project_id,
61+
endpoint_id,
62+
shared_registry_project=project,
5563
)

python/hsml/engine/local_engine.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,4 @@ def mkdir(self, model_instance):
2525
self._dataset_api.mkdir(model_instance.path)
2626

2727
def delete(self, model_instance):
28-
self._dataset_api.delete(model_instance.path)
28+
self._dataset_api.rm(model_instance.path)

python/hsml/engine/model_engine.py

+35-58
Original file line numberDiff line numberDiff line change
@@ -43,17 +43,26 @@ def __init__(self):
4343
else:
4444
self._engine = hopsworks_engine.HopsworksEngine()
4545

46-
def _poll_model_available(self, model_instance, await_registration):
46+
def _poll_model_available(self, _client, model_instance, await_registration):
4747
if await_registration > 0:
48+
endpoint_id = _client._project_id
49+
if model_instance.shared_registry_project is not None:
50+
project_info = _client._get_project_info(
51+
model_instance.shared_registry_project
52+
)
53+
endpoint_id = str(project_info["projectId"])
4854
sleep_seconds = 5
4955
for i in range(int(await_registration / sleep_seconds)):
5056
try:
5157
time.sleep(sleep_seconds)
52-
model = self._model_api.get(
53-
name=model_instance.name, version=model_instance.version
58+
model_meta = self._model_api.get(
59+
model_instance.name,
60+
model_instance.version,
61+
endpoint_id,
62+
model_instance.shared_registry_project,
5463
)
55-
if model is not None:
56-
return model
64+
if model_meta is not None:
65+
return model_meta
5766
except RestAPIError as e:
5867
if e.response.status_code != 404:
5968
raise e
@@ -88,7 +97,7 @@ def _upload_additional_resources(self, model_instance, dataset_model_version_pat
8897
model_instance.signature = dataset_model_version_path + "/signature.json"
8998
return model_instance
9099

91-
def _copy_hopsfs_model(self, model_path, dataset_model_version_path, client):
100+
def _copy_hopsfs_model(self, model_path, dataset_model_version_path):
92101
# Strip hdfs prefix
93102
if model_path.startswith("hdfs:/"):
94103
projects_index = model_path.find("/Projects", 0)
@@ -168,18 +177,9 @@ def _set_model_version(
168177
)
169178
return model_instance
170179

171-
def _build_registry_path(self, model_instance, artifact_path):
172-
models_path = None
173-
if model_instance.shared_registry_project is not None:
174-
models_path = "{}::{}".format(
175-
model_instance.shared_registry_project,
176-
constants.MODEL_SERVING.MODELS_DATASET,
177-
)
178-
else:
179-
models_path = constants.MODEL_SERVING.MODELS_DATASET
180-
return artifact_path.replace(
181-
constants.MODEL_SERVING.MODELS_DATASET, models_path
182-
)
180+
def _build_artifact_path(self, model_instance, artifact):
181+
artifact_path = model_instance.path + "/" + artifact
182+
return artifact_path
183183

184184
def save(self, model_instance, model_path, await_registration=480):
185185

@@ -287,9 +287,7 @@ def save(self, model_instance, model_path, await_registration=480):
287287
elif self._dataset_api.path_exists(
288288
model_path
289289
): # check hdfs relative and absolute
290-
self._copy_hopsfs_model(
291-
model_path, dataset_model_version_path, _client
292-
)
290+
self._copy_hopsfs_model(model_path, dataset_model_version_path)
293291
else:
294292
raise IOError(
295293
"Could not find path {} in the local filesystem or in HopsFS".format(
@@ -299,11 +297,9 @@ def save(self, model_instance, model_path, await_registration=480):
299297
if step["id"] == 3:
300298
self._model_api.put(model_instance, model_query_params)
301299
if step["id"] == 4:
302-
# We do not necessarily have access to the Models REST API for the shared model registry, so we do not know if it is registered or not
303-
if not is_shared_registry:
304-
model_instance = self._poll_model_available(
305-
model_instance, await_registration
306-
)
300+
model_instance = self._poll_model_available(
301+
_client, model_instance, await_registration
302+
)
307303
if step["id"] == 5:
308304
pass
309305
except BaseException as be:
@@ -326,12 +322,7 @@ def download(self, model_instance):
326322
zip_path = model_version_path + ".zip"
327323
os.makedirs(model_name_path)
328324

329-
dataset_model_name_path = (
330-
constants.MODEL_SERVING.MODELS_DATASET + "/" + model_instance._name
331-
)
332-
dataset_model_version_path = (
333-
dataset_model_name_path + "/" + str(model_instance._version)
334-
)
325+
dataset_model_version_path = model_instance.path
335326

336327
temp_download_dir = "/Resources" + "/" + str(uuid.uuid4())
337328
try:
@@ -340,7 +331,7 @@ def download(self, model_instance):
340331
dataset_model_version_path,
341332
destination_path=temp_download_dir,
342333
block=True,
343-
timeout=480,
334+
timeout=600,
344335
)
345336
self._dataset_api.download(
346337
temp_download_dir + "/" + str(model_instance._version) + ".zip",
@@ -358,44 +349,30 @@ def download(self, model_instance):
358349

359350
return model_version_path
360351

361-
def read_input_example(self, model_instance):
362-
try:
363-
tmp_dir = tempfile.TemporaryDirectory(dir=os.getcwd())
364-
self._dataset_api.download(
365-
self._build_registry_path(
366-
model_instance, model_instance._input_example
367-
),
368-
tmp_dir.name + "/inputs.json",
369-
)
370-
with open(tmp_dir.name + "/inputs.json", "rb") as f:
371-
return json.loads(f.read())
372-
finally:
373-
if tmp_dir is not None and os.path.exists(tmp_dir.name):
374-
tmp_dir.cleanup()
375-
376-
def read_environment(self, model_instance):
352+
def read_file(self, model_instance, artifact):
377353
try:
354+
artifact = os.path.basename(artifact)
378355
tmp_dir = tempfile.TemporaryDirectory(dir=os.getcwd())
356+
local_artifact_path = os.path.join(tmp_dir.name, artifact)
379357
self._dataset_api.download(
380-
self._build_registry_path(
381-
model_instance, model_instance._environment[0]
382-
),
383-
tmp_dir.name + "/environment.yml",
358+
self._build_artifact_path(model_instance, artifact),
359+
local_artifact_path,
384360
)
385-
with open(tmp_dir.name + "/environment.yml", "r") as f:
361+
with open(local_artifact_path, "r") as f:
386362
return f.read()
387363
finally:
388364
if tmp_dir is not None and os.path.exists(tmp_dir.name):
389365
tmp_dir.cleanup()
390366

391-
def read_signature(self, model_instance):
367+
def read_json(self, model_instance, artifact):
392368
try:
393369
tmp_dir = tempfile.TemporaryDirectory(dir=os.getcwd())
370+
local_artifact_path = os.path.join(tmp_dir.name, artifact)
394371
self._dataset_api.download(
395-
self._build_registry_path(model_instance, model_instance._signature),
396-
tmp_dir.name + "/signature.json",
372+
self._build_artifact_path(model_instance, artifact),
373+
os.path.join(tmp_dir.name, artifact),
397374
)
398-
with open(tmp_dir.name + "/signature.json", "rb") as f:
375+
with open(local_artifact_path, "rb") as f:
399376
return json.loads(f.read())
400377
finally:
401378
if tmp_dir is not None and os.path.exists(tmp_dir.name):

python/hsml/model.py

+24-10
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,10 @@ def __init__(
4545
training_dataset=None,
4646
input_example=None,
4747
framework=None,
48+
endpoint_id=None,
4849
):
4950

50-
if id is None:
51-
self._id = name + "_" + str(version)
52-
else:
53-
self._id = id
54-
51+
self._id = id
5552
self._name = name
5653
self._version = version
5754

@@ -73,7 +70,11 @@ def __init__(
7370
self._signature = signature
7471
self._training_dataset = training_dataset
7572

76-
self._shared_registry_project = None
73+
# This is needed for update_from_response_json function to not overwrite name of the shared registry this model originates from
74+
if not hasattr(self, "_shared_registry_project"):
75+
self._shared_registry_project = None
76+
77+
self._endpoint_id = endpoint_id
7778

7879
self._model_api = model_api.ModelApi()
7980
self._dataset_api = dataset_api.DatasetApi()
@@ -186,8 +187,8 @@ def created(self, created):
186187
@property
187188
def environment(self):
188189
"""Input example of the model."""
189-
if self._environment is not None and isinstance(self._environment, list):
190-
self._environment = self._model_engine.read_environment(self)
190+
if self._environment is not None:
191+
return self._model_engine.read_file(self, "environment.yml")
191192
return self._environment
192193

193194
@environment.setter
@@ -215,6 +216,8 @@ def training_metrics(self, training_metrics):
215216
@property
216217
def program(self):
217218
"""Executable used to export the model."""
219+
if self._program is not None:
220+
return self._model_engine.read_file(self, self._program)
218221
return self._program
219222

220223
@program.setter
@@ -234,7 +237,9 @@ def user(self, user_full_name):
234237
def input_example(self):
235238
"""input_example of the model."""
236239
if self._input_example is not None and isinstance(self._input_example, str):
237-
self._input_example = self._model_engine.read_input_example(self)
240+
self._input_example = self._model_engine.read_json(
241+
self, "input_example.json"
242+
)
238243
return self._input_example
239244

240245
@input_example.setter
@@ -254,7 +259,7 @@ def framework(self, framework):
254259
def signature(self):
255260
"""signature of the model."""
256261
if self._signature is not None and isinstance(self._signature, str):
257-
self._signature = self._model_engine.read_signature(self)
262+
self._signature = self._model_engine.read_json(self, "signature.json")
258263
return self._signature
259264

260265
@signature.setter
@@ -289,6 +294,15 @@ def project_name(self):
289294
def project_name(self, project_name):
290295
self._project_name = project_name
291296

297+
@property
298+
def endpoint_id(self):
299+
"""endpoint_id of the model."""
300+
return self._endpoint_id
301+
302+
@endpoint_id.setter
303+
def endpoint_id(self, endpoint_id):
304+
self._endpoint_id = endpoint_id
305+
292306
@property
293307
def experiment_project_name(self):
294308
"""experiment_project_name of the model."""

0 commit comments

Comments
 (0)