Skip to content

Commit 7996a10

Browse files
authored
await in list comprehension (RustPython#5334)
* check if comprehension element contains await * force execution to pause in async gen
1 parent db4562f commit 7996a10

File tree

3 files changed

+344
-13
lines changed

3 files changed

+344
-13
lines changed

Lib/test/test_asyncgen.py

-2
Original file line numberDiff line numberDiff line change
@@ -1568,8 +1568,6 @@ async def main():
15681568
self.assertIn('unhandled exception during asyncio.run() shutdown',
15691569
message['message'])
15701570

1571-
# TODO: RUSTPYTHON; TypeError: object async_generator can't be used in 'await' expression
1572-
@unittest.expectedFailure
15731571
def test_async_gen_expression_01(self):
15741572
async def arange(n):
15751573
for i in range(n):

compiler/codegen/src/compile.rs

+173-11
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,14 @@ impl CompileContext {
8787
}
8888
}
8989

90+
#[derive(Debug, Clone, Copy, PartialEq)]
91+
enum ComprehensionType {
92+
Generator,
93+
List,
94+
Set,
95+
Dict,
96+
}
97+
9098
/// Compile an located_ast::Mod produced from rustpython_parser::parse()
9199
pub fn compile_top(
92100
ast: &located_ast::Mod,
@@ -2431,6 +2439,8 @@ impl Compiler {
24312439
);
24322440
Ok(())
24332441
},
2442+
ComprehensionType::List,
2443+
Self::contains_await(elt),
24342444
)?;
24352445
}
24362446
Expr::SetComp(located_ast::ExprSetComp {
@@ -2452,6 +2462,8 @@ impl Compiler {
24522462
);
24532463
Ok(())
24542464
},
2465+
ComprehensionType::Set,
2466+
Self::contains_await(elt),
24552467
)?;
24562468
}
24572469
Expr::DictComp(located_ast::ExprDictComp {
@@ -2480,19 +2492,28 @@ impl Compiler {
24802492

24812493
Ok(())
24822494
},
2495+
ComprehensionType::Dict,
2496+
Self::contains_await(key) || Self::contains_await(value),
24832497
)?;
24842498
}
24852499
Expr::GeneratorExp(located_ast::ExprGeneratorExp {
24862500
elt, generators, ..
24872501
}) => {
2488-
self.compile_comprehension("<genexpr>", None, generators, &|compiler| {
2489-
compiler.compile_comprehension_element(elt)?;
2490-
compiler.mark_generator();
2491-
emit!(compiler, Instruction::YieldValue);
2492-
emit!(compiler, Instruction::Pop);
2502+
self.compile_comprehension(
2503+
"<genexpr>",
2504+
None,
2505+
generators,
2506+
&|compiler| {
2507+
compiler.compile_comprehension_element(elt)?;
2508+
compiler.mark_generator();
2509+
emit!(compiler, Instruction::YieldValue);
2510+
emit!(compiler, Instruction::Pop);
24932511

2494-
Ok(())
2495-
})?;
2512+
Ok(())
2513+
},
2514+
ComprehensionType::Generator,
2515+
Self::contains_await(elt),
2516+
)?;
24962517
}
24972518
Expr::Starred(_) => {
24982519
return Err(self.error(CodegenErrorType::InvalidStarExpr));
@@ -2744,9 +2765,35 @@ impl Compiler {
27442765
init_collection: Option<Instruction>,
27452766
generators: &[located_ast::Comprehension],
27462767
compile_element: &dyn Fn(&mut Self) -> CompileResult<()>,
2768+
comprehension_type: ComprehensionType,
2769+
element_contains_await: bool,
27472770
) -> CompileResult<()> {
27482771
let prev_ctx = self.ctx;
2749-
let is_async = generators.iter().any(|g| g.is_async);
2772+
let has_an_async_gen = generators.iter().any(|g| g.is_async);
2773+
2774+
// async comprehensions are allowed in various contexts:
2775+
// - list/set/dict comprehensions in async functions
2776+
// - always for generator expressions
2777+
// Note: generators have to be treated specially since their async version is a fundamentally
2778+
// different type (aiter vs iter) instead of just an awaitable.
2779+
2780+
// for if it actually is async, we check if any generator is async or if the element contains await
2781+
2782+
// if the element expression contains await, but the context doesn't allow for async,
2783+
// then we continue on here with is_async=false and will produce a syntax once the await is hit
2784+
2785+
let is_async_list_set_dict_comprehension = comprehension_type
2786+
!= ComprehensionType::Generator
2787+
&& (has_an_async_gen || element_contains_await) // does it have to be async? (uses await or async for)
2788+
&& prev_ctx.func == FunctionContext::AsyncFunction; // is it allowed to be async? (in an async function)
2789+
2790+
let is_async_generator_comprehension = comprehension_type == ComprehensionType::Generator
2791+
&& (has_an_async_gen || element_contains_await);
2792+
2793+
// since one is for generators, and one for not generators, they should never both be true
2794+
debug_assert!(!(is_async_list_set_dict_comprehension && is_async_generator_comprehension));
2795+
2796+
let is_async = is_async_list_set_dict_comprehension || is_async_generator_comprehension;
27502797

27512798
self.ctx = CompileContext {
27522799
loop_data: None,
@@ -2838,7 +2885,7 @@ impl Compiler {
28382885

28392886
// End of for loop:
28402887
self.switch_to_block(after_block);
2841-
if is_async {
2888+
if has_an_async_gen {
28422889
emit!(self, Instruction::EndAsyncFor);
28432890
}
28442891
}
@@ -2877,19 +2924,23 @@ impl Compiler {
28772924
self.compile_expression(&generators[0].iter)?;
28782925

28792926
// Get iterator / turn item into an iterator
2880-
if is_async {
2927+
if has_an_async_gen {
28812928
emit!(self, Instruction::GetAIter);
28822929
} else {
28832930
emit!(self, Instruction::GetIter);
28842931
};
28852932

28862933
// Call just created <listcomp> function:
28872934
emit!(self, Instruction::CallFunctionPositional { nargs: 1 });
2888-
if is_async {
2935+
if is_async_list_set_dict_comprehension {
2936+
// async, but not a generator and not an async for
2937+
// in this case, we end up with an awaitable
2938+
// that evaluates to the list/set/dict, so here we add an await
28892939
emit!(self, Instruction::GetAwaitable);
28902940
self.emit_load_const(ConstantData::None);
28912941
emit!(self, Instruction::YieldFrom);
28922942
}
2943+
28932944
Ok(())
28942945
}
28952946

@@ -3016,6 +3067,117 @@ impl Compiler {
30163067
fn mark_generator(&mut self) {
30173068
self.current_code_info().flags |= bytecode::CodeFlags::IS_GENERATOR
30183069
}
3070+
3071+
/// Whether the expression contains an await expression and
3072+
/// thus requires the function to be async.
3073+
/// Async with and async for are statements, so I won't check for them here
3074+
fn contains_await(expression: &located_ast::Expr) -> bool {
3075+
use located_ast::*;
3076+
3077+
match &expression {
3078+
Expr::Call(ExprCall {
3079+
func,
3080+
args,
3081+
keywords,
3082+
..
3083+
}) => {
3084+
Self::contains_await(func)
3085+
|| args.iter().any(Self::contains_await)
3086+
|| keywords.iter().any(|kw| Self::contains_await(&kw.value))
3087+
}
3088+
Expr::BoolOp(ExprBoolOp { values, .. }) => values.iter().any(Self::contains_await),
3089+
Expr::BinOp(ExprBinOp { left, right, .. }) => {
3090+
Self::contains_await(left) || Self::contains_await(right)
3091+
}
3092+
Expr::Subscript(ExprSubscript { value, slice, .. }) => {
3093+
Self::contains_await(value) || Self::contains_await(slice)
3094+
}
3095+
Expr::UnaryOp(ExprUnaryOp { operand, .. }) => Self::contains_await(operand),
3096+
Expr::Attribute(ExprAttribute { value, .. }) => Self::contains_await(value),
3097+
Expr::Compare(ExprCompare {
3098+
left, comparators, ..
3099+
}) => Self::contains_await(left) || comparators.iter().any(Self::contains_await),
3100+
Expr::Constant(ExprConstant { .. }) => false,
3101+
Expr::List(ExprList { elts, .. }) => elts.iter().any(Self::contains_await),
3102+
Expr::Tuple(ExprTuple { elts, .. }) => elts.iter().any(Self::contains_await),
3103+
Expr::Set(ExprSet { elts, .. }) => elts.iter().any(Self::contains_await),
3104+
Expr::Dict(ExprDict { keys, values, .. }) => {
3105+
keys.iter()
3106+
.any(|key| key.as_ref().map_or(false, Self::contains_await))
3107+
|| values.iter().any(Self::contains_await)
3108+
}
3109+
Expr::Slice(ExprSlice {
3110+
lower, upper, step, ..
3111+
}) => {
3112+
lower.as_ref().map_or(false, |l| Self::contains_await(l))
3113+
|| upper.as_ref().map_or(false, |u| Self::contains_await(u))
3114+
|| step.as_ref().map_or(false, |s| Self::contains_await(s))
3115+
}
3116+
Expr::Yield(ExprYield { value, .. }) => {
3117+
value.as_ref().map_or(false, |v| Self::contains_await(v))
3118+
}
3119+
Expr::Await(ExprAwait { .. }) => true,
3120+
Expr::YieldFrom(ExprYieldFrom { value, .. }) => Self::contains_await(value),
3121+
Expr::JoinedStr(ExprJoinedStr { values, .. }) => {
3122+
values.iter().any(Self::contains_await)
3123+
}
3124+
Expr::FormattedValue(ExprFormattedValue {
3125+
value,
3126+
conversion: _,
3127+
format_spec,
3128+
..
3129+
}) => {
3130+
Self::contains_await(value)
3131+
|| format_spec
3132+
.as_ref()
3133+
.map_or(false, |fs| Self::contains_await(fs))
3134+
}
3135+
Expr::Name(located_ast::ExprName { .. }) => false,
3136+
Expr::Lambda(located_ast::ExprLambda { body, .. }) => Self::contains_await(body),
3137+
Expr::ListComp(located_ast::ExprListComp {
3138+
elt, generators, ..
3139+
}) => {
3140+
Self::contains_await(elt)
3141+
|| generators.iter().any(|gen| Self::contains_await(&gen.iter))
3142+
}
3143+
Expr::SetComp(located_ast::ExprSetComp {
3144+
elt, generators, ..
3145+
}) => {
3146+
Self::contains_await(elt)
3147+
|| generators.iter().any(|gen| Self::contains_await(&gen.iter))
3148+
}
3149+
Expr::DictComp(located_ast::ExprDictComp {
3150+
key,
3151+
value,
3152+
generators,
3153+
..
3154+
}) => {
3155+
Self::contains_await(key)
3156+
|| Self::contains_await(value)
3157+
|| generators.iter().any(|gen| Self::contains_await(&gen.iter))
3158+
}
3159+
Expr::GeneratorExp(located_ast::ExprGeneratorExp {
3160+
elt, generators, ..
3161+
}) => {
3162+
Self::contains_await(elt)
3163+
|| generators.iter().any(|gen| Self::contains_await(&gen.iter))
3164+
}
3165+
Expr::Starred(expr) => Self::contains_await(&expr.value),
3166+
Expr::IfExp(located_ast::ExprIfExp {
3167+
test, body, orelse, ..
3168+
}) => {
3169+
Self::contains_await(test)
3170+
|| Self::contains_await(body)
3171+
|| Self::contains_await(orelse)
3172+
}
3173+
3174+
Expr::NamedExpr(located_ast::ExprNamedExpr {
3175+
target,
3176+
value,
3177+
range: _,
3178+
}) => Self::contains_await(target) || Self::contains_await(value),
3179+
}
3180+
}
30193181
}
30203182

30213183
trait EmitArg<Arg: OpArgType> {

0 commit comments

Comments
 (0)