Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Upgrade the pydantic from v1 to V2 #3942

Closed
wants to merge 15 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions sdk/python/feast/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,12 +485,12 @@ def to_proto(self) -> DataSourceProto:
return data_source_proto

def validate(self, config: RepoConfig):
pass
raise NotImplementedError

def get_table_column_names_and_types(
self, config: RepoConfig
) -> Iterable[Tuple[str, str]]:
pass
raise NotImplementedError

@staticmethod
def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]:
Expand Down Expand Up @@ -534,12 +534,12 @@ def __init__(
self.schema = schema

def validate(self, config: RepoConfig):
pass
raise NotImplementedError

def get_table_column_names_and_types(
self, config: RepoConfig
) -> Iterable[Tuple[str, str]]:
pass
raise NotImplementedError

def __eq__(self, other):
if not isinstance(other, RequestSource):
Expand Down Expand Up @@ -610,12 +610,12 @@ def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]:
@typechecked
class KinesisSource(DataSource):
def validate(self, config: RepoConfig):
pass
raise NotImplementedError

def get_table_column_names_and_types(
self, config: RepoConfig
) -> Iterable[Tuple[str, str]]:
pass
raise NotImplementedError

@staticmethod
def from_proto(data_source: DataSourceProto):
Expand All @@ -639,7 +639,7 @@ def from_proto(data_source: DataSourceProto):

@staticmethod
def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]:
pass
raise NotImplementedError

def get_table_query_string(self) -> str:
raise NotImplementedError
Expand Down Expand Up @@ -772,12 +772,12 @@ def __hash__(self):
return super().__hash__()

def validate(self, config: RepoConfig):
pass
raise NotImplementedError

def get_table_column_names_and_types(
self, config: RepoConfig
) -> Iterable[Tuple[str, str]]:
pass
raise NotImplementedError

