Skip to content

Commit

Permalink
feat: Implement offline_write_batch for BigQuery and Snowflake (#2840)
Browse files Browse the repository at this point in the history
* Factor out Redshift pyarrow schema inference logic into helper method

Signed-off-by: Felix Wang <wangfelix98@gmail.com>

* Switch file offline store to use offline_utils for offline_write_batch

Signed-off-by: Felix Wang <wangfelix98@gmail.com>

* Implement offline_write_batch for bigquery

Signed-off-by: Felix Wang <wangfelix98@gmail.com>

* Implement offline_write_batch for snowflake

Signed-off-by: Felix Wang <wangfelix98@gmail.com>

* Enable bigquery and snowflake for test_push_features_and_read_from_offline_store test

Signed-off-by: Felix Wang <wangfelix98@gmail.com>

* Rename get_pyarrow_schema

Signed-off-by: Felix Wang <wangfelix98@gmail.com>
  • Loading branch information
felixwang9817 authored Jun 22, 2022
1 parent a88cd30 commit 97444e4
Show file tree
Hide file tree
Showing 9 changed files with 155 additions and 35 deletions.
55 changes: 55 additions & 0 deletions sdk/python/feast/infra/offline_stores/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from datetime import date, datetime, timedelta
from pathlib import Path
from typing import (
Any,
Callable,
ContextManager,
Dict,
Expand Down Expand Up @@ -303,6 +304,60 @@ def write_logged_features(
job_config=job_config,
)

@staticmethod
def offline_write_batch(
config: RepoConfig,
feature_view: FeatureView,
table: pyarrow.Table,
progress: Optional[Callable[[int], Any]],
):
if not feature_view.batch_source:
raise ValueError(
"feature view does not have a batch source to persist offline data"
)
if not isinstance(config.offline_store, BigQueryOfflineStoreConfig):
raise ValueError(
f"offline store config is of type {type(config.offline_store)} when bigquery type required"
)
if not isinstance(feature_view.batch_source, BigQuerySource):
raise ValueError(
f"feature view batch source is {type(feature_view.batch_source)} not bigquery source"
)

pa_schema, column_names = offline_utils.get_pyarrow_schema_from_batch_source(
config, feature_view.batch_source
)
if column_names != table.column_names:
raise ValueError(
f"The input pyarrow table has schema {pa_schema} with the incorrect columns {column_names}. "
f"The columns are expected to be (in this order): {column_names}."
)

if table.schema != pa_schema:
table = table.cast(pa_schema)

client = _get_bigquery_client(
project=config.offline_store.project_id,
location=config.offline_store.location,
)

job_config = bigquery.LoadJobConfig(
source_format=bigquery.SourceFormat.PARQUET,
schema=arrow_schema_to_bq_schema(pa_schema),
write_disposition="WRITE_APPEND", # Default but included for clarity
)

with tempfile.TemporaryFile() as parquet_temp_file:
pyarrow.parquet.write_table(table=table, where=parquet_temp_file)

parquet_temp_file.seek(0)

client.load_table_from_file(
file_obj=parquet_temp_file,
destination=feature_view.batch_source.table,
job_config=job_config,
)


class BigQueryRetrievalJob(RetrievalJob):
def __init__(
Expand Down
28 changes: 18 additions & 10 deletions sdk/python/feast/infra/offline_stores/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)
from feast.infra.offline_stores.offline_utils import (
DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL,
get_pyarrow_schema_from_batch_source,
)
from feast.infra.provider import (
_get_requested_feature_views_to_features_dict,
Expand Down Expand Up @@ -408,7 +409,7 @@ def write_logged_features(
def offline_write_batch(
config: RepoConfig,
feature_view: FeatureView,
data: pyarrow.Table,
table: pyarrow.Table,
progress: Optional[Callable[[int], Any]],
):
if not feature_view.batch_source:
Expand All @@ -423,20 +424,27 @@ def offline_write_batch(
raise ValueError(
f"feature view batch source is {type(feature_view.batch_source)} not file source"
)

pa_schema, column_names = get_pyarrow_schema_from_batch_source(
config, feature_view.batch_source
)
if column_names != table.column_names:
raise ValueError(
f"The input pyarrow table has schema {pa_schema} with the incorrect columns {column_names}. "
f"The columns are expected to be (in this order): {column_names}."
)

file_options = feature_view.batch_source.file_options
filesystem, path = FileSource.create_filesystem_and_path(
file_options.uri, file_options.s3_endpoint_override
)

prev_table = pyarrow.parquet.read_table(path, memory_map=True)
if prev_table.column_names != data.column_names:
raise ValueError(
f"Input dataframe has incorrect schema or wrong order, expected columns are: {prev_table.column_names}"
)
if data.schema != prev_table.schema:
data = data.cast(prev_table.schema)
new_table = pyarrow.concat_tables([data, prev_table])
writer = pyarrow.parquet.ParquetWriter(path, data.schema, filesystem=filesystem)
if table.schema != prev_table.schema:
table = table.cast(prev_table.schema)
new_table = pyarrow.concat_tables([table, prev_table])
writer = pyarrow.parquet.ParquetWriter(
path, table.schema, filesystem=filesystem
)
writer.write_table(new_table)
writer.close()

Expand Down
6 changes: 3 additions & 3 deletions sdk/python/feast/infra/offline_stores/offline_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def write_logged_features(
def offline_write_batch(
config: RepoConfig,
feature_view: FeatureView,
data: pyarrow.Table,
table: pyarrow.Table,
progress: Optional[Callable[[int], Any]],
):
"""
Expand All @@ -286,8 +286,8 @@ def offline_write_batch(
Args:
config: Repo configuration object
table: FeatureView to write the data to.
data: pyarrow table containing feature data and timestamp column for historical feature retrieval
feature_view: FeatureView to write the data to.
table: pyarrow table containing feature data and timestamp column for historical feature retrieval
progress: Optional function to be called once every mini-batch of rows is written to
the online store. Can be used to display progress.
"""
Expand Down
26 changes: 26 additions & 0 deletions sdk/python/feast/infra/offline_stores/offline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@

import numpy as np
import pandas as pd
import pyarrow as pa
from jinja2 import BaseLoader, Environment
from pandas import Timestamp

from feast.data_source import DataSource
from feast.errors import (
EntityTimestampInferenceException,
FeastEntityDFMissingColumnsError,
Expand All @@ -17,6 +19,8 @@
from feast.infra.offline_stores.offline_store import OfflineStore
from feast.infra.provider import _get_requested_feature_views_to_features_dict
from feast.registry import BaseRegistry
from feast.repo_config import RepoConfig
from feast.type_map import feast_value_type_to_pa
from feast.utils import to_naive_utc

DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL = "event_timestamp"
Expand Down Expand Up @@ -217,3 +221,25 @@ def get_offline_store_from_config(offline_store_config: Any) -> OfflineStore:
class_name = qualified_name.replace("Config", "")
offline_store_class = import_class(module_name, class_name, "OfflineStore")
return offline_store_class()


def get_pyarrow_schema_from_batch_source(
config: RepoConfig, batch_source: DataSource
) -> Tuple[pa.Schema, List[str]]:
"""Returns the pyarrow schema and column names for the given batch source."""
column_names_and_types = batch_source.get_table_column_names_and_types(config)

pa_schema = []
column_names = []
for column_name, column_type in column_names_and_types:
pa_schema.append(
(
column_name,
feast_value_type_to_pa(
batch_source.source_datatype_to_feast_value_type()(column_type)
),
)
)
column_names.append(column_name)

return pa.schema(pa_schema), column_names
27 changes: 8 additions & 19 deletions sdk/python/feast/infra/offline_stores/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
from feast.registry import BaseRegistry
from feast.repo_config import FeastConfigBaseModel, RepoConfig
from feast.saved_dataset import SavedDatasetStorage
from feast.type_map import feast_value_type_to_pa, redshift_to_feast_value_type
from feast.usage import log_exceptions_and_usage


Expand Down Expand Up @@ -318,33 +317,23 @@ def offline_write_batch(
raise ValueError(
f"feature view batch source is {type(feature_view.batch_source)} not redshift source"
)
redshift_options = feature_view.batch_source.redshift_options
redshift_client = aws_utils.get_redshift_data_client(
config.offline_store.region
)

column_name_to_type = feature_view.batch_source.get_table_column_names_and_types(
config
pa_schema, column_names = offline_utils.get_pyarrow_schema_from_batch_source(
config, feature_view.batch_source
)
pa_schema_list = []
column_names = []
for column_name, redshift_type in column_name_to_type:
pa_schema_list.append(
(
column_name,
feast_value_type_to_pa(redshift_to_feast_value_type(redshift_type)),
)
)
column_names.append(column_name)
pa_schema = pa.schema(pa_schema_list)
if column_names != table.column_names:
raise ValueError(
f"Input dataframe has incorrect schema or wrong order, expected columns are: {column_names}"
f"The input pyarrow table has schema {pa_schema} with the incorrect columns {column_names}. "
f"The columns are expected to be (in this order): {column_names}."
)

if table.schema != pa_schema:
table = table.cast(pa_schema)

redshift_options = feature_view.batch_source.redshift_options
redshift_client = aws_utils.get_redshift_data_client(
config.offline_store.region
)
s3_resource = aws_utils.get_s3_resource(config.offline_store.region)

aws_utils.upload_arrow_table_to_redshift(
Expand Down
42 changes: 42 additions & 0 deletions sdk/python/feast/infra/offline_stores/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from datetime import datetime
from pathlib import Path
from typing import (
Any,
Callable,
ContextManager,
Dict,
Expand Down Expand Up @@ -306,6 +307,47 @@ def write_logged_features(
auto_create_table=True,
)

@staticmethod
def offline_write_batch(
config: RepoConfig,
feature_view: FeatureView,
table: pyarrow.Table,
progress: Optional[Callable[[int], Any]],
):
if not feature_view.batch_source:
raise ValueError(
"feature view does not have a batch source to persist offline data"
)
if not isinstance(config.offline_store, SnowflakeOfflineStoreConfig):
raise ValueError(
f"offline store config is of type {type(config.offline_store)} when snowflake type required"
)
if not isinstance(feature_view.batch_source, SnowflakeSource):
raise ValueError(
f"feature view batch source is {type(feature_view.batch_source)} not snowflake source"
)

pa_schema, column_names = offline_utils.get_pyarrow_schema_from_batch_source(
config, feature_view.batch_source
)
if column_names != table.column_names:
raise ValueError(
f"The input pyarrow table has schema {pa_schema} with the incorrect columns {column_names}. "
f"The columns are expected to be (in this order): {column_names}."
)

if table.schema != pa_schema:
table = table.cast(pa_schema)

snowflake_conn = get_snowflake_conn(config.offline_store)

write_pandas(
snowflake_conn,
table.to_pandas(),
table_name=feature_view.batch_source.table,
auto_create_table=True,
)


class SnowflakeRetrievalJob(RetrievalJob):
def __init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@

OFFLINE_STORE_TO_PROVIDER_CONFIG: Dict[str, DataSourceCreator] = {
"file": ("local", FileDataSourceCreator),
"gcp": ("gcp", BigQueryDataSourceCreator),
"bigquery": ("gcp", BigQueryDataSourceCreator),
"redshift": ("aws", RedshiftDataSourceCreator),
"snowflake": ("aws", RedshiftDataSourceCreator),
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def test_writing_incorrect_schema_fails(environment, universal_data_sources):


@pytest.mark.integration
@pytest.mark.universal_offline_stores(only=["file", "redshift"])
@pytest.mark.universal_offline_stores
@pytest.mark.universal_online_stores(only=["sqlite"])
def test_writing_consecutively_to_offline_store(environment, universal_data_sources):
store = environment.feature_store
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


@pytest.mark.integration
@pytest.mark.universal_offline_stores(only=["file", "redshift"])
@pytest.mark.universal_offline_stores
@pytest.mark.universal_online_stores(only=["sqlite"])
def test_push_features_and_read_from_offline_store(environment, universal_data_sources):
store = environment.feature_store
Expand Down

0 comments on commit 97444e4

Please sign in to comment.