diff --git a/cordex/cf.py b/cordex/cf.py index 512fe67..9842188 100644 --- a/cordex/cf.py +++ b/cordex/cf.py @@ -165,12 +165,12 @@ "CMIP6": dict( domain_id="domain_id", dims={ - "LAT": "latitude", - "LON": "longitude", + "LAT": "lat", + "LON": "lon", "Y": "rlat", "X": "rlon", - "LON_BOUNDS": "lon_vertices", - "LAT_BOUNDS": "lat_vertices", + "LON_BOUNDS": "vertices_lon", + "LAT_BOUNDS": "vertices_lat", "BOUNDS_DIM": "vertices", }, coords={ diff --git a/cordex/domain.py b/cordex/domain.py index 3edb906..4c2a5bb 100644 --- a/cordex/domain.py +++ b/cordex/domain.py @@ -566,13 +566,15 @@ def vertices(rlon, rlat, src_crs, trg_crs=None): return xr.merge([lat_vertices, lon_vertices]) -def rewrite_coords(ds, coords="xy", domain_id=None, mip_era="CMIP5", method="nearest"): +def rewrite_coords( + ds, coords="xy", bounds=False, domain_id=None, mip_era="CMIP5", method="nearest" +): """ - Rewrite coordinates in a dataset to correct rounding errors. + Rewrite coordinates in a dataset. - This function is useful for ensuring that the coordinates in a dataset are consistent and - can be compared to other datasets. It can reindex the dataset based on specified coordinates - or domain information by trying to keep the original coordinate attributes. + This function ensures that the coordinates in a dataset are consistent and can be + compared to other datasets. It can reindex the dataset based on specified coordinates + or domain information while trying to keep the original coordinate attributes. Parameters ---------- @@ -583,13 +585,19 @@ def rewrite_coords(ds, coords="xy", domain_id=None, mip_era="CMIP5", method="nea - "xy": Rewrite only the X and Y coordinates. - "lonlat": Rewrite only the longitude and latitude coordinates. - "all": Rewrite both X, Y, longitude, and latitude coordinates. - Default is "xy". + Default is "xy". If longitude and latitude coordinates are not present in the dataset, they will be added. + Rewriting longitude and latitude coordinates is only possible if the dataset contains a grid mapping variable. + bounds : bool, optional + If True, the function will also handle the bounds of the coordinates. If the dataset already has bounds, + they will be updated while preserving attributes and shape. If not, the bounds will be assigned. domain_id : str, optional - The domain identifier used to obtain grid information. If not provided, the function will attempt to use the grid mapping information from the dataset. + The domain identifier used to obtain grid information. If not provided, the function will attempt + to use the domain_id attribute from the dataset. mip_era : str, optional The MIP era (e.g., "CMIP5", "CMIP6") used to determine coordinate attributes. Default is "CMIP5". + Only used if the dataset does not already contain coordinate attributes. method : str, optional - The method used for reindexing. Options include "nearest", "linear", etc. Default is "nearest". + The method used for reindexing the X and Y axis. Options include "nearest", "linear", etc. Default is "nearest". Returns ------- @@ -647,6 +655,19 @@ def rewrite_coords(ds, coords="xy", domain_id=None, mip_era="CMIP5", method="nea ds[trg_dims[0]][:] = dst[trg_dims[0]] ds[trg_dims[1]][:] = dst[trg_dims[1]] + if bounds is True: + # check if the dataset already has bounds + # if so, overwrite them (take care to keep attributes though) + overwrite = "longitude" in ds.cf.bounds and "latitude" in ds.cf.bounds + dst = transform_bounds(ds) + if overwrite is False: + ds = dst + else: + lon_bounds = ds.cf.bounds["longitude"] + lat_bounds = ds.cf.bounds["latitude"] + ds[lon_bounds[0]][:] = dst.cf.get_bounds("longitude") + ds[lat_bounds[0]][:] = dst.cf.get_bounds("latitude") + return ds diff --git a/cordex/preprocessing/preprocessing.py b/cordex/preprocessing/preprocessing.py index f1835a6..f2e12d4 100644 --- a/cordex/preprocessing/preprocessing.py +++ b/cordex/preprocessing/preprocessing.py @@ -269,6 +269,11 @@ def replace_rlon_rlat(ds, domain=None): Dataset with updated rlon, rlat. """ + warn( + "replace_rlon_rlat is deprecated, please use rewrite_coords instead", + DeprecationWarning, + stacklevel=2, + ) ds = ds.copy() if domain is None: domain = ds.cx.domain_id @@ -297,6 +302,11 @@ def replace_vertices(ds, domain=None): Dataset with updated vertices. """ + warn( + "replace_vertices is deprecated, please use rewrite_coords instead", + DeprecationWarning, + stacklevel=2, + ) ds = ds.copy() if domain is None: domain = ds.attrs.get("CORDEX_domain", None) @@ -325,6 +335,11 @@ def replace_lon_lat(ds, domain=None): Dataset with updated lon, lat. """ + warn( + "replace_lon_lat is deprecated, please use rewrite_coords instead", + DeprecationWarning, + stacklevel=2, + ) ds = ds.copy() if domain is None: domain = ds.cx.domain_id @@ -355,6 +370,11 @@ def replace_coords(ds, domain=None): """ + warn( + "replace_coords is deprecated, please use rewrite_coords instead", + DeprecationWarning, + stacklevel=2, + ) ds = ds.copy() ds = replace_rlon_rlat(ds, domain) ds = replace_lon_lat(ds, domain) @@ -380,6 +400,11 @@ def replace_grid(ds, domain=None): """ + warn( + "replace_grid is deprecated, please use rewrite_coords instead", + DeprecationWarning, + stacklevel=2, + ) ds = ds.copy() ds = replace_rlon_rlat(ds, domain) ds = replace_lon_lat(ds, domain) diff --git a/cordex/transform.py b/cordex/transform.py index b57ae94..3e1bf21 100644 --- a/cordex/transform.py +++ b/cordex/transform.py @@ -265,7 +265,9 @@ def transform_coords(ds, src_crs=None, trg_crs=None, trg_dims=None): return ds.assign_coords({trg_dims[0]: xt, trg_dims[1]: yt}) -def transform_bounds(ds, src_crs=None, trg_crs=None, trg_dims=None, bnds_dim=None): +def transform_bounds( + ds, src_crs=None, trg_crs=None, trg_dims=None, bnds_dim=None, keep_xy_bounds=False +): """Transform linear X and Y bounds of a Dataset. Transformation of of the bounds of linear X and Y coordinates @@ -315,7 +317,7 @@ def transform_bounds(ds, src_crs=None, trg_crs=None, trg_dims=None, bnds_dim=Non if bnds_dim is None: bnds_dim = cf.BOUNDS_DIM - bnds = ds.cf.add_bounds((ds.cf["X"].name, ds.cf["Y"].name)) + bnds = ds.cf.add_bounds(("X", "Y")) x_bnds = bnds.cf.get_bounds("X").drop(bnds.cf.bounds["X"]) y_bnds = bnds.cf.get_bounds("Y").drop(bnds.cf.bounds["Y"]) @@ -338,8 +340,8 @@ def transform_bounds(ds, src_crs=None, trg_crs=None, trg_dims=None, bnds_dim=Non ds.cf["Y"].dims[0], ds.cf["X"].dims[0], bnds_dim ) - ds.cf["longitude"].attrs["bounds"] = cf.LON_BOUNDS - ds.cf["latitude"].attrs["bounds"] = cf.LAT_BOUNDS + ds[ds.cf["longitude"].name].attrs["bounds"] = cf.LON_BOUNDS + ds[ds.cf["latitude"].name].attrs["bounds"] = cf.LAT_BOUNDS return ds.assign_coords( { diff --git a/docs/whats_new.rst b/docs/whats_new.rst index f092ed5..10abcc6 100644 --- a/docs/whats_new.rst +++ b/docs/whats_new.rst @@ -12,7 +12,7 @@ rounding errors. This version drops python3.8 support. New Features ~~~~~~~~~~~~ -- New function :py:meth:`cordex.rewrite_coords` (:pull:`306`). +- New function :py:meth:`cordex.rewrite_coords` (:pull:`306`, :pull:`307`). Breaking Changes ~~~~~~~~~~~~~~~~ diff --git a/tests/test_domain.py b/tests/test_domain.py index 6fb310a..9b23205 100644 --- a/tests/test_domain.py +++ b/tests/test_domain.py @@ -170,3 +170,19 @@ def test_rewrite_coords(domain_id): np.testing.assert_array_equal(rewritten_data.lon, grid.lon) np.testing.assert_array_equal(rewritten_data.lat, grid.lat) xr.testing.assert_identical(rewritten_data, grid) + + grid = cx.domain(domain_id, bounds=True, mip_era="CMIP6") + grid["vertices_lon"][:] = 0.0 + grid["vertices_lat"][:] = 0.0 + grid.vertices_lon.attrs["hello"] = "world" + grid.vertices_lat.attrs["hello"] = "world" + + rewritten_data = cx.rewrite_coords(grid, bounds=True) + grid = cx.domain(domain_id, bounds=True, mip_era="CMIP6") + + np.testing.assert_array_equal(rewritten_data.vertices_lon, grid.vertices_lon) + np.testing.assert_array_equal(rewritten_data.vertices_lat, grid.vertices_lat) + + # check if attributes are now overwritten + assert rewritten_data.vertices_lon.attrs["hello"] == "world" + assert rewritten_data.vertices_lat.attrs["hello"] == "world" diff --git a/tests/test_transform.py b/tests/test_transform.py index 8934f7d..3e87066 100644 --- a/tests/test_transform.py +++ b/tests/test_transform.py @@ -89,7 +89,6 @@ def test_derotate_vector(): assert np.allclose(v1, v1_expect) -@requires_cartopy def test_bounds(): # assert that we get the same bounds as before ds = domain("EUR-11", bounds=True) @@ -97,3 +96,6 @@ def test_bounds(): v = transform_bounds(ds) np.array_equal(v.lon_vertices, _v.lon_vertices) np.array_equal(v.lat_vertices, _v.lat_vertices) + + assert "longitude" in v.cf.bounds + assert "latitude" in v.cf.bounds