Skip to content

Commit

Permalink
fix: CI unittest warnings (#4006)
Browse files Browse the repository at this point in the history
fix ci unittest warnings

Signed-off-by: tokoko <togurg14@freeuni.edu.ge>
  • Loading branch information
tokoko authored Mar 12, 2024
1 parent ee4c4f1 commit 0441b8b
Show file tree
Hide file tree
Showing 11 changed files with 46 additions and 30 deletions.
4 changes: 2 additions & 2 deletions sdk/python/feast/driver_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def create_driver_hourly_stats_df(drivers, start_date, end_date) -> pd.DataFrame
"event_timestamp": [
pd.Timestamp(dt, unit="ms", tz="UTC").round("ms")
for dt in pd.date_range(
start=start_date, end=end_date, freq="1H", inclusive="left"
start=start_date, end=end_date, freq="1h", inclusive="left"
)
]
# include a fixed timestamp for get_historical_features in the quickstart
Expand Down Expand Up @@ -209,7 +209,7 @@ def create_location_stats_df(locations, start_date, end_date) -> pd.DataFrame:
"event_timestamp": [
pd.Timestamp(dt, unit="ms", tz="UTC").round("ms")
for dt in pd.date_range(
start=start_date, end=end_date, freq="1H", inclusive="left"
start=start_date, end=end_date, freq="1h", inclusive="left"
)
]
}
Expand Down
19 changes: 14 additions & 5 deletions sdk/python/feast/infra/offline_stores/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import uuid
from datetime import datetime
from pathlib import Path
from typing import Any, Callable, List, Literal, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union

import dask
import dask.dataframe as dd
Expand Down Expand Up @@ -38,10 +38,7 @@
from feast.repo_config import FeastConfigBaseModel, RepoConfig
from feast.saved_dataset import SavedDatasetStorage
from feast.usage import log_exceptions_and_usage
from feast.utils import (
_get_requested_feature_views_to_features_dict,
_run_dask_field_mapping,
)
from feast.utils import _get_requested_feature_views_to_features_dict

# FileRetrievalJob will cast string objects to string[pyarrow] from dask version 2023.7.1
# This is not the desired behavior for our use case, so we set the convert-string option to False
Expand Down Expand Up @@ -512,6 +509,18 @@ def _read_datasource(data_source) -> dd.DataFrame:
)


def _run_dask_field_mapping(
table: dd.DataFrame,
field_mapping: Dict[str, str],
):
if field_mapping:
# run field mapping in the forward direction
table = table.rename(columns=field_mapping)
table = table.persist()

return table


