Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API: Introducing complex (np.complex64/128) native support #2375

Open
wants to merge 11 commits into
base: master
Choose a base branch
from

Conversation

mloubout
Copy link
Contributor

Support complex float data type.

@mloubout mloubout added API api (symbolics, types, ...) feature-request labels May 26, 2024
Copy link

codecov bot commented May 26, 2024

Codecov Report

Attention: Patch coverage is 82.52033% with 43 lines in your changes missing coverage. Please review.

Project coverage is 86.69%. Comparing base (156db8e) to head (94d5571).
Report is 69 commits behind head on master.

Files Patch % Lines
tests/test_gpu_common.py 7.69% 12 Missing ⚠️
devito/symbolics/printer.py 61.53% 7 Missing and 3 partials ⚠️
devito/passes/iet/dtypes.py 73.33% 6 Missing and 2 partials ⚠️
devito/symbolics/inspection.py 81.81% 1 Missing and 1 partial ⚠️
devito/tools/dtypes_lowering.py 33.33% 1 Missing and 1 partial ⚠️
devito/types/basic.py 88.23% 1 Missing and 1 partial ⚠️
tests/test_operator.py 87.50% 1 Missing and 1 partial ⚠️
devito/arch/compiler.py 75.00% 1 Missing ⚠️
devito/finite_differences/differentiable.py 50.00% 1 Missing ⚠️
devito/ir/iet/visitors.py 95.23% 1 Missing ⚠️
... and 2 more
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #2375      +/-   ##
==========================================
- Coverage   86.74%   86.69%   -0.06%     
==========================================
  Files         235      238       +3     
  Lines       44521    44688     +167     
  Branches     8242     8269      +27     
==========================================
+ Hits        38619    38741     +122     
- Misses       5177     5218      +41     
- Partials      725      729       +4     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@mloubout mloubout force-pushed the complex branch 4 times, most recently from 3195b9e to 6ac5b99 Compare May 27, 2024 16:52
devito/operator/operator.py Outdated Show resolved Hide resolved
"""
Add headers for complex arithmetic
"""
if configuration['language'] == 'cuda':
Copy link
Contributor

@EdCaunt EdCaunt May 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could shorten this using:

headers = {'cuda': 'cuComplex.h', 'hip': 'hip/hip_complex.h'}
lib = headers.get(configuration['language'], 'complex.h')

devito/passes/iet/misc.py Outdated Show resolved Hide resolved
dtype = self.dtype
if np.issubdtype(dtype, np.complexfloating):
func_name = 'c%s' % func_name
dtype = self.dtype(0).real.dtype
Copy link
Contributor

@EdCaunt EdCaunt May 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Newline between if blocks would improve readability

tests/test_gpu_common.py Outdated Show resolved Hide resolved
@@ -640,6 +640,25 @@ def test_tensor(self, func1):
op2 = Operator([Eq(f, f.dx) for f in f1.values()])
assert str(op1.ccode) == str(op2.ccode)

