|
2 | 2 | from __future__ import annotations |
3 | 3 |
|
4 | 4 | import math |
5 | | -import re |
6 | 5 | from datetime import datetime, timedelta, timezone |
7 | 6 | from typing import TYPE_CHECKING, Any, Callable, Generator |
8 | 7 | from unittest import mock |
9 | | -from unittest.mock import MagicMock |
| 8 | +from unittest.mock import AsyncMock, MagicMock |
10 | 9 |
|
11 | 10 | import numpy.random |
12 | 11 | import pytest |
|
24 | 23 | try: |
25 | 24 | from snowflake.connector.aio._pandas_tools import write_pandas |
26 | 25 | from snowflake.connector.options import pandas |
27 | | - from snowflake.connector.pandas_tools import _iceberg_config_statement_helper |
28 | 26 | except ImportError: |
29 | 27 | pandas = None |
30 | 28 | write_pandas = None |
31 | | - _iceberg_config_statement_helper = None |
32 | 29 |
|
33 | 30 | if TYPE_CHECKING: |
34 | 31 | from snowflake.connector.aio import SnowflakeConnection |
@@ -520,7 +517,10 @@ async def mocked_execute(*args, **kwargs): |
520 | 517 | if len(args) >= 1 and args[0].startswith("COPY INTO"): |
521 | 518 | assert kwargs["params"][0] == expected_location |
522 | 519 | cur = SnowflakeCursor(cnx) |
523 | | - cur._result = iter([]) |
| 520 | + # Create a mock result iterator with fetch_all_data method |
| 521 | + mock_result = MagicMock() |
| 522 | + mock_result.fetch_all_data = AsyncMock(return_value=[]) |
| 523 | + cur._result = mock_result |
524 | 524 | return cur |
525 | 525 |
|
526 | 526 | with mock.patch( |
@@ -572,7 +572,10 @@ async def mocked_execute(*args, **kwargs): |
572 | 572 | db_schema = ".".join(args[0].split(" ")[-1].split(".")[:-1]) |
573 | 573 | assert db_schema == expected_db_schema |
574 | 574 | cur = SnowflakeCursor(cnx) |
575 | | - cur._result = iter([]) |
| 575 | + # Create a mock result iterator with fetch_all_data method |
| 576 | + mock_result = MagicMock() |
| 577 | + mock_result.fetch_all_data = AsyncMock(return_value=[]) |
| 578 | + cur._result = mock_result |
576 | 579 | return cur |
577 | 580 |
|
578 | 581 | with mock.patch( |
@@ -626,7 +629,10 @@ async def mocked_execute(*args, **kwargs): |
626 | 629 | db_schema = ".".join(args[0].split(" ")[-1].split(".")[:-1]) |
627 | 630 | assert db_schema == expected_db_schema |
628 | 631 | cur = SnowflakeCursor(cnx) |
629 | | - cur._result = iter([]) |
| 632 | + # Create a mock result iterator with fetch_all_data method |
| 633 | + mock_result = MagicMock() |
| 634 | + mock_result.fetch_all_data = AsyncMock(return_value=[]) |
| 635 | + cur._result = mock_result |
630 | 636 | return cur |
631 | 637 |
|
632 | 638 | with mock.patch( |
@@ -682,13 +688,18 @@ async def mocked_execute(*args, **kwargs): |
682 | 688 | db_schema = ".".join(args[0].split(" ")[3].split(".")[:-1]) |
683 | 689 | assert db_schema == expected_db_schema |
684 | 690 | cur = SnowflakeCursor(cnx) |
| 691 | + mock_result = MagicMock() |
685 | 692 | if args[0].startswith("SELECT"): |
686 | 693 | cur._rownumber = 0 |
687 | | - cur._result = iter( |
688 | | - [(col, "") for col in sf_connector_version_df.get().columns] |
| 694 | + # Create a mock result iterator with fetch_all_data method |
| 695 | + mock_result.fetch_all_data = AsyncMock( |
| 696 | + return_value=[ |
| 697 | + (col, "") for col in sf_connector_version_df.get().columns |
| 698 | + ] |
689 | 699 | ) |
690 | 700 | else: |
691 | | - cur._result = iter([]) |
| 701 | + mock_result.fetch_all_data = AsyncMock(return_value=[]) |
| 702 | + cur._result = mock_result |
692 | 703 | return cur |
693 | 704 |
|
694 | 705 | with mock.patch( |
@@ -1032,34 +1043,6 @@ async def mock_execute(*args, **kwargs): |
1032 | 1043 | await cnx.execute_string(f"drop schema if exists {target_schema}") |
1033 | 1044 |
|
1034 | 1045 |
|
1035 | | -def test__iceberg_config_statement_helper(): |
1036 | | - config = { |
1037 | | - "EXTERNAL_VOLUME": "vol", |
1038 | | - "CATALOG": "'SNOWFLAKE'", |
1039 | | - "BASE_LOCATION": "/root", |
1040 | | - "CATALOG_SYNC": "foo", |
1041 | | - "STORAGE_SERIALIZATION_POLICY": "bar", |
1042 | | - } |
1043 | | - assert ( |
1044 | | - _iceberg_config_statement_helper(config) |
1045 | | - == "EXTERNAL_VOLUME='vol' CATALOG='SNOWFLAKE' BASE_LOCATION='/root' CATALOG_SYNC='foo' STORAGE_SERIALIZATION_POLICY='bar'" |
1046 | | - ) |
1047 | | - |
1048 | | - config["STORAGE_SERIALIZATION_POLICY"] = None |
1049 | | - assert ( |
1050 | | - _iceberg_config_statement_helper(config) |
1051 | | - == "EXTERNAL_VOLUME='vol' CATALOG='SNOWFLAKE' BASE_LOCATION='/root' CATALOG_SYNC='foo'" |
1052 | | - ) |
1053 | | - |
1054 | | - config["foo"] = True |
1055 | | - config["bar"] = True |
1056 | | - with pytest.raises( |
1057 | | - ProgrammingError, |
1058 | | - match=re.escape("Invalid iceberg configurations option(s) provided BAR, FOO"), |
1059 | | - ): |
1060 | | - _iceberg_config_statement_helper(config) |
1061 | | - |
1062 | | - |
1063 | 1046 | async def test_write_pandas_with_on_error( |
1064 | 1047 | conn_cnx: Callable[..., Generator[SnowflakeConnection]], |
1065 | 1048 | ): |
|
0 commit comments