Skip to content

Commit

Permalink
added regridding
Browse files Browse the repository at this point in the history
  • Loading branch information
larsbuntemeyer committed Jan 11, 2025
1 parent b42eb5a commit e3d0c71
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 75 deletions.
198 changes: 126 additions & 72 deletions evaltools/eval.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import numpy as np
import xarray as xr
import cordex as cx
import cf_xarray as cfxr
import xesmf as xe
from warnings import warn


def regional_mean(ds, regions=None, weights=None):
Expand Down Expand Up @@ -100,17 +104,12 @@ def height_correction(height1, height2):
return (height1 - height2) * 0.0065


def seasonal_mean(da):
"""
Calculate seasonal averages from a time series of monthly means.
Parameters:
da (xarray.DataArray): The DataArray to compute seasonal means for.
def seasonal_mean(da, skipna=True, min_count=1):
"""Calculate seasonal averages from time series of monthly means
Returns:
xarray.DataArray: The seasonal mean values.
based on: https://xarray.pydata.org/en/stable/examples/monthly-means.html
"""
# Get number of days for each month
# Get number od days for each month
month_length = da.time.dt.days_in_month
# Calculate the weights by grouping by 'time.season'.
weights = (
Expand All @@ -121,79 +120,134 @@ def seasonal_mean(da):
# np.testing.assert_allclose(weights.groupby("time.season").sum().values, np.ones(4))

# Calculate the weighted average
return (da * weights).groupby("time.season").sum(dim="time")
return (
(da * weights)
.groupby("time.season")
.sum(dim="time", skipna=skipna, min_count=min_count)
)


def add_bounds(ds):
if "longitude" not in ds.cf.bounds and "latitude" not in ds.cf.bounds:
ds = cx.transform_bounds(ds, trg_dims=("vertices_lon", "vertices_lat"))
ds = ds.assign_coords(
lon_b=cfxr.bounds_to_vertices(
ds.vertices_lon, bounds_dim="vertices", order="counterclockwise"
),
lat_b=cfxr.bounds_to_vertices(
ds.vertices_lat, bounds_dim="vertices", order="counterclockwise"
),
)
return ds


def mask_with_sftlf(ds, sftlf=None):
if sftlf is None and "sftlf" in ds:
sftlf = ds["sftlf"]
for var in ds.data_vars:
if var != "sftlf":
ds[var] = ds[var].where(sftlf > 0)
ds["mask"] = sftlf > 0
else:
warn(f"sftlf not found in dataset: {ds.source_id}")
return ds

def get_regridder(finer, coarser, method="bilinear", **kwargs):

def create_cordex_grid(domain_id):
"""
Regrid data bilinearly to a coarser grid.
Creates a CORDEX grid for the specified domain.
Parameters:
finer (xarray.Dataset): The dataset to regrid.
coarser (xarray.Dataset): The target grid dataset.
method (str, optional): The regridding method to use. Defaults to "bilinear".
**kwargs: Additional keyword arguments to pass to the regridding function.
Parameters
----------
domain_id : str
The domain ID for the CORDEX grid.
Returns:
xesmf.Regridder: The regridder object.
Returns
-------
xarray.Dataset
The CORDEX grid with assigned coordinates for longitude and latitude bounds.
"""
import xesmf as xe
grid = cx.domain(domain_id, bounds=True, mip_era="CMIP6")
lon_b = cfxr.bounds_to_vertices(
grid.vertices_lon, bounds_dim="vertices", order="counterclockwise"
)
lat_b = cfxr.bounds_to_vertices(
grid.vertices_lat, bounds_dim="vertices", order="counterclockwise"
)
return grid.assign_coords(lon_b=lon_b, lat_b=lat_b)


regridder = xe.Regridder(finer, coarser, method, **kwargs)
def create_regridder(source, target, method="bilinear"):
"""
Creates a regridder for regridding data from the source grid to the target grid.
Parameters
----------
source : xarray.Dataset
The source dataset to be regridded.
target : xarray.Dataset
The target grid dataset.
method : str, optional
The regridding method to use. Default is "bilinear".
Returns
-------
xesmf.Regridder
The regridder object.
"""
regridder = xe.Regridder(source, target, method=method)
return regridder


def compare_seasons(
ds1, ds2, regrid="ds1", do_height_correction=False, orog1=None, orog2=None
):
"""
Function to compare seasonal means of two datasets.
Paramters
---------
ds1 : xarray.Dataset
First variable data for comparision. Temporal resolution has to be less than monthly.
ds1 is mainly model output data.
ds2 is subtracted from ds1.
ds2 : xarray.Dataset
Second variable data for comparision. Temporal resolution has to be less than monthly.
ds2 is mainly observational or reanalysis data.
ds2 is subtracted from ds1.
regrid : {"ds1", "ds2"}, optional
Denotes the dataset to be bilinearly regridded. Specify the dataset with the finer spatial resolution:
- "ds1": Regrid ds1 to ds2's grid with coarser spatial resolution.
- "ds2": Regrid ds2 to ds1's grid with coarser spatial resolution.
do_height_correction : bool, optional
If ``do_height_correction=True``, do a height correction on ds1 using two orography files orog1 and orog2.
orog1 : xarray.Dataset, optional
Use only if ``do_height_correction=True``.
Specify a orography file referring to ds1.
orog2 : xarray.Dataset, optional
Use only if ``do_height_correction=True``.
Specify a orography file referring to ds2.
def regrid(ds, regridder):
"""
Regrids the dataset using the specified regridder.
Parameters
----------
ds : xarray.Dataset
The dataset to be regridded.
regridder : xesmf.Regridder
The regridder object.
Returns
-------
xarray.Dataset
The regridded dataset.
"""
ds_regrid = regridder(ds)
for var in ds.data_vars:
if var not in ["mask", "sftlf"]:
ds_regrid[var] = ds_regrid[var].where(ds_regrid["mask"] > 0.0)
return ds_regrid


