Skip to content
This repository was archived by the owner on Jan 25, 2023. It is now read-only.

Commit 2a9fdeb

Browse files
committed
CSA changes.
1 parent ca731bf commit 2a9fdeb

File tree

9 files changed

+1335
-57
lines changed

9 files changed

+1335
-57
lines changed

numba/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,8 @@ def optional_str(x):
193193
# print debug info of analysis and optimization on array operations
194194
DEBUG_ARRAY_OPT = _readenv("NUMBA_DEBUG_ARRAY_OPT", int, 0)
195195

196+
DEBUG_CSA = _readenv("NUMBA_DEBUG_CSA", int, 0)
197+
196198
# insert debug stmts to print information at runtime
197199
DEBUG_ARRAY_OPT_RUNTIME = _readenv(
198200
"NUMBA_DEBUG_ARRAY_OPT_RUNTIME", int, 0)

numba/lowering.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ def from_fndesc(cls, fndesc):
3838
cls._memo[fndesc.env_name] = inst
3939
return inst
4040

41+
@classmethod
42+
def from_mod(cls, mod):
43+
return cls(mod.__dict__)
44+
4145
def __reduce__(self):
4246
return _rebuild_env, (
4347
self.globals['__name__'],

numba/npyufunc/csa.py

Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
from numba import compiler, sigutils, types, dispatcher
2+
from numba.targets.descriptors import TargetDescriptor
3+
from numba.targets.options import TargetOptions
4+
from numba import typing, config
5+
from numba.targets.base import BaseContext
6+
from numba.targets import csa as tcsa
7+
from numba.targets import registry
8+
from numba.ir_utils import get_call_table
9+
from numba.npyufunc import dufunc
10+
11+
import pdb
12+
import copy
13+
14+
class CSATypingContext(typing.BaseContext):
15+
def load_additional_registries(self):
16+
pass
17+
18+
def resolve_value_type(self, val):
19+
if isinstance(val, dispatcher.Dispatcher):
20+
try:
21+
val = val.__csajitdevice
22+
except AttributeError:
23+
if not val._can_compile:
24+
raise ValueError("using cpu function on device but its compilation is disabled")
25+
jd = jitdevice(val, debug=val.targetoptions.get('debug'))
26+
val.__csajitdevice = jd
27+
val = jd
28+
return super(CSATypingContext, self).resolve_value_type(val)
29+
30+
def load_additional_registries(self):
31+
from numba.typing import (cffi_utils, cmathdecl, enumdecl, listdecl, mathdecl,
32+
npydecl, operatordecl, randomdecl, setdecl)
33+
self.install_registry(cffi_utils.registry)
34+
self.install_registry(cmathdecl.registry)
35+
self.install_registry(enumdecl.registry)
36+
self.install_registry(listdecl.registry)
37+
self.install_registry(mathdecl.registry)
38+
self.install_registry(npydecl.registry)
39+
self.install_registry(operatordecl.registry)
40+
self.install_registry(randomdecl.registry)
41+
self.install_registry(setdecl.registry)
42+
43+
class CSATargetOptions(TargetOptions):
44+
OPTIONS = {}
45+
46+
class CSATargetDesc(TargetDescriptor):
47+
options = CSATargetOptions
48+
typingctx = CSATypingContext()
49+
targetctx = tcsa.CSAContext(typingctx)
50+
51+
def compile_csa(func_ir, return_type, args, inflags):
52+
if config.DEBUG_CSA:
53+
print("compile_csa", func_ir, return_type, args)
54+
55+
cput = registry.dispatcher_registry['cpu'].targetdescr
56+
typingctx = cput.typing_context
57+
# typingctx = CSATargetDesc.typingctx
58+
targetctx = CSATargetDesc.targetctx
59+
flags = compiler.Flags()
60+
#flags = copy.copy(inflags)
61+
if inflags.noalias:
62+
if config.DEBUG_CSA:
63+
print("Propagating noalias")
64+
flags.set('noalias')
65+
flags.set('no_compile')
66+
flags.set('no_cpython_wrapper')
67+
68+
(call_table, _) = get_call_table(func_ir.blocks)
69+
if config.DEBUG_CSA:
70+
print("call_table", call_table)
71+
for key, value in call_table.items():
72+
if len(value) == 1:
73+
value = value[0]
74+
if isinstance(value, dufunc.DUFunc):
75+
if config.DEBUG_CSA:
76+
print("Calling install_cg on", value)
77+
value._install_cg(targetctx)
78+
79+
#pdb.set_trace()
80+
cres = compiler.compile_ir(typingctx,
81+
targetctx,
82+
func_ir,
83+
args,
84+
return_type,
85+
flags,
86+
locals={})
87+
library = cres.library
88+
library.finalize()
89+
90+
if config.DEBUG_CSA:
91+
print("compile_csa cres", cres, type(cres))
92+
print("LLVM")
93+
94+
llvm_str = cres.library.get_llvm_str()
95+
llvm_out = "compile_csa" + ".ll"
96+
print(llvm_out, llvm_str)
97+
with open(llvm_out, "w") as llvm_file:
98+
llvm_file.write(llvm_str)
99+
100+
return cres
101+
102+
#class CachedCSAAsm(object):
103+
# def __init__(self, name, llvmir, options):
104+
# print("CachedCSAAsm::__init__", name, llvmir, options)
105+
# self.name = name
106+
# self.llvmir = llvmir
107+
# self._extra_options = options.copy()
108+
#
109+
# def get(self):
110+
# print("CachedCSAAsm", self.name)
111+
# print(self.llvmir)
112+
# targetctx = CSATargetDesc.targetctx
113+
# print("targetctx", targetctx, type(targetctx))
114+
# cg = targetctx.codegen()
115+
# print("cg", cg, type(cg))
116+
## lib = cg.create_library(self.name)
117+
## csa_asm_name = self.name + '.csa_asm.s'
118+
## asm_str = lib.get_asm_str(csa_asm_name)
119+
# return csa_asm_name
120+
#
121+
#class CachedCSACUFunction(object):
122+
# def __init__(self, entry_name, csa_asm, linking):
123+
# self.entry_name = entry_name
124+
# self.csa_asm = csa_asm
125+
# self.linking = linking
126+
#
127+
# def get(self):
128+
# csa_asm = self.csa_asm.get()
129+
# return csa_asm
130+
131+
class CSAKernel(object):
132+
def __init__(self, llvm_module, library, wrapper_module, kernel, wrapfnty,
133+
name, pretty_name, argtypes, call_helper,
134+
link=(), debug=False, fastmath=False, type_annotation=None):
135+
options = {'debug': debug}
136+
if fastmath:
137+
options.update(dict(ftz=True,
138+
prec_sqrt=False,
139+
prec_div=False,
140+
fma=True))
141+
142+
# csa_asm = CachedCSAAsm(pretty_name, str(llvm_module), options=options)
143+
# cufunc = CachedCSACUFunction(name, csa_asm, link)
144+
self.kernel = kernel
145+
self.wrapfnty = wrapfnty
146+
# populate members
147+
self.entry_name = name
148+
self.argument_types = tuple(argtypes)
149+
self.linking = tuple(link)
150+
self._type_annotation = type_annotation
151+
# self._func = cufunc
152+
self.debug = debug
153+
self.call_helper = call_helper
154+
self.llvm_module = llvm_module
155+
self.wrapper_module = wrapper_module
156+
self.library = library
157+
158+
def __repr__(self):
159+
ret = "CSAKernel object\nself.kernel\n" + str(self.kernel) + str(type(self.kernel))
160+
ret += "\nself.library " + str(self.library) + " type=" + str(type(self.library))
161+
ret += "\nself.wrapfnty " + str(self.wrapfnty) + " type=" + str(type(self.wrapfnty))
162+
ret += "\nself.entry_name " + str(self.entry_name) + " type=" + str(type(self.entry_name))
163+
ret += "\nself.argument_types " + str(self.argument_types) + " type=" + str(type(self.argument_types))
164+
return ret
165+
166+
# def __call__(self, *args, **kwargs):
167+
# assert not kwargs
168+
# self._kernel_call(args=args,
169+
# griddim=self.griddim,
170+
# blockdim=self.blockdim,
171+
# stream=self.stream,
172+
# sharedmem=self.sharedmem)
173+
174+
# def bind(self):
175+
# print("self._func", type(self._func))
176+
# self._func.get()
177+
178+
# @property
179+
# def csa_asm(self):
180+
# return self._func.csa_asm.get().decode('utf8')
181+
182+
def compile_csa_kernel(func_ir, args, flags, link, fastmath=False):
183+
cres = compile_csa(func_ir, types.void, args, flags)
184+
fname = cres.fndesc.llvm_func_name
185+
lib, kernel, wrapfnty, wrapper_library = cres.target_context.prepare_csa_kernel(cres.library, fname, cres.signature.args)
186+
if config.DEBUG_CSA:
187+
print("compile_csa_kernel", func_ir, args, link, fastmath)
188+
print("fname", fname, type(fname))
189+
print("cres.library", cres.library, type(cres.library))
190+
print("cres.signature", cres.signature, type(cres.signature))
191+
print("lib", lib, type(lib))
192+
print("kernel", kernel, type(kernel))
193+
print("wrapfnty", wrapfnty, type(wrapfnty))
194+
csakern = CSAKernel(llvm_module=lib._final_module,
195+
library=wrapper_library,
196+
wrapper_module=wrapper_library._final_module,
197+
kernel=kernel, wrapfnty=wrapfnty, name=kernel.name,
198+
pretty_name=cres.fndesc.qualname,
199+
argtypes=args,
200+
type_annotation=cres.type_annotation,
201+
link=link,
202+
debug=False,
203+
call_helper=cres.call_helper,
204+
fastmath=fastmath)
205+
return csakern
206+
207+
class AutoJitCSAKernel(object):
208+
def __init__(self, func_ir, bind, flags, targetoptions):
209+
self.func_ir = func_ir
210+
self.bind = bind
211+
self.flags = flags
212+
self.targetoptions = targetoptions
213+
self.typingctx = CSATargetDesc.typingctx
214+
self.definitions = {}
215+
216+
def __call__(self, *args):
217+
kernel = self.specialize(*args)
218+
cfg(*args)
219+
220+
def specialize(self, *args):
221+
argtypes = tuple([self.typingctx.resolve_argument_type(a) for a in args])
222+
return self.compile(argtypes)
223+
224+
def compile(self, sig):
225+
argtypes, return_type = sigutils.normalize_signature(sig)
226+
assert return_type is None
227+
kernel = self.definitions.get(argtypes)
228+
if kernel is None:
229+
if 'link' not in self.targetoptions:
230+
self.targetoptions['link'] = ()
231+
kernel = compile_csa_kernel(self.func_ir, argtypes, self.flags,
232+
**self.targetoptions)
233+
self.definitions[argtypes] = kernel
234+
# if self.bind:
235+
# kernel.bind()
236+
return kernel
237+
238+
def inspect_llvm(self, signature=None):
239+
if signature is not None:
240+
return self.definitions[signature].inspect_llvm()
241+
else:
242+
return dict((sig, defn.inspect_llvm())
243+
for sig, defn in self.definitions.items())
244+
245+
def inspect_asm(self, signature=None):
246+
'''
247+
Return the generated assembly code for all signatures encountered thus
248+
far, or the LLVM IR for a specific signature if given.
249+
'''
250+
if signature is not None:
251+
return self.definitions[signature].inspect_asm()
252+
else:
253+
return dict((sig, defn.inspect_asm())
254+
for sig, defn in self.definitions.items())
255+
256+
def inspect_types(self, file=None):
257+
'''
258+
Produce a dump of the Python source of this function annotated with the
259+
corresponding Numba IR and type information. The dump is written to
260+
*file*, or *sys.stdout* if *file* is *None*.
261+
'''
262+
if file is None:
263+
file = sys.stdout
264+
265+
for ver, defn in utils.iteritems(self.definitions):
266+
defn.inspect_types(file=file)

0 commit comments

Comments
 (0)