Skip to content

Commit

Permalink
Merge pull request #1208 from PCMDI/monsoon-wang-fig-update
Browse files Browse the repository at this point in the history
Monsoon wang fig update
  • Loading branch information
lee1043 authored Dec 24, 2024
2 parents 9205c7f + d398d34 commit 025d2f0
Show file tree
Hide file tree
Showing 7 changed files with 363 additions and 193 deletions.
274 changes: 174 additions & 100 deletions doc/jupyter/Demo/Demo_2a_monsoon_wang.ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions pcmdi_metrics/io/regions.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ def load_regions_specs() -> dict:
+========+==================+====================+
| AllMW | (-40.0, 45.0) | (0.0, 360.0) |
+--------+------------------+--------------------+
| AllM | (-45.0, 45.0) | (0.0, 360.0) |
+--------+------------------+--------------------+
| NAMM | (0.0, 45.0) | (210.0, 310.0) |
+--------+------------------+--------------------+
| SAMM | (-45.0, 0.0) | (240.0, 330.0) |
Expand Down
2 changes: 1 addition & 1 deletion pcmdi_metrics/monsoon_wang/lib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .argparse_functions import create_monsoon_wang_parser
from .monsoon_precip_index_fncs import mpd, mpi_skill_scores, regrid
from .plot import plot_monsoon_wang_maps
from .plot import map_plotter, plot_monsoon_wang_maps
198 changes: 152 additions & 46 deletions pcmdi_metrics/monsoon_wang/lib/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,76 @@
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from cartopy.mpl.ticker import LatitudeFormatter, LongitudeFormatter


def map_plotter(domain, title, ds, save_path=None):
"""
Plot a map of the domain with a given title and data from a dataset.
Parameters
----------
domain : str
Domain name of monsoon domain for the map
title : str
Title for the figure
ds : xarray.Dataset
Dataset containing the following variables:
- 'obsmap': observation map
- 'modmap': model map
- 'obsmask': observation mask (optional)
- 'modmask': model mask (optional)
- 'hitmap': hit map (optional)
- 'missmap': miss map (optional)
- 'falarmmap': false alarm map (optional)
save_path : str, optional
Path to save the figure, by default None
Returns
-------
None
"""
if domain in ["ASM"]:
central_longitude = 180
else:
central_longitude = 0

if domain in ["ASM", "AllMW", "NAFM", "NAMM", "AllM"]:
legend_loc = "upper left"
else:
legend_loc = "lower left"

if domain in ["AllMW", "AllM"]:
hit_size = 1
miss_size = 2
falarm_size = 1
else:
hit_size = 2
miss_size = 4
falarm_size = 3

plot_monsoon_wang_maps(
ds,
central_longitude=central_longitude,
title=title,
colorbar_label="Monsoon Annual Range (mm/day)",
colormap="Spectral_r",
legend_loc=legend_loc,
hit_size=hit_size,
miss_size=miss_size,
falarm_size=falarm_size,
save_path=save_path,
)


