@@ -87,6 +87,14 @@ impl CompileContext {
87
87
}
88
88
}
89
89
90
+ #[ derive( Debug , Clone , Copy , PartialEq ) ]
91
+ enum ComprehensionType {
92
+ Generator ,
93
+ List ,
94
+ Set ,
95
+ Dict ,
96
+ }
97
+
90
98
/// Compile an located_ast::Mod produced from rustpython_parser::parse()
91
99
pub fn compile_top (
92
100
ast : & located_ast:: Mod ,
@@ -2431,6 +2439,8 @@ impl Compiler {
2431
2439
) ;
2432
2440
Ok ( ( ) )
2433
2441
} ,
2442
+ ComprehensionType :: List ,
2443
+ Self :: contains_await ( elt) ,
2434
2444
) ?;
2435
2445
}
2436
2446
Expr :: SetComp ( located_ast:: ExprSetComp {
@@ -2452,6 +2462,8 @@ impl Compiler {
2452
2462
) ;
2453
2463
Ok ( ( ) )
2454
2464
} ,
2465
+ ComprehensionType :: Set ,
2466
+ Self :: contains_await ( elt) ,
2455
2467
) ?;
2456
2468
}
2457
2469
Expr :: DictComp ( located_ast:: ExprDictComp {
@@ -2480,19 +2492,28 @@ impl Compiler {
2480
2492
2481
2493
Ok ( ( ) )
2482
2494
} ,
2495
+ ComprehensionType :: Dict ,
2496
+ Self :: contains_await ( key) || Self :: contains_await ( value) ,
2483
2497
) ?;
2484
2498
}
2485
2499
Expr :: GeneratorExp ( located_ast:: ExprGeneratorExp {
2486
2500
elt, generators, ..
2487
2501
} ) => {
2488
- self . compile_comprehension ( "<genexpr>" , None , generators, & |compiler| {
2489
- compiler. compile_comprehension_element ( elt) ?;
2490
- compiler. mark_generator ( ) ;
2491
- emit ! ( compiler, Instruction :: YieldValue ) ;
2492
- emit ! ( compiler, Instruction :: Pop ) ;
2502
+ self . compile_comprehension (
2503
+ "<genexpr>" ,
2504
+ None ,
2505
+ generators,
2506
+ & |compiler| {
2507
+ compiler. compile_comprehension_element ( elt) ?;
2508
+ compiler. mark_generator ( ) ;
2509
+ emit ! ( compiler, Instruction :: YieldValue ) ;
2510
+ emit ! ( compiler, Instruction :: Pop ) ;
2493
2511
2494
- Ok ( ( ) )
2495
- } ) ?;
2512
+ Ok ( ( ) )
2513
+ } ,
2514
+ ComprehensionType :: Generator ,
2515
+ Self :: contains_await ( elt) ,
2516
+ ) ?;
2496
2517
}
2497
2518
Expr :: Starred ( _) => {
2498
2519
return Err ( self . error ( CodegenErrorType :: InvalidStarExpr ) ) ;
@@ -2744,9 +2765,35 @@ impl Compiler {
2744
2765
init_collection : Option < Instruction > ,
2745
2766
generators : & [ located_ast:: Comprehension ] ,
2746
2767
compile_element : & dyn Fn ( & mut Self ) -> CompileResult < ( ) > ,
2768
+ comprehension_type : ComprehensionType ,
2769
+ element_contains_await : bool ,
2747
2770
) -> CompileResult < ( ) > {
2748
2771
let prev_ctx = self . ctx ;
2749
- let is_async = generators. iter ( ) . any ( |g| g. is_async ) ;
2772
+ let has_an_async_gen = generators. iter ( ) . any ( |g| g. is_async ) ;
2773
+
2774
+ // async comprehensions are allowed in various contexts:
2775
+ // - list/set/dict comprehensions in async functions
2776
+ // - always for generator expressions
2777
+ // Note: generators have to be treated specially since their async version is a fundamentally
2778
+ // different type (aiter vs iter) instead of just an awaitable.
2779
+
2780
+ // for if it actually is async, we check if any generator is async or if the element contains await
2781
+
2782
+ // if the element expression contains await, but the context doesn't allow for async,
2783
+ // then we continue on here with is_async=false and will produce a syntax once the await is hit
2784
+
2785
+ let is_async_list_set_dict_comprehension = comprehension_type
2786
+ != ComprehensionType :: Generator
2787
+ && ( has_an_async_gen || element_contains_await) // does it have to be async? (uses await or async for)
2788
+ && prev_ctx. func == FunctionContext :: AsyncFunction ; // is it allowed to be async? (in an async function)
2789
+
2790
+ let is_async_generator_comprehension = comprehension_type == ComprehensionType :: Generator
2791
+ && ( has_an_async_gen || element_contains_await) ;
2792
+
2793
+ // since one is for generators, and one for not generators, they should never both be true
2794
+ debug_assert ! ( !( is_async_list_set_dict_comprehension && is_async_generator_comprehension) ) ;
2795
+
2796
+ let is_async = is_async_list_set_dict_comprehension || is_async_generator_comprehension;
2750
2797
2751
2798
self . ctx = CompileContext {
2752
2799
loop_data : None ,
@@ -2838,7 +2885,7 @@ impl Compiler {
2838
2885
2839
2886
// End of for loop:
2840
2887
self . switch_to_block ( after_block) ;
2841
- if is_async {
2888
+ if has_an_async_gen {
2842
2889
emit ! ( self , Instruction :: EndAsyncFor ) ;
2843
2890
}
2844
2891
}
@@ -2877,19 +2924,23 @@ impl Compiler {
2877
2924
self . compile_expression ( & generators[ 0 ] . iter ) ?;
2878
2925
2879
2926
// Get iterator / turn item into an iterator
2880
- if is_async {
2927
+ if has_an_async_gen {
2881
2928
emit ! ( self , Instruction :: GetAIter ) ;
2882
2929
} else {
2883
2930
emit ! ( self , Instruction :: GetIter ) ;
2884
2931
} ;
2885
2932
2886
2933
// Call just created <listcomp> function:
2887
2934
emit ! ( self , Instruction :: CallFunctionPositional { nargs: 1 } ) ;
2888
- if is_async {
2935
+ if is_async_list_set_dict_comprehension {
2936
+ // async, but not a generator and not an async for
2937
+ // in this case, we end up with an awaitable
2938
+ // that evaluates to the list/set/dict, so here we add an await
2889
2939
emit ! ( self , Instruction :: GetAwaitable ) ;
2890
2940
self . emit_load_const ( ConstantData :: None ) ;
2891
2941
emit ! ( self , Instruction :: YieldFrom ) ;
2892
2942
}
2943
+
2893
2944
Ok ( ( ) )
2894
2945
}
2895
2946
@@ -3016,6 +3067,117 @@ impl Compiler {
3016
3067
fn mark_generator ( & mut self ) {
3017
3068
self . current_code_info ( ) . flags |= bytecode:: CodeFlags :: IS_GENERATOR
3018
3069
}
3070
+
3071
+ /// Whether the expression contains an await expression and
3072
+ /// thus requires the function to be async.
3073
+ /// Async with and async for are statements, so I won't check for them here
3074
+ fn contains_await ( expression : & located_ast:: Expr ) -> bool {
3075
+ use located_ast:: * ;
3076
+
3077
+ match & expression {
3078
+ Expr :: Call ( ExprCall {
3079
+ func,
3080
+ args,
3081
+ keywords,
3082
+ ..
3083
+ } ) => {
3084
+ Self :: contains_await ( func)
3085
+ || args. iter ( ) . any ( Self :: contains_await)
3086
+ || keywords. iter ( ) . any ( |kw| Self :: contains_await ( & kw. value ) )
3087
+ }
3088
+ Expr :: BoolOp ( ExprBoolOp { values, .. } ) => values. iter ( ) . any ( Self :: contains_await) ,
3089
+ Expr :: BinOp ( ExprBinOp { left, right, .. } ) => {
3090
+ Self :: contains_await ( left) || Self :: contains_await ( right)
3091
+ }
3092
+ Expr :: Subscript ( ExprSubscript { value, slice, .. } ) => {
3093
+ Self :: contains_await ( value) || Self :: contains_await ( slice)
3094
+ }
3095
+ Expr :: UnaryOp ( ExprUnaryOp { operand, .. } ) => Self :: contains_await ( operand) ,
3096
+ Expr :: Attribute ( ExprAttribute { value, .. } ) => Self :: contains_await ( value) ,
3097
+ Expr :: Compare ( ExprCompare {
3098
+ left, comparators, ..
3099
+ } ) => Self :: contains_await ( left) || comparators. iter ( ) . any ( Self :: contains_await) ,
3100
+ Expr :: Constant ( ExprConstant { .. } ) => false ,
3101
+ Expr :: List ( ExprList { elts, .. } ) => elts. iter ( ) . any ( Self :: contains_await) ,
3102
+ Expr :: Tuple ( ExprTuple { elts, .. } ) => elts. iter ( ) . any ( Self :: contains_await) ,
3103
+ Expr :: Set ( ExprSet { elts, .. } ) => elts. iter ( ) . any ( Self :: contains_await) ,
3104
+ Expr :: Dict ( ExprDict { keys, values, .. } ) => {
3105
+ keys. iter ( )
3106
+ . any ( |key| key. as_ref ( ) . map_or ( false , Self :: contains_await) )
3107
+ || values. iter ( ) . any ( Self :: contains_await)
3108
+ }
3109
+ Expr :: Slice ( ExprSlice {
3110
+ lower, upper, step, ..
3111
+ } ) => {
3112
+ lower. as_ref ( ) . map_or ( false , |l| Self :: contains_await ( l) )
3113
+ || upper. as_ref ( ) . map_or ( false , |u| Self :: contains_await ( u) )
3114
+ || step. as_ref ( ) . map_or ( false , |s| Self :: contains_await ( s) )
3115
+ }
3116
+ Expr :: Yield ( ExprYield { value, .. } ) => {
3117
+ value. as_ref ( ) . map_or ( false , |v| Self :: contains_await ( v) )
3118
+ }
3119
+ Expr :: Await ( ExprAwait { .. } ) => true ,
3120
+ Expr :: YieldFrom ( ExprYieldFrom { value, .. } ) => Self :: contains_await ( value) ,
3121
+ Expr :: JoinedStr ( ExprJoinedStr { values, .. } ) => {
3122
+ values. iter ( ) . any ( Self :: contains_await)
3123
+ }
3124
+ Expr :: FormattedValue ( ExprFormattedValue {
3125
+ value,
3126
+ conversion : _,
3127
+ format_spec,
3128
+ ..
3129
+ } ) => {
3130
+ Self :: contains_await ( value)
3131
+ || format_spec
3132
+ . as_ref ( )
3133
+ . map_or ( false , |fs| Self :: contains_await ( fs) )
3134
+ }
3135
+ Expr :: Name ( located_ast:: ExprName { .. } ) => false ,
3136
+ Expr :: Lambda ( located_ast:: ExprLambda { body, .. } ) => Self :: contains_await ( body) ,
3137
+ Expr :: ListComp ( located_ast:: ExprListComp {
3138
+ elt, generators, ..
3139
+ } ) => {
3140
+ Self :: contains_await ( elt)
3141
+ || generators. iter ( ) . any ( |gen| Self :: contains_await ( & gen. iter ) )
3142
+ }
3143
+ Expr :: SetComp ( located_ast:: ExprSetComp {
3144
+ elt, generators, ..
3145
+ } ) => {
3146
+ Self :: contains_await ( elt)
3147
+ || generators. iter ( ) . any ( |gen| Self :: contains_await ( & gen. iter ) )
3148
+ }
3149
+ Expr :: DictComp ( located_ast:: ExprDictComp {
3150
+ key,
3151
+ value,
3152
+ generators,
3153
+ ..
3154
+ } ) => {
3155
+ Self :: contains_await ( key)
3156
+ || Self :: contains_await ( value)
3157
+ || generators. iter ( ) . any ( |gen| Self :: contains_await ( & gen. iter ) )
3158
+ }
3159
+ Expr :: GeneratorExp ( located_ast:: ExprGeneratorExp {
3160
+ elt, generators, ..
3161
+ } ) => {
3162
+ Self :: contains_await ( elt)
3163
+ || generators. iter ( ) . any ( |gen| Self :: contains_await ( & gen. iter ) )
3164
+ }
3165
+ Expr :: Starred ( expr) => Self :: contains_await ( & expr. value ) ,
3166
+ Expr :: IfExp ( located_ast:: ExprIfExp {
3167
+ test, body, orelse, ..
3168
+ } ) => {
3169
+ Self :: contains_await ( test)
3170
+ || Self :: contains_await ( body)
3171
+ || Self :: contains_await ( orelse)
3172
+ }
3173
+
3174
+ Expr :: NamedExpr ( located_ast:: ExprNamedExpr {
3175
+ target,
3176
+ value,
3177
+ range : _,
3178
+ } ) => Self :: contains_await ( target) || Self :: contains_await ( value) ,
3179
+ }
3180
+ }
3019
3181
}
3020
3182
3021
3183
trait EmitArg < Arg : OpArgType > {
0 commit comments