Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Type annotate some things ("no-brainers") #827

Merged
merged 6 commits into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions data_diff/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _get_schema(pair: Tuple[Database, DbPath]) -> Dict[str, RawColumnInfo]:
return db.query_table_schema(table_path)


def diff_schemas(table1, table2, schema1, schema2, columns):
def diff_schemas(table1, table2, schema1, schema2, columns) -> None:
logging.info("Diffing schemas...")
attrs = "name", "type", "datetime_precision", "numeric_precision", "numeric_scale"
for c in columns:
Expand All @@ -103,7 +103,7 @@ def diff_schemas(table1, table2, schema1, schema2, columns):


class MyHelpFormatter(click.HelpFormatter):
def __init__(self, **kwargs):
def __init__(self, **kwargs) -> None:
super().__init__(self, **kwargs)
self.indent_increment = 6

Expand Down Expand Up @@ -281,7 +281,7 @@ def write_usage(self, prog: str, args: str = "", prefix: Optional[str] = None) -
default=None,
help="Override the dbt production schema configuration within dbt_project.yml",
)
def main(conf, run, **kw):
def main(conf, run, **kw) -> None:
log_handlers = _get_log_handlers(kw["dbt"])
if kw["table2"] is None and kw["database2"]:
# Use the "database table table" form
Expand Down Expand Up @@ -341,9 +341,7 @@ def main(conf, run, **kw):
production_schema_flag=kw["prod_schema"],
)
else:
return _data_diff(
dbt_project_dir=project_dir_override, dbt_profiles_dir=profiles_dir_override, state=state, **kw
)
_data_diff(dbt_project_dir=project_dir_override, dbt_profiles_dir=profiles_dir_override, state=state, **kw)
except Exception as e:
logging.error(e)
raise
Expand Down Expand Up @@ -389,7 +387,7 @@ def _data_diff(
threads1=None,
threads2=None,
__conf__=None,
):
) -> None:
if limit and stats:
logging.error("Cannot specify a limit when using the -s/--stats switch")
return
Expand Down
2 changes: 1 addition & 1 deletion data_diff/abcs/database_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ class Integer(NumericType, IKey):
precision: int = 0
python_type: type = int

def __attrs_post_init__(self):
def __attrs_post_init__(self) -> None:
assert self.precision == 0


Expand Down
2 changes: 1 addition & 1 deletion data_diff/cloud/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def process_response(self, value: str) -> str:
return value


def _validate_temp_schema(temp_schema: str):
def _validate_temp_schema(temp_schema: str) -> None:
if len(temp_schema.split(".")) != 2:
raise ValueError("Temporary schema should have a format <database>.<schema>")

Expand Down
2 changes: 1 addition & 1 deletion data_diff/cloud/datafold_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ class DatafoldAPI:
host: str = "https://app.datafold.com"
timeout: int = 30

