-
I'm working on a small compiler that has multiple similar tree representations. For example: from typing import Literal, TypedDict, Union
type Surface = Union[
Int,
Add[Surface, Surface],
Subtract[Surface, Surface],
Bool,
If[Surface, Surface, Surface],
And[Surface, Surface],
]
type Core = Union[
Int,
Add[Core, Core],
Subtract[Core, Core],
Bool,
If[Core, Core, Core],
] where the elements of the unions (nodes) are defined as follows: class Int(TypedDict, total=False):
tag: Literal["int"]
value: int
class Add[E1, E2](TypedDict, total=False):
tag: Literal["+"]
operands: tuple[E1, E2]
class Subtract[E1, E2](TypedDict, total=False):
tag: Literal["-"]
operands: tuple[E1, E2]
class Bool(TypedDict, total=False):
tag: Literal["bool"]
value: bool
class If[Predicate, Consequent, Alternative](TypedDict, total=False):
tag: Literal["if"]
predicate: Predicate
consequent: Consequent
alternative: Alternative
class And[E1, E2](TypedDict, total=False):
tag: Literal["and"]
operands: tuple[E1, E2] Many of the compiler passes are simple transformations between the representations. For example: def shrink(
expr: Surface,
) -> Core:
match expr:
case {"tag": ("int" | "bool")}:
return expr
case {"tag": "+", "operands": [e1, e2]}:
return {**expr, "operands": (shrink(e1), shrink(e2))}
case {"tag": "-", "operands": [e1, e2]}:
return {**expr, "operands": (shrink(e1), shrink(e2))}
# case {"tag": ("+" | "-"), "operands": [e1, e2]}:
# return {**expr, "operands": (shrink(e1), shrink(e2))}
case {
"tag": "if",
"predicate": predicate,
"consequent": consequent,
"alternative": alternative,
}:
return {
**expr,
"predicate": shrink(predicate),
"consequent": shrink(consequent),
"alternative": shrink(alternative),
}
case {"tag": "and", "operands": [e1, e2]}:
return If(
tag="if",
predicate=shrink(e1),
consequent=shrink(e2),
alternative=Bool(tag="bool", value=False),
)
case _:
raise ValueError(f"unhandled expression: {expr}") The reason for The code above works fine, but there are many situations where nodes should be handled similarly. For non-recursive nodes (e.g., Int, Bool) I can combine the cases and the type checker is happy. However, for nodes that are recursive (e.g., Add, Subtract), I can not combine them. For example, if the cases for Add and Subtract in the above example are replaced with the commented code it does not type check. Is there a way I can write this code so that the similar cases are merged and the type checker is satisfied without resorting to something like Also, why is the final |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
Recursive nodes require a static type checker to employ bidirectional type inference to infer a result that is compatible with the return type ( def func(t1: Literal["+"], t2: Literal["+", "-"]):
# This works because the dictionary expression evaluates
# to `Add[Core, Core]`, which is one of the subtypes of the
# expected type union.
x1: Add[Core, Core] | Subtract[Core, Core] = {
"tag": t1,
"operands": ({"tag": "int", "value": 1}, {"tag": "int", "value": 1}),
}
# This does not work because the dictionary expression does not
# evaluate to either `Add[Core, Core]` or `Subtract[Core, Core]`.
x2: Add[Core, Core] | Subtract[Core, Core] = {
"tag": t2,
"operands": ({"tag": "int", "value": 1}, {"tag": "int", "value": 1}),
} I don't know of any static type checkers (in any language) that support this. Off the top of my head, I can't think of an algorithm that would enable this, but it's possible that such an algorithm exists. If it does, it would be undoubtedly be extremely expensive computationally — probably infeasible in practice.
There are a couple of reasons for this. The first is that your TypedDict definitions have The second reason is that pyright's type narrowing algorithm for mapping types based on unions of tagged TypedDicts isn't being as smart as it could be in the negative (fall-through) case. It's not narrowing the fall-through type as much as it could in the case where the |
Beta Was this translation helpful? Give feedback.
Recursive nodes require a static type checker to employ bidirectional type inference to infer a result that is compatible with the return type (
Core
). When the "expected type" for bidirectional type inference involves a union, it succeeds only if the expression can be evaluated using one of the subtypes of the union. Your code is creating a situation where the expression must evaluate to more than one subtype.