Skip to content

Commit

Permalink
Switch to mosaic for planar plotting.
Browse files Browse the repository at this point in the history
- Interface to the `plot_horiz_field` has been simplified slightly.

- The mask array now must be the same shape as the input field array.

- Added a framework function to convert a cell mask to and edge mask.
  • Loading branch information
andrewdnolan committed Dec 24, 2024
1 parent d6bc3a1 commit 260b8cf
Show file tree
Hide file tree
Showing 10 changed files with 148 additions and 268 deletions.
1 change: 1 addition & 0 deletions polaris/mpas/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from polaris.mpas.area import area_for_field
from polaris.mpas.mask import cell_mask_2_edge_mask
from polaris.mpas.time import time_index_from_xtime, time_since_start
35 changes: 35 additions & 0 deletions polaris/mpas/mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@

def cell_mask_2_edge_mask(ds_mesh, cell_mask):
"""Convert a cell mask to edge mask using mesh connectivity information
True corresponds to valid cells and False are invalid cells
Parameters
----------
ds_mesh : xarray.Dataset
The MPAS mesh
cell_mask : xarray.DataArray
The cell mask we want to convert to an edge mask
Returns
-------
edge_mask : xarray.DataArray
The edge mask corresponding to the input cell mask
"""

# test if any are False
if ~cell_mask.any():
return ds_mesh.nEdges > -1

# zero index the connectivity array
cellsOnEdge = (ds_mesh.cellsOnEdge - 1)

# using nCells (dim) instead of indexToCellID since it's already 0 indexed
masked_cells = ds_mesh.nCells.where(~cell_mask, drop=True).astype(int)

# use inverse so True/False convention matches input cell_mask
edge_mask = ~cellsOnEdge.isin(masked_cells).any("TWO")

return edge_mask
10 changes: 6 additions & 4 deletions polaris/ocean/tasks/baroclinic_channel/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from polaris import Step
from polaris.mesh.planar import compute_planar_hex_nx_ny
from polaris.mpas import cell_mask_2_edge_mask
from polaris.ocean.vertical import init_vertical_coord
from polaris.ocean.viz import compute_transect, plot_transect
from polaris.viz import plot_horiz_field
Expand Down Expand Up @@ -163,10 +164,11 @@ def run(self):
write_netcdf(ds, 'initial_state.nc')

cell_mask = ds.maxLevelCell >= 1
edge_mask = cell_mask_2_edge_mask(ds, cell_mask)

plot_horiz_field(ds, ds_mesh, 'normalVelocity',
plot_horiz_field(ds_mesh, ds['normalVelocity'],
'initial_normal_velocity.png', cmap='cmo.balance',
show_patch_edges=True, cell_mask=cell_mask)
show_patch_edges=True, field_mask=edge_mask)

y_min = ds_mesh.yVertex.min().values
y_max = ds_mesh.yVertex.max().values
Expand All @@ -191,6 +193,6 @@ def run(self):
vmin=vmin, vmax=vmax, cmap='cmo.thermal',
colorbar_label=r'$^\circ$C', color_start_and_end=True)

plot_horiz_field(ds, ds_mesh, 'temperature', 'initial_temperature.png',
plot_horiz_field(ds_mesh, ds['temperature'], 'initial_temperature.png',
vmin=vmin, vmax=vmax, cmap='cmo.thermal',
cell_mask=cell_mask, transect_x=x, transect_y=y)
field_mask=cell_mask, transect_x=x, transect_y=y)
4 changes: 2 additions & 2 deletions polaris/ocean/tasks/baroclinic_channel/rpe/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,10 @@ def run(self):
time_index = np.argmin(np.abs(times - time))

cell_mask = ds_init.maxLevelCell >= 1
plot_horiz_field(ds, ds_mesh, 'temperature', ax=ax,
plot_horiz_field(ds_mesh, ds['temperature'], ax=ax,
cmap='cmo.thermal', t_index=time_index,
vmin=min_temp, vmax=max_temp,
cmap_title='SST (C)', cell_mask=cell_mask)
cmap_title='SST (C)', field_mask=cell_mask)
ax.set_title(f'day {times[time_index]:g}, $\\nu_h=${nu:g}')

plt.savefig(output_filename)
12 changes: 7 additions & 5 deletions polaris/ocean/tasks/baroclinic_channel/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import xarray as xr

from polaris import Step
from polaris.mpas import cell_mask_2_edge_mask
from polaris.ocean.viz import compute_transect, plot_transect
from polaris.viz import plot_horiz_field

Expand Down Expand Up @@ -43,13 +44,14 @@ def run(self):
ds = xr.load_dataset('output.nc')
t_index = ds.sizes['Time'] - 1
cell_mask = ds_init.maxLevelCell >= 1
edge_mask = cell_mask_2_edge_mask(ds_init, cell_mask)
max_velocity = np.max(np.abs(ds.normalVelocity.values))
plot_horiz_field(ds, ds_mesh, 'normalVelocity',
plot_horiz_field(ds_mesh, ds['normalVelocity'],
'final_normalVelocity.png',
t_index=t_index,
vmin=-max_velocity, vmax=max_velocity,
cmap='cmo.balance', show_patch_edges=True,
cell_mask=cell_mask)
field_mask=edge_mask)