def test_complex(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicate test?

if exp not in parameters + boilerplate:
error("Missing parameter: %s" % exp)
assert exp in parameters + boilerplate
for expi in expected:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe ex in expected?

@mloubout mloubout force-pushed the complex branch 4 times, most recently from 0268781 to 2c80bf8 Compare May 28, 2024 17:16
devito/ir/iet/visitors.py Outdated Show resolved Hide resolved
devito/passes/iet/misc.py Outdated Show resolved Hide resolved
@mloubout mloubout force-pushed the complex branch 3 times, most recently from e7a2791 to 05f8528 Compare May 30, 2024 18:19
devito/arch/compiler.py Outdated Show resolved Hide resolved
devito/tools/dtypes_lowering.py Outdated Show resolved Hide resolved
tests/test_gpu_common.py Outdated Show resolved Hide resolved
tests/test_gpu_common.py Outdated Show resolved Hide resolved
tests/test_operator.py Show resolved Hide resolved
tests/test_operator.py Outdated Show resolved Hide resolved
@mloubout mloubout force-pushed the complex branch 3 times, most recently from a655632 to 7cee7fb Compare May 31, 2024 15:10
@@ -66,6 +67,23 @@ def test_maxpar_option(self):
assert trees[0][0] is trees[1][0]
assert trees[0][1] is not trees[1][1]

@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if you try to take derivatives of an expression containing the imaginary unit? Something like (sympy.I*u).dx?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sympy.I is just a sympy Atomic it's treated like any other symbol or number such as S.One or S.NegativeOne

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I'm just inclined to add tests to check that things work the way they 'should' since I've been tripped up in the past

@@ -270,3 +306,39 @@ def _rename_subdims(target, dimensions):
return {d: d._rebuild(d.root.name) for d in dims
if d.root not in dimensions
and names.count(d.root.name) < 2}


_stdcomplex_defs = """
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imho this belongs to a complex.h

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, given all the other comments (this is the final one I'm writing), you may as well move the entire complex number lowering machinery to a separate python module such as complex.py within passes/iet/

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imho this belongs to a complex.h

This is a lot more robust to generate the header in the same dir as the the generated code and avoid having to infer path from the devito dir.

complex.py

That's fine

@@ -192,6 +192,42 @@ def minimize_symbols(iet):
return iet, {}


@iet_pass
def complex_include(iet, language, compiler):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

include_complex

@iet_pass
def complex_include(iet, language, compiler):
"""
Add headers for complex arithmetic
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

full stop

@@ -243,6 +245,20 @@ def version(self):

return version

@property
def _complex_ctype(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, we definitely don't want code-generation-related machinery in our Compiler classes.

The right thing to do is, instead, single-dispatching the Compiler class within our own compilation pass, which is responsible for the lowering of complex

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as such, I don't think we need to add a custom name to Compiler?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite agree here, the complex type are defined by the actual compiler and their standard, i.e gnu has _Complex and cpp has std::complex, adding complicated dispatch is overkill for something that is standardized at the language level

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I disagree, singledispatch would achieve the same exact objective by dispatching based on the type.

Instead, what you've done here violates a crucial principle of the OO paradigm, that is, classes should have a well defined purpose. These classes are for jit-compiling a given string. They're not supposed to provide compiler-specific code generation (C- or C++ specific) information

adding complicated dispatch

I don't think it's complicated at all. An Iet_pass receives the compiler and all you have to do is creating a series of functions based on single dispatch doing the same exact thing it's being done here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

g. They're not supposed to provide compiler-specific code generation (C- or C++ specific) information

But that's not what it is, this defines the standard associated with the compiler which is c99->_Complex, c++11->std:complex

adding a pass that move the standard out of the compiler doesn't really make sense.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't say (I hope) to add a pass. A single-dispatch function to retrieve some sort of type-specific information doesn't have to be a compiler pass.

But obviously our compiler pass would use it to get the code it needs

@@ -92,8 +92,12 @@ def initialize(cls):
return

def alloc(self, shape, dtype, padding=0):
datasize = int(reduce(mul, shape))
ctype = dtype_to_ctype(dtype)
# For complex number, allocate double the size of its real/imaginary part
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

potentially useful elsewhere, so I'd move it into a function inside devito/tools/dtypes_lowering maybe?

@@ -460,6 +460,12 @@ def _lower_iet(cls, uiet, profiler=None, **kwargs):

# Lower IET to a target-specific IET
graph = Graph(iet, **kwargs)

# Complex header if needed. Needs to be done before specialization
# as some specific cases require complex to be loaded first
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for instance?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FFTW requires complex.h to be loaded first so that it's the type used rather than fftw_complex

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

by "loaded first" you mean that the header file should stay at the very top of the includes list?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't really matter right now but might later


# For (cpp), need to define constant _Complex_I and missing mix-type
# std::complex arithmetic
if compiler._cpp:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if not ... : return

if np.issubdtype(dtype, np.complexfloating):
rtype = dtype(0).real.__class__
from devito import configuration
make = configuration['compiler']._complex_ctype
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you really can't use global information here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not global because this is called within the switchconfig

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from that code path above yes, but from another one?

we shouldn't access configuration in these remote places

@@ -301,7 +313,8 @@ def infer_dtype(dtypes):
# Resolve the vector types, if any
dtypes = {dtypes_vector_mapper.get_base_dtype(i, i) for i in dtypes}

fdtypes = {i for i in dtypes if np.issubdtype(i, np.floating)}
fdtypes = {i for i in dtypes if np.issubdtype(i, np.floating) or
np.issubdtype(i, np.complexfloating)}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

really isn't np.issubdtype(i, (np.floating, np.complexfloating)) supported?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope :(


@property
def _base_typ(self):
return configuration['compiler']._complex_ctype('float')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can't use global objects here

What we should do instead: leave CFLOAT generic

Extend the existing compiler pass to lower CFLOAT into something more specific such as CFLOAT_GCC

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, with the cgen visitor now always using the local config from Oeprator this is never called with a global config.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it doesn't matter, it's conceptually wrong

you're assuming you only go through _base_typ via that visitor, but who/what imposes that?

this is basically just a workaround to avoid a more graceful lowering, which you can do as explained in my first message

@mloubout mloubout force-pushed the complex branch 2 times, most recently from 3f0b8e2 to fdf1b36 Compare June 21, 2024 18:21
@mloubout mloubout force-pushed the complex branch 11 times, most recently from ac6d213 to 6b2b908 Compare June 26, 2024 20:29
@@ -189,6 +192,8 @@ def _(expr, estimate, seen):
flops, flags = _estimate_cost.registry[object](expr, estimate, seen)
if {S.One, S.NegativeOne}.intersection(expr.args):
flops -= 1
if ImaginaryUnit in expr.args:
flops *= 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For all ops *=2 ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Complex arithmetic cost 2 FMA

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, is it guranteed that all args in an expr will be Imaginary?
Am I seeing sth wrong?

@@ -180,6 +180,8 @@ def __init__(self):
_cpp = False

def __init__(self, **kwargs):
self._name = kwargs.pop('name', self.__class__.__name__)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need the name anymore no?

c_scale = 2 if np.issubdtype(dtype, np.complexfloating) else 1

datasize = int(reduce(mul, shape) * c_scale)
ctype = dtype_to_ctype(alloc_dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should pass in dtype and extend dtype_to_ctype figure out the rest

alloc_dtype = dtype(0).real.__class__
c_scale = 2 if np.issubdtype(dtype, np.complexfloating) else 1

datasize = int(reduce(mul, shape) * c_scale)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

potentially PITA comment...

An observation: for complex we implement SoA, while for bundles we implement AoS

it's true we don't support it yet, but one day we might want to be able to allocate bundles-like Functions (hence AoS) in user-land

so, long story short, here I'd have

datasize  = infer_datasize(shape, dtype)

where for now you just put in the logic above

as I said, this is probably a nitpick, so feel free to ignore

@@ -68,7 +67,7 @@ def grid(self):

@cached_property
def dtype(self):
dtypes = {f.dtype for f in self.find(Indexed)} - {None}
dtypes = {f.dtype for f in self._functions} - {None}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.find is so expensive! Good we can get rid of it

@@ -465,6 +465,8 @@ def _lower_iet(cls, uiet, profiler=None, **kwargs):

# Lower IET to a target-specific IET
graph = Graph(iet, **kwargs)

# Specialize
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

irrelevant

@@ -0,0 +1,123 @@
import numpy as np
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure honestly about this file since after all it's still all about symbolic objects

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was annoying me to find it in extended_sympy with all the rest :D can revert

@@ -189,6 +192,8 @@ def _(expr, estimate, seen):
flops, flags = _estimate_cost.registry[object](expr, estimate, seen)
if {S.One, S.NegativeOne}.intersection(expr.args):
flops -= 1
if ImaginaryUnit in expr.args:
flops *= 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?


def __call__(self, val):
return self.nptype(val)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one more blank line here before the comment

@@ -123,6 +123,18 @@ def __repr__(self):
__str__ = __repr__


class CustomNpType(CustomDtype):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we actually need a subclass? Might just be overspecialization

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, basically need to be able to call dtype(0) from it for the printer/real dtype inferance

@@ -1170,7 +1176,8 @@ def bound_symbols(self):
@cached_property
def indexed(self):
"""The wrapped IndexedData object."""
return IndexedData(self.name, shape=self._shape, function=self.function)
return IndexedData(self.name, shape=self._shape, function=self.function,
dtype=self.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be carried by self.function already... how can it be different?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is, can leave it out since it will use function.dtype if not provided so doesn't really matter

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
API api (symbolics, types, ...) compiler feature-request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants