Skip to content

Commit

Permalink
Merge pull request #2283 from devitocodes/fix-pickle-cudahip
Browse files Browse the repository at this point in the history
compiler: Fix minor codegen issues after pickling
  • Loading branch information
FabioLuporini authored Dec 15, 2023
2 parents 2a9533f + 9ee3471 commit 8f45ba0
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 54 deletions.
68 changes: 24 additions & 44 deletions devito/arch/archinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,9 +597,19 @@ def get_platform():

class Platform(object):

registry = {}

def __init__(self, name):
self.name = name

self.registry[name] = self

def __eq__(self, other):
return isinstance(other, Platform) and self.name == other.name

def __hash__(self):
return hash(self.name)

@classmethod
def _mro(cls):
return [Platform]
Expand Down Expand Up @@ -815,17 +825,17 @@ def march(cls):
CPU64_DUMMY = Intel64('cpu64-dummy', cores_logical=2, cores_physical=1, isa='sse')

INTEL64 = Intel64('intel64')
SNB = Intel64('snb')
IVB = Intel64('ivb')
HSW = Intel64('hsw')
BDW = Intel64('bdw', isa='avx2')
KNL = Intel64('knl')
KNL7210 = Intel64('knl', cores_logical=256, cores_physical=64, isa='avx512')
SKX = IntelSkylake('skx')
KLX = IntelSkylake('klx')
CLX = IntelSkylake('clx')
CLK = IntelSkylake('clk')
SPR = IntelGoldenCove('spr')
SNB = Intel64('snb') # Sandy Bridge
IVB = Intel64('ivb') # Ivy Bridge
HSW = Intel64('hsw') # Haswell
BDW = Intel64('bdw', isa='avx2') # Broadwell
KNL = Intel64('knl') # Knights Landing
KNL7210 = Intel64('knl7210', cores_logical=256, cores_physical=64, isa='avx512')
SKX = IntelSkylake('skx') # Skylake
KLX = IntelSkylake('klx') # Kaby Lake
CLX = IntelSkylake('clx') # Coffee Lake
CLK = IntelSkylake('clk') # Cascade Lake
SPR = IntelGoldenCove('spr') # Sapphire Rapids

ARM = Arm('arm')
GRAVITON = Arm('graviton')
Expand All @@ -841,40 +851,10 @@ def march(cls):
AMDGPUX = AmdDevice('amdgpuX')
INTELGPUX = IntelDevice('intelgpuX')

PVC = IntelDevice('pvc', max_threads_per_block=4096)


platform_registry = {
'cpu64-dummy': CPU64_DUMMY,
'intel64': INTEL64,
'snb': SNB, # Sandy Bridge
'ivb': IVB, # Ivy Bridge
'hsw': HSW, # Haswell
'bdw': BDW, # Broadwell
'skx': SKX, # Skylake
'klx': KLX, # Kaby Lake
'clx': CLX, # Coffee Lake
'clk': CLK, # Cascade Lake
'spr': SPR, # Sapphire Rapids
'knl': KNL,
'knl7210': KNL7210,
'arm': ARM, # Generic ARM CPU
'graviton': GRAVITON, # AMS arm
'm1': M1,
'amd': AMD, # Generic AMD CPU
'power8': POWER8,
'power9': POWER9,
'nvidiaX': NVIDIAX, # Generic NVidia GPU
'amdgpuX': AMDGPUX, # Generic AMD GPU
'intelgpuX': INTELGPUX, # Generic Intel GPU
'pvc': PVC # Intel Ponte Vecchio GPU
}
"""
Registry dict for deriving Platform classes according to the environment variable
DEVITO_PLATFORM. Developers should add new platform classes here.
"""
platform_registry['cpu64'] = get_platform # Autodetection
PVC = IntelDevice('pvc', max_threads_per_block=4096) # Intel Ponte Vecchio GPU

platform_registry = Platform.registry
platform_registry['cpu64'] = get_platform # Autodetection

