|
| 1 | +from numba import compiler, sigutils, types, dispatcher |
| 2 | +from numba.targets.descriptors import TargetDescriptor |
| 3 | +from numba.targets.options import TargetOptions |
| 4 | +from numba import typing, config |
| 5 | +from numba.targets.base import BaseContext |
| 6 | +from numba.targets import csa as tcsa |
| 7 | +from numba.targets import registry |
| 8 | +from numba.ir_utils import get_call_table |
| 9 | +from numba.npyufunc import dufunc |
| 10 | + |
| 11 | +import pdb |
| 12 | +import copy |
| 13 | + |
| 14 | +class CSATypingContext(typing.BaseContext): |
| 15 | + def load_additional_registries(self): |
| 16 | + pass |
| 17 | + |
| 18 | + def resolve_value_type(self, val): |
| 19 | + if isinstance(val, dispatcher.Dispatcher): |
| 20 | + try: |
| 21 | + val = val.__csajitdevice |
| 22 | + except AttributeError: |
| 23 | + if not val._can_compile: |
| 24 | + raise ValueError("using cpu function on device but its compilation is disabled") |
| 25 | + jd = jitdevice(val, debug=val.targetoptions.get('debug')) |
| 26 | + val.__csajitdevice = jd |
| 27 | + val = jd |
| 28 | + return super(CSATypingContext, self).resolve_value_type(val) |
| 29 | + |
| 30 | + def load_additional_registries(self): |
| 31 | + from numba.typing import (cffi_utils, cmathdecl, enumdecl, listdecl, mathdecl, |
| 32 | + npydecl, operatordecl, randomdecl, setdecl) |
| 33 | + self.install_registry(cffi_utils.registry) |
| 34 | + self.install_registry(cmathdecl.registry) |
| 35 | + self.install_registry(enumdecl.registry) |
| 36 | + self.install_registry(listdecl.registry) |
| 37 | + self.install_registry(mathdecl.registry) |
| 38 | + self.install_registry(npydecl.registry) |
| 39 | + self.install_registry(operatordecl.registry) |
| 40 | + self.install_registry(randomdecl.registry) |
| 41 | + self.install_registry(setdecl.registry) |
| 42 | + |
| 43 | +class CSATargetOptions(TargetOptions): |
| 44 | + OPTIONS = {} |
| 45 | + |
| 46 | +class CSATargetDesc(TargetDescriptor): |
| 47 | + options = CSATargetOptions |
| 48 | + typingctx = CSATypingContext() |
| 49 | + targetctx = tcsa.CSAContext(typingctx) |
| 50 | + |
| 51 | +def compile_csa(func_ir, return_type, args, inflags): |
| 52 | + if config.DEBUG_CSA: |
| 53 | + print("compile_csa", func_ir, return_type, args) |
| 54 | + |
| 55 | + cput = registry.dispatcher_registry['cpu'].targetdescr |
| 56 | + typingctx = cput.typing_context |
| 57 | +# typingctx = CSATargetDesc.typingctx |
| 58 | + targetctx = CSATargetDesc.targetctx |
| 59 | + flags = compiler.Flags() |
| 60 | + #flags = copy.copy(inflags) |
| 61 | + if inflags.noalias: |
| 62 | + if config.DEBUG_CSA: |
| 63 | + print("Propagating noalias") |
| 64 | + flags.set('noalias') |
| 65 | + flags.set('no_compile') |
| 66 | + flags.set('no_cpython_wrapper') |
| 67 | + |
| 68 | + (call_table, _) = get_call_table(func_ir.blocks) |
| 69 | + if config.DEBUG_CSA: |
| 70 | + print("call_table", call_table) |
| 71 | + for key, value in call_table.items(): |
| 72 | + if len(value) == 1: |
| 73 | + value = value[0] |
| 74 | + if isinstance(value, dufunc.DUFunc): |
| 75 | + if config.DEBUG_CSA: |
| 76 | + print("Calling install_cg on", value) |
| 77 | + value._install_cg(targetctx) |
| 78 | + |
| 79 | + #pdb.set_trace() |
| 80 | + cres = compiler.compile_ir(typingctx, |
| 81 | + targetctx, |
| 82 | + func_ir, |
| 83 | + args, |
| 84 | + return_type, |
| 85 | + flags, |
| 86 | + locals={}) |
| 87 | + library = cres.library |
| 88 | + library.finalize() |
| 89 | + |
| 90 | + if config.DEBUG_CSA: |
| 91 | + print("compile_csa cres", cres, type(cres)) |
| 92 | + print("LLVM") |
| 93 | + |
| 94 | + llvm_str = cres.library.get_llvm_str() |
| 95 | + llvm_out = "compile_csa" + ".ll" |
| 96 | + print(llvm_out, llvm_str) |
| 97 | + with open(llvm_out, "w") as llvm_file: |
| 98 | + llvm_file.write(llvm_str) |
| 99 | + |
| 100 | + return cres |
| 101 | + |
| 102 | +#class CachedCSAAsm(object): |
| 103 | +# def __init__(self, name, llvmir, options): |
| 104 | +# print("CachedCSAAsm::__init__", name, llvmir, options) |
| 105 | +# self.name = name |
| 106 | +# self.llvmir = llvmir |
| 107 | +# self._extra_options = options.copy() |
| 108 | +# |
| 109 | +# def get(self): |
| 110 | +# print("CachedCSAAsm", self.name) |
| 111 | +# print(self.llvmir) |
| 112 | +# targetctx = CSATargetDesc.targetctx |
| 113 | +# print("targetctx", targetctx, type(targetctx)) |
| 114 | +# cg = targetctx.codegen() |
| 115 | +# print("cg", cg, type(cg)) |
| 116 | +## lib = cg.create_library(self.name) |
| 117 | +## csa_asm_name = self.name + '.csa_asm.s' |
| 118 | +## asm_str = lib.get_asm_str(csa_asm_name) |
| 119 | +# return csa_asm_name |
| 120 | +# |
| 121 | +#class CachedCSACUFunction(object): |
| 122 | +# def __init__(self, entry_name, csa_asm, linking): |
| 123 | +# self.entry_name = entry_name |
| 124 | +# self.csa_asm = csa_asm |
| 125 | +# self.linking = linking |
| 126 | +# |
| 127 | +# def get(self): |
| 128 | +# csa_asm = self.csa_asm.get() |
| 129 | +# return csa_asm |
| 130 | + |
| 131 | +class CSAKernel(object): |
| 132 | + def __init__(self, llvm_module, library, wrapper_module, kernel, wrapfnty, |
| 133 | + name, pretty_name, argtypes, call_helper, |
| 134 | + link=(), debug=False, fastmath=False, type_annotation=None): |
| 135 | + options = {'debug': debug} |
| 136 | + if fastmath: |
| 137 | + options.update(dict(ftz=True, |
| 138 | + prec_sqrt=False, |
| 139 | + prec_div=False, |
| 140 | + fma=True)) |
| 141 | + |
| 142 | +# csa_asm = CachedCSAAsm(pretty_name, str(llvm_module), options=options) |
| 143 | +# cufunc = CachedCSACUFunction(name, csa_asm, link) |
| 144 | + self.kernel = kernel |
| 145 | + self.wrapfnty = wrapfnty |
| 146 | + # populate members |
| 147 | + self.entry_name = name |
| 148 | + self.argument_types = tuple(argtypes) |
| 149 | + self.linking = tuple(link) |
| 150 | + self._type_annotation = type_annotation |
| 151 | +# self._func = cufunc |
| 152 | + self.debug = debug |
| 153 | + self.call_helper = call_helper |
| 154 | + self.llvm_module = llvm_module |
| 155 | + self.wrapper_module = wrapper_module |
| 156 | + self.library = library |
| 157 | + |
| 158 | + def __repr__(self): |
| 159 | + ret = "CSAKernel object\nself.kernel\n" + str(self.kernel) + str(type(self.kernel)) |
| 160 | + ret += "\nself.library " + str(self.library) + " type=" + str(type(self.library)) |
| 161 | + ret += "\nself.wrapfnty " + str(self.wrapfnty) + " type=" + str(type(self.wrapfnty)) |
| 162 | + ret += "\nself.entry_name " + str(self.entry_name) + " type=" + str(type(self.entry_name)) |
| 163 | + ret += "\nself.argument_types " + str(self.argument_types) + " type=" + str(type(self.argument_types)) |
| 164 | + return ret |
| 165 | + |
| 166 | +# def __call__(self, *args, **kwargs): |
| 167 | +# assert not kwargs |
| 168 | +# self._kernel_call(args=args, |
| 169 | +# griddim=self.griddim, |
| 170 | +# blockdim=self.blockdim, |
| 171 | +# stream=self.stream, |
| 172 | +# sharedmem=self.sharedmem) |
| 173 | + |
| 174 | +# def bind(self): |
| 175 | +# print("self._func", type(self._func)) |
| 176 | +# self._func.get() |
| 177 | + |
| 178 | +# @property |
| 179 | +# def csa_asm(self): |
| 180 | +# return self._func.csa_asm.get().decode('utf8') |
| 181 | + |
| 182 | +def compile_csa_kernel(func_ir, args, flags, link, fastmath=False): |
| 183 | + cres = compile_csa(func_ir, types.void, args, flags) |
| 184 | + fname = cres.fndesc.llvm_func_name |
| 185 | + lib, kernel, wrapfnty, wrapper_library = cres.target_context.prepare_csa_kernel(cres.library, fname, cres.signature.args) |
| 186 | + if config.DEBUG_CSA: |
| 187 | + print("compile_csa_kernel", func_ir, args, link, fastmath) |
| 188 | + print("fname", fname, type(fname)) |
| 189 | + print("cres.library", cres.library, type(cres.library)) |
| 190 | + print("cres.signature", cres.signature, type(cres.signature)) |
| 191 | + print("lib", lib, type(lib)) |
| 192 | + print("kernel", kernel, type(kernel)) |
| 193 | + print("wrapfnty", wrapfnty, type(wrapfnty)) |
| 194 | + csakern = CSAKernel(llvm_module=lib._final_module, |
| 195 | + library=wrapper_library, |
| 196 | + wrapper_module=wrapper_library._final_module, |
| 197 | + kernel=kernel, wrapfnty=wrapfnty, name=kernel.name, |
| 198 | + pretty_name=cres.fndesc.qualname, |
| 199 | + argtypes=args, |
| 200 | + type_annotation=cres.type_annotation, |
| 201 | + link=link, |
| 202 | + debug=False, |
| 203 | + call_helper=cres.call_helper, |
| 204 | + fastmath=fastmath) |
| 205 | + return csakern |
| 206 | + |
| 207 | +class AutoJitCSAKernel(object): |
| 208 | + def __init__(self, func_ir, bind, flags, targetoptions): |
| 209 | + self.func_ir = func_ir |
| 210 | + self.bind = bind |
| 211 | + self.flags = flags |
| 212 | + self.targetoptions = targetoptions |
| 213 | + self.typingctx = CSATargetDesc.typingctx |
| 214 | + self.definitions = {} |
| 215 | + |
| 216 | + def __call__(self, *args): |
| 217 | + kernel = self.specialize(*args) |
| 218 | + cfg(*args) |
| 219 | + |
| 220 | + def specialize(self, *args): |
| 221 | + argtypes = tuple([self.typingctx.resolve_argument_type(a) for a in args]) |
| 222 | + return self.compile(argtypes) |
| 223 | + |
| 224 | + def compile(self, sig): |
| 225 | + argtypes, return_type = sigutils.normalize_signature(sig) |
| 226 | + assert return_type is None |
| 227 | + kernel = self.definitions.get(argtypes) |
| 228 | + if kernel is None: |
| 229 | + if 'link' not in self.targetoptions: |
| 230 | + self.targetoptions['link'] = () |
| 231 | + kernel = compile_csa_kernel(self.func_ir, argtypes, self.flags, |
| 232 | + **self.targetoptions) |
| 233 | + self.definitions[argtypes] = kernel |
| 234 | +# if self.bind: |
| 235 | +# kernel.bind() |
| 236 | + return kernel |
| 237 | + |
| 238 | + def inspect_llvm(self, signature=None): |
| 239 | + if signature is not None: |
| 240 | + return self.definitions[signature].inspect_llvm() |
| 241 | + else: |
| 242 | + return dict((sig, defn.inspect_llvm()) |
| 243 | + for sig, defn in self.definitions.items()) |
| 244 | + |
| 245 | + def inspect_asm(self, signature=None): |
| 246 | + ''' |
| 247 | + Return the generated assembly code for all signatures encountered thus |
| 248 | + far, or the LLVM IR for a specific signature if given. |
| 249 | + ''' |
| 250 | + if signature is not None: |
| 251 | + return self.definitions[signature].inspect_asm() |
| 252 | + else: |
| 253 | + return dict((sig, defn.inspect_asm()) |
| 254 | + for sig, defn in self.definitions.items()) |
| 255 | + |
| 256 | + def inspect_types(self, file=None): |
| 257 | + ''' |
| 258 | + Produce a dump of the Python source of this function annotated with the |
| 259 | + corresponding Numba IR and type information. The dump is written to |
| 260 | + *file*, or *sys.stdout* if *file* is *None*. |
| 261 | + ''' |
| 262 | + if file is None: |
| 263 | + file = sys.stdout |
| 264 | + |
| 265 | + for ver, defn in utils.iteritems(self.definitions): |
| 266 | + defn.inspect_types(file=file) |
0 commit comments