@@ -794,8 +794,6 @@ def call_function(self, builder, callee, resty, argtys, args, name, attrs=None):
794
794
if isinstance (callee .function_type , lir .FunctionType ):
795
795
if config .DEBUG_OPENMP >= 2 :
796
796
print ("call_function:" , callee , callee .name , type (callee ), argtys , args )
797
- #if "compute_velocity" in callee.name:
798
- # breakpoint()
799
797
ft = callee .function_type
800
798
retty = ft .return_type
801
799
arginfo = self ._get_arg_packer (argtys )
@@ -1508,6 +1506,8 @@ def add_struct_tags(self, var_table):
1508
1506
elif target_num is not None and self .target_copy != True :
1509
1507
var_table = get_name_var_table (lowerer .func_ir .blocks )
1510
1508
1509
+ ompx_attrs = list (filter (lambda x : x .name == "QUAL.OMP.OMPX_ATTRIBUTE" , self .tags ))
1510
+ self .tags = list (filter (lambda x : x .name != "QUAL.OMP.OMPX_ATTRIBUTE" , self .tags ))
1511
1511
selected_device = 0
1512
1512
device_tags = get_tags_of_type (self .tags , "QUAL.OMP.DEVICE" )
1513
1513
if len (device_tags ) > 0 :
@@ -1932,11 +1932,6 @@ def call_conv(self):
1932
1932
dprint_func_ir (outlined_ir , "outlined_ir after replace np.empty" )
1933
1933
#if 'arrayexpr' not in device_target.special_ops:
1934
1934
# device_target.special_ops['arrayexpr'] = array_exprs._lower_array_expr
1935
-
1936
- #breakpoint()
1937
- #tfnty = device_target.typing_context.resolve_value_type(omp_shared_array)
1938
- #tsig = tfnty.get_call_type(device_target.typing_context, (types.IntegerLiteral(7), types.DType(types.float32)), {})
1939
- #timpl = device_target.get_function(tfnty, tsig)
1940
1935
else :
1941
1936
raise NotImplementedError ("Unsupported OpenMP device number" )
1942
1937
@@ -1984,7 +1979,6 @@ def call_conv(self):
1984
1979
is_lifted_loop = False , # tried this as True since code derived from loop lifting code but it goes through the pipeline twice and messes things up
1985
1980
parent_state = state_copy )
1986
1981
1987
- #breakpoint()
1988
1982
if selected_device == 0 :
1989
1983
#from numba.cpython import printimpl
1990
1984
##subtarget.install_registry(printimpl.registry)
@@ -2126,7 +2120,10 @@ def _get_target_image_in_memory(self, mod, filename_prefix):
2126
2120
with open (filename_prefix + "-intrinsics_omp-linked-opt.s" , "w" ) as f :
2127
2121
f .write (str (ptx ))
2128
2122
2129
- linker = driver .Linker .new (cc = self .cc )
2123
+ linker_kwargs = {}
2124
+ for x in ompx_attrs :
2125
+ linker_kwargs [x .arg [0 ]] = tuple (x .arg [1 ]) if len (x .arg [1 ]) > 1 else x .arg [1 ][0 ]
2126
+ linker = driver .Linker .new (cc = self .cc , ** linker_kwargs )
2130
2127
linker .add_ptx (ptx .encode ())
2131
2128
cubin = linker .complete ()
2132
2129
@@ -5325,6 +5322,21 @@ def var_list(self, args):
5325
5322
args [0 ].append (args [1 ])
5326
5323
return args [0 ]
5327
5324
5325
+ def number_list (self , args ):
5326
+ if config .DEBUG_OPENMP >= 1 :
5327
+ print ("visit number_list" , args , type (args ))
5328
+ if len (args ) == 1 :
5329
+ return args
5330
+ else :
5331
+ args [0 ].append (args [1 ])
5332
+ return args [0 ]
5333
+
5334
+ def ompx_attribute (self , args ):
5335
+ if config .DEBUG_OPENMP >= 1 :
5336
+ print ("visit ompx_attribute" , args , type (args ), args [0 ])
5337
+ (_ , attr , number_list ) = args
5338
+ return openmp_tag ("QUAL.OMP.OMPX_ATTRIBUTE" , (attr , number_list ))
5339
+
5328
5340
def PLUS (self , args ):
5329
5341
if config .DEBUG_OPENMP >= 1 :
5330
5342
print ("visit PLUS" , args , type (args ))
@@ -5617,6 +5629,7 @@ def NUMBER(self, args):
5617
5629
| allocate_clause
5618
5630
| depend_with_modifier_clause
5619
5631
// | uses_allocators_clause
5632
+ | ompx_attribute
5620
5633
teams_clause: num_teams_clause
5621
5634
| thread_limit_clause
5622
5635
| data_default_clause
@@ -5668,6 +5681,7 @@ def NUMBER(self, args):
5668
5681
// | simdlen_clause
5669
5682
| aligned_clause
5670
5683
// | nontemporal_clause
5684
+ | ompx_attribute
5671
5685
5672
5686
target_teams_distribute_parallel_for_directive: TARGET TEAMS DISTRIBUTE PARALLEL FOR [target_teams_distribute_parallel_for_clause*]
5673
5687
target_teams_distribute_parallel_for_clause: if_clause
@@ -5697,9 +5711,12 @@ def NUMBER(self, args):
5697
5711
| ORDERED
5698
5712
// | order_clause
5699
5713
| dist_schedule_clause
5714
+ | ompx_attribute
5700
5715
5701
5716
LOOP: "loop"
5702
5717
5718
+ ompx_attribute: OMPX_ATTRIBUTE "(" PYTHON_NAME "(" number_list ")" ")"
5719
+ OMPX_ATTRIBUTE: "ompx_attribute"
5703
5720
//target_teams_loop_directive: TARGET TEAMS LOOP [target_teams_loop_clause*]
5704
5721
target_teams_loop_directive: TARGET TEAMS LOOP [target_teams_distribute_parallel_for_simd_clause*]
5705
5722
target_teams_loop_clause: if_clause
@@ -5723,6 +5740,7 @@ def NUMBER(self, args):
5723
5740
| collapse_clause
5724
5741
| ORDERED
5725
5742
| data_privatization_out_clause
5743
+ | ompx_attribute
5726
5744
5727
5745
target_teams_directive: TARGET TEAMS [target_teams_clause*]
5728
5746
target_teams_clause: if_clause
@@ -5742,6 +5760,7 @@ def NUMBER(self, args):
5742
5760
| data_default_clause
5743
5761
| data_sharing_clause
5744
5762
// | reduction_default_only_clause
5763
+ | ompx_attribute
5745
5764
5746
5765
target_teams_distribute_directive: TARGET TEAMS DISTRIBUTE [target_teams_clause*]
5747
5766
target_teams_distribute_clause: num_teams_clause
@@ -5755,6 +5774,7 @@ def NUMBER(self, args):
5755
5774
| data_privatization_out_clause
5756
5775
| collapse_clause
5757
5776
| dist_schedule_clause
5777
+ | ompx_attribute
5758
5778
5759
5779
IS_DEVICE_PTR: "is_device_ptr"
5760
5780
is_device_ptr_clause: IS_DEVICE_PTR "(" var_list ")"
@@ -5926,6 +5946,7 @@ def NUMBER(self, args):
5926
5946
slice_list: oslice | slice_list "," oslice
5927
5947
name_slice: PYTHON_NAME [ "[" slice_list "]" ]
5928
5948
var_list: name_slice | var_list "," name_slice
5949
+ number_list: NUMBER | number_list "," NUMBER
5929
5950
PLUS: "+"
5930
5951
reduction_operator: PLUS | "\\" | "*" | "-" | "&" | "^" | "|" | "&&" | "||"
5931
5952
threadprivate_directive: "threadprivate" "(" var_list ")"
0 commit comments