def _field_mapping(
df_to_join: dd.DataFrame,
feature_view: FeatureView,
Expand Down
10 changes: 9 additions & 1 deletion sdk/python/feast/infra/offline_stores/file_source.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Callable, Dict, Iterable, List, Optional, Tuple

import pyarrow
from packaging import version
from pyarrow._fs import FileSystem
from pyarrow._s3fs import S3FileSystem
from pyarrow.parquet import ParquetDataset
Expand Down Expand Up @@ -158,7 +160,13 @@ def get_table_column_names_and_types(
# Adding support for different file format path
# based on S3 filesystem
if filesystem is None:
schema = ParquetDataset(path, use_legacy_dataset=False).schema
kwargs = (
{"use_legacy_dataset": False}
if version.parse(pyarrow.__version__) < version.parse("15.0.0")
else {}
)

schema = ParquetDataset(path, **kwargs).schema
if hasattr(schema, "names") and hasattr(schema, "types"):
# Newer versions of pyarrow doesn't have this method,
# but this field is good enough.
Expand Down
3 changes: 0 additions & 3 deletions sdk/python/feast/on_demand_pandas_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,6 @@ def __eq__(self, other):
"Comparisons should only involve OnDemandPandasTransformation class objects."
)

if not super().__eq__(other):
return False

if (
self.udf_string != other.udf_string
or self.udf.__code__.co_code != other.udf.__code__.co_code
Expand Down
13 changes: 0 additions & 13 deletions sdk/python/feast/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import pandas as pd
import pyarrow
from dask import dataframe as dd
from dateutil.tz import tzlocal
from pytz import utc

Expand Down Expand Up @@ -174,18 +173,6 @@ def _run_pyarrow_field_mapping(
return table


def _run_dask_field_mapping(
table: dd.DataFrame,
field_mapping: Dict[str, str],
):
if field_mapping:
# run field mapping in the forward direction
table = table.rename(columns=field_mapping)
table = table.persist()

return table


def _coerce_datetime(ts):
"""
Depending on underlying time resolution, arrow to_pydict() sometimes returns pd
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def test_offline_write_batch(
s3_staging_location="s3://bucket/path",
workgroup="",
),
entity_key_serialization_version=2,
)

batch_source = RedshiftSource(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def retrieval_job(request):
provider="snowflake.offline",
online_store=SqliteOnlineStoreConfig(type="sqlite"),
offline_store=offline_store_config,
entity_key_serialization_version=2,
),
full_feature_names=True,
on_demand_feature_views=[],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def test_apply_feature_view_with_inline_batch_source(
driver_fv = FeatureView(
name="driver_fv",
entities=[entity],
schema=[Field(name="test_key", dtype=Int64)],
source=file_source,
)

Expand Down Expand Up @@ -178,6 +179,7 @@ def test_apply_feature_view_with_inline_stream_source(
driver_fv = FeatureView(
name="driver_fv",
entities=[entity],
schema=[Field(name="test_key", dtype=Int64)],
source=stream_source,
)

Expand Down Expand Up @@ -332,6 +334,7 @@ def test_apply_conflicting_feature_view_names(feature_store_with_local_registry)
driver_stats = FeatureView(
name="driver_hourly_stats",
entities=[driver],
schema=[Field(name="driver_id", dtype=Int64)],
ttl=timedelta(seconds=10),
online=False,
source=FileSource(path="driver_stats.parquet"),
Expand All @@ -341,6 +344,7 @@ def test_apply_conflicting_feature_view_names(feature_store_with_local_registry)
customer_stats = FeatureView(
name="DRIVER_HOURLY_STATS",
entities=[customer],
schema=[Field(name="customer_id", dtype=Int64)],
ttl=timedelta(seconds=10),
online=False,
source=FileSource(path="customer_stats.parquet"),
Expand Down
2 changes: 2 additions & 0 deletions sdk/python/tests/unit/test_on_demand_feature_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import pandas as pd
import pytest

from feast.feature_view import FeatureView
from feast.field import Field
Expand All @@ -38,6 +39,7 @@ def udf2(features_df: pd.DataFrame) -> pd.DataFrame:
return df


@pytest.mark.filterwarnings("ignore:udf and udf_string parameters are deprecated")
def test_hash():
file_source = FileSource(name="my-file-source", path="test.parquet")
feature_view = FeatureView(
Expand Down
11 changes: 9 additions & 2 deletions sdk/python/tests/unit/test_sql_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def mysql_registry():
container.start()

# The log string uses '8.0.*' since the version might be changed as new Docker images are pushed.
log_string_to_wait_for = "/usr/sbin/mysqld: ready for connections. Version: '(\d+(\.\d+){1,2})' socket: '/var/run/mysqld/mysqld.sock' port: 3306" # noqa: W605
log_string_to_wait_for = "/usr/sbin/mysqld: ready for connections. Version: '(\\d+(\\.\\d+){1,2})' socket: '/var/run/mysqld/mysqld.sock' port: 3306" # noqa: W605
waited = wait_for_logs(
container=container,
predicate=log_string_to_wait_for,
Expand Down Expand Up @@ -218,6 +218,7 @@ def test_apply_feature_view_success(sql_registry):
fv1 = FeatureView(
name="my_feature_view_1",
schema=[
Field(name="test", dtype=Int64),
Field(name="fs1_my_feature_1", dtype=Int64),
Field(name="fs1_my_feature_2", dtype=String),
Field(name="fs1_my_feature_3", dtype=Array(String)),
Expand Down Expand Up @@ -313,6 +314,7 @@ def test_apply_on_demand_feature_view_success(sql_registry):
entities=[driver()],
ttl=timedelta(seconds=8640000000),
schema=[
Field(name="driver_id", dtype=Int64),
Field(name="daily_miles_driven", dtype=Float32),
Field(name="lat", dtype=Float32),
Field(name="lon", dtype=Float32),
Expand Down Expand Up @@ -403,7 +405,10 @@ def test_modify_feature_views_success(sql_registry):

fv1 = FeatureView(
name="my_feature_view_1",
schema=[Field(name="fs1_my_feature_1", dtype=Int64)],
schema=[
Field(name="test", dtype=Int64),
Field(name="fs1_my_feature_1", dtype=Int64),
],
entities=[entity],
tags={"team": "matchmaking"},
source=batch_source,
Expand Down Expand Up @@ -527,6 +532,7 @@ def test_apply_data_source(sql_registry):
fv1 = FeatureView(
name="my_feature_view_1",
schema=[
Field(name="test", dtype=Int64),
Field(name="fs1_my_feature_1", dtype=Int64),
Field(name="fs1_my_feature_2", dtype=String),
Field(name="fs1_my_feature_3", dtype=Array(String)),
Expand Down Expand Up @@ -596,6 +602,7 @@ def test_registry_cache(sql_registry):
fv1 = FeatureView(
name="my_feature_view_1",
schema=[
Field(name="test", dtype=Int64),
Field(name="fs1_my_feature_1", dtype=Int64),
Field(name="fs1_my_feature_2", dtype=String),
Field(name="fs1_my_feature_3", dtype=Array(String)),
Expand Down
8 changes: 4 additions & 4 deletions sdk/python/tests/utils/test_wrappers.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import pytest
import warnings


def no_warnings(func):
def wrapper_no_warnings(*args, **kwargs):
with pytest.warns(None) as warnings:
with warnings.catch_warnings(record=True) as record:
func(*args, **kwargs)

if len(warnings) > 0:
if len(record) > 0:
raise AssertionError(
"Warnings were raised: " + ", ".join([str(w) for w in warnings])
"Warnings were raised: " + ", ".join([str(w) for w in record])
)

return wrapper_no_warnings

0 comments on commit 0441b8b

Please sign in to comment.