Skip to content

Commit

Permalink
added band-aid to fix iterating over rows
Browse files Browse the repository at this point in the history
  • Loading branch information
LEFTA98 committed Jul 25, 2022
1 parent fd017db commit 55b2f4f
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 17 deletions.
15 changes: 11 additions & 4 deletions eland/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1444,10 +1444,15 @@ def keys(self) -> pd.Index:
"""
return self.columns

def iterrows(self) -> Iterable[Tuple[Union[str, Tuple[str, ...]], pd.Series]]:
def iterrows(self, sort_index: Optional['str'] = '_doc') -> Iterable[Tuple[Union[str, Tuple[str, ...]], pd.Series]]:
"""
Iterate over eland.DataFrame rows as (index, pandas.Series) pairs.
Parameters
----------
sort_index: str, default '_doc'
What field to sort the OpenSearch data by.
Yields
------
index: index
Expand Down Expand Up @@ -1490,11 +1495,11 @@ def iterrows(self) -> Iterable[Tuple[Union[str, Tuple[str, ...]], pd.Series]]:
Cancelled False
Name: 4, dtype: object
"""
for df in self._query_compiler.search_yield_pandas_dataframes():
for df in self._query_compiler.search_yield_pandas_dataframes(sort_index=sort_index):
yield from df.iterrows()

def itertuples(
self, index: bool = True, name: Union[str, None] = "Eland"
self, index: bool = True, name: Union[str, None] = "Eland", sort_index: Optional[str] = '_doc'
) -> Iterable[Tuple[Any, ...]]:
"""
Iterate over eland.DataFrame rows as namedtuples.
Expand All @@ -1505,6 +1510,8 @@ def itertuples(
If True, return the index as the first element of the tuple.
name: str or None, default "Eland"
The name of the returned namedtuples or None to return regular tuples.
sort_index: str, default '_doc'
What field to sort the OpenSearch data by.
Returns
-------
Expand Down Expand Up @@ -1558,7 +1565,7 @@ def itertuples(
Flight(Index='3', AvgTicketPrice=181.69421554118, Cancelled=True)
Flight(Index='4', AvgTicketPrice=730.041778346198, Cancelled=False)
"""
for df in self._query_compiler.search_yield_pandas_dataframes():
for df in self._query_compiler.search_yield_pandas_dataframes(sort_index=sort_index):
yield from df.itertuples(index=index, name=name)

def aggregate(
Expand Down
38 changes: 35 additions & 3 deletions eland/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1257,7 +1257,7 @@ def to_csv(
).to_csv(**kwargs)

def search_yield_pandas_dataframes(
self, query_compiler: "QueryCompiler"
self, query_compiler: "QueryCompiler", sort_index: Optional['str'] = '_doc'
) -> Generator["pd.DataFrame", None, None]:
query_params, post_processing = self._resolve_tasks(query_compiler)

Expand All @@ -1279,11 +1279,14 @@ def search_yield_pandas_dataframes(
if sort_params:
body["sort"] = [sort_params]

# i = 1
for hits in _search_yield_hits(
query_compiler=query_compiler, body=body, max_number_of_hits=result_size
query_compiler=query_compiler, body=body, max_number_of_hits=result_size, sort_index=sort_index
):
df = query_compiler._es_results_to_pandas(hits)
df = self._apply_df_post_processing(df, post_processing)
# df.to_csv(f'debug_{i}.csv')
# i += 1
yield df

def index_count(self, query_compiler: "QueryCompiler", field: str) -> int:
Expand Down Expand Up @@ -1491,6 +1494,7 @@ def _search_yield_hits(
query_compiler: "QueryCompiler",
body: Dict[str, Any],
max_number_of_hits: Optional[int],
sort_index: Optional[str] = '_doc',
) -> Generator[List[Dict[str, Any]], None, None]:
"""
This is a generator used to initialize point in time API and query the
Expand Down Expand Up @@ -1531,7 +1535,7 @@ def _search_yield_hits(
# Pagination with 'search_after' must have a 'sort' setting.
# Using '_doc:asc' is the most efficient as reads documents
# in the order that they're written on disk in Lucene.
body.setdefault("sort", [{"_doc": "asc"}])
body.setdefault("sort", [{sort_index: "asc"}])

# Improves performance by not tracking # of hits. We only
# care about the hit itself for these queries.
Expand Down Expand Up @@ -1562,3 +1566,31 @@ def _search_yield_hits(
# to be the last sort value for this set of hits.
body["search_after"] = hits[-1]["sort"]

# if __name__ == "__main__":
# import eland as ed
# from opensearchpy import OpenSearch
#
#
# # try connecting to an actual cluster at some point
# def get_os_client(cluster_url='https://localhost:9200',
# username='admin',
# password='admin'):
# '''
# Get OpenSearch client
# :param cluster_url: cluster URL like https://ml-te-netwo-1s12ba42br23v-ff1736fa7db98ff2.elb.us-west-2.amazonaws.com:443
# :return: OpenSearch client
# '''
# client = OpenSearch(
# hosts=[cluster_url],
# http_auth=(username, password),
# verify_certs=False
# )
# return client
#
# client = get_os_client()
# ed_df = ed.DataFrame(client, 'sagemaker_demo_data')
#
# indices = [index for index, _ in ed_df.iterrows('_doc')]
# print(len(set(indices)))
#
# pass
4 changes: 2 additions & 2 deletions eland/query_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,8 +515,8 @@ def to_csv(self, **kwargs) -> Optional[str]:
"""
return self._operations.to_csv(self, **kwargs)

def search_yield_pandas_dataframes(self) -> Generator["pd.DataFrame", None, None]:
return self._operations.search_yield_pandas_dataframes(self)
def search_yield_pandas_dataframes(self, sort_index: Optional['str'] = "_doc") -> Generator["pd.DataFrame", None, None]:
return self._operations.search_yield_pandas_dataframes(self, sort_index)

# __getitem__ methods
def getitem_column_array(self, key, numeric=False):
Expand Down
17 changes: 9 additions & 8 deletions eland/sagemaker_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

def make_sagemaker_prediction(endpoint_name: str,
data: DataFrame,
target_column: str,
column_order: Optional[List[str]] = None,
chunksize: int = None
)-> np.array:
Expand All @@ -26,32 +27,32 @@ def make_sagemaker_prediction(endpoint_name: str,
endpoint_name: string representing name of SageMaker endpoint
data: eland DataFrame representing data to feed to SageMaker model. The dataframe must match the input datatypes
of the model and also have the correct number of columns.
column_order: list of string values representing the proper order that the columns should be read into the
SageMaker model. Must be a permutation of the column names of the eland DataFrame.
target_column: column name of the dependent variable in the data.
column_order: list of string values representing the proper order that the columns of independent variables should
be read into the SageMaker model. Must be a permutation of the column names of the eland DataFrame.
chunksize: how large each chunk being uploaded to sagemaker should be.
Returns
----------
np.array representing the output of the model on input data
"""
predictor = RealTimePredictor(endpoint=endpoint_name, content_type='text/csv')
test_data = data
data = data.drop(columns=target_column)

if column_order is not None:
test_data = test_data[column_order]
data = data[column_order]
if chunksize is None:
chunksize = DEFAULT_UPLOAD_CHUNK_SIZE

indices = [index for index, _ in data.iterrows()]
indices = [index for index, _ in data.iterrows(sort_index="_id")]

to_return = []

for i in range(ceil(data.shape[0] / chunksize)):
df_slice = indices[chunksize * i: min(len(indices), chunksize * (i+1))]

to_process = test_data.filter(df_slice, axis=0)
to_process = data.filter(df_slice, axis=0)
preds = predictor.predict(to_process.to_csv(header=False, index=False))
preds = np.array(json.loads(preds.decode('utf-8'))['probabilities'])
to_return.append(preds)

return np.concatenate(to_return, axis=0)
return indices, np.concatenate(to_return, axis=0)

0 comments on commit 55b2f4f

Please sign in to comment.