Skip to content

Commit

Permalink
reverted indexing change for sagemaker predict
Browse files Browse the repository at this point in the history
  • Loading branch information
LEFTA98 committed Jul 28, 2022
1 parent 5f2d9aa commit 5679b13
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 14 deletions.
10 changes: 4 additions & 6 deletions eland/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1279,14 +1279,14 @@ def search_yield_pandas_dataframes(
if sort_params:
body["sort"] = [sort_params]

i = 1
# i = 1
for hits in _search_yield_hits(
query_compiler=query_compiler, body=body, max_number_of_hits=result_size, sort_index=sort_index
):
df = query_compiler._es_results_to_pandas(hits)
df = self._apply_df_post_processing(df, post_processing)
df.to_csv(f'debug_{i}.csv')
i += 1
# df.to_csv(f'debug_{i}.csv')
# i += 1
yield df

def index_count(self, query_compiler: "QueryCompiler", field: str) -> int:
Expand Down Expand Up @@ -1591,6 +1591,4 @@ def get_os_client(cluster_url='https://localhost:9200',
ed_df = ed.DataFrame(client, 'sagemaker_demo_data')

indices = [index for index, _ in ed_df.iterrows('_doc')]
print(len(set(indices)))

pass
print(len(set(indices)))
17 changes: 11 additions & 6 deletions eland/sagemaker_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,18 @@
from typing import List, Optional
from math import ceil

from sagemaker import RealTimePredictor
from sagemaker import RealTimePredictor, Session

DEFAULT_UPLOAD_CHUNK_SIZE = 1000


def make_sagemaker_prediction(endpoint_name: str,
data: DataFrame,
target_column: str,
sagemaker_session: Optional[Session] = None,
column_order: Optional[List[str]] = None,
chunksize: int = None
chunksize: int = None,
sort_index: Optional[str] = '_doc'
)-> np.array:
"""
Make a prediction on an eland dataframe using a deployed SageMaker model endpoint.
Expand All @@ -28,31 +30,34 @@ def make_sagemaker_prediction(endpoint_name: str,
data: eland DataFrame representing data to feed to SageMaker model. The dataframe must match the input datatypes
of the model and also have the correct number of columns.
target_column: column name of the dependent variable in the data.
sagemaker_session: A SageMaker Session object, used for SageMaker interactions (default: None). If not specified,
one is created using the default AWS configuration chain.
column_order: list of string values representing the proper order that the columns of independent variables should
be read into the SageMaker model. Must be a permutation of the column names of the eland DataFrame.
chunksize: how large each chunk being uploaded to sagemaker should be.
sort_index: the index with which to sort the predictions by. Defaults to '_doc', an internal identifier for
Lucene that optimizes performance.
Returns
----------
np.array representing the output of the model on input data
"""
predictor = RealTimePredictor(endpoint=endpoint_name, content_type='text/csv')
predictor = RealTimePredictor(endpoint=endpoint_name, sagemaker_session=sagemaker_session, content_type='text/csv')
data = data.drop(columns=target_column)

if column_order is not None:
data = data[column_order]
if chunksize is None:
chunksize = DEFAULT_UPLOAD_CHUNK_SIZE

indices = [index for index, _ in data.iterrows(sort_index="_id")]
indices = [index for index, _ in data.iterrows(sort_index=sort_index)]

to_return = []

for i in range(ceil(data.shape[0] / chunksize)):
df_slice = indices[chunksize * i: min(len(indices), chunksize * (i+1))]
to_process = data.filter(df_slice, axis=0)
preds = predictor.predict(to_process.to_csv(header=False, index=False))
preds = np.array(json.loads(preds.decode('utf-8'))['probabilities'])
to_return.append(preds)

return indices, np.concatenate(to_return, axis=0)
return indices, to_return
3 changes: 1 addition & 2 deletions tests/setup_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,7 @@ def _update_max_compilations_limit(es: OpenSearch, limit="10000/1m"):
body={
"transient": {
"script.max_compilations_rate": "use-context",
"senecccbnijrkjjerviltfgjlibhffleggivlgcrhgthi"
"cript.context.field.max_compilations_rate": limit,
"script.context.field.max_compilations_rate": limit,
}
}
)
Expand Down

0 comments on commit 5679b13

Please sign in to comment.