Source code for sea_ice_regridder

from __future__ import annotations
import os
import xarray as xr
import numpy  as np
import pandas as pd
from pathlib       import Path
from scipy.spatial import cKDTree
from scipy         import sparse
from scipy.sparse  import save_npz, load_npz

__all__ = ["SeaIceRegridder"]

[docs] class SeaIceRegridder: """ Regridding and geometric utilities for Antarctic sea-ice analysis workflows. This class provides two broad groups of functionality: 1) CICE grid-to-grid remapping (B-grid → T-grid) - xESMF-based remapping using persistent weight files (recommended when you need a consistent, reproducible mapping operator). - a lightweight, Dask-safe 2×2 corner averaging operator as a fast alternative when an exact xESMF regrid is unnecessary. 2) Swath-to-grid remapping and EPSG:3031 helpers - projection of lat/lon swaths to Antarctic polar stereographic (EPSG:3031), - extent unioning/snap-to-grid utilities for building a common analysis grid, - nearest-neighbour resampling of swath data onto an AreaDefinition grid, - convenience functions for adding lon/lat coordinates to EPSG:3031 grids, - geographic subsetting of curvilinear model grids using lon/lat masks that handle dateline/seam crossing. Expected external configuration ------------------------------ This class is designed to sit inside the broader AFIM stack. The following attributes/methods are expected to exist on `self` (typically injected by kwargs or provided by a base class/toolbox manager): Attributes - logger : logging.Logger For progress/debug logging. - CICE_dict : dict Must contain keys used by the B→T routines, typically: * "x_dim", "y_dim" * "x_dim_length", "y_dim_length" * "bcoord_names" (e.g., ["ULON","ULAT"]) * "P_reG_u2t_weights" (path to xESMF weight file) - G_u, G_t : dict-like or xr.Dataset CICE source and target grid definitions used by xESMF. - kmt_org : xr.Dataset or dict-like Must include "kmt_org" used as the target grid mask for xESMF weight generation. Methods - load_cice_grid(...) Must populate `self.G_u` and `self.G_t` (and typically `self.kmt_org`). - define_reG_weights() Creates `self.reG` (xESMF regridder) and sets `self.reG_weights_defined`. - normalise_longitudes(lon, wrap=...) Used by EPSG:3031 and geographic subsetting utilities. Notes ----- - xESMF weight generation is usually expensive; this class is structured to reuse persistent weight files where possible. - Longitudes are explicitly wrapped before building swath definitions and before geographic masking to avoid seam artefacts. """ def __init__(self, **kwargs): """ Initialise the regridding helper. Parameters ---------- **kwargs Optional configuration injected into `self` by the surrounding workflow. Common keys include `logger`, `CICE_dict`, and file paths. Notes ----- - The shown implementation is a no-op; in production you typically assign kwargs onto `self` (as you do in your other classes) or inherit from a shared base class. """ return def _ensure_reG_defined(self): """ Ensure that an xESMF regridder is available on `self`. This method checks for an existing `self.reG` operator and, if missing, calls `define_reG_weights()` to create (or reuse) xESMF weights and instantiate the regridder. Returns ------- None Notes ----- - Intended as a lightweight guard at the top of functions that depend on `self.reG`. - Weight reuse behaviour is controlled inside `define_reG_weights()`. """ if getattr(self, "reG", None) is not None and getattr(self, "reG", None) is not None: return self.define_reG_weights()
[docs] def define_reG_weights(self): """ Define and store an xESMF regridder to remap CICE B-grid (U-point) data to the T-grid. This method constructs/loads the source (U-point) and target (T-point) CICE grids, attaches a land/sea mask to the target grid, and then instantiates an xESMF `Regridder` object. If a weight file already exists, it is reused; otherwise, new weights are generated and written to disk. Regridding configuration ------------------------ - Source grid : `self.G_u` (B-grid / U-point style grid definition) - Target grid : `self.G_t` (T-grid / cell-center grid definition) - Target mask : `self.kmt_org["kmt_org"]` stored as `G_t["mask"]` - Method : "bilinear" - periodic : True (assumes global periodicity in longitude) - ignore_degenerate : True (skip degenerate cells) - extrap_method : "nearest_s2d" (nearest-neighbour extrapolation source→dest) - Weight file : `self.CICE_dict["P_reG_u2t_weights"]` - reuse_weights: True if the file exists, else False (weights created) Side effects ------------ - Sets `self.reG` to the instantiated xESMF regridder. - Sets `self.reG_weights_defined = True` upon success. Returns ------- None Raises ------ KeyError If required keys are missing from `self.CICE_dict` (e.g., weight path). AttributeError If `load_cice_grid()` does not populate `self.G_u`, `self.G_t`, or `self.kmt_org`. ImportError If `xesmf` is not available in the runtime environment. Exception Propagates xESMF errors arising from grid definitions or weight I/O. Notes ----- - The grid definitions must include valid "lon"/"lat" (and ideally corner) information for the selected method. - For conservative methods you would typically require corner arrays; for bilinear, centers are sufficient, but corner metadata may still improve robustness. """ import xesmf as xe self.load_cice_grid() G_u = self.G_u G_t = self.G_t G_t['mask'] = self.kmt_org['kmt_org'] F_weights = self.CICE_dict["P_reG_u2t_weights"] weights_exist = os.path.exists(F_weights) self.logger.info(f"{'Reusing' if weights_exist else 'Creating'} regrid weights: {F_weights}") self.reG = xe.Regridder(G_u, G_t, method = "bilinear", periodic = True, ignore_degenerate = True, extrap_method = "nearest_s2d", reuse_weights = weights_exist, filename = F_weights) self.reG_weights_defined = True
[docs] def reG_bgrid_to_tgrid_xesmf(self, da, coord_names=None): """ Regrid a B-grid DataArray to the T-grid using the pre-defined xESMF regridder. The input DataArray must carry longitude/latitude coordinate variables. If `coord_names` is not provided, this method uses the configured coordinate names in `self.CICE_dict["bcoord_names"]` (commonly ["ULON","ULAT"]). Parameters ---------- da : xr.DataArray B-grid variable to regrid. Must include coordinate variables corresponding to longitude and latitude (either 1D or 2D, depending on your grid definition). coord_names : list[str], optional Names of the coordinates on `da` that represent lon/lat. If omitted, defaults to `self.CICE_dict["bcoord_names"]`. Returns ------- xr.DataArray or None Regridded DataArray on the T-grid. Returns None if required coordinates are missing or if regridding fails. Raises ------ RuntimeError If `self.reG` is not defined. Call `_ensure_reG_defined()` or `define_reG_weights()` first (or ensure your workflow does so). Notes ----- - The method renames the provided coordinate variables to the xESMF-conventional names "lon" and "lat" before applying `self.reG(...)`. - Error handling is intentionally conservative: failures are logged and `None` is returned to allow calling workflows to skip problematic variables gracefully. """ coord_names = coord_names if coord_names is not None else self.CICE_dict["bcoord_names"] if not set(coord_names).issubset(set(da.coords)): self.logger.error(f"Cannot regrid: as {coord_names} not found in coordinates.") return None coord_map = {} for name in coord_names: if "LAT" in name.upper(): coord_map[name] = "lat" elif "LON" in name.upper(): coord_map[name] = "lon" if set(coord_map.values()) != {"lat", "lon"}: self.logger.error(f"Could not identify lat/lon from coord_names: {coord_names}") return None da_tmp = da.rename(coord_map) try: da_reG = self.reG(da_tmp) except Exception as e: self.logger.error(f"Regridding failed: {e}") return None return da_reG
[docs] def define_reG_regular_weights(self, da, G_res = 0.15, region = [0,360,-90,0], AF2020 = False, variable_name = None, lon_coord_name = None, lat_coord_name = None, time_coord_name = None, spatial_dim_names = None, reG_method = "bilinear", periodic = False, reuse_weights = True, P_weights = None): """ Define an xESMF regridder from a curvilinear source grid to a regular lat/lon grid. Parameters ---------- da : xr.Dataset or xr.DataArray Source object from which lat/lon coordinates are extracted. G_res : float, default 0.15 Regular grid resolution (degrees). region : list[float], default [0,360,-90,0] Regular grid bounding box [lon_min, lon_max, lat_min, lat_max]. AF2020 : bool, default False If True, use AF2020 naming and default weights path; else CICE/model. reG_method : str, default "bilinear" xESMF method (e.g., "bilinear", "nearest_s2d", "conservative"). periodic : bool, default False Set True for global periodic longitude grids. reuse_weights : bool, default True If True, reuse an existing weights file if present. P_weights : str or Path, optional Override weights filename. Defaults to AF2020/CICE configured weights path. Returns ------- xesmf.Regridder Configured regridder from source to regular grid. Notes ----- - Source/target grids are constructed from 2D lat/lon arrays. - Conservative regridding requires grid corner information; this helper builds only center coordinates unless you extend it. """ import xesmf as xe coords = self.define_fast_ice_coordinates(da, AF2020 = AF2020, variable_name = variable_name, lon_coord_name = lon_coord_name, lat_coord_name = lat_coord_name, time_coord_name = time_coord_name, spatial_dim_names = spatial_dim_names) if AF2020: P_weights = P_weights if P_weights is not None else self.AF_FI_dict["P_reG_reg_weights"] else: P_weights = P_weights if P_weights is not None else self.CICE_dict["P_reG_reg_weights"] G_src = xr.Dataset({self.CICE_dict['lat_coord_name'] : (coords['names']['spatial_dims'], coords['latitudes']), self.CICE_dict['lon_coord_name'] : (coords['names']['spatial_dims'], coords['longitudes'])}) G_dst = self.define_regular_G(G_res, region=region, spatial_dim_names=coords['names']['spatial_dims']) return xe.Regridder(G_src, G_dst, method = reG_method, periodic = periodic, ignore_degenerate = True, reuse_weights = reuse_weights, filename = P_weights)
[docs] def simple_spatial_averaging_bgrid_to_tgrid(self, var): """ Compute a Dask-safe 2×2 unweighted average from B-grid corner points to T-grid centers. This provides a lightweight alternative to xESMF remapping when a simple local averaging is sufficient. The operation is performed by slicing four corner-shifted views (v00, v01, v10, v11), averaging them, padding back to the original grid size, and applying a cyclic wrap to the last x column (last equals first). Parameters ---------- var : xr.DataArray Input array on the B-grid. May be 2D (nj, ni) or 3D (time, nj, ni) depending on your configuration. Dimension names are taken from `self.CICE_dict["y_dim"]` and `self.CICE_dict["x_dim"]`. Returns ------- xr.DataArray Averaged field on the T-grid with the same nominal shape as the configured target sizes (`y_dim_length`, `x_dim_length`). NaNs are used where padding is required. Raises ------ KeyError If required dimension names/lengths are missing from `self.CICE_dict`. AssertionError If the final array shape does not match the configured target sizes. Notes ----- - This is not a conservative remap; it is a local arithmetic average. - The last x column is set equal to the first to maintain periodicity. - Indexes for spatial dims are dropped to avoid downstream surprises when mixing integer and coordinate-based selection. """ x_dim = self.CICE_dict["x_dim"] y_dim = self.CICE_dict["y_dim"] x_len = self.CICE_dict["x_dim_length"] y_len = self.CICE_dict["y_dim_length"] self.logger.info(f"input shape to spatial averaging: {var.shape}") self.logger.info(" → Slicing corner points for averaging...") v00 = var.isel({y_dim: slice(None, -1), x_dim: slice(None, -1)}) v01 = var.isel({y_dim: slice(None, -1), x_dim: slice(1, None)}) v10 = var.isel({y_dim: slice(1, None), x_dim: slice(None, -1)}) v11 = var.isel({y_dim: slice(1, None), x_dim: slice(1, None)}) self.logger.info(" → Computing mean of four corners...") avg = (v00 + v01 + v10 + v11) / 4.0 self.logger.info(" → Padding with NaNs to restore original grid size...") pad_y = max(y_len - avg.sizes.get(y_dim, 0), 0) pad_x = max(x_len - avg.sizes.get(x_dim, 0), 0) avg = avg.pad({y_dim: (0, pad_y), x_dim: (0, pad_x)}, constant_values=np.nan) self.logger.info(" → Applying cyclic wrap for last column...") if avg.sizes.get(x_dim, 0) > 1: avg[{x_dim: -1}] = avg.isel({x_dim: 0}) # Force re-slicing to expected grid size to ensure consistency avg = avg.isel({y_dim: slice(0, y_len), x_dim: slice(0, x_len)}) if "time" in var.coords: avg = avg.assign_coords(time=var["time"]) self.logger.info(" → Time coordinate restored.") assert avg.sizes[y_dim] == y_len, f"{y_dim} mismatch: got {avg.sizes[y_dim]}, expected {y_len}" assert avg.sizes[x_dim] == x_len, f"{x_dim} mismatch: got {avg.sizes[x_dim]}, expected {x_len}" for dim in [y_dim, x_dim]: if dim in avg.indexes: avg = avg.drop_indexes(dim) return avg
[docs] def pygmt_regrid(self, da, lon, lat, grid_res=None, region=None, search_radius="200k"): """ Regrid a 2D data array using PyGMT's nearneighbor interpolation. This method applies PyGMT's `nearneighbor` algorithm to interpolate scattered data values (`da`) onto a regular grid based on specified longitude and latitude arrays. The input is masked to ignore NaNs or non-finite values. Parameters ---------- da : xarray.DataArray 2D array of data values to interpolate (e.g., sea ice thickness). lon : xarray.DataArray or np.ndarray Longitude values corresponding to `da`, same shape. lat : xarray.DataArray or np.ndarray Latitude values corresponding to `da`, same shape. grid_res : str or float, optional Grid spacing for the output grid (e.g., "0.5", "10m"). Required by PyGMT. region : list or tuple, optional Bounding box for the output grid in the form [west, east, south, north]. search_radius : str or float, default "200k" Search radius for PyGMT's nearneighbor (e.g., "100k" for 100 km). Returns ------- gridded : xarray.DataArray Gridded output with interpolated values over the defined region. Notes ----- - All non-finite values in `da` are excluded prior to interpolation. - PyGMT must be properly installed and configured with GMT for this to work. """ import pygmt mask = np.isfinite(da) df = pd.DataFrame({"longitude": lon.values[mask].ravel(), "latitude" : lat.values[mask].ravel(), "z" : da.values[mask].ravel()}) return pygmt.nearneighbor(data = df, spacing = grid_res, region = region, search_radius = search_radius)
############################################################## # PYRESAMPLE # ##############################################################
[docs] def to_3031_extent(self, lat2d, lon2d, buffer_m=20_000): """ Project a swath's latitude/longitude coordinates to EPSG:3031 and return an extent. This helper converts 2D lat/lon arrays to Antarctic polar stereographic coordinates (EPSG:3031), then returns an axis-aligned bounding box: [xmin, ymin, xmax, ymax] expanded by an optional buffer (meters). Parameters ---------- lat2d, lon2d : array-like 2D arrays of latitude and longitude (degrees). Must be broadcast-compatible. buffer_m : float, default 20000 Buffer added to each side of the extent (meters). Returns ------- list[float] Extent [xmin, ymin, xmax, ymax] in EPSG:3031 meters (including buffer). Notes ----- - Longitudes are wrapped to [-180, 180) prior to projection to avoid dateline issues. - Non-finite projected points are excluded when computing min/max. """ from pyproj import Transformer lon2d = self.normalise_longitudes(lon2d, to="-180-180") transformer = Transformer.from_crs("EPSG:4326", "EPSG:3031", always_xy=True) x, y = transformer.transform(lon2d.ravel(), lat2d.ravel()) x = np.asarray(x); y = np.asarray(y) finite = np.isfinite(x) & np.isfinite(y) xmin, xmax = x[finite].min(), x[finite].max() ymin, ymax = y[finite].min(), y[finite].max() return [xmin - buffer_m, ymin - buffer_m, xmax + buffer_m, ymax + buffer_m]
[docs] def union_extents(self, extents): """ Compute the union (bounding) extent of multiple EPSG:3031 extents. Parameters ---------- extents : list of extents Each extent must be [xmin, ymin, xmax, ymax] in meters. Returns ------- list[float] Union extent [xmin, ymin, xmax, ymax] spanning all input extents. """ xs = [e[0] for e in extents] + [e[2] for e in extents] ys = [e[1] for e in extents] + [e[3] for e in extents] return [min(xs), min(ys), max(xs), max(ys)]
[docs] def snap_extent_to_grid(self, extent, pixel_size): """ Snap an extent to a regular grid defined by `pixel_size`. The snapped extent is expanded (never shrunk) such that: - xmin, ymin are floored to the nearest multiple of pixel_size - xmax, ymax are ceiled to the nearest multiple of pixel_size Parameters ---------- extent : list[float] [xmin, ymin, xmax, ymax] in meters. pixel_size : float Grid resolution (meters). Returns ------- list[float] Snapped extent [xmin, ymin, xmax, ymax] in meters. """ xmin, ymin, xmax, ymax = extent xmin = np.floor(xmin / pixel_size) * pixel_size ymin = np.floor(ymin / pixel_size) * pixel_size xmax = np.ceil (xmax / pixel_size) * pixel_size ymax = np.ceil (ymax / pixel_size) * pixel_size return [xmin, ymin, xmax, ymax]
[docs] def make_area_definition(self, extent, pixel_size=5_000, area_id="epsg3031_5km_union"): """ Create a PyResample AreaDefinition for a regular EPSG:3031 grid. Parameters ---------- extent : list[float] [xmin, ymin, xmax, ymax] in EPSG:3031 meters. Typically produced by `to_3031_extent()`, `union_extents()`, and `snap_extent_to_grid()`. pixel_size : float, default 5000 Grid resolution (meters). area_id : str, default "epsg3031_5km_union" Identifier string for the AreaDefinition. Returns ------- pyresample.geometry.AreaDefinition Regular projected grid definition suitable for PyResample resampling. Notes ----- - Width/height are computed from the extent and rounded to integer pixel counts. - The returned extent is adjusted so that xmax/ymax align exactly with width/height. """ from pyresample.geometry import AreaDefinition xmin, ymin, xmax, ymax = extent width = int(round((xmax - xmin) / pixel_size)) height = int(round((ymax - ymin) / pixel_size)) xmax = xmin + width * pixel_size ymax = ymin + height * pixel_size area_def = AreaDefinition( area_id=area_id, description="Common 5 km EPSG:3031 grid (union of inputs)", proj_id="epsg3031", projection="EPSG:3031", width=width, height=height, area_extent=(xmin, ymin, xmax, ymax), ) return area_def
[docs] def grid_coords_from_area(self, area_def, pixel_size=5_000): """ Construct 1D x/y coordinate arrays (cell centers) from a PyResample AreaDefinition. Parameters ---------- area_def : pyresample.geometry.AreaDefinition Target grid definition in EPSG:3031. pixel_size : float, default 5000 Grid resolution (meters). Returns ------- (x, y) : tuple[np.ndarray, np.ndarray] 1D arrays of cell-center coordinates (meters). `x` increases eastward. `y` decreases from top to bottom (north → south) consistent with the area extent definition. """ xmin, ymin, xmax, ymax = area_def.area_extent width, height = area_def.width, area_def.height # Cell centers x = xmin + (np.arange(width) + 0.5) * pixel_size y = ymax - (np.arange(height) + 0.5) * pixel_size # top->down (north->south) return x, y
[docs] def resample_swath_to_area(self, src_da, lat2d, lon2d, area_def, radius=10_000, fill_value=np.nan, pixel_size=5_000): """ Nearest-neighbour resample a 2D swath to an EPSG:3031 AreaDefinition grid. Parameters ---------- src_da : xr.DataArray 2D swath data array (e.g., satellite swath variable). Must be aligned with `lat2d`/`lon2d` in shape. lat2d, lon2d : array-like 2D latitude/longitude arrays (degrees) describing the swath geolocation. area_def : pyresample.geometry.AreaDefinition Target grid definition (EPSG:3031). radius : float, default 10000 Radius of influence in meters used for nearest-neighbour resampling. fill_value : scalar, default np.nan Fill value assigned where no source points fall within `radius`. pixel_size : float, default 5000 Target grid resolution used to construct the output x/y coordinates. Returns ------- xr.DataArray Resampled 2D field on the target grid with dims ("y","x") and coordinates "x" and "y" in meters, plus metadata indicating EPSG:3031. Notes ----- - Longitudes are wrapped to [-180, 180) before constructing the SwathDefinition. This is critical to avoid dateline discontinuities in the KD-tree search. - `nprocs=0` uses serial execution; set >0 to parallelise if appropriate. """ from pyresample.geometry import SwathDefinition from pyresample.kd_tree import resample_nearest lon2d = self.normalise_longitudes(lon2d, to="-180-180") # << key fix: wrap before building the swath swath = SwathDefinition(lons=lon2d, lats=lat2d) out2d = resample_nearest(source_geo_def = swath, data = src_da.values, target_geo_def = area_def, radius_of_influence= radius, fill_value = fill_value, nprocs = 0, # set >0 to parallelise reduce_data = True) x, y = self.grid_coords_from_area(area_def, pixel_size=pixel_size) da_out = xr.DataArray(out2d, dims = ("y", "x"), coords = {"x": ("x", x, {"units": "m", "standard_name": "projection_x_coordinate"}), "y": ("y", y, {"units": "m", "standard_name": "projection_y_coordinate"})}, name = src_da.name, attrs = {"crs": "EPSG:3031", "grid_mapping": "spstereo", "res": float(pixel_size), **src_da.attrs}) return da_out
[docs] def add_lonlat_from_epsg3031(self, ds, x_name = "x", y_name = "y", wrap = "0..360", # or "-180..180" out_dtype = "float32"): """ Add 2D lon/lat coordinate fields to a dataset defined on an EPSG:3031 grid. This function broadcasts the 1D projected x/y coordinates to 2D, converts them to lon/lat using EPSG:3031 → EPSG:4326, wraps longitudes to the requested convention, and attaches the results as dataset coordinates. Parameters ---------- ds : xr.Dataset Dataset with 1D projected coordinate dimensions `x_name` and `y_name`. x_name, y_name : str, default ("x","y") Names of the projected coordinate dimensions in `ds`. wrap : {"0..360","-180..180"}, default "0..360" Longitude wrapping convention for the output coordinate. out_dtype : str, default "float32" Output dtype for lon/lat coordinates. Use None to keep float64. Returns ------- xr.Dataset Dataset with added coordinates: - lon(y, x) - lat(y, x) Raises ------ ValueError If `x_name` or `y_name` are not present as dataset dimensions. Notes ----- - Uses `xr.apply_ufunc(..., dask="parallelized")` so it can operate on Dask-backed x/y coordinates without materialising large intermediate arrays. """ if x_name not in ds.dims or y_name not in ds.dims: raise ValueError(f"Expected dims '{y_name}', '{x_name}' in dataset.") # broadcast 1-D x/y -> 2-D (y,x) X2D, Y2D = xr.broadcast(ds[x_name], ds[y_name]) # shapes (y,x) lon, lat = xr.apply_ufunc(self._xy_to_lonlat, X2D, Y2D, input_core_dims=[[y_name, x_name], [y_name, x_name]], output_core_dims=[[y_name, x_name], [y_name, x_name]], dask="parallelized", vectorize=False, output_dtypes=[np.float64, np.float64],) # wrap longitudes & cast if wrap == "0..360": lon = lon % 360.0 else: lon = ((lon + 180.0) % 360.0) - 180.0 if out_dtype: lon = lon.astype(out_dtype) lat = lat.astype(out_dtype) # attach as coordinates (on same (y,x) dims) return ds.assign_coords(lon=lon, lat=lat)
[docs] def subset_by_lonlat_box(self, da: xr.DataArray, lon_range, lat_range, lon_name="TLON", lat_name="TLAT", jdim="nj", idim="ni", wrap="-180-180", crop=True): """ Subset a curvilinear grid DataArray by a geographic lon/lat bounding box. This method constructs a boolean mask from 2D lon/lat coordinates and applies it to `da` using `.where(mask)`. It supports bounding boxes that cross the longitude seam/dateline by interpreting a range where lon_min > lon_max as a wrapped interval. Parameters ---------- da : xr.DataArray DataArray on a curvilinear grid, typically with dims (time, nj, ni) or (nj, ni). Must include 2D coordinate fields `lon_name` and `lat_name`. lon_range : tuple[float, float] (lon_min, lon_max) in degrees. If the range crosses the seam, use a wrapped specification (e.g., (350, 20) for 0..360 convention). lat_range : tuple[float, float] (lat_min, lat_max) in degrees. lon_name, lat_name : str, default ("TLON","TLAT") Names of 2D longitude/latitude coordinate variables in `da.coords`. jdim, idim : str, default ("nj","ni") Names of the two spatial dimensions. wrap : {"0-360","-180-180"}, default "-180-180" Longitude wrap convention to apply prior to masking. crop : bool, default True If True, crops the output to the minimal bounding index box that contains at least one valid cell (reduces storage/plotting cost). Cropping is done by finding any-True rows/columns in the 2D mask. Returns ------- xr.DataArray Subsetted DataArray. Values outside the lon/lat box are masked (NaN). If `crop=True`, spatial dims are also index-cropped to a tight bounding box. Notes ----- - The masking step (`da.where(mask)`) is Dask-friendly. - The cropping step currently uses `.values` on 1D any-masks; for very large grids this may trigger compute. In typical Antarctic subsetting use-cases, this is acceptable because the mask reduction is cheap relative to the full field. - Requires `self.normalise_longitudes(...)` to exist and to accept the `wrap` argument. """ # coords (2-D) TLON = self.normalise_longitudes(da.coords[lon_name], wrap) TLAT = da.coords[lat_name] lon_min, lon_max = lon_range lat_min, lat_max = lat_range # Longitude mask (handle seam) if wrap == "0-360": lon_min, lon_max = lon_min % 360, lon_max % 360 else: lon_min = ((lon_min + 180) % 360) - 180 lon_max = ((lon_max + 180) % 360) - 180 if lon_min <= lon_max: mask_lon = (TLON >= lon_min) & (TLON <= lon_max) else: # crosses dateline / wrap seam mask_lon = (TLON >= lon_min) | (TLON <= lon_max) mask_lat = (TLAT >= lat_min) & (TLAT <= lat_max) mask = mask_lon & mask_lat # (nj, ni) out = da.where(mask) # broadcast over time if crop: j_any = mask.any(dim=idim) i_any = mask.any(dim=jdim) j_idx = np.where(j_any.values)[0] i_idx = np.where(i_any.values)[0] if j_idx.size and i_idx.size: j0, j1 = j_idx.min(), j_idx.max() + 1 i0, i1 = i_idx.min(), i_idx.max() + 1 out = out.isel({jdim: slice(j0, j1), idim: slice(i0, i1)}) return out
#------------------------------------------------------------------------------------ # CAWCR #------------------------------------------------------------------------------------ def _lonlat_to_unit_sphere_xyz(self, lon, lat): """ Convert lon/lat in degrees to unit-sphere Cartesian coordinates. """ lon = np.asarray(lon, dtype=float) lat = np.asarray(lat, dtype=float) lonr = np.deg2rad(((lon + 180.0) % 360.0) - 180.0) latr = np.deg2rad(lat) x = np.cos(latr) * np.cos(lonr) y = np.cos(latr) * np.sin(lonr) z = np.sin(latr) return np.column_stack([x.ravel(), y.ravel(), z.ravel()]) def _angular_radius_to_chord(self, radius_km, earth_radius_km=6371.0): """ Convert great-circle radius in km to chord distance on the unit sphere. """ if radius_km is None: return None ang = radius_km / earth_radius_km return 2.0 * np.sin(0.5 * ang)
[docs] def build_station_to_curvilinear_sparse_weights(self, src_lon, src_lat, tgt_lon, tgt_lat, p_weights, target_mask = None, k = 1, power = 1.0, radius_km = None, overwrite = False, chunk_size = 100_000): """ Build or load a sparse IDW remapping operator from station points to a curvilinear target grid. Parameters ---------- src_lon, src_lat : 1D arrays Station coordinates. tgt_lon, tgt_lat : 2D arrays Target grid coordinates on (nj, ni). p_weights : str or Path Path to a sparse .npz weight matrix. target_mask : 2D array-like, optional 1 for valid ocean points, 0 for masked points. k : int, default 8 Number of nearest stations used in IDW. power : float, default 2.0 IDW exponent. radius_km : float, optional Optional cutoff radius; if no station lies within radius, the nearest station is still used as fallback. overwrite : bool, default False Rebuild even if weight file already exists. chunk_size : int, default 100000 Number of target points per KDTree query chunk. Returns ------- scipy.sparse.csr_matrix Sparse interpolation matrix of shape (n_target, n_station). """ p_weights = Path(p_weights) p_weights.parent.mkdir(parents=True, exist_ok=True) if p_weights.exists() and not overwrite: self.logger.info(f"Reusing wave sparse weights: {p_weights}") return load_npz(p_weights) self.logger.info(f"Creating wave sparse weights: {p_weights}") src_xyz = self._lonlat_to_unit_sphere_xyz(src_lon, src_lat) tgt_xyz = self._lonlat_to_unit_sphere_xyz(tgt_lon, tgt_lat) n_station = src_xyz.shape[0] n_target = tgt_xyz.shape[0] k = int(min(k, n_station)) radius_chord = self._angular_radius_to_chord(radius_km) tree = cKDTree(src_xyz) if target_mask is None: active = np.ones(n_target, dtype=bool) else: active = np.asarray(target_mask).ravel().astype(bool) row_parts = [] col_parts = [] data_parts = [] active_idx = np.where(active)[0] chunk_size = int(chunk_size) n_chunks = (active_idx.size + chunk_size - 1) // chunk_size for ichunk, i0 in enumerate(range(0, active_idx.size, chunk_size), start=1): ii = active_idx[i0:i0 + chunk_size] self.logger.info(f"Wave weight chunk {ichunk}/{n_chunks}: {ii.size} target cells") d, ind = tree.query(tgt_xyz[ii], k=k) d, ind = tree.query(tgt_xyz[ii], k=k) if k == 1: d = d[:, None] ind = ind[:, None] # Optional cutoff radius if radius_chord is not None: within = d <= radius_chord else: within = np.ones_like(d, dtype=bool) # Ensure at least nearest neighbour is used no_valid = ~within.any(axis=1) within[no_valid, 0] = True # Inverse-distance weights d_safe = np.maximum(d, 1.0e-12) w = np.where(within, d_safe ** (-power), 0.0) # Exact hit -> give full weight to nearest station exact = d[:, 0] < 1.0e-12 if np.any(exact): w[exact, :] = 0.0 w[exact, 0] = 1.0 wsum = w.sum(axis=1, keepdims=True) w = np.where(wsum > 0.0, w / wsum, 0.0) rr = np.repeat(ii, k) cc = ind.reshape(-1) vv = w.reshape(-1) keep = vv > 0.0 row_parts.append(rr[keep]) col_parts.append(cc[keep]) data_parts.append(vv[keep]) rows = np.concatenate(row_parts) if row_parts else np.array([], dtype=np.int64) cols = np.concatenate(col_parts) if col_parts else np.array([], dtype=np.int64) vals = np.concatenate(data_parts) if data_parts else np.array([], dtype=float) vals = vals.astype(np.float32, copy=False) W = sparse.csr_matrix((vals, (rows, cols)), shape=(n_target, n_station), dtype=np.float32) save_npz(p_weights, W) self.logger.info(f"weight file saved to {p_weights}") return W
[docs] def build_or_load_station_to_curvilinear_sparse_weights(self, src_lon, src_lat, tgt_lon, tgt_lat, p_weights, target_mask = None, k = 8, power = 2.0, radius_km = None, overwrite = False, chunk_size = 100_000): """ Thin wrapper for a more readable calling path from SeaIceWaves. """ return self.build_station_to_curvilinear_sparse_weights(src_lon = src_lon, src_lat = src_lat, tgt_lon = tgt_lon, tgt_lat = tgt_lat, p_weights = p_weights, target_mask = target_mask, k = k, power = power, radius_km = radius_km, overwrite = overwrite, chunk_size = chunk_size)
[docs] def apply_sparse_station_regridder( self, values, weights, target_shape, fill_value=np.nan, time_chunk=24, ): """ Apply a sparse station->grid operator to values with shape: (time, station, frequency) Works with NumPy, xarray, or dask-backed arrays. Processes in time chunks to avoid forcing a full remote/lazy load. """ import numpy as np import xarray as xr # Normalize input if isinstance(values, xr.DataArray): da = values.transpose("time", "station", "frequency") arr = da.data time_coord = da["time"].values else: arr = values time_coord = None shape = getattr(arr, "shape", None) if shape is None or len(shape) != 3: raise ValueError(f"Expected values with ndim=3 (time,station,frequency), got {shape}") n_time, n_station, n_freq = shape n_target = target_shape[0] * target_shape[1] if weights.shape != (n_target, n_station): raise ValueError( f"Weight matrix shape {weights.shape} incompatible with " f"target/station sizes {(n_target, n_station)}" ) self.logger.info( f"Applying sparse station regridder in time chunks: " f"n_time={n_time}, n_station={n_station}, n_freq={n_freq}, " f"time_chunk={time_chunk}" ) out_chunks = [] for t0 in range(0, n_time, time_chunk): t1 = min(n_time, t0 + time_chunk) self.logger.info(f"Regridding wave spectra time chunk {t0}:{t1}") # This is where the remote/lazy load happens, but only for one chunk block = np.asarray(arr[t0:t1, :, :], dtype=np.float32) # (tb, station, freq) tb = block.shape[0] # Recast to (station, tb*freq) vals2 = np.moveaxis(block, 1, 0).reshape(n_station, tb * n_freq) # Sparse apply out2 = weights @ vals2 # (n_target, tb*freq) # Reshape back to (tb, nj, ni, freq) out = out2.reshape(n_target, tb, n_freq) out = out.reshape(target_shape[0], target_shape[1], tb, n_freq) out = np.moveaxis(out, 2, 0) out = np.where(np.isfinite(out), out, fill_value).astype(np.float32, copy=False) out_chunks.append(out) return np.concatenate(out_chunks, axis=0)