def __attrs_post_init__(self):
def __attrs_post_init__(self) -> None:
self.host = self.host.rstrip("/")
self.headers = {
"Authorization": f"Key {self.api_key}",
Expand Down
2 changes: 1 addition & 1 deletion data_diff/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def _apply_config(config: Dict[str, Any], run_name: str, kw: Dict[str, Any]):
_ENV_VAR_PATTERN = r"\$\{([A-Za-z0-9_]+)\}"


def _resolve_env(config: Dict[str, Any]):
def _resolve_env(config: Dict[str, Any]) -> None:
"""
Resolve environment variables referenced as ${ENV_VAR_NAME}.
Missing environment variables are replaced with an empty string.
Expand Down
2 changes: 1 addition & 1 deletion data_diff/databases/_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ class Connect:
database_by_scheme: Dict[str, Database]
conn_cache: MutableMapping[Hashable, Database]

def __init__(self, database_by_scheme: Dict[str, Database] = DATABASE_BY_SCHEME):
def __init__(self, database_by_scheme: Dict[str, Database] = DATABASE_BY_SCHEME) -> None:
super().__init__()
self.database_by_scheme = database_by_scheme
self.conn_cache = weakref.WeakValueDictionary()
Expand Down
32 changes: 24 additions & 8 deletions data_diff/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,22 @@
import math
import sys
import logging
from typing import Any, Callable, ClassVar, Dict, Generator, Tuple, Optional, Sequence, Type, List, Union, TypeVar
from typing import (
Any,
Callable,
ClassVar,
Dict,
Generator,
Iterator,
NewType,
Tuple,
Optional,
Sequence,
Type,
List,
Union,
TypeVar,
)
from functools import partial, wraps
from concurrent.futures import ThreadPoolExecutor
import threading
Expand Down Expand Up @@ -116,7 +131,7 @@ def dialect(self) -> "BaseDialect":
def compile(self, elem, params=None) -> str:
return self.dialect.compile(self, elem, params)

def new_unique_name(self, prefix="tmp"):
def new_unique_name(self, prefix="tmp") -> str:
self._counter[0] += 1
return f"{prefix}{self._counter[0]}"

Expand Down Expand Up @@ -173,7 +188,7 @@ class ThreadLocalInterpreter:
compiler: Compiler
gen: Generator

def apply_queries(self, callback: Callable[[str], Any]):
def apply_queries(self, callback: Callable[[str], Any]) -> None:
q: Expr = next(self.gen)
while True:
sql = self.compiler.database.dialect.compile(self.compiler, q)
Expand Down Expand Up @@ -885,20 +900,21 @@ def optimizer_hints(self, hints: str) -> str:


T = TypeVar("T", bound=BaseDialect)
Row = Sequence[Any]


@attrs.define(frozen=True)
class QueryResult:
rows: list
rows: List[Row]
columns: Optional[list] = None

def __iter__(self):
def __iter__(self) -> Iterator[Row]:
return iter(self.rows)

def __len__(self):
def __len__(self) -> int:
return len(self.rows)

def __getitem__(self, i):
def __getitem__(self, i) -> Row:
return self.rows[i]


Expand Down Expand Up @@ -1209,7 +1225,7 @@ class ThreadedDatabase(Database):
_queue: Optional[ThreadPoolExecutor] = None
thread_local: threading.local = attrs.field(factory=threading.local)

def __attrs_post_init__(self):
def __attrs_post_init__(self) -> None:
self._queue = ThreadPoolExecutor(self.thread_count, initializer=self.set_conn)
logger.info(f"[{self.name}] Starting a threadpool, size={self.thread_count}.")

Expand Down
6 changes: 3 additions & 3 deletions data_diff/databases/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,10 @@ class Dialect(BaseDialect):
def random(self) -> str:
return "RAND()"

def quote(self, s: str):
def quote(self, s: str) -> str:
return f"`{s}`"

def to_string(self, s: str):
def to_string(self, s: str) -> str:
return f"cast({s} as string)"

def type_repr(self, t) -> str:
Expand Down Expand Up @@ -212,7 +212,7 @@ class BigQuery(Database):
dataset: str
_client: Any

def __init__(self, project, *, dataset, bigquery_credentials=None, **kw):
def __init__(self, project, *, dataset, bigquery_credentials=None, **kw) -> None:
super().__init__()
credentials = bigquery_credentials
bigquery = import_bigquery()
Expand Down
2 changes: 1 addition & 1 deletion data_diff/databases/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ class Clickhouse(ThreadedDatabase):

_args: Dict[str, Any]

def __init__(self, *, thread_count: int, **kw):
def __init__(self, *, thread_count: int, **kw) -> None:
super().__init__(thread_count=thread_count)

self._args = kw
Expand Down
4 changes: 2 additions & 2 deletions data_diff/databases/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def type_repr(self, t) -> str:
except KeyError:
return super().type_repr(t)

def quote(self, s: str):
def quote(self, s: str) -> str:
return f"`{s}`"

def to_string(self, s: str) -> str:
Expand Down Expand Up @@ -118,7 +118,7 @@ class Databricks(ThreadedDatabase):
catalog: str
_args: Dict[str, Any]

def __init__(self, *, thread_count, **kw):
def __init__(self, *, thread_count, **kw) -> None:
super().__init__(thread_count=thread_count)
logging.getLogger("databricks.sql").setLevel(logging.WARNING)

Expand Down
2 changes: 1 addition & 1 deletion data_diff/databases/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ class DuckDB(Database):
_args: Dict[str, Any] = attrs.field(init=False)
_conn: Any = attrs.field(init=False)

def __init__(self, **kw):
def __init__(self, **kw) -> None:
super().__init__()
self._args = kw
self._conn = self.create_connection()
Expand Down
6 changes: 3 additions & 3 deletions data_diff/databases/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class Dialect(BaseDialect):
"json": JSON,
}

def quote(self, s: str):
def quote(self, s: str) -> str:
return f"[{s}]"

def set_timezone_to_utc(self) -> str:
Expand All @@ -93,7 +93,7 @@ def current_schema(self) -> str:
FROM sys.database_principals
WHERE name = CURRENT_USER"""

def to_string(self, s: str):
def to_string(self, s: str) -> str:
# Both convert(varchar(max), …) and convert(text, …) do work.
return f"CONVERT(VARCHAR(MAX), {s})"

Expand Down Expand Up @@ -168,7 +168,7 @@ class MsSQL(ThreadedDatabase):
_args: Dict[str, Any]
_mssql: Any

def __init__(self, host, port, user, password, *, database, thread_count, **kw):
def __init__(self, host, port, user, password, *, database, thread_count, **kw) -> None:
super().__init__(thread_count=thread_count)

args = dict(server=host, port=port, database=database, user=user, password=password, **kw)
Expand Down
6 changes: 3 additions & 3 deletions data_diff/databases/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,10 @@ class Dialect(BaseDialect):
"boolean": Boolean,
}

def quote(self, s: str):
def quote(self, s: str) -> str:
return f"`{s}`"

def to_string(self, s: str):
def to_string(self, s: str) -> str:
return f"cast({s} as char)"

