diff --git a/conda_package/mpas_tools/mesh/creation/sort_mesh.py b/conda_package/mpas_tools/mesh/creation/sort_mesh.py new file mode 100644 index 000000000..2eb75ad79 --- /dev/null +++ b/conda_package/mpas_tools/mesh/creation/sort_mesh.py @@ -0,0 +1,241 @@ + +import numpy as np +import os +import xarray +import argparse +from scipy.sparse import csr_matrix +from scipy.sparse.csgraph import reverse_cuthill_mckee + + +def sort_mesh(mesh): + """ + Sort cells, edges and duals in the mesh + to improve cache-locality + + Parameters + ---------- + mesh : xarray.Dataset + A dataset containing an MPAS mesh to sort + + Returns + ------- + mesh : xarray.Dataset + A dataset containing the sorted MPAS mesh + """ + # Authors: Darren Engwirda + + ncells = mesh.dims["nCells"] + nedges = mesh.dims["nEdges"] + nduals = mesh.dims["nVertices"] + + cell_fwd = np.arange(0, ncells) + 1 + cell_rev = np.arange(0, ncells) + 1 + edge_fwd = np.arange(0, nedges) + 1 + edge_rev = np.arange(0, nedges) + 1 + dual_fwd = np.arange(0, nduals) + 1 + dual_rev = np.arange(0, nduals) + 1 + + # sort cells via RCM ordering of adjacency matrix + + cell_fwd = reverse_cuthill_mckee(_cell_del2(mesh)) + 1 + + cell_rev = np.zeros(ncells, dtype=np.int32) + cell_rev[cell_fwd - 1] = np.arange(ncells) + 1 + + mesh["cellsOnCell"][:] = \ + _sort_rev(mesh["cellsOnCell"], cell_rev) + mesh["cellsOnEdge"][:] = \ + _sort_rev(mesh["cellsOnEdge"], cell_rev) + mesh["cellsOnVertex"][:] = \ + _sort_rev(mesh["cellsOnVertex"], cell_rev) + + for var in mesh.keys(): + dims = mesh.variables[var].dims + if ("nCells" in dims): + mesh[var][:] = _sort_fwd(mesh[var], cell_fwd) + + mesh["indexToCellID"][:] = np.arange(ncells) + 1 + + # sort duals via pseudo-linear cell-wise ordering + + dual_fwd = np.ravel(mesh["verticesOnCell"].values) + dual_fwd = dual_fwd[dual_fwd > 0] + + __, imap = np.unique(dual_fwd, return_index=True) + + dual_fwd = dual_fwd[np.sort(imap)] + + dual_rev = np.zeros(nduals, dtype=np.int32) + dual_rev[dual_fwd - 1] = np.arange(nduals) + 1 + + mesh["verticesOnCell"][:] = \ + _sort_rev(mesh["verticesOnCell"], dual_rev) + + mesh["verticesOnEdge"][:] = \ + _sort_rev(mesh["verticesOnEdge"], dual_rev) + + for var in mesh.keys(): + dims = mesh.variables[var].dims + if ("nVertices" in dims): + mesh[var][:] = _sort_fwd(mesh[var], dual_fwd) + + mesh["indexToVertexID"][:] = np.arange(nduals) + 1 + + # sort edges via pseudo-linear cell-wise ordering + + edge_fwd = np.ravel(mesh["edgesOnCell"].values) + edge_fwd = edge_fwd[edge_fwd > 0] + + __, imap = np.unique(edge_fwd, return_index=True) + + edge_fwd = edge_fwd[np.sort(imap)] + + edge_rev = np.zeros(nedges, dtype=np.int32) + edge_rev[edge_fwd - 1] = np.arange(nedges) + 1 + + mesh["edgesOnCell"][:] = \ + _sort_rev(mesh["edgesOnCell"], edge_rev) + + mesh["edgesOnEdge"][:] = \ + _sort_rev(mesh["edgesOnEdge"], edge_rev) + + mesh["edgesOnVertex"][:] = \ + _sort_rev(mesh["edgesOnVertex"], edge_rev) + + for var in mesh.keys(): + dims = mesh.variables[var].dims + if ("nEdges" in dims): + mesh[var][:] = _sort_fwd(mesh[var], edge_fwd) + + mesh["indexToEdgeID"][:] = np.arange(nedges) + 1 + + return mesh + + +def main(): + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawTextHelpFormatter) + + parser.add_argument( + "--mesh-file", dest="mesh_file", type=str, + required=True, help="Path+name to unsorted mesh file.") + + parser.add_argument( + "--sort-file", dest="sort_file", type=str, + required=True, help="Path+name to sorted output file.") + + args = parser.parse_args() + + mesh = xarray.open_dataset(args.mesh_file) + + sort_mesh(mesh) + + with open(os.path.join(os.path.dirname( + args.sort_file), "graph.info"), "w") as fptr: + cellsOnCell = mesh["cellsOnCell"].values + + ncells = mesh.dims["nCells"] + nedges = np.count_nonzero(cellsOnCell) // 2 + + fptr.write(f"{ncells} {nedges}\n") + for cell in range(ncells): + data = cellsOnCell[cell, :] + data = data[data > 0] + for item in data: + fptr.write(f"{item} ") + fptr.write("\n") + + mesh.to_netcdf(args.sort_file, format="NETCDF4") + + +def _sort_fwd(data, fwd): + """ + Apply a forward permutation to a mesh array + + Parameters + ---------- + data : array-like + An MPAS mesh array to permute + + fwd : numpy.ndarray + An array of integers defining the permutation + + Returns + ------- + data : numpy.ndarray + The forward permuted MPAS mesh array + """ + vals = data.values + vals = vals[fwd - 1] + return vals + + +def _sort_rev(data, rev): + """ + Apply a reverse permutation to a mesh array + + Parameters + ---------- + data : array-like + An MPAS mesh array to permute + + rev : numpy.ndarray + An array of integers defining the permutation + + Returns + ------- + data : numpy.ndarray + The reverse permuted MPAS mesh array + """ + vals = data.values + mask = vals > 0 + vals[mask] = rev[vals[mask] - 1] + return vals + + +def _cell_del2(mesh): + """ + Form cell-to-cell sparse adjacency graph + + Parameters + ---------- + mesh : xarray.Dataset + A dataset containing an MPAS mesh to sort + + Returns + ------- + del2 : scipy.sparse.csr_matrix + The cell-to-cell adjacency graph as a sparse matrix + """ + xvec = np.array([], dtype=np.int8) + ivec = np.array([], dtype=np.int32) + jvec = np.array([], dtype=np.int32) + + topolOnCell = mesh["nEdgesOnCell"].values + cellsOnCell = mesh["cellsOnCell"].values + + for edge in range(np.max(topolOnCell)): + + # cell-to-cell pairs, if edges exist + mask = topolOnCell > edge + idx_self = np.argwhere(mask).ravel() + idx_next = cellsOnCell[mask, edge] - 1 + + # cell-to-cell pairs, if cells exist + mask = idx_next >= 0 + idx_self = idx_self[mask] + idx_next = idx_next[mask] + + # dummy matrix values, just topol. needed + val_edge = np.ones(idx_next.size, dtype=np.int8) + + ivec = np.hstack((ivec, idx_self)) + jvec = np.hstack((jvec, idx_next)) + xvec = np.hstack((xvec, val_edge)) + + return csr_matrix((xvec, (ivec, jvec))) + + +if (__name__ == "__main__"):\ + main() diff --git a/conda_package/recipe/meta.yaml b/conda_package/recipe/meta.yaml index 8312dceb2..572cc292c 100644 --- a/conda_package/recipe/meta.yaml +++ b/conda_package/recipe/meta.yaml @@ -33,6 +33,7 @@ build: - prepare_seaice_partitions = mpas_tools.seaice.partition:prepare_partitions - create_seaice_partitions = mpas_tools.seaice.partition:create_partitions - simple_seaice_partitions = mpas_tools.seaice.partition:simple_partitions + - sort_mesh = mpas_tools.mesh.creation.sort_mesh:main requirements: build: @@ -101,6 +102,7 @@ test: - planar_hex --nx=20 --ny=40 --dc=1000. --outFileName='periodic_mesh_20x40_1km.nc' - translate_planar_grid -f 'periodic_mesh_10x20_1km.nc' -d 'periodic_mesh_20x40_1km.nc' - MpasMeshConverter.x mesh_tools/mesh_conversion_tools/test/mesh.QU.1920km.151026.nc mesh.nc + - sort_mesh --mesh-file mesh.nc --sort-file sorted_mesh.nc - MpasCellCuller.x mesh.nc culled_mesh.nc -m mesh_tools/mesh_conversion_tools/test/land_mask_final.nc # - MpasMaskCreator.x mesh.nc arctic_mask.nc -f mesh_tools/mesh_conversion_tools/test/Arctic_Ocean.geojson - planar_hex --nx=30 --ny=20 --dc=1000. --npx --npy --outFileName='nonperiodic_mesh_30x20_1km.nc' diff --git a/conda_package/setup.py b/conda_package/setup.py index 82996002a..5506aadcb 100755 --- a/conda_package/setup.py +++ b/conda_package/setup.py @@ -89,6 +89,7 @@ 'mpas_to_triangle = mpas_tools.mesh.creation.mpas_to_triangle:main', 'triangle_to_netcdf = mpas_tools.mesh.creation.triangle_to_netcdf:main', 'jigsaw_to_netcdf = mpas_tools.mesh.creation.jigsaw_to_netcdf:main', + 'sort_mesh = mpas_tools.mesh.creation.sort_mesh:main', 'scrip_from_mpas = mpas_tools.scrip.from_mpas:main', 'compute_mpas_region_masks = mpas_tools.mesh.mask:entry_point_compute_mpas_region_masks', 'compute_mpas_transect_masks = mpas_tools.mesh.mask:entry_point_compute_mpas_transect_masks',