Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Added Offline Store Arrow client errors handler #4524

Merged
merged 2 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions sdk/python/feast/arrow_error_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import logging
from functools import wraps

import pyarrow.flight as fl

from feast.errors import FeastError

logger = logging.getLogger(__name__)


def arrow_client_error_handling_decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
mapped_error = FeastError.from_error_detail(_get_exception_data(e.args[0]))
if mapped_error is not None:
raise mapped_error
raise e

return wrapper


def arrow_server_error_handling_decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
if isinstance(e, FeastError):
raise fl.FlightError(e.to_error_detail())

return wrapper


def _get_exception_data(except_str) -> str:
tmihalac marked this conversation as resolved.
Show resolved Hide resolved
substring = "Flight error: "

# Find the starting index of the substring
position = except_str.find(substring)
end_json_index = except_str.find("}")

if position != -1 and end_json_index != -1:
# Extract the part of the string after the substring
result = except_str[position + len(substring) : end_json_index + 1]
return result

return ""
68 changes: 60 additions & 8 deletions sdk/python/feast/infra/offline_stores/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@
import pyarrow as pa
import pyarrow.flight as fl
import pyarrow.parquet
from pyarrow import Schema
from pyarrow._flight import FlightCallOptions, FlightDescriptor, Ticket
from pydantic import StrictInt, StrictStr

from feast import OnDemandFeatureView
from feast.arrow_error_handler import arrow_client_error_handling_decorator
from feast.data_source import DataSource
from feast.feature_logging import (
FeatureServiceLoggingSource,
Expand All @@ -27,15 +30,54 @@
RetrievalMetadata,
)
from feast.infra.registry.base_registry import BaseRegistry
from feast.permissions.auth.auth_type import AuthType
from feast.permissions.auth_model import AuthConfig
from feast.permissions.client.arrow_flight_auth_interceptor import (
build_arrow_flight_client,
FlightAuthInterceptorFactory,
)
from feast.repo_config import FeastConfigBaseModel, RepoConfig
from feast.saved_dataset import SavedDatasetStorage

logger = logging.getLogger(__name__)


class FeastFlightClient(fl.FlightClient):
@arrow_client_error_handling_decorator
def get_flight_info(
self, descriptor: FlightDescriptor, options: FlightCallOptions = None
):
return super().get_flight_info(descriptor, options)

@arrow_client_error_handling_decorator
def do_get(self, ticket: Ticket, options: FlightCallOptions = None):
return super().do_get(ticket, options)

@arrow_client_error_handling_decorator
def do_put(
self,
descriptor: FlightDescriptor,
schema: Schema,
options: FlightCallOptions = None,
):
return super().do_put(descriptor, schema, options)

@arrow_client_error_handling_decorator
def list_flights(self, criteria: bytes = b"", options: FlightCallOptions = None):
return super().list_flights(criteria, options)

@arrow_client_error_handling_decorator
def list_actions(self, options: FlightCallOptions = None):
return super().list_actions(options)


def build_arrow_flight_client(host: str, port, auth_config: AuthConfig):
if auth_config.type != AuthType.NONE.value:
middlewares = [FlightAuthInterceptorFactory(auth_config)]
return FeastFlightClient(f"grpc://{host}:{port}", middleware=middlewares)

return FeastFlightClient(f"grpc://{host}:{port}")


