Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support auxiliary variables in weight files for chunked regridding #513

Merged
merged 12 commits into from
Apr 8, 2020
3 changes: 2 additions & 1 deletion src/ocgis/driver/nc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
import six
from netCDF4._netCDF4 import VLType, MFDataset, MFTime

from ocgis import constants, vm
from ocgis import env
from ocgis.base import orphaned, raise_if_empty
Expand Down Expand Up @@ -262,7 +263,7 @@ def _open_(uri, mode='r', **kwargs):
if kwargs.get('parallel') and kwargs.get('comm') is None:
kwargs['comm'] = lvm.comm
ret = nc.Dataset(uri, mode=mode, **kwargs)
# tdk:FIX: this should be enabled for MFDataset as well. see https://github.com/Unidata/netcdf4-python/issues/809#issuecomment-435144221
# tdk:RELEASE:FIX: this should be enabled for MFDataset as well. see https://github.com/Unidata/netcdf4-python/issues/809#issuecomment-435144221
# netcdf4 >= 1.4.0 always returns masked arrays. This is inefficient and is turned off by default by ocgis.
if hasattr(ret, 'set_always_mask'):
ret.set_always_mask(False)
Expand Down
1 change: 1 addition & 0 deletions src/ocgis/messages.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
M1 = 'No dimensioned variables found. This typically means no target variables in the dataset have space and/or time dimensions associated with them. Consider using a dimension map if file metadata is non-standard. Overloading "variable" is also an option to avoid dimension checking.'
M3 = 'Output path exists "{0}" and must be removed before proceeding. Set "overwrite" argument or env.OVERWRITE to True to overwrite.'
M4 = """A level subset was requested but the target dataset does not have a level dimension. The dataset's alias is: {0}"""
M5 = "The ESMF FileMode constant value. BASIC (the default) only writes the factor index list and weight factor variables. WITHAUX adds auxiliary variables and additional file metadata to the output weight file."
48 changes: 45 additions & 3 deletions src/ocgis/ocli.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
#!/usr/bin/env python

import datetime
import logging
import os
import shutil
import tempfile

import click
import netCDF4 as nc
from shapely.geometry import box

import ocgis
from ocgis import RequestDataset, GeometryVariable, constants
from ocgis.base import grid_abstraction_scope, raise_if_empty
from ocgis.constants import DriverKey, Topology, GridChunkerConstants, DecompositionType
from ocgis.driver.nc_ugrid import DriverNetcdfUGRID
from ocgis.messages import M5
from ocgis.spatial.grid_chunker import GridChunker
from ocgis.spatial.spatial_subset import SpatialSubsetOperation
from ocgis.util.logging_ocgis import ocgis_lh
Expand Down Expand Up @@ -91,9 +93,14 @@ def ocli():
@click.option('--verbose/--not_verbose', default=False, help='If True, log to standard out using verbosity level.')
@click.option('--loglvl', default="INFO", help='Verbosity level for standard out logging. Default is '
'"INFO". See Python logging level docs for additional values: https://docs.python.org/3/howto/logging.html')
@click.option('--weightfilemode', default="BASIC", help=M5)
def chunked_rwg(source, destination, weight, nchunks_dst, merge, esmf_src_type, esmf_dst_type, genweights,
esmf_regrid_method, spatial_subset, src_resolution, dst_resolution, buffer_distance, wd, persist,
eager, ignore_degenerate, data_variables, spatial_subset_path, verbose, loglvl):
eager, ignore_degenerate, data_variables, spatial_subset_path, verbose, loglvl, weightfilemode):

# Used for creating the history string.
the_locals = locals()

