Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit fe17768

Browse files
committed
Support ompx_attribute.
1 parent 6ba1cfa commit fe17768

File tree

1 file changed

+30
-9
lines changed

1 file changed

+30
-9
lines changed

numba/openmp.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -794,8 +794,6 @@ def call_function(self, builder, callee, resty, argtys, args, name, attrs=None):
794794
if isinstance(callee.function_type, lir.FunctionType):
795795
if config.DEBUG_OPENMP >= 2:
796796
print("call_function:", callee, callee.name, type(callee), argtys, args)
797-
#if "compute_velocity" in callee.name:
798-
# breakpoint()
799797
ft = callee.function_type
800798
retty = ft.return_type
801799
arginfo = self._get_arg_packer(argtys)
@@ -1508,6 +1506,8 @@ def add_struct_tags(self, var_table):
15081506
elif target_num is not None and self.target_copy != True:
15091507
var_table = get_name_var_table(lowerer.func_ir.blocks)
15101508

1509+
ompx_attrs = list(filter(lambda x: x.name == "QUAL.OMP.OMPX_ATTRIBUTE", self.tags))
1510+
self.tags = list(filter(lambda x: x.name != "QUAL.OMP.OMPX_ATTRIBUTE", self.tags))
15111511
selected_device = 0
15121512
device_tags = get_tags_of_type(self.tags, "QUAL.OMP.DEVICE")
15131513
if len(device_tags) > 0:
@@ -1932,11 +1932,6 @@ def call_conv(self):
19321932
dprint_func_ir(outlined_ir, "outlined_ir after replace np.empty")
19331933
#if 'arrayexpr' not in device_target.special_ops:
19341934
# device_target.special_ops['arrayexpr'] = array_exprs._lower_array_expr
1935-
1936-
#breakpoint()
1937-
#tfnty = device_target.typing_context.resolve_value_type(omp_shared_array)
1938-
#tsig = tfnty.get_call_type(device_target.typing_context, (types.IntegerLiteral(7), types.DType(types.float32)), {})
1939-
#timpl = device_target.get_function(tfnty, tsig)
19401935
else:
19411936
raise NotImplementedError("Unsupported OpenMP device number")
19421937

@@ -1984,7 +1979,6 @@ def call_conv(self):
19841979
is_lifted_loop=False, # tried this as True since code derived from loop lifting code but it goes through the pipeline twice and messes things up
19851980
parent_state=state_copy)
19861981

1987-
#breakpoint()
19881982
if selected_device == 0:
19891983
#from numba.cpython import printimpl
19901984
##subtarget.install_registry(printimpl.registry)
@@ -2126,7 +2120,10 @@ def _get_target_image_in_memory(self, mod, filename_prefix):
21262120
with open(filename_prefix + "-intrinsics_omp-linked-opt.s", "w") as f:
21272121
f.write(str(ptx))
21282122

2129-
linker = driver.Linker.new(cc=self.cc)
2123+
linker_kwargs = {}
2124+
for x in ompx_attrs:
2125+
linker_kwargs[x.arg[0]] = tuple(x.arg[1]) if len(x.arg[1]) > 1 else x.arg[1][0]
2126+
linker = driver.Linker.new(cc=self.cc, **linker_kwargs)
21302127
linker.add_ptx(ptx.encode())
21312128
cubin = linker.complete()
21322129

@@ -5325,6 +5322,21 @@ def var_list(self, args):
53255322
args[0].append(args[1])
53265323
return args[0]
53275324

5325+
def number_list(self, args):
5326+
if config.DEBUG_OPENMP >= 1:
5327+
print("visit number_list", args, type(args))
5328+
if len(args) == 1:
5329+
return args
5330+
else:
5331+
args[0].append(args[1])
5332+
return args[0]
5333+
5334+
def ompx_attribute(self, args):
5335+
if config.DEBUG_OPENMP >= 1:
5336+
print("visit ompx_attribute", args, type(args), args[0])
5337+
(_, attr, number_list) = args
5338+
return openmp_tag("QUAL.OMP.OMPX_ATTRIBUTE", (attr, number_list))
5339+
53285340
def PLUS(self, args):
53295341
if config.DEBUG_OPENMP >= 1:
53305342
print("visit PLUS", args, type(args))
@@ -5617,6 +5629,7 @@ def NUMBER(self, args):
56175629
| allocate_clause
56185630
| depend_with_modifier_clause
56195631
// | uses_allocators_clause
5632+
| ompx_attribute
56205633
teams_clause: num_teams_clause
56215634
| thread_limit_clause
56225635
| data_default_clause
@@ -5668,6 +5681,7 @@ def NUMBER(self, args):
56685681
// | simdlen_clause
56695682
| aligned_clause
56705683
// | nontemporal_clause
5684+
| ompx_attribute
56715685
56725686
target_teams_distribute_parallel_for_directive: TARGET TEAMS DISTRIBUTE PARALLEL FOR [target_teams_distribute_parallel_for_clause*]
56735687
target_teams_distribute_parallel_for_clause: if_clause
@@ -5697,9 +5711,12 @@ def NUMBER(self, args):
56975711
| ORDERED
56985712
// | order_clause
56995713
| dist_schedule_clause
5714+
| ompx_attribute
57005715
57015716
LOOP: "loop"
57025717
5718+
ompx_attribute: OMPX_ATTRIBUTE "(" PYTHON_NAME "(" number_list ")" ")"
5719+
OMPX_ATTRIBUTE: "ompx_attribute"
57035720
//target_teams_loop_directive: TARGET TEAMS LOOP [target_teams_loop_clause*]
57045721
target_teams_loop_directive: TARGET TEAMS LOOP [target_teams_distribute_parallel_for_simd_clause*]
57055722
target_teams_loop_clause: if_clause
@@ -5723,6 +5740,7 @@ def NUMBER(self, args):
57235740
| collapse_clause
57245741
| ORDERED
57255742
| data_privatization_out_clause
5743+
| ompx_attribute
57265744
57275745
target_teams_directive: TARGET TEAMS [target_teams_clause*]
57285746
target_teams_clause: if_clause
@@ -5742,6 +5760,7 @@ def NUMBER(self, args):
57425760
| data_default_clause
57435761
| data_sharing_clause
57445762
// | reduction_default_only_clause
5763+
| ompx_attribute
57455764
57465765
target_teams_distribute_directive: TARGET TEAMS DISTRIBUTE [target_teams_clause*]
57475766
target_teams_distribute_clause: num_teams_clause
@@ -5755,6 +5774,7 @@ def NUMBER(self, args):
57555774
| data_privatization_out_clause
57565775
| collapse_clause
57575776
| dist_schedule_clause
5777+
| ompx_attribute
57585778
57595779
IS_DEVICE_PTR: "is_device_ptr"
57605780
is_device_ptr_clause: IS_DEVICE_PTR "(" var_list ")"
@@ -5926,6 +5946,7 @@ def NUMBER(self, args):
59265946
slice_list: oslice | slice_list "," oslice
59275947
name_slice: PYTHON_NAME [ "[" slice_list "]" ]
59285948
var_list: name_slice | var_list "," name_slice
5949+
number_list: NUMBER | number_list "," NUMBER
59295950
PLUS: "+"
59305951
reduction_operator: PLUS | "\\" | "*" | "-" | "&" | "^" | "|" | "&&" | "||"
59315952
threadprivate_directive: "threadprivate" "(" var_list ")"

0 commit comments

Comments
 (0)