@@ -2809,8 +2809,6 @@ def get_dotted_type(x, typemap, lowerer):
2809
2809
2810
2810
2811
2811
def is_target_arg (name ):
2812
- #return name in ["QUAL.OMP.FIRSTPRIVATE", "QUAL.OMP.TARGET.IMPLICIT", "QUAL.OMP.THREAD_LIMIT", "QUAL.OMP.NUM_TEAMS"] or name.startswith("QUAL.OMP.MAP")
2813
- #or name.startswith("QUAL.OMP.NORMALIZED")
2814
2812
return name in ["QUAL.OMP.FIRSTPRIVATE" , "QUAL.OMP.TARGET.IMPLICIT" ] or name .startswith ("QUAL.OMP.MAP" ) or name .startswith ("QUAL.OMP.REDUCTION" )
2815
2813
2816
2814
@@ -2824,7 +2822,6 @@ def is_pointer_target_arg(name, typ):
2824
2822
if name .startswith ("QUAL.OMP.MAP" ):
2825
2823
if isinstance (typ , types .npytypes .Array ):
2826
2824
return True
2827
- #return False
2828
2825
else :
2829
2826
return True
2830
2827
if name in ["QUAL.OMP.FIRSTPRIVATE" , "QUAL.OMP.PRIVATE" ]:
@@ -3408,15 +3405,25 @@ def get_loops_in_region(all_loops):
3408
3405
3409
3406
# Copy all stmts from the loop entry block up to the ir.Global
3410
3407
# for range.
3408
+ call_offset = None
3411
3409
for entry_block_index , stmt in enumerate (loop_entry_block .body ):
3410
+ found_range = False
3412
3411
if isinstance (stmt , ir .Assign ) and isinstance (stmt .value , ir .Global ) and stmt .value .name == "range" :
3412
+ found_range = True
3413
3413
range_target = stmt .target
3414
- call_stmt = loop_entry_block .body [entry_block_index + 1 ]
3415
- assert isinstance (call_stmt , ir .Assign ) and isinstance (call_stmt .value , ir .Expr ) and call_stmt .value .op == 'call' and call_stmt .value .func == range_target
3416
- # Remove stmts that were retained.
3417
- loop_entry_block .body = loop_entry_block .body [entry_block_index :]
3414
+ found_call = False
3415
+ for call_index in range (entry_block_index + 1 , len (loop_entry_block .body )):
3416
+ call_stmt = loop_entry_block .body [call_index ]
3417
+ if isinstance (call_stmt , ir .Assign ) and isinstance (call_stmt .value , ir .Expr ) and call_stmt .value .op == 'call' and call_stmt .value .func == range_target :
3418
+ found_call = True
3419
+ # Remove stmts that were retained.
3420
+ loop_entry_block .body = loop_entry_block .body [entry_block_index :]
3421
+ call_offset = call_index - entry_block_index
3422
+ break
3423
+ assert found_call
3418
3424
break
3419
3425
stmts_to_retain .append (stmt )
3426
+ assert found_range
3420
3427
for header_block_index , stmt in enumerate (loop_header_block .body ):
3421
3428
if isinstance (stmt , ir .Assign ) and isinstance (stmt .value , ir .Expr ) and stmt .value .op == "iternext" :
3422
3429
iternext_inst = loop_header_block .body [header_block_index ]
@@ -3475,7 +3482,7 @@ def get_loops_in_region(all_loops):
3475
3482
new_stmts_for_iterspace .append (ir .Assign (mul_op , new_iterspace_var , self .loc ))
3476
3483
# Change iteration space of innermost loop to the product of all the
3477
3484
# loops' iteration spaces.
3478
- last_loop_entry_block .body [1 ].value .args [0 ] = new_iterspace_var
3485
+ last_loop_entry_block .body [call_offset ].value .args [0 ] = new_iterspace_var
3479
3486
3480
3487
last_eliminated_loop_header_block .body = new_stmts_for_iterspace + last_eliminated_loop_header_block .body
3481
3488
@@ -4324,14 +4331,16 @@ def target_teams_distribute_directive(self, args):
4324
4331
self .some_target_directive (args , "TARGET.TEAMS.DISTRIBUTE" , 3 , has_loop = True )
4325
4332
4326
4333
def target_teams_loop_directive (self , args ):
4327
- self .some_target_directive (args , "TARGET.TEAMS.DISTRIBUTE.PARALLEL.LOOP.SIMD" , 3 , has_loop = True )
4334
+ self .some_target_directive (args , "TARGET.TEAMS.DISTRIBUTE.PARALLEL.LOOP" , 3 , has_loop = True )
4335
+ #self.some_target_directive(args, "TARGET.TEAMS.DISTRIBUTE.PARALLEL.LOOP.SIMD", 3, has_loop=True)
4328
4336
#self.some_target_directive(args, "TARGET.TEAMS.LOOP", 3, has_loop=True)
4329
4337
4330
4338
def target_teams_distribute_parallel_for_directive (self , args ):
4331
4339
self .some_target_directive (args , "TARGET.TEAMS.DISTRIBUTE.PARALLEL.LOOP" , 5 , has_loop = True )
4332
4340
4333
4341
def target_teams_distribute_parallel_for_simd_directive (self , args ):
4334
- self .some_target_directive (args , "TARGET.TEAMS.DISTRIBUTE.PARALLEL.LOOP.SIMD" , 6 , has_loop = True )
4342
+ # Intentionally dropping "SIMD" from string as that typically isn't implemented on GPU.
4343
+ self .some_target_directive (args , "TARGET.TEAMS.DISTRIBUTE.PARALLEL.LOOP" , 6 , has_loop = True )
4335
4344
4336
4345
def get_clauses_by_name (self , clauses , names , remove_from_orig = False ):
4337
4346
if not isinstance (names , list ):
0 commit comments