Skip to content

Commit

Permalink
Major changes to workflow using multiprocessing for computing CTP fro…
Browse files Browse the repository at this point in the history
…m UFS/HR1 forecast data.
  • Loading branch information
DanielAdriaansen committed Dec 23, 2024
1 parent fc07382 commit e56cf2c
Showing 1 changed file with 49 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,35 @@ def apply_args_and_kwargs(fn, args, kwargs):
input_file = sys.argv[1]
tmpvarname = sys.argv[2]
prsvarname = sys.argv[3]
mask_file = sys.argv[4]

# Open the input_file as an Xarray Dataset
if os.path.splitext(input_file)[1]=='.nc':
ds = xr.open_dataset(input_file)
ds = ds[[tmpvarname,prsvarname,'pressfc']]
else:
print("FATAL! pyembed_ctp_fcst_HR1.py.")
print("Unable to open input file.")
sys.exit(1)

# Determine the input dims
indims = ds.sizes
if not ('grid_xt' in ds.coords) or not ('grid_yt' in ds.coords):
print("FATAL! unexpected dimension names in FCST file.")
sys.exit(1)
else:
ny = indims['grid_yt']
nx = indims['grid_xt']

# Flip the latitudes across the equator
#ds = ds.reindex(grid_yt=ds.grid_yt[::-1])

# Open the mask file
maskdata = xr.open_dataset(mask_file)

# Add the mask variable to the data
ds['maskvar'] = xr.DataArray(maskdata['RAOB_SITES'].values,dims=['grid_yt','grid_xt'],coords={'grid_yt':ds.grid_yt,'grid_xt':ds.grid_xt})

# The files that were used to develop this use case need special treatment of the pressure field.
# Find the "bk_interp" attribute
try:
Expand Down Expand Up @@ -73,79 +93,58 @@ def apply_args_and_kwargs(fn, args, kwargs):
tmp3d = tmp3d.isel(pfull=slice(0,len(z0)))

# Change the vertical coordinate and dimension for the temperature data to be z0
tmp3d = tmp3d.expand_dims(dim={'z0':z0}).assign_coords({'z0':z0}).isel(pfull=0).squeeze()
#tmp3d = tmp3d.expand_dims(dim={'z0':z0}).assign_coords({'z0':z0}).isel(pfull=0).squeeze()
#tmp3d = tmp3d.expand_dims(dim={'z0':z0}).assign_coords({'z0':z0}).squeeze()
tmp3d = tmp3d.rename({'pfull':'z0'}).assign_coords({'z0':z0})
tmp3d = tmp3d*units('degK')

# Create the 3D pressure variable
prs3d = xr.DataArray(bk_interp,dims=['z0'],coords={'z0':z0},attrs={'units':'Pa'}).broadcast_like(tmp3d)
prs3d = (prs3d*(ds['pressfc'].squeeze()))*units('Pa').to('hPa')

print(tmp3d)
print(prs3d)

# Get a pool of workers
mp = multiprocessing.Pool(multiprocessing.cpu_count()-2)

# Stack the data in the x-y dimension into a single dimension named "sid".
# This treats each grid cell/column like a "site"
tmpstack = tmp3d.stack(sid=("grid_yt","grid_xt"))
prsstack = prs3d.stack(sid=("grid_yt","grid_xt"))
mskstack = ds['maskvar'].stack(sid=("grid_yt","grid_xt"))

print("COMPUTING CTP FOR %10d CELLS." % (int(tmpstack.sizes['sid'])))

# Create a list of dictionaries equal to the number of times the function will be called.
# Each dictionary entry is another keyword argument for the function.
kwargs_iter = [{'station_index':idx} for idx in list(range(0,tmpstack.sizes['sid']))]
print("KWARGS OK")
# Create an Xarray DataArray like the stacked variables to hold the results
resstack = xr.full_like(mskstack,-9999.).rename('ctp')

# Create a iterable for each of the positional arguments for the function to be called.
# Using "zip" creates the iterable and "repeat" will repeat the item equivalent to the length
# of the kwargs iterator, which in this case is what's varying for each function call.
args_iter = zip(repeat(prsstack),repeat(tmpstack))
print("ITER OK")
# Subset the data to only the points where the mask is
prs_mask = prsstack[:,mskstack>0]
tmp_mask = tmpstack[:,mskstack>0]

# Pass the function name, the pool of workers, and the positional and keyword arg iterators
# to the multiprocessing helper function
print("CALLING STARMAP_WITH_KWARGS")
result = starmap_with_kwargs(mp,calc_ctp,args_iter,kwargs_iter)

# A pint.Quantity is returned from the function, so get the magnitude
print("COMPUTING CTP FOR %10d CELLS." % (int(tmpstack[:,mskstack>0].sizes['sid'])))
result = mp.starmap(calc_ctp,([prs_mask,tmp_mask,sidx] for sidx in list(range(0,tmp_mask.sizes['sid']))))
result = [x.m for x in result]

# Re-populate the stacked array with the values at the correct locations
resstack[mskstack>0] = result

# Put the results back into an Xarray DataArray and assign the multi-index variable from
# stacking earlier so we can unstack the data into a 2D grid
#met_data = xr.DataArray(result,dims=['sid'],coords={'sid':tmpstack.sid},attrs={'units':'J/kg'}).unstack().to_netcdf('test.nc')
met_data = xr.DataArray(result,dims=['sid'],coords={'sid':tmpstack.sid},attrs={'units':'J/kg'}).unstack()
print(met_data)
exit()

#
print(result)
exit()

# Compute the CTP over all grid cells
result = mp.starmap(calc_ctp,[(prsvar,stack,idx) for idx in list(range(0,stack.sizes['sid']))])
result = [x.m for x in result]
#print(result)
result2 = xr.DataArray(result,dims=['sid'],coords={'sid':stack.sid},attrs={})
#print(result2)
print(result2.unstack())
exit()

# Unpack the result at each grid cell back to the 2D grid
met_data = unpack_results(result,nx,ny)
print(met_data)

exit()
#met_data = xr.DataArray(result,dims=['sid'],coords={'sid':tmpstack.sid},attrs={'units':'J/kg'}).unstack()

met_data = ds['soilw'].isel(depthBelowLandLayer=0).values
print(met_data)
# Unstack the data from the `sid` dimension back to just grid_xt and grid_yt (2D) and obtain the NumPy N-D array
met_data = resstack.unstack()
met_data = met_data.reindex(grid_yt=met_data.grid_yt[::-1])
met_data.to_netcdf('test.nc')
met_data = met_data.values

grid_attrs = {}
grid_attrs['type'] = 'LatLon'
grid_attrs['type'] = 'Gaussian'
grid_attrs['name'] = 'HR1'
grid_attrs['lat_ll'] = -89.910324
grid_attrs['lon_ll'] = 0.0
grid_attrs['delta_lat'] = 0.117188
grid_attrs['delta_lon'] = 0.117188
grid_attrs['Nlat'] = 1536
grid_attrs['Nlon'] = 3072
grid_attrs['lon_zero'] = 0.0
grid_attrs['nx'] = nx
grid_attrs['ny'] = ny

attrs = {}
attrs['valid'] = '20200805_120000'
Expand All @@ -156,5 +155,5 @@ def apply_args_and_kwargs(fn, args, kwargs):
attrs['long_name'] = 'long_test'
attrs['level'] = 'surface'
attrs['units'] = 'test'
attrs['fill_value'] = -9999.
attrs['grid'] = grid_attrs
#attrs['grid'] = "G040"

0 comments on commit e56cf2c

Please sign in to comment.