Skip to content
This repository was archived by the owner on Jan 25, 2023. It is now read-only.

Commit b4a325d

Browse files
committed
Changes
1 parent b811057 commit b4a325d

File tree

5 files changed

+108
-9
lines changed

5 files changed

+108
-9
lines changed

numba/npyufunc/csa.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import pdb
1212
import copy
13+
import sys
1314

1415
class CSATypingContext(typing.BaseContext):
1516
def load_additional_registries(self):
@@ -51,6 +52,7 @@ class CSATargetDesc(TargetDescriptor):
5152
def compile_csa(func_ir, return_type, args, inflags):
5253
if config.DEBUG_CSA:
5354
print("compile_csa", func_ir, return_type, args)
55+
sys.stdout.flush()
5456

5557
cput = registry.dispatcher_registry['cpu'].targetdescr
5658
typingctx = cput.typing_context
@@ -85,11 +87,15 @@ def compile_csa(func_ir, return_type, args, inflags):
8587
flags,
8688
locals={})
8789
library = cres.library
90+
if config.DEBUG_CSA:
91+
print("library", library, type(library))
92+
sys.stdout.flush()
8893
library.finalize()
8994

9095
if config.DEBUG_CSA:
9196
print("compile_csa cres", cres, type(cres))
9297
print("LLVM")
98+
sys.stdout.flush()
9399