def regrid_dsets(dsets, target_grid, method="bilinear"):
"""
Regrids multiple datasets to the target grid.
Parameters
----------
dsets : dict
A dictionary of datasets to be regridded, with keys as dataset IDs and values as xarray.Datasets.
target_grid : xarray.Dataset
The target grid dataset.
method : str, optional
The regridding method to use. Default is "bilinear".
Returns
-------
seasonal_comparision : xarray.Dataset
Spatial mean differences of two datasets
"""
ds1 = ds1.copy()
ds2 = ds2.copy()
ds1_seasmean = seasonal_mean(ds1)
ds2_seasmean = seasonal_mean(ds2)
if regrid == "ds1":
regridder = get_regridder(ds1, ds2)
print(regridder)
ds1_seasmean = regridder(ds1_seasmean)
elif regrid == "ds2":
regridder = get_regridder(ds2, ds1)
print(regridder)
ds2_seasmean = regridder(ds2_seasmean)

if do_height_correction is True:
orog1 = regridder(orog1)
ds1_seasmean += height_correction(orog1, orog2)
return ds1_seasmean - ds2_seasmean
# return xr.where(ds1_seasmean.mask, ds2_seasmean - ds1_seasmean, np.nan)
dict
A dictionary of regridded datasets.
"""
for dset_id, ds in dsets.items():
print(dset_id)
mapping = ds.cf["grid_mapping"].grid_mapping_name
if mapping == "rotated_latitude_longitude":
dsets[dset_id] = ds.cx.rewrite_coords(coords="all")
else:
print(f"regridding {dset_id} with grid_mapping: {mapping}")
regridder = create_regridder(ds, target_grid, method=method)
print(regridder)
dsets[dset_id] = regrid(ds, regridder)
return dsets
14 changes: 13 additions & 1 deletion evaltools/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pandas as pd
from warnings import warn

from .utils import iid_to_dict, dict_to_iid
from .utils import iid_to_dict, dict_to_iid, mask_with_sftlf, add_bounds

xarray_open_kwargs = {"use_cftime": True, "decode_coords": "all", "chunks": None}
time_range_default = slice("1979", "2020")
Expand Down Expand Up @@ -119,3 +119,15 @@ def open_and_sort(catalog, merge=None, concat=False, time_range="auto"):
join="override",
)
return dsets


def open_datasets(variables, frequency="mon", mask=True, add_missing_bounds=True):
catalog = get_source_collection(variables, frequency, add_fx=["areacella", "sftlf"])
dsets = open_and_sort(catalog, merge=True)
if mask is True:
for ds in dsets.values():
mask_with_sftlf(ds)
if add_missing_bounds is True:
for dset_id, ds in dsets.items():
dsets[dset_id] = add_bounds(ds)
return dsets
37 changes: 35 additions & 2 deletions evaltools/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,22 @@
from collections import defaultdict


def iid_to_dict(dset_id, attrs):
default_attrs = [
"project_id",
"domain_id",
"institution_id",
"driving_source_id",
"driving_experiment_id",
"driving_variant_label",
"source_id",
"version_realization",
#'frequency',
#'variable_id',
"version",
]


def iid_to_dict(iid, attrs=None):
"""
Convert a dataset ID and its attributes to a dictionary.
Expand All @@ -12,7 +27,9 @@ def iid_to_dict(dset_id, attrs):
Returns:
dict: The dataset ID and attributes as a dictionary.
"""
values = dset_id.split(".")
if attrs is None:
attrs = default_attrs
values = iid.split(".")
return dict(zip(attrs, values))


Expand All @@ -31,6 +48,22 @@ def dict_to_iid(attrs, drop=None):
return ".".join(v for k, v in attrs.items() if k not in drop)


def short_iid(iid, attrs=None):
"""
Convert a dataset ID to a short ID.
Parameters:
iid (str): The dataset ID.
attrs (dict): The dataset attributes.
Returns:
str: The short ID.
"""
if attrs is None:
attrs = ["institution_id", "source_id", "driving_source_id", "experiment_id"]
return dict_to_iid({k: v for k, v in iid_to_dict(iid).items() if k in attrs})


def sort_by_grid_mapping(dsets):
"""
Sort the datasets by their grid mapping.
Expand Down

0 comments on commit e3d0c71

Please sign in to comment.