From 1a1f0b1c56aa2ac00b1e1aa1e21cc200ea659334 Mon Sep 17 00:00:00 2001 From: Hao Xu Date: Wed, 17 Apr 2024 09:13:51 -0700 Subject: [PATCH] fix: Pgvector patch (#4108) --- sdk/python/feast/feature_store.py | 19 ++++--- .../infra/online_stores/contrib/postgres.py | 55 +++++++++++++------ .../feast/infra/online_stores/online_store.py | 9 ++- sdk/python/feast/infra/provider.py | 9 ++- sdk/python/tests/foo_provider.py | 9 ++- 5 files changed, 73 insertions(+), 28 deletions(-) diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index 15598e1d60..f42cced11c 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -1740,12 +1740,14 @@ def _retrieve_online_documents( query, top_k, ) - document_feature_vals = [feature[2] for feature in document_features] - document_feature_distance_vals = [feature[3] for feature in document_features] - online_features_response = GetOnlineFeaturesResponse(results=[]) # TODO Refactor to better way of populating result # TODO populate entity in the response after returning entity in document_features is supported + # TODO currently not return the vector value since it is same as feature value, if embedding is supported, + # the feature value can be raw text before embedded + document_feature_vals = [feature[2] for feature in document_features] + document_feature_distance_vals = [feature[4] for feature in document_features] + online_features_response = GetOnlineFeaturesResponse(results=[]) self._populate_result_rows_from_columnar( online_features_response=online_features_response, data={requested_feature: document_feature_vals}, @@ -1979,7 +1981,7 @@ def _retrieve_from_online_store( requested_feature: str, query: List[float], top_k: int, - ) -> List[Tuple[Timestamp, "FieldStatus.ValueType", Value, Value]]: + ) -> List[Tuple[Timestamp, "FieldStatus.ValueType", Value, Value, Value]]: """ Search and return document features from the online document store. """ @@ -1994,19 +1996,22 @@ def _retrieve_from_online_store( read_row_protos = [] row_ts_proto = Timestamp() - for row_ts, feature_val, distance_val in documents: + for row_ts, feature_val, vector_value, distance_val in documents: # Reset timestamp to default or update if row_ts is not None if row_ts is not None: row_ts_proto.FromDatetime(row_ts) - if feature_val is None or distance_val is None: + if feature_val is None or vector_value is None or distance_val is None: feature_val = Value() + vector_value = Value() distance_val = Value() status = FieldStatus.NOT_FOUND else: status = FieldStatus.PRESENT - read_row_protos.append((row_ts_proto, status, feature_val, distance_val)) + read_row_protos.append( + (row_ts_proto, status, feature_val, vector_value, distance_val) + ) return read_row_protos @staticmethod diff --git a/sdk/python/feast/infra/online_stores/contrib/postgres.py b/sdk/python/feast/infra/online_stores/contrib/postgres.py index 2890f60746..6ed0885d13 100644 --- a/sdk/python/feast/infra/online_stores/contrib/postgres.py +++ b/sdk/python/feast/infra/online_stores/contrib/postgres.py @@ -75,10 +75,7 @@ def online_write_batch( for feature_name, val in values.items(): vector_val = None - if ( - "pgvector_enabled" in config.online_store - and config.online_store.pgvector_enabled - ): + if config.online_store.pgvector_enabled: vector_val = get_list_val_str(val) insert_values.append( ( @@ -226,10 +223,7 @@ def update( for table in tables_to_keep: table_name = _table_id(project, table) - if ( - "pgvector_enabled" in config.online_store - and config.online_store.pgvector_enabled - ): + if config.online_store.pgvector_enabled: vector_value_type = f"vector({config.online_store.vector_len})" else: # keep the vector_value_type as BYTEA if pgvector is not enabled, to maintain compatibility @@ -282,7 +276,14 @@ def retrieve_online_documents( requested_feature: str, embedding: List[float], top_k: int, - ) -> List[Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto]]]: + ) -> List[ + Tuple[ + Optional[datetime], + Optional[ValueProto], + Optional[ValueProto], + Optional[ValueProto], + ] + ]: """ Args: @@ -297,10 +298,7 @@ def retrieve_online_documents( """ project = config.project - if ( - "pgvector_enabled" not in config.online_store - or not config.online_store.pgvector_enabled - ): + if not config.online_store.pgvector_enabled: raise ValueError( "pgvector is not enabled in the online store configuration" ) @@ -309,7 +307,12 @@ def retrieve_online_documents( query_embedding_str = f"[{','.join(str(el) for el in embedding)}]" result: List[ - Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto]] + Tuple[ + Optional[datetime], + Optional[ValueProto], + Optional[ValueProto], + Optional[ValueProto], + ] ] = [] with self._get_conn(config) as conn, conn.cursor() as cur: table_name = _table_id(project, table) @@ -322,6 +325,7 @@ def retrieve_online_documents( SELECT entity_key, feature_name, + value, vector_value, vector_value <-> %s as distance, event_ts FROM {table_name} @@ -338,16 +342,31 @@ def retrieve_online_documents( ) rows = cur.fetchall() - for entity_key, feature_name, vector_value, distance, event_ts in rows: + for ( + entity_key, + feature_name, + value, + vector_value, + distance, + event_ts, + ) in rows: # TODO Deserialize entity_key to return the entity in response # entity_key_proto = EntityKeyProto() # entity_key_proto_bin = bytes(entity_key) - # TODO Convert to List[float] for value type proto - feature_value_proto = ValueProto(string_val=vector_value) + feature_value_proto = ValueProto() + feature_value_proto.ParseFromString(bytes(value)) + vector_value_proto = ValueProto(string_val=vector_value) distance_value_proto = ValueProto(float_val=distance) - result.append((event_ts, feature_value_proto, distance_value_proto)) + result.append( + ( + event_ts, + feature_value_proto, + vector_value_proto, + distance_value_proto, + ) + ) return result diff --git a/sdk/python/feast/infra/online_stores/online_store.py b/sdk/python/feast/infra/online_stores/online_store.py index fc1b3d4ad3..67c5a931dd 100644 --- a/sdk/python/feast/infra/online_stores/online_store.py +++ b/sdk/python/feast/infra/online_stores/online_store.py @@ -142,7 +142,14 @@ def retrieve_online_documents( requested_feature: str, embedding: List[float], top_k: int, - ) -> List[Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto]]]: + ) -> List[ + Tuple[ + Optional[datetime], + Optional[ValueProto], + Optional[ValueProto], + Optional[ValueProto], + ] + ]: """ Retrieves online feature values for the specified embeddings. diff --git a/sdk/python/feast/infra/provider.py b/sdk/python/feast/infra/provider.py index e71e87488d..a45051a1b6 100644 --- a/sdk/python/feast/infra/provider.py +++ b/sdk/python/feast/infra/provider.py @@ -303,7 +303,14 @@ def retrieve_online_documents( requested_feature: str, query: List[float], top_k: int, - ) -> List[Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto]]]: + ) -> List[ + Tuple[ + Optional[datetime], + Optional[ValueProto], + Optional[ValueProto], + Optional[ValueProto], + ] + ]: """ Searches for the top-k nearest neighbors of the given document in the online document store. diff --git a/sdk/python/tests/foo_provider.py b/sdk/python/tests/foo_provider.py index 7ba4adb114..2a830d424c 100644 --- a/sdk/python/tests/foo_provider.py +++ b/sdk/python/tests/foo_provider.py @@ -111,5 +111,12 @@ def retrieve_online_documents( requested_feature: str, query: List[float], top_k: int, - ) -> List[Tuple[Optional[datetime], Optional[ValueProto], Optional[ValueProto]]]: + ) -> List[ + Tuple[ + Optional[datetime], + Optional[ValueProto], + Optional[ValueProto], + Optional[ValueProto], + ] + ]: return []