-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Compile array comprehensions to Hugr #616
base: array-compr/linearity
Are you sure you want to change the base?
Conversation
def array_new_uninitialized(elem_ty: ht.Type, length: int) -> ops.ExtOp: | ||
"""Returns an array `uninitialized` operation.""" | ||
# TODO | ||
return UnsupportedOp( | ||
op_name="array.uninitialized", inputs=[], outputs=[array_type(elem_ty, length)] | ||
).ext_op |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See CQCL/hugr#1627
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## array-compr/linearity #616 +/- ##
=========================================================
+ Coverage 91.63% 91.85% +0.21%
=========================================================
Files 60 60
Lines 6495 6522 +27
=========================================================
+ Hits 5952 5991 +39
+ Misses 543 531 -12 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some initial thoughts, am still thinking about the interface for build_update
.
@@ -26,15 +28,19 @@ class CompiledGlobals: | |||
compiled: dict[DefId, CompiledDef] | |||
worklist: set[DefId] | |||
|
|||
checked_globals: Globals |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should do some renaming here - not necessarily in this PR, in fact probably not. But this is in class CompiledGlobals
, so the _globals
here feels redundant, it should be just checked
. But then that conflicts with the dict[DefId, CheckedDef]
- that might want to be checked_defs
or just defs
, thus...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And/or see other comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, we should do some refactoring here. Also see #497
guppylang/compiler/expr_compiler.py
Outdated
compiler.compile_stmts([gen.hasnext_assign], self.dfg) | ||
with self._if_true(gen.hasnext, inputs): | ||
|
||
def compile_ifs(ifs: list[ast.expr]) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you familiar with contextlib.ExitStack
? This would let you do this in a non-recursive loop while ifs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice, I have never used that. Makes this function a lot cleaner 👍
guppylang/compiler/expr_compiler.py
Outdated
gen, *gens = gens | ||
compiler = StmtCompiler(self.globals) | ||
compiler.compile_stmts([gen.iter_assign], self.dfg) | ||
assert isinstance(gen.iter, PlaceNode) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There do seem to be a lot of assert
s of this kind throughout this PR. Is it possible to e.g. update DesugaredGenerator
so that these are statically checked?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem is that they start of as regular ast.Name
nodes and are only turned into places during type checking. We could add a separate CheckedGenerator
class with better type info though, if you think it's a good idea?
guppylang/compiler/expr_compiler.py
Outdated
array, count = self.dfg[array_var], self.dfg[count_var] | ||
(self.dfg[array_var],) = self._build_method_call( | ||
array_ty, "__setitem__", node, [array, count, elt], array_ty.args | ||
).inout_returns |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very-optional, probably for another PR - but how about all_returns
that gives back a tuple of both inout and regular, so you can check the other is empty in your (non-exhaustive) patterm-match, and so you can't accidentally forget which one to call?
(I slightly dislike that which of inout_returns
and regular_returns
you want, depends on the sugar/style with which the function was declared, rather than the actual functionality, right?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a NamedTuple
, so we can actually unpack it directly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I slightly dislike that which of inout_returns and regular_returns you want, depends on the sugar/style with which the function was declared, rather than the actual functionality, right?
Yes, but I don't see a way around that...
def get_instance_func( | ||
self, ty: Type | TypeDef, name: str | ||
) -> CompiledCallableDef | None: | ||
checked_func = self.checked_globals.get_instance_func(ty, name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does feel like some refactoring could help, perhaps after the array stuff is all merged in; all we actually want from the checked-globals is the mapping from name to id...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We also need the mapping from class definitions to their methods
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the only place we're reading checked_globals
AFAICS.
I think the refactoring we might want to look at the relationship between "compiled" and "checked" objects/types/etc., so let's leave all that until after arrays is done.
guppylang/compiler/expr_compiler.py
Outdated
assert func is not None | ||
return func.compile_call(args, type_args or [], self.dfg, self.globals, node) | ||
|
||
def _compile_generators( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I was just wondering whether you could use ExitStack
to make the whole _compile_generators
nonrecursive too....and you'd gone and done it 😆 😄 Good job! 👍
Next...I am wondering if you can turn the whole _compile_generators
into a context manager. The build_update
callback is a bit restrictive (lots of hidden state because build_update
has a fixed single-Wire interface, etc.), so rather than self._compile_generators(node.elt, node.generators, [list_place], build_update)
, you could call it by
with _compile_generators(self, node.generators, [list_place]) as _:
build_update(node.elt)
....which would allow inlining build_update
at the use site.
For the implementation of _compile_generators
, see the @contextmanager
decorator - you'd put your yield
in place of the existing call to the build_update
Callable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Love the idea! 👍
assert isinstance(gen.hasnext, PlaceNode) | ||
inputs = [gen.iter] + [PlaceNode(place=var) for var in loop_vars] | ||
# Remember to finalize the iterator once we are done with it. Note that | ||
# we need to use partial in the callback, so that we bind the *current* |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the seriously awkward bit, which of course I hadn't spotted when I suggested it ;-), well done!
def test_zero_length(validate): | ||
@compile_guppy | ||
def test() -> array[int, 0]: | ||
return array(i for i in range(0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return array(i for i in range(0)) | |
return array(i/0 for i in range(0)) |
I.e. something that would runtime panic if it were ever actually evaluated
def test_capture(validate): | ||
@compile_guppy | ||
def test(x: int) -> array[int, 42]: | ||
return array(i + x for i in range(42)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be good to make this one an execution test to check it does actually capture the correct value of x
....
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, but unfortunately it doesn't run yet until we have CQCL/hugr#1627 and corresponding lowering in hugr-llvm.
def test() -> array[array[int, 10], 20]: | ||
return array(array(x + y for y in range(10)) for x in range(20)) | ||
|
||
validate(test) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might be good to sum the whole lot too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Array comprehensions with multiple generators aren't supported yet since we can't express arrays of size n * m
if n
or m
are generic arguments. We need more advanced const expression to allow that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can't we do
@compile_guppy
def test2() -> int:
arr = test()
total = 0
for inner in arr:
for elem in inner:
total += elem
return total
...
run_int_fn(test2, ....)
can we only iterate over arrays in a comprehension?
If we have 1d sum
only (of type array[int,_] -> int
) then you can lift to 2d via a comprehension sum(array(sum(inner) for inner in test()))
modulo there are no execution tests!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd approve now, but I'm just wondering whether the contextmanager trick works (i.e. I might have other ideas if it doesn't)....
Looks good, thanks @mark-koch !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thanks @mark-koch !
Not sure if it's worth adding execution tests with @pytest.mark.xfail
or anything like that...
array_var = Variable(next(tmp_vars), array_ty, node) | ||
count_var = Variable(next(tmp_vars), int_type(), node) | ||
hugr_elt_ty = ( | ||
ht.Option(node.elt_ty.to_hugr()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might possibly be worth making this into a method that you call node.elt_ty.array_element_type()
def test() -> array[array[int, 10], 20]: | ||
return array(array(x + y for y in range(10)) for x in range(20)) | ||
|
||
validate(test) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can't we do
@compile_guppy
def test2() -> int:
arr = test()
total = 0
for inner in arr:
for elem in inner:
total += elem
return total
...
run_int_fn(test2, ....)
can we only iterate over arrays in a comprehension?
If we have 1d sum
only (of type array[int,_] -> int
) then you can lift to 2d via a comprehension sum(array(sum(inner) for inner in test()))
modulo there are no execution tests!
push
to add the next element to the accumulator whereas arrays useset
and update a counter. This is now implemented via abuild_update
hook provided to_compile_generators
Globals
inside the compilation context so we can invoke Guppy methods by name while building Hugr