Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Longitude normalization #160

Merged
merged 13 commits into from
Nov 21, 2024
157 changes: 78 additions & 79 deletions monet/monet_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,13 @@ def _monet_to_latlon(da):
return dset


def _dataset_to_monet(dset, lat_name="latitude", lon_name="longitude", latlon2d=False):
def _dataset_to_monet(
dset,
lat_name="latitude",
lon_name="longitude",
latlon2d=None,
lon180=None,
):
"""Rename xarray DataArray or Dataset coordinate variables for use with monet functions,
returning a new xarray object.

Expand All @@ -74,73 +80,68 @@ def _dataset_to_monet(dset, lat_name="latitude", lon_name="longitude", latlon2d=
Name of the latitude array.
lon_name : str
Name of the longitude array.
latlon2d : bool
latlon2d : bool, optional
Whether the latitude and longitude data is two-dimensional.
If unset (``None``), guess based on dim count.
lon180 : bool, optional
Whether the longitude values are in the range [-180, 180) already.
If true, longitude wrapping/normalization,
which can introduce small floating point errors, will be skipped.
If unset (``None``), compute min/max to determine.
"""
if "grid_xt" in dset.dims:
# GFS v16 file
try:
if isinstance(dset, xr.DataArray):
dset = _dataarray_coards_to_netcdf(dset, lat_name="grid_yt", lon_name="grid_xt")
elif isinstance(dset, xr.Dataset):
dset = _dataarray_coards_to_netcdf(dset, lat_name="grid_yt", lon_name="grid_xt")
else:
raise ValueError
except ValueError:
print("dset must be an xarray.DataArray or xarray.Dataset")
if not isinstance(dset, (xr.DataArray, xr.Dataset)):
raise TypeError("dset must be an xarray.DataArray or xarray.Dataset")

if "grid_xt" in dset.dims: # GFS v16 file
if isinstance(dset, xr.DataArray):
dset = _dataarray_coards_to_netcdf(dset, lat_name="grid_yt", lon_name="grid_xt")
elif isinstance(dset, xr.Dataset):
dset = _dataarray_coards_to_netcdf(dset, lat_name="grid_yt", lon_name="grid_xt")

if "south_north" in dset.dims: # WRF WPS file
dset = dset.rename(dict(south_north="y", west_east="x"))
try:
if isinstance(dset, xr.Dataset):
if "XLAT_M" in dset.data_vars:
dset["XLAT_M"] = dset.XLAT_M.squeeze()
dset["XLONG_M"] = dset.XLONG_M.squeeze()
dset = dset.set_coords(["XLAT_M", "XLONG_M"])
elif "XLAT" in dset.data_vars:
dset["XLAT"] = dset.XLAT.squeeze()
dset["XLONG"] = dset.XLONG.squeeze()
dset = dset.set_coords(["XLAT", "XLONG"])
elif isinstance(dset, xr.DataArray):
if "XLAT_M" in dset.coords:
dset["XLAT_M"] = dset.XLAT_M.squeeze()
dset["XLONG_M"] = dset.XLONG_M.squeeze()
elif "XLAT" in dset.coords:
dset["XLAT"] = dset.XLAT.squeeze()
dset["XLONG"] = dset.XLONG.squeeze()
else:
raise ValueError
except ValueError:
print("dset must be an Xarray.DataArray or Xarray.Dataset")
if isinstance(dset, xr.Dataset):
if "XLAT_M" in dset.data_vars:
dset["XLAT_M"] = dset.XLAT_M.squeeze()
dset["XLONG_M"] = dset.XLONG_M.squeeze()
dset = dset.set_coords(["XLAT_M", "XLONG_M"])
elif "XLAT" in dset.data_vars:
dset["XLAT"] = dset.XLAT.squeeze()
dset["XLONG"] = dset.XLONG.squeeze()
dset = dset.set_coords(["XLAT", "XLONG"])
elif isinstance(dset, xr.DataArray):
if "XLAT_M" in dset.coords:
dset["XLAT_M"] = dset.XLAT_M.squeeze()
dset["XLONG_M"] = dset.XLONG_M.squeeze()
elif "XLAT" in dset.coords:
dset["XLAT"] = dset.XLAT.squeeze()
dset["XLONG"] = dset.XLONG.squeeze()