@staticmethod
def from_proto(data_source: DataSourceProto):
Expand Down
2 changes: 1 addition & 1 deletion sdk/python/feast/feature_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(
*,
name: str,
features: List[Union[FeatureView, OnDemandFeatureView]],
tags: Dict[str, str] = None,
tags: Optional[Dict[str, str]] = None,
description: str = "",
owner: str = "",
logging_config: Optional[LoggingConfig] = None,
Expand Down
2 changes: 1 addition & 1 deletion sdk/python/feast/feature_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __init__(
name: str,
source: DataSource,
schema: Optional[List[Field]] = None,
entities: List[Entity] = None,
entities: Optional[List[Entity]] = None,
ttl: Optional[timedelta] = timedelta(days=0),
online: bool = True,
description: str = "",
Expand Down
2 changes: 1 addition & 1 deletion sdk/python/feast/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
)


def import_class(module_name: str, class_name: str, class_type: str = None):
def import_class(module_name: str, class_name: str, class_type: str = ""):
"""
Dynamically loads and returns a class from a module.

Expand Down
4 changes: 3 additions & 1 deletion sdk/python/feast/infra/contrib/spark_kafka_processor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from types import MethodType
from typing import List, Optional
from typing import List, Optional, no_type_check

import pandas as pd
from pyspark.sql import DataFrame, SparkSession
Expand Down Expand Up @@ -69,6 +69,8 @@ def ingest_stream_feature_view(self, to: PushMode = PushMode.ONLINE) -> None:
online_store_query = self._write_stream_data(transformed_df, to)
return online_store_query

# In the line 64 of __init__(), the "data_source" is assigned a stream_source (and has to be KafkaSource as in line 40).
@no_type_check
def _ingest_stream_data(self) -> StreamTable:
"""Only supports json and avro formats currently."""
if self.format == "json":
Expand Down
6 changes: 5 additions & 1 deletion sdk/python/feast/infra/contrib/stream_processor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from abc import ABC
from abc import ABC, abstractmethod
from types import MethodType
from typing import TYPE_CHECKING, Optional

Expand Down Expand Up @@ -49,19 +49,22 @@ def __init__(
self.sfv = sfv
self.data_source = data_source

@abstractmethod
def ingest_stream_feature_view(self, to: PushMode = PushMode.ONLINE) -> None:
"""
Ingests data from the stream source attached to the stream feature view; transforms the data
and then persists it to the online store and/or offline store, depending on the 'to' parameter.
"""
pass

@abstractmethod
def _ingest_stream_data(self) -> StreamTable:
"""
Ingests data into a StreamTable.
"""
pass

@abstractmethod
def _construct_transformation_plan(self, table: StreamTable) -> StreamTable:
"""
Applies transformations on top of StreamTable object. Since stream engines use lazy
Expand All @@ -70,6 +73,7 @@ def _construct_transformation_plan(self, table: StreamTable) -> StreamTable:
"""
pass

@abstractmethod
def _write_stream_data(self, table: StreamTable, to: PushMode) -> None:
"""
Launches a job to persist stream data to the online store and/or offline store, depending
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Literal

from pydantic import StrictBool, StrictStr
from pydantic.typing import Literal

from feast.infra.feature_servers.base_config import BaseFeatureServerConfig

Expand Down
2 changes: 1 addition & 1 deletion sdk/python/feast/infra/feature_servers/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,5 @@ class BaseFeatureServerConfig(FeastConfigBaseModel):
enabled: StrictBool = False
"""Whether the feature server should be launched."""

feature_logging: Optional[FeatureLoggingConfig]
feature_logging: Optional[FeatureLoggingConfig] = None
""" Feature logging configuration """
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Literal

from pydantic import StrictBool
from pydantic.typing import Literal

from feast.infra.feature_servers.base_config import BaseFeatureServerConfig

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pydantic.typing import Literal
from typing import Literal

from feast.infra.feature_servers.base_config import BaseFeatureServerConfig

Expand Down
6 changes: 2 additions & 4 deletions sdk/python/feast/infra/materialization/snowflake_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import click
import pandas as pd
from colorama import Fore, Style
from pydantic import Field, StrictStr
from pydantic import ConfigDict, Field, StrictStr
from pytz import utc
from tqdm import tqdm

Expand Down Expand Up @@ -72,9 +72,7 @@ class SnowflakeMaterializationEngineConfig(FeastConfigBaseModel):

schema_: Optional[str] = Field("PUBLIC", alias="schema")
""" Snowflake schema name """

class Config:
allow_population_by_field_name = True
model_config = ConfigDict(populate_by_name=True)


@dataclass
Expand Down
22 changes: 10 additions & 12 deletions sdk/python/feast/infra/offline_stores/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Dict,
Iterator,
List,
Literal,
Optional,
Tuple,
Union,
Expand All @@ -19,8 +20,7 @@
import pandas as pd
import pyarrow
import pyarrow.parquet
from pydantic import ConstrainedStr, StrictStr, validator
from pydantic.typing import Literal
from pydantic import StrictStr, field_validator
from tenacity import Retrying, retry_if_exception_type, stop_after_delay, wait_fixed

from feast import flags_helper
Expand Down Expand Up @@ -72,13 +72,6 @@ def get_http_client_info():
return http_client_info.ClientInfo(user_agent=get_user_agent())


class BigQueryTableCreateDisposition(ConstrainedStr):
"""Custom constraint for table_create_disposition. To understand more, see:
https://cloud.google.com/bigquery/docs/reference/rest/v2/Job#JobConfigurationLoad.FIELDS.create_disposition"""

values = {"CREATE_NEVER", "CREATE_IF_NEEDED"}


class BigQueryOfflineStoreConfig(FeastConfigBaseModel):
"""Offline store config for GCP BigQuery"""

Expand All @@ -102,10 +95,15 @@ class BigQueryOfflineStoreConfig(FeastConfigBaseModel):
gcs_staging_location: Optional[str] = None
""" (optional) GCS location used for offloading BigQuery results as parquet files."""

table_create_disposition: Optional[BigQueryTableCreateDisposition] = None
""" (optional) Specifies whether the job is allowed to create new tables. The default value is CREATE_IF_NEEDED."""
table_create_disposition: Literal[
"CREATE_NEVER", "CREATE_IF_NEEDED"
] = "CREATE_IF_NEEDED"
""" (optional) Specifies whether the job is allowed to create new tables. The default value is CREATE_IF_NEEDED.
Custom constraint for table_create_disposition. To understand more, see:
https://cloud.google.com/bigquery/docs/reference/rest/v2/Job#JobConfigurationLoad.FIELDS.create_disposition
"""

@validator("billing_project_id")
@field_validator("billing_project_id")
def project_id_exists(cls, v, values, **kwargs):
if v and not values["project_id"]:
raise ValueError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Dict,
Iterator,
List,
Literal,
Optional,
Tuple,
Union,
Expand All @@ -18,7 +19,6 @@
import pyarrow
import pyarrow as pa
from pydantic import StrictStr
from pydantic.typing import Literal
from pytz import utc

from feast import OnDemandFeatureView
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,9 +297,9 @@ class SavedDatasetAthenaStorage(SavedDatasetStorage):
def __init__(
self,
table_ref: str,
query: str = None,
database: str = None,
data_source: str = None,
query: Optional[str] = None,
database: Optional[str] = None,
data_source: Optional[str] = None,
):
self.athena_options = AthenaOptions(
table=table_ref, query=query, database=database, data_source=data_source
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,16 @@ def __init__(self, project_name: str, *args, **kwargs):
workgroup=workgroup,
s3_staging_location=f"s3://{bucket_name}/test_dir",
)
self,

def create_data_source(
self,
df: pd.DataFrame,
destination_name: str,
suffix: Optional[str] = None,
timestamp_field="ts",
event_timestamp_column="ts",
created_timestamp_column="created_ts",
field_mapping: Dict[str, str] = None,
field_mapping: Optional[Dict[str, str]] = None,
timestamp_field: Optional[str] = "ts",
) -> DataSource:

table_name = destination_name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@
import warnings
from datetime import datetime
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Callable, Dict, List, Literal, Optional, Set, Tuple, Union

import numpy as np
import pandas
import pyarrow
import pyarrow as pa
import sqlalchemy
from pydantic.types import StrictStr
from pydantic.typing import Literal
from sqlalchemy import create_engine
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
Expand All @@ -32,7 +31,7 @@
from feast.infra.provider import RetrievalJob
from feast.infra.registry.base_registry import BaseRegistry
from feast.on_demand_feature_view import OnDemandFeatureView
from feast.repo_config import FeastBaseModel, RepoConfig
from feast.repo_config import FeastConfigBaseModel, RepoConfig
from feast.saved_dataset import SavedDatasetStorage
from feast.type_map import pa_to_mssql_type
from feast.usage import log_exceptions_and_usage
Expand All @@ -43,7 +42,7 @@
EntitySchema = Dict[str, np.dtype]


class MsSqlServerOfflineStoreConfig(FeastBaseModel):
class MsSqlServerOfflineStoreConfig(FeastConfigBaseModel):
"""Offline store config for SQL Server"""

type: Literal["mssql"] = "mssql"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List
from typing import Dict, List, Optional

import pandas as pd
import pytest
Expand Down Expand Up @@ -64,10 +64,10 @@ def create_data_source(
self,
df: pd.DataFrame,
destination_name: str,
timestamp_field="ts",
event_timestamp_column="ts",
created_timestamp_column="created_ts",
field_mapping: Dict[str, str] = None,
**kwargs,
field_mapping: Optional[Dict[str, str]] = None,
timestamp_field: Optional[str] = "ts",
) -> DataSource:
# Make sure the field mapping is correct and convert the datetime datasources.
if timestamp_field in df:
Expand Down Expand Up @@ -99,7 +99,7 @@ def create_data_source(
)

def create_saved_dataset_destination(self) -> SavedDatasetStorage:
pass
raise NotImplementedError

def get_prefixed_table_name(self, destination_name: str) -> str:
return f"{self.project_name}_{destination_name}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Iterator,
KeysView,
List,
Literal,
Optional,
Tuple,
Union,
Expand All @@ -19,7 +20,6 @@
import pyarrow as pa
from jinja2 import BaseLoader, Environment
from psycopg2 import sql
from pydantic.typing import Literal
from pytz import utc

from feast.data_source import DataSource
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,10 @@ def create_data_source(
self,
df: pd.DataFrame,
destination_name: str,
suffix: Optional[str] = None,
timestamp_field="ts",
event_timestamp_column="ts",
created_timestamp_column="created_ts",
field_mapping: Dict[str, str] = None,
field_mapping: Optional[Dict[str, str]] = None,
timestamp_field: Optional[str] = "ts",
) -> DataSource:
destination_name = self.get_prefixed_table_name(destination_name)

Expand Down
Loading
Loading