From e703b40582e676d4ec92551e79a444a9c0949f66 Mon Sep 17 00:00:00 2001 From: locnt241 <73770977+ElliotNguyen68@users.noreply.github.com> Date: Tue, 26 Mar 2024 09:33:58 +0700 Subject: [PATCH] fix: Add __eq__, __hash__ to SparkSource for correct comparison (#4028) * feat: Enable Arrow-based columnar data transfers Signed-off-by: tanlocnguyen * fix: Add __eq__, __hash__ to SparkSource for comparision Signed-off-by: tanlocnguyen * chore: simplify the logic Signed-off-by: tanlocnguyen --------- Signed-off-by: tanlocnguyen Co-authored-by: tanlocnguyen --- .../contrib/spark_offline_store/spark_source.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark_source.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark_source.py index 8cd392ce5d..0809043a01 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark_source.py +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark_source.py @@ -185,6 +185,19 @@ def get_table_query_string(self) -> str: return f"`{tmp_table_name}`" + def __eq__(self, other): + base_eq = super().__eq__(other) + if not base_eq: + return False + return ( + self.table == other.table + and self.query == other.query + and self.path == other.path + ) + + def __hash__(self): + return super().__hash__() + class SparkOptions: allowed_formats = [format.value for format in SparkSourceFormat]