if verbose:
ocgis_lh.configure(to_stream=True, level=getattr(logging, loglvl))
ocgis_lh(msg="Starting Chunked Regrid Weight Generation", level=logging.INFO, logger=CRWG_LOG)
Expand Down Expand Up @@ -185,7 +192,7 @@ def chunked_rwg(source, destination, weight, nchunks_dst, merge, esmf_src_type,
msg = "Writing ESMF weights..."
ocgis_lh(msg=msg, level=logging.INFO, logger=CRWG_LOG)
handle_weight_file_check(weight)
gs.write_esmf_weights(source, destination, weight)
gs.write_esmf_weights(source, destination, weight, filemode=weightfilemode)

# Create the global weight file. This does not apply to spatial subsets because there will always be one weight
# file.
Expand All @@ -206,6 +213,20 @@ def chunked_rwg(source, destination, weight, nchunks_dst, merge, esmf_src_type,

ocgis.vm.barrier()

# Append the history string if there is an output weight file.
if weight and ocgis.vm.rank == 0:
if os.path.exists(weight):
# Add some additional stuff for record keeping
import getpass
import socket
import datetime

with nc.Dataset(weight, 'a') as ds:
ds.setncattr('created_by_user', getpass.getuser())
ds.setncattr('created_on_hostname', socket.getfqdn())
ds.setncattr('history', create_history_string(the_locals))
ocgis.vm.barrier()

# Remove the working directory unless the persist flag is provided.
if not persist:
if ocgis.vm.rank == 0:
Expand All @@ -218,6 +239,27 @@ def chunked_rwg(source, destination, weight, nchunks_dst, merge, esmf_src_type,
return 0


def create_history_string(the_locals):
history_parms = {}
for k, v in the_locals.items():
if v is not None and k != 'history_parms':
if type(v) == bool:
if not v:
history_parms['--no_' + k] = v
else:
history_parms['--' + k] = v
try:
import ESMF
ever = ESMF.__version__
except ImportError:
ever = None
history = "{}: Created by ocgis (v{}) and ESMF (v{}) with CLI command: ocli chunked-rwg".format(
datetime.datetime.now(), ocgis.__version__, ever)
for k, v in history_parms.items():
history += " {} {}".format(k, v)
return history


@ocli.command(help='Apply weights in chunked files with an option to insert the global data.', name='chunked-smm')
@click.option('--wd', type=click.Path(exists=True, dir_okay=True), required=False,
help="Optional working directory containing destination chunk files. If empty, the current working "
Expand Down
81 changes: 68 additions & 13 deletions src/ocgis/spatial/grid_chunker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from shapely.geometry import box

from ocgis import constants
from ocgis.base import AbstractOcgisObject, grid_abstraction_scope
from ocgis.base import AbstractOcgisObject, grid_abstraction_scope, orphaned
from ocgis.collection.field import Field
from ocgis.constants import GridChunkerConstants, RegriddingRole, Topology
from ocgis.driver.request.core import RequestDataset
Expand All @@ -23,6 +23,7 @@

_LOCAL_LOGGER = "grid_chunker"


class GridChunker(AbstractOcgisObject):
"""
Splits source and destination grids into separate netCDF files. "Source" is intended to mean the source data for a
Expand Down Expand Up @@ -96,13 +97,15 @@ class GridChunker(AbstractOcgisObject):
:param bool eager: If ``True``, load grid data from disk before chunking. This avoids always loading the data from
disk for sourced datasets following a subset. There will be an improvement in performance but an increase in the
memory used.
:param str filemode: If ``'BASIC'`` (the default), only write source/destination indices and weight factors to the
output weight file. If ``'WITHAUX'`` also write geometry-related auxiliary variables to the output weight file.
:raises: ValueError
"""

def __init__(self, source, destination, nchunks_dst=None, paths=None, check_contains=False, allow_masked=True,
src_grid_resolution=None, dst_grid_resolution=None, optimized_bbox_subset='auto', iter_dst=None,
buffer_value=None, redistribute=False, genweights=False, esmf_kwargs=None, use_spatial_decomp='auto',
eager=True, debug=False):
eager=True, filemode="BASIC", debug=False):
self._src_grid = None
self._dst_grid = None
self._buffer_value = None
Expand All @@ -114,6 +117,7 @@ def __init__(self, source, destination, nchunks_dst=None, paths=None, check_cont
self.source = source
self.destination = destination
self.eager = eager
self.filemode = filemode
self.debug = debug

if esmf_kwargs is None:
Expand Down Expand Up @@ -305,39 +309,54 @@ def create_merged_weight_file(self, merged_weight_filename, strict=False):
raise ValueError("'create_merged_weight_file' does not work in parallel")

index_filename = self.create_full_path_from_template('index_file')
ifile = RequestDataset(uri=index_filename).get()
ifile = RequestDataset(uri=index_filename, driver='netcdf').get()
ifile.load()
ifc = GridChunkerConstants.IndexFile
gidx = ifile[ifc.NAME_INDEX_VARIABLE].attrs

src_global_shape = gidx[ifc.NAME_SRC_GRID_SHAPE]
dst_global_shape = gidx[ifc.NAME_DST_GRID_SHAPE]

vc = VariableCollection()
wf_varnames = ['row', 'col', 'S']

# Get the global weight dimension size.
n_s_size = 0
weight_filename = ifile[gidx[ifc.NAME_WEIGHTS_VARIABLE]]
wv = weight_filename.join_string_value()
split_weight_file_directory = self.paths['wd']
ctr = 0
for wfn in map(lambda x: os.path.join(split_weight_file_directory, os.path.split(x)[1]), wv):
ocgis_lh(msg="current merge weight file target: {}".format(wfn), level=logging.DEBUG, logger=_LOCAL_LOGGER)
if not os.path.exists(wfn):
if strict:
raise IOError(wfn)
else:
continue
curr_dimsize = RequestDataset(wfn).get().dimensions['n_s'].size
vc_target = RequestDataset(wfn, driver='netcdf').get()
curr_dimsize = vc_target.dimensions['n_s'].size
# ESMF writes the weight file, but it may be empty if there are no generated weights.
if curr_dimsize is not None:
n_s_size += curr_dimsize

# Copy over auxiliary variables if they are required.
if self.filemode == 'WITHAUX' and ctr == 0:
for var in vc_target.values():
if var.name not in wf_varnames:
var.load()
with orphaned(var, keep_dimensions=True):
vc.add_variable(var)
# Also copy over global attributes
vc.attrs = vc_target.attrs
ctr += 1

# Create output weight file.
wf_varnames = ['row', 'col', 'S']
wf_dtypes = [np.int32, np.int32, np.float64]
vc = VariableCollection()
dim = Dimension('n_s', n_s_size)
for w, wd in zip(wf_varnames, wf_dtypes):
var = Variable(name=w, dimensions=dim, dtype=wd)
vc.add_variable(var)

vc.write(merged_weight_filename)

# Transfer weights to the merged file.
Expand All @@ -352,7 +371,7 @@ def create_merged_weight_file(self, merged_weight_filename, strict=False):
raise IOError(wfn)
else:
continue
wdata = RequestDataset(wfn).get()
wdata = RequestDataset(wfn, driver='netcdf').get()
for wvn in wf_varnames:
odata = wdata[wvn].get_value()
try:
Expand Down Expand Up @@ -722,17 +741,22 @@ def write_chunks(self):
level=logging.DEBUG)
cc += 1

# Increment the counter outside of the loop to avoid counting empty subsets.
ctr += 1

# Generate an ESMF weights file if requested and at least one rank has data on it.
if self.genweights and len(vm.get_live_ranks_from_object(sub_src)) > 0:
vm.barrier()
if (ctr == 1) and (self.filemode == 'WITHAUX'):
filemode = 'WITHAUX'
else:
filemode = 'BASIC'
ocgis_lh(logger=_LOCAL_LOGGER, msg='write_chunks:writing esmf weights: {}'.format(wgt_path),
level=logging.DEBUG)
self.write_esmf_weights(src_path, dst_path, wgt_path, src_grid=sub_src, dst_grid=sub_dst)
self.write_esmf_weights(src_path, dst_path, wgt_path, src_grid=sub_src, dst_grid=sub_dst,
filemode=filemode)
vm.barrier()

# Increment the counter outside of the loop to avoid counting empty subsets.
ctr += 1

# Global shapes require a VM global scope to collect.
src_global_shape = global_grid_shape(self.src_grid)
dst_global_shape = global_grid_shape(self.dst_grid)
Expand Down Expand Up @@ -790,7 +814,7 @@ def write_chunks(self):

vm.barrier()

def write_esmf_weights(self, src_path, dst_path, wgt_path, src_grid=None, dst_grid=None):
def write_esmf_weights(self, src_path, dst_path, wgt_path, src_grid=None, dst_grid=None, filemode=None):
"""
Write ESMF regridding weights for a source and destination filename combination.

Expand All @@ -801,12 +825,21 @@ def write_esmf_weights(self, src_path, dst_path, wgt_path, src_grid=None, dst_gr
:type src_grid: :class:`ocgis.spatial.grid.AbstractGrid`
:param dst_grid: If provided, use this destination grid for identifying ESMF parameters
:type dst_grid: :class:`ocgis.spatial.grid.AbstractGrid`
:param str filemode: If ``'BASIC'`` (default when ``None``), only write source/destination indices and weight
factors to the output weight file. If ``'WITHAUX'`` also write geometry-related auxiliary variables to the
output weight file.
"""
assert wgt_path is not None
assert src_path is not None
assert dst_path is not None

from ocgis.regrid.base import create_esmf_field, create_esmf_regrid
import ESMF

if filemode is None:
filemode = "BASIC"
filemode = getattr(ESMF.FileMode, filemode)

if src_grid is None:
src_grid = self.src_grid
if dst_grid is None:
Expand All @@ -818,9 +851,25 @@ def write_esmf_weights(self, src_path, dst_path, wgt_path, src_grid=None, dst_gr
dstfield, dstgrid = create_esmf_field(dst_path, dst_grid, self.esmf_kwargs)
regrid = None

# If auxiliary weight file variables are being written, update the ESMF arguments with some additional metadata.
if filemode == ESMF.FileMode.WITHAUX:
try:
self.esmf_kwargs['src_file'] = get_file_path(self.source)
self.esmf_kwargs['dst_file'] = get_file_path(self.destination)
except Exception as e:
vm.abort(exc=e)
self.esmf_kwargs['src_file_type'] = self.src_grid.driver.get_esmf_fileformat()
self.esmf_kwargs['dst_file_type'] = self.dst_grid.driver.get_esmf_fileformat()
try:
ocgis_lh(msg="creating esmf regrid...", logger=_LOCAL_LOGGER, level=logging.DEBUG)
regrid = create_esmf_regrid(srcfield=srcfield, dstfield=dstfield, filename=wgt_path, **self.esmf_kwargs)

# Older versions of ESMPy do not support 'filemode'. If it is set to BASIC (the legacy default) then remove
# the filemode argument.
if filemode == ESMF.FileMode.BASIC:
regrid = create_esmf_regrid(srcfield=srcfield, dstfield=dstfield, filename=wgt_path, **self.esmf_kwargs)
else:
regrid = create_esmf_regrid(srcfield=srcfield, dstfield=dstfield, filename=wgt_path, filemode=filemode,
**self.esmf_kwargs)
finally:
to_destroy = [regrid, srcgrid, srcfield, dstgrid, dstfield]
for t in to_destroy:
Expand Down Expand Up @@ -895,6 +944,12 @@ def get_grid_object(obj, load=True):
return res


def get_file_path(rd):
if not isinstance(rd, RequestDataset):
raise ValueError('must be a RequestDataset')
return rd.uri


def global_grid_shape(grid):
with vm.scoped_by_emptyable('global grid shape', grid):
if not vm.is_null:
Expand Down
22 changes: 19 additions & 3 deletions src/ocgis/test/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,11 +423,14 @@ def assertWarns(self, warning, meth):
meth()
self.assertTrue(any(item.category == warning for item in warning_list))

def assertWeightFilesEquivalent(self, src_filename, dst_filename):
def assertWeightFilesEquivalent(self, src_filename, dst_filename, special_history=True):
"""Assert weight files are equivalent."""

nwf = RequestDataset(dst_filename).get()
gwf = RequestDataset(src_filename).get()
nwf = RequestDataset(dst_filename, driver="netcdf").get()
gwf = RequestDataset(src_filename, driver="netcdf").get()

self.assertEqual(nwf.keys(), gwf.keys())

nwf_row = nwf['row'].get_value()
gwf_row = gwf['row'].get_value()
self.assertAsSetEqual(nwf_row, gwf_row)
Expand Down Expand Up @@ -455,6 +458,19 @@ def assertWeightFilesEquivalent(self, src_filename, dst_filename):
diffs = np.abs(diffs)
self.assertLess(diffs.max(), 1e-14)

if special_history:
actual = nwf.attrs.copy()
desired = gwf.attrs.copy()
removes = ['history', 'created_by_user', 'created_on_hostname']
for r in removes:
actual.pop(r, None)
desired.pop(r, None)
self.assertEqual(actual, desired)
else:
actual = nwf.attrs
desired = gwf.attrs
self.assertEqual(actual, desired)

@staticmethod
def barrier_print(*args, **kwargs):
from ocgis.vmachine.mpi import barrier_print
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def fixture_esmf_unstruct_field(self):
def test_system_converting_state_boundaries_shapefile(self):
verbose = False
if verbose: ocgis.vm.barrier_print("starting test")
ocgis.env.USE_NETCDF4_MPI = False # tdk:FIX: this hangs in the STATE_FIPS write for asynch might be nc4 bug...
ocgis.env.USE_NETCDF4_MPI = False # tdk:RELEASE:FIX: this hangs in the STATE_FIPS write for asynch might be nc4 bug...
keywords = {'transform_to_crs': [None, Spherical],
'use_geometry_iterator': [False, True]}
actual_xsums = []
Expand Down Expand Up @@ -216,7 +216,8 @@ def test_system_spatial_subsetting(self):

@attr('mpi', 'esmf')
def test_system_grid_chunking(self):
if vm.size != 4: raise SkipTest('vm.size != 4')
if vm.size != 4:
raise SkipTest('vm.size != 4')

from ocgis.spatial.grid_chunker import GridChunker
path = self.path_esmf_unstruct
Expand Down Expand Up @@ -247,7 +248,6 @@ def test_system_grid_chunking(self):
d = os.path.join(chunk_wd, 'split_dst_{}.nc'.format(ctr))
sf = Field.read(s, driver=DriverESMFUnstruct)
df = Field.read(d, driver=DriverESMFUnstruct)
self.assertLessEqual(sf.grid.shape[0] - df.grid.shape[0], 150)
self.assertGreater(sf.grid.shape[0], df.grid.shape[0])

wgt = os.path.join(chunk_wd, 'esmf_weights_{}.nc'.format(ctr))
Expand Down
Loading