-
Notifications
You must be signed in to change notification settings - Fork 230
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
API: Introducing complex (np.complex64/128) native support #2375
base: master
Are you sure you want to change the base?
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #2375 +/- ##
==========================================
+ Coverage 87.29% 87.32% +0.02%
==========================================
Files 238 243 +5
Lines 45749 46093 +344
Branches 4059 4083 +24
==========================================
+ Hits 39937 40250 +313
- Misses 5127 5148 +21
- Partials 685 695 +10 ☔ View full report in Codecov by Sentry. |
3195b9e
to
6ac5b99
Compare
devito/passes/iet/misc.py
Outdated
""" | ||
Add headers for complex arithmetic | ||
""" | ||
if configuration['language'] == 'cuda': |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could shorten this using:
headers = {'cuda': 'cuComplex.h', 'hip': 'hip/hip_complex.h'}
lib = headers.get(configuration['language'], 'complex.h')
devito/symbolics/printer.py
Outdated
dtype = self.dtype | ||
if np.issubdtype(dtype, np.complexfloating): | ||
func_name = 'c%s' % func_name | ||
dtype = self.dtype(0).real.dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Newline between if blocks would improve readability
tests/test_operator.py
Outdated
@@ -640,6 +640,25 @@ def test_tensor(self, func1): | |||
op2 = Operator([Eq(f, f.dx) for f in f1.values()]) | |||
assert str(op1.ccode) == str(op2.ccode) | |||
|
|||
def test_complex(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Duplicate test?
if exp not in parameters + boilerplate: | ||
error("Missing parameter: %s" % exp) | ||
assert exp in parameters + boilerplate | ||
for expi in expected: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe ex in expected
?
0268781
to
2c80bf8
Compare
e7a2791
to
05f8528
Compare
a655632
to
7cee7fb
Compare
@@ -66,6 +67,23 @@ def test_maxpar_option(self): | |||
assert trees[0][0] is trees[1][0] | |||
assert trees[0][1] is not trees[1][1] | |||
|
|||
@pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens if you try to take derivatives of an expression containing the imaginary unit? Something like (sympy.I*u).dx
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sympy.I is just a sympy Atomic it's treated like any other symbol or number such as S.One or S.NegativeOne
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I'm just inclined to add tests to check that things work the way they 'should' since I've been tripped up in the past
devito/passes/iet/misc.py
Outdated
@@ -270,3 +306,39 @@ def _rename_subdims(target, dimensions): | |||
return {d: d._rebuild(d.root.name) for d in dims | |||
if d.root not in dimensions | |||
and names.count(d.root.name) < 2} | |||
|
|||
|
|||
_stdcomplex_defs = """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
imho this belongs to a complex.h
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually, given all the other comments (this is the final one I'm writing), you may as well move the entire complex number lowering machinery to a separate python module such as complex.py
within passes/iet/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
imho this belongs to a complex.h
This is a lot more robust to generate the header in the same dir as the the generated code and avoid having to infer path from the devito dir.
complex.py
That's fine
devito/passes/iet/misc.py
Outdated
@@ -192,6 +192,42 @@ def minimize_symbols(iet): | |||
return iet, {} | |||
|
|||
|
|||
@iet_pass | |||
def complex_include(iet, language, compiler): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
include_complex
devito/passes/iet/misc.py
Outdated
@iet_pass | ||
def complex_include(iet, language, compiler): | ||
""" | ||
Add headers for complex arithmetic |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
full stop
devito/arch/compiler.py
Outdated
@@ -243,6 +245,20 @@ def version(self): | |||
|
|||
return version | |||
|
|||
@property | |||
def _complex_ctype(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no, we definitely don't want code-generation-related machinery in our Compiler classes.
The right thing to do is, instead, single-dispatching the Compiler class within our own compilation pass, which is responsible for the lowering of complex
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as such, I don't think we need to add a custom name
to Compiler?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't quite agree here, the complex type are defined by the actual compiler and their standard, i.e gnu has _Complex
and cpp has std::complex
, adding complicated dispatch is overkill for something that is standardized at the language level
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I disagree, singledispatch would achieve the same exact objective by dispatching based on the type.
Instead, what you've done here violates a crucial principle of the OO paradigm, that is, classes should have a well defined purpose. These classes are for jit-compiling a given string. They're not supposed to provide compiler-specific code generation (C- or C++ specific) information
adding complicated dispatch
I don't think it's complicated at all. An Iet_pass receives the compiler
and all you have to do is creating a series of functions based on single dispatch doing the same exact thing it's being done here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
g. They're not supposed to provide compiler-specific code generation (C- or C++ specific) information
But that's not what it is, this defines the standard associated with the compiler which is c99->_Complex, c++11->std:complex
adding a pass that move the standard out of the compiler doesn't really make sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't say (I hope) to add a pass. A single-dispatch function to retrieve some sort of type-specific information doesn't have to be a compiler pass.
But obviously our compiler pass would use it to get the code it needs
devito/data/allocators.py
Outdated
@@ -92,8 +92,12 @@ def initialize(cls): | |||
return | |||
|
|||
def alloc(self, shape, dtype, padding=0): | |||
datasize = int(reduce(mul, shape)) | |||
ctype = dtype_to_ctype(dtype) | |||
# For complex number, allocate double the size of its real/imaginary part |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
potentially useful elsewhere, so I'd move it into a function inside devito/tools/dtypes_lowering
maybe?
devito/operator/operator.py
Outdated
@@ -460,6 +460,12 @@ def _lower_iet(cls, uiet, profiler=None, **kwargs): | |||
|
|||
# Lower IET to a target-specific IET | |||
graph = Graph(iet, **kwargs) | |||
|
|||
# Complex header if needed. Needs to be done before specialization | |||
# as some specific cases require complex to be loaded first |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for instance?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FFTW requires complex.h
to be loaded first so that it's the type used rather than fftw_complex
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
by "loaded first" you mean that the header file should stay at the very top of the includes list?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It doesn't really matter right now but might later
devito/passes/iet/misc.py
Outdated
|
||
# For (cpp), need to define constant _Complex_I and missing mix-type | ||
# std::complex arithmetic | ||
if compiler._cpp: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if not ... : return
devito/tools/dtypes_lowering.py
Outdated
if np.issubdtype(dtype, np.complexfloating): | ||
rtype = dtype(0).real.__class__ | ||
from devito import configuration | ||
make = configuration['compiler']._complex_ctype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you really can't use global information here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not global because this is called within the switchconfig
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from that code path above yes, but from another one?
we shouldn't access configuration
in these remote places
@@ -301,7 +313,8 @@ def infer_dtype(dtypes): | |||
# Resolve the vector types, if any | |||
dtypes = {dtypes_vector_mapper.get_base_dtype(i, i) for i in dtypes} | |||
|
|||
fdtypes = {i for i in dtypes if np.issubdtype(i, np.floating)} | |||
fdtypes = {i for i in dtypes if np.issubdtype(i, np.floating) or | |||
np.issubdtype(i, np.complexfloating)} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
really isn't np.issubdtype(i, (np.floating, np.complexfloating))
supported?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope :(
devito/symbolics/extended_sympy.py
Outdated
|
||
@property | ||
def _base_typ(self): | ||
return configuration['compiler']._complex_ctype('float') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can't use global objects here
What we should do instead: leave CFLOAT generic
Extend the existing compiler pass to lower CFLOAT into something more specific such as CFLOAT_GCC
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again, with the cgen visitor now always using the local config from Oeprator this is never called with a global config.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it doesn't matter, it's conceptually wrong
you're assuming you only go through _base_typ
via that visitor, but who/what imposes that?
this is basically just a workaround to avoid a more graceful lowering, which you can do as explained in my first message
3f0b8e2
to
fdf1b36
Compare
edb98c0
to
aef841b
Compare
Support complex float data type.