Skip to content

Commit de862cb

Browse files
committed
Fix cache_column_metadata behavior and identifier normalization
1 parent 3106287 commit de862cb

File tree

7 files changed

+176
-16
lines changed

7 files changed

+176
-16
lines changed

README.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -334,8 +334,6 @@ engine = create_engine(URL(
334334
))
335335
```
336336

337-
Note that this flag has been deprecated, as our caching now uses the built-in SQLAlchemy reflection cache, the flag has been removed, but caching has been improved and if possible extra data will be fetched and cached.
338-
339337
### VARIANT, ARRAY and OBJECT Support
340338

341339
Snowflake SQLAlchemy supports fetching `VARIANT`, `ARRAY` and `OBJECT` data types. All types are converted into `str` in Python so that you can convert them to native data types using `json.loads`.

src/snowflake/sqlalchemy/name_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77

88
class _NameUtils:
9-
109
def __init__(self, identifier_preparer: IdentifierPreparer) -> None:
1110
self.identifier_preparer = identifier_preparer
1211

@@ -19,7 +18,9 @@ def normalize_name(self, name):
1918
name.lower()
2019
):
2120
return name.lower()
22-
elif name.lower() == name:
21+
elif name.lower() == name and self.identifier_preparer._requires_quotes(
22+
name.lower()
23+
):
2324
return quoted_name(name, quote=True)
2425
else:
2526
return name

src/snowflake/sqlalchemy/snowdialect.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,7 @@ def _get_schema_columns(self, connection, schema, **kw):
529529
elif issubclass(col_type, sqltypes.Numeric):
530530
col_type_kw["precision"] = numeric_precision
531531
col_type_kw["scale"] = numeric_scale
532-
elif issubclass(col_type, (sqltypes.String, sqltypes.BINARY)):
532+
elif issubclass(col_type, sqltypes.String):
533533
col_type_kw["length"] = character_maximum_length
534534
elif issubclass(col_type, StructuredType):
535535
column_info = structured_type_info_manager.get_column_info(
@@ -582,7 +582,11 @@ def get_columns(self, connection, table_name, schema=None, **kw):
582582
if not schema:
583583
_, schema = self._current_database_schema(connection, **kw)
584584

585-
schema_columns = self._get_schema_columns(connection, schema, **kw)
585+
if self._cache_column_metadata:
586+
schema_columns = self._get_schema_columns(connection, schema, **kw)
587+
else:
588+
schema_columns = None
589+
586590
if schema_columns is None:
587591
column_info_manager = _StructuredTypeInfoManager(
588592
connection, self.name_utils, self.default_schema_name

src/snowflake/sqlalchemy/structured_type_info_manager.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55

66
from sqlalchemy import util as sa_util
77
from sqlalchemy.sql import text
8+
import sqlalchemy.sql.sqltypes as sqltypes
9+
from sqlalchemy.sql.elements import quoted_name
10+
811

912
from snowflake.sqlalchemy.name_utils import _NameUtils
1013
from snowflake.sqlalchemy.parser.custom_type_parser import NullType, parse_type
@@ -45,7 +48,6 @@ def get_column_info(
4548
def _load_structured_type_info(self, schema_name: str, table_name: str):
4649
"""Get column information for a structured type"""
4750
if (schema_name, table_name) not in self.full_columns_descriptions:
48-
4951
column_definitions = self.get_table_columns(table_name, schema_name)
5052
if not column_definitions:
5153
self.full_columns_descriptions[(schema_name, table_name)] = {}
@@ -68,8 +70,8 @@ def get_table_columns(self, table_name: str, schema: str = None):
6870

6971
schema = schema if schema else self.default_schema
7072

71-
table_schema = self.name_utils.denormalize_name(schema)
72-
table_name = self.name_utils.denormalize_name(table_name)
73+
table_schema = self.name_utils.normalize_name(schema)
74+
table_name = self.name_utils.normalize_name(table_name)
7375
result = self._execute_desc(table_schema, table_name)
7476
if not result:
7577
return []
@@ -100,10 +102,18 @@ def get_table_columns(self, table_name: str, schema: str = None):
100102
identity = {
101103
"start": int(match.group("start")),
102104
"increment": int(match.group("increment")),
103-
"order_type": match.group("order_type"),
105+
"order": match.group("order_type"),
104106
}
105107
is_identity = identity is not None
106108

109+
# Normalize BINARY type length for consistency with _get_schema_columns().
110+
# DESC TABLE returns the type with the length attribute, but information_schema.columns does not (character_maximum_length is None).
111+
# Setting length to None ensures both code paths return identical column metadata,
112+
# which is important when cache_column_metadata toggles between the two approaches.
113+
# See: tests/test_core.py::test_column_metadata
114+
if isinstance(type_instance, sqltypes.BINARY):
115+
type_instance.length = None
116+
107117
ans.append(
108118
{
109119
"name": column_name,
@@ -129,6 +139,16 @@ def _execute_desc(self, table_schema: str, table_name: str):
129139
Exception can be caused by another session dropping the table while
130140
once this process has started"""
131141
try:
142+
table_schema = (
143+
self.name_utils.identifier_preparer.quote(table_schema)
144+
if isinstance(table_schema, quoted_name)
145+
else table_schema
146+
)
147+
table_name = (
148+
self.name_utils.identifier_preparer.quote(table_name)
149+
if isinstance(table_name, quoted_name)
150+
else table_name
151+
)
132152
return self.connection.execute(
133153
text(
134154
"DESC /* sqlalchemy:_get_schema_columns */"

tests/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def help():
107107
"protocol": "https",
108108
"host": "<host>",
109109
"port": "443",
110+
"cache_column_metadata": False,
110111
}
111112

112113

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
#
2+
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
3+
#
4+
from unittest.mock import patch
5+
6+
import pytest
7+
from sqlalchemy import Column, Integer, Sequence, String, inspect
8+
from sqlalchemy.orm import declarative_base
9+
10+
from snowflake.sqlalchemy.custom_types import OBJECT
11+
12+
13+
@pytest.mark.parametrize(
14+
"cache_column_metadata,expected_schema_count,expected_desc_count",
15+
[
16+
(False, 0, 1),
17+
(True, 1, 3),
18+
],
19+
)
20+
def test_cache_column_metadata(
21+
cache_column_metadata,
22+
expected_schema_count,
23+
expected_desc_count,
24+
engine_testaccount,
25+
):
26+
"""
27+
Test cache_column_metadata behavior for column reflection.
28+
29+
This test verifies that the _cache_column_metadata flag controls whether
30+
the dialect prefetches all columns from a schema or queries individual tables.
31+
32+
When cache_column_metadata=False (default):
33+
- _get_schema_columns is NOT called
34+
- Only the requested table is queried via DESC TABLE
35+
- Results in 1 DESC call for the User table
36+
37+
When cache_column_metadata=True:
38+
- _get_schema_columns IS called (fetches all columns via information_schema)
39+
- Additional DESC TABLE calls are made for tables with structured types
40+
(MAP, ARRAY, OBJECT) to get detailed type information
41+
- Results in 1 schema query + 3 DESC calls (User, OtherTableA, OtherTableB)
42+
43+
Note: OtherTableC does not trigger a DESC call because it has no structured types.
44+
"""
45+
Base = declarative_base()
46+
47+
class User(Base):
48+
__tablename__ = "user"
49+
50+
id = Column(Integer, Sequence("user_id_seq"), primary_key=True)
51+
name = Column(String)
52+
object = Column(OBJECT)
53+
54+
class OtherTableA(Base):
55+
__tablename__ = "other_a"
56+
57+
id = Column(Integer, primary_key=True)
58+
name = Column(String)
59+
payload = Column(OBJECT)
60+
61+
class OtherTableB(Base):
62+
__tablename__ = "other_b"
63+
64+
id = Column(Integer, primary_key=True)
65+
name = Column(String)
66+
payload = Column(OBJECT)
67+
68+
class OtherTableC(Base):
69+
__tablename__ = "other_c"
70+
71+
id = Column(Integer, primary_key=True)
72+
name = Column(String)
73+
74+
models = [User, OtherTableA, OtherTableB, OtherTableC]
75+
76+
Base.metadata.create_all(engine_testaccount)
77+
78+
inspector = inspect(engine_testaccount)
79+
schema = inspector.default_schema_name
80+
81+
# Verify cache_column_metadata is False by default
82+
assert not engine_testaccount.dialect._cache_column_metadata
83+
84+
# Track calls to _get_schema_columns
85+
schema_columns_count = []
86+
original_schema_columns = engine_testaccount.dialect._get_schema_columns
87+
88+
def tracked_schema_columns(*args, **kwargs):
89+
"""Wrapper to count calls to _get_schema_columns."""
90+
schema_columns_count.append(1)
91+
return original_schema_columns(*args, **kwargs)
92+
93+
# Track DESC TABLE commands executed by the dialect
94+
desc_call_count = []
95+
96+
def tracked_execute(statement, *args, **kwargs):
97+
"""
98+
Wrapper to count DESC TABLE commands for our test tables.
99+
100+
Only counts DESC commands with the sqlalchemy:_get_schema_columns comment
101+
that target one of our test tables (filters out unrelated DESC calls).
102+
"""
103+
stmt_str = str(statement)
104+
if (
105+
"DESC" in stmt_str
106+
and "sqlalchemy:_get_schema_columns" in stmt_str
107+
and any(model.__tablename__.lower() in stmt_str for model in models)
108+
):
109+
desc_call_count.append(stmt_str)
110+
return original_execute(statement, *args, **kwargs)
111+
112+
with patch.object(
113+
engine_testaccount.dialect,
114+
"_cache_column_metadata",
115+
cache_column_metadata,
116+
), patch.object(
117+
engine_testaccount.dialect,
118+
"_get_schema_columns",
119+
side_effect=tracked_schema_columns,
120+
):
121+
with engine_testaccount.connect() as conn:
122+
original_execute = conn.execute
123+
124+
with patch.object(conn, "execute", side_effect=tracked_execute):
125+
tracked_inspector = inspect(conn)
126+
127+
# Reflect columns for User table
128+
_ = tracked_inspector.get_columns(User.__tablename__, schema)
129+
130+
# Verify expected behavior based on cache_column_metadata setting
131+
assert len(schema_columns_count) == expected_schema_count, (
132+
f"Expected {expected_schema_count} _get_schema_columns call(s), got {len(schema_columns_count)}"
133+
)
134+
assert len(desc_call_count) == expected_desc_count, (
135+
f"Expected {expected_desc_count} DESC call(s), got {len(desc_call_count)}"
136+
)

tests/test_structured_datatypes.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -603,12 +603,12 @@ def test_structured_type_not_supported_in_table_columns_error(
603603

604604

605605
@patch.object(_StructuredTypeInfoManager, "_execute_desc")
606-
@patch.object(_NameUtils, "denormalize_name")
606+
@patch.object(_NameUtils, "normalize_name")
607607
def test_structured_type_on_dropped_table(
608-
mocked_execute_desc_method, mocked_denormalize_name_method
608+
mocked_normalize_name_method, mocked_execute_desc_method
609609
):
610610
mocked_execute_desc_method.return_value = None
611-
mocked_denormalize_name_method.side_effect = lambda self, v: v
611+
mocked_normalize_name_method.side_effect = lambda v: v
612612
structured_type_info = _StructuredTypeInfoManager(
613613
None, _NameUtils(None), "mySchema"
614614
)
@@ -619,9 +619,9 @@ def test_structured_type_on_dropped_table(
619619

620620

621621
@patch.object(_StructuredTypeInfoManager, "_execute_desc")
622-
@patch.object(_NameUtils, "denormalize_name")
622+
@patch.object(_NameUtils, "normalize_name")
623623
def test_structured_type_on_table_with_map(
624-
mocked_execute_desc_method, mocked_denormalize_name_method
624+
mocked_normalize_name_method, mocked_execute_desc_method
625625
):
626626
mocked_execute_desc_method.return_value = [
627627
[
@@ -637,7 +637,7 @@ def test_structured_type_on_table_with_map(
637637
"MapColumn",
638638
]
639639
]
640-
mocked_denormalize_name_method.side_effect = lambda self, v: v
640+
mocked_normalize_name_method.side_effect = lambda v: v
641641
structured_type_info = _StructuredTypeInfoManager(
642642
None, _NameUtils(None), "mySchema"
643643
)

0 commit comments

Comments
 (0)