From ec4c15c0104fa8f4cebdbf29f9e067baab07b09b Mon Sep 17 00:00:00 2001 From: Shuchu Han Date: Wed, 3 Apr 2024 15:46:20 -0400 Subject: [PATCH] fix: Upgrade sqlalchemy from 1.x to 2.x regarding PVE-2022-51668. (#4065) * fix: Upgrade sqlalchemy from 1.x to 2.x regarding PVE-2022-51668. Signed-off-by: Shuchu Han * fix: fix typo. Signed-off-by: Shuchu Han --------- Signed-off-by: Shuchu Han --- sdk/python/feast/infra/registry/sql.py | 45 ++++++++++++++------------ setup.py | 2 +- 2 files changed, 25 insertions(+), 22 deletions(-) diff --git a/sdk/python/feast/infra/registry/sql.py b/sdk/python/feast/infra/registry/sql.py index f9030a6875..98b23e1943 100644 --- a/sdk/python/feast/infra/registry/sql.py +++ b/sdk/python/feast/infra/registry/sql.py @@ -205,7 +205,7 @@ def teardown(self): saved_datasets, validation_references, }: - with self.engine.connect() as conn: + with self.engine.begin() as conn: stmt = delete(t) conn.execute(stmt) @@ -399,7 +399,7 @@ def apply_feature_service( ) def delete_data_source(self, name: str, project: str, commit: bool = True): - with self.engine.connect() as conn: + with self.engine.begin() as conn: stmt = delete(data_sources).where( data_sources.c.data_source_name == name, data_sources.c.project_id == project, @@ -441,7 +441,7 @@ def _list_on_demand_feature_views(self, project: str) -> List[OnDemandFeatureVie ) def _list_project_metadata(self, project: str) -> List[ProjectMetadata]: - with self.engine.connect() as conn: + with self.engine.begin() as conn: stmt = select(feast_metadata).where( feast_metadata.c.project_id == project, ) @@ -449,8 +449,11 @@ def _list_project_metadata(self, project: str) -> List[ProjectMetadata]: if rows: project_metadata = ProjectMetadata(project_name=project) for row in rows: - if row["metadata_key"] == FeastMetadataKeys.PROJECT_UUID.value: - project_metadata.project_uuid = row["metadata_value"] + if ( + row._mapping["metadata_key"] + == FeastMetadataKeys.PROJECT_UUID.value + ): + project_metadata.project_uuid = row._mapping["metadata_value"] break # TODO(adchia): Add other project metadata in a structured way return [project_metadata] @@ -557,7 +560,7 @@ def apply_user_metadata( table = self._infer_fv_table(feature_view) name = feature_view.name - with self.engine.connect() as conn: + with self.engine.begin() as conn: stmt = select(table).where( getattr(table.c, "feature_view_name") == name, table.c.project_id == project, @@ -612,11 +615,11 @@ def get_user_metadata( table = self._infer_fv_table(feature_view) name = feature_view.name - with self.engine.connect() as conn: + with self.engine.begin() as conn: stmt = select(table).where(getattr(table.c, "feature_view_name") == name) row = conn.execute(stmt).first() if row: - return row["user_metadata"] + return row._mapping["user_metadata"] else: raise FeatureViewNotFoundException(feature_view.name, project=project) @@ -674,7 +677,7 @@ def _apply_object( name = name or (obj.name if hasattr(obj, "name") else None) assert name, f"name needs to be provided for {obj}" - with self.engine.connect() as conn: + with self.engine.begin() as conn: update_datetime = datetime.utcnow() update_time = int(update_datetime.timestamp()) stmt = select(table).where( @@ -723,7 +726,7 @@ def _apply_object( def _maybe_init_project_metadata(self, project): # Initialize project metadata if needed - with self.engine.connect() as conn: + with self.engine.begin() as conn: update_datetime = datetime.utcnow() update_time = int(update_datetime.timestamp()) stmt = select(feast_metadata).where( @@ -732,7 +735,7 @@ def _maybe_init_project_metadata(self, project): ) row = conn.execute(stmt).first() if row: - usage.set_current_project_uuid(row["metadata_value"]) + usage.set_current_project_uuid(row._mapping["metadata_value"]) else: new_project_uuid = f"{uuid.uuid4()}" values = { @@ -753,7 +756,7 @@ def _delete_object( id_field_name: str, not_found_exception: Optional[Callable], ): - with self.engine.connect() as conn: + with self.engine.begin() as conn: stmt = delete(table).where( getattr(table.c, id_field_name) == name, table.c.project_id == project ) @@ -777,13 +780,13 @@ def _get_object( ): self._maybe_init_project_metadata(project) - with self.engine.connect() as conn: + with self.engine.begin() as conn: stmt = select(table).where( getattr(table.c, id_field_name) == name, table.c.project_id == project ) row = conn.execute(stmt).first() if row: - _proto = proto_class.FromString(row[proto_field_name]) + _proto = proto_class.FromString(row._mapping[proto_field_name]) return python_class.from_proto(_proto) if not_found_exception: raise not_found_exception(name, project) @@ -799,20 +802,20 @@ def _list_objects( proto_field_name: str, ): self._maybe_init_project_metadata(project) - with self.engine.connect() as conn: + with self.engine.begin() as conn: stmt = select(table).where(table.c.project_id == project) rows = conn.execute(stmt).all() if rows: return [ python_class.from_proto( - proto_class.FromString(row[proto_field_name]) + proto_class.FromString(row._mapping[proto_field_name]) ) for row in rows ] return [] def _set_last_updated_metadata(self, last_updated: datetime, project: str): - with self.engine.connect() as conn: + with self.engine.begin() as conn: stmt = select(feast_metadata).where( feast_metadata.c.metadata_key == FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value, @@ -846,7 +849,7 @@ def _set_last_updated_metadata(self, last_updated: datetime, project: str): conn.execute(insert_stmt) def _get_last_updated_metadata(self, project: str): - with self.engine.connect() as conn: + with self.engine.begin() as conn: stmt = select(feast_metadata).where( feast_metadata.c.metadata_key == FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value, @@ -855,13 +858,13 @@ def _get_last_updated_metadata(self, project: str): row = conn.execute(stmt).first() if not row: return None - update_time = int(row["last_updated_timestamp"]) + update_time = int(row._mapping["last_updated_timestamp"]) return datetime.utcfromtimestamp(update_time) def _get_all_projects(self) -> Set[str]: projects = set() - with self.engine.connect() as conn: + with self.engine.begin() as conn: for table in { entities, data_sources, @@ -872,6 +875,6 @@ def _get_all_projects(self) -> Set[str]: stmt = select(table) rows = conn.execute(stmt).all() for row in rows: - projects.add(row["project_id"]) + projects.add(row._mapping["project_id"]) return projects diff --git a/setup.py b/setup.py index e14e723d0e..f94fb25bb5 100644 --- a/setup.py +++ b/setup.py @@ -57,7 +57,7 @@ "pygments>=2.12.0,<3", "PyYAML>=5.4.0,<7", "requests", - "SQLAlchemy[mypy]>1,<2", + "SQLAlchemy[mypy]>1", "tabulate>=0.8.0,<1", "tenacity>=7,<9", "toml>=0.10.0,<1",