Skip to content

Commit

Permalink
fix: Update udf tests and add base functions to streaming fcos and fi…
Browse files Browse the repository at this point in the history
…x some nonetype errors (#2776)

* Fix lint and add comments

Signed-off-by: Kevin Zhang <kzhang@tecton.ai>

* Fix

Signed-off-by: Kevin Zhang <kzhang@tecton.ai>

* Fix lint

Signed-off-by: Kevin Zhang <kzhang@tecton.ai>
  • Loading branch information
kevjumba authored Jun 10, 2022
1 parent 83ab682 commit 331a214
Show file tree
Hide file tree
Showing 4 changed files with 220 additions and 23 deletions.
2 changes: 1 addition & 1 deletion sdk/python/feast/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ def __hash__(self):
@staticmethod
def from_proto(data_source: DataSourceProto):
watermark = None
if data_source.kafka_options.HasField("watermark"):
if data_source.kafka_options.watermark:
watermark = (
timedelta(days=0)
if data_source.kafka_options.watermark.ToNanoseconds() == 0
Expand Down
99 changes: 78 additions & 21 deletions sdk/python/feast/stream_feature_view.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import functools
import warnings
from datetime import timedelta
Expand All @@ -9,7 +10,7 @@

from feast import utils
from feast.aggregation import Aggregation
from feast.data_source import DataSource, KafkaSource
from feast.data_source import DataSource, KafkaSource, PushSource
from feast.entity import Entity
from feast.feature_view import FeatureView
from feast.field import Field
Expand Down Expand Up @@ -39,6 +40,26 @@ class StreamFeatureView(FeatureView):
"""
NOTE: Stream Feature Views are not yet fully implemented and exist to allow users to register their stream sources and
schemas with Feast.
Attributes:
name: str. The unique name of the stream feature view.
entities: Union[List[Entity], List[str]]. List of entities or entity join keys.
ttl: timedelta. The amount of time this group of features lives. A ttl of 0 indicates that
this group of features lives forever. Note that large ttl's or a ttl of 0
can result in extremely computationally intensive queries.
tags: Dict[str, str]. A dictionary of key-value pairs to store arbitrary metadata.
online: bool. Defines whether this stream feature view is used in online feature retrieval.
description: str. A human-readable description.
owner: The owner of the on demand feature view, typically the email of the primary
maintainer.
schema: List[Field] The schema of the feature view, including feature, timestamp, and entity
columns. If not specified, can be inferred from the underlying data source.
source: DataSource. The stream source of data where this group of features
is stored.
aggregations (optional): List[Aggregation]. List of aggregations registered with the stream feature view.
mode(optional): str. The mode of execution.
timestamp_field (optional): Must be specified if aggregations are specified. Defines the timestamp column on which to aggregate windows.
udf (optional): MethodType The user defined transformation function. This transformation function should have all of the corresponding imports imported within the function.
"""

def __init__(
Expand All @@ -54,18 +75,19 @@ def __init__(
schema: Optional[List[Field]] = None,
source: Optional[DataSource] = None,
aggregations: Optional[List[Aggregation]] = None,
mode: Optional[str] = "spark", # Mode of ingestion/transformation
timestamp_field: Optional[str] = "", # Timestamp for aggregation
mode: Optional[str] = "spark",
timestamp_field: Optional[str] = "",
udf: Optional[MethodType] = None,
):
warnings.warn(
"Stream Feature Views are experimental features in alpha development. "
"Some functionality may still be unstable so functionality can change in the future.",
RuntimeWarning,
)

if source is None:
raise ValueError("Stream Feature views need a source specified")
# source uses the batch_source of the kafkasource in feature_view
raise ValueError("Stream Feature views need a source to be specified")

if (
type(source).__name__ not in SUPPORTED_STREAM_SOURCES
and source.to_proto().type != DataSourceProto.SourceType.CUSTOM_SOURCE
Expand All @@ -74,18 +96,26 @@ def __init__(
f"Stream feature views need a stream source, expected one of {SUPPORTED_STREAM_SOURCES} "
f"or CUSTOM_SOURCE, got {type(source).__name__}: {source.name} instead "
)

if aggregations and not timestamp_field:
raise ValueError(
"aggregations must have a timestamp field associated with them to perform the aggregations"
)

self.aggregations = aggregations or []
self.mode = mode
self.timestamp_field = timestamp_field
self.mode = mode or ""
self.timestamp_field = timestamp_field or ""
self.udf = udf
_batch_source = None
if isinstance(source, KafkaSource):
if isinstance(source, KafkaSource) or isinstance(source, PushSource):
_batch_source = source.batch_source if source.batch_source else None

_ttl = ttl
if not _ttl:
_ttl = timedelta(days=0)
super().__init__(
name=name,
entities=entities,
ttl=ttl,
ttl=_ttl,
batch_source=_batch_source,
stream_source=source,
tags=tags,
Expand All @@ -102,7 +132,10 @@ def __eq__(self, other):

if not super().__eq__(other):
return False

if not self.udf:
return not other.udf
if not other.udf:
return False
if (
self.mode != other.mode
or self.timestamp_field != other.timestamp_field
Expand All @@ -113,13 +146,14 @@ def __eq__(self, other):

return True

def __hash__(self):
def __hash__(self) -> int:
return super().__hash__()

def to_proto(self):
meta = StreamFeatureViewMetaProto(materialization_intervals=[])
if self.created_timestamp:
meta.created_timestamp.FromDatetime(self.created_timestamp)

if self.last_updated_timestamp:
meta.last_updated_timestamp.FromDatetime(self.last_updated_timestamp)

Expand All @@ -134,6 +168,7 @@ def to_proto(self):
ttl_duration = Duration()
ttl_duration.FromTimedelta(self.ttl)

batch_source_proto = None
if self.batch_source:
batch_source_proto = self.batch_source.to_proto()
batch_source_proto.data_source_class_type = f"{self.batch_source.__class__.__module__}.{self.batch_source.__class__.__name__}"
Expand All @@ -143,23 +178,24 @@ def to_proto(self):
stream_source_proto = self.stream_source.to_proto()
stream_source_proto.data_source_class_type = f"{self.stream_source.__class__.__module__}.{self.stream_source.__class__.__name__}"

udf_proto = None
if self.udf:
udf_proto = UserDefinedFunctionProto(
name=self.udf.__name__, body=dill.dumps(self.udf, recurse=True),
)
spec = StreamFeatureViewSpecProto(
name=self.name,
entities=self.entities,
entity_columns=[field.to_proto() for field in self.entity_columns],
features=[field.to_proto() for field in self.schema],
user_defined_function=UserDefinedFunctionProto(
name=self.udf.__name__, body=dill.dumps(self.udf, recurse=True),
)
if self.udf
else None,
user_defined_function=udf_proto,
description=self.description,
tags=self.tags,
owner=self.owner,
ttl=(ttl_duration if ttl_duration is not None else None),
ttl=ttl_duration,
online=self.online,
batch_source=batch_source_proto or None,
stream_source=stream_source_proto,
stream_source=stream_source_proto or None,
timestamp_field=self.timestamp_field,
aggregations=[agg.to_proto() for agg in self.aggregations],
mode=self.mode,
Expand Down Expand Up @@ -239,6 +275,25 @@ def from_proto(cls, sfv_proto):

return sfv_feature_view

def __copy__(self):
fv = StreamFeatureView(
name=self.name,
schema=self.schema,
entities=self.entities,
ttl=self.ttl,
tags=self.tags,
online=self.online,
description=self.description,
owner=self.owner,
aggregations=self.aggregations,
mode=self.mode,
timestamp_field=self.timestamp_field,
sources=self.sources,
udf=self.udf,
)
fv.projection = copy.copy(self.projection)
return fv


def stream_feature_view(
*,
Expand All @@ -251,11 +306,13 @@ def stream_feature_view(
schema: Optional[List[Field]] = None,
source: Optional[DataSource] = None,
aggregations: Optional[List[Aggregation]] = None,
mode: Optional[str] = "spark", # Mode of ingestion/transformation
timestamp_field: Optional[str] = "", # Timestamp for aggregation
mode: Optional[str] = "spark",
timestamp_field: Optional[str] = "",
):
"""
Creates an StreamFeatureView object with the given user function as udf.
Please make sure that the udf contains all non-built in imports within the function to ensure that the execution
of a deserialized function does not miss imports.
"""

def mainify(obj):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,71 @@ def simple_sfv(df):
assert features["test_key"] == [1001]
assert "dummy_field" in features
assert features["dummy_field"] == [None]


@pytest.mark.integration
def test_stream_feature_view_udf(environment) -> None:
"""
Test apply of StreamFeatureView udfs are serialized correctly and usable.
"""
fs = environment.feature_store

# Create Feature Views
entity = Entity(name="driver_entity", join_keys=["test_key"])

stream_source = KafkaSource(
name="kafka",
timestamp_field="event_timestamp",
bootstrap_servers="",
message_format=AvroFormat(""),
topic="topic",
batch_source=FileSource(path="test_path", timestamp_field="event_timestamp"),
watermark=timedelta(days=1),
)

@stream_feature_view(
entities=[entity],
ttl=timedelta(days=30),
owner="test@example.com",
online=True,
schema=[Field(name="dummy_field", dtype=Float32)],
description="desc",
aggregations=[
Aggregation(
column="dummy_field", function="max", time_window=timedelta(days=1),
),
Aggregation(
column="dummy_field2", function="count", time_window=timedelta(days=24),
),
],
timestamp_field="event_timestamp",
mode="spark",
source=stream_source,
tags={},
)
def pandas_view(pandas_df):
import pandas as pd

assert type(pandas_df) == pd.DataFrame
df = pandas_df.transform(lambda x: x + 10, axis=1)
df.insert(2, "C", [20.2, 230.0, 34.0], True)
return df

import pandas as pd

df = pd.DataFrame({"A": [1, 2, 3], "B": [10, 20, 30]})

fs.apply([entity, pandas_view])
stream_feature_views = fs.list_stream_feature_views()
assert len(stream_feature_views) == 1
assert stream_feature_views[0].name == "pandas_view"
assert stream_feature_views[0] == pandas_view

sfv = stream_feature_views[0]

new_df = sfv.udf(df)

expected_df = pd.DataFrame(
{"A": [11, 12, 13], "B": [20, 30, 40], "C": [20.2, 230.0, 34.0]}
)
assert new_df.equals(expected_df)
74 changes: 73 additions & 1 deletion sdk/python/tests/unit/test_feature_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from feast.entity import Entity
from feast.field import Field
from feast.infra.offline_stores.file_source import FileSource
from feast.stream_feature_view import StreamFeatureView
from feast.stream_feature_view import StreamFeatureView, stream_feature_view
from feast.types import Float32


Expand Down Expand Up @@ -129,3 +129,75 @@ def test_stream_feature_view_serialization():

new_sfv = StreamFeatureView.from_proto(sfv_proto=sfv_proto)
assert new_sfv == sfv


def test_stream_feature_view_udfs():
entity = Entity(name="driver_entity", join_keys=["test_key"])
stream_source = KafkaSource(
name="kafka",
timestamp_field="event_timestamp",
bootstrap_servers="",
message_format=AvroFormat(""),
topic="topic",
batch_source=FileSource(path="some path"),
)

@stream_feature_view(
entities=[entity],
ttl=timedelta(days=30),
owner="test@example.com",
online=True,
schema=[Field(name="dummy_field", dtype=Float32)],
description="desc",
aggregations=[
Aggregation(
column="dummy_field", function="max", time_window=timedelta(days=1),
)
],
timestamp_field="event_timestamp",
source=stream_source,
)
def pandas_udf(pandas_df):
import pandas as pd

assert type(pandas_df) == pd.DataFrame
df = pandas_df.transform(lambda x: x + 10, axis=1)
return df

import pandas as pd

df = pd.DataFrame({"A": [1, 2, 3], "B": [10, 20, 30]})
sfv = pandas_udf
sfv_proto = sfv.to_proto()
new_sfv = StreamFeatureView.from_proto(sfv_proto)
new_df = new_sfv.udf(df)

expected_df = pd.DataFrame({"A": [11, 12, 13], "B": [20, 30, 40]})

assert new_df.equals(expected_df)


def test_stream_feature_view_initialization_with_optional_fields_omitted():
entity = Entity(name="driver_entity", join_keys=["test_key"])
stream_source = KafkaSource(
name="kafka",
timestamp_field="event_timestamp",
bootstrap_servers="",
message_format=AvroFormat(""),
topic="topic",
batch_source=FileSource(path="some path"),
)

sfv = StreamFeatureView(
name="test kafka stream feature view",
entities=[entity],
schema=[],
description="desc",
timestamp_field="event_timestamp",
source=stream_source,
tags={},
)
sfv_proto = sfv.to_proto()

new_sfv = StreamFeatureView.from_proto(sfv_proto=sfv_proto)
assert new_sfv == sfv

0 comments on commit 331a214

Please sign in to comment.