# Rename lat/lon coordinates to 'latitude'/'longitude'
dset = _rename_to_monet_latlon(dset) # common cases
if (isinstance(dset, xr.Dataset) and not {"latitude", "longitude"} <= set(dset.variables)) or (
isinstance(dset, xr.DataArray) and not {"latitude", "longitude"} <= set(dset.coords)
):
dset = dset.rename({lat_name: "latitude", lon_name: "longitude"})

# Unstructured Grid
# lat & lon are not coordinate variables in unstructured grid
if dset.attrs.get("mio_has_unstructured_grid", False):
# only call rename and wrap_longitudes
dset = _rename_to_monet_latlon(dset)
# Maybe wrap longitudes
if lon180 is None:
lon180 = dset["longitude"].min() >= -180 and dset["longitude"].max() < 180
if not lon180:
dset["longitude"] = wrap_longitudes(dset["longitude"])

else:
dset = _rename_to_monet_latlon(dset)
latlon2d = True
# print(len(dset[lat_name].shape))
# print(dset)
if len(dset[lat_name].shape) < 2:
# print(dset[lat_name].shape)
latlon2d = False
if latlon2d is False:
try:
if isinstance(dset, xr.DataArray):
dset = _dataarray_coards_to_netcdf(dset, lat_name=lat_name, lon_name=lon_name)
elif isinstance(dset, xr.Dataset):
dset = _coards_to_netcdf(dset, lat_name=lat_name, lon_name=lon_name)
else:
raise ValueError
except ValueError:
print("dset must be an Xarray.DataArray or Xarray.Dataset")
else:
dset = _rename_to_monet_latlon(dset)
dset["longitude"] = wrap_longitudes(dset["longitude"])
# lat & lon are not coordinate variables in unstructured grid, so we're done
if dset.attrs.get("mio_has_unstructured_grid", False):
return dset

# Maybe convert 1-D lat/lon coords to 2-D
if latlon2d is None:
latlon2d = dset["latitude"].ndim >= 2
if not latlon2d:
if isinstance(dset, xr.DataArray):
dset = _dataarray_coards_to_netcdf(dset, lat_name="latitude", lon_name="longitude")
elif isinstance(dset, xr.Dataset):
dset = _coards_to_netcdf(dset, lat_name="latitude", lon_name="longitude")

return dset

Expand Down Expand Up @@ -171,7 +172,7 @@ def _rename_to_monet_latlon(ds):
elif "XLAT" in check_list:
return ds.rename({"XLAT": "latitude", "XLONG": "longitude"})
else:
return ds
return ds.copy()


def _coards_to_netcdf(dset, lat_name="lat", lon_name="lon"):
Expand All @@ -189,7 +190,7 @@ def _coards_to_netcdf(dset, lat_name="lat", lon_name="lon"):
"""
from numpy import arange, meshgrid

lon = wrap_longitudes(dset[lon_name])
lon = dset[lon_name]
lat = dset[lat_name]
lons, lats = meshgrid(lon, lat)
x = arange(len(lon))
Expand Down Expand Up @@ -218,7 +219,7 @@ def _dataarray_coards_to_netcdf(dset, lat_name="lat", lon_name="lon"):
"""
from numpy import arange, meshgrid

lon = wrap_longitudes(dset[lon_name])
lon = dset[lon_name]
lat = dset[lat_name]
lons, lats = meshgrid(lon, lat)
x = arange(len(lon))
Expand Down Expand Up @@ -1191,7 +1192,7 @@ def _get_CoordinateDefinition(self, data=None):
g = geo.CoordinateDefinition(lats=self._obj.latitude, lons=self._obj.longitude)
return g