class RemoteOfflineStoreConfig(FeastConfigBaseModel):
type: Literal["remote"] = "remote"
host: StrictStr
Expand All @@ -48,7 +90,7 @@ class RemoteOfflineStoreConfig(FeastConfigBaseModel):
class RemoteRetrievalJob(RetrievalJob):
def __init__(
self,
client: fl.FlightClient,
client: FeastFlightClient,
api: str,
api_parameters: Dict[str, Any],
entity_df: Union[pd.DataFrame, str] = None,
Expand Down Expand Up @@ -338,7 +380,7 @@ def _send_retrieve_remote(
api_parameters: Dict[str, Any],
entity_df: Union[pd.DataFrame, str],
table: pa.Table,
client: fl.FlightClient,
client: FeastFlightClient,
):
command_descriptor = _call_put(
api,
Expand All @@ -351,19 +393,19 @@ def _send_retrieve_remote(


def _call_get(
client: fl.FlightClient,
client: FeastFlightClient,
command_descriptor: fl.FlightDescriptor,
):
flight = client.get_flight_info(command_descriptor)
ticket = flight.endpoints[0].ticket
reader = client.do_get(ticket)
return reader.read_all()
return read_all(reader)


def _call_put(
api: str,
api_parameters: Dict[str, Any],
client: fl.FlightClient,
client: FeastFlightClient,
entity_df: Union[pd.DataFrame, str],
table: pa.Table,
):
Expand Down Expand Up @@ -391,7 +433,7 @@ def _put_parameters(
command_descriptor: fl.FlightDescriptor,
entity_df: Union[pd.DataFrame, str],
table: pa.Table,
client: fl.FlightClient,
client: FeastFlightClient,
):
updatedTable: pa.Table

Expand All @@ -404,10 +446,20 @@ def _put_parameters(

writer, _ = client.do_put(command_descriptor, updatedTable.schema)

writer.write_table(updatedTable)
write_table(writer, updatedTable)


@arrow_client_error_handling_decorator
def write_table(writer, updated_table: pa.Table):
writer.write_table(updated_table)
writer.close()


@arrow_client_error_handling_decorator
def read_all(reader):
return reader.read_all()


def _create_empty_table():
schema = pa.schema(
{
Expand Down
54 changes: 37 additions & 17 deletions sdk/python/feast/offline_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,18 @@
import pyarrow.flight as fl

from feast import FeatureStore, FeatureView, utils
from feast.arrow_error_handler import arrow_server_error_handling_decorator
from feast.feature_logging import FeatureServiceLoggingSource
from feast.feature_view import DUMMY_ENTITY_NAME
from feast.infra.offline_stores.offline_utils import get_offline_store_from_config
from feast.permissions.action import AuthzedAction
from feast.permissions.security_manager import assert_permissions
from feast.permissions.server.arrow import (
arrowflight_middleware,
AuthorizationMiddlewareFactory,
inject_user_details_decorator,
)
from feast.permissions.server.utils import (
AuthManagerType,
ServerType,
init_auth_manager,
init_security_manager,
Expand All @@ -34,7 +36,7 @@ class OfflineServer(fl.FlightServerBase):
def __init__(self, store: FeatureStore, location: str, **kwargs):
super(OfflineServer, self).__init__(
location,
middleware=arrowflight_middleware(
middleware=self.arrow_flight_auth_middleware(
str_to_auth_manager_type(store.config.auth_config.type)
),
**kwargs,
Expand All @@ -45,6 +47,25 @@ def __init__(self, store: FeatureStore, location: str, **kwargs):
self.store = store
self.offline_store = get_offline_store_from_config(store.config.offline_store)

def arrow_flight_auth_middleware(
self,
auth_type: AuthManagerType,
) -> dict[str, fl.ServerMiddlewareFactory]:
"""
A dictionary with the configured middlewares to support extracting the user details when the authorization manager is defined.
The authorization middleware key is `auth`.

Returns:
dict[str, fl.ServerMiddlewareFactory]: Optional dictionary of middlewares. If the authorization type is set to `NONE`, it returns an empty dict.
"""

if auth_type == AuthManagerType.NONE:
return {}

return {
"auth": AuthorizationMiddlewareFactory(),
}

@classmethod
def descriptor_to_key(self, descriptor: fl.FlightDescriptor):
return (
Expand All @@ -61,15 +82,7 @@ def _make_flight_info(self, key: Any, descriptor: fl.FlightDescriptor):
return fl.FlightInfo(schema, descriptor, endpoints, -1, -1)

@inject_user_details_decorator
def get_flight_info(
self, context: fl.ServerCallContext, descriptor: fl.FlightDescriptor
):
key = OfflineServer.descriptor_to_key(descriptor)
if key in self.flights:
return self._make_flight_info(key, descriptor)
raise KeyError("Flight not found.")

@inject_user_details_decorator
@arrow_server_error_handling_decorator
def list_flights(self, context: fl.ServerCallContext, criteria: bytes):
for key, table in self.flights.items():
if key[1] is not None:
Expand All @@ -79,9 +92,20 @@ def list_flights(self, context: fl.ServerCallContext, criteria: bytes):

yield self._make_flight_info(key, descriptor)

@inject_user_details_decorator
@arrow_server_error_handling_decorator
def get_flight_info(
self, context: fl.ServerCallContext, descriptor: fl.FlightDescriptor
):
key = OfflineServer.descriptor_to_key(descriptor)
if key in self.flights:
return self._make_flight_info(key, descriptor)
raise KeyError("Flight not found.")

# Expects to receive request parameters and stores them in the flights dictionary
# Indexed by the unique command
@inject_user_details_decorator
@arrow_server_error_handling_decorator
def do_put(
self,
context: fl.ServerCallContext,
Expand Down Expand Up @@ -179,6 +203,7 @@ def _validate_do_get_parameters(self, command: dict):
# Extracts the API parameters from the flights dictionary, delegates the execution to the FeatureStore instance
# and returns the stream of data
@inject_user_details_decorator
@arrow_server_error_handling_decorator
def do_get(self, context: fl.ServerCallContext, ticket: fl.Ticket):
key = ast.literal_eval(ticket.ticket.decode())
if key not in self.flights:
Expand Down Expand Up @@ -337,6 +362,7 @@ def pull_latest_from_table_or_query(self, command: dict):
utils.make_tzaware(datetime.fromisoformat(command["end_date"])),
)

@arrow_server_error_handling_decorator
def list_actions(self, context):
return [
(
Expand Down Expand Up @@ -431,12 +457,6 @@ def persist(self, command: dict, key: str):
traceback.print_exc()
raise e

def do_action(self, context: fl.ServerCallContext, action: fl.Action):
pass

def do_drop_dataset(self, dataset):
pass


def remove_dummies(fv: FeatureView) -> FeatureView:
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import pyarrow.flight as fl

from feast.permissions.auth.auth_type import AuthType
from feast.permissions.auth_model import AuthConfig
from feast.permissions.client.client_auth_token import get_auth_token

Expand Down Expand Up @@ -28,11 +27,3 @@ def __init__(self, auth_config: AuthConfig):

def start_call(self, info):
return FlightBearerTokenInterceptor(self.auth_config)


def build_arrow_flight_client(host: str, port, auth_config: AuthConfig):
if auth_config.type != AuthType.NONE.value:
middleware_factory = FlightAuthInterceptorFactory(auth_config)
return fl.FlightClient(f"grpc://{host}:{port}", middleware=[middleware_factory])
else:
return fl.FlightClient(f"grpc://{host}:{port}")
31 changes: 5 additions & 26 deletions sdk/python/feast/permissions/server/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import asyncio
import functools
import logging
from typing import Optional, cast
from typing import cast

import pyarrow.flight as fl
from pyarrow.flight import ServerCallContext
Expand All @@ -14,41 +14,19 @@
get_auth_manager,
)
from feast.permissions.security_manager import get_security_manager
from feast.permissions.server.utils import (
AuthManagerType,
)
from feast.permissions.user import User

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


def arrowflight_middleware(
auth_type: AuthManagerType,
) -> Optional[dict[str, fl.ServerMiddlewareFactory]]:
"""
A dictionary with the configured middlewares to support extracting the user details when the authorization manager is defined.
The authorization middleware key is `auth`.

Returns:
dict[str, fl.ServerMiddlewareFactory]: Optional dictionary of middlewares. If the authorization type is set to `NONE`, it returns `None`.
"""

if auth_type == AuthManagerType.NONE:
return None

return {
"auth": AuthorizationMiddlewareFactory(),
}


class AuthorizationMiddlewareFactory(fl.ServerMiddlewareFactory):
"""
A middleware factory to intercept the authorization header and propagate it to the authorization middleware.
"""

def __init__(self):
pass
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def start_call(self, info, headers):
"""
Expand All @@ -65,7 +43,8 @@ class AuthorizationMiddleware(fl.ServerMiddleware):
A server middleware holding the authorization header and offering a method to extract the user credentials.
"""

def __init__(self, access_token: str):
def __init__(self, access_token: str, *args, **kwargs):
super().__init__(*args, **kwargs)
self.access_token = access_token

def call_completed(self, exception):
Expand Down
Loading
Loading