From 7f9d76b3655b210424b62ac18ed10371965396d3 Mon Sep 17 00:00:00 2001 From: Pablo Nicolas Estevez Date: Fri, 25 Oct 2024 11:52:59 -0300 Subject: [PATCH] add tupleCursor and type dictCursor --- src/snowflake/connector/connection.py | 7 ++++--- src/snowflake/connector/cursor.py | 27 ++++++++++++++++++++++++++- 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/src/snowflake/connector/connection.py b/src/snowflake/connector/connection.py index 5205bafc1..45095b418 100644 --- a/src/snowflake/connector/connection.py +++ b/src/snowflake/connector/connection.py @@ -25,7 +25,7 @@ from threading import Lock from time import strptime from types import TracebackType -from typing import Any, Callable, Generator, Iterable, Iterator, NamedTuple, Sequence +from typing import Any, Callable, Generator, Iterable, Iterator, NamedTuple, Sequence, TypeVar from uuid import UUID from cryptography.hazmat.backends import default_backend @@ -113,6 +113,7 @@ DEFAULT_CLIENT_PREFETCH_THREADS = 4 MAX_CLIENT_PREFETCH_THREADS = 10 DEFAULT_BACKOFF_POLICY = exponential_backoff() +T = TypeVar('T', bound=SnowflakeCursor) def DefaultConverterClass() -> type: @@ -855,8 +856,8 @@ def rollback(self) -> None: self.cursor().execute("ROLLBACK") def cursor( - self, cursor_class: type[SnowflakeCursor] = SnowflakeCursor - ) -> SnowflakeCursor: + self, cursor_class: type[T] = SnowflakeCursor + ) -> T: """Creates a cursor object. Each statement will be executed in a new cursor object.""" logger.debug("cursor") if not self.rest: diff --git a/src/snowflake/connector/cursor.py b/src/snowflake/connector/cursor.py index 8b9d400e0..9d7409cf1 100644 --- a/src/snowflake/connector/cursor.py +++ b/src/snowflake/connector/cursor.py @@ -1724,13 +1724,38 @@ def get_result_batches(self) -> list[ResultBatch] | None: class DictCursor(SnowflakeCursor): """Cursor returning results in a dictionary.""" - def __init__(self, connection) -> None: super().__init__( connection, use_dict_result=True, ) + def fetchone(self) -> dict | None: + return super().fetchone() + + def fetchmany(self, size: int | None = None) -> list[dict]: + return super().fetchmany() + + def fetchall(self) -> list[dict]: + return super().fetchall() + +class TupleCursor(SnowflakeCursor): + """Cursor returning results in a dictionary.""" + + def __init__(self, connection) -> None: + super().__init__( + connection, + use_dict_result=False, + ) + + def fetchone(self) -> tuple | None: + return super().fetchone() + + def fetchmany(self, size: int | None = None) -> list[tuple]: + return super().fetchmany() + + def fetchall(self) -> list[tuple]: + return super().fetchall() def __getattr__(name): if name == "NanoarrowUsage":