Skip to content

Commit

Permalink
Adjust naming & verbose choice for private generic function
Browse files Browse the repository at this point in the history
  • Loading branch information
FlorianDeconinck committed May 20, 2024
1 parent 50b8d82 commit 0f8906b
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 28 deletions.
65 changes: 41 additions & 24 deletions ndsl/boilerplate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@

from ndsl import (
CompilationConfig,
CubedSphereCommunicator,
CubedSpherePartitioner,
DaceConfig,
DaCeOrchestration,
GridIndexing,
LocalComm,
NullComm,
QuantityFactory,
RunMode,
Expand All @@ -18,12 +21,22 @@
)


def _get_one_tile_factory(
nx, ny, nz, nhalo, backend, orchestration
def _get_factories(
nx: int,
ny: int,
nz: int,
nhalo,
backend: str,
orchestration: DaCeOrchestration,
topology: str,
) -> Tuple[StencilFactory, QuantityFactory]:
"""Build a Stencil & Quantity factory for:
- one tile
- no MPI communicator
"""Build a Stencil & Quantity factory for a combination of options.
Dev Note: We don't expose this function because we want the boilerplate to remain
as easy and self describing as possible. It should be a very easy call to make.
The other reason is that the orchestration requires two inputs instead of change
a backend name for now, making it confusing. Until refactor, we choose to hide this
pattern for boilerplate use.
"""
dace_config = DaceConfig(
communicator=None,
Expand All @@ -47,49 +60,53 @@ def _get_one_tile_factory(
dace_config=dace_config,
)

partitioner = TilePartitioner((1, 1))
sizer = SubtileGridSizer.from_tile_params(
nx_tile=nx,
ny_tile=ny,
nz=nz,
n_halo=nhalo,
extra_dim_lengths={},
layout=partitioner.layout,
tile_partitioner=partitioner,
)

tile_comm = TileCommunicator(comm=NullComm(0, 1, 42), partitioner=partitioner)
if topology == "tile":
partitioner = TilePartitioner((1, 1))
sizer = SubtileGridSizer.from_tile_params(
nx_tile=nx,
ny_tile=ny,
nz=nz,
n_halo=nhalo,
extra_dim_lengths={},
layout=partitioner.layout,
tile_partitioner=partitioner,
)
comm = TileCommunicator(comm=NullComm(0, 1, 42), partitioner=partitioner)
else:
raise NotImplementedError(f"Topology {topology} is not implemented.")

grid_indexing = GridIndexing.from_sizer_and_communicator(sizer, tile_comm)
grid_indexing = GridIndexing.from_sizer_and_communicator(sizer, comm)
stencil_factory = StencilFactory(config=stencil_config, grid_indexing=grid_indexing)
quantity_factory = QuantityFactory(sizer, np)

return stencil_factory, quantity_factory


def get_one_tile_factory_orchestrated_cpu(
def get_factories_single_tile_orchestrated_cpu(
nx, ny, nz, nhalo
) -> Tuple[StencilFactory, QuantityFactory]:
"""Build a Stencil & Quantity factory for orchestrated CPU"""
return _get_one_tile_factory(
"""Build a Stencil & Quantity factory for orchestrated CPU, on a single tile toplogy."""
return _get_factories(
nx=nx,
ny=ny,
nz=nz,
nhalo=nhalo,
backend="dace:cpu",
orchestration=DaCeOrchestration.BuildAndRun,
topology="tile",
)


def get_one_tile_factory_numpy(
def get_factories_single_tile_numpy(
nx, ny, nz, nhalo
) -> Tuple[StencilFactory, QuantityFactory]:
"""Build a Stencil & Quantity factory for Numpy"""
return _get_one_tile_factory(
"""Build a Stencil & Quantity factory for Numpy, on a single tile toplogy."""
return _get_factories(
nx=nx,
ny=ny,
nz=nz,
nhalo=nhalo,
backend="numpy",
orchestration=DaCeOrchestration.Python,
topology="tile",
)
8 changes: 4 additions & 4 deletions tests/test_boilerplate.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ def test_boilerplate_import_numpy():
Dev Note: the import inside the function are part of the test.
"""
from ndsl.boilerplate import get_one_tile_factory_numpy
from ndsl.boilerplate import get_factories_single_tile_numpy

# Boilerplate
stencil_factory, quantity_factory = get_one_tile_factory_numpy(
stencil_factory, quantity_factory = get_factories_single_tile_numpy(
nx=5, ny=5, nz=2, nhalo=1
)

Expand All @@ -50,10 +50,10 @@ def test_boilerplate_import_orchestrated_cpu():
Dev Note: the import inside the function are part of the test.
"""
from ndsl.boilerplate import get_one_tile_factory_orchestrated_cpu
from ndsl.boilerplate import get_factories_single_tile_orchestrated_cpu

# Boilerplate
stencil_factory, quantity_factory = get_one_tile_factory_orchestrated_cpu(
stencil_factory, quantity_factory = get_factories_single_tile_orchestrated_cpu(
nx=5, ny=5, nz=2, nhalo=1
)

Expand Down

0 comments on commit 0f8906b

Please sign in to comment.