Skip to content

Commit

Permalink
fix: Reject undefined features when using get_historical_features o…
Browse files Browse the repository at this point in the history
…r `get_online_features` (#2665)

Reject undefined features when using `get_historical_features` or
`get_online_features`.

Signed-off-by: Abhin Chhabra <abhin.chhabra@shopify.com>
  • Loading branch information
chhabrakadabra authored May 12, 2022
1 parent 4060c3d commit 36849fb
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 3 deletions.
9 changes: 8 additions & 1 deletion sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,7 +875,7 @@ def get_historical_features(
DeprecationWarning,
)

# TODO(achal): _group_feature_refs returns the on demand feature views, but it's no passed into the provider.
# TODO(achal): _group_feature_refs returns the on demand feature views, but it's not passed into the provider.
# This is a weird interface quirk - we should revisit the `get_historical_features` to
# pass in the on demand feature views as well.
fvs, odfvs, request_fvs, request_fv_refs = _group_feature_refs(
Expand Down Expand Up @@ -2125,8 +2125,12 @@ def _group_feature_refs(
for ref in features:
view_name, feat_name = ref.split(":")
if view_name in view_index:
view_index[view_name].projection.get_feature(feat_name) # For validation
views_features[view_name].add(feat_name)
elif view_name in on_demand_view_index:
on_demand_view_index[view_name].projection.get_feature(
feat_name
) # For validation
on_demand_view_features[view_name].add(feat_name)
# Let's also add in any FV Feature dependencies here.
for input_fv_projection in on_demand_view_index[
Expand All @@ -2135,6 +2139,9 @@ def _group_feature_refs(
for input_feat in input_fv_projection.features:
views_features[input_fv_projection.name].add(input_feat.name)
elif view_name in request_view_index:
request_view_index[view_name].projection.get_feature(
feat_name
) # For validation
request_views_features[view_name].add(feat_name)
request_view_refs.add(ref)
else:
Expand Down
8 changes: 8 additions & 0 deletions sdk/python/feast/feature_view_projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,11 @@ def from_definition(base_feature_view: "BaseFeatureView"):
name_alias=None,
features=base_feature_view.features,
)

def get_feature(self, feature_name: str) -> Field:
try:
return next(field for field in self.features if field.name == feature_name)
except StopIteration:
raise KeyError(
f"Feature {feature_name} not found in projection {self.name_to_use()}"
)
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from feast.infra.offline_stores.offline_utils import (
DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL,
)
from feast.types import Int32
from feast.types import Float32, Int32
from feast.value_type import ValueType
from tests.integration.feature_repos.repo_configuration import (
construct_universal_feature_views,
Expand Down Expand Up @@ -410,6 +410,46 @@ def test_historical_features(environment, universal_data_sources, full_feature_n
)


@pytest.mark.integration
@pytest.mark.universal
@pytest.mark.parametrize("full_feature_names", [True, False], ids=lambda v: str(v))
def test_historical_features_with_shared_batch_source(
environment, universal_data_sources, full_feature_names
):
# Addresses https://github.com/feast-dev/feast/issues/2576

store = environment.feature_store

entities, datasets, data_sources = universal_data_sources
driver_stats_v1 = FeatureView(
name="driver_stats_v1",
entities=["driver"],
schema=[Field(name="avg_daily_trips", dtype=Int32)],
source=data_sources.driver,
)
driver_stats_v2 = FeatureView(
name="driver_stats_v2",
entities=["driver"],
schema=[
Field(name="avg_daily_trips", dtype=Int32),
Field(name="conv_rate", dtype=Float32),
],
source=data_sources.driver,
)

store.apply([driver(), driver_stats_v1, driver_stats_v2])

with pytest.raises(KeyError):
store.get_historical_features(
entity_df=datasets.entity_df,
features=[
# `driver_stats_v1` does not have `conv_rate`
"driver_stats_v1:conv_rate",
],
full_feature_names=full_feature_names,
).to_df()


@pytest.mark.integration
@pytest.mark.universal_offline_stores
def test_historical_features_with_missing_request_data(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
RequestDataNotFoundInEntityRowsException,
)
from feast.online_response import TIMESTAMP_POSTFIX
from feast.types import String
from feast.types import Float32, Int32, String
from feast.wait import wait_retry_backoff
from tests.integration.feature_repos.repo_configuration import (
Environment,
Expand Down Expand Up @@ -324,6 +324,60 @@ def get_online_features_dict(
return dict1


@pytest.mark.integration
@pytest.mark.universal
def test_online_retrieval_with_shared_batch_source(environment, universal_data_sources):
# Addresses https://github.com/feast-dev/feast/issues/2576

fs = environment.feature_store

entities, datasets, data_sources = universal_data_sources
driver_stats_v1 = FeatureView(
name="driver_stats_v1",
entities=["driver"],
schema=[Field(name="avg_daily_trips", dtype=Int32)],
source=data_sources.driver,
)
driver_stats_v2 = FeatureView(
name="driver_stats_v2",
entities=["driver"],
schema=[
Field(name="avg_daily_trips", dtype=Int32),
Field(name="conv_rate", dtype=Float32),
],
source=data_sources.driver,
)

fs.apply([driver(), driver_stats_v1, driver_stats_v2])

data = pd.DataFrame(
{
"driver_id": [1, 2],
"avg_daily_trips": [4, 5],
"conv_rate": [0.5, 0.3],
"event_timestamp": [
pd.to_datetime(1646263500, utc=True, unit="s"),
pd.to_datetime(1646263600, utc=True, unit="s"),
],
"created": [
pd.to_datetime(1646263500, unit="s"),
pd.to_datetime(1646263600, unit="s"),
],
}
)
fs.write_to_online_store("driver_stats_v1", data.drop("conv_rate", axis=1))
fs.write_to_online_store("driver_stats_v2", data)

with pytest.raises(KeyError):
fs.get_online_features(
features=[
# `driver_stats_v1` does not have `conv_rate`
"driver_stats_v1:conv_rate",
],
entity_rows=[{"driver_id": 1}, {"driver_id": 2}],
)


@pytest.mark.integration
@pytest.mark.universal_online_stores
@pytest.mark.parametrize("full_feature_names", [True, False], ids=lambda v: str(v))
Expand Down

0 comments on commit 36849fb

Please sign in to comment.