Skip to content

Commit

Permalink
PoC for model upload
Browse files Browse the repository at this point in the history
  • Loading branch information
LEFTA98 committed Aug 12, 2022
1 parent 410ad0e commit 7f3d7dd
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 10 deletions.
2 changes: 1 addition & 1 deletion opensearch_py_ml/ml_commons_integration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
Copyright OpenSearch Contributors
SPDX-License-Identifier: Apache-2.0
"""
from ml_common_client import MLCommonClient
from opensearch_py_ml.ml_commons_integration.ml_common_client import MLCommonClient

__all__ = ["MLCommonClient"]
11 changes: 8 additions & 3 deletions opensearch_py_ml/ml_commons_integration/ml_common_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""

from opensearchpy import OpenSearch
from upload.ml_common_upload_client import MLCommonUploadClient, DEFAULT_ML_COMMON_UPLOAD_CHUNK_SIZE
from opensearch_py_ml.ml_commons_integration.upload.ml_common_upload_client import MLCommonUploadClient, DEFAULT_ML_COMMON_UPLOAD_CHUNK_SIZE


class MLCommonClient:
Expand All @@ -17,5 +17,10 @@ def __init__(self, os_client: OpenSearch):
self._client = os_client
self._upload_client = MLCommonUploadClient(os_client)

def put_model(self, model_path: str, chunk_size: int = DEFAULT_ML_COMMON_UPLOAD_CHUNK_SIZE):
self._upload_client.put_model(model_path, chunk_size)
def put_model(self,
model_path: str,
model_name: str,
version_number: int,
chunk_size: int = DEFAULT_ML_COMMON_UPLOAD_CHUNK_SIZE,
verbose: bool = False) -> None:
self._upload_client.put_model(model_path, model_name, version_number, chunk_size, verbose)
6 changes: 6 additions & 0 deletions opensearch_py_ml/ml_commons_integration/ml_common_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""
Copyright OpenSearch Contributors
SPDX-License-Identifier: Apache-2.0
"""

ML_BASE_URI = '_plugins/_ml'
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,17 @@
SPDX-License-Identifier: Apache-2.0
"""

DEFAULT_ML_COMMON_UPLOAD_CHUNK_SIZE = 4 * 1024 * 1024

from opensearchpy import OpenSearch
import os
import math
from typing import Iterable
import base64
from opensearch_py_ml.ml_commons_integration.ml_common_utils import ML_BASE_URI

from tqdm.auto import tqdm # type: ignore

DEFAULT_ML_COMMON_UPLOAD_CHUNK_SIZE = 10_000_000 # 10MB


class MLCommonUploadClient:
"""
Expand All @@ -14,5 +22,29 @@ class MLCommonUploadClient:
def __init__(self, os_client: OpenSearch):
self._client = os_client

def put_model(self, model_path: str, chunk_size: int = DEFAULT_ML_COMMON_UPLOAD_CHUNK_SIZE):
pass
def put_model(self,
model_path: str,
model_name: str,
version_number: int,
chunk_size: int = DEFAULT_ML_COMMON_UPLOAD_CHUNK_SIZE,
verbose: bool = False) -> None:
total_model_size = os.stat(model_path).st_size

def model_file_chunk_generator() -> Iterable[str]:
with open(model_path, "rb") as f:
while True:
data = f.read(chunk_size)
if not data:
break
yield data # check if we actually need to do base64 encoding

to_iterate_over = enumerate(model_file_chunk_generator())
if verbose:
to_iterate_over = tqdm(to_iterate_over)

for i, chunk in to_iterate_over:
self._client.transport.perform_request(
method="POST",
url=f"/{ML_BASE_URI}/custom_model/upload/{model_name}/{version_number}/{i}",
body=chunk
)
4 changes: 2 additions & 2 deletions opensearch_py_ml/sagemaker_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from sagemaker import RealTimePredictor, Session

DEFAULT_UPLOAD_CHUNK_SIZE = 1000
DEFAULT_SAGEMAKER_UPLOAD_CHUNK_SIZE = 1000


def make_sagemaker_prediction(endpoint_name: str,
Expand Down Expand Up @@ -52,7 +52,7 @@ def make_sagemaker_prediction(endpoint_name: str,
if column_order is not None:
data = data[column_order]
if chunksize is None:
chunksize = DEFAULT_UPLOAD_CHUNK_SIZE
chunksize = DEFAULT_SAGEMAKER_UPLOAD_CHUNK_SIZE

indices = [index for index, _ in data.iterrows(sort_index=sort_index)]

Expand Down

0 comments on commit 7f3d7dd

Please sign in to comment.