This repository was archived by the owner on Jan 25, 2023. It is now read-only.
forked from numba/numba
-
Notifications
You must be signed in to change notification settings - Fork 6
Semantics "with dppl_context" #40
Draft
1e-to
wants to merge
2
commits into
pydppl
Choose a base branch
from
feature/new_with_semantics
base: pydppl
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from numba.core import dispatcher, compiler | ||
from numba.core.registry import cpu_target, dispatcher_registry | ||
import numba.dppl_config as dppl_config | ||
from numba.dppl.compiler import DPPLCompiler | ||
|
||
|
||
class GPUDispatcher(dispatcher.Dispatcher): | ||
targetdescr = cpu_target | ||
|
||
def __init__(self, py_func, locals={}, targetoptions={}, impl_kind='direct', pipeline_class=compiler.Compiler): | ||
if dppl_config.dppl_present: | ||
dispatcher.Dispatcher.__init__(self, py_func, locals=locals, | ||
targetoptions=targetoptions, impl_kind=impl_kind, pipeline_class=DPPLCompiler) | ||
else: | ||
print("---------------------------------------------------------------------") | ||
print("WARNING : DPPL pipeline ignored. Ensure OpenCL drivers are installed.") | ||
print("---------------------------------------------------------------------") | ||
dispatcher.Dispatcher.__init__(self, py_func, locals=locals, | ||
targetoptions=targetoptions, impl_kind=impl_kind, pipeline_class=pipeline_class) | ||
|
||
|
||
dispatcher_registry['gpu'] = GPUDispatcher |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
from numba.dppl.testing import unittest | ||
from numba.dppl.testing import DPPLTestCase | ||
from numba.dppl.withcontexts import dppl_context | ||
from numba.core import typing, cpu | ||
from numba.core.compiler import compile_ir, DEFAULT_FLAGS | ||
from numba.core.transforms import with_lifting | ||
from numba.core.registry import cpu_target | ||
from numba.core.bytecode import FunctionIdentity, ByteCode | ||
from numba.core.interpreter import Interpreter | ||
from numba.tests.support import captured_stdout | ||
from numba import njit, prange | ||
import numpy as np | ||
|
||
|
||
def get_func_ir(func): | ||
func_id = FunctionIdentity.from_function(func) | ||
bc = ByteCode(func_id=func_id) | ||
interp = Interpreter(func_id) | ||
func_ir = interp.interpret(bc) | ||
return func_ir | ||
|
||
|
||
def liftcall1(): | ||
x = 1 | ||
print("A", x) | ||
with dppl_context: | ||
x += 1 | ||
print("B", x) | ||
return x | ||
|
||
|
||
def liftcall2(): | ||
x = 1 | ||
print("A", x) | ||
with dppl_context: | ||
x += 1 | ||
print("B", x) | ||
with dppl_context: | ||
x += 10 | ||
print("C", x) | ||
return x | ||
|
||
|
||
def liftcall3(): | ||
x = 1 | ||
print("A", x) | ||
with dppl_context: | ||
if x > 0: | ||
x += 1 | ||
print("B", x) | ||
with dppl_context: | ||
for i in range(10): | ||
x += i | ||
print("C", x) | ||
return x | ||
|
||
|
||
class BaseTestWithLifting(DPPLTestCase): | ||
def setUp(self): | ||
super(BaseTestWithLifting, self).setUp() | ||
self.typingctx = typing.Context() | ||
self.targetctx = cpu.CPUContext(self.typingctx) | ||
self.flags = DEFAULT_FLAGS | ||
|
||
def check_extracted_with(self, func, expect_count, expected_stdout): | ||
the_ir = get_func_ir(func) | ||
new_ir, extracted = with_lifting( | ||
the_ir, self.typingctx, self.targetctx, self.flags, | ||
locals={}, | ||
) | ||
self.assertEqual(len(extracted), expect_count) | ||
cres = self.compile_ir(new_ir) | ||
|
||
with captured_stdout() as out: | ||
cres.entry_point() | ||
|
||
self.assertEqual(out.getvalue(), expected_stdout) | ||
|
||
def compile_ir(self, the_ir, args=(), return_type=None): | ||
typingctx = self.typingctx | ||
targetctx = self.targetctx | ||
flags = self.flags | ||
# Register the contexts in case for nested @jit or @overload calls | ||
with cpu_target.nested_context(typingctx, targetctx): | ||
return compile_ir(typingctx, targetctx, the_ir, args, | ||
return_type, flags, locals={}) | ||
|
||
|
||
class TestLiftCall(BaseTestWithLifting): | ||
|
||
def check_same_semantic(self, func): | ||
"""Ensure same semantic with non-jitted code | ||
""" | ||
jitted = njit(target="gpu")(func) | ||
with captured_stdout() as got: | ||
jitted() | ||
|
||
with captured_stdout() as expect: | ||
func() | ||
|
||
self.assertEqual(got.getvalue(), expect.getvalue()) | ||
|
||
def test_liftcall1(self): | ||
self.check_extracted_with(liftcall1, expect_count=1, | ||
expected_stdout="A 1\nB 2\n") | ||
self.check_same_semantic(liftcall1) | ||
|
||
def test_liftcall2(self): | ||
self.check_extracted_with(liftcall2, expect_count=2, | ||
expected_stdout="A 1\nB 2\nC 12\n") | ||
self.check_same_semantic(liftcall2) | ||
|
||
def test_liftcall3(self): | ||
self.check_extracted_with(liftcall3, expect_count=2, | ||
expected_stdout="A 1\nB 2\nC 47\n") | ||
self.check_same_semantic(liftcall3) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
from numba.core import compiler, typing, types, sigutils | ||
from numba.core.compiler_lock import global_compiler_lock | ||
from numba.core.dispatcher import _DispatcherBase | ||
from numba.core.transforms import find_region_inout_vars | ||
from numba.core.withcontexts import (WithContext, _mutate_with_block_callee, _mutate_with_block_caller, | ||
_clear_blocks) | ||
from numba.dppl.compiler import DPPLCompiler | ||
from numba.core.cpu_options import ParallelOptions | ||
|
||
|
||
class _DPPLContextType(WithContext): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you please explain what this functions are doing? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a class that creates a separate dispatcher for the lifted code with our semantics. This is needed to run the code with the 'offload': True option and our pipeline |
||
def mutate_with_body(self, func_ir, blocks, blk_start, blk_end, | ||
body_blocks, dispatcher_factory, extra): | ||
assert extra is None | ||
vlt = func_ir.variable_lifetime | ||
|
||
inputs, outputs = find_region_inout_vars( | ||
blocks=blocks, | ||
livemap=vlt.livemap, | ||
callfrom=blk_start, | ||
returnto=blk_end, | ||
body_block_ids=set(body_blocks), | ||
) | ||
|
||
lifted_blks = {k: blocks[k] for k in body_blocks} | ||
_mutate_with_block_callee(lifted_blks, blk_start, blk_end, | ||
inputs, outputs) | ||
|
||
# XXX: transform body-blocks to return the output variables | ||
lifted_ir = func_ir.derive( | ||
blocks=lifted_blks, | ||
arg_names=tuple(inputs), | ||
arg_count=len(inputs), | ||
force_non_generator=True, | ||
) | ||
|
||
dispatcher = dispatcher_factory(lifted_ir, dppl_mode=True) | ||
|
||
newblk = _mutate_with_block_caller( | ||
dispatcher, blocks, blk_start, blk_end, inputs, outputs, | ||
) | ||
|
||
blocks[blk_start] = newblk | ||
_clear_blocks(blocks, body_blocks) | ||
return dispatcher | ||
|
||
|
||
class DPPLLiftedCode(_DispatcherBase): | ||
""" | ||
Implementation of the hidden dispatcher objects used for lifted code | ||
(a lifted loop is really compiled as a separate function). | ||
""" | ||
_fold_args = False | ||
|
||
def __init__(self, func_ir, typingctx, targetctx, flags, locals): | ||
self.func_ir = func_ir | ||
self.lifted_from = None | ||
|
||
self.typingctx = typingctx | ||
self.targetctx = targetctx | ||
self.flags = flags | ||
self.locals = locals | ||
|
||
_DispatcherBase.__init__(self, self.func_ir.arg_count, | ||
self.func_ir.func_id.func, | ||
self.func_ir.func_id.pysig, | ||
can_fallback=True, | ||
exact_match_required=False) | ||
|
||
def get_source_location(self): | ||
"""Return the starting line number of the loop. | ||
""" | ||
return self.func_ir.loc.line | ||
|
||
def _pre_compile(self, args, return_type, flags): | ||
"""Pre-compile actions | ||
""" | ||
pass | ||
|
||
@global_compiler_lock | ||
def compile(self, sig): | ||
# Use counter to track recursion compilation depth | ||
with self._compiling_counter: | ||
# XXX this is mostly duplicated from Dispatcher. | ||
flags = self.flags | ||
args, return_type = sigutils.normalize_signature(sig) | ||
|
||
# Don't recompile if signature already exists | ||
# (e.g. if another thread compiled it before we got the lock) | ||
existing = self.overloads.get(tuple(args)) | ||
if existing is not None: | ||
return existing.entry_point | ||
|
||
self._pre_compile(args, return_type, flags) | ||
|
||
# Clone IR to avoid (some of the) mutation in the rewrite pass | ||
cloned_func_ir = self.func_ir.copy() | ||
|
||
flags.auto_parallel = ParallelOptions({'offload':True}) | ||
cres = compiler.compile_ir(typingctx=self.typingctx, | ||
targetctx=self.targetctx, | ||
func_ir=cloned_func_ir, | ||
args=args, return_type=return_type, | ||
flags=flags, locals=self.locals, | ||
lifted=(), | ||
lifted_from=self.lifted_from, | ||
is_lifted_loop=True, | ||
pipeline_class=DPPLCompiler) | ||
|
||
# Check typing error if object mode is used | ||
if cres.typing_error is not None and not flags.enable_pyobject: | ||
raise cres.typing_error | ||
self.add_overload(cres) | ||
return cres.entry_point | ||
|
||
|
||
class DPPLLiftedWith(DPPLLiftedCode): | ||
@property | ||
def _numba_type_(self): | ||
return types.Dispatcher(self) | ||
|
||
def get_call_template(self, args, kws): | ||
""" | ||
Get a typing.ConcreteTemplate for this dispatcher and the given | ||
*args* and *kws* types. This enables the resolving of the return type. | ||
|
||
A (template, pysig, args, kws) tuple is returned. | ||
""" | ||
# Ensure an overload is available | ||
if self._can_compile: | ||
self.compile(tuple(args)) | ||
|
||
pysig = None | ||
# Create function type for typing | ||
func_name = self.py_func.__name__ | ||
name = "CallTemplate({0})".format(func_name) | ||
# The `key` isn't really used except for diagnosis here, | ||
# so avoid keeping a reference to `cfunc`. | ||
call_template = typing.make_concrete_template( | ||
name, key=func_name, signatures=self.nopython_signatures) | ||
return call_template, pysig, args, kws | ||
|
||
|
||
dppl_context = _DPPLContextType() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is it actually doing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is necessary to create a dispatcher specifically for lifted code with new semantics. This is for now an intermediate solution, in the future it will be necessary to rewrite it so as not to change the numba files.