Skip to content

Implement JOIN #245

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 146 additions & 27 deletions beanquery/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,22 @@

from .query_compile import (
EvalAggregator,
EvalAnd,
EvalAll,
EvalAnd,
EvalAny,
EvalCoalesce,
EvalColumn,
EvalConstant,
EvalConstantSubquery1D,
EvalCreateTable,
EvalGetItem,
EvalGetter,
EvalHashJoin,
EvalInsert,
EvalOr,
EvalPivot,
EvalProjection,
EvalQuery,
EvalConstantSubquery1D,
EvalRow,
EvalTarget,
FUNCTIONS,
Expand All @@ -54,7 +56,8 @@ def __init__(self, message, node=None):
class Compiler:
def __init__(self, context):
self.context = context
self.stack = [context.tables.get(None)]
self.stack = []
self.columns = {}

@property
def table(self):
Expand Down Expand Up @@ -98,7 +101,11 @@ def _compile(self, node: Optional[ast.Node]):

@_compile.register
def _select(self, node: ast.Select):
self.stack.append(self.table)
self.stack.append(self.context.tables.get(None))

# JOIN.
if isinstance(node.from_clause, ast.Join):
return self._compile_join(node)

# Compile the FROM clause.
c_from_expr = self._compile_from(node.from_clause)
Expand Down Expand Up @@ -177,11 +184,12 @@ def _compile_from(self, node):
# FROM expression.
if isinstance(node, ast.From):
# Check if the FROM expression is a column name belongin to the current table.
if isinstance(node.expression, ast.Column):
column = self.table.columns.get(node.expression.name)
if isinstance(node.expression, ast.Column) and len(node.expression.ids) == 1:
name = node.expression.ids[0].name
column = self.table.columns.get(name)
if column is None:
# When it is not, threat it as a table name.
table = self.context.tables.get(node.expression.name)
table = self.context.tables.get(name)
if table is not None:
self.table = table
return None
Expand Down Expand Up @@ -214,8 +222,7 @@ def _compile_targets(self, targets):
# Bind the targets expressions to the execution context.
if isinstance(targets, ast.Asterisk):
# Insert the full list of available columns.
targets = [ast.Target(ast.Column(name), None)
for name in self.table.wildcard_columns]
targets = [ast.Target(ast.Column([ast.Name(name)]), None) for name in self.table.wildcard_columns]

# Compile targets.
c_targets = []
Expand Down Expand Up @@ -287,7 +294,7 @@ def _compile_order_by(self, order_by, c_targets):
# simple Column expressions. If they refer to a target name, we
# resolve them.
if isinstance(column, ast.Column):
name = column.name
name = '.'.join(i.name for i in column.ids)
index = targets_name_map.get(name, None)

# Otherwise we compile the expression and add it to the list of
Expand Down Expand Up @@ -337,7 +344,7 @@ def _compile_pivot_by(self, pivot_by, targets, group_indexes):
continue

# Process target references by name.
if isinstance(column, ast.Column):
if isinstance(column, ast.Name):
index = names.get(column.name, None)
if index is None:
raise CompilationError(f'PIVOT BY column {column!r} is not in the targets list')
Expand Down Expand Up @@ -403,7 +410,7 @@ def _compile_group_by(self, group_by, c_targets):
# simple Column expressions. If they refer to a target name, we
# resolve them.
if isinstance(column, ast.Column):
name = column.name
name = '.'.join(i.name for i in column.ids)
index = targets_name_map.get(name, None)

# Otherwise we compile the expression and add it to the list of
Expand Down Expand Up @@ -475,12 +482,121 @@ def _compile_group_by(self, group_by, c_targets):

return new_targets[len(c_targets):], group_indexes, having_index

def _compile_join(self, node):
join = node.from_clause

left = self.context.tables.get(join.left.name)
if left is None:
raise CompilationError(f'table "{join.left.name}" does not exist', join.left)
right = self.context.tables.get(join.right.name)
if right is None:
raise CompilationError(f'table "{join.right.name}" does not exist', join.right)
self.table = right
self.stack.append(left)

if join.using is not None:
join.constraint = ast.Equal(
ast.Column([ast.Name(join.left.name) , ast.Name(join.using.name)]),
ast.Column([ast.Name(join.right.name) , ast.Name(join.using.name)]),
)

constraint = self._compile(join.constraint)
keycolnames = [col for t, col in self.columns.keys() if t == left.name]
targets = self._compile_targets(node.targets)

left_column_names = [col for t, col in self.columns.keys() if t == left.name]
right_column_names = [col for t, col in self.columns.keys() if t == right.name]

left_p = EvalProjection(left, [left[col] for col in left_column_names])
right_p = EvalProjection(right, [right[col] for col in right_column_names])

from beanquery.tables import Table

def itemgetter(item, datatype):
def func(row):
return row[item]
func.dtype = datatype
func.__qualname__ = func.__name__ = f'column[{item}, {datatype.__name__}]'
return func

table = Table()
table.columns = {}
for i, column in enumerate(sorted(self.columns.items(), key=lambda x: x[0][0])):
key, col = column
tname, colname = key
table.columns[f'{tname}.{colname}'] = itemgetter(i, col.dtype)
table.columns[f'{colname}'] = itemgetter(i, col.dtype)

self.stack = [table]
constraint = self._compile(join.constraint)

left_columns = {col: itemgetter(i, left.columns[col].dtype) for i, col in enumerate(left_column_names)}
keycols = [left_columns[name] for name in keycolnames]
def keyfunc(lrow, keycols=keycols):
return tuple(keycol(lrow) for keycol in keycols)

join = EvalHashJoin(left_p, right_p, constraint, keyfunc)

targets = self._compile_targets(node.targets)
cols = []
for t in targets:
expr = t.c_expr
expr.name = t.name
cols.append(expr)

return EvalProjection(join, cols)

def _resolve_column(self, node: ast.Column):
parts = node.ids[::-1]

# FIXME!!
if len(parts) > 1:
colname = f'{parts[-1].name}.{parts[-2].name}'
for table in reversed(self.stack):
column = table.columns.get(colname)
if column is not None:
self.columns[(table.name, colname)] = column
return column, parts[:-2]

colname = parts.pop().name
for table in reversed(self.stack):
column = table.columns.get(colname)
if column is not None:
self.columns[(table.name, colname)] = column
return column, parts
if parts:
# table.column
name = colname
colname = parts.pop().name
for table in reversed(self.stack):
if table.name == name:
column = table.columns.get(colname)
if column is not None:
self.columns[(table.name, colname)] = column
return column, parts
raise CompilationError(f'column "{colname}" not found in table "{table.name}"', node)

@_compile.register
def _column(self, node: ast.Column):
column = self.table.columns.get(node.name)
if column is not None:
return column
raise CompilationError(f'column "{node.name}" not found in table "{self.table.name}"', node)
column, parts = self._resolve_column(node)
for part in parts:
column = self._resolve_attribute(column, part)
return column

# operand = None
# if isinstance(node.operand, ast.Column):
# # This can be table.column or column.attribute.
# if node.operand.name in self.table.columns:
# operand = self._column(node.operand)
# elif f'{node.operand.name}.{node.name}' in self.table.columns:
# return self._column(ast.Column(f'{node.operand.name}.{node.name}'))
# else:
# for table in reversed(self.stack):
# if table.name == node.operand.name:
# column = table.columns.get(node.name)
# if column:
# self.columns.append((table.name, node.name))
# return column

@_compile.register
def _or(self, node: ast.Or):
Expand Down Expand Up @@ -567,25 +683,25 @@ def _function(self, node: ast.Function):
# Replace ``meta(key)`` with ``meta[key]``.
if node.fname == 'meta':
key = node.operands[0]
node = ast.Function('getitem', [ast.Column('meta', parseinfo=node.parseinfo), key])
node = ast.Function('getitem', [ast.Column([ast.Name('meta')], parseinfo=node.parseinfo), key])
return self._compile(node)

# Replace ``entry_meta(key)`` with ``entry.meta[key]``.
if node.fname == 'entry_meta':
key = node.operands[0]
node = ast.Function('getitem', [ast.Attribute(ast.Column('entry', parseinfo=node.parseinfo), 'meta'), key])
node = ast.Function('getitem', [ast.Attribute(ast.Column([ast.Name('entry')], parseinfo=node.parseinfo), 'meta'), key])
return self._compile(node)

# Replace ``any_meta(key)`` with ``getitem(meta, key, entry.meta[key])``.
if node.fname == 'any_meta':
key = node.operands[0]
node = ast.Function('getitem', [ast.Column('meta', parseinfo=node.parseinfo), key, ast.Function('getitem', [
ast.Attribute(ast.Column('entry', parseinfo=node.parseinfo), 'meta'), key])])
node = ast.Function('getitem', [ast.Column([ast.Name('meta')], parseinfo=node.parseinfo), key, ast.Function('getitem', [
ast.Attribute(ast.Column([ast.Name('entry')], parseinfo=node.parseinfo), 'meta'), key])])
return self._compile(node)

# Replace ``has_account(regexp)`` with ``('(?i)' + regexp) ~? any (accounts)``.
if node.fname == 'has_account':
node = ast.Any(ast.Add(ast.Constant('(?i)'), node.operands[0]), '?~', ast.Column('accounts'))
node = ast.Any(ast.Add(ast.Constant('(?i)'), node.operands[0]), '?~', ast.Column([ast.Name('accounts')]))
return self._compile(node)

function = function(self.context, operands)
Expand All @@ -601,9 +717,7 @@ def _subscript(self, node: ast.Subscript):
return EvalGetItem(operand, node.key)
raise CompilationError('column type is not subscriptable', node)

@_compile.register
def _attribute(self, node: ast.Attribute):
operand = self._compile(node.operand)
def _resolve_attribute(self, operand, node):
dtype = types.ALIASES.get(operand.dtype, operand.dtype)
if issubclass(dtype, types.Structure):
getter = dtype.columns.get(node.name)
Expand All @@ -612,6 +726,11 @@ def _attribute(self, node: ast.Attribute):
return EvalGetter(operand, getter, getter.dtype)
raise CompilationError('column type is not structured', node)

@_compile.register
def _attribute(self, node: ast.Attribute):
operand = self._compile(node.operand)
return self._resolve_attribute(operand, node)

@_compile.register
def _unaryop(self, node: ast.UnaryOp):
operand = self._compile(node.operand)
Expand Down Expand Up @@ -711,7 +830,7 @@ def _constant(self, node: ast.Constant):
# in the current table.
if isinstance(node.value, str) and node.text and node.text[0] == '"':
if node.value in self.table.columns:
return self._column(ast.Column(node.value))
return self._column(ast.Column([ast.Name(node.value)]))
return EvalConstant(node.value)

@_compile.register
Expand All @@ -732,7 +851,7 @@ def _journal(self, node: ast.Journal):

@_compile.register
def _print(self, node: ast.Print):
self.table = self.context.tables.get('entries')
self.stack.append(self.context.tables.get('entries'))
expr = self._compile_from(node.from_clause)
targets = [EvalTarget(EvalRow(), 'ROW(*)', False)]
return EvalQuery(self.table, targets, expr, None, None, None, None, False)
Expand Down Expand Up @@ -854,7 +973,7 @@ def get_target_name(target):
if target.name is not None:
return target.name
if isinstance(target.expression, ast.Column):
return target.expression.name
return '.'.join(i.name for i in target.expression.ids)
return target.expression.text.strip()


Expand Down
7 changes: 6 additions & 1 deletion beanquery/parser/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ class From(Node):
clear: Optional[bool] = None
parseinfo: Any = dataclasses.field(default=None, compare=False, repr=False)

Join = node('Join', 'left right constraint using')

# A GROUP BY clause.
#
# Attributes:
Expand Down Expand Up @@ -180,11 +182,14 @@ def __repr__(self):
# name: The table name.
Table = node('Table', 'name')

Name = node('Name', 'name')

# A reference to a column.
#
# Attributes:
# name: A string, the name of the column to access.
Column = node('Column', 'name')
Column = node('Column', 'ids')


# A function call.
#
Expand Down
18 changes: 13 additions & 5 deletions beanquery/parser/bql.ebnf
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ statement

select::Select
= 'SELECT' ['DISTINCT' distinct:`True`] targets:(','.{ target }+ | asterisk)
['FROM' from_clause:(_table | subselect | from)]
['FROM' from_clause:(_table | join | subselect | from)]
['WHERE' where_clause:expression]
['GROUP' 'BY' group_by:groupby]
['ORDER' 'BY' order_by:','.{order}+]
Expand All @@ -45,6 +45,10 @@ from::From
| expression:expression ['OPEN' 'ON' open:date] ['CLOSE' ('ON' close:date | {} close:`True`)] ['CLEAR' clear:`True`]
;

join::Join
= left:table 'JOIN' ~ right:table ('ON' constraint:expression | 'USING' using:name)
;

_table::Table
=
| name:/#([a-zA-Z_][a-zA-Z0-9_]*)?/
Expand All @@ -68,7 +72,7 @@ ordering
;

pivotby::PivotBy
= columns+:(integer | column) ',' columns+:(integer | column)
= columns+:(integer | name) ',' columns+:(integer | name)
;

target::Target
Expand Down Expand Up @@ -301,10 +305,14 @@ function::Function
| fname:identifier '(' operands+:asterisk ')'
;

column::Column
name::Name
= name:identifier
;

column::Column
= ids:'.'.{ name }+
;

literal
=
| date
Expand Down Expand Up @@ -396,6 +404,6 @@ create_table::CreateTable

insert::Insert
= 'INSERT' 'INTO' ~ table:table
['(' columns:','.{column} ')']
'VALUES' '(' values:','.{expression} ')'
['(' columns:','.{ name } ')']
'VALUES' '(' values:','.{ expression } ')'
;
Loading