Skip to content

Commit

Permalink
0.12.0-alpha.7
Browse files Browse the repository at this point in the history
  • Loading branch information
joocer committed Sep 28, 2023
1 parent c9820fb commit f5a4286
Show file tree
Hide file tree
Showing 9 changed files with 133 additions and 44 deletions.
2 changes: 1 addition & 1 deletion opteryx/__version__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@
"""

# __version__ = "0.4.0-alpha.6"
__version__ = "0.12.0-alpha.6"
__version__ = "0.12.0-alpha.7"
5 changes: 2 additions & 3 deletions opteryx/components/binder/binder.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def create_variable_node(node: Node, context: Dict[str, Any]) -> Node:
return node

# Check if the identifier is a variable
if node.value[0] == "@":
if node.current_name[0] == "@":
node = create_variable_node(node, context)
return node, context

Expand Down Expand Up @@ -200,8 +200,7 @@ def inner_binder(node: Node, context: Dict[str, Any], step: str) -> Tuple[Node,

# If the column exists in the schema, update node and context accordingly.
if found_column:
# Convert to a FLATCOLUMN (an EVALUATED identifier)
node.schema_column = found_column # .to_flatcolumn()
node.schema_column = found_column
node.query_column = node.alias or column_name

return node, context
Expand Down
32 changes: 20 additions & 12 deletions opteryx/components/binder/binder_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,6 @@ def visit_exit(self, node: Node, context: BindingContext) -> Tuple[Node, Binding

def name_column(qualifier, column):
if len(context.schemas) > 1 or needs_qualifier:
# if len(column.aliases) == 1:
# return column.aliases[0]
return f"{qualifier}.{column.name}"
return column.name

Expand Down Expand Up @@ -453,6 +451,8 @@ def visit_join(self, node: Node, context: BindingContext) -> Tuple[Node, Binding

def visit_project(self, node: Node, context: BindingContext) -> Tuple[Node, BindingContext]:
columns = []

# Handle wildcards, including qualified wildcards.
for column in node.columns:
if not column.node_type == NodeType.WILDCARD:
columns.append(column)
Expand All @@ -471,26 +471,26 @@ def visit_project(self, node: Node, context: BindingContext) -> Tuple[Node, Bind

node.columns = columns

# Bind the local columns to physical columns
node.columns, group_contexts = zip(
*(inner_binder(col, context, node.identity) for col in node.columns)
)
context.schemas = merge_schemas(*[ctx.schemas for ctx in group_contexts])

# Check for duplicates
all_identities = [c.schema_column.identity for c in node.columns]

if len(set(all_identities)) != len(all_identities):
from collections import Counter

from opteryx.exceptions import AmbiguousIdentifierError

duplicates = [column for column, count in Counter(all_identities).items() if count > 1]
matches = {
c.query_column for c in node.columns if c.schema_column.identity in duplicates
}
matches = {c.value for c in node.columns if c.schema_column.identity in duplicates}
raise AmbiguousIdentifierError(
message=f"Query result contains multiple instances of the same column(s) - `{'`, `'.join(matches)}`"
)

# Remove columns not being projected from the schemas, and remove empty schemas
columns = []
for relation, schema in list(context.schemas.items()):
schema_columns = [
Expand All @@ -500,12 +500,20 @@ def visit_project(self, node: Node, context: BindingContext) -> Tuple[Node, Bind
context.schemas.pop(relation)
else:
schema.columns = schema_columns
columns += [
column
for column in node.columns
if column.schema_column.identity in [i.identity for i in schema_columns]
]

for column in node.columns:
if column.schema_column.identity in [i.identity for i in schema_columns]:
# If .alias is set, update .value and set .alias to None
if column.alias is not None:
column.value = column.alias
column.query_column = column.alias
current_name = column.schema_column.name
column.schema_column.name = column.alias
context.schemas[relation].pop_column(current_name)
context.schemas[relation].columns.append(column.schema_column)
column.alias = None
columns.append(column)

# We always have a $derived schema, even if it's empty
if "$derived" in context.schemas:
context.schemas["$project"] = context.schemas.pop("$derived")
if not "$derived" in context.schemas:
Expand Down
31 changes: 14 additions & 17 deletions opteryx/components/logical_planner_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from opteryx.exceptions import UnsupportedSyntaxError
from opteryx.functions.binary_operators import BINARY_OPERATORS
from opteryx.managers.expression import NodeType
from opteryx.models import LogicalColumn
from opteryx.models import Node
from opteryx.utils import dates
from opteryx.utils import suggest_alternative
Expand Down Expand Up @@ -154,24 +155,20 @@ def qualified_wildcard(branch, alias=None, key=None):


def identifier(branch, alias=None, key=None):
return Node(
node_type=NodeType.IDENTIFIER,
value=branch["value"],
alias=alias,
query_column=branch["value"],
"""idenitifier doesn't have a qualifier (recorded in source)"""
return LogicalColumn(
node_type=NodeType.IDENTIFIER, # column type
alias=alias, # AS alias, if provided
source_column=branch["value"], # the source column
)