def remap_nearest(self, data, **kwargs):
def remap_nearest(self, data, radius_of_influence=1e6, **kwargs):
"""Remap `data` from another grid to the current self grid using pyresample
nearest-neighbor interpolation.

Expand All @@ -1213,16 +1214,20 @@ def remap_nearest(self, data, **kwargs):

# from .grids import get_generic_projection_from_proj4
# check to see if grid is supplied

source_data = _dataset_to_monet(data)
target_data = _dataset_to_monet(self._obj)
source = self._get_CoordinateDefinition(data=source_data)
target = self._get_CoordinateDefinition(data=target_data)
r = kd_tree.XArrayResamplerNN(source, target, **kwargs)
source = self._get_CoordinateDefinition(source_data)
target = self._get_CoordinateDefinition(target_data)
r = kd_tree.XArrayResamplerNN(
source, target, radius_of_influence=radius_of_influence, **kwargs
)
r.get_neighbour_info()
if isinstance(source_data, xr.DataArray):
result = r.get_sample_from_neighbour_info(source_data)
result.name = source_data.name
result["latitude"] = target_data.latitude
result["longitude"] = target_data.longitude

elif isinstance(source_data, xr.Dataset):
results = {}
Expand Down Expand Up @@ -1504,7 +1509,7 @@ def _get_CoordinateDefinition(self, data=None):
g = geo.CoordinateDefinition(lats=self._obj.latitude, lons=self._obj.longitude)
return g

def remap_nearest(self, data, radius_of_influence=1e6):
def remap_nearest(self, data, radius_of_influence=1e6, **kwargs):
"""Remap `data` from another grid to the current self grid using pyresample
nearest-neighbor interpolation.

Expand All @@ -1525,26 +1530,20 @@ def remap_nearest(self, data, radius_of_influence=1e6):

# from .grids import get_generic_projection_from_proj4
# check to see if grid is supplied
try:
check_error = False
if isinstance(data, xr.DataArray) or isinstance(data, xr.Dataset):
check_error = False
else:
check_error = True
if check_error:
raise TypeError
except TypeError:
print("data must be either an Xarray.DataArray or Xarray.Dataset")

source_data = _dataset_to_monet(data)
target_data = _dataset_to_monet(self._obj)
source = self._get_CoordinateDefinition(source_data)
target = self._get_CoordinateDefinition(target_data)
r = kd_tree.XArrayResamplerNN(source, target, radius_of_influence=radius_of_influence)
r = kd_tree.XArrayResamplerNN(
source, target, radius_of_influence=radius_of_influence, **kwargs
)
r.get_neighbour_info()
if isinstance(source_data, xr.DataArray):
result = r.get_sample_from_neighbour_info(source_data)
result.name = source_data.name
result["latitude"] = target_data.latitude
result["longitude"] = target_data.longitude

elif isinstance(source_data, xr.Dataset):
results = {}
Expand Down
11 changes: 6 additions & 5 deletions monet/util/combinetool.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def combine_da_to_da(source, target, *, merge=True, interp_time=False, **kwargs)
----------
source : xarray.DataArray or xarray.Dataset
Gridded data.
target : xarray.DataArray
target : xarray.DataArray or xarray.Dataset
Point observations.
merge : bool
If false, only return the interpolated source data.
Expand All @@ -87,13 +87,14 @@ def combine_da_to_da(source, target, *, merge=True, interp_time=False, **kwargs)
"""
from ..monet_accessor import _dataset_to_monet

target_fixed = _dataset_to_monet(target)
source_fixed = _dataset_to_monet(source)
output = target_fixed.monet.remap_nearest(source_fixed, **kwargs)
output = target.monet.remap_nearest(source, **kwargs)

if interp_time:
output = output.interp(time=target.time)

if merge:
output = xr.merge([target_fixed, output])
output = xr.merge([_dataset_to_monet(target), output])

return output


Expand Down
11 changes: 9 additions & 2 deletions tests/test_remap.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ def test_combine_da_da():
},
)

# Longitude normalization introduces floating point error
x_ = (x + 180) % 360 - 180
assert not (x_ == x).any()
assert np.abs(x_ - x).max() < 5e-14

# Combine (find closest model grid cell to each obs point)
# NOTE: to use `merge`, must have matching `level` dims
new = combine_da_to_da(model, obs, merge=False, interp_time=False)
Expand All @@ -100,8 +105,10 @@ def test_combine_da_da():
assert float(new.longitude.max()) == pytest.approx(0.9)
assert float(new.latitude.min()) == pytest.approx(0.1)
assert float(new.latitude.max()) == pytest.approx(0.9)
assert (new.latitude.isel(x=0).values == obs.latitude.values).all()
assert np.allclose(new.longitude.isel(y=0).values, obs.longitude.values)

assert (obs.longitude.values == x).all(), "preserved"
assert (new.latitude.isel(x=0).values == obs.latitude.values).all(), "same as target"
assert (new.longitude.isel(y=0).values == obs.longitude.values).all(), "same as target"

# Use orthogonal selection to get track
a = new.data.values[:, new.y, new.x]
Expand Down