Skip to content

Commit

Permalink
refactor: Add parameters validation to OfflineServer (#4289)
Browse files Browse the repository at this point in the history
Add parameters validation to OfflineServer

Signed-off-by: Theodor Mihalache <tmihalac@redhat.com>
  • Loading branch information
tmihalac authored Jun 20, 2024
1 parent 6c75e84 commit de5b0eb
Showing 1 changed file with 88 additions and 8 deletions.
96 changes: 88 additions & 8 deletions sdk/python/feast/offline_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,22 +74,23 @@ def do_put(
logger.debug(f"do_put: command is{command}, data is {data}")
self.flights[key] = data

self._call_api(command, key)
self._call_api(command["api"], command, key)
else:
logger.warning(f"No 'api' field in command: {command}")

def _call_api(self, command: dict, key: str):
def _call_api(self, api: str, command: dict, key: str):
assert api is not None, "api can not be empty"

remove_data = False
try:
api = command["api"]
if api == OfflineServer.offline_write_batch.__name__:
self.offline_write_batch(command, key)
remove_data = True
elif api == OfflineServer.write_logged_features.__name__:
self.write_logged_features(command, key)
remove_data = True
elif api == OfflineServer.persist.__name__:
self.persist(command["retrieve_func"], command, key)
self.persist(command, key)
remove_data = True
except Exception as e:
remove_data = True
Expand Down Expand Up @@ -150,6 +151,9 @@ def list_feature_views_by_name(
for index, fv_name in enumerate(feature_view_names)
]

def _validate_do_get_parameters(self, command: dict):
assert "api" in command, "api parameter is mandatory"

# Extracts the API parameters from the flights dictionary, delegates the execution to the FeatureStore instance
# and returns the stream of data
def do_get(self, context: fl.ServerCallContext, ticket: fl.Ticket):
Expand All @@ -159,6 +163,9 @@ def do_get(self, context: fl.ServerCallContext, ticket: fl.Ticket):
return None

command = json.loads(key[1])

self._validate_do_get_parameters(command)

api = command["api"]
logger.debug(f"get command is {command}")
logger.debug(f"requested api is {api}")
Expand All @@ -180,33 +187,52 @@ def do_get(self, context: fl.ServerCallContext, ticket: fl.Ticket):
del self.flights[key]
return fl.RecordBatchStream(table)

def offline_write_batch(self, command: dict, key: str):
def _validate_offline_write_batch_parameters(self, command: dict):
assert (
"feature_view_names" in command
), "feature_view_names is a mandatory parameter"
assert "name_aliases" in command, "name_aliases is a mandatory parameter"

feature_view_names = command["feature_view_names"]
assert (
len(feature_view_names) == 1
), "feature_view_names list should only have one item"

name_aliases = command["name_aliases"]
assert len(name_aliases) == 1, "name_aliases list should only have one item"

def offline_write_batch(self, command: dict, key: str):
self._validate_offline_write_batch_parameters(command)

feature_view_names = command["feature_view_names"]
name_aliases = command["name_aliases"]

project = self.store.config.project
feature_views = self.list_feature_views_by_name(
feature_view_names=feature_view_names,
name_aliases=name_aliases,
project=project,
)

assert len(feature_views) == 1
assert len(feature_views) == 1, "incorrect feature view"
table = self.flights[key]
self.offline_store.offline_write_batch(
self.store.config, feature_views[0], table, command["progress"]
)

def _validate_write_logged_features_parameters(self, command: dict):
assert "feature_service_name" in command

def write_logged_features(self, command: dict, key: str):
self._validate_write_logged_features_parameters(command)
table = self.flights[key]
feature_service = self.store.get_feature_service(
command["feature_service_name"]
)

assert feature_service.logging_config is not None
assert (
feature_service.logging_config is not None
), "feature service must have logging_config set"

self.offline_store.write_logged_features(
config=self.store.config,
Expand All @@ -218,7 +244,23 @@ def write_logged_features(self, command: dict, key: str):
registry=self.store.registry,
)

def _validate_pull_all_from_table_or_query_parameters(self, command: dict):
assert (
"data_source_name" in command
), "data_source_name is a mandatory parameter"
assert (
"join_key_columns" in command
), "join_key_columns is a mandatory parameter"
assert (
"feature_name_columns" in command
), "feature_name_columns is a mandatory parameter"
assert "timestamp_field" in command, "timestamp_field is a mandatory parameter"
assert "start_date" in command, "start_date is a mandatory parameter"
assert "end_date" in command, "end_date is a mandatory parameter"

def pull_all_from_table_or_query(self, command: dict):
self._validate_pull_all_from_table_or_query_parameters(command)

return self.offline_store.pull_all_from_table_or_query(
self.store.config,
self.store.get_data_source(command["data_source_name"]),
Expand All @@ -229,7 +271,23 @@ def pull_all_from_table_or_query(self, command: dict):
utils.make_tzaware(datetime.fromisoformat(command["end_date"])),
)

def _validate_pull_latest_from_table_or_query_parameters(self, command: dict):
assert (
"data_source_name" in command
), "data_source_name is a mandatory parameter"
assert (
"join_key_columns" in command
), "join_key_columns is a mandatory parameter"
assert (
"feature_name_columns" in command
), "feature_name_columns is a mandatory parameter"
assert "timestamp_field" in command, "timestamp_field is a mandatory parameter"
assert "start_date" in command, "start_date is a mandatory parameter"
assert "end_date" in command, "end_date is a mandatory parameter"

def pull_latest_from_table_or_query(self, command: dict):
self._validate_pull_latest_from_table_or_query_parameters(command)

return self.offline_store.pull_latest_from_table_or_query(
self.store.config,
self.store.get_data_source(command["data_source_name"]),
Expand Down Expand Up @@ -258,20 +316,33 @@ def list_actions(self, context):
),
]

def _validate_get_historical_features_parameters(self, command: dict, key: str):
assert key in self.flights, f"missing key={key}"
assert "feature_view_names" in command, "feature_view_names is mandatory"
assert "name_aliases" in command, "name_aliases is mandatory"
assert "feature_refs" in command, "feature_refs is mandatory"
assert "project" in command, "project is mandatory"
assert "full_feature_names" in command, "full_feature_names is mandatory"

def get_historical_features(self, command: dict, key: str):
self._validate_get_historical_features_parameters(command, key)

# Extract parameters from the internal flights dictionary
entity_df_value = self.flights[key]
entity_df = pa.Table.to_pandas(entity_df_value)

feature_view_names = command["feature_view_names"]
name_aliases = command["name_aliases"]
feature_refs = command["feature_refs"]
project = command["project"]
full_feature_names = command["full_feature_names"]

feature_views = self.list_feature_views_by_name(
feature_view_names=feature_view_names,
name_aliases=name_aliases,
project=project,
)

retJob = self.offline_store.get_historical_features(
config=self.store.config,
feature_views=feature_views,
Expand All @@ -281,10 +352,19 @@ def get_historical_features(self, command: dict, key: str):
project=project,
full_feature_names=full_feature_names,
)

return retJob

def persist(self, retrieve_func: str, command: dict, key: str):
def _validate_persist_parameters(self, command: dict):
assert "retrieve_func" in command, "retrieve_func is mandatory"
assert "data_source_name" in command, "data_source_name is mandatory"
assert "allow_overwrite" in command, "allow_overwrite is mandatory"

def persist(self, command: dict, key: str):
self._validate_persist_parameters(command)

try:
retrieve_func = command["retrieve_func"]
if retrieve_func == OfflineServer.get_historical_features.__name__:
ret_job = self.get_historical_features(command, key)
elif (
Expand Down

0 comments on commit de5b0eb

Please sign in to comment.