Skip to content

Commit

Permalink
[mypyc] Implement async with
Browse files Browse the repository at this point in the history
Also fix returning a value from inside a try block when the finally block
does a yield. (Which happens when returning from an async with).

Progress on mypyc/mypyc##868.
  • Loading branch information
msullivan committed Aug 17, 2022
1 parent c3f2e8a commit aff843e
Show file tree
Hide file tree
Showing 8 changed files with 152 additions and 88 deletions.
13 changes: 6 additions & 7 deletions mypyc/irbuild/nonlocalcontrol.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@
from __future__ import annotations

from abc import abstractmethod
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING, Union

from mypyc.ir.ops import (
NO_TRACEBACK_LINE_NO,
Assign,
BasicBlock,
Branch,
Goto,
Expand Down Expand Up @@ -142,7 +141,7 @@ class TryFinallyNonlocalControl(NonlocalControl):

def __init__(self, target: BasicBlock) -> None:
self.target = target
self.ret_reg: Optional[Register] = None
self.ret_reg: Union[None, Register, AssignmentTarget] = None

def gen_break(self, builder: IRBuilder, line: int) -> None:
builder.error("break inside try/finally block is unimplemented", line)
Expand All @@ -152,9 +151,10 @@ def gen_continue(self, builder: IRBuilder, line: int) -> None:

def gen_return(self, builder: IRBuilder, value: Value, line: int) -> None:
if self.ret_reg is None:
self.ret_reg = Register(builder.ret_types[-1])
self.ret_reg = builder.maybe_spill_assignable(value)
else:
builder.assign(self.ret_reg, value, line)

builder.add(Assign(self.ret_reg, value))
builder.add(Goto(self.target))


Expand All @@ -180,9 +180,8 @@ class FinallyNonlocalControl(CleanupNonlocalControl):
leave and the return register is decrefed if it isn't null.
"""

def __init__(self, outer: NonlocalControl, ret_reg: Optional[Value], saved: Value) -> None:
def __init__(self, outer: NonlocalControl, saved: Value) -> None:
super().__init__(outer)
self.ret_reg = ret_reg
self.saved = saved

def gen_cleanup(self, builder: IRBuilder, line: int) -> None:
Expand Down
80 changes: 47 additions & 33 deletions mypyc/irbuild/statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from __future__ import annotations

import importlib.util
from typing import Callable, List, Optional, Sequence, Tuple
from typing import Callable, List, Optional, Sequence, Tuple, Union

from mypy.nodes import (
AssertStmt,
Expand Down Expand Up @@ -464,7 +464,7 @@ def try_finally_try(
return_entry: BasicBlock,
main_entry: BasicBlock,
try_body: GenFunc,
) -> Optional[Register]:
) -> Union[Register, AssignmentTarget, None]:
# Compile the try block with an error handler
control = TryFinallyNonlocalControl(return_entry)
builder.builder.push_error_handler(err_handler)
Expand All @@ -485,14 +485,14 @@ def try_finally_entry_blocks(
return_entry: BasicBlock,
main_entry: BasicBlock,
finally_block: BasicBlock,
ret_reg: Optional[Register],
ret_reg: Union[Register, AssignmentTarget, None],
) -> Value:
old_exc = Register(exc_rtuple)

# Entry block for non-exceptional flow
builder.activate_block(main_entry)
if ret_reg:
builder.add(Assign(ret_reg, builder.add(LoadErrorValue(builder.ret_types[-1]))))
builder.assign(ret_reg, builder.add(LoadErrorValue(builder.ret_types[-1])), -1)
builder.goto(return_entry)

builder.activate_block(return_entry)
Expand All @@ -502,24 +502,20 @@ def try_finally_entry_blocks(
# Entry block for errors
builder.activate_block(err_handler)
if ret_reg:
builder.add(Assign(ret_reg, builder.add(LoadErrorValue(builder.ret_types[-1]))))
builder.assign(ret_reg, builder.add(LoadErrorValue(builder.ret_types[-1])), -1)
builder.add(Assign(old_exc, builder.call_c(error_catch_op, [], -1)))
builder.goto(finally_block)

return old_exc


def try_finally_body(
builder: IRBuilder,
finally_block: BasicBlock,
finally_body: GenFunc,
ret_reg: Optional[Value],
old_exc: Value,
builder: IRBuilder, finally_block: BasicBlock, finally_body: GenFunc, old_exc: Value
) -> Tuple[BasicBlock, FinallyNonlocalControl]:
cleanup_block = BasicBlock()
# Compile the finally block with the nonlocal control flow overridden to restore exc_info
builder.builder.push_error_handler(cleanup_block)
finally_control = FinallyNonlocalControl(builder.nonlocal_control[-1], ret_reg, old_exc)
finally_control = FinallyNonlocalControl(builder.nonlocal_control[-1], old_exc)
builder.nonlocal_control.append(finally_control)
builder.activate_block(finally_block)
finally_body()
Expand All @@ -533,7 +529,7 @@ def try_finally_resolve_control(
cleanup_block: BasicBlock,
finally_control: FinallyNonlocalControl,
old_exc: Value,
ret_reg: Optional[Value],
ret_reg: Union[Register, AssignmentTarget, None],
) -> BasicBlock:
"""Resolve the control flow out of a finally block.
Expand All @@ -553,10 +549,10 @@ def try_finally_resolve_control(
if ret_reg:
builder.activate_block(rest)
return_block, rest = BasicBlock(), BasicBlock()
builder.add(Branch(ret_reg, rest, return_block, Branch.IS_ERROR))
builder.add(Branch(builder.read(ret_reg), rest, return_block, Branch.IS_ERROR))

builder.activate_block(return_block)
builder.nonlocal_control[-1].gen_return(builder, ret_reg, -1)
builder.nonlocal_control[-1].gen_return(builder, builder.read(ret_reg), -1)

# TODO: handle break/continue
builder.activate_block(rest)
Expand Down Expand Up @@ -598,7 +594,7 @@ def transform_try_finally_stmt(

# Compile the body of the finally
cleanup_block, finally_control = try_finally_body(
builder, finally_block, finally_body, ret_reg, old_exc
builder, finally_block, finally_body, old_exc
)

# Resolve the control flow out of the finally block
Expand Down Expand Up @@ -636,18 +632,28 @@ def get_sys_exc_info(builder: IRBuilder) -> List[Value]:


def transform_with(
builder: IRBuilder, expr: Expression, target: Optional[Lvalue], body: GenFunc, line: int
builder: IRBuilder,
expr: Expression,
target: Optional[Lvalue],
body: GenFunc,
is_async: bool,
line: int,
) -> None:
# This is basically a straight transcription of the Python code in PEP 343.
# I don't actually understand why a bunch of it is the way it is.
# We could probably optimize the case where the manager is compiled by us,
# but that is not our common case at all, so.

al = "a" if is_async else ""

mgr_v = builder.accept(expr)
typ = builder.call_c(type_op, [mgr_v], line)
exit_ = builder.maybe_spill(builder.py_get_attr(typ, "__exit__", line))
value = builder.py_call(builder.py_get_attr(typ, "__enter__", line), [mgr_v], line)
exit_ = builder.maybe_spill(builder.py_get_attr(typ, f"__{al}exit__", line))
value = builder.py_call(builder.py_get_attr(typ, f"__{al}enter__", line), [mgr_v], line)
mgr = builder.maybe_spill(mgr_v)
exc = builder.maybe_spill_assignable(builder.true())
if is_async:
value = emit_await(builder, value, line)

def try_body() -> None:
if target:
Expand All @@ -657,13 +663,13 @@ def try_body() -> None:
def except_body() -> None:
builder.assign(exc, builder.false(), line)
out_block, reraise_block = BasicBlock(), BasicBlock()
builder.add_bool_branch(
builder.py_call(
builder.read(exit_), [builder.read(mgr)] + get_sys_exc_info(builder), line
),
out_block,
reraise_block,
exit_val = builder.py_call(
builder.read(exit_), [builder.read(mgr)] + get_sys_exc_info(builder), line
)
if is_async:
exit_val = emit_await(builder, exit_val, line)

builder.add_bool_branch(exit_val, out_block, reraise_block)
builder.activate_block(reraise_block)
builder.call_c(reraise_exception_op, [], NO_TRACEBACK_LINE_NO)
builder.add(Unreachable())
Expand All @@ -674,7 +680,12 @@ def finally_body() -> None:
builder.add(Branch(builder.read(exc), exit_block, out_block, Branch.BOOL))
builder.activate_block(exit_block)
none = builder.none_object()
builder.py_call(builder.read(exit_), [builder.read(mgr), none, none, none], line)
exit_val = builder.py_call(
builder.read(exit_), [builder.read(mgr), none, none, none], line
)
if is_async:
emit_await(builder, exit_val, line)

builder.goto_and_activate(out_block)

transform_try_finally_stmt(
Expand All @@ -685,15 +696,14 @@ def finally_body() -> None:


def transform_with_stmt(builder: IRBuilder, o: WithStmt) -> None:
if o.is_async:
builder.error("async with is unimplemented", o.line)

# Generate separate logic for each expr in it, left to right
def generate(i: int) -> None:
if i >= len(o.expr):
builder.accept(o.body)
else:
transform_with(builder, o.expr[i], o.target[i], lambda: generate(i + 1), o.line)
transform_with(
builder, o.expr[i], o.target[i], lambda: generate(i + 1), o.is_async, o.line
)

generate(0)

Expand Down Expand Up @@ -779,7 +789,7 @@ def emit_yield(builder: IRBuilder, val: Value, line: int) -> Value:


def emit_yield_from_or_await(
builder: IRBuilder, expr: Expression, line: int, *, is_await: bool
builder: IRBuilder, val: Value, line: int, *, is_await: bool
) -> Value:
# This is basically an implementation of the code in PEP 380.

Expand All @@ -789,7 +799,7 @@ def emit_yield_from_or_await(
received_reg = Register(object_rprimitive)

get_op = coro_op if is_await else iter_op
iter_val = builder.call_c(get_op, [builder.accept(expr)], line)
iter_val = builder.call_c(get_op, [val], line)

iter_reg = builder.maybe_spill_assignable(iter_val)

Expand Down Expand Up @@ -860,6 +870,10 @@ def else_body() -> None:
return builder.read(result)


def emit_await(builder: IRBuilder, val: Value, line: int) -> Value:
return emit_yield_from_or_await(builder, val, line, is_await=True)


def transform_yield_expr(builder: IRBuilder, expr: YieldExpr) -> Value:
if builder.fn_info.is_coroutine:
builder.error("async generators are unimplemented", expr.line)
Expand All @@ -872,8 +886,8 @@ def transform_yield_expr(builder: IRBuilder, expr: YieldExpr) -> Value:


def transform_yield_from_expr(builder: IRBuilder, o: YieldFromExpr) -> Value:
return emit_yield_from_or_await(builder, o.expr, o.line, is_await=False)
return emit_yield_from_or_await(builder, builder.accept(o.expr), o.line, is_await=False)


def transform_await_expr(builder: IRBuilder, o: AwaitExpr) -> Value:
return emit_yield_from_or_await(builder, o.expr, o.line, is_await=True)
return emit_yield_from_or_await(builder, builder.accept(o.expr), o.line, is_await=True)
8 changes: 0 additions & 8 deletions mypyc/test-data/commandline.test
Original file line number Diff line number Diff line change
Expand Up @@ -210,14 +210,6 @@ async def async_for(xs: AsyncIterable[int]) -> None:
{x async for x in xs} # E: async comprehensions are unimplemented
{x: x async for x in xs} # E: async comprehensions are unimplemented

class async_ctx:
async def __aenter__(self) -> int: pass
async def __aexit__(self, x, y, z) -> None: pass

async def async_with() -> None:
async with async_ctx() as x: # E: async with is unimplemented
print(x)

async def async_generators() -> AsyncIterable[int]:
yield 1 # E: async generators are unimplemented

Expand Down
14 changes: 13 additions & 1 deletion mypyc/test-data/fixtures/testutil.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Simple support library for our run tests.

from contextlib import contextmanager
from collections.abc import Iterator
from typing import (
Any, Iterator, TypeVar, Generator, Optional, List, Tuple, Sequence,
Union, Callable,
Union, Callable, Generic, Awaitable,
)

@contextmanager
Expand All @@ -30,6 +31,8 @@ def run_generator(gen: Generator[T, V, U],
if i >= 0 and inputs:
# ... fixtures don't have send
val = gen.send(inputs[i]) # type: ignore
elif not hasattr(gen, '__next__'): # type: ignore
val = gen.send(None) # type: ignore
else:
val = next(gen)
except StopIteration as e:
Expand All @@ -44,6 +47,15 @@ def run_generator(gen: Generator[T, V, U],
F = TypeVar('F', bound=Callable)


class async_val(Awaitable[V]):
def __init__(self, val: T) -> None:
self.val = val

def __await__(self) -> Generator[T, V, V]:
z = yield self.val
return z


# Wrap a mypyc-generated function in a real python function, to allow it to be
# stuck into classes and the like.
def make_python_function(f: F) -> F:
Expand Down
66 changes: 66 additions & 0 deletions mypyc/test-data/run-async.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# async test cases (compile and run)

[case testAsync]
import asyncio

async def h() -> int:
return 1

async def g() -> int:
await asyncio.sleep(0)
return await h()

async def f() -> int:
return await g()

[typing fixtures/typing-full.pyi]

[file driver.py]
from native import f
import asyncio

result = asyncio.run(f())
assert result == 1

[case testAsyncWith]
from testutil import async_val

class async_ctx:
async def __aenter__(self) -> str:
await async_val("enter")
return "test"

async def __aexit__(self, x, y, z) -> None:
await async_val("exit")


async def async_with() -> str:
async with async_ctx() as x:
return await async_val("body")


[file driver.py]
from native import async_with
from testutil import run_generator

yields, val = run_generator(async_with(), [None, 'x', None])
assert yields == ('enter', 'body', 'exit'), yields
assert val == 'x', val


[case testAsyncReturn]
from testutil import async_val

async def async_return() -> str:
try:
return 'test'
finally:
await async_val('foo')

[file driver.py]
from native import async_return
from testutil import run_generator

yields, val = run_generator(async_return())
assert yields == ('foo',)
assert val == 'test', val
Loading

0 comments on commit aff843e

Please sign in to comment.