Skip to content

Commit

Permalink
added initial connection to predicting with sagemaker
Browse files Browse the repository at this point in the history
  • Loading branch information
LEFTA98 committed Jul 19, 2022
1 parent bfcd903 commit 3644f84
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 2 deletions.
2 changes: 2 additions & 0 deletions eland/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from .index import Index
from .ndframe import NDFrame
from .series import Series
from .sagemaker_tools import make_sagemaker_prediction

__all__ = [
"DataFrame",
Expand All @@ -41,4 +42,5 @@
"eland_to_pandas",
"csv_to_eland",
"SortOrder",
"make_sagemaker_prediction"
]
37 changes: 37 additions & 0 deletions eland/sagemaker_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import json

import numpy as np
from eland import DataFrame
from typing import List, Optional

from sagemaker import RealTimePredictor


def make_sagemaker_prediction(endpoint_name: str,
data: DataFrame,
column_order: Optional[List[str]] = None
) -> np.array:
"""
Make a prediction on an eland dataframe using a deployed SageMaker model endpoint.
Parameters
----------
endpoint_name: string representing name of SageMaker endpoint
data: eland DataFrame representing data to feed to SageMaker model. The dataframe must match the input datatypes
of the model and also have the correct number of columns.
column_order: list of string values representing the proper order that the columns should be read into the
SageMaker model. Must be a permutation of the column names of the eland DataFrame.
Returns
----------
np.array representing the output of the model on input data
"""
predictor = RealTimePredictor(endpoint=endpoint_name, content_type='text/csv')

test_data = data
if column_order is not None:
test_data = test_data[column_order]

preds = predictor.predict(test_data.to_csv(header=False, index=False))
preds = np.array(json.loads(preds.decode('utf-8'))['probabilities'])
return preds
2 changes: 2 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ elasticsearch>=8,<9
pandas>=1.2,<2
matplotlib<4
numpy<2
opensearch-py>=2
sagemaker>=1.72,<2
tqdm<5

#
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ elasticsearch>=8,<9
pandas>=1.2,<2
matplotlib<4
numpy<2
opensearch-py>=2
opensearch-py>=2
sagemaker>=1.72,<2
1 change: 0 additions & 1 deletion tests/dataframe/test_es_query_pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def test_es_query_allows_query_in_dict(self):
assert len(left) > 0
assert_eland_frame_equal(left, right)

# @pytest.mark.skip(reason="OpenSearch currently does not support geosearch")
def test_es_query_geo_location(self):
df = self.ed_ecommerce()
cur_nearby = df.es_query(
Expand Down

0 comments on commit 3644f84

Please sign in to comment.