diff --git a/polaris/ocean/tasks/sphere_transport/filament_analysis.py b/polaris/ocean/tasks/sphere_transport/filament_analysis.py index bb7ff8404..a3f196439 100644 --- a/polaris/ocean/tasks/sphere_transport/filament_analysis.py +++ b/polaris/ocean/tasks/sphere_transport/filament_analysis.py @@ -1,11 +1,10 @@ -import datetime - import matplotlib.pyplot as plt import numpy as np import pandas as pd import xarray as xr from polaris import Step +from polaris.mpas import time_index_from_xtime from polaris.ocean.resolution import resolution_to_subdir from polaris.viz import use_mplstyle @@ -93,8 +92,8 @@ def run(self): for i, resolution in enumerate(resolutions): mesh_name = resolution_to_subdir(resolution) ds = xr.open_dataset(f'{mesh_name}_output.nc') - tidx = _time_index_from_xtime(ds.xtime.values, - eval_time * s_per_day) + tidx = time_index_from_xtime(ds.xtime.values, + eval_time * s_per_day) tracer = ds[variable_name] area_cell = ds["areaCell"] for j, tau in enumerate(filament_tau): @@ -122,29 +121,3 @@ def run(self): col_headers.append(f'{tau:g}') df = pd.DataFrame(data, columns=col_headers) df.to_csv('filament.csv', index=False) - - -def _time_index_from_xtime(xtime, dt_target): - """ - Determine the time index at which to evaluate convergence - - Parameters - ---------- - xtime: list of str - Times in the dataset - dt_target: float - Time in seconds at which to evaluate convergence - - Returns - ------- - tidx: int - Index in xtime that is closest to dt_target - """ - t0 = datetime.datetime.strptime(xtime[0].decode(), - '%Y-%m-%d_%H:%M:%S') - dt = np.zeros((len(xtime))) - for idx, xt in enumerate(xtime): - t = datetime.datetime.strptime(xt.decode(), - '%Y-%m-%d_%H:%M:%S') - dt[idx] = (t - t0).total_seconds() - return np.argmin(np.abs(np.subtract(dt, dt_target))) diff --git a/polaris/ocean/tasks/sphere_transport/mixing_analysis.py b/polaris/ocean/tasks/sphere_transport/mixing_analysis.py index 60b6de0ec..750518a7f 100644 --- a/polaris/ocean/tasks/sphere_transport/mixing_analysis.py +++ b/polaris/ocean/tasks/sphere_transport/mixing_analysis.py @@ -1,4 +1,3 @@ -import datetime from math import ceil import matplotlib.pyplot as plt @@ -7,6 +6,7 @@ from matplotlib.lines import Line2D from polaris import Step +from polaris.mpas import time_index_from_xtime from polaris.ocean.resolution import resolution_to_subdir from polaris.viz import use_mplstyle @@ -99,8 +99,8 @@ def run(self): ax.set_ylabel("tracer3") if int(i / 2) == nrows - 1: ax.set_xlabel("tracer2") - tidx = _time_index_from_xtime(ds.xtime.values, - eval_time * s_per_day) + tidx = time_index_from_xtime(ds.xtime.values, + eval_time * s_per_day) ds = ds.isel(Time=tidx) ds = ds.isel(nVertLevels=zidx) tracer2 = ds["tracer2"].values @@ -147,29 +147,3 @@ def _init_triplot_axes(ax): ax.text(0.5, 0.27, 'Real mixing', rotation=-40., fontsize=8) ax.text(0.02, 0.1, 'Overshooting', rotation=90., fontsize=8) ax.grid(color='lightgray') - - -def _time_index_from_xtime(xtime, dt_target): - """ - Determine the time index at which to evaluate convergence - - Parameters - ---------- - xtime: list of str - Times in the dataset - dt_target: float - Time in seconds at which to evaluate convergence - - Returns - ------- - tidx: int - Index in xtime that is closest to dt_target - """ - t0 = datetime.datetime.strptime(xtime[0].decode(), - '%Y-%m-%d_%H:%M:%S') - dt = np.zeros((len(xtime))) - for idx, xt in enumerate(xtime): - t = datetime.datetime.strptime(xt.decode(), - '%Y-%m-%d_%H:%M:%S') - dt[idx] = (t - t0).total_seconds() - return np.argmin(np.abs(np.subtract(dt, dt_target))) diff --git a/polaris/ocean/tasks/sphere_transport/viz.py b/polaris/ocean/tasks/sphere_transport/viz.py index 0ca988a1e..c943eebc5 100644 --- a/polaris/ocean/tasks/sphere_transport/viz.py +++ b/polaris/ocean/tasks/sphere_transport/viz.py @@ -1,10 +1,8 @@ -import datetime - import cmocean # noqa: F401 -import numpy as np import xarray as xr from polaris import Step +from polaris.mpas import time_index_from_xtime from polaris.remap import MappingFileStep from polaris.viz import plot_global_field @@ -149,16 +147,16 @@ def run(self): # Visualization at halfway around the globe (provided run duration is # set to the time needed to circumnavigate the globe) - tidx = _time_index_from_xtime(ds_out.xtime.values, - run_duration * seconds_per_day / 2.) + tidx = time_index_from_xtime(ds_out.xtime.values, + run_duration * seconds_per_day / 2.) ds_mid = ds_out[variables_to_plot.keys()].isel(Time=tidx, nVertLevels=0) ds_mid = remapper.remap(ds_mid) ds_mid.to_netcdf('remapped_mid.nc') # Visualization at all the way around the globe - tidx = _time_index_from_xtime(ds_out.xtime.values, - run_duration * seconds_per_day) + tidx = time_index_from_xtime(ds_out.xtime.values, + run_duration * seconds_per_day) ds_final = ds_out[variables_to_plot.keys()].isel(Time=tidx, nVertLevels=0) ds_final = remapper.remap(ds_final) @@ -192,30 +190,3 @@ def run(self): title=f'Difference in {mesh_name} {var} from ' f'initial condition after {run_duration:g} days', plot_land=False) - - -def _time_index_from_xtime(xtime, dt_target): - """ - Determine the time index at which to evaluate convergence - - Parameters - ---------- - xtime : list of str - Times in the dataset - - dt_target : float - Time in seconds at which to evaluate convergence - - Returns - ------- - tidx : int - Index in xtime that is closest to dt_target - """ - t0 = datetime.datetime.strptime(xtime[0].decode(), - '%Y-%m-%d_%H:%M:%S') - dt = np.zeros((len(xtime))) - for idx, xt in enumerate(xtime): - t = datetime.datetime.strptime(xt.decode(), - '%Y-%m-%d_%H:%M:%S') - dt[idx] = (t - t0).total_seconds() - return np.argmin(np.abs(np.subtract(dt, dt_target)))