Skip to content

Commit

Permalink
feat: make Builders, Patterns and various Annotation classes hashable
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs committed Sep 11, 2024
1 parent 8961fbd commit 37ef82e
Show file tree
Hide file tree
Showing 8 changed files with 478 additions and 43 deletions.
26 changes: 25 additions & 1 deletion koerce/annots.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
_any,
pattern,
)
from .utils import get_type_hints, get_type_origin
from .utils import PseudoHashable, get_type_hints, get_type_origin

EMPTY = inspect.Parameter.empty
_ensure_pattern = pattern
Expand All @@ -38,6 +38,9 @@ def __init__(self, pattern: Any = _any, default: Any = EMPTY):
def __repr__(self):
return f"<{self.__class__.__name__} pattern={self.pattern!r} default={self.default_!r}>"

def __hash__(self) -> int:
return hash((self.__class__, self.pattern, PseudoHashable(self.default_)))

def __eq__(self, other: Any) -> bool:
if not isinstance(other, Attribute):
return NotImplemented
Expand Down Expand Up @@ -124,6 +127,17 @@ def __eq__(self, other: Any) -> bool:
and self.typehint == right.typehint
)

def __hash__(self) -> int:
return hash(
(
self.__class__,
self.kind,
self.pattern,
PseudoHashable(self.default_),
self.typehint,
)
)


@cython.final
@cython.cclass
Expand Down Expand Up @@ -266,6 +280,16 @@ def __eq__(self, other: Any) -> bool:
and self.return_typehint == right.return_typehint
)

def __hash__(self) -> int:
return hash(
(
self.__class__,
PseudoHashable(self.parameters),
self.return_pattern,
self.return_typehint,
)
)

def __call__(self, /, *args, **kwargs):
return self.bind(args, kwargs)

Expand Down
65 changes: 57 additions & 8 deletions koerce/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
import functools
import inspect
import operator
from typing import Any
from typing import Any, Optional

import cython

from .utils import PseudoHashable

Context = dict[str, Any]


Expand Down Expand Up @@ -64,9 +66,6 @@ def __getitem__(self, name):
def __call__(self, *args, **kwargs):
return Deferred(Call(self, *args, **kwargs))

# def __contains__(self, item):
# return Deferred(Binop(operator.contains, self, item))

def __invert__(self) -> Deferred:
return Deferred(Unop(operator.invert, self))

Expand Down Expand Up @@ -187,6 +186,12 @@ def build(self, ctx: Context): ...
def __eq__(self, other: Any) -> bool:
return type(self) is type(other) and self.equals(other)

def __hash__(self):
return self._hash()

def __repr__(self):
raise NotImplementedError(f"{self.__class__.__name__} is not reprable")


def _deferred_repr(obj):
try:
Expand Down Expand Up @@ -223,6 +228,9 @@ def __init__(self, func: Any):
def __repr__(self):
return _deferred_repr(self.func)

def _hash(self):
return hash((self.__class__, self.func))

def equals(self, other: Func) -> bool:
return self.func == other.func

Expand All @@ -247,12 +255,17 @@ class Just(Builder):
def __init__(self, value: Any):
if isinstance(value, Just):
self.value = cython.cast(Just, value).value
elif isinstance(value, (Builder, Deferred)):
raise TypeError(f"`{value}` cannot be used as a Just value")
else:
self.value = value

def __repr__(self):
return _deferred_repr(self.value)

def _hash(self):
return hash((self.__class__, PseudoHashable(self.value)))

def equals(self, other: Just) -> bool:
return self.value == other.value

Expand All @@ -279,6 +292,9 @@ def __init__(self, name: str):
def __repr__(self):
return f"${self.name}"

def _hash(self):
return hash((self.__class__, self.name))

def equals(self, other: Var) -> bool:
return self.name == other.name

Expand Down Expand Up @@ -332,6 +348,9 @@ def __init__(self, func):
def __repr__(self):
return f"{self.func!r}()"

def _hash(self):
return hash((self.__class__, self.func))

def equals(self, other: Call0) -> bool:
return self.func == other.func

Expand Down Expand Up @@ -364,6 +383,9 @@ def __init__(self, func, arg):
def __repr__(self):
return f"{self.func!r}({self.arg!r})"