def plot_monsoon_wang_maps(
ds: xr.Dataset,
central_longitude: float = 180,
title: str = None,
fig_size: tuple = (12, 10),
fig_size: tuple = (8, 6),
colormap: str = "viridis",
colorbar_label: str = None,
levels: int = 21,
num_levels: int = 21,
legend_loc: str = "lower left",
projection: ccrs.Projection = None,
hit_color: str = "blue",
Expand Down Expand Up @@ -44,12 +104,12 @@ def plot_monsoon_wang_maps(
title : str, optional
Title for the figure, by default None
fig_size : tuple, optional
Figure size (width, height) in inches, by default (12, 10)
Figure size (width, height) in inches, by default (8, 6)
colormap : str, optional
Colormap name for the maps from matplotlib, by default 'viridis'
colorbar_label : str, optional
Label for the colorbar, by default None
levels : int, optional
num_levels : int, optional
Number of discrete color levels for the maps, by default 21
legend_loc : str, optional
Location of the legend for the markers, by default 'lower left'
Expand All @@ -76,6 +136,10 @@ def plot_monsoon_wang_maps(
plot_falarm : bool, optional
Whether to plot false alarm map, by default True
"""
ds = ds.copy(deep=True)
ds["obsmap"] = ds["obsmap"] * 86400 # Convert from kg m-2 s-1 to mm/day
ds["modmap"] = ds["modmap"] * 86400 # Convert from kg m-2 s-1 to mm/day

# Check if required variables exist in the dataset
required_vars = ["obsmap", "modmap"]
if not all(var in ds for var in required_vars):
Expand All @@ -87,8 +151,12 @@ def plot_monsoon_wang_maps(
vmin = min(ds["obsmap"].min().item(), ds["modmap"].min().item())
vmax = max(ds["obsmap"].max().item(), ds["modmap"].max().item())

# Round vmin and vmax to the nearest nice value (e.g., nearest integer)
vmin = np.floor(vmin) # Round down
vmax = np.ceil(vmax) # Round up

# Create discrete color levels
levels = np.linspace(vmin, vmax, levels)
levels = np.linspace(vmin, vmax, num_levels)
cmap = plt.get_cmap(colormap, len(levels) - 1)
norm = mcolors.BoundaryNorm(boundaries=levels, ncolors=cmap.N)

Expand All @@ -112,48 +180,18 @@ def plot_monsoon_wang_maps(
ax2 = fig.add_subplot(grid[1, 0], projection=projection)

# Plot obsmap
ds["obsmap"].plot(
ax=ax1, transform=ccrs.PlateCarree(), cmap=cmap, norm=norm, add_colorbar=False
)
ax1.coastlines()
ax1.set_title("Observation Map")
plot_map(ds, "obsmap", ax1, cmap, norm, title="Observation Map")
add_monsoon_domain(ds, "obsmask", ax1)

# Plot modmap
ds["modmap"].plot(
ax=ax2, transform=ccrs.PlateCarree(), cmap=cmap, norm=norm, add_colorbar=False
)
ax2.coastlines()
ax2.set_title("Model Map")

# Function to plot overlay if available and requested
def plot_overlay(map_name, plot_flag, marker, color, size, label):
if plot_flag and map_name in ds:
# Extract coordinates and boolean arrays for overlays
mask = ds[map_name] # Boolean array for hitmap
# Extract lat, lon values where hitmap and missmap are True
mask_coords = mask == 1 # Mask of where hitmap is True
mask_lons, mask_lats = np.meshgrid(
ds["lon"], ds["lat"]
) # Create meshgrid for coordinates
# Filter coordinates where the hitmap and missmap are True
mask_lons = mask_lons[mask_coords]
mask_lats = mask_lats[mask_coords]
# Overlay hitmap (circles) on modmap
ax2.plot(
mask_lons,
mask_lats,
marker,
color=color,
markersize=size,
transform=ccrs.PlateCarree(),
label=label,
)

# Plot overlays
plot_overlay("hitmap", plot_hit, "o", hit_color, hit_size, "Hit")
plot_overlay("missmap", plot_miss, "x", miss_color, miss_size, "Miss")
plot_map(ds, "modmap", ax2, cmap, norm, title="Model Map")
add_monsoon_domain(ds, "modmask", ax2)

# Plot overlaying markers
plot_overlay(ds, ax2, "hitmap", plot_hit, "o", hit_color, hit_size, "Hit")
plot_overlay(ds, ax2, "missmap", plot_miss, "x", miss_color, miss_size, "Miss")
plot_overlay(
"falarmmap", plot_falarm, "^", falarm_color, falarm_size, "False Alarm"
ds, ax2, "falarmmap", plot_falarm, "^", falarm_color, falarm_size, "False Alarm"
)

# Add a legend for the markers if any overlay is plotted
Expand All @@ -166,9 +204,9 @@ def plot_overlay(map_name, plot_flag, marker, color, size, label):
cm.ScalarMappable(norm=norm, cmap=cmap), cax=cbar_ax, orientation="vertical"
)
cbar.set_ticks(levels)
cbar.ax.tick_params(labelsize=8)
cbar.ax.tick_params(labelsize=10)
if colorbar_label is not None:
cbar.set_label(colorbar_label)
cbar.set_label(colorbar_label, fontsize=15)

# Add a title
if title is not None:
Expand All @@ -179,3 +217,71 @@ def plot_overlay(map_name, plot_flag, marker, color, size, label):
plt.savefig(save_path, dpi=300, bbox_inches="tight")

plt.show()


# Support functions for plot_monsoon_wang_maps


def plot_map(ds, var, ax, cmap, norm, title=None):
da = ds[var]
da.plot(
ax=ax, transform=ccrs.PlateCarree(), cmap=cmap, norm=norm, add_colorbar=False
)
ax.coastlines()
if title is not None:
ax.set_title(title)

# draw parallels and meridians by adding grid lines
gl = ax.gridlines(
draw_labels=True, crs=ccrs.PlateCarree(), linestyle="--", color="k"
)
gl.xformatter = LongitudeFormatter()
gl.yformatter = LatitudeFormatter()
gl.top_labels = False
gl.right_labels = False
gl.xlabel_style = {"size": 12}
gl.ylabel_style = {"size": 12}


def add_monsoon_domain(ds, mask_var_name, ax):
has_nan = np.isnan(ds[mask_var_name]).any()

if has_nan:
# Create a binary mask: 1 for NaN, 0 for non-NaN
binary_mask = np.isnan(ds[mask_var_name])

# Draw a contour along the boundary between NaN and non-NaN
ax.contour(
ds[mask_var_name].lon, # Replace with your longitude coordinates
ds[mask_var_name].lat, # Replace with your latitude coordinates
binary_mask,
levels=[0.5], # This draws the contour along the transition from 0 to 1
transform=ccrs.PlateCarree(),
colors="grey",
linewidths=1,
)


# Function to plot overlay if available and requested
def plot_overlay(ds, ax, map_name, plot_flag, marker, color, size, label):
if plot_flag and map_name in ds:
# Extract coordinates and boolean arrays for overlays
mask = ds[map_name] # Boolean array for hitmap
# Extract lat, lon values where hitmap and missmap are True
mask_coords = mask == 1 # Mask of where hitmap is True
mask_lons, mask_lats = np.meshgrid(
ds["lon"], ds["lat"]
) # Create meshgrid for coordinates
# Filter coordinates where the hitmap and missmap are True
mask_lons = mask_lons[mask_coords]
mask_lats = mask_lats[mask_coords]
# Overlay hitmap (circles) on modmap
ax.plot(
mask_lons,
mask_lats,
marker,
color=color,
markersize=size,
transform=ccrs.PlateCarree(),
label=label,
)
75 changes: 33 additions & 42 deletions pcmdi_metrics/monsoon_wang/monsoon_wang_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,20 @@
from pcmdi_metrics.io import da_to_ds, region_subset
from pcmdi_metrics.monsoon_wang.lib import (
create_monsoon_wang_parser,
map_plotter,
mpd,
mpi_skill_scores,
plot_monsoon_wang_maps,
regrid,
)
from pcmdi_metrics.utils import StringConstructor


def main():
P = create_monsoon_wang_parser()
args = P.get_parameter(argparse_vals_only=False)
monsoon_wang_runner(args)


def monsoon_wang_runner(args):
modpath = StringConstructor(args.test_data_path)
modpath.variable = args.modvar
Expand All @@ -42,26 +48,6 @@ def monsoon_wang_runner(args):
# Get flag for CMEC output
cmec = args.cmec

# ########################################
# PMP monthly default PR obs
fobs = xr.open_dataset(args.reference_data_path, decode_times=False)
dobs_orig = fobs[args.obsvar]
fobs.close()

# #######################################

# FCN TO COMPUTE GLOBAL ANNUAL RANGE AND MONSOON PRECIP INDEX

annrange_obs, mpi_obs = mpd(dobs_orig)

# create monsoon domain mask based on observations: annual range > 2.5 mm/day
if args.obs_mask:
domain_mask_obs = xr.where(annrange_obs > thr, 1, 0)
domain_mask_obs.name = "mask"
mpi_obs = mpi_obs.where(domain_mask_obs)
nout_mpi_obs = os.path.join(outpathdata, "mpi_obs_masked.nc")
da_to_ds(mpi_obs).to_netcdf(nout_mpi_obs)

# ########################################
# SETUP WHERE TO OUTPUT RESULTING DATA (netcdf)
nout = os.path.join(
Expand Down Expand Up @@ -99,10 +85,29 @@ def monsoon_wang_runner(args):
raise RuntimeError("No model file found!")

# ########################################
# PMP monthly default PR obs

fobs = xr.open_dataset(args.reference_data_path, decode_times=False)
dobs_orig = fobs[args.obsvar]
fobs.close()

# #######################################
# FCN TO COMPUTE GLOBAL ANNUAL RANGE AND MONSOON PRECIP INDEX

annrange_obs, mpi_obs = mpd(dobs_orig)

# create monsoon domain mask based on observations: annual range > 2.5 mm/day
if args.obs_mask:
domain_mask_obs = xr.where(annrange_obs > thr, 1, 0)
domain_mask_obs.name = "mask"
mpi_obs = mpi_obs.where(domain_mask_obs)

nout_mpi_obs = os.path.join(nout, "mpi_obs_masked.nc")
da_to_ds(mpi_obs).to_netcdf(nout_mpi_obs)

egg_pth = resources.resource_path()

doms = ["AllMW", "AllM", "NAMM", "SAMM", "NAFM", "SAFM", "ASM", "AUSM"]
doms = ["AllMW", "NAMM", "SAMM", "NAFM", "SAFM", "ASM", "AUSM"]

mpi_stats_dic = {}
for i, mod in enumerate(gmods):
Expand Down Expand Up @@ -178,31 +183,19 @@ def monsoon_wang_runner(args):
"hitmap": hitmap,
"missmap": missmap,
"falarmmap": falarmmap,
"obsmask": mpi_obs_reg,
"modmask": mpi_mod_reg,
}
)
ds_out.to_netcdf(fm)

# PLOT FIGURES
if dom in ["ASM"]:
central_longitude = 180
else:
central_longitude = 0

if dom in ["ASM", "AllMW", "NAFM", "NAMM", "AllM"]:
legend_loc = "upper left"
else:
legend_loc = "lower left"

title = f"{mod}, {dom}"

save_path = os.path.join(nout, "_".join([mod, dom, "wang-monsoon.png"]))

plot_monsoon_wang_maps(
map_plotter(
dom,
title,
ds_out,
central_longitude=central_longitude,
title=title,
colormap="Spectral_r",
legend_loc=legend_loc,
save_path=save_path,
)

Expand Down Expand Up @@ -240,6 +233,4 @@ def monsoon_wang_runner(args):


if __name__ == "__main__":
P = create_monsoon_wang_parser()
args = P.get_parameter(argparse_vals_only=False)
monsoon_wang_runner(args)
main()
Loading

0 comments on commit 025d2f0

Please sign in to comment.