Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 70 additions & 43 deletions ibis/expr/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,27 @@
from __future__ import annotations

from collections import defaultdict
from typing import TYPE_CHECKING, Optional

import toolz
from typing_extensions import Self

import ibis.expr.operations as ops
from ibis.common.collections import FrozenDict # noqa: TC001
from ibis.common.deferred import Item, _, deferred, var
from ibis.common.exceptions import ExpressionError, IbisInputError
from ibis.common.graph import Node as Traversable
from ibis.common.graph import traverse
from ibis.common.grounds import Concrete
from ibis.common.grounds import Annotable
from ibis.common.patterns import Check, pattern, replace
from ibis.common.typing import VarTuple # noqa: TC001
from ibis.util import Namespace, promote_list

if TYPE_CHECKING:
from collections.abc import Iterator, Mapping

import ibis.expr.types as ir

p = Namespace(pattern, module=ops)
d = Namespace(deferred, module=ops)

Expand All @@ -26,7 +33,7 @@
name = var("name")


class DerefMap(Concrete, Traversable):
class DerefMap(Annotable, Traversable):
"""Trace and replace fields from earlier relations in the hierarchy.

In order to provide a nice user experience, we need to allow expressions
Expand All @@ -52,57 +59,42 @@ class DerefMap(Concrete, Traversable):
"""

"""The relations we want the values to point to."""
rels: VarTuple[ops.Relation]
rels: frozenset[ops.Relation]

"""Extra substitutions to be added to the dereference map. Stored on the
instance to facilitate lazy dereferencing."""
extra: Optional[FrozenDict[ops.Node, ops.Node]]

"""Substitution mapping from values of earlier relations to the fields of `rels`."""
subs: FrozenDict[ops.Value, ops.Field]
subs: Optional[FrozenDict[ops.Value, ops.Field]] = None

"""Ambiguous field references."""
ambigs: FrozenDict[ops.Value, VarTuple[ops.Value]]
ambigs: Optional[FrozenDict[ops.Value, VarTuple[ops.Value]]] = None

@classmethod
def from_targets(cls, rels, extra=None):
def from_targets(
cls, rels, extra: Mapping[ops.Node, ops.Node] | None = None
) -> Self:
"""Create a dereference map from a list of target relations.

Usually a single relation is passed except for joins where multiple
relations are involved.

Parameters
----------
rels : list of ops.Relation
rels
The target relations to dereference to.
extra : dict, optional
extra
Extra substitutions to be added to the dereference map.

Returns
-------
DerefMap
"""
rels = promote_list(rels)
mapping = defaultdict(dict)
for rel in rels:
for field in rel.fields.values():
for value, distance in cls.backtrack(field):
mapping[value][field] = distance

subs, ambigs = {}, {}
for from_, to in mapping.items():
mindist = min(to.values())
minkeys = [k for k, v in to.items() if v == mindist]
# if all the closest fields are from the same relation, then we
# can safely substitute them and we pick the first one arbitrarily
if all(minkeys[0].relations == k.relations for k in minkeys):
subs[from_] = minkeys[0]
else:
ambigs[from_] = minkeys

if extra is not None:
subs.update(extra)

return cls(rels, subs, ambigs)
return cls(rels=frozenset(promote_list(rels)), extra=extra)

@classmethod
def backtrack(cls, value):
def backtrack(cls, value) -> Iterator[tuple[ops.Field, int]]:
"""Backtrack the field in the relation hierarchy.

The field is traced back until no modification is made, so only follow
Expand Down Expand Up @@ -132,28 +124,63 @@ def backtrack(cls, value):
):
yield value, distance

def dereference(self, value):
"""Dereference a value to the target relations.
def _fill_substitution_mappings(self) -> None:
if self.subs is not None and self.ambigs is not None:
return

mapping = defaultdict(dict)

for rel in self.rels:
for field in rel.fields.values():
for val, distance in self.__class__.backtrack(field):
Copy link
Contributor