def is_distinct_from(self, a: str, b: str) -> str:
Expand Down Expand Up @@ -129,7 +129,7 @@ class MySQL(ThreadedDatabase):

_args: Dict[str, Any]

def __init__(self, *, thread_count, **kw):
def __init__(self, *, thread_count, **kw) -> None:
super().__init__(thread_count=thread_count)
self._args = kw

Expand Down
6 changes: 3 additions & 3 deletions data_diff/databases/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@ class Dialect(
ROUNDS_ON_PREC_LOSS = True
PLACEHOLDER_TABLE = "DUAL"

def quote(self, s: str):
def quote(self, s: str) -> str:
return f'"{s}"'

def to_string(self, s: str):
def to_string(self, s: str) -> str:
return f"cast({s} as varchar(1024))"

def limit_select(
Expand Down Expand Up @@ -164,7 +164,7 @@ class Oracle(ThreadedDatabase):
kwargs: Dict[str, Any]
_oracle: Any

def __init__(self, *, host, database, thread_count, **kw):
def __init__(self, *, host, database, thread_count, **kw) -> None:
super().__init__(thread_count=thread_count)
self.kwargs = dict(dsn=f"{host}/{database}" if database else host, **kw)
self.default_schema = kw.get("user").upper()
Expand Down
2 changes: 1 addition & 1 deletion data_diff/databases/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ class PostgreSQL(ThreadedDatabase):
_args: Dict[str, Any]
_conn: Any

def __init__(self, *, thread_count, **kw):
def __init__(self, *, thread_count, **kw) -> None:
super().__init__(thread_count=thread_count)
self._args = kw
self.default_schema = "public"
Expand Down
2 changes: 1 addition & 1 deletion data_diff/databases/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ class Presto(Database):

_conn: Any

def __init__(self, **kw):
def __init__(self, **kw) -> None:
super().__init__()
self.default_schema = "public"
prestodb = import_presto()
Expand Down
2 changes: 1 addition & 1 deletion data_diff/databases/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class Snowflake(Database):

_conn: Any

def __init__(self, *, schema: str, key: Optional[str] = None, key_content: Optional[str] = None, **kw):
def __init__(self, *, schema: str, key: Optional[str] = None, key_content: Optional[str] = None, **kw) -> None:
super().__init__()
snowflake, serialization, default_backend = import_snowflake()
logging.getLogger("snowflake.connector").setLevel(logging.WARNING)
Expand Down
2 changes: 1 addition & 1 deletion data_diff/databases/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class Trino(presto.Presto):

_conn: Any

def __init__(self, **kw):
def __init__(self, **kw) -> None:
super().__init__()
trino = import_trino()

Expand Down
4 changes: 2 additions & 2 deletions data_diff/databases/vertica.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class Dialect(BaseDialect):
# https://www.vertica.com/docs/9.3.x/HTML/Content/Authoring/SQLReferenceManual/DataTypes/Numeric/NUMERIC.htm#Default
DEFAULT_NUMERIC_PRECISION = 15

def quote(self, s: str):
def quote(self, s: str) -> str:
return f'"{s}"'

def concat(self, items: List[str]) -> str:
Expand Down Expand Up @@ -137,7 +137,7 @@ class Vertica(ThreadedDatabase):

_args: Dict[str, Any]

def __init__(self, *, thread_count, **kw):
def __init__(self, *, thread_count, **kw) -> None:
super().__init__(thread_count=thread_count)
self._args = kw
self._args["AUTOCOMMIT"] = False
Expand Down
2 changes: 1 addition & 1 deletion data_diff/dbt_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def try_get_dbt_runner():

# ProfileRenderer.render_data() fails without instantiating global flag MACRO_DEBUGGING in dbt-core 1.5
# hacky but seems to be a bug on dbt's end
def try_set_dbt_flags():
def try_set_dbt_flags() -> None:
try:
from dbt.flags import set_flags

Expand Down
4 changes: 2 additions & 2 deletions data_diff/diff_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from enum import Enum
from contextlib import contextmanager
from operator import methodcaller
from typing import Dict, Set, List, Tuple, Iterator, Optional, Union
from typing import Any, Dict, Set, List, Tuple, Iterator, Optional, Union
from concurrent.futures import ThreadPoolExecutor, as_completed

import attrs
Expand Down Expand Up @@ -89,7 +89,7 @@ class DiffResultWrapper:
stats: dict
result_list: list = attrs.field(factory=list)

def __iter__(self):
def __iter__(self) -> Iterator[Any]:
yield from self.result_list
for i in self.diff:
self.result_list.append(i)
Expand Down
2 changes: 1 addition & 1 deletion data_diff/hashdiff_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class HashDiffer(TableDiffer):

stats: dict = attrs.field(factory=dict)

def __attrs_post_init__(self):
def __attrs_post_init__(self) -> None:
# Validate options
if self.bisection_factor >= self.bisection_threshold:
raise ValueError("Incorrect param values (bisection factor must be lower than threshold)")
Expand Down
Loading