def compound_identifier(branch, alias=None, key=None):
if alias is None:
alias = ".".join(p["value"] for p in branch)
return Node(
node_type=NodeType.IDENTIFIER,
value=".".join(p["value"] for p in branch),
alias=alias,
query_column=".".join(p["value"] for p in branch),
source_column=branch[-1]["value"],
source=".".join(p["value"] for p in branch[:-1]),
return LogicalColumn(
node_type=NodeType.IDENTIFIER, # column type
alias=alias, # AS alias, if provided
source_column=branch[-1]["value"], # the source column
source=".".join(p["value"] for p in branch[:-1]), # the source relation
)


Expand Down Expand Up @@ -329,7 +326,7 @@ def extract(branch, alias=None, key=None):
def map_access(branch, alias=None, key=None):
# Identifier[key] -> GET(Identifier, key)

field = branch["column"]["Identifier"]["value"]
identifier_node = build(branch["column"]) # ["Identifier"]["value"]
key_dict = branch["keys"][0]["Value"]
if "SingleQuotedString" in key_dict:
key = key_dict["SingleQuotedString"]
Expand All @@ -338,12 +335,12 @@ def map_access(branch, alias=None, key=None):
key = int(key_dict["Number"][0])
key_node = Node(NodeType.LITERAL, type=OrsoTypes.INTEGER, value=key)

identifier_node = Node(NodeType.IDENTIFIER, value=field)
return Node(
NodeType.FUNCTION,
value="GET",
parameters=[identifier_node, key_node],
alias=alias or f"{field}[{repr(key) if isinstance(key, str) else key}]",
alias=alias
or f"{identifier_node.current_name}[{repr(key) if isinstance(key, str) else key}]",
)


Expand Down
3 changes: 2 additions & 1 deletion opteryx/managers/expression/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,5 +96,6 @@ def format_expression(root):
return f"{format_expression(root.left)} {_map[node_type]} {format_expression(root.right)}"
if node_type == NodeType.NESTED:
return f"({format_expression(root.centre)})"

if node_type == NodeType.IDENTIFIER:
return root.current_name
return str(root.value)
3 changes: 3 additions & 0 deletions opteryx/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,8 @@

from opteryx.models.connection_context import ConnectionContext
from opteryx.models.execution_tree import ExecutionTree
from opteryx.models.logical_column import LogicalColumn
from opteryx.models.node import Node
from opteryx.models.query_properties import QueryProperties

__all__ = ("ConnectionContext", "ExecutionTree", "LogicalColumn", "Node", "QueryProperties")
77 changes: 77 additions & 0 deletions opteryx/models/logical_column.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional


class LogicalColumn:
"""
Represents a logical column in the binding phase, tied to schema columns later.
Parameters:
source_column: str
The original name of the column in its logical source (e.g., table, subquery).
source: str
The originating logical source for the column.
alias: Optional[str]
A temporary name assigned in the SQL query for the column, defaults to None.
"""

def __init__(
self,
node_type,
source_column: str,
source: Optional[str] = None,
alias: Optional[str] = None,
schema_column=None,
):
self.node_type = node_type
self.source_column = source_column
self.source = source
self.alias = alias
self.schema_column = schema_column

@property
def qualified_name(self) -> str:
"""
Returns the fully qualified column name based on the logical source and source_column.
Returns:
The fully qualified column name as a string.
"""
return f"{self.source}.{self.source_column}"

@property
def current_name(self) -> str:
"""
Returns the current name of the column, considering any alias.
Returns:
The current name of the column as a string.
"""
return self.alias or self.source_column

@property
def value(self) -> str:
return self.current_name

def __getattr__(self, name: str):
return None

def copy(self):
return LogicalColumn(
node_type=self.node_type,
source_column=self.source_column,
source=self.source,
alias=self.alias,
schema_column=self.schema_column,
)
22 changes: 13 additions & 9 deletions tests/sql_battery/test_shapes_and_errors_battery.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,8 @@
("SELECT * FROM $satellites WHERE name NOT SIMILAR TO '^C.'", 165, 8, None),
("SELECT * FROM $satellites WHERE name ~* '^c.'", 12, 8, None),
("SELECT * FROM $satellites WHERE name !~* '^c.'", 165, 8, None),

]
A = [
("SELECT COUNT(*) FROM $satellites", 1, 1, None),
("SELECT count(*) FROM $satellites", 1, 1, None),
("SELECT COUNT (*) FROM $satellites", 1, 1, None),
Expand Down Expand Up @@ -683,21 +684,20 @@
("SELECT id, name FROM $planets AS P_1 INNER JOIN $planets AS P_2 USING (id, name)", 9, 2, None),
("SELECT P_1.* FROM $planets AS P_1 INNER JOIN $planets AS P_2 USING (id, name)", 9, 18, None),
("SELECT * FROM $satellites AS P_1 INNER JOIN $satellites AS P_2 USING (id, name)", 177, 14, None),
]
A = [

("SELECT DISTINCT planetId FROM $satellites RIGHT OUTER JOIN $planets ON $satellites.planetId = $planets.id", 8, 1, None),
("SELECT DISTINCT planetId FROM $satellites RIGHT JOIN $planets ON $satellites.planetId = $planets.id", 8, 1, None),
("SELECT planetId FROM $satellites RIGHT JOIN $planets ON $satellites.planetId = $planets.id", 179, 1, None),
("SELECT DISTINCT planetId FROM $satellites FULL OUTER JOIN $planets ON $satellites.planetId = $planets.id", 8, 1, None),
("SELECT DISTINCT planetId FROM $satellites FULL JOIN $planets ON $satellites.planetId = $planets.id", 8, 1, None),
("SELECT planetId FROM $satellites FULL JOIN $planets ON $satellites.planetId = $planets.id", 179, 1, None),

("SELECT pid FROM ( SELECT id AS pid FROM $planets) WHERE pid > 5", 4, 1, None),
("SELECT * FROM ( SELECT id AS pid FROM $planets) WHERE pid > 5", 4, 1, None),
("SELECT * FROM ( SELECT COUNT(planetId) AS moons, planetId FROM $satellites GROUP BY planetId ) WHERE moons > 10", 4, 2, None),
("SELECT pid FROM ( SELECT id AS pid FROM $planets) AS SQ WHERE pid > 5", 4, 1, None),
("SELECT * FROM ( SELECT id AS pid FROM $planets) AS SQ WHERE pid > 5", 4, 1, None),
("SELECT * FROM ( SELECT COUNT(planetId) AS moons, planetId FROM $satellites GROUP BY planetId ) AS SQ WHERE moons > 10", 4, 2, None),

("SELECT * FROM $planets WHERE id = -1", 0, 20, None),
("SELECT COUNT(*) FROM (SELECT DISTINCT a FROM $astronauts CROSS JOIN UNNEST(alma_mater) AS a ORDER BY a)", 1, 1, None),
("SELECT COUNT(*) FROM (SELECT DISTINCT a FROM $astronauts CROSS JOIN UNNEST(alma_mater) AS a ORDER BY a) AS SQ", 1, 1, None),

("SELECT a.id, b.id, c.id FROM $planets AS a INNER JOIN $planets AS b ON a.id = b.id INNER JOIN $planets AS c ON c.id = b.id", 9, 3, None),
("SELECT * FROM $planets AS a INNER JOIN $planets AS b ON a.id = b.id RIGHT OUTER JOIN $satellites AS c ON c.planetId = b.id", 177, 48, None),
Expand Down Expand Up @@ -1144,6 +1144,8 @@
("SELECT LEFT('APPLE', 1) || LEFT('APPLE', 1)", 1, 1, None),
# 1153 temporal extract from cross joins
("SELECT p.name, s.name FROM $planets as p, $satellites as s WHERE p.id = s.planetId", 177, 2, None),
# Can't qualify fields used in subscripts
("SELECT d.birth_place['town'] FROM $astronauts AS d", 357, 1),
]
# fmt:on

