Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API: Introducing complex (np.complex64/128) native support #2375

Open
wants to merge 41 commits into
base: master
Choose a base branch
from

Conversation

mloubout
Copy link
Contributor

Support complex float data type.

@mloubout mloubout added API api (symbolics, types, ...) feature-request labels May 26, 2024
Copy link

codecov bot commented May 26, 2024

Codecov Report

Attention: Patch coverage is 90.43977% with 50 lines in your changes missing coverage. Please review.

Project coverage is 87.32%. Comparing base (da2c9a4) to head (1e4a125).

Files with missing lines Patch % Lines
tests/test_gpu_common.py 7.14% 13 Missing ⚠️
devito/tools/dtypes_lowering.py 47.36% 5 Missing and 5 partials ⚠️
devito/ir/iet/nodes.py 44.44% 3 Missing and 2 partials ⚠️
devito/symbolics/inspection.py 75.00% 2 Missing and 2 partials ⚠️
devito/symbolics/printer.py 95.40% 2 Missing and 2 partials ⚠️
devito/types/dense.py 20.00% 2 Missing and 2 partials ⚠️
devito/symbolics/extended_sympy.py 94.87% 2 Missing ⚠️
devito/types/basic.py 50.00% 1 Missing and 1 partial ⚠️
tests/test_operator.py 88.23% 1 Missing and 1 partial ⚠️
devito/arch/compiler.py 92.30% 1 Missing ⚠️
... and 3 more
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #2375      +/-   ##
==========================================
+ Coverage   87.29%   87.32%   +0.02%     
==========================================
  Files         238      243       +5     
  Lines       45749    46093     +344     
  Branches     4059     4083      +24     
==========================================
+ Hits        39937    40250     +313     
- Misses       5127     5148      +21     
- Partials      685      695      +10     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@mloubout mloubout force-pushed the complex branch 4 times, most recently from 3195b9e to 6ac5b99 Compare May 27, 2024 16:52
devito/operator/operator.py Outdated Show resolved Hide resolved
"""
Add headers for complex arithmetic
"""
if configuration['language'] == 'cuda':
Copy link
Contributor

@EdCaunt EdCaunt May 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could shorten this using:

headers = {'cuda': 'cuComplex.h', 'hip': 'hip/hip_complex.h'}
lib = headers.get(configuration['language'], 'complex.h')

devito/passes/iet/misc.py Outdated Show resolved Hide resolved
dtype = self.dtype
if np.issubdtype(dtype, np.complexfloating):
func_name = 'c%s' % func_name
dtype = self.dtype(0).real.dtype
Copy link
Contributor

@EdCaunt EdCaunt May 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Newline between if blocks would improve readability

tests/test_gpu_common.py Outdated Show resolved Hide resolved
@@ -640,6 +640,25 @@ def test_tensor(self, func1):
op2 = Operator([Eq(f, f.dx) for f in f1.values()])
assert str(op1.ccode) == str(op2.ccode)

