Skip to content

Commit

Permalink
feat: Support entity fields in feature view schema parameter by dro…
Browse files Browse the repository at this point in the history
…pping them (#2568)

Signed-off-by: Felix Wang <wangfelix98@gmail.com>
  • Loading branch information
felixwang9817 authored Apr 19, 2022
1 parent e8e418e commit c8fcc35
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 8 deletions.
31 changes: 29 additions & 2 deletions sdk/python/feast/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class Entity:
owner: The owner of the entity, typically the email of the primary maintainer.
created_timestamp: The time when the entity was created.
last_updated_timestamp: The time when the entity was last updated.
join_keys: A list of property that uniquely identifies different entities within the
join_keys: A list of properties that uniquely identifies different entities within the
collection. This is meant to replace the `join_key` parameter, but currently only
supports a list of size one.
"""
Expand All @@ -67,7 +67,25 @@ def __init__(
owner: str = "",
join_keys: Optional[List[str]] = None,
):
"""Creates an Entity object."""
"""
Creates an Entity object.
Args:
name: The unique name of the entity.
value_type: The type of the entity, such as string or float.
description: A human-readable description.
join_key (deprecated): A property that uniquely identifies different entities within the
collection. The join_key property is typically used for joining entities
with their associated features. If not specified, defaults to the name.
tags: A dictionary of key-value pairs to store arbitrary metadata.
owner: The owner of the entity, typically the email of the primary maintainer.
join_keys: A list of properties that uniquely identifies different entities within the
collection. This is meant to replace the `join_key` parameter, but currently only
supports a list of size one.
Raises:
ValueError: Parameters are specified incorrectly.
"""
if len(args) == 1:
warnings.warn(
(
Expand All @@ -88,6 +106,15 @@ def __init__(

self.value_type = value_type

if join_key:
warnings.warn(
(
"The `join_key` parameter is being deprecated in favor of the `join_keys` parameter. "
"Please switch from using `join_key` to `join_keys`. Feast 0.22 and onwards will not "
"support the `join_key` parameter."
),
DeprecationWarning,
)
self.join_keys = join_keys or []
if join_keys and len(join_keys) > 1:
raise ValueError(
Expand Down
5 changes: 5 additions & 0 deletions sdk/python/feast/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,13 @@ def update_feature_views_with_inferred_features(
config: The config for the current feature store.
"""
entity_name_to_join_key_map = {entity.name: entity.join_key for entity in entities}
join_keys = entity_name_to_join_key_map.values()

for fv in fvs:
# First drop all Entity fields. Then infer features if necessary.
fv.schema = [field for field in fv.schema if field.name not in join_keys]
fv.features = [field for field in fv.features if field.name not in join_keys]

if not fv.features:
columns_to_exclude = {
fv.batch_source.timestamp_field,
Expand Down
59 changes: 53 additions & 6 deletions sdk/python/tests/integration/registration/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@
from feast.inference import (
update_data_sources_with_inferred_event_timestamp_col,
update_entities_with_inferred_types_from_feature_views,
update_feature_views_with_inferred_features,
)
from feast.infra.offline_stores.contrib.spark_offline_store.spark_source import (
SparkSource,
)
from feast.on_demand_feature_view import on_demand_feature_view
from feast.types import PrimitiveFeastType, String, UnixTimestamp
from feast.types import Float32, PrimitiveFeastType, String, UnixTimestamp
from tests.utils.data_source_utils import (
prep_file_source,
simple_bq_source_using_query_arg,
Expand Down Expand Up @@ -168,15 +169,14 @@ def test_update_data_sources_with_inferred_event_timestamp_col(universal_data_so
def test_on_demand_features_type_inference():
# Create Feature Views
date_request = RequestSource(
name="date_request",
schema=[Field(name="some_date", dtype=PrimitiveFeastType.UNIX_TIMESTAMP)],
name="date_request", schema=[Field(name="some_date", dtype=UnixTimestamp)],
)

@on_demand_feature_view(
sources={"date_request": date_request},
features=[
Feature(name="output", dtype=ValueType.UNIX_TIMESTAMP),
Feature(name="string_output", dtype=ValueType.STRING),
schema=[
Field(name="output", dtype=UnixTimestamp),
Field(name="string_output", dtype=String),
],
)
def test_view(features_df: pd.DataFrame) -> pd.DataFrame:
Expand Down Expand Up @@ -285,3 +285,50 @@ def test_view_with_missing_feature(features_df: pd.DataFrame) -> pd.DataFrame:

with pytest.raises(SpecifiedFeaturesNotPresentError):
test_view_with_missing_feature.infer_features()


def test_update_feature_views_with_inferred_features():
file_source = FileSource(name="test", path="test path")
entity1 = Entity(name="test1", join_key="test_column_1")
entity2 = Entity(name="test2", join_key="test_column_2")
feature_view_1 = FeatureView(
name="test1",
entities=[entity1],
schema=[
Field(name="feature", dtype=Float32),
Field(name="test_column_1", dtype=String),
],
source=file_source,
)
feature_view_2 = FeatureView(
name="test2",
entities=[entity1, entity2],
schema=[
Field(name="feature", dtype=Float32),
Field(name="test_column_1", dtype=String),
Field(name="test_column_2", dtype=String),
],
source=file_source,
)

assert len(feature_view_1.schema) == 2
assert len(feature_view_1.features) == 2

# The entity field should be deleted from the schema and features of the feature view.
update_feature_views_with_inferred_features(
[feature_view_1], [entity1], RepoConfig(provider="local", project="test")
)
assert len(feature_view_1.schema) == 1
assert len(feature_view_1.features) == 1

assert len(feature_view_2.schema) == 3
assert len(feature_view_2.features) == 3

# The entity fields should be deleted from the schema and features of the feature view.
update_feature_views_with_inferred_features(
[feature_view_2],
[entity1, entity2],
RepoConfig(provider="local", project="test"),
)
assert len(feature_view_2.schema) == 1
assert len(feature_view_2.features) == 1

0 comments on commit c8fcc35

Please sign in to comment.