Skip to content

Commit

Permalink
Clean-up mesh sorting
Browse files Browse the repository at this point in the history
  • Loading branch information
dengwirda committed Aug 4, 2023
1 parent 2956e63 commit 42cd356
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 66 deletions.
186 changes: 120 additions & 66 deletions conda_package/mpas_tools/mesh/creation/sort_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,20 @@
from scipy.sparse.csgraph import reverse_cuthill_mckee


def sort_fwd(data, fwd):
vals = data.values
vals = vals[fwd - 1]
return vals


def sort_rev(data, rev):
vals = data.values
mask = vals > 0
vals[mask] = rev[vals[mask] - 1]
return vals


def sort_mesh(mesh):
"""
SORT-MESH: sort cells, edges and duals in the mesh
to improve cache-locality.
Sort cells, edges and duals in the mesh
to improve cache-locality
Parameters
----------
mesh : xarray.Dataset
A dataset containing an MPAS mesh to sort
Returns
-------
mesh : xarray.Dataset
A dataset containing the sorted MPAS mesh
"""
# Authors: Darren Engwirda

Expand All @@ -41,22 +37,22 @@ def sort_mesh(mesh):

# sort cells via RCM ordering of adjacency matrix

cell_fwd = reverse_cuthill_mckee(cell_del2(mesh)) + 1
cell_fwd = reverse_cuthill_mckee(_cell_del2(mesh)) + 1

cell_rev = np.zeros(ncells, dtype=np.int32)
cell_rev[cell_fwd - 1] = np.arange(ncells) + 1

mesh["cellsOnCell"][:] = \
sort_rev(mesh["cellsOnCell"], cell_rev)
_sort_rev(mesh["cellsOnCell"], cell_rev)
mesh["cellsOnEdge"][:] = \
sort_rev(mesh["cellsOnEdge"], cell_rev)
_sort_rev(mesh["cellsOnEdge"], cell_rev)
mesh["cellsOnVertex"][:] = \
sort_rev(mesh["cellsOnVertex"], cell_rev)
_sort_rev(mesh["cellsOnVertex"], cell_rev)

for var in mesh.keys():
dims = mesh.variables[var].dims
if ("nCells" in dims):
mesh[var][:] = sort_fwd(mesh[var], cell_fwd)
mesh[var][:] = _sort_fwd(mesh[var], cell_fwd)

mesh["indexToCellID"][:] = np.arange(ncells) + 1

Expand All @@ -73,15 +69,15 @@ def sort_mesh(mesh):
dual_rev[dual_fwd - 1] = np.arange(nduals) + 1

mesh["verticesOnCell"][:] = \
sort_rev(mesh["verticesOnCell"], dual_rev)
_sort_rev(mesh["verticesOnCell"], dual_rev)

mesh["verticesOnEdge"][:] = \
sort_rev(mesh["verticesOnEdge"], dual_rev)
_sort_rev(mesh["verticesOnEdge"], dual_rev)

for var in mesh.keys():
dims = mesh.variables[var].dims
if ("nVertices" in dims):
mesh[var][:] = sort_fwd(mesh[var], dual_fwd)
mesh[var][:] = _sort_fwd(mesh[var], dual_fwd)

mesh["indexToVertexID"][:] = np.arange(nduals) + 1

Expand All @@ -98,29 +94,120 @@ def sort_mesh(mesh):
edge_rev[edge_fwd - 1] = np.arange(nedges) + 1

mesh["edgesOnCell"][:] = \
sort_rev(mesh["edgesOnCell"], edge_rev)
_sort_rev(mesh["edgesOnCell"], edge_rev)

mesh["edgesOnEdge"][:] = \
sort_rev(mesh["edgesOnEdge"], edge_rev)
_sort_rev(mesh["edgesOnEdge"], edge_rev)

mesh["edgesOnVertex"][:] = \
sort_rev(mesh["edgesOnVertex"], edge_rev)
_sort_rev(mesh["edgesOnVertex"], edge_rev)

for var in mesh.keys():
dims = mesh.variables[var].dims
if ("nEdges" in dims):
mesh[var][:] = sort_fwd(mesh[var], edge_fwd)
mesh[var][:] = _sort_fwd(mesh[var], edge_fwd)

mesh["indexToEdgeID"][:] = np.arange(nedges) + 1

return mesh


def main():
parser = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawTextHelpFormatter)

parser.add_argument(
"--mesh-file", dest="mesh_file", type=str,
required=True, help="Path+name to unsorted mesh file.")

parser.add_argument(
"--sort-file", dest="sort_file", type=str,
required=True, help="Path+name to sorted output file.")

args = parser.parse_args()

mesh = xarray.open_dataset(args.mesh_file)