def test_complex(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicate test?

if exp not in parameters + boilerplate:
error("Missing parameter: %s" % exp)
assert exp in parameters + boilerplate
for expi in expected:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe ex in expected?

@mloubout mloubout force-pushed the complex branch 4 times, most recently from 0268781 to 2c80bf8 Compare May 28, 2024 17:16
devito/ir/iet/visitors.py Outdated Show resolved Hide resolved
devito/passes/iet/misc.py Outdated Show resolved Hide resolved
@mloubout mloubout force-pushed the complex branch 3 times, most recently from e7a2791 to 05f8528 Compare May 30, 2024 18:19
devito/arch/compiler.py Outdated Show resolved Hide resolved
devito/tools/dtypes_lowering.py Outdated Show resolved Hide resolved
tests/test_gpu_common.py Outdated Show resolved Hide resolved
tests/test_gpu_common.py Outdated Show resolved Hide resolved
tests/test_operator.py Show resolved Hide resolved
tests/test_operator.py Outdated Show resolved Hide resolved
@mloubout mloubout force-pushed the complex branch 3 times, most recently from a655632 to 7cee7fb Compare May 31, 2024 15:10
@@ -66,6 +67,23 @@ def test_maxpar_option(self):
assert trees[0][0] is trees[1][0]
assert trees[0][1] is not trees[1][1]

@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if you try to take derivatives of an expression containing the imaginary unit? Something like (sympy.I*u).dx?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sympy.I is just a sympy Atomic it's treated like any other symbol or number such as S.One or S.NegativeOne

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I'm just inclined to add tests to check that things work the way they 'should' since I've been tripped up in the past

@@ -270,3 +306,39 @@ def _rename_subdims(target, dimensions):
return {d: d._rebuild(d.root.name) for d in dims
if d.root not in dimensions
and names.count(d.root.name) < 2}


_stdcomplex_defs = """
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imho this belongs to a complex.h

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, given all the other comments (this is the final one I'm writing), you may as well move the entire complex number lowering machinery to a separate python module such as complex.py within passes/iet/

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imho this belongs to a complex.h

This is a lot more robust to generate the header in the same dir as the the generated code and avoid having to infer path from the devito dir.

complex.py

That's fine

@@ -192,6 +192,42 @@ def minimize_symbols(iet):
return iet, {}


@iet_pass
def complex_include(iet, language, compiler):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

include_complex

@iet_pass
def complex_include(iet, language, compiler):
"""
Add headers for complex arithmetic
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

full stop

@@ -243,6 +245,20 @@ def version(self):

return version

@property
def _complex_ctype(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, we definitely don't want code-generation-related machinery in our Compiler classes.

The right thing to do is, instead, single-dispatching the Compiler class within our own compilation pass, which is responsible for the lowering of complex

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as such, I don't think we need to add a custom name to Compiler?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite agree here, the complex type are defined by the actual compiler and their standard, i.e gnu has _Complex and cpp has std::complex, adding complicated dispatch is overkill for something that is standardized at the language level

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I disagree, singledispatch would achieve the same exact objective by dispatching based on the type.

Instead, what you've done here violates a crucial principle of the OO paradigm, that is, classes should have a well defined purpose. These classes are for jit-compiling a given string. They're not supposed to provide compiler-specific code generation (C- or C++ specific) information

adding complicated dispatch

I don't think it's complicated at all. An Iet_pass receives the compiler and all you have to do is creating a series of functions based on single dispatch doing the same exact thing it's being done here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

g. They're not supposed to provide compiler-specific code generation (C- or C++ specific) information

But that's not what it is, this defines the standard associated with the compiler which is c99->_Complex, c++11->std:complex

adding a pass that move the standard out of the compiler doesn't really make sense.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't say (I hope) to add a pass. A single-dispatch function to retrieve some sort of type-specific information doesn't have to be a compiler pass.

But obviously our compiler pass would use it to get the code it needs

@@ -92,8 +92,12 @@ def initialize(cls):
return

def alloc(self, shape, dtype, padding=0):
datasize = int(reduce(mul, shape))
ctype = dtype_to_ctype(dtype)
# For complex number, allocate double the size of its real/imaginary part
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

potentially useful elsewhere, so I'd move it into a function inside devito/tools/dtypes_lowering maybe?

@@ -460,6 +460,12 @@ def _lower_iet(cls, uiet, profiler=None, **kwargs):

# Lower IET to a target-specific IET
graph = Graph(iet, **kwargs)

# Complex header if needed. Needs to be done before specialization
# as some specific cases require complex to be loaded first
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for instance?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FFTW requires complex.h to be loaded first so that it's the type used rather than fftw_complex

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

by "loaded first" you mean that the header file should stay at the very top of the includes list?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't really matter right now but might later


# For (cpp), need to define constant _Complex_I and missing mix-type
# std::complex arithmetic
if compiler._cpp:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if not ... : return

if np.issubdtype(dtype, np.complexfloating):
rtype = dtype(0).real.__class__
from devito import configuration
make = configuration['compiler']._complex_ctype
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you really can't use global information here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not global because this is called within the switchconfig

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from that code path above yes, but from another one?

we shouldn't access configuration in these remote places

@@ -301,7 +313,8 @@ def infer_dtype(dtypes):
# Resolve the vector types, if any
dtypes = {dtypes_vector_mapper.get_base_dtype(i, i) for i in dtypes}

fdtypes = {i for i in dtypes if np.issubdtype(i, np.floating)}
fdtypes = {i for i in dtypes if np.issubdtype(i, np.floating) or
np.issubdtype(i, np.complexfloating)}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

really isn't np.issubdtype(i, (np.floating, np.complexfloating)) supported?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope :(


@property
def _base_typ(self):
return configuration['compiler']._complex_ctype('float')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can't use global objects here

What we should do instead: leave CFLOAT generic

Extend the existing compiler pass to lower CFLOAT into something more specific such as CFLOAT_GCC

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, with the cgen visitor now always using the local config from Oeprator this is never called with a global config.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it doesn't matter, it's conceptually wrong

you're assuming you only go through _base_typ via that visitor, but who/what imposes that?

this is basically just a workaround to avoid a more graceful lowering, which you can do as explained in my first message

@mloubout mloubout force-pushed the complex branch 2 times, most recently from 3f0b8e2 to fdf1b36 Compare June 21, 2024 18:21
@mloubout mloubout force-pushed the complex branch 2 times, most recently from edb98c0 to aef841b Compare January 18, 2025 03:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
API api (symbolics, types, ...) compiler feature-request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants