Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Compile array comprehensions to Hugr #616

Open
wants to merge 5 commits into
base: array-compr/linearity
Choose a base branch
from

Conversation

mark-koch
Copy link
Collaborator

  • Refactor compilation of list comprehensions so we can use the same logic for array comprehensions. The only difference is that lists call push to add the next element to the accumulator whereas arrays use set and update a counter. This is now implemented via a build_update hook provided to _compile_generators
  • Store checked Globals inside the compilation context so we can invoke Guppy methods by name while building Hugr

@mark-koch mark-koch requested a review from a team as a code owner November 4, 2024 13:29
@mark-koch mark-koch requested review from acl-cqc and removed request for a team November 4, 2024 13:29
Comment on lines +71 to +76
def array_new_uninitialized(elem_ty: ht.Type, length: int) -> ops.ExtOp:
"""Returns an array `uninitialized` operation."""
# TODO
return UnsupportedOp(
op_name="array.uninitialized", inputs=[], outputs=[array_type(elem_ty, length)]
).ext_op
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@codecov-commenter
Copy link

codecov-commenter commented Nov 4, 2024

Codecov Report

Attention: Patch coverage is 98.46154% with 1 line in your changes missing coverage. Please review.

Project coverage is 91.85%. Comparing base (ced00e1) to head (f35ea05).

Files with missing lines Patch % Lines
guppylang/compiler/core.py 92.30% 1 Missing ⚠️
Additional details and impacted files
@@                    Coverage Diff                    @@
##           array-compr/linearity     #616      +/-   ##
=========================================================
+ Coverage                  91.63%   91.85%   +0.21%     
=========================================================
  Files                         60       60              
  Lines                       6495     6522      +27     
=========================================================
+ Hits                        5952     5991      +39     
+ Misses                       543      531      -12     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@acl-cqc acl-cqc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some initial thoughts, am still thinking about the interface for build_update.

@@ -26,15 +28,19 @@ class CompiledGlobals:
compiled: dict[DefId, CompiledDef]
worklist: set[DefId]

checked_globals: Globals
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should do some renaming here - not necessarily in this PR, in fact probably not. But this is in class CompiledGlobals, so the _globals here feels redundant, it should be just checked. But then that conflicts with the dict[DefId, CheckedDef] - that might want to be checked_defs or just defs, thus...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And/or see other comment

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, we should do some refactoring here. Also see #497

compiler.compile_stmts([gen.hasnext_assign], self.dfg)
with self._if_true(gen.hasnext, inputs):

def compile_ifs(ifs: list[ast.expr]) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you familiar with contextlib.ExitStack ? This would let you do this in a non-recursive loop while ifs

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice, I have never used that. Makes this function a lot cleaner 👍

gen, *gens = gens
compiler = StmtCompiler(self.globals)
compiler.compile_stmts([gen.iter_assign], self.dfg)
assert isinstance(gen.iter, PlaceNode)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There do seem to be a lot of asserts of this kind throughout this PR. Is it possible to e.g. update DesugaredGenerator so that these are statically checked?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that they start of as regular ast.Name nodes and are only turned into places during type checking. We could add a separate CheckedGenerator class with better type info though, if you think it's a good idea?