@JonAnCla JonAnCla Aug 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in a separate PR I think a useful performance improvement (for long chains of expressions) would be to cache each "level" of info extracted from backtracking on a "per relation" basis
cached info would have to be held at relation level, or maybe via weakref/finalizer, so that it gets GC'd when associated relations are deleted

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not 100% clear to me from your description what kind performance improvement you might expect here.

Thinking it through out loud, it seems like the improvement scales with the number of relations in a chain, or in terms of implementation with the number of DerefMaps constructed. I think this is in line with your supposition about long chains of expressions.

I think that the complexity is in figuring out who owns the cache (another thing you're alluding to!).

What if instead of adding caching to DerefMap, we give every instance of Table or the underlying operation a lazily constructed deref map.

This would then have the effect of tying the deref map to the object, effectively caching it for the instance, and we don't have to add complexity to DerefMap to make it work.

Copy link
Contributor

@JonAnCla JonAnCla Aug 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks sorry for not being very clear, just to clarify:

  • at the moment a derefmap is constructed for every call to .bind on a table
  • this has a performance impact if:
    • a table is used as the base for many queries
    • or a chain of operations is constructed (map from ancestors could possibly be used to build descendent's map)

I agree that the map could be held as an attribute at table level and that might be better than having derefmap manage a set of cached maps via weakref etc

another thing to note is that for chains of operations derefmaps grow in size in a way that might not make naive caching a good solution. total memory requirements for a length D chain of operations with N fields is O(N*D^2), because every additional operation in the chain includes all derefmap items from the level above.

Another approach would be to only cache 1 layer depth of derefs on each table, and then do something like:

while expr.table != self:  # assumes expressions have some concept of the table they are bound to
    expr = self._deref_maps_by_src_table[expr.table][expr]

This would have total memory requirements of O(N*D) (instead of D^2). Deref maps in deep chains would be fast to construct, but fields from relations far up the chain would take longer to dereference (my assumption was that users are unlikely to use distant fields but I have no evidence)

There was a little more explanation in the previous PR so will link to relevant rather than duplicating the whole thing: #11458 (comment)

naive caching could be done as a start point and that's a decision I'm happy to leave with ibis team :) I guess for a 10 item chain of 100 column table its still just ~10000 dict items. but if people are doing much bigger things in practice it could be an issue

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a collections.ChainMap might serve this purpose well, with each new child map being only the current relation's new fields.

I believe then (as you say) that there's now an additional O(D) number of operations for a lookup in the worst case (when a look up is in the first parent).

I don't know whether the "chaining radius" is usually small, but I would guess that it is, as it seems difficult to reason about things the bigger the radius.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great thanks for the reply. I think it might be worth looking into at a later stage, I've been testing main with typical queries we run and have seen a good reduction (>50%) in time taken to build expressions. Most queries benefit from the "lazy derefmap" changes and so are now not affected by derefmap at all, and main bottlenecks are elsewhere within ibis, so I'll likely look at those next when that bubbles to top of my list

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sweet! Please make issues for the performance problems you encounter as they arise!

mapping[val][field] = distance

subs, ambigs = {}, {}
for from_, to in mapping.items():
mindist = min(to.values())
minkeys = [k for k, v in to.items() if v == mindist]
# if all the closest fields are from the same relation, then we
# can safely substitute them and we pick the first one arbitrarily
if all(minkeys[0].relations == k.relations for k in minkeys):
subs[from_] = minkeys[0]
else:
ambigs[from_] = minkeys

if extra := self.extra:
subs.update(extra)

self.subs = subs
self.ambigs = ambigs

def dereference(self, *values: ir.Value) -> Iterator[ops.Value]:
"""Dereference values to target relations.

Also check for ambiguous field references. If a field reference is found
which is marked as ambiguous, then raise an error.

Parameters
----------
value : ops.Value
The value to dereference.
values
Expression values to dereference.

Returns
-------
ops.Value
The dereferenced value.
tuple[ops.Value]
The dereferenced values.
"""
ambigs = value.find(lambda x: x in self.ambigs, filter=ops.Value)
if ambigs:
raise IbisInputError(
f"Ambiguous field reference {ambigs!r} in expression {value!r}"
)
return value.replace(self.subs, filter=ops.Value)
for v in values:
if (rels := v.relations) and rels != self.rels:
# called on every iteration but only does work once per
# instance
self._fill_substitution_mappings()

if ambigs := v.find(self.ambigs.__contains__, filter=ops.Value):
raise IbisInputError(
f"Ambiguous field reference {ambigs!r} in expression {v!r}"
)
yield v.replace(self.subs, filter=ops.Value)
else:
yield v


def flatten_predicates(node):
Expand Down
32 changes: 15 additions & 17 deletions ibis/expr/tests/test_dereference.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import pytest

import ibis
from ibis.expr.types.relations import DerefMap

Expand All @@ -21,26 +23,22 @@ def test_dereference_project():
p = t.select([t.int_col, t.double_col])

mapping = DerefMap.from_targets([p.op()])
expected = dereference_expect(
{
p.int_col: p.int_col,
p.double_col: p.double_col,
t.int_col: p.int_col,
t.double_col: p.double_col,
}

assert tuple(
mapping.dereference(
p.int_col.op(), p.double_col.op(), t.int_col.op(), t.double_col.op()
)
) == (
p.int_col.op(),
p.double_col.op(),
p.int_col.op(),
p.double_col.op(),
)
assert mapping.subs == expected


def test_dereference_mapping_self_reference():
@pytest.mark.parametrize("column", ["int_col", "double_col", "string_col"])
def test_dereference_mapping_self_reference(column):
v = t.view()

mapping = DerefMap.from_targets([v.op()])
expected = dereference_expect(
{
v.int_col: v.int_col,
v.double_col: v.double_col,
v.string_col: v.string_col,
}
)
assert mapping.subs == expected
assert tuple(mapping.dereference(v[column].op())) == (v[column].op(),)
4 changes: 2 additions & 2 deletions ibis/expr/types/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def having(self, *predicates: ir.BooleanScalar) -> GroupedTable:
"""
table = self.table.to_expr()
havings = table.bind(*predicates)
return self.copy(havings=self.havings + havings)
return self.copy(havings=(*self.havings, *havings))

def order_by(self, *by: ir.Value) -> GroupedTable:
"""Sort a grouped table expression by `expr`.
Expand All @@ -129,7 +129,7 @@ def order_by(self, *by: ir.Value) -> GroupedTable:
"""
table = self.table.to_expr()
orderings = table.bind(*by)
return self.copy(orderings=self.orderings + orderings)
return self.copy(orderings=(*self.orderings, *orderings))

def mutate(
self, *exprs: ir.Value | Sequence[ir.Value], **kwexprs: ir.Value
Expand Down
19 changes: 13 additions & 6 deletions ibis/expr/types/joins.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def prepare_predicates(
for pred in util.promote_list(predicates):
if isinstance(pred, (Value, Deferred, bool)):
for bound in bind(left, pred):
yield deref_both.dereference(bound.op())
yield from deref_both.dereference(bound.op())
else:
if isinstance(pred, tuple):
if len(pred) != 2:
Expand All @@ -180,9 +180,11 @@ def prepare_predicates(
lk = rk = pred

for lhs, rhs in zip(bind(left, lk), bind(right, rk)):
lhs = deref_left.dereference(lhs.op())
rhs = deref_right.dereference(rhs.op())
yield comparison(lhs, rhs)
yield from map(
comparison,
deref_left.dereference(lhs.op()),
deref_right.dereference(rhs.op()),
)


def finished(method):
Expand Down Expand Up @@ -390,9 +392,14 @@ def select(self, *args, **kwargs):
values = {
k: v.replace(peel_join_field, filter=ops.Value) for k, v in values.items()
}
values = {k: derefmap.dereference(v) for k, v in values.items()}
new_values = {}

node = chain.copy(values=values)
for k, v in values.items():
# ensure that there's only a single thing returned from dereferencing
(new_value,) = derefmap.dereference(v)
new_values[k] = new_value

node = chain.copy(values=new_values)
return Table(node)

@property
Expand Down
41 changes: 17 additions & 24 deletions ibis/expr/types/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,24 +658,17 @@ def bind(self, *args: Any, **kwargs: Any) -> tuple[Value, ...]:
tuple[Value, ...]
A tuple of bound values
"""
values = self._fast_bind(*args, **kwargs)
dm = None # delay creating a dereference map until one is definitely needed
this_table_relations_set = {self.op()}
result = []
for original in values:
rels = original.op().relations
if len(rels) and rels != this_table_relations_set:
# the expression needs dereferencing as it has relations (i.e. is not a literal) and is not bound to self
if dm is None:
# create a dereference map as one has not been created yet
dm = DerefMap.from_targets(self.op())
# dereference the values to `self`
value = dm.dereference(original.op()).to_expr()
value = value.name(original.get_name())
else:
value = original
result.append(value)
return tuple(result)
dm = DerefMap.from_targets(self.op())

bound = self._fast_bind(*args, **kwargs)
return (
derefed.to_expr().name(name) if original is not derefed else original
for name, original, derefed in zip(
(expr.get_name() for expr in bound),
bound,
dm.dereference(*(expr.op() for expr in bound)),
)
)

def as_scalar(self) -> ir.Scalar:
"""Inform ibis that the table expression should be treated as a scalar.
Expand Down Expand Up @@ -1085,7 +1078,7 @@ def __getitem__(self, what: str | int | slice | Sequence[str | int]):
FutureWarning,
stacklevel=2,
)
values = self.bind(args)
values = tuple(self.bind(args))

if util.all_of(values, BooleanValue):
return self.filter(values)
Expand Down Expand Up @@ -1272,9 +1265,9 @@ def group_by(
"""
from ibis.expr.types.groupby import GroupedTable

by = tuple(v for v in by if v is not None)
by = (v for v in by if v is not None)
groups = self.bind(*by, **key_exprs)
return GroupedTable(self, groups)
return GroupedTable(self, tuple(groups))

# TODO(kszucs): shouldn't this be ibis.rowid() instead not bound to a specific table?
def rowid(self) -> ir.IntegerValue:
Expand Down Expand Up @@ -1395,7 +1388,7 @@ def aggregate(

groups = self.bind(by)
metrics = self.bind(metrics, **kwargs)
having = self.bind(having)
having = tuple(self.bind(having))

groups = unwrap_aliases(groups)
metrics = unwrap_aliases(metrics)
Expand Down Expand Up @@ -3069,7 +3062,7 @@ def drop_null(
└─────┘
"""
if subset is not None:
subset = self.bind(subset)
subset = tuple(self.bind(subset))
return ops.DropNull(self, how, subset).to_expr()

def fill_null(self, replacements: ir.Scalar | Mapping[str, ir.Scalar], /) -> Table:
Expand Down Expand Up @@ -5158,7 +5151,7 @@ def relocate(
def window_by(self, time_col: str | ir.Value, /) -> WindowedTable:
from ibis.expr.types.temporal_windows import WindowedTable

time_col = next(iter(self.bind(time_col)))
time_col = next(self.bind(time_col))

# validate time_col is a timestamp column
if not isinstance(time_col, TimestampColumn):
Expand Down
4 changes: 2 additions & 2 deletions ibis/tests/expr/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -2101,13 +2101,13 @@ def eq(left, right):
assert eq(exprs, expected)

# no args
assert t.bind() == ()
assert tuple(t.bind()) == ()

def utter_failure(_):
raise ValueError("¡moo!")

with pytest.raises(ValueError, match="¡moo!"):
t.bind(foo=utter_failure)
tuple(t.bind(foo=utter_failure))


# TODO: remove when dropna is fully deprecated
Expand Down
Loading