Skip to content

Commit

Permalink
Merge pull request #2436 from devitocodes/assumptions
Browse files Browse the repository at this point in the history
dsl: Fix missing sympy assumptions during rebuilding
  • Loading branch information
FabioLuporini authored Aug 2, 2024
2 parents 98453a6 + ac89150 commit cbd40c3
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 3 deletions.
8 changes: 8 additions & 0 deletions devito/tools/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,14 @@ def __init__(self, a, b, c=4):

kwargs.update({i: getattr(self, i) for i in self.__rkwargs__ if i not in kwargs})

# If this object has SymPy assumptions associated with it, which were not
# in the kwargs, then include them
try:
assumptions = self._assumptions_orig
kwargs.update({k: v for k, v in assumptions.items() if k not in kwargs})
except AttributeError:
pass

# Should we use a custom reconstructor?
try:
cls = self._rcls
Expand Down
5 changes: 3 additions & 2 deletions devito/types/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,10 +338,11 @@ class AbstractSymbol(sympy.Symbol, Basic, Pickable, Evaluable):
def _filter_assumptions(cls, **kwargs):
"""Extract sympy.Symbol-specific kwargs."""
assumptions = {}
# pop predefined assumptions
# Pop predefined assumptions
for key in ('real', 'imaginary', 'commutative'):
kwargs.pop(key, None)
# extract sympy.Symbol-specific kwargs

# Extract sympy.Symbol-specific kwargs
for i in list(kwargs):
if i in _assume_rules.defined_facts:
assumptions[i] = kwargs.pop(i)
Expand Down
11 changes: 11 additions & 0 deletions tests/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
TensorFunction, TensorTimeFunction, VectorTimeFunction)
from devito.types import (DeviceID, NThreadsBase, NPThreads, Object, LocalObject,
Scalar, Symbol, ThreadID)
from devito.types.basic import AbstractSymbol


@pytest.fixture
Expand Down Expand Up @@ -44,6 +45,16 @@ class TestHashing:
Test hashing of symbolic objects.
"""

def test_abstractsymbol(self):
"""Test that different Symbols have different hash values."""
s0 = AbstractSymbol('s')
s1 = AbstractSymbol('s')
assert s0 is not s1
assert hash(s0) == hash(s1)

s2 = AbstractSymbol('s', nonnegative=True)
assert hash(s0) != hash(s2)

def test_constant(self):
"""Test that different Constants have different hash value."""
c0 = Constant(name='c')
Expand Down
18 changes: 17 additions & 1 deletion tests/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
PointerArray, Lock, PThreadArray, SharedData, Timer,
DeviceID, NPThreads, ThreadID, TempFunction, Indirection,
FIndexed)
from devito.types.basic import BoundSymbol
from devito.types.basic import BoundSymbol, AbstractSymbol
from devito.tools import EnrichedTuple
from devito.symbolics import (IntDiv, ListInitializer, FieldFromPointer,
CallFromPointer, DefFunction)
Expand All @@ -29,6 +29,22 @@
@pytest.mark.parametrize('pickle', [pickle0, pickle1])
class TestBasic:

def test_abstractsymbol(self, pickle):
s0 = AbstractSymbol('s')
s1 = AbstractSymbol('s', nonnegative=True, integer=False)

pkl_s0 = pickle.dumps(s0)
pkl_s1 = pickle.dumps(s1)

new_s0 = pickle.loads(pkl_s0)
new_s1 = pickle.loads(pkl_s1)

assert s0.assumptions0 == new_s0.assumptions0
assert s1.assumptions0 == new_s1.assumptions0

assert s0 == new_s0
assert s1 == new_s1

def test_constant(self, pickle):
c = Constant(name='c')
assert c.data == 0.
Expand Down
41 changes: 41 additions & 0 deletions tests/test_symbolics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from devito.tools import as_tuple
from devito.types import (Array, Bundle, FIndexed, LocalObject, Object,
Symbol as dSymbol)
from devito.types.basic import AbstractSymbol


def test_float_indices():
Expand Down Expand Up @@ -70,6 +71,46 @@ def test_floatification_issue_1627(dtype, expected):
assert str(exprs[0]) == expected


def test_sympy_assumptions():
"""
Ensure that AbstractSymbol assumptions are set correctly and
preserved during rebuild.
"""
s0 = AbstractSymbol('s')
s1 = AbstractSymbol('s', nonnegative=True, integer=False, real=True)

assert s0.is_negative is None
assert s0.is_positive is None
assert s0.is_integer is None
assert s0.is_real is True
assert s1.is_negative is False
assert s1.is_positive is True
assert s1.is_integer is False
assert s1.is_real is True

s0r = s0._rebuild()
s1r = s1._rebuild()

assert s0.assumptions0 == s0r.assumptions0
assert s0 == s0r

assert s1.assumptions0 == s1r.assumptions0
assert s1 == s1r


def test_modified_sympy_assumptions():
"""
Check that sympy assumptions can be changed during a rebuild.
"""
s0 = AbstractSymbol('s')
s1 = AbstractSymbol('s', nonnegative=True, integer=False, real=True)

s2 = s0._rebuild(nonnegative=True, integer=False, real=True)

assert s2.assumptions0 == s1.assumptions0
assert s2 == s1


def test_constant():
c = Constant(name='c')

Expand Down

0 comments on commit cbd40c3

Please sign in to comment.