isa_registry = {
'cpp': 16,
Expand Down
6 changes: 4 additions & 2 deletions devito/ir/support/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class GuardBoundGt(BaseGuardBound, Gt):
# *** GuardBoundNext


class BaseGuardBoundNext(Guard):
class BaseGuardBoundNext(Guard, Pickable):

"""
A guard to avoid out-of-bounds iteration.
Expand All @@ -118,11 +118,13 @@ class BaseGuardBoundNext(Guard):
given `direction`.
"""

__rargs__ = ('d', 'direction')

def __new__(cls, d, direction, **kwargs):
assert isinstance(d, Dimension)
assert isinstance(direction, IterationDirection)

if direction is Forward:
if direction == Forward:
p0 = d.root
p1 = d.root.symbolic_max

Expand Down
7 changes: 3 additions & 4 deletions devito/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1088,13 +1088,12 @@ def comm(self):

@property
def opkwargs(self):
temp_registry = {v: k for k, v in platform_registry.items()}
platform = temp_registry[self.platform]

temp_registry = {v: k for k, v in compiler_registry.items()}
compiler = temp_registry[self.compiler.__class__]

return {'platform': platform, 'compiler': compiler, 'language': self.language}
return {'platform': self.platform.name,
'compiler': compiler,
'language': self.language}


def parse_kwargs(**kwargs):
Expand Down
25 changes: 24 additions & 1 deletion tests/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
Dimension, SubDimension, ConditionalDimension, IncrDimension,
TimeDimension, SteppingDimension, Operator, MPI, Min, solve,
PrecomputedSparseTimeFunction)
from devito.ir import GuardFactor
from devito.ir import Backward, Forward, GuardFactor, GuardBound, GuardBoundNext
from devito.data import LEFT, OWNED
from devito.mpi.halo_scheme import Halo
from devito.mpi.routines import (MPIStatusObject, MPIMsgEnriched, MPIRequestObject,
Expand Down Expand Up @@ -388,6 +388,29 @@ def test_guard_factor(self, pickle):

assert str(gf) == str(new_gf)

def test_guard_bound(self, pickle):
d = Dimension(name='d')

gb = GuardBound(d, 3)

pkl_gb = pickle.dumps(gb)
new_gb = pickle.loads(pkl_gb)

assert str(gb) == str(new_gb)

@pytest.mark.parametrize('direction', [Backward, Forward])
def test_guard_bound_next(self, pickle, direction):
d = Dimension(name='d')
cd = ConditionalDimension(name='cd', parent=d, factor=4)

for i in [d, cd]:
gbn = GuardBoundNext(i, direction)

pkl_gbn = pickle.dumps(gbn)
new_gbn = pickle.loads(pkl_gbn)

assert str(gbn) == str(new_gbn)

def test_temp_function(self, pickle):
grid = Grid(shape=(3, 3))
d = Dimension(name='d')
Expand Down
8 changes: 5 additions & 3 deletions tests/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import numpy as np
import scipy.sparse

from devito import Grid, TimeFunction, Eq, Operator, Dimension, Function
from devito import (SparseFunction, SparseTimeFunction, PrecomputedSparseFunction,
PrecomputedSparseTimeFunction, MatrixSparseTimeFunction)
from devito import (Grid, TimeFunction, Eq, Operator, Dimension, Function,
SparseFunction, SparseTimeFunction, PrecomputedSparseFunction,
PrecomputedSparseTimeFunction, MatrixSparseTimeFunction,
switchconfig)


_sptypes = [SparseFunction, SparseTimeFunction,
Expand Down Expand Up @@ -455,6 +456,7 @@ def test_subs(self, sptype):
assert getattr(sps, subf).indices[0] == new_spdim
assert np.all(getattr(sps, subf).data == getattr(sp, subf).data)

@switchconfig(safe_math=True)
@pytest.mark.parallel(mode=[1, 4])
def test_mpi_no_data(self):
grid = Grid((11, 11), extent=(10, 10))
Expand Down

0 comments on commit 8f45ba0

Please sign in to comment.