Expand Down Expand Up @@ -1194,6 +1196,8 @@ def test_sql_battery(statement, rows, columns, exception):

from tests.tools import trunc_printable

start_suite = time.monotonic_ns()

width = shutil.get_terminal_size((80, 20))[0] - 15

passed = 0
Expand Down Expand Up @@ -1226,7 +1230,7 @@ def test_sql_battery(statement, rows, columns, exception):

print("--- ✅ \033[0;32mdone\033[0m")
print(
f"\n\033[38;2;139;233;253m\033[3mCOMPLETE\033[0m\n"
f" \033[38;2;26;185;67m{passed} passed\033[0m\n"
f"\n\033[38;2;139;233;253m\033[3mCOMPLETE\033[0m ({(time.monotonic_ns() - start_suite) // 1e9} seconds)\n"
f" \033[38;2;26;185;67m{passed} passed ({(passed * 100) // (passed + failed)}%)\033[0m\n"
f" \033[38;2;255;121;198m{failed} failed\033[0m"
)
2 changes: 1 addition & 1 deletion tests/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def run_tests():

print(
f"\n\033[38;2;139;233;253m\033[3mCOMPLETE\033[0m\n"
f" \033[38;2;26;185;67m{passed} passed\033[0m\n"
f" \033[38;2;26;185;67m{passed} passed ({(passed * 100) // (passed + failed)}%)\033[0m\n"
f" \033[38;2;255;121;198m{failed} failed\033[0m"
)

Expand Down

0 comments on commit f5a4286

Please sign in to comment.