y_min = ds_mesh.yVertex.min().values
y_max = ds_mesh.yVertex.max().values
Expand All @@ -76,7 +78,7 @@ def run(self):
vmin=vmin, vmax=vmax, cmap='cmo.thermal',
colorbar_label=r'$^\circ$C', color_start_and_end=True)

plot_horiz_field(ds, ds_mesh, 'temperature', 'final_temperature.png',
plot_horiz_field(ds_mesh, ds['temperature'], 'final_temperature.png',
t_index=t_index, vmin=vmin, vmax=vmax,
cmap='cmo.thermal', cell_mask=cell_mask, transect_x=x,
transect_y=y)
cmap='cmo.thermal', field_mask=cell_mask,
transect_x=x, transect_y=y)
4 changes: 2 additions & 2 deletions polaris/ocean/tasks/barotropic_gyre/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def run(self):

cell_mask = ds.maxLevelCell >= 1

plot_horiz_field(ds_forcing, ds_mesh, 'windStressZonal',
plot_horiz_field(ds_mesh, ds_forcing['windStressZonal'],
'forcing_wind_stress_zonal.png', cmap='cmo.balance',
show_patch_edges=True, cell_mask=cell_mask,
show_patch_edges=True, field_mask=cell_mask,
vmin=-0.1, vmax=0.1)
42 changes: 22 additions & 20 deletions polaris/ocean/tasks/ice_shelf_2d/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import xarray as xr

from polaris import Step
from polaris.mpas import cell_mask_2_edge_mask
from polaris.ocean.viz import compute_transect, plot_transect
from polaris.viz import plot_horiz_field

Expand Down Expand Up @@ -118,9 +119,10 @@ def run(self):

# Plot water column thickness horizontal ds_init
cell_mask = ds_init.maxLevelCell >= 1
plot_horiz_field(ds_horiz, ds_mesh, 'columnThickness',
edge_mask = cell_mask_2_edge_mask(ds_init, cell_mask)
plot_horiz_field(ds_mesh, ds_horiz['columnThickness'],
'H_horiz_init.png', t_index=None,
cell_mask=cell_mask)
field_mask=cell_mask)

time_index = -1 # Plot the final state
ds_horiz = self._process_ds(ds, ds_ice, ds_init,
Expand All @@ -130,27 +132,27 @@ def run(self):
vmax_del_ssh = np.max(ds_horiz.delSsh.values)
vmax_del_p = np.amax(ds_horiz.delLandIcePressure.values)
# Plot water column thickness horizontal
plot_horiz_field(ds_horiz, ds_mesh, 'columnThickness',
plot_horiz_field(ds_mesh, ds_horiz['columnThickness'],
f'H_horiz_t{time_index}.png', t_index=None,
cell_mask=cell_mask)
plot_horiz_field(ds_horiz, ds_mesh, 'landIceFreshwaterFlux',
field_mask=cell_mask)
plot_horiz_field(ds_mesh, ds_horiz['landIceFreshwaterFlux'],
f'melt_horiz_t{time_index}.png', t_index=None,
cell_mask=cell_mask)
field_mask=cell_mask)
if 'wettingVelocityFactor' in ds_horiz.keys():
plot_horiz_field(ds_horiz, ds_mesh, 'wettingVelocityFactor',
plot_horiz_field(ds_mesh, ds_horiz['wettingVelocityFactor'],
f'wet_horiz_t{time_index}.png', t_index=None,
z_index=None, cell_mask=cell_mask,
z_index=None, field_mask=edge_mask,
vmin=0, vmax=1, cmap='cmo.ice')
# Plot difference in ssh
plot_horiz_field(ds_horiz, ds_mesh, 'delSsh',
plot_horiz_field(ds_mesh, ds_horiz['delSsh'],
f'del_ssh_horiz_t{time_index}.png', t_index=None,
cell_mask=cell_mask,
field_mask=cell_mask,
vmin=vmin_del_ssh, vmax=vmax_del_ssh)

# Plot difference in land ice pressure
plot_horiz_field(ds_horiz, ds_mesh, 'delLandIcePressure',
plot_horiz_field(ds_mesh, ds_horiz['delLandIcePressure'],
f'del_land_ice_pressure_horiz_t{time_index}.png',
t_index=None, cell_mask=cell_mask,
t_index=None, field_mask=cell_mask,
vmin=-vmax_del_p, vmax=vmax_del_p,
cmap='cmo.balance')

Expand All @@ -163,24 +165,24 @@ def run(self):
max_level_cell=ds_init.maxLevelCell - 1,
spherical=False)

plot_horiz_field(ds, ds_mesh, 'velocityX',
plot_horiz_field(ds_mesh, ds['velocityX'],
f'u_surf_horiz_t{time_index}.png', t_index=time_index,
z_index=0, cell_mask=cell_mask,
z_index=0, field_mask=cell_mask,
vmin=-vmax_uv, vmax=vmax_uv,
cmap_title=r'm/s', cmap='cmo.balance')
plot_horiz_field(ds, ds_mesh, 'velocityX',
plot_horiz_field(ds_mesh, ds['velocityX'],
f'u_bot_horiz_t{time_index}.png', t_index=time_index,
z_index=-1, cell_mask=cell_mask,
z_index=-1, field_mask=cell_mask,
vmin=-vmax_uv, vmax=vmax_uv,
cmap_title=r'm/s', cmap='cmo.balance')
plot_horiz_field(ds, ds_mesh, 'velocityY',
plot_horiz_field(ds_mesh, ds['velocityY'],
f'v_surf_horiz_t{time_index}.png', t_index=time_index,
z_index=0, cell_mask=cell_mask,
z_index=0, field_mask=cell_mask,
vmin=-vmax_uv, vmax=vmax_uv,
cmap_title=r'm/s', cmap='cmo.balance')
plot_horiz_field(ds, ds_mesh, 'velocityY',
plot_horiz_field(ds_mesh, ds['velocityY'],
f'v_bot_horiz_t{time_index}.png', t_index=time_index,
z_index=-1, cell_mask=cell_mask,
z_index=-1, field_mask=cell_mask,
vmin=-vmax_uv, vmax=vmax_uv,
cmap_title=r'm/s', cmap='cmo.balance')
plot_transect(ds_transect,
Expand Down
14 changes: 7 additions & 7 deletions polaris/ocean/tasks/inertial_gravity_wave/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,19 +93,19 @@ def run(self):
error_range = np.max(np.abs(ds.ssh_error.values))

cell_mask = ds_init.maxLevelCell >= 1
patches, patch_mask = plot_horiz_field(
ds, ds_mesh, 'ssh', ax=axes[i, 0], cmap='cmo.balance',
descriptor = plot_horiz_field(
ds_mesh, ds['ssh'], ax=axes[i, 0], cmap='cmo.balance',
t_index=ds.sizes["Time"] - 1, vmin=-eta0, vmax=eta0,
cmap_title="SSH (m)", cell_mask=cell_mask)
plot_horiz_field(ds, ds_mesh, 'ssh_exact', ax=axes[i, 1],
cmap_title="SSH (m)", field_mask=cell_mask)
plot_horiz_field(ds_mesh, ds['ssh_exact'], ax=axes[i, 1],
cmap='cmo.balance',
vmin=-eta0, vmax=eta0, cmap_title="SSH (m)",
patches=patches, patch_mask=patch_mask)
plot_horiz_field(ds, ds_mesh, 'ssh_error', ax=axes[i, 2],
descriptor=descriptor)
plot_horiz_field(ds_mesh, ds['ssh_error'], ax=axes[i, 2],
cmap='cmo.balance',
cmap_title=r"$\Delta$ SSH (m)",
vmin=-error_range, vmax=error_range,
patches=patches, patch_mask=patch_mask)
descriptor=descriptor)

axes[0, 0].set_title('Numerical solution')
axes[0, 1].set_title('Analytical solution')
Expand Down
12 changes: 6 additions & 6 deletions polaris/ocean/tasks/manufactured_solution/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,18 +176,18 @@ def run(self):
error_range = np.max(np.abs(ds.ssh_error.values))

cell_mask = ds_init.maxLevelCell >= 1
patches, patch_mask = plot_horiz_field(
descriptor = plot_horiz_field(
ds, ds_mesh, 'ssh', ax=axes[i, 0], cmap='cmo.balance',
t_index=ds.sizes["Time"] - 1, vmin=-eta0, vmax=eta0,
cmap_title="SSH", cell_mask=cell_mask)
plot_horiz_field(ds, ds_mesh, 'ssh_exact', ax=axes[i, 1],
cmap_title="SSH", field_mask=cell_mask)
plot_horiz_field(ds_mesh, ds['ssh_exact'], ax=axes[i, 1],
cmap='cmo.balance',
vmin=-eta0, vmax=eta0, cmap_title="SSH",
patches=patches, patch_mask=patch_mask)
plot_horiz_field(ds, ds_mesh, 'ssh_error', ax=axes[i, 2],
descriptor=descriptor)
plot_horiz_field(ds_mesh, ds['ssh_error'], ax=axes[i, 2],
cmap='cmo.balance', cmap_title="dSSH",
vmin=-error_range, vmax=error_range,
patches=patches, patch_mask=patch_mask)
descriptor=descriptor)

axes[0, 0].set_title('Numerical solution')
axes[0, 1].set_title('Analytical solution')
Expand Down
Loading

0 comments on commit 260b8cf

Please sign in to comment.