Skip to content

Commit

Permalink
update bounds handling (#307)
Browse files Browse the repository at this point in the history
* update bounds handling

* cleaning

* update tests

* update docstrings

* added deprecation messages
  • Loading branch information
larsbuntemeyer authored Dec 19, 2024
1 parent dcfa641 commit 196622e
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 18 deletions.
8 changes: 4 additions & 4 deletions cordex/cf.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,12 @@
"CMIP6": dict(
domain_id="domain_id",
dims={
"LAT": "latitude",
"LON": "longitude",
"LAT": "lat",
"LON": "lon",
"Y": "rlat",
"X": "rlon",
"LON_BOUNDS": "lon_vertices",
"LAT_BOUNDS": "lat_vertices",
"LON_BOUNDS": "vertices_lon",
"LAT_BOUNDS": "vertices_lat",
"BOUNDS_DIM": "vertices",
},
coords={
Expand Down
37 changes: 29 additions & 8 deletions cordex/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,13 +566,15 @@ def vertices(rlon, rlat, src_crs, trg_crs=None):
return xr.merge([lat_vertices, lon_vertices])


def rewrite_coords(ds, coords="xy", domain_id=None, mip_era="CMIP5", method="nearest"):
def rewrite_coords(
ds, coords="xy", bounds=False, domain_id=None, mip_era="CMIP5", method="nearest"
):
"""
Rewrite coordinates in a dataset to correct rounding errors.
Rewrite coordinates in a dataset.
This function is useful for ensuring that the coordinates in a dataset are consistent and
can be compared to other datasets. It can reindex the dataset based on specified coordinates
or domain information by trying to keep the original coordinate attributes.
This function ensures that the coordinates in a dataset are consistent and can be
compared to other datasets. It can reindex the dataset based on specified coordinates
or domain information while trying to keep the original coordinate attributes.
Parameters
----------
Expand All @@ -583,13 +585,19 @@ def rewrite_coords(ds, coords="xy", domain_id=None, mip_era="CMIP5", method="nea
- "xy": Rewrite only the X and Y coordinates.
- "lonlat": Rewrite only the longitude and latitude coordinates.
- "all": Rewrite both X, Y, longitude, and latitude coordinates.
Default is "xy".
Default is "xy". If longitude and latitude coordinates are not present in the dataset, they will be added.
Rewriting longitude and latitude coordinates is only possible if the dataset contains a grid mapping variable.
bounds : bool, optional
If True, the function will also handle the bounds of the coordinates. If the dataset already has bounds,
they will be updated while preserving attributes and shape. If not, the bounds will be assigned.
domain_id : str, optional
The domain identifier used to obtain grid information. If not provided, the function will attempt to use the grid mapping information from the dataset.
The domain identifier used to obtain grid information. If not provided, the function will attempt
to use the domain_id attribute from the dataset.
mip_era : str, optional
The MIP era (e.g., "CMIP5", "CMIP6") used to determine coordinate attributes. Default is "CMIP5".
Only used if the dataset does not already contain coordinate attributes.
method : str, optional
The method used for reindexing. Options include "nearest", "linear", etc. Default is "nearest".
The method used for reindexing the X and Y axis. Options include "nearest", "linear", etc. Default is "nearest".
Returns
-------
Expand Down Expand Up @@ -647,6 +655,19 @@ def rewrite_coords(ds, coords="xy", domain_id=None, mip_era="CMIP5", method="nea
ds[trg_dims[0]][:] = dst[trg_dims[0]]
ds[trg_dims[1]][:] = dst[trg_dims[1]]

if bounds is True:
# check if the dataset already has bounds
# if so, overwrite them (take care to keep attributes though)
overwrite = "longitude" in ds.cf.bounds and "latitude" in ds.cf.bounds
dst = transform_bounds(ds)
if overwrite is False:
ds = dst
else:
lon_bounds = ds.cf.bounds["longitude"]
lat_bounds = ds.cf.bounds["latitude"]
ds[lon_bounds[0]][:] = dst.cf.get_bounds("longitude")
ds[lat_bounds[0]][:] = dst.cf.get_bounds("latitude")

return ds


Expand Down
25 changes: 25 additions & 0 deletions cordex/preprocessing/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,11 @@ def replace_rlon_rlat(ds, domain=None):
Dataset with updated rlon, rlat.
"""
warn(
"replace_rlon_rlat is deprecated, please use rewrite_coords instead",
DeprecationWarning,
stacklevel=2,
)
ds = ds.copy()
if domain is None:
domain = ds.cx.domain_id
Expand Down Expand Up @@ -297,6 +302,11 @@ def replace_vertices(ds, domain=None):
Dataset with updated vertices.
"""
warn(
"replace_vertices is deprecated, please use rewrite_coords instead",
DeprecationWarning,
stacklevel=2,
)
ds = ds.copy()
if domain is None:
domain = ds.attrs.get("CORDEX_domain", None)
Expand Down Expand Up @@ -325,6 +335,11 @@ def replace_lon_lat(ds, domain=None):
Dataset with updated lon, lat.
"""
warn(
"replace_lon_lat is deprecated, please use rewrite_coords instead",
DeprecationWarning,
stacklevel=2,
)
ds = ds.copy()
if domain is None:
domain = ds.cx.domain_id
Expand Down Expand Up @@ -355,6 +370,11 @@ def replace_coords(ds, domain=None):
"""
warn(
"replace_coords is deprecated, please use rewrite_coords instead",
DeprecationWarning,
stacklevel=2,
)
ds = ds.copy()
ds = replace_rlon_rlat(ds, domain)
ds = replace_lon_lat(ds, domain)
Expand All @@ -380,6 +400,11 @@ def replace_grid(ds, domain=None):
"""
warn(
"replace_grid is deprecated, please use rewrite_coords instead",
DeprecationWarning,
stacklevel=2,
)
ds = ds.copy()
ds = replace_rlon_rlat(ds, domain)
ds = replace_lon_lat(ds, domain)
Expand Down
10 changes: 6 additions & 4 deletions cordex/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,9 @@ def transform_coords(ds, src_crs=None, trg_crs=None, trg_dims=None):
return ds.assign_coords({trg_dims[0]: xt, trg_dims[1]: yt})


def transform_bounds(ds, src_crs=None, trg_crs=None, trg_dims=None, bnds_dim=None):
def transform_bounds(
ds, src_crs=None, trg_crs=None, trg_dims=None, bnds_dim=None, keep_xy_bounds=False
):
"""Transform linear X and Y bounds of a Dataset.
Transformation of of the bounds of linear X and Y coordinates
Expand Down Expand Up @@ -315,7 +317,7 @@ def transform_bounds(ds, src_crs=None, trg_crs=None, trg_dims=None, bnds_dim=Non
if bnds_dim is None:
bnds_dim = cf.BOUNDS_DIM

bnds = ds.cf.add_bounds((ds.cf["X"].name, ds.cf["Y"].name))
bnds = ds.cf.add_bounds(("X", "Y"))
x_bnds = bnds.cf.get_bounds("X").drop(bnds.cf.bounds["X"])
y_bnds = bnds.cf.get_bounds("Y").drop(bnds.cf.bounds["Y"])

Expand All @@ -338,8 +340,8 @@ def transform_bounds(ds, src_crs=None, trg_crs=None, trg_dims=None, bnds_dim=Non
ds.cf["Y"].dims[0], ds.cf["X"].dims[0], bnds_dim
)

ds.cf["longitude"].attrs["bounds"] = cf.LON_BOUNDS
ds.cf["latitude"].attrs["bounds"] = cf.LAT_BOUNDS
ds[ds.cf["longitude"].name].attrs["bounds"] = cf.LON_BOUNDS
ds[ds.cf["latitude"].name].attrs["bounds"] = cf.LAT_BOUNDS

return ds.assign_coords(
{
Expand Down
2 changes: 1 addition & 1 deletion docs/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ rounding errors. This version drops python3.8 support.
New Features
~~~~~~~~~~~~

- New function :py:meth:`cordex.rewrite_coords` (:pull:`306`).
- New function :py:meth:`cordex.rewrite_coords` (:pull:`306`, :pull:`307`).

Breaking Changes
~~~~~~~~~~~~~~~~
Expand Down
16 changes: 16 additions & 0 deletions tests/test_domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,19 @@ def test_rewrite_coords(domain_id):
np.testing.assert_array_equal(rewritten_data.lon, grid.lon)
np.testing.assert_array_equal(rewritten_data.lat, grid.lat)
xr.testing.assert_identical(rewritten_data, grid)

grid = cx.domain(domain_id, bounds=True, mip_era="CMIP6")
grid["vertices_lon"][:] = 0.0
grid["vertices_lat"][:] = 0.0
grid.vertices_lon.attrs["hello"] = "world"
grid.vertices_lat.attrs["hello"] = "world"

rewritten_data = cx.rewrite_coords(grid, bounds=True)
grid = cx.domain(domain_id, bounds=True, mip_era="CMIP6")

np.testing.assert_array_equal(rewritten_data.vertices_lon, grid.vertices_lon)
np.testing.assert_array_equal(rewritten_data.vertices_lat, grid.vertices_lat)

# check if attributes are now overwritten
assert rewritten_data.vertices_lon.attrs["hello"] == "world"
assert rewritten_data.vertices_lat.attrs["hello"] == "world"
4 changes: 3 additions & 1 deletion tests/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,13 @@ def test_derotate_vector():
assert np.allclose(v1, v1_expect)


@requires_cartopy
def test_bounds():
# assert that we get the same bounds as before
ds = domain("EUR-11", bounds=True)
_v = vertices(ds.rlon, ds.rlat, CRS.from_cf(ds.cf["grid_mapping"].attrs))
v = transform_bounds(ds)
np.array_equal(v.lon_vertices, _v.lon_vertices)
np.array_equal(v.lat_vertices, _v.lat_vertices)

assert "longitude" in v.cf.bounds
assert "latitude" in v.cf.bounds

0 comments on commit 196622e

Please sign in to comment.