Source code for sea_ice_regridder

from __future__ import annotations
import os
import xarray as xr
import numpy  as np
import pandas as pd
__all__ = ["SeaIceRegridder"]
[docs] class SeaIceRegridder: """ Regridding and geometric utilities for Antarctic sea-ice analysis workflows. This class provides two broad groups of functionality: 1) CICE grid-to-grid remapping (B-grid → T-grid) - xESMF-based remapping using persistent weight files (recommended when you need a consistent, reproducible mapping operator). - a lightweight, Dask-safe 2×2 corner averaging operator as a fast alternative when an exact xESMF regrid is unnecessary. 2) Swath-to-grid remapping and EPSG:3031 helpers - projection of lat/lon swaths to Antarctic polar stereographic (EPSG:3031), - extent unioning/snap-to-grid utilities for building a common analysis grid, - nearest-neighbour resampling of swath data onto an AreaDefinition grid, - convenience functions for adding lon/lat coordinates to EPSG:3031 grids, - geographic subsetting of curvilinear model grids using lon/lat masks that handle dateline/seam crossing. Expected external configuration ------------------------------ This class is designed to sit inside the broader AFIM stack. The following attributes/methods are expected to exist on `self` (typically injected by kwargs or provided by a base class/toolbox manager): Attributes - logger : logging.Logger For progress/debug logging. - CICE_dict : dict Must contain keys used by the B→T routines, typically: * "x_dim", "y_dim" * "x_dim_length", "y_dim_length" * "bcoord_names" (e.g., ["ULON","ULAT"]) * "P_reG_u2t_weights" (path to xESMF weight file) - G_u, G_t : dict-like or xr.Dataset CICE source and target grid definitions used by xESMF. - kmt_org : xr.Dataset or dict-like Must include "kmt_org" used as the target grid mask for xESMF weight generation. Methods - load_cice_grid(...) Must populate `self.G_u` and `self.G_t` (and typically `self.kmt_org`). - define_reG_weights() Creates `self.reG` (xESMF regridder) and sets `self.reG_weights_defined`. - normalise_longitudes(lon, wrap=...) Used by EPSG:3031 and geographic subsetting utilities. Notes ----- - xESMF weight generation is usually expensive; this class is structured to reuse persistent weight files where possible. - Longitudes are explicitly wrapped before building swath definitions and before geographic masking to avoid seam artefacts. """ def __init__(self, **kwargs): """ Initialise the regridding helper. Parameters ---------- **kwargs Optional configuration injected into `self` by the surrounding workflow. Common keys include `logger`, `CICE_dict`, and file paths. Notes ----- - The shown implementation is a no-op; in production you typically assign kwargs onto `self` (as you do in your other classes) or inherit from a shared base class. """ return def _ensure_reG_defined(self): """ Ensure that an xESMF regridder is available on `self`. This method checks for an existing `self.reG` operator and, if missing, calls `define_reG_weights()` to create (or reuse) xESMF weights and instantiate the regridder. Returns ------- None Notes ----- - Intended as a lightweight guard at the top of functions that depend on `self.reG`. - Weight reuse behaviour is controlled inside `define_reG_weights()`. """ if getattr(self, "reG", None) is not None and getattr(self, "reG", None) is not None: return self.define_reG_weights()
[docs] def define_reG_weights(self): """ Define and store an xESMF regridder to remap CICE B-grid (U-point) data to the T-grid. This method constructs/loads the source (U-point) and target (T-point) CICE grids, attaches a land/sea mask to the target grid, and then instantiates an xESMF `Regridder` object. If a weight file already exists, it is reused; otherwise, new weights are generated and written to disk. Regridding configuration ------------------------ - Source grid : `self.G_u` (B-grid / U-point style grid definition) - Target grid : `self.G_t` (T-grid / cell-center grid definition) - Target mask : `self.kmt_org["kmt_org"]` stored as `G_t["mask"]` - Method : "bilinear" - periodic : True (assumes global periodicity in longitude) - ignore_degenerate : True (skip degenerate cells) - extrap_method : "nearest_s2d" (nearest-neighbour extrapolation source→dest) - Weight file : `self.CICE_dict["P_reG_u2t_weights"]` - reuse_weights: True if the file exists, else False (weights created) Side effects ------------ - Sets `self.reG` to the instantiated xESMF regridder. - Sets `self.reG_weights_defined = True` upon success. Returns ------- None Raises ------ KeyError If required keys are missing from `self.CICE_dict` (e.g., weight path). AttributeError If `load_cice_grid()` does not populate `self.G_u`, `self.G_t`, or `self.kmt_org`. ImportError If `xesmf` is not available in the runtime environment. Exception Propagates xESMF errors arising from grid definitions or weight I/O. Notes ----- - The grid definitions must include valid "lon"/"lat" (and ideally corner) information for the selected method. - For conservative methods you would typically require corner arrays; for bilinear, centers are sufficient, but corner metadata may still improve robustness. """ import xesmf as xe self.load_cice_grid() G_u = self.G_u G_t = self.G_t G_t['mask'] = self.kmt_org['kmt_org'] F_weights = self.CICE_dict["P_reG_u2t_weights"] weights_exist = os.path.exists(F_weights) self.logger.info(f"{'Reusing' if weights_exist else 'Creating'} regrid weights: {F_weights}") self.reG = xe.Regridder(G_u, G_t, method = "bilinear", periodic = True, ignore_degenerate = True, extrap_method = "nearest_s2d", reuse_weights = weights_exist, filename = F_weights) self.reG_weights_defined = True
[docs] def reG_bgrid_to_tgrid_xesmf(self, da, coord_names=None): """ Regrid a B-grid DataArray to the T-grid using the pre-defined xESMF regridder. The input DataArray must carry longitude/latitude coordinate variables. If `coord_names` is not provided, this method uses the configured coordinate names in `self.CICE_dict["bcoord_names"]` (commonly ["ULON","ULAT"]). Parameters ---------- da : xr.DataArray B-grid variable to regrid. Must include coordinate variables corresponding to longitude and latitude (either 1D or 2D, depending on your grid definition). coord_names : list[str], optional Names of the coordinates on `da` that represent lon/lat. If omitted, defaults to `self.CICE_dict["bcoord_names"]`. Returns ------- xr.DataArray or None Regridded DataArray on the T-grid. Returns None if required coordinates are missing or if regridding fails. Raises ------ RuntimeError If `self.reG` is not defined. Call `_ensure_reG_defined()` or `define_reG_weights()` first (or ensure your workflow does so). Notes ----- - The method renames the provided coordinate variables to the xESMF-conventional names "lon" and "lat" before applying `self.reG(...)`. - Error handling is intentionally conservative: failures are logged and `None` is returned to allow calling workflows to skip problematic variables gracefully. """ coord_names = coord_names if coord_names is not None else self.CICE_dict["bcoord_names"] if not set(coord_names).issubset(set(da.coords)): self.logger.error(f"Cannot regrid: as {coord_names} not found in coordinates.") return None coord_map = {} for name in coord_names: if "LAT" in name.upper(): coord_map[name] = "lat" elif "LON" in name.upper(): coord_map[name] = "lon" if set(coord_map.values()) != {"lat", "lon"}: self.logger.error(f"Could not identify lat/lon from coord_names: {coord_names}") return None da_tmp = da.rename(coord_map) try: da_reG = self.reG(da_tmp) except Exception as e: self.logger.error(f"Regridding failed: {e}") return None return da_reG
[docs] def define_reG_regular_weights(self, da, G_res = 0.15, region = [0,360,-90,0], AF2020 = False, variable_name = None, lon_coord_name = None, lat_coord_name = None, time_coord_name = None, spatial_dim_names = None, reG_method = "bilinear", periodic = False, reuse_weights = True, P_weights = None): """ Define an xESMF regridder from a curvilinear source grid to a regular lat/lon grid. Parameters ---------- da : xr.Dataset or xr.DataArray Source object from which lat/lon coordinates are extracted. G_res : float, default 0.15 Regular grid resolution (degrees). region : list[float], default [0,360,-90,0] Regular grid bounding box [lon_min, lon_max, lat_min, lat_max]. AF2020 : bool, default False If True, use AF2020 naming and default weights path; else CICE/model. reG_method : str, default "bilinear" xESMF method (e.g., "bilinear", "nearest_s2d", "conservative"). periodic : bool, default False Set True for global periodic longitude grids. reuse_weights : bool, default True If True, reuse an existing weights file if present. P_weights : str or Path, optional Override weights filename. Defaults to AF2020/CICE configured weights path. Returns ------- xesmf.Regridder Configured regridder from source to regular grid. Notes ----- - Source/target grids are constructed from 2D lat/lon arrays. - Conservative regridding requires grid corner information; this helper builds only center coordinates unless you extend it. """ import xesmf as xe coords = self.define_fast_ice_coordinates(da, AF2020 = AF2020, variable_name = variable_name, lon_coord_name = lon_coord_name, lat_coord_name = lat_coord_name, time_coord_name = time_coord_name, spatial_dim_names = spatial_dim_names) if AF2020: P_weights = P_weights if P_weights is not None else self.AF_FI_dict["P_reG_reg_weights"] else: P_weights = P_weights if P_weights is not None else self.CICE_dict["P_reG_reg_weights"] G_src = xr.Dataset({self.CICE_dict['lat_coord_name'] : (coords['names']['spatial_dims'], coords['latitudes']), self.CICE_dict['lon_coord_name'] : (coords['names']['spatial_dims'], coords['longitudes'])}) G_dst = self.define_regular_G(G_res, region=region, spatial_dim_names=coords['names']['spatial_dims']) return xe.Regridder(G_src, G_dst, method = reG_method, periodic = periodic, ignore_degenerate = True, reuse_weights = reuse_weights, filename = P_weights)
[docs] def simple_spatial_averaging_bgrid_to_tgrid(self, var): """ Compute a Dask-safe 2×2 unweighted average from B-grid corner points to T-grid centers. This provides a lightweight alternative to xESMF remapping when a simple local averaging is sufficient. The operation is performed by slicing four corner-shifted views (v00, v01, v10, v11), averaging them, padding back to the original grid size, and applying a cyclic wrap to the last x column (last equals first). Parameters ---------- var : xr.DataArray Input array on the B-grid. May be 2D (nj, ni) or 3D (time, nj, ni) depending on your configuration. Dimension names are taken from `self.CICE_dict["y_dim"]` and `self.CICE_dict["x_dim"]`. Returns ------- xr.DataArray Averaged field on the T-grid with the same nominal shape as the configured target sizes (`y_dim_length`, `x_dim_length`). NaNs are used where padding is required. Raises ------ KeyError If required dimension names/lengths are missing from `self.CICE_dict`. AssertionError If the final array shape does not match the configured target sizes. Notes ----- - This is not a conservative remap; it is a local arithmetic average. - The last x column is set equal to the first to maintain periodicity. - Indexes for spatial dims are dropped to avoid downstream surprises when mixing integer and coordinate-based selection. """ x_dim = self.CICE_dict["x_dim"] y_dim = self.CICE_dict["y_dim"] x_len = self.CICE_dict["x_dim_length"] y_len = self.CICE_dict["y_dim_length"] self.logger.info(f"input shape to spatial averaging: {var.shape}") self.logger.info(" → Slicing corner points for averaging...") v00 = var.isel({y_dim: slice(None, -1), x_dim: slice(None, -1)}) v01 = var.isel({y_dim: slice(None, -1), x_dim: slice(1, None)}) v10 = var.isel({y_dim: slice(1, None), x_dim: slice(None, -1)}) v11 = var.isel({y_dim: slice(1, None), x_dim: slice(1, None)}) self.logger.info(" → Computing mean of four corners...") avg = (v00 + v01 + v10 + v11) / 4.0 self.logger.info(" → Padding with NaNs to restore original grid size...") pad_y = max(y_len - avg.sizes.get(y_dim, 0), 0) pad_x = max(x_len - avg.sizes.get(x_dim, 0), 0) avg = avg.pad({y_dim: (0, pad_y), x_dim: (0, pad_x)}, constant_values=np.nan) self.logger.info(" → Applying cyclic wrap for last column...") if avg.sizes.get(x_dim, 0) > 1: avg[{x_dim: -1}] = avg.isel({x_dim: 0}) # Force re-slicing to expected grid size to ensure consistency avg = avg.isel({y_dim: slice(0, y_len), x_dim: slice(0, x_len)}) if "time" in var.coords: avg = avg.assign_coords(time=var["time"]) self.logger.info(" → Time coordinate restored.") assert avg.sizes[y_dim] == y_len, f"{y_dim} mismatch: got {avg.sizes[y_dim]}, expected {y_len}" assert avg.sizes[x_dim] == x_len, f"{x_dim} mismatch: got {avg.sizes[x_dim]}, expected {x_len}" for dim in [y_dim, x_dim]: if dim in avg.indexes: avg = avg.drop_indexes(dim) return avg
[docs] def pygmt_regrid(self, da, lon, lat, grid_res=None, region=None, search_radius="200k"): """ Regrid a 2D data array using PyGMT's nearneighbor interpolation. This method applies PyGMT's `nearneighbor` algorithm to interpolate scattered data values (`da`) onto a regular grid based on specified longitude and latitude arrays. The input is masked to ignore NaNs or non-finite values. Parameters ---------- da : xarray.DataArray 2D array of data values to interpolate (e.g., sea ice thickness). lon : xarray.DataArray or np.ndarray Longitude values corresponding to `da`, same shape. lat : xarray.DataArray or np.ndarray Latitude values corresponding to `da`, same shape. grid_res : str or float, optional Grid spacing for the output grid (e.g., "0.5", "10m"). Required by PyGMT. region : list or tuple, optional Bounding box for the output grid in the form [west, east, south, north]. search_radius : str or float, default "200k" Search radius for PyGMT's nearneighbor (e.g., "100k" for 100 km). Returns ------- gridded : xarray.DataArray Gridded output with interpolated values over the defined region. Notes ----- - All non-finite values in `da` are excluded prior to interpolation. - PyGMT must be properly installed and configured with GMT for this to work. """ import pygmt mask = np.isfinite(da) df = pd.DataFrame({"longitude": lon.values[mask].ravel(), "latitude" : lat.values[mask].ravel(), "z" : da.values[mask].ravel()}) return pygmt.nearneighbor(data = df, spacing = grid_res, region = region, search_radius = search_radius)
############################################################## # PYRESAMPLE # ##############################################################
[docs] def to_3031_extent(self, lat2d, lon2d, buffer_m=20_000): """ Project a swath's latitude/longitude coordinates to EPSG:3031 and return an extent. This helper converts 2D lat/lon arrays to Antarctic polar stereographic coordinates (EPSG:3031), then returns an axis-aligned bounding box: [xmin, ymin, xmax, ymax] expanded by an optional buffer (meters). Parameters ---------- lat2d, lon2d : array-like 2D arrays of latitude and longitude (degrees). Must be broadcast-compatible. buffer_m : float, default 20000 Buffer added to each side of the extent (meters). Returns ------- list[float] Extent [xmin, ymin, xmax, ymax] in EPSG:3031 meters (including buffer). Notes ----- - Longitudes are wrapped to [-180, 180) prior to projection to avoid dateline issues. - Non-finite projected points are excluded when computing min/max. """ from pyproj import Transformer lon2d = self.normalise_longitudes(lon2d, to="-180-180") transformer = Transformer.from_crs("EPSG:4326", "EPSG:3031", always_xy=True) x, y = transformer.transform(lon2d.ravel(), lat2d.ravel()) x = np.asarray(x); y = np.asarray(y) finite = np.isfinite(x) & np.isfinite(y) xmin, xmax = x[finite].min(), x[finite].max() ymin, ymax = y[finite].min(), y[finite].max() return [xmin - buffer_m, ymin - buffer_m, xmax + buffer_m, ymax + buffer_m]
[docs] def union_extents(self, extents): """ Compute the union (bounding) extent of multiple EPSG:3031 extents. Parameters ---------- extents : list of extents Each extent must be [xmin, ymin, xmax, ymax] in meters. Returns ------- list[float] Union extent [xmin, ymin, xmax, ymax] spanning all input extents. """ xs = [e[0] for e in extents] + [e[2] for e in extents] ys = [e[1] for e in extents] + [e[3] for e in extents] return [min(xs), min(ys), max(xs), max(ys)]
[docs] def snap_extent_to_grid(self, extent, pixel_size): """ Snap an extent to a regular grid defined by `pixel_size`. The snapped extent is expanded (never shrunk) such that: - xmin, ymin are floored to the nearest multiple of pixel_size - xmax, ymax are ceiled to the nearest multiple of pixel_size Parameters ---------- extent : list[float] [xmin, ymin, xmax, ymax] in meters. pixel_size : float Grid resolution (meters). Returns ------- list[float] Snapped extent [xmin, ymin, xmax, ymax] in meters. """ xmin, ymin, xmax, ymax = extent xmin = np.floor(xmin / pixel_size) * pixel_size ymin = np.floor(ymin / pixel_size) * pixel_size xmax = np.ceil (xmax / pixel_size) * pixel_size ymax = np.ceil (ymax / pixel_size) * pixel_size return [xmin, ymin, xmax, ymax]
[docs] def make_area_definition(self, extent, pixel_size=5_000, area_id="epsg3031_5km_union"): """ Create a PyResample AreaDefinition for a regular EPSG:3031 grid. Parameters ---------- extent : list[float] [xmin, ymin, xmax, ymax] in EPSG:3031 meters. Typically produced by `to_3031_extent()`, `union_extents()`, and `snap_extent_to_grid()`. pixel_size : float, default 5000 Grid resolution (meters). area_id : str, default "epsg3031_5km_union" Identifier string for the AreaDefinition. Returns ------- pyresample.geometry.AreaDefinition Regular projected grid definition suitable for PyResample resampling. Notes ----- - Width/height are computed from the extent and rounded to integer pixel counts. - The returned extent is adjusted so that xmax/ymax align exactly with width/height. """ from pyresample.geometry import AreaDefinition xmin, ymin, xmax, ymax = extent width = int(round((xmax - xmin) / pixel_size)) height = int(round((ymax - ymin) / pixel_size)) xmax = xmin + width * pixel_size ymax = ymin + height * pixel_size area_def = AreaDefinition( area_id=area_id, description="Common 5 km EPSG:3031 grid (union of inputs)", proj_id="epsg3031", projection="EPSG:3031", width=width, height=height, area_extent=(xmin, ymin, xmax, ymax), ) return area_def
[docs] def grid_coords_from_area(self, area_def, pixel_size=5_000): """ Construct 1D x/y coordinate arrays (cell centers) from a PyResample AreaDefinition. Parameters ---------- area_def : pyresample.geometry.AreaDefinition Target grid definition in EPSG:3031. pixel_size : float, default 5000 Grid resolution (meters). Returns ------- (x, y) : tuple[np.ndarray, np.ndarray] 1D arrays of cell-center coordinates (meters). `x` increases eastward. `y` decreases from top to bottom (north → south) consistent with the area extent definition. """ xmin, ymin, xmax, ymax = area_def.area_extent width, height = area_def.width, area_def.height # Cell centers x = xmin + (np.arange(width) + 0.5) * pixel_size y = ymax - (np.arange(height) + 0.5) * pixel_size # top->down (north->south) return x, y
[docs] def resample_swath_to_area(self, src_da, lat2d, lon2d, area_def, radius=10_000, fill_value=np.nan, pixel_size=5_000): """ Nearest-neighbour resample a 2D swath to an EPSG:3031 AreaDefinition grid. Parameters ---------- src_da : xr.DataArray 2D swath data array (e.g., satellite swath variable). Must be aligned with `lat2d`/`lon2d` in shape. lat2d, lon2d : array-like 2D latitude/longitude arrays (degrees) describing the swath geolocation. area_def : pyresample.geometry.AreaDefinition Target grid definition (EPSG:3031). radius : float, default 10000 Radius of influence in meters used for nearest-neighbour resampling. fill_value : scalar, default np.nan Fill value assigned where no source points fall within `radius`. pixel_size : float, default 5000 Target grid resolution used to construct the output x/y coordinates. Returns ------- xr.DataArray Resampled 2D field on the target grid with dims ("y","x") and coordinates "x" and "y" in meters, plus metadata indicating EPSG:3031. Notes ----- - Longitudes are wrapped to [-180, 180) before constructing the SwathDefinition. This is critical to avoid dateline discontinuities in the KD-tree search. - `nprocs=0` uses serial execution; set >0 to parallelise if appropriate. """ from pyresample.geometry import SwathDefinition from pyresample.kd_tree import resample_nearest lon2d = self.normalise_longitudes(lon2d, to="-180-180") # << key fix: wrap before building the swath swath = SwathDefinition(lons=lon2d, lats=lat2d) out2d = resample_nearest(source_geo_def = swath, data = src_da.values, target_geo_def = area_def, radius_of_influence= radius, fill_value = fill_value, nprocs = 0, # set >0 to parallelise reduce_data = True) x, y = self.grid_coords_from_area(area_def, pixel_size=pixel_size) da_out = xr.DataArray(out2d, dims = ("y", "x"), coords = {"x": ("x", x, {"units": "m", "standard_name": "projection_x_coordinate"}), "y": ("y", y, {"units": "m", "standard_name": "projection_y_coordinate"})}, name = src_da.name, attrs = {"crs": "EPSG:3031", "grid_mapping": "spstereo", "res": float(pixel_size), **src_da.attrs}) return da_out
[docs] def add_lonlat_from_epsg3031(self, ds, x_name = "x", y_name = "y", wrap = "0..360", # or "-180..180" out_dtype = "float32"): """ Add 2D lon/lat coordinate fields to a dataset defined on an EPSG:3031 grid. This function broadcasts the 1D projected x/y coordinates to 2D, converts them to lon/lat using EPSG:3031 → EPSG:4326, wraps longitudes to the requested convention, and attaches the results as dataset coordinates. Parameters ---------- ds : xr.Dataset Dataset with 1D projected coordinate dimensions `x_name` and `y_name`. x_name, y_name : str, default ("x","y") Names of the projected coordinate dimensions in `ds`. wrap : {"0..360","-180..180"}, default "0..360" Longitude wrapping convention for the output coordinate. out_dtype : str, default "float32" Output dtype for lon/lat coordinates. Use None to keep float64. Returns ------- xr.Dataset Dataset with added coordinates: - lon(y, x) - lat(y, x) Raises ------ ValueError If `x_name` or `y_name` are not present as dataset dimensions. Notes ----- - Uses `xr.apply_ufunc(..., dask="parallelized")` so it can operate on Dask-backed x/y coordinates without materialising large intermediate arrays. """ if x_name not in ds.dims or y_name not in ds.dims: raise ValueError(f"Expected dims '{y_name}', '{x_name}' in dataset.") # broadcast 1-D x/y -> 2-D (y,x) X2D, Y2D = xr.broadcast(ds[x_name], ds[y_name]) # shapes (y,x) lon, lat = xr.apply_ufunc(self._xy_to_lonlat, X2D, Y2D, input_core_dims=[[y_name, x_name], [y_name, x_name]], output_core_dims=[[y_name, x_name], [y_name, x_name]], dask="parallelized", vectorize=False, output_dtypes=[np.float64, np.float64],) # wrap longitudes & cast if wrap == "0..360": lon = lon % 360.0 else: lon = ((lon + 180.0) % 360.0) - 180.0 if out_dtype: lon = lon.astype(out_dtype) lat = lat.astype(out_dtype) # attach as coordinates (on same (y,x) dims) return ds.assign_coords(lon=lon, lat=lat)
[docs] def subset_by_lonlat_box(self, da: xr.DataArray, lon_range, lat_range, lon_name="TLON", lat_name="TLAT", jdim="nj", idim="ni", wrap="-180-180", crop=True): """ Subset a curvilinear grid DataArray by a geographic lon/lat bounding box. This method constructs a boolean mask from 2D lon/lat coordinates and applies it to `da` using `.where(mask)`. It supports bounding boxes that cross the longitude seam/dateline by interpreting a range where lon_min > lon_max as a wrapped interval. Parameters ---------- da : xr.DataArray DataArray on a curvilinear grid, typically with dims (time, nj, ni) or (nj, ni). Must include 2D coordinate fields `lon_name` and `lat_name`. lon_range : tuple[float, float] (lon_min, lon_max) in degrees. If the range crosses the seam, use a wrapped specification (e.g., (350, 20) for 0..360 convention). lat_range : tuple[float, float] (lat_min, lat_max) in degrees. lon_name, lat_name : str, default ("TLON","TLAT") Names of 2D longitude/latitude coordinate variables in `da.coords`. jdim, idim : str, default ("nj","ni") Names of the two spatial dimensions. wrap : {"0-360","-180-180"}, default "-180-180" Longitude wrap convention to apply prior to masking. crop : bool, default True If True, crops the output to the minimal bounding index box that contains at least one valid cell (reduces storage/plotting cost). Cropping is done by finding any-True rows/columns in the 2D mask. Returns ------- xr.DataArray Subsetted DataArray. Values outside the lon/lat box are masked (NaN). If `crop=True`, spatial dims are also index-cropped to a tight bounding box. Notes ----- - The masking step (`da.where(mask)`) is Dask-friendly. - The cropping step currently uses `.values` on 1D any-masks; for very large grids this may trigger compute. In typical Antarctic subsetting use-cases, this is acceptable because the mask reduction is cheap relative to the full field. - Requires `self.normalise_longitudes(...)` to exist and to accept the `wrap` argument. """ # coords (2-D) TLON = self.normalise_longitudes(da.coords[lon_name], wrap) TLAT = da.coords[lat_name] lon_min, lon_max = lon_range lat_min, lat_max = lat_range # Longitude mask (handle seam) if wrap == "0-360": lon_min, lon_max = lon_min % 360, lon_max % 360 else: lon_min = ((lon_min + 180) % 360) - 180 lon_max = ((lon_max + 180) % 360) - 180 if lon_min <= lon_max: mask_lon = (TLON >= lon_min) & (TLON <= lon_max) else: # crosses dateline / wrap seam mask_lon = (TLON >= lon_min) | (TLON <= lon_max) mask_lat = (TLAT >= lat_min) & (TLAT <= lat_max) mask = mask_lon & mask_lat # (nj, ni) out = da.where(mask) # broadcast over time if crop: j_any = mask.any(dim=idim) i_any = mask.any(dim=jdim) j_idx = np.where(j_any.values)[0] i_idx = np.where(i_any.values)[0] if j_idx.size and i_idx.size: j0, j1 = j_idx.min(), j_idx.max() + 1 i0, i1 = i_idx.min(), i_idx.max() + 1 out = out.isel({jdim: slice(j0, j1), idim: slice(i0, i1)}) return out