Skip to content

Commit

Permalink
Update historical retrieval integration test for on demand feature vi…
Browse files Browse the repository at this point in the history
…ews (#1836)

Signed-off-by: Achal Shah <achals@gmail.com>
  • Loading branch information
achals authored Sep 3, 2021
1 parent fda5b55 commit b0f38ad
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
RedshiftDataSourceCreator,
)
from tests.integration.feature_repos.universal.feature_views import (
conv_rate_plus_100_feature_view,
create_customer_daily_profile_feature_view,
create_driver_hourly_stats_feature_view,
)
Expand Down Expand Up @@ -126,11 +127,15 @@ def construct_universal_data_sources(
def construct_universal_feature_views(
data_sources: Dict[str, DataSource],
) -> Dict[str, FeatureView]:
driver_hourly_stats = create_driver_hourly_stats_feature_view(
data_sources["driver"]
)
return {
"customer": create_customer_daily_profile_feature_view(
data_sources["customer"]
),
"driver": create_driver_hourly_stats_feature_view(data_sources["driver"]),
"driver": driver_hourly_stats,
"driver_odfv": conv_rate_plus_100_feature_view({"driver": driver_hourly_stats}),
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ def get_expected_training_df(
for col, typ in expected_column_types.items():
expected_df[col] = expected_df[col].astype(typ)

conv_feature_name = "driver_stats__conv_rate" if full_feature_names else "conv_rate"
expected_df["conv_rate_plus_100"] = expected_df[conv_feature_name] + 100

return expected_df


Expand All @@ -150,13 +153,14 @@ def test_historical_features(environment, universal_data_sources, full_feature_n
datasets["driver"],
datasets["orders"],
)
customer_fv, driver_fv = (
customer_fv, driver_fv, driver_odfv = (
feature_views["customer"],
feature_views["driver"],
feature_views["driver_odfv"],
)

feast_objects = []
feast_objects.extend([customer_fv, driver_fv, driver(), customer()])
feast_objects.extend([customer_fv, driver_fv, driver_odfv, driver(), customer()])
store.apply(feast_objects)

entity_df_query = None
Expand Down Expand Up @@ -188,6 +192,7 @@ def test_historical_features(environment, universal_data_sources, full_feature_n
"customer_profile:current_balance",
"customer_profile:avg_passenger_count",
"customer_profile:lifetime_trip_count",
"conv_rate_plus_100",
],
full_feature_names=full_feature_names,
)
Expand Down Expand Up @@ -221,14 +226,21 @@ def test_historical_features(environment, universal_data_sources, full_feature_n
actual_df_from_sql_entities, expected_df, check_dtype=False,
)

expected_df_from_arrow = expected_df.drop(columns=["conv_rate_plus_100"])
table_from_sql_entities = job_from_sql.to_arrow()
df_from_sql_entities = (
table_from_sql_entities.to_pandas()[expected_df.columns]
table_from_sql_entities.to_pandas()[expected_df_from_arrow.columns]
.sort_values(by=[event_timestamp, "order_id", "driver_id", "customer_id"])
.drop_duplicates()
.reset_index(drop=True)
)
assert_frame_equal(actual_df_from_sql_entities, df_from_sql_entities)

for col in df_from_sql_entities.columns:
expected_df_from_arrow[col] = expected_df_from_arrow[col].astype(
df_from_sql_entities[col].dtype
)

assert_frame_equal(expected_df_from_arrow, df_from_sql_entities)

job_from_df = store.get_historical_features(
entity_df=orders_df,
Expand All @@ -238,6 +250,7 @@ def test_historical_features(environment, universal_data_sources, full_feature_n
"customer_profile:current_balance",
"customer_profile:avg_passenger_count",
"customer_profile:lifetime_trip_count",
"conv_rate_plus_100",
],
full_feature_names=full_feature_names,
)
Expand All @@ -250,7 +263,7 @@ def test_historical_features(environment, universal_data_sources, full_feature_n
print(str(f"Time to execute job_from_df.to_df() = '{(end_time - start_time)}'\n"))

assert sorted(expected_df.columns) == sorted(actual_df_from_df_entities.columns)
expected_df = (
expected_df: pd.DataFrame = (
expected_df.sort_values(
by=[event_timestamp, "order_id", "driver_id", "customer_id"]
)
Expand All @@ -268,11 +281,20 @@ def test_historical_features(environment, universal_data_sources, full_feature_n
expected_df, actual_df_from_df_entities, check_dtype=False,
)

table_from_df_entities = job_from_df.to_arrow().to_pandas()
# on demand features is only plumbed through to to_df for now.
table_from_df_entities: pd.DataFrame = job_from_df.to_arrow().to_pandas()
actual_df_from_df_entities_for_table = actual_df_from_df_entities.drop(
columns=["conv_rate_plus_100"]
)
assert "conv_rate_plus_100" not in table_from_df_entities.columns

columns_expected_in_table = expected_df.columns.tolist()
columns_expected_in_table.remove("conv_rate_plus_100")

table_from_df_entities = (
table_from_df_entities[expected_df.columns]
table_from_df_entities[columns_expected_in_table]
.sort_values(by=[event_timestamp, "order_id", "driver_id", "customer_id"])
.drop_duplicates()
.reset_index(drop=True)
)
assert_frame_equal(actual_df_from_df_entities, table_from_df_entities)
assert_frame_equal(actual_df_from_df_entities_for_table, table_from_df_entities)
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@

import pandas as pd
import pytest
from integration.feature_repos.universal.feature_views import (
conv_rate_plus_100_feature_view,
)

from tests.integration.feature_repos.repo_configuration import (
construct_universal_feature_views,
Expand All @@ -20,10 +17,9 @@ def test_online_retrieval(environment, universal_data_sources, full_feature_name
fs = environment.feature_store
entities, datasets, data_sources = universal_data_sources
feature_views = construct_universal_feature_views(data_sources)
odfv = conv_rate_plus_100_feature_view(inputs={"driver": feature_views["driver"]})
feast_objects = []
feast_objects.extend(feature_views.values())
feast_objects.extend([odfv, driver(), customer()])
feast_objects.extend([driver(), customer()])
fs.apply(feast_objects)
fs.materialize(environment.start_date, environment.end_date)

Expand Down

0 comments on commit b0f38ad

Please sign in to comment.