-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #39 from NOAA-GFDL/feature/boilerplate
Boilerplate code for fast deploy of NDSL
- Loading branch information
Showing
2 changed files
with
169 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
from typing import Tuple | ||
|
||
import numpy as np | ||
|
||
from ndsl import ( | ||
CompilationConfig, | ||
DaceConfig, | ||
DaCeOrchestration, | ||
GridIndexing, | ||
NullComm, | ||
QuantityFactory, | ||
RunMode, | ||
StencilConfig, | ||
StencilFactory, | ||
SubtileGridSizer, | ||
TileCommunicator, | ||
TilePartitioner, | ||
) | ||
|
||
|
||
def _get_factories( | ||
nx: int, | ||
ny: int, | ||
nz: int, | ||
nhalo, | ||
backend: str, | ||
orchestration: DaCeOrchestration, | ||
topology: str, | ||
) -> Tuple[StencilFactory, QuantityFactory]: | ||
"""Build a Stencil & Quantity factory for a combination of options. | ||
Dev Note: We don't expose this function because we want the boilerplate to remain | ||
as easy and self describing as possible. It should be a very easy call to make. | ||
The other reason is that the orchestration requires two inputs instead of change | ||
a backend name for now, making it confusing. Until refactor, we choose to hide this | ||
pattern for boilerplate use. | ||
""" | ||
dace_config = DaceConfig( | ||
communicator=None, | ||
backend=backend, | ||
orchestration=orchestration, | ||
) | ||
|
||
compilation_config = CompilationConfig( | ||
backend=backend, | ||
rebuild=True, | ||
validate_args=True, | ||
format_source=False, | ||
device_sync=False, | ||
run_mode=RunMode.BuildAndRun, | ||
use_minimal_caching=False, | ||
) | ||
|
||
stencil_config = StencilConfig( | ||
compare_to_numpy=False, | ||
compilation_config=compilation_config, | ||
dace_config=dace_config, | ||
) | ||
|
||
if topology == "tile": | ||
partitioner = TilePartitioner((1, 1)) | ||
sizer = SubtileGridSizer.from_tile_params( | ||
nx_tile=nx, | ||
ny_tile=ny, | ||
nz=nz, | ||
n_halo=nhalo, | ||
extra_dim_lengths={}, | ||
layout=partitioner.layout, | ||
tile_partitioner=partitioner, | ||
) | ||
comm = TileCommunicator(comm=NullComm(0, 1, 42), partitioner=partitioner) | ||
else: | ||
raise NotImplementedError(f"Topology {topology} is not implemented.") | ||
|
||
grid_indexing = GridIndexing.from_sizer_and_communicator(sizer, comm) | ||
stencil_factory = StencilFactory(config=stencil_config, grid_indexing=grid_indexing) | ||
quantity_factory = QuantityFactory(sizer, np) | ||
|
||
return stencil_factory, quantity_factory | ||
|
||
|
||
def get_factories_single_tile_orchestrated_cpu( | ||
nx, ny, nz, nhalo | ||
) -> Tuple[StencilFactory, QuantityFactory]: | ||
"""Build a Stencil & Quantity factory for orchestrated CPU, on a single tile toplogy.""" | ||
return _get_factories( | ||
nx=nx, | ||
ny=ny, | ||
nz=nz, | ||
nhalo=nhalo, | ||
backend="dace:cpu", | ||
orchestration=DaCeOrchestration.BuildAndRun, | ||
topology="tile", | ||
) | ||
|
||
|
||
def get_factories_single_tile_numpy( | ||
nx, ny, nz, nhalo | ||
) -> Tuple[StencilFactory, QuantityFactory]: | ||
"""Build a Stencil & Quantity factory for Numpy, on a single tile toplogy.""" | ||
return _get_factories( | ||
nx=nx, | ||
ny=ny, | ||
nz=nz, | ||
nhalo=nhalo, | ||
backend="numpy", | ||
orchestration=DaCeOrchestration.Python, | ||
topology="tile", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import numpy as np | ||
from gt4py.cartesian.gtscript import PARALLEL, computation, interval | ||
|
||
from ndsl import QuantityFactory, StencilFactory | ||
from ndsl.constants import X_DIM, Y_DIM, Z_DIM | ||
from ndsl.dsl.typing import FloatField | ||
|
||
|
||
def _copy_ops(stencil_factory: StencilFactory, quantity_factory: QuantityFactory): | ||
# Allocate data and fill input | ||
qty_out = quantity_factory.zeros(dims=[X_DIM, Y_DIM, Z_DIM], units="n/a") | ||
qty_in = quantity_factory.zeros(dims=[X_DIM, Y_DIM, Z_DIM], units="n/a") | ||
qty_in.view[:] = np.indices( | ||
dimensions=quantity_factory.sizer.get_extent([X_DIM, Y_DIM, Z_DIM]), | ||
dtype=np.float64, | ||
).sum( | ||
axis=0 | ||
) # Value of each entry is sum of the I and J index at each point | ||
|
||
# Define a stencil | ||
def copy_stencil(input_field: FloatField, output_field: FloatField): | ||
with computation(PARALLEL), interval(...): | ||
output_field = input_field | ||
|
||
# Execute | ||
copy = stencil_factory.from_dims_halo( | ||
func=copy_stencil, compute_dims=[X_DIM, Y_DIM, Z_DIM] | ||
) | ||
copy(qty_in, qty_out) | ||
assert (qty_in.view[:] == qty_out.view[:]).all() | ||
|
||
|
||
def test_boilerplate_import_numpy(): | ||
"""Test make sure the basic numpy boilerplate works as expected. | ||
Dev Note: the import inside the function are part of the test. | ||
""" | ||
from ndsl.boilerplate import get_factories_single_tile_numpy | ||
|
||
# Boilerplate | ||
stencil_factory, quantity_factory = get_factories_single_tile_numpy( | ||
nx=5, ny=5, nz=2, nhalo=1 | ||
) | ||
|
||
_copy_ops(stencil_factory, quantity_factory) | ||
|
||
|
||
def test_boilerplate_import_orchestrated_cpu(): | ||
"""Test make sure the basic orchestrate boilerplate works as expected. | ||
Dev Note: the import inside the function are part of the test. | ||
""" | ||
from ndsl.boilerplate import get_factories_single_tile_orchestrated_cpu | ||
|
||
# Boilerplate | ||
stencil_factory, quantity_factory = get_factories_single_tile_orchestrated_cpu( | ||
nx=5, ny=5, nz=2, nhalo=1 | ||
) | ||
|
||
_copy_ops(stencil_factory, quantity_factory) |