array, count = self.dfg[array_var], self.dfg[count_var]
(self.dfg[array_var],) = self._build_method_call(
array_ty, "__setitem__", node, [array, count, elt], array_ty.args
).inout_returns
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very-optional, probably for another PR - but how about all_returns that gives back a tuple of both inout and regular, so you can check the other is empty in your (non-exhaustive) patterm-match, and so you can't accidentally forget which one to call?
(I slightly dislike that which of inout_returns and regular_returns you want, depends on the sugar/style with which the function was declared, rather than the actual functionality, right?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a NamedTuple, so we can actually unpack it directly

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I slightly dislike that which of inout_returns and regular_returns you want, depends on the sugar/style with which the function was declared, rather than the actual functionality, right?

Yes, but I don't see a way around that...

def get_instance_func(
self, ty: Type | TypeDef, name: str
) -> CompiledCallableDef | None:
checked_func = self.checked_globals.get_instance_func(ty, name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does feel like some refactoring could help, perhaps after the array stuff is all merged in; all we actually want from the checked-globals is the mapping from name to id...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also need the mapping from class definitions to their methods

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only place we're reading checked_globals AFAICS.

I think the refactoring we might want to look at the relationship between "compiled" and "checked" objects/types/etc., so let's leave all that until after arrays is done.

assert func is not None
return func.compile_call(args, type_args or [], self.dfg, self.globals, node)

def _compile_generators(
Copy link
Contributor

@acl-cqc acl-cqc Nov 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I was just wondering whether you could use ExitStack to make the whole _compile_generators nonrecursive too....and you'd gone and done it 😆 😄 Good job! 👍

Next...I am wondering if you can turn the whole _compile_generators into a context manager. The build_update callback is a bit restrictive (lots of hidden state because build_update has a fixed single-Wire interface, etc.), so rather than self._compile_generators(node.elt, node.generators, [list_place], build_update), you could call it by

with _compile_generators(self, node.generators, [list_place]) as _:
  build_update(node.elt)

....which would allow inlining build_update at the use site.

For the implementation of _compile_generators, see the @contextmanager decorator - you'd put your yield in place of the existing call to the build_update Callable.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love the idea! 👍

assert isinstance(gen.hasnext, PlaceNode)
inputs = [gen.iter] + [PlaceNode(place=var) for var in loop_vars]
# Remember to finalize the iterator once we are done with it. Note that
# we need to use partial in the callback, so that we bind the *current*
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the seriously awkward bit, which of course I hadn't spotted when I suggested it ;-), well done!

def test_zero_length(validate):
@compile_guppy
def test() -> array[int, 0]:
return array(i for i in range(0))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return array(i for i in range(0))
return array(i/0 for i in range(0))

I.e. something that would runtime panic if it were ever actually evaluated

def test_capture(validate):
@compile_guppy
def test(x: int) -> array[int, 42]:
return array(i + x for i in range(42))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to make this one an execution test to check it does actually capture the correct value of x....

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, but unfortunately it doesn't run yet until we have CQCL/hugr#1627 and corresponding lowering in hugr-llvm.

def test() -> array[array[int, 10], 20]:
return array(array(x + y for y in range(10)) for x in range(20))

validate(test)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be good to sum the whole lot too

Copy link
Collaborator Author

@mark-koch mark-koch Nov 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Array comprehensions with multiple generators aren't supported yet since we can't express arrays of size n * m if n or m are generic arguments. We need more advanced const expression to allow that

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't we do

@compile_guppy
def test2() -> int:
  arr = test()
  total = 0
  for inner in arr:
    for elem in inner:
      total += elem
  return total

...
run_int_fn(test2, ....)

can we only iterate over arrays in a comprehension?

If we have 1d sum only (of type array[int,_] -> int) then you can lift to 2d via a comprehension sum(array(sum(inner) for inner in test()))

modulo there are no execution tests!

Copy link
Contributor

@acl-cqc acl-cqc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd approve now, but I'm just wondering whether the contextmanager trick works (i.e. I might have other ideas if it doesn't)....

Looks good, thanks @mark-koch !

Copy link
Contributor

@acl-cqc acl-cqc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks @mark-koch !

Not sure if it's worth adding execution tests with @pytest.mark.xfail or anything like that...

array_var = Variable(next(tmp_vars), array_ty, node)
count_var = Variable(next(tmp_vars), int_type(), node)
hugr_elt_ty = (
ht.Option(node.elt_ty.to_hugr())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might possibly be worth making this into a method that you call node.elt_ty.array_element_type()

def test() -> array[array[int, 10], 20]:
return array(array(x + y for y in range(10)) for x in range(20))

validate(test)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't we do

@compile_guppy
def test2() -> int:
  arr = test()
  total = 0
  for inner in arr:
    for elem in inner:
      total += elem
  return total

...
run_int_fn(test2, ....)

can we only iterate over arrays in a comprehension?

If we have 1d sum only (of type array[int,_] -> int) then you can lift to 2d via a comprehension sum(array(sum(inner) for inner in test()))

modulo there are no execution tests!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants