|
| 1 | +# |
| 2 | +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. |
| 3 | +# |
| 4 | +from unittest.mock import patch |
| 5 | + |
| 6 | +import pytest |
| 7 | +from sqlalchemy import Column, Integer, Sequence, String, inspect |
| 8 | +from sqlalchemy.orm import declarative_base |
| 9 | + |
| 10 | +from snowflake.sqlalchemy.custom_types import OBJECT |
| 11 | + |
| 12 | + |
| 13 | +@pytest.mark.parametrize( |
| 14 | + "cache_column_metadata,expected_schema_count,expected_desc_count", |
| 15 | + [ |
| 16 | + (False, 0, 1), |
| 17 | + (True, 1, 3), |
| 18 | + ], |
| 19 | +) |
| 20 | +def test_cache_column_metadata( |
| 21 | + cache_column_metadata, |
| 22 | + expected_schema_count, |
| 23 | + expected_desc_count, |
| 24 | + engine_testaccount, |
| 25 | +): |
| 26 | + """ |
| 27 | + Test cache_column_metadata behavior for column reflection. |
| 28 | +
|
| 29 | + This test verifies that the _cache_column_metadata flag controls whether |
| 30 | + the dialect prefetches all columns from a schema or queries individual tables. |
| 31 | +
|
| 32 | + When cache_column_metadata=False (default): |
| 33 | + - _get_schema_columns is NOT called |
| 34 | + - Only the requested table is queried via DESC TABLE |
| 35 | + - Results in 1 DESC call for the User table |
| 36 | +
|
| 37 | + When cache_column_metadata=True: |
| 38 | + - _get_schema_columns IS called (fetches all columns via information_schema) |
| 39 | + - Additional DESC TABLE calls are made for tables with structured types |
| 40 | + (MAP, ARRAY, OBJECT) to get detailed type information |
| 41 | + - Results in 1 schema query + 3 DESC calls (User, OtherTableA, OtherTableB) |
| 42 | +
|
| 43 | + Note: OtherTableC does not trigger a DESC call because it has no structured types. |
| 44 | + """ |
| 45 | + Base = declarative_base() |
| 46 | + |
| 47 | + class User(Base): |
| 48 | + __tablename__ = "user" |
| 49 | + |
| 50 | + id = Column(Integer, Sequence("user_id_seq"), primary_key=True) |
| 51 | + name = Column(String) |
| 52 | + object = Column(OBJECT) |
| 53 | + |
| 54 | + class OtherTableA(Base): |
| 55 | + __tablename__ = "other_a" |
| 56 | + |
| 57 | + id = Column(Integer, primary_key=True) |
| 58 | + name = Column(String) |
| 59 | + payload = Column(OBJECT) |
| 60 | + |
| 61 | + class OtherTableB(Base): |
| 62 | + __tablename__ = "other_b" |
| 63 | + |
| 64 | + id = Column(Integer, primary_key=True) |
| 65 | + name = Column(String) |
| 66 | + payload = Column(OBJECT) |
| 67 | + |
| 68 | + class OtherTableC(Base): |
| 69 | + __tablename__ = "other_c" |
| 70 | + |
| 71 | + id = Column(Integer, primary_key=True) |
| 72 | + name = Column(String) |
| 73 | + |
| 74 | + models = [User, OtherTableA, OtherTableB, OtherTableC] |
| 75 | + |
| 76 | + Base.metadata.create_all(engine_testaccount) |
| 77 | + |
| 78 | + inspector = inspect(engine_testaccount) |
| 79 | + schema = inspector.default_schema_name |
| 80 | + |
| 81 | + # Verify cache_column_metadata is False by default |
| 82 | + assert not engine_testaccount.dialect._cache_column_metadata |
| 83 | + |
| 84 | + # Track calls to _get_schema_columns |
| 85 | + schema_columns_count = [] |
| 86 | + original_schema_columns = engine_testaccount.dialect._get_schema_columns |
| 87 | + |
| 88 | + def tracked_schema_columns(*args, **kwargs): |
| 89 | + """Wrapper to count calls to _get_schema_columns.""" |
| 90 | + schema_columns_count.append(1) |
| 91 | + return original_schema_columns(*args, **kwargs) |
| 92 | + |
| 93 | + # Track DESC TABLE commands executed by the dialect |
| 94 | + desc_call_count = [] |
| 95 | + |
| 96 | + def tracked_execute(statement, *args, **kwargs): |
| 97 | + """ |
| 98 | + Wrapper to count DESC TABLE commands for our test tables. |
| 99 | +
|
| 100 | + Only counts DESC commands with the sqlalchemy:_get_schema_columns comment |
| 101 | + that target one of our test tables (filters out unrelated DESC calls). |
| 102 | + """ |
| 103 | + stmt_str = str(statement) |
| 104 | + if ( |
| 105 | + "DESC" in stmt_str |
| 106 | + and "sqlalchemy:_get_schema_columns" in stmt_str |
| 107 | + and any(model.__tablename__.lower() in stmt_str for model in models) |
| 108 | + ): |
| 109 | + desc_call_count.append(stmt_str) |
| 110 | + return original_execute(statement, *args, **kwargs) |
| 111 | + |
| 112 | + with patch.object( |
| 113 | + engine_testaccount.dialect, |
| 114 | + "_cache_column_metadata", |
| 115 | + cache_column_metadata, |
| 116 | + ), patch.object( |
| 117 | + engine_testaccount.dialect, |
| 118 | + "_get_schema_columns", |
| 119 | + side_effect=tracked_schema_columns, |
| 120 | + ): |
| 121 | + with engine_testaccount.connect() as conn: |
| 122 | + original_execute = conn.execute |
| 123 | + |
| 124 | + with patch.object(conn, "execute", side_effect=tracked_execute): |
| 125 | + tracked_inspector = inspect(conn) |
| 126 | + |
| 127 | + # Reflect columns for User table |
| 128 | + _ = tracked_inspector.get_columns(User.__tablename__, schema) |
| 129 | + |
| 130 | + # Verify expected behavior based on cache_column_metadata setting |
| 131 | + assert len(schema_columns_count) == expected_schema_count, ( |
| 132 | + f"Expected {expected_schema_count} _get_schema_columns call(s), got {len(schema_columns_count)}" |
| 133 | + ) |
| 134 | + assert len(desc_call_count) == expected_desc_count, ( |
| 135 | + f"Expected {expected_desc_count} DESC call(s), got {len(desc_call_count)}" |
| 136 | + ) |
0 commit comments