Skip to content

Commit

Permalink
chore: Force entity inference without modifying fv.schema (feast-de…
Browse files Browse the repository at this point in the history
…v#3448)

* Remove unnecessary wrapper `create_feature_view`

Signed-off-by: Felix Wang <wangfelix98@gmail.com>

* Add `infer_entities` option to `driver_feature_view`

Signed-off-by: Felix Wang <wangfelix98@gmail.com>

* Force entity inference

Signed-off-by: Felix Wang <wangfelix98@gmail.com>

Signed-off-by: Felix Wang <wangfelix98@gmail.com>
  • Loading branch information
felixwang9817 authored Jan 15, 2023
1 parent 73930f6 commit 5a83c6e
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
def driver_feature_view(
data_source: DataSource,
name="test_correctness",
infer_entities: bool = False,
infer_features: bool = False,
dtype: FeastType = Float32,
entity_type: FeastType = Int64,
Expand All @@ -34,7 +35,7 @@ def driver_feature_view(
return FeatureView(
name=name,
entities=[d],
schema=[Field(name=d.join_key, dtype=entity_type)]
schema=([] if infer_entities else [Field(name=d.join_key, dtype=entity_type)])
+ ([] if infer_features else [Field(name="value", dtype=dtype)]),
ttl=timedelta(days=5),
source=data_source,
Expand Down
62 changes: 22 additions & 40 deletions sdk/python/tests/integration/registration/test_universal_types.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import pandas as pd
Expand All @@ -12,6 +12,7 @@
from feast.types import (
Array,
Bool,
FeastType,
Float32,
Float64,
Int32,
Expand Down Expand Up @@ -42,20 +43,15 @@ def test_entity_inference_types_match(environment, entity_type):
destination_name=f"entity_type_{entity_type.name.lower()}",
field_mapping={"ts_1": "ts"},
)
fv = create_feature_view(
f"fv_entity_type_{entity_type.name.lower()}",
feature_dtype="int32",
feature_is_list=False,
has_empty_list=False,
fv = driver_feature_view(
data_source=data_source,
name=f"fv_entity_type_{entity_type.name.lower()}",
infer_entities=True, # Forces entity inference by not including a field for the entity.
dtype=_get_feast_type("int32", False),
entity_type=entity_type,
)

# TODO(felixwang9817): Refactor this by finding a better way to force type inference.
# Override the schema and entity_columns to force entity inference.
entity = driver()
fv.schema = list(filter(lambda x: x.name != entity.join_key, fv.schema))
fv.entity_columns = []
fs.apply([fv, entity])

entity_type_to_expected_inferred_entity_type = {
Expand Down Expand Up @@ -88,12 +84,10 @@ def test_feature_get_historical_features_types_match(
config, data_source, fv = offline_types_test_fixtures
fs = environment.feature_store
entity = driver()
fv = create_feature_view(
"get_historical_features_types_match",
config.feature_dtype,
config.feature_is_list,
config.has_empty_list,
data_source,
fv = driver_feature_view(
data_source=data_source,
name="get_historical_features_types_match",
dtype=_get_feast_type(config.feature_dtype, config.feature_is_list),
)
fs.apply([fv, entity])

Expand Down Expand Up @@ -139,12 +133,10 @@ def test_feature_get_online_features_types_match(
):
config, data_source, fv = online_types_test_fixtures
entity = driver()
fv = create_feature_view(
"get_online_features_types_match",
config.feature_dtype,
config.feature_is_list,
config.has_empty_list,
data_source,
fv = driver_feature_view(
data_source=data_source,
name="get_online_features_types_match",
dtype=_get_feast_type(config.feature_dtype, config.feature_is_list),
)
fs = environment.feature_store
features = [fv.name + ":value"]
Expand Down Expand Up @@ -188,14 +180,8 @@ def test_feature_get_online_features_types_match(
assert isinstance(feature, expected_dtype)


def create_feature_view(
name,
feature_dtype,
feature_is_list,
has_empty_list,
data_source,
entity_type=Int64,
):
def _get_feast_type(feature_dtype: str, feature_is_list: bool) -> FeastType:
dtype: Optional[FeastType] = None
if feature_is_list is True:
if feature_dtype == "int32":
dtype = Array(Int32)
Expand All @@ -218,10 +204,8 @@ def create_feature_view(
dtype = Bool
elif feature_dtype == "datetime":
dtype = UnixTimestamp

return driver_feature_view(
data_source, name=name, dtype=dtype, entity_type=entity_type
)
assert dtype
return dtype


def assert_expected_historical_feature_types(
Expand Down Expand Up @@ -388,12 +372,10 @@ def get_fixtures(request, environment):
destination_name=destination_name,
field_mapping={"ts_1": "ts"},
)
fv = create_feature_view(
destination_name,
config.feature_dtype,
config.feature_is_list,
config.has_empty_list,
data_source,
fv = driver_feature_view(
data_source=data_source,
name=destination_name,
dtype=_get_feast_type(config.feature_dtype, config.feature_is_list),
)

return config, data_source, fv

0 comments on commit 5a83c6e

Please sign in to comment.