Skip to content

Commit

Permalink
More type hints (#584)
Browse files Browse the repository at this point in the history
* Add some type hints in metadata.py

* Add some type hints in sql/reader/base.py

* Fix a bug in sql/reader/base.py

* Add pyright to dev tools in pyproject.toml

* Stop using TypeAlias in metadata.py
  • Loading branch information
mhauru authored Nov 30, 2023
1 parent a1d5b47 commit 7d41c3b
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 24 deletions.
9 changes: 9 additions & 0 deletions sql/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,16 @@ sqlalchemy = "^2.0.0"
pandas = "^2.0.1"

[tool.poetry.dev-dependencies]
pyright = "^1"

[build-system]
requires = ["setuptools", "poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"

[tool.pyright]
# These are the only files for which, so far, we have checked that pyright passes.
# To be expanded as more type hints are added.
include = [
"snsql/metadata.py",
"snsql/sql/reader/base.py",
]
40 changes: 25 additions & 15 deletions sql/snsql/metadata.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Iterable, Union
import yaml
import io
from os import path
import warnings
if TYPE_CHECKING:
from collections.abc import Mapping
from pathlib import Path

from snsql.sql.reader.base import NameCompare

Expand All @@ -10,7 +15,7 @@
class Metadata:
"""Information about a collection of tabular data sources"""

def __init__(self, tables, engine=None, compare=None, dbname=None):
def __init__(self, tables: Iterable[Table], engine=None, compare=None, dbname=None):
"""Instantiate a metadata object with information about tabular data sources
:param tables: A list of Table descriptions
Expand All @@ -25,7 +30,7 @@ def __init__(self, tables, engine=None, compare=None, dbname=None):
self.compare = NameCompare.get_name_compare(engine) if compare is None else compare
self.dbname = dbname if dbname else None

def __getitem__(self, tablename):
def __getitem__(self, tablename: str):
schema_name = ""
dbname = ""
parts = tablename.split(".")
Expand Down Expand Up @@ -72,19 +77,19 @@ def __iter__(self):
return self.tables()

@staticmethod
def from_file(file):
def from_file(file: Union[str, io.IOBase]) -> Metadata:
"""Load the metadata about this collection from a YAML file"""
ys = CollectionYamlLoader(file)
return ys.read_file()

@staticmethod
def from_dict(schema_dict):
def from_dict(schema_dict: dict):
"""Load the metadata from a dict object"""
ys = CollectionYamlLoader("dummy")
return ys._create_metadata_object(schema_dict)

@classmethod
def from_(cls, val):
def from_(cls, val : Union[Metadata, str, io.IOBase, dict]):
if isinstance(val, Metadata):
return val
elif isinstance(val, (str, io.IOBase)):
Expand All @@ -109,9 +114,9 @@ def __init__(
self,
schema,
name,
columns,
columns: Iterable[Column],
*ignore,
rowcount=0,
rowcount:int=0,
rows_exact=None,
row_privacy=False,
max_ids=1,
Expand Down Expand Up @@ -147,8 +152,11 @@ def __init__(

if clamp_columns:
for col in self.m_columns.values():
if col.typename() in ["int", "float"] and (col.lower is None or col.upper is None):
if col.sensitivity is not None:
if (
col.typename() in ["int", "float"]
and (col.lower is None or col.upper is None) # type: ignore
and col.sensitivity is not None # type: ignore
):
raise ValueError(
f"Column {col.name} has sensitivity and no bounds, but table specifies clamp_columns. "
"clamp_columns should be False, or bounds should be provided."
Expand Down Expand Up @@ -355,11 +363,13 @@ def typename(self):
def unbounded(self):
return True

Column = Union[Boolean, DateTime, Int, Float, String, Unknown]

class CollectionYamlLoader:
def __init__(self, file):
def __init__(self, file: Union[Path, str, io.IOBase]) -> None:
self.file = file

def read_file(self):
def read_file(self) -> Metadata:
if isinstance(self.file, io.IOBase):
try:
c_s = yaml.safe_load(self.file)
Expand All @@ -376,7 +386,7 @@ def read_file(self):
raise
return self._create_metadata_object(c_s)

def _create_metadata_object(self, c_s):
def _create_metadata_object(self, c_s: Mapping) -> Metadata:
if not hasattr(c_s, "keys"):
raise ValueError("Metadata must be a YAML dictionary")
keys = list(c_s.keys())
Expand Down Expand Up @@ -407,7 +417,7 @@ def _create_metadata_object(self, c_s):

return Metadata(tables, engine, dbname=collection)

def load_table(self, schema, table, t):
def load_table(self, schema, table, t) -> Table:
rowcount = int(t["rows"]) if "rows" in t else 0
rows_exact = int(t["rows_exact"]) if "rows_exact" in t else None
row_privacy = bool(t["row_privacy"]) if "row_privacy" in t else False
Expand Down Expand Up @@ -453,7 +463,7 @@ def load_table(self, schema, table, t):
censor_dims=censor_dims,
)

def load_column(self, column, c):
def load_column(self, column, c) -> Column:
lower = float(c["lower"]) if "lower" in c else None
upper = float(c["upper"]) if "upper" in c else None
is_key = False if "private_id" not in c else bool(c["private_id"])
Expand Down Expand Up @@ -492,7 +502,7 @@ def load_column(self, column, c):
else:
raise ValueError("Unknown column type for column {0}: {1}".format(column, c))

def write_file(self, collection_metadata, collection_name):
def write_file(self, collection_metadata, collection_name) -> None:

engine = collection_metadata.engine
schemas = {}
Expand Down
20 changes: 11 additions & 9 deletions sql/snsql/sql/reader/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations
from typing import Type
from snsql.reader.base import Reader
from snsql.sql.reader.engine import Engine
import importlib
Expand All @@ -6,13 +8,13 @@

class SqlReader(Reader):
@classmethod
def get_reader_class(cls, engine):
def get_reader_class(cls, engine) -> Type[Reader]:
prefix = ""
for eng in Engine.known_engines:
if str(eng).lower() == engine.lower():
prefix = str(eng)
if prefix == "":
return SqlReader() # should this throw?
return SqlReader # should this throw?
else:
mod_path = f"snsql.sql.reader.{Engine.class_map[prefix]}"
module = importlib.import_module(mod_path)
Expand Down Expand Up @@ -72,10 +74,10 @@ def __init__(self, search_path=None):
path. Pass in only the schema part.
"""

def reserved(self):
def reserved(self) -> list[str]:
return ["select", "group", "on"]

def schema_match(self, from_query, from_meta):
def schema_match(self, from_query: str, from_meta: str) -> bool:
if from_query.strip() == "" and from_meta in self.search_path:
return True
if from_meta.strip() == "" and from_query in self.search_path:
Expand All @@ -87,25 +89,25 @@ def schema_match(self, from_query, from_meta):
if identifier used in query matches identifier
of metadata object. Pass in one part at a time.
"""
def identifier_match(self, from_query, from_meta):
def identifier_match(self, from_query: str, from_meta: str) -> bool:
return from_query == from_meta

"""
Removes all escaping characters, keeping identifiers unchanged
"""
def strip_escapes(self, value):
def strip_escapes(self, value: str) -> str:
return value.replace('"', "").replace("`", "").replace("[", "").replace("]", "")

"""
True if any part of identifier is escaped
"""
def is_escaped(self, identifier):
def is_escaped(self, identifier: str) -> bool:
return any([p[0] in ['"', "[", "`"] for p in identifier.split(".") if p != ""])

"""
Converts proprietary escaping to SQL-92. Supports multi-part identifiers
"""
def clean_escape(self, identifier):
def clean_escape(self, identifier: str) -> str:
escaped = []
for p in identifier.split("."):
if self.is_escaped(p):
Expand All @@ -118,7 +120,7 @@ def clean_escape(self, identifier):
Returns true if an identifier should
be escaped. Checks only one part per call.
"""
def should_escape(self, identifier):
def should_escape(self, identifier: str) -> bool:
if self.is_escaped(identifier):
return False
if identifier.lower() in self.reserved():
Expand Down

0 comments on commit 7d41c3b

Please sign in to comment.