94100
llvm_str = cres.library.get_llvm_str()
95101
llvm_out = "compile_csa" + ".ll"
@@ -191,6 +197,7 @@ def compile_csa_kernel(func_ir, args, flags, link, fastmath=False):
191197
print("lib", lib, type(lib))
192198
print("kernel", kernel, type(kernel))
193199
print("wrapfnty", wrapfnty, type(wrapfnty))
200+
sys.stdout.flush()
194201
csakern = CSAKernel(llvm_module=lib._final_module,
195202
library=wrapper_library,
196203
wrapper_module=wrapper_library._final_module,
@@ -228,6 +235,9 @@ def compile(self, sig):
228235
if kernel is None:
229236
if 'link' not in self.targetoptions:
230237
self.targetoptions['link'] = ()
238+
if config.DEBUG_CSA:
239+
print("Before compile_csa_kernel.")
240+
sys.stdout.flush()
231241
kernel = compile_csa_kernel(self.func_ir, argtypes, self.flags,
232242
**self.targetoptions)
233243
self.definitions[argtypes] = kernel

numba/npyufunc/parfor.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def lower(self, lowerer):
218218
builder = lowerer.builder
219219
library = lowerer.library
220220

221-
assert self.start_spm.prs_var != None
221+
assert self.start_spmd.prs_var != None
222222

223223
llvm_int32_t = lc.Type.int(32)
224224
fnty = lir.FunctionType(lc.Type.void(), [llvm_int32_t])
@@ -1123,9 +1123,9 @@ def _create_gufunc_for_parfor_body(
11231123
if config.DEBUG_CSA:
11241124
print("eachdim", eachdim)
11251125
if targetctx.auto_parallel.csa:
1126-
for indent in range(eachdim + 1):
1127-
gufunc_txt += " "
1128-
gufunc_txt += "numba.npyufunc.parfor.par_start_spmd(" + str(eachdim) + ", 4)\n"
1126+
# for indent in range(eachdim + 1):
1127+
# gufunc_txt += " "
1128+
# gufunc_txt += "numba.npyufunc.parfor.par_start_spmd(" + str(eachdim) + ", 4)\n"
11291129
for indent in range(eachdim + 1):
11301130
gufunc_txt += " "
11311131
gufunc_txt += "numba.npyufunc.parfor.par_start_region(" + str(eachdim) + ")\n"
@@ -1179,9 +1179,9 @@ def _create_gufunc_for_parfor_body(
11791179
for indent in range(eachdim + 1):
11801180
gufunc_txt += " "
11811181
gufunc_txt += "numba.npyufunc.parfor.par_end_region(" + str(eachdim) + ")\n"
1182-
for indent in range(eachdim + 1):
1183-
gufunc_txt += " "
1184-
gufunc_txt += "numba.npyufunc.parfor.par_end_spmd(" + str(eachdim) + ")\n"
1182+
# for indent in range(eachdim + 1):
1183+
# gufunc_txt += " "
1184+
# gufunc_txt += "numba.npyufunc.parfor.par_end_spmd(" + str(eachdim) + ")\n"
11851185

11861186

11871187
# Add assignments of reduction variables (for returning the value)
@@ -1285,7 +1285,7 @@ def _create_gufunc_for_parfor_body(
12851285
elif callname[0] == 'par_start_spmd':
12861286
pss_loc = gufunc_ir._definitions[rhs.args[0].name][0].value
12871287
pnum_threads = gufunc_ir._definitions[rhs.args[1].name][0].value
1288-
pspmd_dict[pss_loc] = parallel_spmd_start(loc, pnum_threads)
1288+
pspmd_dict[pss_loc] = parallel_spmd_start(pnum_threads, loc)
12891289
new_block.append(pspmd_dict[pss_loc])
12901290
continue
12911291
elif callname[0] == 'par_end_spmd':
@@ -1428,6 +1428,10 @@ def _create_gufunc_for_parfor_body(
14281428
gufunc_ir.blocks = rename_labels(gufunc_ir.blocks)
14291429
remove_dels(gufunc_ir.blocks)
14301430

1431+
if config.DEBUG_ARRAY_OPT:
1432+
print("flush")
1433+
sys.stdout.flush()
1434+
14311435
if config.DEBUG_ARRAY_OPT:
14321436
print("gufunc_ir last dump")
14331437
gufunc_ir.dump()
@@ -1450,6 +1454,8 @@ def _create_gufunc_for_parfor_body(
14501454

14511455
if config.DEBUG_CSA:
14521456
print("Before compile gufunc.")
1457+
if config.DEBUG_ARRAY_OPT:
1458+
sys.stdout.flush()
14531459
if targetctx.auto_parallel.csa:
14541460
ajck = csa.AutoJitCSAKernel(gufunc_ir, True, flags, {})
14551461
# Returns CSAKernel
@@ -1458,6 +1464,7 @@ def _create_gufunc_for_parfor_body(
14581464
if config.DEBUG_CSA:
14591465
print("After compile gufunc.")
14601466
print("kernel_func", kernel_func, type(kernel_func))
1467+
sys.stdout.flush()
14611468
kernel_func = compiler.compile_result(typing_context=typingctx,
14621469
target_context=targetctx,
14631470
entry_point=kernel_func.kernel,

numba/targets/codegen.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from numba.compiler_lock import require_global_compiler_lock
1818

1919
import pdb
20+
import sys
2021

2122
_x86arch = frozenset(['x86', 'i386', 'i486', 'i586', 'i686', 'i786',
2223
'i886', 'i986'])
@@ -1075,6 +1076,7 @@ def get_asm_str(self, filename):
10751076
if config.DEBUG_CSA:
10761077
print("CSACodeLibrary::get_asm_str", filename)
10771078
print(self._final_module)
1079+
sys.stdout.flush()
10781080
return str(self._codegen._tm.emit_assembly_file(self._final_module, filename))
10791081

10801082
#class AOTCSACodegen(BaseCSACodegen):

numba/targets/csa.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import print_function, absolute_import
22

33
import sys
4+
import os
45

56
import llvmlite.llvmpy.core as lc
67
import llvmlite as ll
@@ -16,6 +17,7 @@
1617
from . import fastmathpass
1718
from llvmlite.llvmpy.core import (Type, Builder, LINKAGE_INTERNAL, Constant, ICMP_EQ)
1819
from llvmlite import ir
20+
from llvmlite import binding
1921
from ctypes import *
2022

2123

@@ -50,6 +52,23 @@ def init(self):
5052
self.is32bit = (utils.MACHINE_BITS == 32)
5153
self._internal_codegen = codegen.JITCSACodegen("numba.csa.exec")
5254

55+
dir_path = os.path.dirname(os.path.realpath(__file__))
56+
57+
#with open(dir_path + "/libcsac.bc", 'rb') as f:
58+
# bccode = f.read()
59+
# self._csamod = binding.parse_bitcode(bccode)
60+
61+
#with open(dir_path + "/libcsamath.bc", 'rb') as f:
62+
# bccode = f.read()
63+
# self._csamath = binding.parse_bitcode(bccode)
64+
65+
with open(dir_path + "/libcsamath_fixed.bc", 'rb') as f:
66+
bccode = f.read()
67+
self._csamath = binding.parse_bitcode(bccode)
68+
69+
self.sqrt32extern = "sp_sqrt_RN"
70+
self.sqrt64extern = "dp_sqrt_RN"
71+
5372
# Map external C functions.
5473
#externals.c_math_functions.install(self)
5574

@@ -152,6 +171,9 @@ def post_lowering(self, mod, library):
152171
intrinsics.fix_divmod(mod)
153172

154173
library.add_linking_library(rtsys.library)
174+
print("CSA post-lowering adding csa library.", library, type(library))
175+
library._final_module.link_in(self._csamod)
176+
library._final_module.link_in(self._csamath)
155177

156178
def create_cpython_wrapper(self, library, fndesc, env, call_helper,
157179
release_gil=False):
@@ -216,11 +238,22 @@ def prepare_csa_kernel(self, codelib, fname, argtypes):
216238
print("fname", fname, type(fname))
217239
print("argtypes", argtypes, type(argtypes))
218240
print("csa_asm_name", csa_asm_name, type(csa_asm_name))
241+
print("codelib", codelib, type(codelib))
242+
sys.stdout.flush()
219243

220244
codelib.get_asm_str(csa_asm_name)
245+
if config.DEBUG_CSA:
246+
print("After get_asm_str")
247+
sys.stdout.flush()
221248
library = self.codegen().create_library('')
249+
if config.DEBUG_CSA:
250+
print("After create_library")
251+
sys.stdout.flush()
222252
#library.add_linking_library(codelib)
223253
wrapper, wrapfnty, wrapper_library = self.generate_kernel_wrapper(library, fname, argtypes, csa_asm_name)
254+
if config.DEBUG_CSA:
255+
print("After generate_kernel_wrapper")
256+
sys.stdout.flush()
224257
return library, wrapper, wrapfnty, wrapper_library
225258

226259
def generate_kernel_wrapper(self, library, fname, argtypes, csa_asm_name):
@@ -229,7 +262,13 @@ def generate_kernel_wrapper(self, library, fname, argtypes, csa_asm_name):
229262
The function being wrapped have the name ``fname`` and argument types
230263
``argtypes``. The wrapper function is returned.
231264
"""
265+
if config.DEBUG_CSA:
266+
print("Before finalize in generate_kernel_wrapper")
267+
sys.stdout.flush()
232268
library.finalize()
269+
if config.DEBUG_CSA:
270+
print("After finalize in generate_kernel_wrapper")
271+
sys.stdout.flush()
233272
cput = registry.dispatcher_registry['cpu'].targetdescr
234273
context = cput.target_context
235274

@@ -260,6 +299,7 @@ def generate_kernel_wrapper(self, library, fname, argtypes, csa_asm_name):
260299
print("prefixed", prefixed)
261300
print("wrapfn", wrapfn)
262301
print("builder", builder)
302+
sys.stdout.flush()
263303

264304
ll.binding.load_library_permanently("/home/taanders/numba/numba_csa2/numba/numba/targets/libgeneric.so")
265305

numba/targets/mathimpl.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,45 @@ def float_impl(context, builder, sig, args):
168168

169169
return float_impl
170170

171+
def unary_sqrt_extern(fn, int_restype=False):
172+
"""
173+
Register implementations of Python function *fn* using the
174+
external function named *f32extern* and *f64extern* (for float32
175+
and float64 inputs, respectively).
176+
If *int_restype* is true, then the function's return value should be
177+
integral, otherwise floating-point.
178+
"""
179+
f_restype = types.int64 if int_restype else None
180+
181+
def float_impl(context, builder, sig, args):
182+
"""
183+
Implement *fn* for a types.Float input.
184+
"""
185+
[val] = args
186+
mod = builder.module
187+
input_type = sig.args[0]
188+
lty = context.get_value_type(input_type)
189+
sqrt32extern = getattr(context, 'sqrt32extern', "sqrtf")
190+
sqrt64extern = getattr(context, 'sqrt64extern', "sqrt")
191+
192+
func_name = {
193+
types.float32: sqrt32extern,
194+
types.float64: sqrt64extern,
195+
}[input_type]
196+
fnty = Type.function(lty, [lty])
197+
fn = cgutils.insert_pure_function(builder.module, fnty, name=func_name)
198+
res = builder.call(fn, (val,))
199+
res = context.cast(builder, res, input_type, sig.return_type)
200+
return impl_ret_untracked(context, builder, sig.return_type, res)
201+
202+
lower(fn, types.Float)(float_impl)
203+
204+
# Implement wrapper for integer inputs
205+
unary_math_int_impl(fn, float_impl)
206+
207+
return float_impl
208+
209+
171210

172211
unary_math_intr(math.fabs, lc.INTR_FABS)
173212
#unary_math_intr(math.sqrt, lc.INTR_SQRT)
@@ -206,7 +245,8 @@ def float_impl(context, builder, sig, args):
206245
ceil_impl = unary_math_extern(math.ceil, "ceilf", "ceil")
207246
floor_impl = unary_math_extern(math.floor, "floorf", "floor")
208247
gamma_impl = unary_math_extern(math.gamma, "numba_gammaf", "numba_gamma") # work-around
209-
sqrt_impl = unary_math_extern(math.sqrt, "sqrtf", "sqrt")
248+
sqrt_impl = unary_sqrt_extern(math.sqrt)
249+
#sqrt_impl = unary_math_extern(math.sqrt, "sqrtf", "sqrt")
210250
trunc_impl = unary_math_extern(math.trunc, "truncf", "trunc", True)
211251
lgamma_impl = unary_math_extern(math.lgamma, "lgammaf", "lgamma")
212252

0 commit comments

Comments
 (0)