sort_mesh(mesh)

with open(os.path.join(os.path.dirname(
args.sort_file), "graph.info"), "w") as fptr:
cellsOnCell = mesh["cellsOnCell"].values

ncells = mesh.dims["nCells"]
nedges = np.count_nonzero(cellsOnCell) // 2

fptr.write(f"{ncells} {nedges}\n")
for cell in range(ncells):
data = cellsOnCell[cell, :]
data = data[data > 0]
for item in data:
fptr.write(f"{item} ")
fptr.write("\n")

mesh.to_netcdf(args.sort_file, format="NETCDF4")


def _sort_fwd(data, fwd):
"""
Apply a forward permutation to a mesh array
Parameters
----------
data : array-like
An MPAS mesh array to permute
fwd : numpy.ndarray
An array of integers defining the permutation
Returns
-------
data : numpy.ndarray
The forward permuted MPAS mesh array
"""
vals = data.values
vals = vals[fwd - 1]
return vals


def cell_del2(mesh):
def _sort_rev(data, rev):
"""
CELL-DEL2: form cell-to-cell sparse adjacency graph
Apply a reverse permutation to a mesh array
Parameters
----------
data : array-like
An MPAS mesh array to permute
rev : numpy.ndarray
An array of integers defining the permutation
Returns
-------
data : numpy.ndarray
The reverse permuted MPAS mesh array
"""
vals = data.values
mask = vals > 0
vals[mask] = rev[vals[mask] - 1]
return vals


def _cell_del2(mesh):
"""
Form cell-to-cell sparse adjacency graph
Parameters
----------
mesh : xarray.Dataset
A dataset containing an MPAS mesh to sort
Returns
-------
del2 : scipy.sparse.csr_matrix
The cell-to-cell adjacency graph as a sparse matrix
"""
xvec = np.array([], dtype=np.int8)
ivec = np.array([], dtype=np.int32)
jvec = np.array([], dtype=np.int32)
Expand Down Expand Up @@ -149,39 +236,6 @@ def cell_del2(mesh):

return csr_matrix((xvec, (ivec, jvec)))


if (__name__ == "__main__"):
parser = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawTextHelpFormatter)

parser.add_argument(
"--mesh-file", dest="mesh_file", type=str,
required=True, help="Path+name to unsorted mesh file.")

parser.add_argument(
"--sort-file", dest="sort_file", type=str,
required=True, help="Path+name to sorted output file.")

args = parser.parse_args()

mesh = xarray.open_dataset(args.mesh_file)

sort_mesh(mesh)

with open(os.path.join(os.path.dirname(
args.sort_file), "graph.info"), "w") as fptr:
cellsOnCell = mesh["cellsOnCell"].values

ncells = mesh.dims["nCells"]
nedges = np.count_nonzero(cellsOnCell) // 2

fptr.write("{} {}\n".format(ncells, nedges))
for cell in range(ncells):
data = cellsOnCell[cell, :]
data = data[data > 0]
for item in data:
fptr.write("{} ".format(item))
fptr.write("\n")

mesh.to_netcdf(args.sort_file, format="NETCDF4")

if (__name__ == "__main__"):\
main()
2 changes: 2 additions & 0 deletions conda_package/recipe/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ build:
- prepare_seaice_partitions = mpas_tools.seaice.partition:prepare_partitions
- create_seaice_partitions = mpas_tools.seaice.partition:create_partitions
- simple_seaice_partitions = mpas_tools.seaice.partition:simple_partitions
- sort_mesh = mpas_tools.mesh.creation.sort_mesh:main

requirements:
build:
Expand Down Expand Up @@ -101,6 +102,7 @@ test:
- planar_hex --nx=20 --ny=40 --dc=1000. --outFileName='periodic_mesh_20x40_1km.nc'
- translate_planar_grid -f 'periodic_mesh_10x20_1km.nc' -d 'periodic_mesh_20x40_1km.nc'
- MpasMeshConverter.x mesh_tools/mesh_conversion_tools/test/mesh.QU.1920km.151026.nc mesh.nc
- sort_mesh --mesh-file mesh.nc --sort-file sorted_mesh.nc
- MpasCellCuller.x mesh.nc culled_mesh.nc -m mesh_tools/mesh_conversion_tools/test/land_mask_final.nc
- MpasMaskCreator.x mesh.nc arctic_mask.nc -f mesh_tools/mesh_conversion_tools/test/Arctic_Ocean.geojson
- planar_hex --nx=30 --ny=20 --dc=1000. --npx --npy --outFileName='nonperiodic_mesh_30x20_1km.nc'
Expand Down

0 comments on commit 42cd356

Please sign in to comment.