Skip to content

Commit 6629361

Browse files
committed
Fix cache_column_metadata behavior and identifier normalization
1 parent 3106287 commit 6629361

File tree

5 files changed

+156
-10
lines changed

5 files changed

+156
-10
lines changed

README.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -334,8 +334,6 @@ engine = create_engine(URL(
334334
))
335335
```
336336

337-
Note that this flag has been deprecated, as our caching now uses the built-in SQLAlchemy reflection cache, the flag has been removed, but caching has been improved and if possible extra data will be fetched and cached.
338-
339337
### VARIANT, ARRAY and OBJECT Support
340338

341339
Snowflake SQLAlchemy supports fetching `VARIANT`, `ARRAY` and `OBJECT` data types. All types are converted into `str` in Python so that you can convert them to native data types using `json.loads`.

src/snowflake/sqlalchemy/name_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
33

44
from sqlalchemy.sql.compiler import IdentifierPreparer
5-
from sqlalchemy.sql.elements import quoted_name
65

76

87
class _NameUtils:
@@ -20,7 +19,7 @@ def normalize_name(self, name):
2019
):
2120
return name.lower()
2221
elif name.lower() == name:
23-
return quoted_name(name, quote=True)
22+
return self.identifier_preparer.quote(name)
2423
else:
2524
return name
2625

src/snowflake/sqlalchemy/snowdialect.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,7 @@ def _get_schema_columns(self, connection, schema, **kw):
529529
elif issubclass(col_type, sqltypes.Numeric):
530530
col_type_kw["precision"] = numeric_precision
531531
col_type_kw["scale"] = numeric_scale
532-
elif issubclass(col_type, (sqltypes.String, sqltypes.BINARY)):
532+
elif issubclass(col_type, sqltypes.String):
533533
col_type_kw["length"] = character_maximum_length
534534
elif issubclass(col_type, StructuredType):
535535
column_info = structured_type_info_manager.get_column_info(
@@ -582,7 +582,11 @@ def get_columns(self, connection, table_name, schema=None, **kw):
582582
if not schema:
583583
_, schema = self._current_database_schema(connection, **kw)
584584

585-
schema_columns = self._get_schema_columns(connection, schema, **kw)
585+
if self._cache_column_metadata:
586+
schema_columns = self._get_schema_columns(connection, schema, **kw)
587+
else:
588+
schema_columns = None
589+
586590
if schema_columns is None:
587591
column_info_manager = _StructuredTypeInfoManager(
588592
connection, self.name_utils, self.default_schema_name

src/snowflake/sqlalchemy/structured_type_info_manager.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
from sqlalchemy import util as sa_util
77
from sqlalchemy.sql import text
8+
import sqlalchemy.sql.sqltypes as sqltypes
9+
810

911
from snowflake.sqlalchemy.name_utils import _NameUtils
1012
from snowflake.sqlalchemy.parser.custom_type_parser import NullType, parse_type
@@ -45,7 +47,6 @@ def get_column_info(
4547
def _load_structured_type_info(self, schema_name: str, table_name: str):
4648
"""Get column information for a structured type"""
4749
if (schema_name, table_name) not in self.full_columns_descriptions:
48-
4950
column_definitions = self.get_table_columns(table_name, schema_name)
5051
if not column_definitions:
5152
self.full_columns_descriptions[(schema_name, table_name)] = {}
@@ -68,8 +69,8 @@ def get_table_columns(self, table_name: str, schema: str = None):
6869

6970
schema = schema if schema else self.default_schema
7071

71-
table_schema = self.name_utils.denormalize_name(schema)
72-
table_name = self.name_utils.denormalize_name(table_name)
72+
table_schema = self.name_utils.normalize_name(schema)
73+
table_name = self.name_utils.normalize_name(table_name)
7374
result = self._execute_desc(table_schema, table_name)
7475
if not result:
7576
return []
@@ -100,10 +101,18 @@ def get_table_columns(self, table_name: str, schema: str = None):
100101
identity = {
101102
"start": int(match.group("start")),
102103
"increment": int(match.group("increment")),
103-
"order_type": match.group("order_type"),
104+
"order": match.group("order_type"),
104105
}
105106
is_identity = identity is not None
106107

108+
# Normalize BINARY type length for consistency with _get_schema_columns().
109+
# DESC TABLE returns the type with the length attribute, but information_schema.columns does not (character_maximum_length is None).
110+
# Setting length to None ensures both code paths return identical column metadata,
111+
# which is important when cache_column_metadata toggles between the two approaches.
112+
# See: tests/test_core.py::test_column_metadata
113+
if isinstance(type_instance, sqltypes.BINARY):
114+
type_instance.length = None
115+
107116
ans.append(
108117
{
109118
"name": column_name,
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
#
2+
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
3+
#
4+
from unittest.mock import patch
5+
6+
import pytest
7+
from sqlalchemy import Column, Integer, Sequence, String, inspect
8+
from sqlalchemy.orm import declarative_base
9+
10+
from snowflake.sqlalchemy.custom_types import OBJECT
11+
12+
13+
@pytest.mark.parametrize(
14+
"cache_column_metadata,expected_schema_count,expected_desc_count",
15+
[
16+
(False, 0, 1),
17+
(True, 1, 3),
18+
],
19+
)
20+
def test_cache_column_metadata(
21+
cache_column_metadata,
22+
expected_schema_count,
23+
expected_desc_count,
24+
engine_testaccount,
25+
):
26+
"""
27+
Test cache_column_metadata behavior for column reflection.
28+
29+
This test verifies that the _cache_column_metadata flag controls whether
30+
the dialect prefetches all columns from a schema or queries individual tables.
31+
32+
When cache_column_metadata=False (default):
33+
- _get_schema_columns is NOT called
34+
- Only the requested table is queried via DESC TABLE
35+
- Results in 1 DESC call for the User table
36+
37+
When cache_column_metadata=True:
38+
- _get_schema_columns IS called (fetches all columns via information_schema)
39+
- Additional DESC TABLE calls are made for tables with structured types
40+
(MAP, ARRAY, OBJECT) to get detailed type information
41+
- Results in 1 schema query + 3 DESC calls (User, OtherTableA, OtherTableB)
42+
43+
Note: OtherTableC does not trigger a DESC call because it has no structured types.
44+
"""
45+
Base = declarative_base()
46+
47+
class User(Base):
48+
__tablename__ = "user"
49+
50+
id = Column(Integer, Sequence("user_id_seq"), primary_key=True)
51+
name = Column(String)
52+
object = Column(OBJECT)
53+
54+
class OtherTableA(Base):
55+
__tablename__ = "other_a"
56+
57+
id = Column(Integer, primary_key=True)
58+
name = Column(String)
59+
payload = Column(OBJECT)
60+
61+
class OtherTableB(Base):
62+
__tablename__ = "other_b"
63+
64+
id = Column(Integer, primary_key=True)
65+
name = Column(String)
66+
payload = Column(OBJECT)
67+
68+
class OtherTableC(Base):
69+
__tablename__ = "other_c"
70+
71+
id = Column(Integer, primary_key=True)
72+
name = Column(String)
73+
74+
models = [User, OtherTableA, OtherTableB, OtherTableC]
75+
76+
Base.metadata.create_all(engine_testaccount)
77+
78+
inspector = inspect(engine_testaccount)
79+
schema = inspector.default_schema_name
80+
81+
# Verify cache_column_metadata is False by default
82+
assert not engine_testaccount.dialect._cache_column_metadata
83+
84+
# Track calls to _get_schema_columns
85+
schema_columns_count = []
86+
original_schema_columns = engine_testaccount.dialect._get_schema_columns
87+
88+
def tracked_schema_columns(*args, **kwargs):
89+
"""Wrapper to count calls to _get_schema_columns."""
90+
schema_columns_count.append(1)
91+
return original_schema_columns(*args, **kwargs)
92+
93+
# Track DESC TABLE commands executed by the dialect
94+
desc_call_count = []
95+
96+
def tracked_execute(statement, *args, **kwargs):
97+
"""
98+
Wrapper to count DESC TABLE commands for our test tables.
99+
100+
Only counts DESC commands with the sqlalchemy:_get_schema_columns comment
101+
that target one of our test tables (filters out unrelated DESC calls).
102+
"""
103+
stmt_str = str(statement)
104+
if (
105+
"DESC" in stmt_str
106+
and "sqlalchemy:_get_schema_columns" in stmt_str
107+
and any(model.__tablename__.lower() in stmt_str for model in models)
108+
):
109+
desc_call_count.append(stmt_str)
110+
return original_execute(statement, *args, **kwargs)
111+
112+
with patch.object(
113+
engine_testaccount.dialect,
114+
"_cache_column_metadata",
115+
cache_column_metadata,
116+
), patch.object(
117+
engine_testaccount.dialect,
118+
"_get_schema_columns",
119+
side_effect=tracked_schema_columns,
120+
):
121+
with engine_testaccount.connect() as conn:
122+
original_execute = conn.execute
123+
124+
with patch.object(conn, "execute", side_effect=tracked_execute):
125+
tracked_inspector = inspect(conn)
126+
127+
# Reflect columns for User table
128+
_ = tracked_inspector.get_columns(User.__tablename__, schema)
129+
130+
# Verify expected behavior based on cache_column_metadata setting
131+
assert len(schema_columns_count) == expected_schema_count, (
132+
f"Expected {expected_schema_count} _get_schema_columns call(s), got {len(schema_columns_count)}"
133+
)
134+
assert len(desc_call_count) == expected_desc_count, (
135+
f"Expected {expected_desc_count} DESC call(s), got {len(desc_call_count)}"
136+
)

0 commit comments

Comments
 (0)