Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 417d8f4

Browse files
authored
Merge pull request #35 from datafold/apr1
Various changes
2 parents 8c5e98c + 10159ef commit 417d8f4

File tree

4 files changed

+23
-6
lines changed

4 files changed

+23
-6
lines changed

sqeleton/databases/base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,10 @@ class Database(AbstractDatabase[T]):
316316
def name(self):
317317
return type(self).__name__
318318

319+
def compile(self, sql_ast):
320+
compiler = Compiler(self)
321+
return compiler.compile(sql_ast)
322+
319323
def query(self, sql_ast: Union[Expr, Generator], res_type: type = None):
320324
"""Query the given SQL code/AST, and attempt to convert the result to type 'res_type'
321325
@@ -381,6 +385,8 @@ def query(self, sql_ast: Union[Expr, Generator], res_type: type = None):
381385
return [_one(row) for row in res]
382386
elif res_type.__args__ in [(Tuple,), (tuple,)]:
383387
return [tuple(row) for row in res]
388+
elif res_type.__args__ == (dict,):
389+
return [dict(safezip(res.columns, row)) for row in res]
384390
else:
385391
raise ValueError(res_type)
386392
return res

sqeleton/queries/api.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
this = This()
99

1010

11-
def join(*tables: ITable):
11+
def join(*tables: ITable) -> Join:
1212
"""Inner-join a sequence of table expressions"
1313
1414
When joining, it's recommended to use explicit tables names, instead of `this`, in order to avoid potential name collisions.
@@ -110,6 +110,11 @@ def max_(expr: Expr):
110110
return Func("max", [expr])
111111

112112

113+
def exists(expr: Expr):
114+
"""Call EXISTS(expr)"""
115+
return Func("exists", [expr])
116+
117+
113118
def if_(cond: Expr, then: Expr, else_: Optional[Expr] = None):
114119
"""Conditional expression, shortcut to when-then-else.
115120

sqeleton/queries/ast_classes.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ class ITable(AbstractTable):
9191
source_table: Any
9292
schema: Schema = None
9393

94-
def select(self, *exprs, distinct=SKIP, optimizer_hints=SKIP, **named_exprs):
94+
def select(self, *exprs, distinct=SKIP, optimizer_hints=SKIP, **named_exprs) -> "ITable":
9595
"""Create a new table with the specified fields"""
9696
exprs = args_as_tuple(exprs)
9797
exprs = _drop_skips(exprs)
@@ -248,6 +248,12 @@ def test_regex(self, other):
248248
def sum(self):
249249
return Func("SUM", [self])
250250

251+
def max(self):
252+
return Func("MAX", [self])
253+
254+
def min(self):
255+
return Func("MIN", [self])
256+
251257

252258
@dataclass
253259
class TestRegex(ExprNode, LazyOps):
@@ -261,7 +267,7 @@ def compile(self, c: Compiler) -> str:
261267
return c.compile(regex)
262268

263269

264-
@dataclass
270+
@dataclass(eq=False)
265271
class Func(ExprNode, LazyOps):
266272
name: str
267273
args: Sequence[Expr]
@@ -525,7 +531,7 @@ def schema(self):
525531
s = self.source_tables[0].schema # TODO validate types match between both tables
526532
return type(s)({c.name: c.type for c in self.columns})
527533

528-
def on(self, *exprs):
534+
def on(self, *exprs) -> "Join":
529535
"""Add an ON clause, for filtering the result of the cartesian product (i.e. the JOIN)"""
530536
if len(exprs) == 1:
531537
(e,) = exprs
@@ -538,7 +544,7 @@ def on(self, *exprs):
538544

539545
return self.replace(on_exprs=(self.on_exprs or []) + exprs)
540546

541-
def select(self, *exprs, **named_exprs):
547+
def select(self, *exprs, **named_exprs) -> ITable:
542548
"""Select fields to return from the JOIN operation
543549
544550
See Also: ``ITable.select()``

sqeleton/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def __repr__(self) -> str:
130130

131131
class CaseSensitiveDict(dict, CaseAwareMapping):
132132
def get_key(self, key):
133-
self[key] # Throw KeyError is key doesn't exist
133+
self[key] # Throw KeyError if key doesn't exist
134134
return key
135135

136136
def as_insensitive(self):

0 commit comments

Comments
 (0)