From 881afcbde44a7d72cdba4a39ec673b462ece8313 Mon Sep 17 00:00:00 2001 From: mloubout Date: Tue, 12 Mar 2024 20:25:40 -0400 Subject: [PATCH] api: add priority to fd coefficients for mixed derivatives --- devito/finite_differences/differentiable.py | 9 +++------ devito/finite_differences/tools.py | 1 + tests/test_symbolic_coefficients.py | 7 +++++++ 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/devito/finite_differences/differentiable.py b/devito/finite_differences/differentiable.py index ee3e2ba38b..d2df797df6 100644 --- a/devito/finite_differences/differentiable.py +++ b/devito/finite_differences/differentiable.py @@ -11,7 +11,7 @@ from sympy.core.decorators import call_highest_priority from sympy.core.evalf import evalf_table -from devito.finite_differences.tools import make_shift_x0 +from devito.finite_differences.tools import make_shift_x0, coeff_priority from devito.logger import warning from devito.tools import (as_tuple, filter_ordered, flatten, frozendict, infer_dtype, is_integer, split) @@ -130,11 +130,8 @@ def coefficients(self): coefficients = {f.coefficients for f in self._functions} # If there is multiple ones, we have to revert to the highest priority # i.e we have to remove symbolic - if len(coefficients) == 2: - return (coefficients - {'symbolic'}).pop() - else: - assert len(coefficients) == 1 - return coefficients.pop() + key = lambda x: coeff_priority[x] + return sorted(coefficients, key=key, reverse=True)[0] @cached_property def _coeff_symbol(self, *args, **kwargs): diff --git a/devito/finite_differences/tools.py b/devito/finite_differences/tools.py index 0aa56f316b..a52b9625be 100644 --- a/devito/finite_differences/tools.py +++ b/devito/finite_differences/tools.py @@ -248,6 +248,7 @@ def numeric_weights(function, deriv_order, indices, x0): fd_weights_registry = {'taylor': numeric_weights, 'standard': numeric_weights, 'symbolic': symbolic_weights} +coeff_priority = {'taylor': 1, 'standard': 1, 'symbolic': 0} def generate_indices(expr, dim, order, side=None, matvec=None, x0=None): diff --git a/tests/test_symbolic_coefficients.py b/tests/test_symbolic_coefficients.py index 1b66aac3e5..abbdbdb3b1 100644 --- a/tests/test_symbolic_coefficients.py +++ b/tests/test_symbolic_coefficients.py @@ -510,3 +510,10 @@ def test_compound_nested_subs(self): # `str` simply because some objects are of type EvalDerivative assert str(eq.evaluate.rhs) == str(term0 + term1 + term2) + + def test_priority(self): + grid = Grid(shape=(11,)) + m = Function(name='m', grid=grid, space_order=2, coefficients='symbolic') + p = Function(name='p', grid=grid, space_order=2) + + assert (p*m).coefficients == 'taylor'