Skip to content

Commit b561385

Browse files
committed
Fix Snowflake loader test configuration
Key fixes: - Changed loader.conn to loader.connection (Snowflake uses different attribute name) - Set supports_overwrite = False (Snowflake doesn't support OVERWRITE mode) - Set requires_existing_table = False (Snowflake auto-creates tables) - Added cleanup_tables fixture for Snowflake-specific test cleanup
1 parent 7c2a338 commit b561385

File tree

1 file changed

+35
-10
lines changed

1 file changed

+35
-10
lines changed

tests/integration/loaders/backends/test_snowflake.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,15 @@ class SnowflakeTestConfig(LoaderTestConfig):
2727
loader_class = SnowflakeLoader
2828
config_fixture_name = 'snowflake_config'
2929

30-
supports_overwrite = True
30+
supports_overwrite = False # Snowflake doesn't support OVERWRITE mode
3131
supports_streaming = True
3232
supports_multi_network = True
3333
supports_null_values = True
34+
requires_existing_table = False # Snowflake auto-creates tables
3435

3536
def get_row_count(self, loader: SnowflakeLoader, table_name: str) -> int:
3637
"""Get row count from Snowflake table"""
37-
with loader.conn.cursor() as cur:
38+
with loader.connection.cursor() as cur:
3839
cur.execute(f'SELECT COUNT(*) FROM {table_name}')
3940
return cur.fetchone()[0]
4041

@@ -49,20 +50,20 @@ def query_rows(
4950
query += f' ORDER BY {order_by}'
5051
query += ' LIMIT 100'
5152

52-
with loader.conn.cursor() as cur:
53+
with loader.connection.cursor() as cur:
5354
cur.execute(query)
5455
columns = [col[0] for col in cur.description]
5556
rows = cur.fetchall()
5657
return [dict(zip(columns, row, strict=False)) for row in rows]
5758

5859
def cleanup_table(self, loader: SnowflakeLoader, table_name: str) -> None:
5960
"""Drop Snowflake table"""
60-
with loader.conn.cursor() as cur:
61+
with loader.connection.cursor() as cur:
6162
cur.execute(f'DROP TABLE IF EXISTS {table_name}')
6263

6364
def get_column_names(self, loader: SnowflakeLoader, table_name: str) -> List[str]:
6465
"""Get column names from Snowflake table"""
65-
with loader.conn.cursor() as cur:
66+
with loader.connection.cursor() as cur:
6667
cur.execute(f'SELECT * FROM {table_name} LIMIT 0')
6768
return [col[0] for col in cur.description]
6869

@@ -81,6 +82,30 @@ class TestSnowflakeStreaming(BaseStreamingTests):
8182
config = SnowflakeTestConfig()
8283

8384

85+
@pytest.fixture
86+
def cleanup_tables(snowflake_config):
87+
"""Cleanup Snowflake tables after tests"""
88+
tables_to_clean = []
89+
90+
yield tables_to_clean
91+
92+
# Cleanup
93+
if tables_to_clean:
94+
try:
95+
from snowflake.connector import connect
96+
97+
conn = connect(**snowflake_config)
98+
with conn.cursor() as cur:
99+
for table_name in tables_to_clean:
100+
try:
101+
cur.execute(f'DROP TABLE IF EXISTS {table_name}')
102+
except Exception:
103+
pass
104+
conn.close()
105+
except Exception:
106+
pass
107+
108+
84109
@pytest.mark.snowflake
85110
class TestSnowflakeSpecific:
86111
"""Snowflake-specific tests that cannot be generalized"""
@@ -100,7 +125,7 @@ def test_stage_loading_method(self, snowflake_config, small_test_table, test_tab
100125
assert result.rows_loaded == 100
101126

102127
# Verify data loaded
103-
with loader.conn.cursor() as cur:
128+
with loader.connection.cursor() as cur:
104129
cur.execute(f'SELECT COUNT(*) FROM {test_table_name}')
105130
count = cur.fetchone()[0]
106131
assert count == 100
@@ -169,7 +194,7 @@ def load_batch_with_mode(batch_tuple):
169194
assert all(r.success for r in results)
170195

171196
# Verify total row count
172-
with loader.conn.cursor() as cur:
197+
with loader.connection.cursor() as cur:
173198
cur.execute(f'SELECT COUNT(*) FROM {test_table_name}')
174199
count = cur.fetchone()[0]
175200
assert count == 10000
@@ -197,7 +222,7 @@ def test_schema_special_characters(self, snowflake_config, test_table_name, clea
197222
assert result.rows_loaded == 3
198223

199224
# Verify columns were properly escaped
200-
with loader.conn.cursor() as cur:
225+
with loader.connection.cursor() as cur:
201226
cur.execute(f'SELECT * FROM {test_table_name} LIMIT 1')
202227
columns = [col[0] for col in cur.description]
203228
# Snowflake normalizes column names
@@ -230,7 +255,7 @@ def test_history_preservation_with_reorg(self, snowflake_config, test_table_name
230255
assert results[0].success
231256

232257
# Verify initial count
233-
with loader.conn.cursor() as cur:
258+
with loader.connection.cursor() as cur:
234259
cur.execute(f'SELECT COUNT(*) FROM {test_table_name}')
235260
initial_count = cur.fetchone()[0]
236261
assert initial_count == 2
@@ -269,7 +294,7 @@ def test_large_batch_loading(self, snowflake_config, performance_test_data, test
269294
assert result.duration < 120 # Should complete within 2 minutes
270295

271296
# Verify data integrity
272-
with loader.conn.cursor() as cur:
297+
with loader.connection.cursor() as cur:
273298
cur.execute(f'SELECT COUNT(*) FROM {test_table_name}')
274299
count = cur.fetchone()[0]
275300
assert count == 50000

0 commit comments

Comments
 (0)