Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Rework get_online_features helper functions #5060

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2018,7 +2018,7 @@ def _retrieve_from_online_store_v2(
entity_key_dict[key] = []
entity_key_dict[key].append(python_value)

table_entity_values, idxs = utils._get_unique_entities_from_values(
table_entity_values, idxs, output_len = utils._get_unique_entities_from_values(
entity_key_dict,
)

Expand All @@ -2040,6 +2040,7 @@ def _retrieve_from_online_store_v2(
full_feature_names=False,
requested_features=features_to_request,
table=table,
output_len=output_len,
)

return OnlineResponse(online_features_response)
Expand Down
10 changes: 6 additions & 4 deletions sdk/python/feast/infra/online_stores/online_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def get_online_features(

for table, requested_features in grouped_refs:
# Get the correct set of entity values with the correct join keys.
table_entity_values, idxs = utils._get_unique_entities(
table_entity_values, idxs, output_len = utils._get_unique_entities(
table,
join_key_values,
entity_name_to_join_key_map,
Expand Down Expand Up @@ -215,6 +215,7 @@ def get_online_features(
full_feature_names,
requested_features,
table,
output_len,
)

if requested_on_demand_feature_views:
Expand Down Expand Up @@ -274,7 +275,7 @@ async def get_online_features_async(

async def query_table(table, requested_features):
# Get the correct set of entity values with the correct join keys.
table_entity_values, idxs = utils._get_unique_entities(
table_entity_values, idxs, output_len = utils._get_unique_entities(
table,
join_key_values,
entity_name_to_join_key_map,
Expand All @@ -290,7 +291,7 @@ async def query_table(table, requested_features):
requested_features=requested_features,
)

return idxs, read_rows
return idxs, read_rows, output_len

all_responses = await asyncio.gather(
*[
Expand All @@ -299,7 +300,7 @@ async def query_table(table, requested_features):
]
)

for (idxs, read_rows), (table, requested_features) in zip(
for (idxs, read_rows, output_len), (table, requested_features) in zip(
all_responses, grouped_refs
):
feature_data = utils._convert_rows_to_protobuf(
Expand All @@ -314,6 +315,7 @@ async def query_table(table, requested_features):
full_feature_names,
requested_features,
table,
output_len,
)

if requested_on_demand_feature_views:
Expand Down
137 changes: 67 additions & 70 deletions sdk/python/feast/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,16 +490,28 @@ def _group_feature_refs(
return fvs_result, odfvs_result


def apply_list_mapping(
lst: Iterable[Any], mapping_indexes: Iterable[List[int]]
) -> Iterable[Any]:
output_len = sum(len(item) for item in mapping_indexes)
output = [None] * output_len
for elem, destinations in zip(lst, mapping_indexes):
def construct_response_feature_vector(
values_vector: Iterable[Any],
statuses_vector: Iterable[Any],
timestamp_vector: Iterable[Any],
mapping_indexes: Iterable[List[int]],
output_len: int,
) -> GetOnlineFeaturesResponse.FeatureVector:
values_output: Iterable[Any] = [None] * output_len
statuses_output: Iterable[Any] = [None] * output_len
timestamp_output: Iterable[Any] = [None] * output_len

for i, destinations in enumerate(mapping_indexes):
for idx in destinations:
output[idx] = elem

return output
values_output[idx] = values_vector[i] # type: ignore[index]
statuses_output[idx] = statuses_vector[i] # type: ignore[index]
timestamp_output[idx] = timestamp_vector[i] # type: ignore[index]

return GetOnlineFeaturesResponse.FeatureVector(
values=values_output,
statuses=statuses_output,
event_timestamps=timestamp_output,
)


def _augment_response_with_on_demand_transforms(
Expand Down Expand Up @@ -674,7 +686,7 @@ def _get_unique_entities(
table: "FeatureView",
join_key_values: Dict[str, List[ValueProto]],
entity_name_to_join_key_map: Dict[str, str],
) -> Tuple[Tuple[Dict[str, ValueProto], ...], Tuple[List[int], ...]]:
) -> Tuple[Tuple[Dict[str, ValueProto], ...], Tuple[List[int], ...], int]:
"""Return the set of unique composite Entities for a Feature View and the indexes at which they appear.

This method allows us to query the OnlineStore for data we need only once
Expand Down Expand Up @@ -712,7 +724,7 @@ def _get_unique_entities(

# If there are no rows, return empty tuples.
if not rowise:
return (), ()
return (), (), 0

# Sort rowise so that rows with the same join key values are adjacent.
rowise.sort(key=lambda row: tuple(getattr(x, x.WhichOneof("val")) for x in row[1]))
Expand All @@ -725,16 +737,16 @@ def _get_unique_entities(

# If no groups were formed (should not happen for valid input), return empty tuples.
if not groups:
return (), ()
return (), (), 0

# Unpack the unique entities and their original row indexes.
unique_entities, indexes = tuple(zip(*groups))
return unique_entities, indexes
return unique_entities, indexes, len(rowise)


def _get_unique_entities_from_values(
table_entity_values: Dict[str, List[ValueProto]],
) -> Tuple[Tuple[Dict[str, ValueProto], ...], Tuple[List[int], ...]]:
) -> Tuple[Tuple[Dict[str, ValueProto], ...], Tuple[List[int], ...], int]:
"""Return the set of unique composite Entities for a Feature View and the indexes at which they appear.

This method allows us to query the OnlineStore for data we need only once
Expand All @@ -758,7 +770,7 @@ def _get_unique_entities_from_values(
]
)
)
return unique_entities, indexes
return unique_entities, indexes, len(rowise)


def _drop_unneeded_columns(
Expand Down Expand Up @@ -835,6 +847,7 @@ def _populate_response_from_feature_data(
full_feature_names: bool,
requested_features: Iterable[str],
table: "FeatureView",
output_len: int,
):
"""Populate the GetOnlineFeaturesResponse with feature data.

Expand All @@ -853,33 +866,22 @@ def _populate_response_from_feature_data(
requested_features: The names of the features in `feature_data`. This should be ordered in the same way as the
data in `feature_data`.
table: The FeatureView that `feature_data` was retrieved from.
output_len: The number of result rows in `online_features_response`.
"""
# Add the feature names to the response.
table_name = table.projection.name_to_use()
requested_feature_refs = [
(
f"{table.projection.name_to_use()}__{feature_name}"
if full_feature_names
else feature_name
)
f"{table_name}__{feature_name}" if full_feature_names else feature_name
for feature_name in requested_features
]
online_features_response.metadata.feature_names.val.extend(requested_feature_refs)

timestamps, statuses, values = zip(*feature_data)

# Populate the result with data fetched from the OnlineStore
# which is guaranteed to be aligned with `requested_features`.
for (
feature_idx,
(timestamp_vector, statuses_vector, values_vector),
) in enumerate(zip(zip(*timestamps), zip(*statuses), zip(*values))):
online_features_response.results.append(
GetOnlineFeaturesResponse.FeatureVector(
values=apply_list_mapping(values_vector, indexes),
statuses=apply_list_mapping(statuses_vector, indexes),
event_timestamps=apply_list_mapping(timestamp_vector, indexes),
)
# Process each feature vector in a single pass
for timestamp_vector, statuses_vector, values_vector in feature_data:
response_vector = construct_response_feature_vector(
values_vector, statuses_vector, timestamp_vector, indexes, output_len
)
online_features_response.results.append(response_vector)


def _populate_response_from_feature_data_v2(
Expand All @@ -891,6 +893,7 @@ def _populate_response_from_feature_data_v2(
indexes: Iterable[List[int]],
online_features_response: GetOnlineFeaturesResponse,
requested_features: Iterable[str],
output_len: int,
):
"""Populate the GetOnlineFeaturesResponse with feature data.

Expand All @@ -908,6 +911,7 @@ def _populate_response_from_feature_data_v2(
"customer_fv__daily_transactions").
requested_features: The names of the features in `feature_data`. This should be ordered in the same way as the
data in `feature_data`.
output_len: The number of result rows in `online_features_response`.
"""
# Add the feature names to the response.
requested_feature_refs = [(feature_name) for feature_name in requested_features]
Expand All @@ -917,17 +921,11 @@ def _populate_response_from_feature_data_v2(

# Populate the result with data fetched from the OnlineStore
# which is guaranteed to be aligned with `requested_features`.
for (
feature_idx,
(timestamp_vector, statuses_vector, values_vector),
) in enumerate(zip(zip(*timestamps), zip(*statuses), zip(*values))):
online_features_response.results.append(
GetOnlineFeaturesResponse.FeatureVector(
values=apply_list_mapping(values_vector, indexes),
statuses=apply_list_mapping(statuses_vector, indexes),
event_timestamps=apply_list_mapping(timestamp_vector, indexes),
)
for timestamp_vector, statuses_vector, values_vector in feature_data:
response_vector = construct_response_feature_vector(
values_vector, statuses_vector, timestamp_vector, indexes, output_len
)
online_features_response.results.append(response_vector)


def _convert_entity_key_to_proto_to_dict(
Expand Down Expand Up @@ -1246,33 +1244,32 @@ def _convert_rows_to_protobuf(
requested_features: List[str],
read_rows: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]],
) -> List[Tuple[List[Timestamp], List["FieldStatus.ValueType"], List[ValueProto]]]:
# Each row is a set of features for a given entity key.
# We only need to convert the data to Protobuf once.
# Pre-calculate the length to avoid repeated calculations
n_rows = len(read_rows)

# Create single instances of commonly used values
null_value = ValueProto()
read_row_protos = []
for read_row in read_rows:
row_ts_proto = Timestamp()
row_ts, feature_data = read_row
# TODO (Ly): reuse whatever timestamp if row_ts is None?
if row_ts is not None:
row_ts_proto.FromDatetime(row_ts)
event_timestamps = [row_ts_proto] * len(requested_features)
if feature_data is None:
statuses = [FieldStatus.NOT_FOUND] * len(requested_features)
values = [null_value] * len(requested_features)
else:
statuses = []
values = []
for feature_name in requested_features:
# Make sure order of data is the same as requested_features.
if feature_name not in feature_data:
statuses.append(FieldStatus.NOT_FOUND)
values.append(null_value)
else:
statuses.append(FieldStatus.PRESENT)
values.append(feature_data[feature_name])
read_row_protos.append((event_timestamps, statuses, values))
return read_row_protos
null_status = FieldStatus.NOT_FOUND
null_timestamp = Timestamp()
present_status = FieldStatus.PRESENT

requested_features_vectors = []
for feature_name in requested_features:
ts_vector = [null_timestamp] * n_rows
status_vector = [null_status] * n_rows
value_vector = [null_value] * n_rows
for idx, read_row in enumerate(read_rows):
row_ts_proto = Timestamp()
row_ts, feature_data = read_row
# TODO (Ly): reuse whatever timestamp if row_ts is None?
if row_ts is not None:
row_ts_proto.FromDatetime(row_ts)
ts_vector[idx] = row_ts_proto
if (feature_data is not None) and (feature_name in feature_data):
status_vector[idx] = present_status
value_vector[idx] = feature_data[feature_name]
requested_features_vectors.append((ts_vector, status_vector, value_vector))
return requested_features_vectors


def has_all_tags(
Expand Down
6 changes: 4 additions & 2 deletions sdk/python/tests/unit/test_unit_feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_get_unique_entities_success():
projection=MockFeatureViewProjection(join_key_map={}),
)

unique_entities, indexes = utils._get_unique_entities(
unique_entities, indexes, output_len = utils._get_unique_entities(
table=fv,
join_key_values=entity_values,
entity_name_to_join_key_map=entity_name_to_join_key_map,
Expand All @@ -51,6 +51,7 @@ def test_get_unique_entities_success():

assert unique_entities == expected_entities
assert indexes == expected_indexes
assert output_len == 3


def test_get_unique_entities_missing_join_key_success():
Expand All @@ -74,7 +75,7 @@ def test_get_unique_entities_missing_join_key_success():
projection=MockFeatureViewProjection(join_key_map={}),
)

unique_entities, indexes = utils._get_unique_entities(
unique_entities, indexes, output_len = utils._get_unique_entities(
table=fv,
join_key_values=entity_values,
entity_name_to_join_key_map=entity_name_to_join_key_map,
Expand All @@ -87,6 +88,7 @@ def test_get_unique_entities_missing_join_key_success():

assert unique_entities == expected_entities
assert indexes == expected_indexes
assert output_len == 3
# We're not say anything about the entity_1 missing from the unique_entities list
assert "entity_1" not in [entity.keys() for entity in unique_entities]

Expand Down
Loading