def _hash(self):
return hash((self.__class__, self.func, self.arg))

def equals(self, other: Call1) -> bool:
return self.func == other.func and self.arg == other.arg

Expand Down Expand Up @@ -401,6 +423,9 @@ def __init__(self, func, arg1, arg2):
def __repr__(self):
return f"{self.func!r}({self.arg1!r}, {self.arg2!r})"

def _hash(self):
return hash((self.__class__, self.func, self.arg1, self.arg2))

def equals(self, other: Call2) -> bool:
return (
self.func == other.func
Expand Down Expand Up @@ -447,6 +472,9 @@ def __init__(self, func, arg1, arg2, arg3):
def __repr__(self):
return f"{self.func!r}({self.arg1!r}, {self.arg2!r}, {self.arg3!r})"

def _hash(self):
return hash((self.__class__, self.func, self.arg1, self.arg2, self.arg3))

def equals(self, other: Call3) -> bool:
return (
self.func == other.func
Expand Down Expand Up @@ -482,12 +510,12 @@ class CallN(Builder):
"""

func: Builder
args: list[Builder]
args: tuple[Builder, ...]
kwargs: dict[str, Builder]

def __init__(self, func, *args, **kwargs):
self.func = builder(func)
self.args = [builder(arg) for arg in args]
self.args = tuple(builder(arg) for arg in args)
self.kwargs = {k: builder(v) for k, v in kwargs.items()}

def __repr__(self):
Expand All @@ -502,6 +530,9 @@ def __repr__(self):
else:
return f"{self.func!r}()"

def _hash(self):
return hash((self.__class__, self.func, self.args, PseudoHashable(self.kwargs)))

def equals(self, other: CallN) -> bool:
return (
self.func == other.func
Expand Down Expand Up @@ -573,6 +604,9 @@ def __repr__(self):
symbol = _operator_symbols[self.op]
return f"{symbol}{self.arg!r}"

def _hash(self):
return hash((self.__class__, self.op, self.arg))

def equals(self, other: Unop) -> bool:
return self.op == other.op and self.arg == other.arg

Expand Down Expand Up @@ -610,6 +644,9 @@ def __repr__(self):
symbol = _operator_symbols[self.op]
return f"({self.arg1!r} {symbol} {self.arg2!r})"

def _hash(self):
return hash((self.__class__, self.op, self.arg1, self.arg2))

def equals(self, other: Binop) -> bool:
return (
self.op == other.op and self.arg1 == other.arg1 and self.arg2 == other.arg2
Expand Down Expand Up @@ -645,6 +682,9 @@ def __init__(self, obj, key):
def __repr__(self):
return f"{self.obj!r}[{self.key!r}]"

def _hash(self):
return hash((self.__class__, self.obj, self.key))

def equals(self, other: Item) -> bool:
return self.obj == other.obj and self.key == other.key

Expand Down Expand Up @@ -678,6 +718,9 @@ def __init__(self, obj: Any, attr: str):
def __repr__(self):
return f"{self.obj!r}.{self.attr}"

def _hash(self):
return hash((self.__class__, self.obj, self.attr))

def equals(self, other: Attr) -> bool:
return self.obj == other.obj and self.attr == other.attr

Expand All @@ -699,11 +742,11 @@ class Seq(Builder):
"""

type_: Any
items: list[Builder]
items: tuple[Builder, ...]

def __init__(self, items):
self.type_ = type(items)
self.items = [builder(item) for item in items]
self.items = tuple(builder(item) for item in items)

def __repr__(self):
elems = ", ".join(map(repr, self.items))
Expand All @@ -714,6 +757,9 @@ def __repr__(self):
else:
return f"{self.type_.__name__}({elems})"

def _hash(self):
return hash((self.__class__, self.type_, self.items))

def equals(self, other: Seq) -> bool:
return self.type_ == other.type_ and self.items == other.items

Expand Down Expand Up @@ -751,6 +797,9 @@ def __repr__(self):
else:
return f"{self.type_.__name__}({{{items}}})"

def _hash(self):
return hash((self.__class__, self.type_, PseudoHashable(self.items)))

def equals(self, other: Map) -> bool:
return self.type_ == other.type_ and self.items == other.items

Expand Down
Loading

0 comments on commit 37ef82e

Please sign in to comment.