-
Notifications
You must be signed in to change notification settings - Fork 231
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
18 changed files
with
336 additions
and
175 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import numpy as np | ||
import ctypes | ||
|
||
from devito.ir import FindSymbols, Uxreplace | ||
|
||
__all__ = ['lower_complex'] | ||
|
||
|
||
def lower_complex(iet, lang, compiler): | ||
""" | ||
Add headers for complex arithmetic | ||
""" | ||
# Check if there is complex numbers that always take dtype precedence | ||
types = {f.dtype for f in FindSymbols().visit(iet) | ||
if not issubclass(f.dtype, ctypes._Pointer)} | ||
|
||
if not any(np.issubdtype(d, np.complexfloating) for d in types): | ||
return iet, {} | ||
|
||
lib = (lang['header-complex'],) | ||
headers = lang.get('I-def') | ||
|
||
# Some languges such as c++11 need some extra arithmetic definitions | ||
if lang.get('def-complex'): | ||
dest = compiler.get_jit_dir() | ||
hfile = dest.joinpath('complex_arith.h') | ||
with open(str(hfile), 'w') as ff: | ||
ff.write(str(lang['def-complex'])) | ||
lib += (str(hfile),) | ||
|
||
iet = _complex_dtypes(iet, lang) | ||
|
||
return iet, {'includes': lib, 'headers': headers} | ||
|
||
|
||
def _complex_dtypes(iet, lang): | ||
""" | ||
Lower dtypes to language specific types | ||
""" | ||
mapper = {} | ||
|
||
for s in FindSymbols('indexeds').visit(iet): | ||
if s.dtype in lang['types']: | ||
mapper[s] = s._rebuild(dtype=lang['types'][s.dtype]) | ||
|
||
for s in FindSymbols().visit(iet): | ||
if s.dtype in lang['types']: | ||
mapper[s] = s._rebuild(dtype=lang['types'][s.dtype]) | ||
|
||
body = Uxreplace(mapper).visit(iet.body) | ||
params = Uxreplace(mapper).visit(iet.parameters) | ||
iet = iet._rebuild(body=body, parameters=params) | ||
|
||
return iet |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import numpy as np | ||
|
||
from devito.ir import Call | ||
from devito.passes.iet.langbase import LangBB | ||
from devito.tools import CustomNpType | ||
|
||
__all__ = ['CXXBB'] | ||
|
||
|
||
std_arith = """ | ||
#include <complex> | ||
template<typename _Tp, typename _Ti> | ||
std::complex<_Tp> operator * (const _Ti & a, const std::complex<_Tp> & b){ | ||
return std::complex<_Tp>(b.real() * a, b.imag() * a); | ||
} | ||
template<typename _Tp, typename _Ti> | ||
std::complex<_Tp> operator * (const std::complex<_Tp> & b, const _Ti & a){ | ||
return std::complex<_Tp>(b.real() * a, b.imag() * a); | ||
} | ||
template<typename _Tp, typename _Ti> | ||
std::complex<_Tp> operator / (const _Ti & a, const std::complex<_Tp> & b){ | ||
_Tp denom = b.real() * b.real () + b.imag() * b.imag() | ||
return std::complex<_Tp>(b.real() * a / denom, - b.imag() * a / denom); | ||
} | ||
template<typename _Tp, typename _Ti> | ||
std::complex<_Tp> operator / (const std::complex<_Tp> & b, const _Ti & a){ | ||
return std::complex<_Tp>(b.real() / a, b.imag() / a); | ||
} | ||
template<typename _Tp, typename _Ti> | ||
std::complex<_Tp> operator + (const _Ti & a, const std::complex<_Tp> & b){ | ||
return std::complex<_Tp>(b.real() + a, b.imag()); | ||
} | ||
template<typename _Tp, typename _Ti> | ||
std::complex<_Tp> operator + (const std::complex<_Tp> & b, const _Ti & a){ | ||
return std::complex<_Tp>(b.real() + a, b.imag()); | ||
} | ||
""" | ||
|
||
CXXCFloat = CustomNpType('std::complex', np.complex64, template='float') | ||
CXXCDouble = CustomNpType('std::complex', np.complex128, template='double') | ||
|
||
|
||
class CXXBB(LangBB): | ||
|
||
mapper = { | ||
'header-memcpy': 'string.h', | ||
'host-alloc': lambda i, j, k: | ||
Call('posix_memalign', (i, j, k)), | ||
'host-alloc-pin': lambda i, j, k: | ||
Call('posix_memalign', (i, j, k)), | ||
'host-free': lambda i: | ||
Call('free', (i,)), | ||
'host-free-pin': lambda i: | ||
Call('free', (i,)), | ||
'alloc-global-symbol': lambda i, j, k: | ||
Call('memcpy', (i, j, k)), | ||
# Complex | ||
'header-complex': 'complex', | ||
'I-def': (('_Complex_I', ('std::complex<float>(0.0, 1.0)')),), | ||
'def-complex': std_arith, | ||
'types': {np.complex128: CXXCDouble, np.complex64: CXXCFloat}, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.