Skip to content

Commit

Permalink
[mypyc] Implement async for as a statement and in comprehensions
Browse files Browse the repository at this point in the history
Progress on mypyc/mypyc##868.
  • Loading branch information
msullivan committed Aug 18, 2022
1 parent 3ec1849 commit 976774e
Show file tree
Hide file tree
Showing 11 changed files with 395 additions and 50 deletions.
18 changes: 4 additions & 14 deletions mypyc/irbuild/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,31 +873,24 @@ def _visit_display(


def transform_list_comprehension(builder: IRBuilder, o: ListComprehension) -> Value:
if any(o.generator.is_async):
builder.error("async comprehensions are unimplemented", o.line)
return translate_list_comprehension(builder, o.generator)


def transform_set_comprehension(builder: IRBuilder, o: SetComprehension) -> Value:
if any(o.generator.is_async):
builder.error("async comprehensions are unimplemented", o.line)
return translate_set_comprehension(builder, o.generator)


def transform_dictionary_comprehension(builder: IRBuilder, o: DictionaryComprehension) -> Value:
if any(o.is_async):
builder.error("async comprehensions are unimplemented", o.line)

d = builder.call_c(dict_new_op, [], o.line)
loop_params = list(zip(o.indices, o.sequences, o.condlists))
d = builder.maybe_spill(builder.call_c(dict_new_op, [], o.line))
loop_params = list(zip(o.indices, o.sequences, o.condlists, o.is_async))

def gen_inner_stmts() -> None:
k = builder.accept(o.key)
v = builder.accept(o.value)
builder.call_c(dict_set_item_op, [d, k, v], o.line)
builder.call_c(dict_set_item_op, [builder.read(d), k, v], o.line)

comprehension_helper(builder, loop_params, gen_inner_stmts, o.line)
return d
return builder.read(d)


# Misc
Expand All @@ -915,9 +908,6 @@ def get_arg(arg: Optional[Expression]) -> Value:


def transform_generator_expr(builder: IRBuilder, o: GeneratorExpr) -> Value:
if any(o.is_async):
builder.error("async comprehensions are unimplemented", o.line)

builder.warning("Treating generator comprehension as list", o.line)
return builder.call_c(iter_op, [translate_list_comprehension(builder, o)], o.line)

Expand Down
136 changes: 119 additions & 17 deletions mypyc/irbuild/for_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,30 @@
TupleExpr,
TypeAlias,
)
from mypyc.ir.ops import BasicBlock, Branch, Integer, IntOp, Register, TupleGet, TupleSet, Value
from mypyc.ir.ops import (
BasicBlock,
Branch,
Integer,
IntOp,
LoadAddress,
LoadMem,
Register,
TupleGet,
TupleSet,
Value,
)
from mypyc.ir.rtypes import (
RTuple,
RType,
bool_rprimitive,
int_rprimitive,
is_dict_rprimitive,
is_list_rprimitive,
is_sequence_rprimitive,
is_short_int_rprimitive,
is_str_rprimitive,
is_tuple_rprimitive,
pointer_rprimitive,
short_int_rprimitive,
)
from mypyc.irbuild.builder import IRBuilder
Expand All @@ -45,8 +58,9 @@
dict_value_iter_op,
)
from mypyc.primitives.exc_ops import no_err_occurred_op
from mypyc.primitives.generic_ops import iter_op, next_op
from mypyc.primitives.generic_ops import aiter_op, anext_op, iter_op, next_op
from mypyc.primitives.list_ops import list_append_op, list_get_item_unsafe_op, new_list_set_item_op
from mypyc.primitives.misc_ops import stop_async_iteration_op
from mypyc.primitives.registry import CFunctionDescription
from mypyc.primitives.set_ops import set_add_op

Expand All @@ -59,6 +73,7 @@ def for_loop_helper(
expr: Expression,
body_insts: GenFunc,
else_insts: Optional[GenFunc],
is_async: bool,
line: int,
) -> None:
"""Generate IR for a loop.
Expand All @@ -81,7 +96,9 @@ def for_loop_helper(
# Determine where we want to exit, if our condition check fails.
normal_loop_exit = else_block if else_insts is not None else exit_block

for_gen = make_for_loop_generator(builder, index, expr, body_block, normal_loop_exit, line)
for_gen = make_for_loop_generator(
builder, index, expr, body_block, normal_loop_exit, line, is_async=is_async
)

builder.push_loop_stack(step_block, exit_block)
condition_block = BasicBlock()
Expand Down Expand Up @@ -220,32 +237,33 @@ def translate_list_comprehension(builder: IRBuilder, gen: GeneratorExpr) -> Valu
if val is not None:
return val

list_ops = builder.new_list_op([], gen.line)
loop_params = list(zip(gen.indices, gen.sequences, gen.condlists))
list_ops = builder.maybe_spill(builder.new_list_op([], gen.line))

loop_params = list(zip(gen.indices, gen.sequences, gen.condlists, gen.is_async))

def gen_inner_stmts() -> None:
e = builder.accept(gen.left_expr)
builder.call_c(list_append_op, [list_ops, e], gen.line)
builder.call_c(list_append_op, [builder.read(list_ops), e], gen.line)

comprehension_helper(builder, loop_params, gen_inner_stmts, gen.line)
return list_ops
return builder.read(list_ops)


def translate_set_comprehension(builder: IRBuilder, gen: GeneratorExpr) -> Value:
set_ops = builder.new_set_op([], gen.line)
loop_params = list(zip(gen.indices, gen.sequences, gen.condlists))
set_ops = builder.maybe_spill(builder.new_set_op([], gen.line))
loop_params = list(zip(gen.indices, gen.sequences, gen.condlists, gen.is_async))

def gen_inner_stmts() -> None:
e = builder.accept(gen.left_expr)
builder.call_c(set_add_op, [set_ops, e], gen.line)
builder.call_c(set_add_op, [builder.read(set_ops), e], gen.line)

comprehension_helper(builder, loop_params, gen_inner_stmts, gen.line)
return set_ops
return builder.read(set_ops)


def comprehension_helper(
builder: IRBuilder,
loop_params: List[Tuple[Lvalue, Expression, List[Expression]]],
loop_params: List[Tuple[Lvalue, Expression, List[Expression], bool]],
gen_inner_stmts: Callable[[], None],
line: int,
) -> None:
Expand All @@ -260,20 +278,26 @@ def comprehension_helper(
gen_inner_stmts: function to generate the IR for the body of the innermost loop
"""

def handle_loop(loop_params: List[Tuple[Lvalue, Expression, List[Expression]]]) -> None:
def handle_loop(loop_params: List[Tuple[Lvalue, Expression, List[Expression], bool]]) -> None:
"""Generate IR for a loop.
Given a list of (index, expression, [conditions]) tuples, generate IR
for the nested loops the list defines.
"""
index, expr, conds = loop_params[0]
index, expr, conds, is_async = loop_params[0]
for_loop_helper(
builder, index, expr, lambda: loop_contents(conds, loop_params[1:]), None, line
builder,
index,
expr,
lambda: loop_contents(conds, loop_params[1:]),
None,
is_async=is_async,
line=line,
)

def loop_contents(
conds: List[Expression],
remaining_loop_params: List[Tuple[Lvalue, Expression, List[Expression]]],
remaining_loop_params: List[Tuple[Lvalue, Expression, List[Expression], bool]],
) -> None:
"""Generate the body of the loop.
Expand Down Expand Up @@ -319,13 +343,23 @@ def make_for_loop_generator(
body_block: BasicBlock,
loop_exit: BasicBlock,
line: int,
is_async: bool = False,
nested: bool = False,
) -> ForGenerator:
"""Return helper object for generating a for loop over an iterable.
If "nested" is True, this is a nested iterator such as "e" in "enumerate(e)".
"""

# Do an async loop if needed. async is always generic
if is_async:
expr_reg = builder.accept(expr)
async_obj = ForAsyncIterable(builder, index, body_block, loop_exit, line, nested)
item_type = builder._analyze_iterable_item_type(expr)
item_rtype = builder.type_to_rtype(item_type)
async_obj.init(expr_reg, item_rtype)
return async_obj

rtyp = builder.node_type(expr)
if is_sequence_rprimitive(rtyp):
# Special case "for x in <list>".
Expand Down Expand Up @@ -500,7 +534,7 @@ def load_len(self, expr: Union[Value, AssignmentTarget]) -> Value:


class ForIterable(ForGenerator):
"""Generate IR for a for loop over an arbitrary iterable (the normal case)."""
"""Generate IR for an async for loop."""

def need_cleanup(self) -> bool:
# Create a new cleanup block for when the loop is finished.
Expand Down Expand Up @@ -548,6 +582,74 @@ def gen_cleanup(self) -> None:
self.builder.call_c(no_err_occurred_op, [], self.line)


class ForAsyncIterable(ForGenerator):
"""Generate IR for a for loop over an arbitrary iterable (the general case)."""

def need_cleanup(self) -> bool:
# Create a new cleanup block for when the loop is finished.
return True

def init(self, expr_reg: Value, target_type: RType) -> None:
# Define targets to contain the expression, along with the
# iterator that will be used for the for-loop. We are inside
# of a generator function, so we will spill these into
# environment class.
builder = self.builder
iter_reg = builder.call_c(aiter_op, [expr_reg], self.line)
builder.maybe_spill(expr_reg)
self.iter_target = builder.maybe_spill(iter_reg)
self.target_type = target_type
self.stop_reg = Register(bool_rprimitive)

def gen_condition(self) -> None:
# This does the test and fetches the next value
# try:
# TARGET = await type(iter).__anext__(iter)
# stop = False
# except StopAsyncIteration:
# stop = True
#
# What a pain.
# There are optimizations available here if we punch through some abstractions.

from mypyc.irbuild.statement import emit_await, transform_try_except

builder = self.builder
line = self.line

def except_match() -> Value:
addr = builder.add(LoadAddress(pointer_rprimitive, stop_async_iteration_op.src, line))
return builder.add(LoadMem(stop_async_iteration_op.type, addr))

def try_body() -> None:
awaitable = builder.call_c(anext_op, [builder.read(self.iter_target)], line)
self.next_reg = emit_await(builder, awaitable, line)
builder.assign(self.stop_reg, builder.false(), -1)

def except_body() -> None:
builder.assign(self.stop_reg, builder.true(), line)

transform_try_except(
builder, try_body, [((except_match, line), None, except_body)], None, line
)

builder.add(Branch(self.stop_reg, self.loop_exit, self.body_block, Branch.BOOL))

def begin_body(self) -> None:
# Assign the value obtained from await __anext__ to the
# lvalue so that it can be referenced by code in the body of the loop.
builder = self.builder
line = self.line
# We unbox here so that iterating with tuple unpacking generates a tuple based
# unpack instead of an iterator based one.
next_reg = builder.coerce(self.next_reg, self.target_type, line)
builder.assign(builder.get_assignment_target(self.index), next_reg, line)

def gen_step(self) -> None:
# Nothing to do here, since we get the next item as part of gen_condition().
pass


def unsafe_index(builder: IRBuilder, target: Value, index: Value, line: int) -> Value:
"""Emit a potentially unsafe index into a target."""
# This doesn't really fit nicely into any of our data-driven frameworks
Expand Down
8 changes: 5 additions & 3 deletions mypyc/irbuild/specialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def any_all_helper(
) -> Value:
retval = Register(bool_rprimitive)
builder.assign(retval, initial_value(), -1)
loop_params = list(zip(gen.indices, gen.sequences, gen.condlists))
loop_params = list(zip(gen.indices, gen.sequences, gen.condlists, gen.is_async))
true_block, false_block, exit_block = BasicBlock(), BasicBlock(), BasicBlock()

def gen_inner_stmts() -> None:
Expand Down Expand Up @@ -421,7 +421,9 @@ def gen_inner_stmts() -> None:
call_expr = builder.accept(gen_expr.left_expr)
builder.assign(retval, builder.binary_op(retval, call_expr, "+", -1), -1)

loop_params = list(zip(gen_expr.indices, gen_expr.sequences, gen_expr.condlists))
loop_params = list(
zip(gen_expr.indices, gen_expr.sequences, gen_expr.condlists, gen_expr.is_async)
)
comprehension_helper(builder, loop_params, gen_inner_stmts, gen_expr.line)

return retval
Expand Down Expand Up @@ -471,7 +473,7 @@ def gen_inner_stmts() -> None:
builder.assign(retval, builder.accept(gen.left_expr), gen.left_expr.line)
builder.goto(exit_block)

loop_params = list(zip(gen.indices, gen.sequences, gen.condlists))
loop_params = list(zip(gen.indices, gen.sequences, gen.condlists, gen.is_async))
comprehension_helper(builder, loop_params, gen_inner_stmts, gen.line)

# Now we need the case for when nothing got hit. If there was
Expand Down
19 changes: 12 additions & 7 deletions mypyc/irbuild/statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
)

GenFunc = Callable[[], None]
ValueGenFunc = Callable[[], Value]


def transform_block(builder: IRBuilder, block: Block) -> None:
Expand Down Expand Up @@ -327,17 +328,16 @@ def transform_while_stmt(builder: IRBuilder, s: WhileStmt) -> None:


def transform_for_stmt(builder: IRBuilder, s: ForStmt) -> None:
if s.is_async:
builder.error("async for is unimplemented", s.line)

def body() -> None:
builder.accept(s.body)

def else_block() -> None:
assert s.else_body is not None
builder.accept(s.else_body)

for_loop_helper(builder, s.index, s.expr, body, else_block if s.else_body else None, s.line)
for_loop_helper(
builder, s.index, s.expr, body, else_block if s.else_body else None, s.is_async, s.line
)


def transform_break_stmt(builder: IRBuilder, node: BreakStmt) -> None:
Expand All @@ -362,7 +362,7 @@ def transform_raise_stmt(builder: IRBuilder, s: RaiseStmt) -> None:
def transform_try_except(
builder: IRBuilder,
body: GenFunc,
handlers: Sequence[Tuple[Optional[Expression], Optional[Expression], GenFunc]],
handlers: Sequence[Tuple[Optional[Tuple[ValueGenFunc, int]], Optional[Expression], GenFunc]],
else_body: Optional[GenFunc],
line: int,
) -> None:
Expand Down Expand Up @@ -399,8 +399,9 @@ def transform_try_except(
for type, var, handler_body in handlers:
next_block = None
if type:
type_f, type_line = type
next_block, body_block = BasicBlock(), BasicBlock()
matches = builder.call_c(exc_matches_op, [builder.accept(type)], type.line)
matches = builder.call_c(exc_matches_op, [type_f()], type_line)
builder.add(Branch(matches, body_block, next_block, Branch.BOOL))
builder.activate_block(body_block)
if var:
Expand Down Expand Up @@ -451,8 +452,12 @@ def body() -> None:
def make_handler(body: Block) -> GenFunc:
return lambda: builder.accept(body)

def make_entry(type: Expression) -> Tuple[ValueGenFunc, int]:
return (lambda: builder.accept(type), type.line)

handlers = [
(type, var, make_handler(body)) for type, var, body in zip(t.types, t.vars, t.handlers)
(make_entry(type) if type else None, var, make_handler(body))
for type, var, body in zip(t.types, t.vars, t.handlers)
]
else_body = (lambda: builder.accept(t.else_body)) if t.else_body else None
transform_try_except(builder, body, handlers, else_body, t.line)
Expand Down
4 changes: 4 additions & 0 deletions mypyc/lib-rt/CPy.h
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,10 @@ PyObject *CPyImport_ImportFrom(PyObject *module, PyObject *package_name,

PyObject *CPySingledispatch_RegisterFunction(PyObject *singledispatch_func, PyObject *cls,
PyObject *func);

PyObject *CPy_GetAIter(PyObject *obj);
PyObject *CPy_GetANext(PyObject *aiter);

#ifdef __cplusplus
}
#endif
Expand Down
Loading

0 comments on commit 976774e

Please sign in to comment.