Skip to content

Commit

Permalink
added first iteration of custom model load supprot
Browse files Browse the repository at this point in the history
  • Loading branch information
LEFTA98 committed Sep 6, 2022
1 parent 18ebcaa commit 2d60689
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""
Copyright OpenSearch Contributors
SPDX-License-Identifier: Apache-2.0
"""

from opensearchpy import OpenSearch
from opensearch_py_ml.ml_commons_integration.ml_common_utils import ML_BASE_URI


class MLCommonLoadClient:
"""
Client for performing model upload tasks to ml-commons plugin for OpenSearch.
"""
def __init__(self, os_client: OpenSearch):
self._client = os_client

def load_model(self, model_name : str, version_num : int):
"""
Load a model with name model_name and version number version_num.
"""
return self._client.transport.perform_request(
method="POST",
url=f"{ML_BASE_URI}/custom_model/load",
body={
"name": f"\"{model_name}\"",
"version": version_num
}
)
6 changes: 5 additions & 1 deletion opensearch_py_ml/ml_commons_integration/ml_common_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from opensearchpy import OpenSearch
from opensearch_py_ml.ml_commons_integration.upload.ml_common_upload_client import MLCommonUploadClient, DEFAULT_ML_COMMON_UPLOAD_CHUNK_SIZE

from opensearch_py_ml.ml_commons_integration.load.ml_common_load_client import MLCommonLoadClient

class MLCommonClient:
"""
Expand All @@ -16,6 +16,7 @@ class MLCommonClient:
def __init__(self, os_client: OpenSearch):
self._client = os_client
self._upload_client = MLCommonUploadClient(os_client)
self._load_client = MLCommonLoadClient(os_client)

def put_model(self,
model_path: str,
Expand All @@ -24,3 +25,6 @@ def put_model(self,
chunk_size: int = DEFAULT_ML_COMMON_UPLOAD_CHUNK_SIZE,
verbose: bool = False) -> None:
self._upload_client.put_model(model_path, model_name, version_number, chunk_size, verbose)

def load_model(self, model_name: str, version_number: int):
return self._load_client.load_model(model_name, version_number)
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""
Copyright OpenSearch Contributors
SPDX-License-Identifier: Apache-2.0
"""

0 comments on commit 2d60689

Please sign in to comment.