-
Notifications
You must be signed in to change notification settings - Fork 231
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
API: Introducing complex (np.complex64/128) native support #2375
base: master
Are you sure you want to change the base?
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #2375 +/- ##
==========================================
- Coverage 86.74% 86.69% -0.06%
==========================================
Files 235 238 +3
Lines 44521 44688 +167
Branches 8242 8269 +27
==========================================
+ Hits 38619 38741 +122
- Misses 5177 5218 +41
- Partials 725 729 +4 ☔ View full report in Codecov by Sentry. |
3195b9e
to
6ac5b99
Compare
devito/passes/iet/misc.py
Outdated
""" | ||
Add headers for complex arithmetic | ||
""" | ||
if configuration['language'] == 'cuda': |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could shorten this using:
headers = {'cuda': 'cuComplex.h', 'hip': 'hip/hip_complex.h'}
lib = headers.get(configuration['language'], 'complex.h')
devito/symbolics/printer.py
Outdated
dtype = self.dtype | ||
if np.issubdtype(dtype, np.complexfloating): | ||
func_name = 'c%s' % func_name | ||
dtype = self.dtype(0).real.dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Newline between if blocks would improve readability
tests/test_operator.py
Outdated
@@ -640,6 +640,25 @@ def test_tensor(self, func1): | |||
op2 = Operator([Eq(f, f.dx) for f in f1.values()]) | |||
assert str(op1.ccode) == str(op2.ccode) | |||
|
|||
def test_complex(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Duplicate test?
if exp not in parameters + boilerplate: | ||
error("Missing parameter: %s" % exp) | ||
assert exp in parameters + boilerplate | ||
for expi in expected: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe ex in expected
?
0268781
to
2c80bf8
Compare
e7a2791
to
05f8528
Compare
a655632
to
7cee7fb
Compare
@@ -66,6 +67,23 @@ def test_maxpar_option(self): | |||
assert trees[0][0] is trees[1][0] | |||
assert trees[0][1] is not trees[1][1] | |||
|
|||
@pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens if you try to take derivatives of an expression containing the imaginary unit? Something like (sympy.I*u).dx
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sympy.I is just a sympy Atomic it's treated like any other symbol or number such as S.One or S.NegativeOne
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I'm just inclined to add tests to check that things work the way they 'should' since I've been tripped up in the past
devito/passes/iet/misc.py
Outdated
@@ -270,3 +306,39 @@ def _rename_subdims(target, dimensions): | |||
return {d: d._rebuild(d.root.name) for d in dims | |||
if d.root not in dimensions | |||
and names.count(d.root.name) < 2} | |||
|
|||
|
|||
_stdcomplex_defs = """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
imho this belongs to a complex.h
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually, given all the other comments (this is the final one I'm writing), you may as well move the entire complex number lowering machinery to a separate python module such as complex.py
within passes/iet/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
imho this belongs to a complex.h
This is a lot more robust to generate the header in the same dir as the the generated code and avoid having to infer path from the devito dir.
complex.py
That's fine
devito/passes/iet/misc.py
Outdated
@@ -192,6 +192,42 @@ def minimize_symbols(iet): | |||
return iet, {} | |||
|
|||
|
|||
@iet_pass | |||
def complex_include(iet, language, compiler): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
include_complex
devito/passes/iet/misc.py
Outdated
@iet_pass | ||
def complex_include(iet, language, compiler): | ||
""" | ||
Add headers for complex arithmetic |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
full stop
devito/arch/compiler.py
Outdated
@@ -243,6 +245,20 @@ def version(self): | |||
|
|||
return version | |||
|
|||
@property | |||
def _complex_ctype(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no, we definitely don't want code-generation-related machinery in our Compiler classes.
The right thing to do is, instead, single-dispatching the Compiler class within our own compilation pass, which is responsible for the lowering of complex
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as such, I don't think we need to add a custom name
to Compiler?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't quite agree here, the complex type are defined by the actual compiler and their standard, i.e gnu has _Complex
and cpp has std::complex
, adding complicated dispatch is overkill for something that is standardized at the language level
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I disagree, singledispatch would achieve the same exact objective by dispatching based on the type.
Instead, what you've done here violates a crucial principle of the OO paradigm, that is, classes should have a well defined purpose. These classes are for jit-compiling a given string. They're not supposed to provide compiler-specific code generation (C- or C++ specific) information
adding complicated dispatch
I don't think it's complicated at all. An Iet_pass receives the compiler
and all you have to do is creating a series of functions based on single dispatch doing the same exact thing it's being done here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
g. They're not supposed to provide compiler-specific code generation (C- or C++ specific) information
But that's not what it is, this defines the standard associated with the compiler which is c99->_Complex, c++11->std:complex
adding a pass that move the standard out of the compiler doesn't really make sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't say (I hope) to add a pass. A single-dispatch function to retrieve some sort of type-specific information doesn't have to be a compiler pass.
But obviously our compiler pass would use it to get the code it needs
@@ -92,8 +92,12 @@ def initialize(cls): | |||
return | |||
|
|||
def alloc(self, shape, dtype, padding=0): | |||
datasize = int(reduce(mul, shape)) | |||
ctype = dtype_to_ctype(dtype) | |||
# For complex number, allocate double the size of its real/imaginary part |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
potentially useful elsewhere, so I'd move it into a function inside devito/tools/dtypes_lowering
maybe?
devito/operator/operator.py
Outdated
@@ -460,6 +460,12 @@ def _lower_iet(cls, uiet, profiler=None, **kwargs): | |||
|
|||
# Lower IET to a target-specific IET | |||
graph = Graph(iet, **kwargs) | |||
|
|||
# Complex header if needed. Needs to be done before specialization | |||
# as some specific cases require complex to be loaded first |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for instance?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FFTW requires complex.h
to be loaded first so that it's the type used rather than fftw_complex
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
by "loaded first" you mean that the header file should stay at the very top of the includes list?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It doesn't really matter right now but might later
devito/passes/iet/misc.py
Outdated
|
||
# For (cpp), need to define constant _Complex_I and missing mix-type | ||
# std::complex arithmetic | ||
if compiler._cpp: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if not ... : return
devito/tools/dtypes_lowering.py
Outdated
if np.issubdtype(dtype, np.complexfloating): | ||
rtype = dtype(0).real.__class__ | ||
from devito import configuration | ||
make = configuration['compiler']._complex_ctype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you really can't use global information here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not global because this is called within the switchconfig
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from that code path above yes, but from another one?
we shouldn't access configuration
in these remote places
@@ -301,7 +313,8 @@ def infer_dtype(dtypes): | |||
# Resolve the vector types, if any | |||
dtypes = {dtypes_vector_mapper.get_base_dtype(i, i) for i in dtypes} | |||
|
|||
fdtypes = {i for i in dtypes if np.issubdtype(i, np.floating)} | |||
fdtypes = {i for i in dtypes if np.issubdtype(i, np.floating) or | |||
np.issubdtype(i, np.complexfloating)} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
really isn't np.issubdtype(i, (np.floating, np.complexfloating))
supported?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope :(
devito/symbolics/extended_sympy.py
Outdated
|
||
@property | ||
def _base_typ(self): | ||
return configuration['compiler']._complex_ctype('float') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can't use global objects here
What we should do instead: leave CFLOAT generic
Extend the existing compiler pass to lower CFLOAT into something more specific such as CFLOAT_GCC
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again, with the cgen visitor now always using the local config from Oeprator this is never called with a global config.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it doesn't matter, it's conceptually wrong
you're assuming you only go through _base_typ
via that visitor, but who/what imposes that?
this is basically just a workaround to avoid a more graceful lowering, which you can do as explained in my first message
3f0b8e2
to
fdf1b36
Compare
ac6d213
to
6b2b908
Compare
@@ -189,6 +192,8 @@ def _(expr, estimate, seen): | |||
flops, flags = _estimate_cost.registry[object](expr, estimate, seen) | |||
if {S.One, S.NegativeOne}.intersection(expr.args): | |||
flops -= 1 | |||
if ImaginaryUnit in expr.args: | |||
flops *= 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For all ops *=2 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Complex arithmetic cost 2 FMA
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, is it guranteed that all args in an expr will be Imaginary?
Am I seeing sth wrong?
@@ -180,6 +180,8 @@ def __init__(self): | |||
_cpp = False | |||
|
|||
def __init__(self, **kwargs): | |||
self._name = kwargs.pop('name', self.__class__.__name__) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think you need the name
anymore no?
c_scale = 2 if np.issubdtype(dtype, np.complexfloating) else 1 | ||
|
||
datasize = int(reduce(mul, shape) * c_scale) | ||
ctype = dtype_to_ctype(alloc_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you should pass in dtype
and extend dtype_to_ctype
figure out the rest
alloc_dtype = dtype(0).real.__class__ | ||
c_scale = 2 if np.issubdtype(dtype, np.complexfloating) else 1 | ||
|
||
datasize = int(reduce(mul, shape) * c_scale) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
potentially PITA comment...
An observation: for complex we implement SoA, while for bundles we implement AoS
it's true we don't support it yet, but one day we might want to be able to allocate bundles-like Functions (hence AoS) in user-land
so, long story short, here I'd have
datasize = infer_datasize(shape, dtype)
where for now you just put in the logic above
as I said, this is probably a nitpick, so feel free to ignore
@@ -68,7 +67,7 @@ def grid(self): | |||
|
|||
@cached_property | |||
def dtype(self): | |||
dtypes = {f.dtype for f in self.find(Indexed)} - {None} | |||
dtypes = {f.dtype for f in self._functions} - {None} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.find is so expensive! Good we can get rid of it
@@ -465,6 +465,8 @@ def _lower_iet(cls, uiet, profiler=None, **kwargs): | |||
|
|||
# Lower IET to a target-specific IET | |||
graph = Graph(iet, **kwargs) | |||
|
|||
# Specialize |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
irrelevant
@@ -0,0 +1,123 @@ | |||
import numpy as np |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure honestly about this file since after all it's still all about symbolic objects
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was annoying me to find it in extended_sympy with all the rest :D can revert
@@ -189,6 +192,8 @@ def _(expr, estimate, seen): | |||
flops, flags = _estimate_cost.registry[object](expr, estimate, seen) | |||
if {S.One, S.NegativeOne}.intersection(expr.args): | |||
flops -= 1 | |||
if ImaginaryUnit in expr.args: | |||
flops *= 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
?
devito/tools/dtypes_lowering.py
Outdated
|
||
def __call__(self, val): | ||
return self.nptype(val) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one more blank line here before the comment
devito/tools/dtypes_lowering.py
Outdated
@@ -123,6 +123,18 @@ def __repr__(self): | |||
__str__ = __repr__ | |||
|
|||
|
|||
class CustomNpType(CustomDtype): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we actually need a subclass? Might just be overspecialization
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, basically need to be able to call dtype(0)
from it for the printer/real dtype inferance
@@ -1170,7 +1176,8 @@ def bound_symbols(self): | |||
@cached_property | |||
def indexed(self): | |||
"""The wrapped IndexedData object.""" | |||
return IndexedData(self.name, shape=self._shape, function=self.function) | |||
return IndexedData(self.name, shape=self._shape, function=self.function, | |||
dtype=self.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be carried by self.function
already... how can it be different?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is, can leave it out since it will use function.dtype
if not provided so doesn't really matter
Support complex float data type.