diff --git a/CHANGELOG.md b/CHANGELOG.md index c76e7ba2af..070763bb16 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ Code freeze date: YYYY-MM-DD ### Changed +- Use Geopandas GeoDataFrame.plot() for centroids plotting function [896](https://github.com/CLIMADA-project/climada_python/pull/896) - Update SALib sensitivity and sampling methods from newest version (SALib 1.4.7) [#828](https://github.com/CLIMADA-project/climada_python/issues/828) - Allow for computation of relative and absolute delta impacts in `CalcDeltaClimate` - Remove content tables and make minor improvements (fix typos and readability) in diff --git a/climada/hazard/centroids/centr.py b/climada/hazard/centroids/centr.py index 59b8068877..b02f7569d6 100644 --- a/climada/hazard/centroids/centr.py +++ b/climada/hazard/centroids/centr.py @@ -27,7 +27,9 @@ import warnings import h5py +import cartopy import cartopy.crs as ccrs +import cartopy.feature as cfeature import geopandas as gpd import matplotlib.pyplot as plt import numpy as np @@ -38,7 +40,6 @@ from climada.util.constants import DEF_CRS import climada.util.coordinates as u_coord -import climada.util.plot as u_plot __all__ = ['Centroids'] @@ -478,44 +479,43 @@ def select_mask(self, sel_cen=None, reg_id=None, extent=None): (self.lat >= lat_min) & (self.lat <= lat_max) ) return sel_cen - - #TODO replace with nice GeoDataFrame util plot method. - def plot(self, axis=None, figsize=(9, 13), **kwargs): - """Plot centroids scatter points over earth + + def plot(self, *, axis=None, figsize=(9, 13), **kwargs): + """Plot centroids geodataframe using geopandas and cartopy plotting functions. Parameters ---------- - axis : matplotlib.axes._subplots.AxesSubplot, optional - axis to use + axis: optional + user-defined cartopy.mpl.geoaxes.GeoAxes instance figsize: (float, float), optional figure size for plt.subplots The default is (9, 13) + args : optional + positional arguments for geopandas.GeoDataFrame.plot kwargs : optional - arguments for scatter matplotlib function - + keyword arguments for geopandas.GeoDataFrame.plot + Returns ------- - axis : matplotlib.axes._subplots.AxesSubplot + ax : cartopy.mpl.geoaxes.GeoAxes instance """ - pad = np.abs(u_coord.get_resolution(self.lat, self.lon)).min() - - proj_data, _ = u_plot.get_transformation(self.crs) - proj_plot = proj_data - if isinstance(proj_data, ccrs.PlateCarree): - # use different projections for plot and data to shift the central lon in the plot - xmin, ymin, xmax, ymax = u_coord.latlon_bounds(self.lat, self.lon, buffer=pad) - proj_plot = ccrs.PlateCarree(central_longitude=0.5 * (xmin + xmax)) - else: - xmin, ymin, xmax, ymax = (self.lon.min() - pad, self.lat.min() - pad, - self.lon.max() + pad, self.lat.max() + pad) + if axis == None: + fig, axis = plt.subplots(figsize=figsize, subplot_kw={"projection": ccrs.PlateCarree()}) + if type(axis) != cartopy.mpl.geoaxes.GeoAxes: + raise AttributeError( + f"The axis provided is of type: {type(axis)} " + "The function requires a cartopy.mpl.geoaxes.GeoAxes." + ) - if not axis: - _fig, axis, _fontsize = u_plot.make_map(proj=proj_plot, figsize=figsize) + axis.add_feature(cfeature.BORDERS) + axis.add_feature(cfeature.COASTLINE) + axis.gridlines(draw_labels=True, dms=True, x_inline=False, y_inline=False) - axis.set_extent((xmin, xmax, ymin, ymax), crs=proj_data) - u_plot.add_shapes(axis) - axis.scatter(self.lon, self.lat, transform=proj_data, **kwargs) - plt.tight_layout() + if self.gdf.crs != DEF_CRS: + centroids_plot = self.to_default_crs(inplace=False) + centroids_plot.gdf.plot(ax=axis, transform=ccrs.PlateCarree(), **kwargs) + else: + self.gdf.plot(ax=axis, transform=ccrs.PlateCarree(), **kwargs) return axis def set_region_id(self, level='country', overwrite=False): diff --git a/climada/hazard/centroids/test/test_centr.py b/climada/hazard/centroids/test/test_centr.py index 1ee90ee1bf..745e544d5a 100644 --- a/climada/hazard/centroids/test/test_centr.py +++ b/climada/hazard/centroids/test/test_centr.py @@ -991,6 +991,25 @@ def test_equal_pass(self): self.assertTrue(centr1 == centr1) self.assertTrue(centr2 == centr2) + def test_plot(self): + "Test Centroids.plot()" + centr = Centroids( + lat=np.array([-5, -3, 0, 3, 5]), + lon=np.array([-180, -175, -170, 170, 175]), + region_id=np.zeros(5), + crs=DEF_CRS + ) + centr.plot() + + def test_plot_non_def_crs(self): + "Test Centroids.plot() with non-default CRS" + centr = Centroids( + lat = np.array([10.0, 20.0, 30.0]), + lon = np.array([-10.0, -20.0, -30.0]), + region_id=np.zeros(3), + crs='epsg:32632' + ) + centr.plot() # Execute Tests if __name__ == "__main__":