Source code for sea_ice_plotter

from __future__ import annotations
import xarray as xr
import pandas as pd
import numpy  as np
from pathlib    import Path
from contextlib import nullcontext

__all__ = ["SeaIcePlotter"]

[docs] class SeaIcePlotter: """ Plotting utilities for AFIM / CICE sea-ice diagnostics. This class groups helper methods that prepare gridded xarray.DataArray fields for plotting (primarily with PyGMT), including: - Writing custom CPT (colour palette table) files for continuous or categorical difference fields (e.g., obs–model agreement maps). - Building “alpha/transparency” layers that modulate opacity by confidence or persistence strength. - Resolving 2D lon/lat coordinates from a DataArray via explicit coordinate names, inferred coordinate names, or fallback class grid objects (B-grid / T-grid). - Converting a 2D DataArray into 1D lon/lat/z vectors + masks suitable for `pygmt.Figure.plot` / `pygmt.Figure.grdimage` workflows. Notes ----- - Several methods assume the class has access to grid datasets such as `self.G_t` and `self.G_u`, and a method `self.load_bgrid(slice_hem=True)` that populates them. - The methods here are designed to be safe with Dask-backed DataArrays: they will compute only when necessary and will avoid materialising coordinate grids unless required. - CPT writing methods create plain text files on disk and return the resolved path. Expected Attributes / Methods ----------------------------- _hex_to_rgb : callable Helper that converts hex colour strings (e.g. "#2CA25F") to integer RGB triplets. load_bgrid : callable Loads grid coordinate datasets (at least `self.G_t` and/or `self.G_u`). G_t, G_u : xarray.Dataset Grid datasets holding 2D lon/lat fields under keys like `"lon"` and `"lat"`. normalise_longitudes : callable Converts longitudes to a target convention ("0-360" or "-180-180"). """ def __init__(self,**kwargs): """ Initialise the plotter. Parameters ---------- **kwargs Implementation-dependent configuration. In typical AFIM usage, kwargs may include paths, logging, grid configuration, or pre-loaded grid datasets. Notes ----- - This constructor is intentionally lightweight; many plotting workflows defer expensive grid loading until it is required (e.g., in `_resolve_lonlat_2d()`). """ return
[docs] def load_ice_shelves(self): """ Load and preprocess Antarctic ice shelf polygons for PyGMT plotting. This method reads a shapefile containing Antarctic coastal geometries, filters for polygons classified as ice shelves (`POLY_TYPE == 'S'`), ensures valid geometry, reprojects them to WGS84 (EPSG:4326), and applies a zero-width buffer to clean topology issues. Returns ------- geopandas.GeoSeries Cleaned and reprojected geometries of Antarctic ice shelves. Notes ----- - The input shapefile path is read from `self.config['pygmt_dict']['P_coast_shape']`. - This method is typically used to overlay ice shelf boundaries in PyGMT plots. - The returned geometries can be passed directly to `pygmt.Figure.plot()`. See Also -------- - self.plot_FIA_FIP_faceted : Uses this method to overlay ice shelves in map panels. - geopandas.read_file : For reading shapefiles. """ import geopandas as gpd gdf = gpd.read_file(self.config['pygmt_dict']['P_coast_shape']) shelves = gdf[gdf['POLY_TYPE'] == 'S'] shelves = shelves[~shelves.geometry.is_empty & shelves.geometry.notnull()] shelves = shelves.to_crs("EPSG:4326") shelves.geometry = shelves.geometry.buffer(0) return shelves.geometry
[docs] def create_IBCSO_bath(self): """ Extract and save a masked IBCSO bathymetry layer for Antarctic plotting. This method loads the IBCSO v2.0 dataset (as a NetCDF raster), extracts the seafloor depth (negative elevations), masks out land areas (positive or zero), and saves a cleaned version to NetCDF for use in plotting. The result is saved at the path specified in `self.config['pygmt_dict']['P_IBCSO_bath']`. Returns ------- None Notes ----- - Input file is assumed to be a NetCDF raster with variable `band_data`. - Only `band=0` is used (i.e., the main bathymetric band). - Output variable is named `bath` and excludes attributes and encodings for simplicity. - This method is typically called once to prepare the bathymetry layer before reuse. See Also -------- - self.load_IBCSO_bath : Loads the pre-saved masked bathymetry layer. - https://www.ibcso.org/ : International Bathymetric Chart of the Southern Ocean. """ ds = xr.open_dataset(self.config['pygmt_dict']['P_IBCSO_bed']) bed = ds.band_data.isel(band=0) bed_masked = bed.where(bed < 0) bed_masked.name = "bath" bed_masked.attrs = {} ds_out = bed_masked.to_dataset() ds_out.attrs = {} ds_out.encoding = {} ds_out.to_netcdf(self.config['pygmt_dict']["P_IBCSO_bath"])
[docs] def load_IBCSO_bath(self): """ Load masked IBCSO bathymetry dataset prepared by `create_IBCSO_bath()`. Returns the `bath` variable from the NetCDF file specified in `self.config['pygmt_dict']["P_IBCSO_bath"]`. Returns ------- xarray.DataArray 2D array of bathymetry values (only ocean depths, in meters). Values are negative below sea level, NaN over land. Notes ----- - This method assumes that `create_IBCSO_bath()` has already been run. - Designed for use in PyGMT background plotting or masking. See Also -------- - self.create_IBCSO_bath : Creates this file if it doesn't exist. - xarray.open_dataset : Loads the bathymetry layer. """ return xr.open_dataset(self.config['pygmt_dict']["P_IBCSO_bath"]).bath
[docs] def create_cbar_frame(self, series, label, units=None, extend_cbar=False, max_ann_steps=10): """ Construct a GMT-style colorbar annotation frame string for PyGMT plotting. This utility generates a clean, readable colorbar frame string using adaptive step sizing based on the data range. An optional second axis label (e.g., for units) and extension arrows can be included. Parameters ---------- series : list or tuple of float Data range for the colorbar as [vmin, vmax]. label : str Label text for the colorbar (e.g., "Fast Ice Persistence"). units : str, optional Units label to be shown along the secondary (y) axis (e.g., "1/100"). extend_cbar : bool, optional Whether to append extension arrows to the colorbar (+e). Default is False. max_ann_steps : int, default=10 Desired maximum number of major annotations (controls tick spacing). Returns ------- str or list of str GMT-format colorbar annotation string (e.g., "a0.1f0.02+lLabel"), or a list of two strings if `units` is provided. Notes ----- - Tick spacing is determined using a scaled logarithmic rounding to ensure clean steps (e.g., 0.1, 0.2, 0.5). - If `extend_cbar` is True, `+e` is added to indicate out-of-bounds extension arrows. - Compatible with `pygmt.Figure.colorbar(frame=...)`. Examples -------- >>> self.create_cbar_frame([0.01, 1.0], "Persistence", units="1/100") ['a0.1f0.02+lPersistence', 'y+l 1/100'] """ vmin, vmax = series[0], series[1] vrange = vmax - vmin raw_step = vrange / max_ann_steps exp = np.floor(np.log10(raw_step)) base = 10 ** exp mult = raw_step / base if mult < 1.5: ann_step = base * 1 elif mult < 3: ann_step = base * 2 elif mult < 7: ann_step = base * 5 else: ann_step = base * 10 tick_step = ann_step / 5 ann_str = f"{ann_step:.3f}".rstrip("0").rstrip(".") tick_str = f"{tick_step:.3f}".rstrip("0").rstrip(".") # Build annotation string frame = f"a{ann_str}f{tick_str}+l{label}" if extend_cbar: frame += "+e" # or use "+eU" / "+eL" for one-sided arrows if units is not None: return [frame, f"y+l {units}"] else: return frame
[docs] def get_meridian_center_from_geographic_extent(self, geographic_extent): """ Determine the optimal central meridian for PyGMT polar stereographic projections. Given a geographic extent in longitude/latitude format, this method calculates the central meridian (longitude) to use in 'S<lon>/<lat>/<width>' PyGMT projection strings. It accounts for dateline wrapping and ensures the plot is centered visually. Parameters ---------- geographic_extent : list of float Geographic region as [min_lon, max_lon, min_lat, max_lat]. Accepts longitudes in either [-180, 180] or [0, 360] and gracefully handles dateline crossing. Returns ------- float Central meridian (longitude) in the range [-180, 180]. Notes ----- - If the computed center falls outside the intended range (e.g., due to dateline wrapping), the method rotates the meridian 180° to better align the figure. - The result is saved to `self.plot_meridian_center` for reuse. - Used in PyGMT stereographic projections like `'S{lon}/-90/30c'`. See Also -------- - https://docs.generic-mapping-tools.org/latest/cookbook/proj.html - pygmt.Figure.basemap : For setting map projections using meridian centers. """ lon_min, lon_max = geographic_extent[0], geographic_extent[1] lon_min_360 = lon_min % 360 lon_max_360 = lon_max % 360 if (lon_max_360 - lon_min_360) % 360 > 180: center = ((lon_min_360 + lon_max_360 + 360) / 2) % 360 else: center = (lon_min_360 + lon_max_360) / 2 if center > 180: center -= 360 # Edge case fix: ensure center is visually aligned with geographic_extent # If the computed center is 180° out of phase (i.e., upside-down plots) if not (geographic_extent[0] <= center <= geographic_extent[1]): # Flip 180° center = (center + 180) % 360 if center > 180: center -= 360 #print(f"meridian center computed as {center:.2f}°") self.plot_meridian_center = center return center
[docs] def extract_min_max_dates(self, ts_dict, keys2plot=None, primary_key='FIA', time_coord='time'): """ Extract the minimum and maximum datetime values from time series data. This method scans through a dictionary of time series objects (either xarray.DataArrays, xarray.Datasets, or nested dictionaries containing a DataArray under `primary_key`), and returns the earliest and latest valid time values found across all selected entries. Parameters ---------- ts_dict : dict A dictionary where each value is either: - an xarray.DataArray with a `time_coord` coordinate, or - a dictionary containing a DataArray under the key `primary_key`. Keys typically represent simulation or experiment identifiers. keys2plot : list of str, optional If provided, restricts the operation to keys within this list. Keys not in `keys2plot` will be skipped. primary_key : str, default 'FIA' The key to use when accessing nested dictionaries within `ts_dict`. Ignored if the value is already a DataArray or Dataset. time_coord : str, default 'time' The name of the time coordinate to extract from each DataArray. Returns ------- tmin : pandas.Timestamp The earliest datetime found across all valid time series. tmax : pandas.Timestamp The latest datetime found across all valid time series. Notes ----- - Entries in `ts_dict` with missing or malformed time coordinates are skipped. - The key "AF2020" is always excluded. - If no valid entries remain after filtering, a warning is logged and `None` is returned. """ df_dts = [] for dict_key, data in ts_dict.items(): if (keys2plot is not None and dict_key not in keys2plot) or (dict_key == "AF2020"): continue self.logger.info(f"{dict_key} simulation will be included in {self._method_name()}()") # Handle both nested dict and direct DataArray if isinstance(data, dict) and primary_key in data: da = data[primary_key] else: da = data # assume data is already a DataArray df_dt = pd.DataFrame({"time": pd.to_datetime(da[time_coord].values)}) df_dts.append(df_dt) if not df_dts: self.logger.warning("No data to plot after filtering with keys2plot.") return df_all = pd.concat(df_dts, ignore_index=True).dropna() all_times = df_all["time"] return all_times.min(), all_times.max()
def _hex_to_rgb(self, hexstr: str) -> tuple[int, int, int]: h = hexstr.lstrip("#") return (int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16))
[docs] def write_tricolour_cpt(self, P_cpt, *, vmin : float = -1.0, vmid : float = 0.0, vmax : float = 1.0, cmin : str = "#2CA25F", # green cmid : str = "#FDAE61", # orange cmax : str = "#2171B5", # blue background_rgb : str = "255/255/255", foreground_rgb : str = "0/0/0", nan_rgb : str = "255/255/255") -> str: """ Create and write a 3-colour continuous CPT file for difference fields. This writes a simple two-segment CPT suitable for continuous data where values below/above a midpoint should transition through three anchor colours: - [vmin, vmid] mapped cmin → cmid - [vmid, vmax] mapped cmid → cmax The output is a GMT CPT text file, usable directly by PyGMT (e.g. `cmap=P_cpt`). Parameters ---------- P_cpt : str or pathlib.Path Target output path for the CPT file. Parent directories are created if needed. vmin, vmid, vmax : float, optional Data range breakpoints. `vmid` defines the central colour transition point. cmin, cmid, cmax : str, optional Hex colours used at `vmin`, `vmid`, and `vmax` respectively. Example: "#2CA25F". background_rgb : str, optional GMT background colour specification used for values below `vmin` (B record), formatted as "R/G/B". foreground_rgb : str, optional GMT foreground colour specification used for values above `vmax` (F record), formatted as "R/G/B". nan_rgb : str, optional GMT NaN colour specification (N record), formatted as "R/G/B". Returns ------- str String path to the written CPT file. Notes ----- - This function assumes `self._hex_to_rgb()` returns (r, g, b) integers in [0, 255]. - GMT expects CPT lines of the form: z0 r/g/b z1 r/g/b plus optional B/F/N records for background/foreground/NaN. """ r1, g1, b1 = self._hex_to_rgb(cmin) r2, g2, b2 = self._hex_to_rgb(cmid) r3, g3, b3 = self._hex_to_rgb(cmax) cpt_text = (f"{vmin} {r1}/{g1}/{b1} {vmid} {r2}/{g2}/{b2}\n" f"{vmid} {r2}/{g2}/{b2} {vmax} {r3}/{g3}/{b3}\n" f"B {background_rgb}\n" f"F {foreground_rgb}\n" f"N {nan_rgb}\n") P_cpt = Path(P_cpt) P_cpt.parent.mkdir(parents=True, exist_ok=True) P_cpt.write_text(cpt_text) return str(P_cpt)
[docs] def tricolor_cpt_and_alpha(self, P_cpt, obs, mod, *, # CPT vmin : float = -1.0, vmid : float = 0.0, vmax : float = 1.0, cmin : str = "#2CA25F", cmid : str = "#FDAE61", cmax : str = "#2171B5", # difference sign diff_mode : str = "obs-mod", # recommended for green=model-dom, blue=obs-dom # alpha rule alpha_mode : str = "max", # "max" or "mean" alpha_power : float = 1.0, alpha_floor : float = 0.0, nan_transparency : float = 100.0): # percent (0 opaque, 100 transparent) """ Build a tricolour CPT plus a transparency (alpha) layer from obs/model fields. This helper is designed for “difference + confidence” map products where: - colour encodes the signed difference (obs − model or model − obs), clipped into a fixed range [vmin, vmax], and - transparency encodes the strength/credibility of the signal (e.g., based on persistence or agreement magnitude), computed from the obs/model magnitudes. The function returns: - `cpt_path`: path to the written tricolour CPT - `z`: the clipped difference array - `t`: transparency percent array (0 = opaque, 100 = fully transparent) Parameters ---------- P_cpt : str or pathlib.Path Output path for the CPT file written by `write_tricolour_cpt()`. obs, mod : array-like Observed and modelled fields on a common grid. These are converted to `numpy.ndarray(dtype=float)` internally and must be broadcast-compatible. vmin, vmid, vmax : float, optional Colour scale limits and midpoint. The returned difference `z` is clipped into [vmin, vmax]. cmin, cmid, cmax : str, optional Hex colours for the CPT anchor points at vmin/vmid/vmax. diff_mode : {"obs-mod", "mod-obs"}, optional Sign convention for the difference: - "obs-mod": z = obs - mod - "mod-obs": z = mod - obs Choose a convention consistent with your map interpretation (e.g. green as model-dominant vs obs-dominant). alpha_mode : {"max", "mean"}, optional How to compute the “alpha strength” `a` from `obs` and `mod`: - "max": a = max(obs, mod) - "mean": a = 0.5 * (obs + mod) In both cases, `a` is clipped to [0, 1]. alpha_power : float, optional Exponent applied to `a` after clipping. Values > 1 sharpen opacity toward high-confidence regions; values < 1 broaden opacity. alpha_floor : float, optional Minimum opacity floor expressed in alpha-space. When > 0, rescales: a <- alpha_floor + (1 - alpha_floor) * a This prevents regions from becoming too transparent. nan_transparency : float, optional Transparency percentage used where either `z` or `t` becomes non-finite (NaN/inf). Default 100 = fully transparent. Returns ------- (str, np.ndarray, np.ndarray) cpt_path : str Path to the written CPT file. z : np.ndarray Difference field, clipped to [vmin, vmax]. t : np.ndarray Transparency percent in [0, 100], where 0 is opaque. Raises ------ ValueError If `diff_mode` is not one of {"obs-mod", "mod-obs"}. ValueError If `alpha_mode` is not one of {"max", "mean"}. Notes ----- - Transparency is computed as: t = (1 - a) * 100, so larger `a` yields more opaque pixels (smaller transparency percentage). - This method does not attempt to regrid or align obs/model; inputs must already be on a common grid and comparable. """ cpt_path = self.write_tricolour_cpt(P_cpt, vmin=vmin, vmid=vmid, vmax=vmax, cmin=cmin, cmid=cmid, cmax=cmax) obs = np.asarray(obs, dtype=float) mod = np.asarray(mod, dtype=float) dm = diff_mode.lower() if dm in ("obs-mod", "obs_minus_mod"): z = obs - mod elif dm in ("mod-obs", "mod_minus_obs"): z = mod - obs else: raise ValueError("diff_mode must be 'obs-mod' or 'mod-obs'.") z = np.clip(z, vmin, vmax) am = alpha_mode.lower() if am == "max": a = np.maximum(obs, mod) elif am == "mean": a = 0.5 * (obs + mod) else: raise ValueError("alpha_mode must be 'max' or 'mean'.") a = np.clip(a, 0.0, 1.0) if alpha_power != 1.0: a = a ** alpha_power if alpha_floor > 0.0: a = alpha_floor + (1.0 - alpha_floor) * a t = (1.0 - a) * 100.0 bad = ~np.isfinite(z) | ~np.isfinite(t) if np.any(bad): t = np.where(bad, nan_transparency, t) z = np.where(bad, np.nan, z) return cpt_path, z, t
[docs] def write_tricolor_category_cpt(self, P_cpt, *, # codes: 0=agreement, 1=model-dom, 2=obs-dom int0 = (0.0, 1.0), int1 = (1.0, 2.0), int2 = (2.0, 3.0), hex0 = "#FDAE61", # agreement = orange hex1 = "#2CA25F", # model-dom = green hex2 = "#2171B5", # obs-dom = blue background_rgb = "255/255/255", foreground_rgb = "0/0/0", nan_rgb = "255/255/255") -> str: """ Create and write a 3-class categorical CPT file. This is intended for integer-coded classification maps, commonly: - 0 : agreement - 1 : model-dominant - 2 : observation-dominant Each class is given a constant colour across its interval. The default colour scheme matches the continuous tricolour palette: - agreement: orange - model-dominant: green - obs-dominant: blue Parameters ---------- P_cpt : str or pathlib.Path Target output path for the categorical CPT file. Parent directories are created if needed. int0, int1, int2 : tuple[float, float], optional Numeric intervals for categories 0, 1, and 2 respectively. Defaults are (0,1), (1,2), (2,3), which works well with integer-coded rasters. hex0, hex1, hex2 : str, optional Hex colours for category 0, 1, and 2. background_rgb : str, optional GMT background (B record) colour as "R/G/B". foreground_rgb : str, optional GMT foreground (F record) colour as "R/G/B". nan_rgb : str, optional GMT NaN (N record) colour as "R/G/B". Returns ------- str String path to the written CPT file. Notes ----- - The CPT uses constant colours per interval, which is typically what you want for categorical class rasters. - This function assumes `self._hex_to_rgb()` returns integer RGB triplets. """ r0, g0, b0 = self._hex_to_rgb(hex0) r1, g1, b1 = self._hex_to_rgb(hex1) r2, g2, b2 = self._hex_to_rgb(hex2) cpt_text = (f"{int0[0]} {r0}/{g0}/{b0} {int0[1]} {r0}/{g0}/{b0}\n" f"{int1[0]} {r1}/{g1}/{b1} {int1[1]} {r1}/{g1}/{b1}\n" f"{int2[0]} {r2}/{g2}/{b2} {int2[1]} {r2}/{g2}/{b2}\n" f"B {background_rgb}\n" f"F {foreground_rgb}\n" f"N {nan_rgb}\n") P_cpt = Path(P_cpt) P_cpt.parent.mkdir(parents=True, exist_ok=True) P_cpt.write_text(cpt_text) return str(P_cpt)
def _is_dask_array(self, x) -> bool: """ Return True if `x` appears to be a Dask-backed array-like. Parameters ---------- x : object Any object. The check is heuristic and designed to work for common Dask array types attached to xarray objects. Returns ------- bool True if `x` has a `.compute()` method and its module path suggests a Dask type; False otherwise. Notes ----- - This is intentionally lightweight and avoids importing dask explicitly. - False positives are unlikely but possible if an object mimics the Dask API. """ return hasattr(x, "compute") and ("dask" in type(x).__module__.lower()) def _auto_mask_zero(self, da: xr.DataArray) -> bool: """ Decide whether zero values should be masked when preparing data for plotting. Many continuous diagnostics use zeros as “no data” placeholders (or contain large regions of structural zeros that would dominate plotting). However, categorical fields frequently use 0 as a valid class label and must not be masked. Parameters ---------- da : xr.DataArray 2D data field intended for plotting. Returns ------- bool True if zeros should be masked (typical for continuous fields), False if zeros should be retained (typical for categorical class rasters). Decision Logic -------------- - If the variable name suggests a categorical difference field (e.g. contains "diff_cat"), or if `da.attrs["flag_values"]` indicates a {0,1,2} category encoding, then zeros are treated as valid and are NOT masked. - Otherwise, zeros are masked by default. Notes ----- - You can override this behaviour explicitly in `pygmt_da_prep(mask_zero=...)`. """ name = (da.name or "").lower() fv = str(da.attrs.get("flag_values", "")).strip().replace(",", " ") is_categorical = ("diff_cat" in name) or (fv in {"0 1 2", "0 1 2 "}) return not is_categorical def _resolve_lonlat_2d(self, da2: xr.DataArray, *, bcoords : bool = False, tcoords : bool = True, lon_coord_name : str | None = None, lat_coord_name : str | None = None, infer_if_missing : bool = True): """ Resolve 2D longitude/latitude arrays matching a 2D DataArray. This method returns `(lon2d, lat2d)` as NumPy arrays matching `da2.shape`. It supports multiple coordinate discovery strategies (in priority order): 1) Explicit coordinate names passed via `lon_coord_name` / `lat_coord_name`. - Supports either 2D lon/lat coords aligned to the data, or 1D lon/lat vectors that can be meshed into 2D grids. 2) Inference from common coordinate naming conventions present in `da2.coords` (if `infer_if_missing=True`), e.g.: ("lon","lat"), ("longitude","latitude"), ("TLON","TLAT"), ("ULON","ULAT") 3) Fallback to class grid datasets (`self.G_u` for B-grid or `self.G_t` for T-grid), loaded on demand via `self.load_bgrid(slice_hem=True)`, including optional subsetting via `da2["nj"]` and `da2["ni"]` coordinate indices. Parameters ---------- da2 : xr.DataArray 2D DataArray whose spatial coordinates are required. bcoords : bool, default False If True, use B-grid coordinates (typically `self.G_u["lon"]`, `self.G_u["lat"]`). tcoords : bool, default True If True, use T-grid coordinates (typically `self.G_t["lon"]`, `self.G_t["lat"]`). lon_coord_name, lat_coord_name : str or None, optional Explicit coordinate names to use from `da2`. Both must be provided together. infer_if_missing : bool, default True If True, attempt to infer coordinate names from common patterns in `da2.coords` when explicit names are not provided. Returns ------- (np.ndarray, np.ndarray) lon2d, lat2d arrays with shape equal to `da2.shape`. Raises ------ ValueError If `da2` is not 2D. ValueError If only one of `lon_coord_name` or `lat_coord_name` is provided. ValueError If both `bcoords` and `tcoords` are True, or both are False, when fallback grid coordinates are required. ValueError If lon/lat cannot be resolved or cannot be matched/subset to `da2.shape`. Notes ----- - If `da2` lacks lon/lat coordinates and the grid arrays do not match in shape, `da2` must provide integer index coordinates `nj` and `ni` so the grid can be subset consistently. """ # Ensure clean 2D view, and standard ordering if possible if da2.ndim != 2: raise ValueError(f"Expected 2D DataArray; got dims={da2.dims}, shape={da2.shape}") # ---- 1) Explicit coords ---- if (lon_coord_name is not None) or (lat_coord_name is not None): if lon_coord_name is None or lat_coord_name is None: raise ValueError("Provide both lon_coord_name and lat_coord_name.") lon_da = da2[lon_coord_name] lat_da = da2[lat_coord_name] if lon_da.ndim == 2 and lat_da.ndim == 2: # align dims if needed if lon_da.dims != da2.dims: da2 = da2.transpose(*lon_da.dims) lon_da = da2[lon_coord_name] lat_da = da2[lat_coord_name] return np.asarray(lon_da.data), np.asarray(lat_da.data) if lon_da.ndim == 1 and lat_da.ndim == 1: lon2d, lat2d = np.meshgrid(np.asarray(lon_da.data), np.asarray(lat_da.data), indexing="xy") return lon2d, lat2d raise ValueError(f"Mixed/unsupported explicit coord dims: lon={lon_da.ndim}, lat={lat_da.ndim}") # ---- 2) Infer coords from da2 ---- if infer_if_missing: # common patterns cand_pairs = [("lon", "lat"), ("longitude", "latitude"), ("TLON", "TLAT"), ("ULON", "ULAT")] for xnm, ynm in cand_pairs: if (xnm in da2.coords) and (ynm in da2.coords): lon_da = da2.coords[xnm] lat_da = da2.coords[ynm] if lon_da.ndim == 2 and lat_da.ndim == 2: if lon_da.dims != da2.dims: da2 = da2.transpose(*lon_da.dims) lon_da = da2.coords[xnm] lat_da = da2.coords[ynm] return np.asarray(lon_da.data), np.asarray(lat_da.data) if lon_da.ndim == 1 and lat_da.ndim == 1: lon2d, lat2d = np.meshgrid(np.asarray(lon_da.data), np.asarray(lat_da.data), indexing="xy") return lon2d, lat2d # ---- 3) B/T grid coords ---- if bcoords and tcoords: raise ValueError("Cannot set both bcoords and tcoords True.") if not (bcoords or tcoords): raise ValueError("Must set bcoords or tcoords, or provide lon/lat coords on da.") # Only load grid if we need it # (adjust to your class: load_bgrid may load both; otherwise add load_tgrid, etc.) self.load_cice_grid(slice_hem=True) if bcoords: lon2d_full = np.asarray(self.G_u["lon"].values) lat2d_full = np.asarray(self.G_u["lat"].values) else: lon2d_full = np.asarray(self.G_t["lon"].values) lat2d_full = np.asarray(self.G_t["lat"].values) if lon2d_full.shape == da2.shape: return lon2d_full, lat2d_full # subset lon/lat to match da2 using nj/ni index coords if ("nj" in da2.coords) and ("ni" in da2.coords): nj_idx = np.asarray(da2["nj"].values, dtype=int) ni_idx = np.asarray(da2["ni"].values, dtype=int) return lon2d_full[np.ix_(nj_idx, ni_idx)], lat2d_full[np.ix_(nj_idx, ni_idx)] raise ValueError(f"Grid lon/lat shape {lon2d_full.shape} does not match da shape {da2.shape}, " "and da has no nj/ni coords for subsetting.")
[docs] def pygmt_da_prep(self, da: xr.DataArray, *, bcoords : bool = False, tcoords : bool = True, lon_coord_name : str | None = None, lat_coord_name : str | None = None, region : tuple[float, float, float, float] | None = None, lon_wrap : str = "auto", # "auto" | "0-360" | "-180-180" extra_mask : xr.DataArray = None, mask_zero : bool | None = None, z_clip : tuple[float, float] | None = None, z_range_mask : tuple[float, float] | None = None, dtype : str = "float32", infer_coords : bool = True, return_mask : bool = True, return_flat_index : bool = True): """ Prepare a 2D DataArray for PyGMT plotting by returning 1D lon/lat/z vectors. This routine standardises a gridded field into a structure that PyGMT commonly expects for point plotting or scattered gridding. It also constructs masks for: - finite values, - optional masking of structural zeros, - optional clipping/masking by z-range, - optional geographic subsetting by `region` (dateline-safe), - optional additional user-provided mask. Parameters ---------- da : xr.DataArray Input data field. It may contain singleton dimensions; these will be squeezed. After squeeze, the data must be 2D. bcoords : bool, default False Use B-grid coordinates if lon/lat are not directly available on `da`. tcoords : bool, default True Use T-grid coordinates if lon/lat are not directly available on `da`. lon_coord_name, lat_coord_name : str or None, optional Explicit coordinate names to use for longitude and latitude in `da`. If provided, both must be provided. region : tuple[float, float, float, float], optional Geographic bounding box `(xmin, xmax, ymin, ymax)` applied as a mask. This masking is dateline-safe: if `xmin > xmax`, the region is interpreted as crossing the dateline and uses `(lon >= xmin) OR (lon <= xmax)`. lon_wrap : {"auto", "0-360", "-180-180"}, default "auto" Longitude convention to enforce on the resolved lon grid. If "auto" and `region` is provided, the method will choose a convention consistent with the region bounds; otherwise it leaves longitudes unchanged. extra_mask : array-like, xr.DataArray, callable, optional Additional mask to apply. If callable, it is invoked as `extra_mask(da2)` and must return a boolean mask with the same shape as the 2D data. If array-like, it is converted to a boolean array and combined with the existing mask. mask_zero : bool or None, optional Whether to mask zeros (treat as missing) before flattening. - If None, uses `_auto_mask_zero(da2)` to decide. - If True, masks values close to zero (|z| <= 1e-8). - If False, preserves zeros. z_clip : tuple[float, float], optional If provided, clip z values into `[z_clip[0], z_clip[1]]` prior to masking. This is a hard clip (values outside become boundary values). z_range_mask : tuple[float, float], optional If provided, apply an additional mask retaining only values within `[lo, hi]`. dtype : str, default "float32" Target dtype used when materialising arrays (especially helpful for memory). infer_coords : bool, default True If True, allow `_resolve_lonlat_2d()` to infer lon/lat from common coordinate names on `da` when explicit names are not provided. return_mask : bool, default True If True, include `mask2d` (boolean 2D mask) in the returned dict. return_flat_index : bool, default True If True, include `flat_idx` (indices into `z2d.ravel()` where mask is True) in the returned dict. Returns ------- dict Always includes: - "lon" : 1D np.ndarray of longitudes (masked + flattened) - "lat" : 1D np.ndarray of latitudes (masked + flattened) - "z" : 1D np.ndarray of values (masked + flattened) - "shape" : tuple of original 2D shape for sanity checks Optionally includes: - "mask2d" : 2D boolean mask (if `return_mask=True`) - "flat_idx" : 1D integer indices into `z2d.ravel()` (if `return_flat_index=True`) Raises ------ ValueError If `da` cannot be reduced to a 2D field (after squeeze). ValueError If coordinate resolution fails or `extra_mask` has a shape mismatch. Notes ----- - If `da` is Dask-backed, the data are computed only once, and coerced to `dtype`. - If `da` has dims ("nj","ni") in a different order, the method transposes to ("nj","ni") to keep indexing consistent with AFIM grid conventions. """ da2 = da.squeeze(drop=True) if da2.ndim != 2: raise ValueError(f"Expected 2D DataArray after squeeze; got dims={da2.dims}, shape={da2.shape}") # prefer canonical order if present if ("nj" in da2.dims) and ("ni" in da2.dims) and (da2.dims != ("nj", "ni")): da2 = da2.transpose("nj", "ni") lon2d, lat2d = self._resolve_lonlat_2d(da2, bcoords = bcoords, tcoords = tcoords, lon_coord_name = lon_coord_name, lat_coord_name = lat_coord_name, infer_if_missing = infer_coords) # ---- NEW: harmonise lon convention with region ---- if lon_wrap == "auto": if region is not None: xmin, xmax, _, _ = region # If region uses negative longitudes, force lon to -180..180 if (xmin < 0.0) or (xmax < 0.0): lon2d = self.normalise_longitudes(lon2d, to="-180-180") # If region is clearly 0..360-ish, force lon to 0..360 elif (xmin >= 0.0) and (xmax > 180.0): lon2d = self.normalise_longitudes(lon2d, to="0-360") # else: leave as-is else: lon2d = self.normalise_longitudes(lon2d, to=lon_wrap) # materialize Z z_data = da2.data if self._is_dask_array(z_data): z2d = da2.astype(dtype).compute().values else: z2d = np.asarray(z_data, dtype=dtype) if z_clip is not None: z2d = np.clip(z2d, z_clip[0], z_clip[1]) # base mask m = np.isfinite(z2d) if mask_zero is None: mask_zero = self._auto_mask_zero(da2) if mask_zero: m &= ~np.isclose(z2d, 0.0, atol=1e-8) if z_range_mask is not None: lo, hi = z_range_mask m &= (z2d >= lo) & (z2d <= hi) # ---- region mask (DATELINE-SAFE) ---- if region is not None: xmin, xmax, ymin, ymax = region lat_ok = (lat2d >= ymin) & (lat2d <= ymax) # handle regions that cross the dateline by allowing xmin > xmax if xmin <= xmax: lon_ok = (lon2d >= xmin) & (lon2d <= xmax) else: lon_ok = (lon2d >= xmin) | (lon2d <= xmax) m &= lon_ok & lat_ok if extra_mask is not None: em = extra_mask(da2) if callable(extra_mask) else extra_mask if isinstance(em, xr.DataArray): em = np.asarray(em.squeeze(drop=True).values, dtype=bool) else: em = np.asarray(em, dtype=bool) if em.shape != z2d.shape: raise ValueError(f"extra_mask shape mismatch: {em.shape} vs {z2d.shape}") m &= em lon1 = np.asarray(lon2d, dtype=dtype)[m].ravel() lat1 = np.asarray(lat2d, dtype=dtype)[m].ravel() z1 = np.asarray(z2d, dtype=dtype)[m].ravel() out = {"lon" : lon1, "lat" : lat1, "z" : z1, "shape" : z2d.shape} if return_mask: out["mask2d"] = m if return_flat_index: out["flat_idx"] = np.flatnonzero(m.ravel()) return out
[docs] def pygmt_base_layer(self, fig, region, projection, title = None, frame = "af", coast = True, land = "gray85", water = "white", shorelines = "0.35p,black"): """ Create a PyGMT figure with a basemap/coastline. """ frame_use = [frame] if isinstance(frame, str) else list(frame) if title is not None: frame_use = frame_use + [f"+t{title}"] fig.basemap(region=region, projection=projection, frame=frame_use) if coast: fig.coast(region=region, projection=projection, land=land, water=water, shorelines=shorelines) return fig
def _format_projection(self, proj_str, reg, fig_size, fig_width): if proj_str is None: # sensible default if no projection provided if reg[3] <= 0: mc = self.get_meridian_center_from_geographic_extent(reg) return f"S{mc}/-90/{fig_size}c" elif reg[2] >= 0: mc = self.get_meridian_center_from_geographic_extent(reg) return f"S{mc}/90/{fig_size}c" else: return f"M{fig_width}c" fmt = {"MC": self.get_meridian_center_from_geographic_extent(reg), "fig_size": fig_size, "fig_width": fig_width} try: return proj_str.format(**fmt) except Exception: return proj_str def _resolve_plot_region_projection(self, *, regions_dict : dict | None = None, region_name : str | None = None, region : tuple | list | None = None, projection : str | None = None, fig_size : float = 20.0, fig_width : float | None = None, default_hemisphere : str = "south"): """ Resolve a plotting region and projection string. Priority -------- 1. Explicit `region` + explicit/derived `projection` 2. `regions_dict[region_name]` 3. Search built-in region dictionaries on `self` 4. Default to pan-Antarctic south polar stereographic Supported placeholders in projection strings -------------------------------------------- {MC} : meridian centre inferred from plot_region {fig_size} : nominal figure size {fig_width}: nominal figure width """ fig_width = fig_size if fig_width is None else fig_width # ---------------------------------------------------------- # 1. Explicit region supplied # ---------------------------------------------------------- if region is not None: reg = tuple(region) proj = self._format_projection(projection, reg, fig_size, fig_width) return reg, proj # ---------------------------------------------------------- # 2/3. Search provided and built-in dictionaries # ---------------------------------------------------------- search_dicts = [] if isinstance(regions_dict, dict): search_dicts.append(regions_dict) for attr in ("specific_regions", "Ant_8sectors", "Ant_2sectors", "hemispheres_dict"): d = getattr(self, attr, None) if isinstance(d, dict): search_dicts.append(d) spec = None if region_name is not None: for d in search_dicts: if region_name in d: spec = d[region_name] break if spec is None: raise KeyError(f"Region name {region_name!r} not found in supplied/built-in region dictionaries.") else: # ------------------------------------------------------ # 4. Default hemisphere # ------------------------------------------------------ default_name = default_hemisphere if default_name not in ("south", "north"): default_name = "south" # try built-in hemispheres_dict first hemi_dict = getattr(self, "hemispheres_dict", {}) if isinstance(hemi_dict, dict) and default_name in hemi_dict: spec = hemi_dict[default_name] else: # fallback hard-coded pan-Antarctic default if default_name == "south": spec = { "plot_region": [-180, 180, -90, -55], "projection": "S0.0/-90.0/50/{fig_size}c", } else: spec = { "plot_region": [-180, 180, 55, 90], "projection": "S0.0/90.0/50/{fig_size}c", } if "plot_region" not in spec: raise KeyError("Resolved region spec does not contain 'plot_region'.") reg = tuple(spec["plot_region"]) proj = self._format_projection(projection or spec.get("projection", None), reg, fig_size, fig_width) return reg, proj
[docs] def pygmt_2D_array(self, da: xr.DataArray, *, fig = None, panel = None, panel_title = None, show_fig = False, P_png = None, return_cmap = False, # region / projection handling regions_dict : dict | None = None, region_name : str | None = None, region : tuple | list | None = None, projection : str | None = None, default_hemisphere : str = "south", fig_size : float = 20.0, fig_width : float | None = None, # coordinates lon_coord_name : str | None = None, lat_coord_name : str | None = None, bcoords : bool = False, tcoords : bool = True, infer_coords : bool = True, lon_wrap : str = "auto", # data prep mask_zero : bool | None = None, extra_mask = None, z_clip : tuple[float, float] | None = None, z_range_mask : tuple[float, float] | None = None, dtype : str = "float32", # colour / CPT cmap : str | None = None, series = "auto", # "auto" or [min,max] or [min,max,inc] or None reverse_cmap : bool = False, background : bool = True, P_cpt : str | Path | None = None, add_colorbar : bool = True, cbar_pos : str = "JBC+w15c/0.45c+mc+h", cbar_frame = ("af",), # plot appearance basemap_frame = ("af",), land_color : str = "#666666", water_color : str = "#BABCDE", shoreline_pen : str = "1/0.25p", plot_bathymetry : bool = False, bath_cmap : str = "geo", point_marker : str = "c", point_size : str = "0.05", point_unit : str = "c", point_pen : str = "none", transparency : float | None = None, # extras plot_GI : bool = False, GI_color : str = "#E349D0", GI_marker : str = "s", GI_size : str = "0.01", GI_pt_unit : str = "c"): """ General 2D PyGMT map for curvilinear sea-ice style grids. Parameters ---------- da : xr.DataArray 2D field to plot. After squeeze, it must be 2D. regions_dict : dict, optional Dictionary of the form: { "REGION_NAME": { "plot_region": [lon_min, lon_max, lat_min, lat_max], "projection": "PyGMT projection string" } } region_name : str, optional Key into `regions_dict` or into built-in dictionaries on `self` (`hemispheres_dict`, `Ant_8sectors`, `Ant_2sectors`, `specific_regions`). region, projection : optional Explicit overrides. series : "auto", sequence, or None If "auto", infer the CPT bounds from the plotted data. If sequence, use [zmin, zmax] or [zmin, zmax, dz]. If None, use `cmap` as-is without rebuilding a CPT. """ import pygmt from contextlib import nullcontext if not isinstance(da, xr.DataArray): raise TypeError("da must be an xarray.DataArray.") da2 = da.squeeze(drop=True) if da2.ndim != 2: raise ValueError(f"Expected a 2D DataArray after squeeze; got dims={da2.dims}, shape={da2.shape}") # -------------------------------------------------------------- # resolve region / projection # -------------------------------------------------------------- region_use, projection_use = self._resolve_plot_region_projection(regions_dict = regions_dict, region_name = region_name, region = region, projection = projection, fig_size = fig_size, fig_width = fig_width, default_hemisphere = default_hemisphere) # -------------------------------------------------------------- # optional ancillary data # -------------------------------------------------------------- if plot_GI: self.load_cice_grid(slice_hem=True) SO_BATH = None if plot_bathymetry: SO_BATH = self.load_IBCSO_bath() # -------------------------------------------------------------- # prepare lon / lat / z for plotting # -------------------------------------------------------------- prep = self.pygmt_da_prep(da2, bcoords = bcoords, tcoords = tcoords, lon_coord_name = lon_coord_name, lat_coord_name = lat_coord_name, region = region_use, lon_wrap = lon_wrap, extra_mask = extra_mask, mask_zero = mask_zero, z_clip = z_clip, z_range_mask = z_range_mask, dtype = dtype, infer_coords = infer_coords, return_mask = True, return_flat_index = True) x = prep["lon"] y = prep["lat"] z = prep["z"] if z.size == 0: raise ValueError("No plottable points remain after masking / regional subsetting.") # -------------------------------------------------------------- # build CPT if requested # -------------------------------------------------------------- cmap_use = cmap or getattr(self, "pygmt_dict", {}).get("default_cmap", "viridis") if series == "auto": zmin = float(np.nanmin(z)) zmax = float(np.nanmax(z)) if not np.isfinite(zmin) or not np.isfinite(zmax): raise ValueError("Plotted data contain no finite values.") if np.isclose(zmin, zmax): pad = 1.0 if np.isclose(zmin, 0.0) else abs(zmin) * 0.01 zmin -= pad zmax += pad series_use = [zmin, zmax] else: series_use = series if series_use is not None: P_cpt = Path(P_cpt) if P_cpt is not None else ( Path(self.D_graph) / "CPTs" / f"{(da2.name or 'map').replace(' ', '_')}_auto.cpt" ) P_cpt.parent.mkdir(parents=True, exist_ok=True) pygmt.makecpt(cmap = cmap_use, series = series_use, reverse = reverse_cmap, background = background, output = str(P_cpt)) cmap_use = str(P_cpt) # -------------------------------------------------------------- # figure setup # -------------------------------------------------------------- created_here = fig is None if created_here: fig = pygmt.Figure() frame = list(basemap_frame) if panel_title is not None: frame = frame + [f"+t{panel_title}"] panel_ctx = fig.set_panel(panel=panel) if panel is not None else nullcontext() self.logger.info(f"pygmt_2D_map: name={da2.name}, shape={da2.shape}, " f"npts={z.size}, zmin={float(np.nanmin(z)):.4g}, zmax={float(np.nanmax(z)):.4g}, " f"region={region_use}, projection={projection_use}") style = f"{point_marker}{point_size}{point_unit}" with pygmt.config(FONT_TITLE = "18p,Bookman-Demi", FONT_ANNOT_PRIMARY = "16p,NewCenturySchlbk-Roman", FONT_ANNOT_SECONDARY = "16p,NewCenturySchlbk-Bold", FONT_LABEL = "16p,NewCenturySchlbk-Bold", COLOR_FOREGROUND = "black"): with panel_ctx: fig.basemap(region=region_use, projection=projection_use, frame=frame) if plot_bathymetry: fig.grdimage(grid=SO_BATH, cmap=bath_cmap) else: fig.coast( region = region_use, projection = projection_use, land = land_color, water = water_color, ) # IMPORTANT: use zvalue + fill="+z" so symbols are coloured by CPT fig.plot( x = x, y = y, fill = z, style = style, cmap = cmap_use, #transparency = transparency, ) if plot_GI: fig.plot( x = self.G_GI["lon"].values.ravel(), y = self.G_GI["lat"].values.ravel(), fill = GI_color, style = f"{GI_marker}{GI_size}{GI_pt_unit}", pen = "none", ) fig.coast( region = region_use, projection = projection_use, shorelines = shoreline_pen, ) if add_colorbar: fig.colorbar( position = cbar_pos, frame = list(cbar_frame), cmap = cmap_use, ) if created_here and P_png is not None: P_png = Path(P_png) P_png.parent.mkdir(parents=True, exist_ok=True) fig.savefig(P_png) if created_here and show_fig: fig.show() if return_cmap: return fig, cmap_use return fig
[docs] def pygmt_map_plot_one_var(self, da, var_name, aux_ds = None, sim_name = None, plot_regions = None, regional_dict = None, hemisphere = "south", time_stamp = None, tit_str = None, plot_GI = False, cmap = None, series = None, reverse = None, cbar_label = None, cbar_units = None, extend_cbar = False, cbar_position = None, lon_name = None, lat_name = None, fig_size = None, var_sq_size = 0.2, GI_sq_size = 0.1, GI_fill_color = "red", plot_iceshelves= True, plot_bathymetry= True, add_stat_annot = False, land_color = None, water_color = None, P_png = None, var_out = None, overwrite_fig = None, show_fig = None): """ Generate a PyGMT figure showing a variable (e.g., FIA, FIP, differences) as a spatial map over Antarctic regions or hemispheric view, optionally including grounded icebergs, ice shelf outlines, and bathymetry. This flexible mapping function supports multiple types of visualizations: - Scalar field plots (e.g., fast ice persistence, concentration) - Binary or categorical masks (e.g., agreement maps, simulation masks) - Difference plots (e.g., observation minus simulation) Parameters ---------- da : xarray.DataArray Input 2D (or broadcastable) data array to be plotted, e.g., fast ice persistence. var_name : str Name of the variable to be plotted. Used to look up color map settings and labels. aux_ds : optional Dataset containing auxiliaries (e.g. aice) aligned with da sim_name : str, optional Simulation name used for file naming and figure annotation. Defaults to `self.sim_name`. plot_regions : int or None, optional Number of regional plots: - 8 : Antarctic 8-sector view - 2 : East/West sectors - None : Full hemisphere plot (default) regional_dict : dict, optional Custom dictionary of plotting regions (overrides built-ins if provided). hemisphere : str, default="south" Hemisphere name for hemispheric plot. Typically "south". time_stamp : str, optional Timestamp string for file naming. Defaults to `self.dt0_str`. tit_str : str, optional Title string to display on the figure. plot_GI : bool, default=False Whether to overlay grounded iceberg locations using `load_GI_lon_lats()`. cmap : str, optional Color map name for continuous scalar plots. series : list of float, optional Min, max, and increment for the colorbar (e.g., [0, 1, 0.1]). reverse : bool, optional Whether to reverse the colormap. cbar_label : str, optional Label to show on the colorbar. cbar_units : str, optional Units for the colorbar (shown on secondary axis). extend_cbar : bool, default=False Whether to add extension arrows to colorbar. cbar_position : str, optional PyGMT-compatible position string for placing the colorbar. lon_name : str, optional Longitude coordinate name in `da`. Defaults to `self.pygmt_dict`. lat_name : str, optional Latitude coordinate name in `da`. Defaults to `self.pygmt_dict`. fig_size : float, optional Figure width in centimeters for hemispheric plot. var_sq_size : float, default=0.2 Marker size (in cm) for main plotted variable (for scatter plotting). GI_sq_size : float, default=0.1 Marker size (in cm) for grounded iceberg overlay. GI_fill_color : str, default="red" Color to use for grounded iceberg markers. plot_iceshelves : bool, default=True Whether to overlay Antarctic ice shelf outlines. plot_bathymetry : bool, default=True Whether to plot IBCSO bathymetry using shaded relief (`grdimage`). add_stat_annot : bool, default=False Whether to annotate figure with basic regional statistics. land_color : str, optional Color for land. Defaults to `self.pygmt_dict['land_color']`. water_color : str, optional Color for ocean/water. Defaults to `self.pygmt_dict['water_color']`. P_png : pathlib.Path, optional File path to save the figure. If None, path is generated automatically. var_out : str, optional Output variable name used in file naming. Defaults to `var_name`. overwrite_fig : bool, optional Whether to overwrite existing figure if it exists. show_fig : bool, optional Whether to display the figure interactively. Returns ------- None The method generates and optionally saves or displays a PyGMT figure. Notes ----- - Supports 8-region, 2-region, or full hemisphere plotting via `plot_regions`. - For "diff" plots (where 'diff' in var_name), colors are assigned categorically (e.g., agreement/simulation/observation). - Bathymetry and ice shelf overlays are loaded from NetCDF and shapefiles respectively. - The method gracefully skips plotting if required coordinates or data are missing. See Also -------- - self.pygmt_da_prep : Prepares data dictionary for plotting - self.create_cbar_frame : Builds formatted colorbar strings - self.load_GI_lon_lats : Loads grounded iceberg locations - self.load_ice_shelves : Loads Antarctic ice shelf polygons - self.load_IBCSO_bath : Loads bathymetry grid from IBCSO Examples -------- >>> self.pygmt_map_plot_one_var(FIP_DA, "FIP", plot_regions=8, show_fig=True) """ import pygmt sim_name = sim_name if sim_name is not None else self.sim_name show_fig = show_fig if show_fig is not None else self.show_fig ow_fig = overwrite_fig if overwrite_fig is not None else self.ow_fig time_stamp = time_stamp if time_stamp is not None else self.dt0_str lon_name = lon_name if lon_name is not None else self.pygmt_dict.get("lon_coord_name", "TLON") lat_name = lat_name if lat_name is not None else self.pygmt_dict.get("lat_coord_name", "TLAT") cmap = cmap if cmap is not None else self.plot_var_dict[var_name]['cmap'] series = series if series is not None else self.plot_var_dict[var_name]['series'] reverse = reverse if reverse is not None else self.plot_var_dict[var_name]['reverse'] cbar_lab = cbar_label if cbar_label is not None else self.plot_var_dict[var_name]['name'] cbar_units = cbar_units if cbar_units is not None else self.plot_var_dict[var_name]['units'] fig_size = fig_size if fig_size is not None else self.pygmt_dict['fig_size'] cbar_pos = cbar_position if cbar_position is not None else self.pygmt_dict['cbar_pos'].format(width=fig_size*0.8,height=0.75) land_color = land_color if land_color is not None else self.pygmt_dict['land_color'] water_color = water_color if water_color is not None else self.pygmt_dict['water_color'] # Accept Dataset directly if isinstance(da, xr.Dataset): ds = da da = ds[var_name] else: ds = aux_ds # may be None aice_da = None u_da = None v_da = None if aux_ds is not None: if "aice" in aux_ds: aice_da = aux_ds["aice"] if "uvel" in aux_ds: u_da = aux_ds["uvel"] if "vvel" in aux_ds: v_da = aux_ds["vvel"] if var_out is None: var_out = var_name if plot_iceshelves: if not hasattr(self, "_ANT_IS"): self._ANT_IS = self.load_ice_shelves() ANT_IS = self._ANT_IS if plot_bathymetry: if not hasattr(self, "_SO_BATH"): self._SO_BATH = self.load_IBCSO_bath() SO_BATH = self._SO_BATH required_keys = ['lon', 'lat', 'z'] plot_data_dict = self.pygmt_da_prep(da, bcoords = False, tcoords = False, lon_coord_name = lon_name, lat_coord_name = lat_name) try: if not isinstance(plot_data_dict, dict): self.logger.warning("plot_data_dict is not a dictionary — skipping plot.") return for k in required_keys: if k not in plot_data_dict: self.logger.warning(f"Missing key '{k}' in plot_data_dict — skipping plot.") return v = plot_data_dict[k] if v is None: self.logger.warning(f"plot_data_dict['{k}'] is None — skipping plot.") return if hasattr(v, "size") and v.size == 0: self.logger.warning(f"plot_data_dict['{k}'] is empty — skipping plot.") return except Exception as e: self.logger.warning(f"Skipping plot due to error: {e}") return cbar_frame = self.create_cbar_frame(series, cbar_lab, units=cbar_units, extend_cbar=extend_cbar) hem_plot = False if plot_GI: plot_GI_dict = self.load_GI_lon_lats()#plot_data_dict) if plot_regions is not None and plot_regions==8: self.logger.info("method will plot eight Antarctic sectors regional dictionary") reg_dict = self.Ant_8sectors elif plot_regions is not None and plot_regions==2: self.logger.info("method will plot two Antarctic sectors regional dictionary") reg_dict = self.Ant_2sectors elif plot_regions is not None and regional_dict is not None: self.logger.info("method will plot regional dictionary passed to this method") reg_dict = regional_dict elif plot_regions is not None and regional_dict is None: self.logger.info("plot_regions argument not valid") else: self.logger.info("method will plot hemispheric data") hem_plot = True reg_dict = self.hemispheres_dict if tit_str is not None: basemap_frame = ["af", f"+t{tit_str}"] else: basemap_frame = ["af"] for i, (reg_name, reg_vals) in enumerate(reg_dict.items()): if hem_plot and reg_name!=hemisphere: continue if P_png is None and self.save_fig: P_png = Path(self.D_graph, sim_name, reg_name, var_out, f"{time_stamp}_{sim_name}_{reg_name}_{var_out}.png") region = reg_vals['plot_region'] projection = reg_vals['projection'] if hem_plot: projection = projection.format(fig_size=fig_size) elif reg_name in list(self.Ant_8sectors.keys()): MC = self.get_meridian_center_from_geographic_extent(region) projection = projection.format(MC=MC, fig_size=fig_size) elif reg_name in list(self.Ant_2sectors.keys()): projection = projection.format(fig_width=fig_size) fig = pygmt.Figure() with pygmt.config(FONT_TITLE = "16p,Courier-Bold", FONT_ANNOT_PRIMARY = "14p,Helvetica", COLOR_FOREGROUND = 'black'): fig.basemap(region=region, projection=projection, frame=basemap_frame) if plot_bathymetry: fig.grdimage(grid=SO_BATH, cmap='geo') else: fig.coast(region=region, projection=projection, shorelines="1/0.5p,gray30", land=land_color, water=water_color) pygmt.makecpt(cmap=cmap, reverse=reverse, series=series) fig.plot(x=plot_data_dict['lon'], y=plot_data_dict['lat'], fill=plot_data_dict['z'], style=f"s{var_sq_size}c", cmap=True) if plot_bathymetry: fig.coast(region=region, projection=projection, shorelines="1/0.5p,gray30") if plot_GI: fig.plot(x=plot_GI_dict['lon'], y=plot_GI_dict['lat'], fill=GI_fill_color, style=f"c{GI_sq_size}c") if plot_iceshelves: fig.plot(data=ANT_IS, fill="lightgray") if add_stat_annot: annot_text = self.generate_regional_annotation_stats( da, region, lon_name, lat_name, var_name, area_unit="1e6km2", aice_da=aice_da, # <--- optional; used by hi ) for i, line in enumerate(annot_text): try: fig.text(position="TR", text=line, font="12p,Helvetica-Bold,black", justify="LM", no_clip=True, offset=f"-1/{-0.5*i}") except pygmt.exceptions.GMTCLibError as e: self.logger.warning(f"Error in plotting anotation text {e} -- skipping annotation") fig.colorbar(position=cbar_pos, frame=cbar_frame) if P_png: if not P_png.exists(): P_png.parent.mkdir(parents=True, exist_ok=True) fig.savefig(P_png) self.logger.info(f"Saved figure to {P_png}") else: if ow_fig: fig.savefig(P_png) self.logger.info(f"Saved figure to {P_png}") else: self.logger.info(f"{P_png} already exists and not overwriting") P_png = None if show_fig: fig.show() pygmt.clib.Session.__exit__
[docs] def pygmt_fastice_panel(self, fast_ice_variable : str = "FIA", # "FIA"|"fia" or "FIT"|"fit" or "FIS|fis" sim_name : str = None, roll_days : int = 0, # Generic (can be overridden) fig_width : str = None, fig_height : str = None, ylim : tuple = None, frame_bndy : str = None, yaxis_pri : str = None, xaxis_pri : str = "a1Of15Dg", leg_pos : str = None, leg_box : str = "+gwhite+p.5p", spat_var_style : str = None, GI_plot_style : str = "c0.05c", GI_fill_color : str = "#BA561A", plot_GI : bool = False, min_max_trans_val : int = 80, yshift_top : str = None, yshift_bot : str = None, bottom_frame_bndy : str = "WSne", bottom_yaxis : str = None, bottom_xaxis : str = None, land_clr : str = '#D1DDE0', water_clr : str = "#EDF2F5", coast_pen : str = "1/0.5p,black", cbar_pos : str = None, lon_coord_name : str = None, lat_coord_name : str = None, cmap : str = None, series : list = None, cbar_frame : str = None, ANT_IS_pen : str = "0.2p,black", ANT_IS_color : str = "#C1CED6", font_annot_pri : str = "24p,Times-Roman", font_annot_sec : str = "16p,Times-Roman", font_lab : str = "22p,Times-Bold", line_pen : str = "2p", grid_pen_pri : str = ".5p", grid_pen_sec : str = ".25p", fmt_geo_map : str = "D:mm", P_png : str = None, save_fig : bool = None, overwrite_fig : bool = None, show_fig : bool = None): """ Unified PyGMT fast-ice panel plotter. Set `fast_ice_variable` to: - "FIA" (or "fia") -> plots FIA time series + FIP maps - "FIT" (or "fit") -> plots FIT time series + FIHI maps Everything else (loading, styling, legends, grounded iceberg overlay, etc.) follows the same code path with variable-specific defaults injected via a small configuration dictionary. The rest of the arguments let you override those defaults if you need to. """ import pygmt var = fast_ice_variable.lower() if var not in ("fia", "fit", "fis", "fimar", "fimvr", "fitar", "fitvr"): raise ValueError(f"`fast_ice_variable` must be one of ['FIA','FIT','FIS','FIMAR','FIMVR','FITAR','FITVR']; got {fast_ice_variable}") # ----------------------------------------------------------------------- # Per-variable defaults (you can push more things in here if you like) # ------------------------------------------------------------------------- cfg = self.pygmt_FI_panel[var] # ------------------------------------------------------------------------- # Resolve user overrides or fall back to defaults # ------------------------------------------------------------------------- sim_name = sim_name if sim_name is not None else self.sim_name show_fig = show_fig if show_fig is not None else self.show_fig save_fig = save_fig if save_fig is not None else self.save_fig ow_fig = overwrite_fig if overwrite_fig is not None else self.ow_fig frame_bndy = frame_bndy if frame_bndy is not None else "WS" fig_width = fig_width if fig_width is not None else "30c" fig_height = fig_height if fig_height is not None else "25c" spat_var_style = spat_var_style if spat_var_style is not None else "s0.2c" yshift_top = yshift_top if yshift_top is not None else "-6.25c" yshift_bot = yshift_bot if yshift_bot is not None else "-10.5c" bottom_yaxis = bottom_yaxis if bottom_yaxis is not None else "a5f1g" bottom_xaxis = bottom_xaxis if bottom_xaxis is not None else "a30f10g" cbar_pos = cbar_pos if cbar_pos is not None else "JBC+w25c/1c+mc+h" ylim = ylim if ylim is not None else cfg["panel_ylim"] yaxis_pri = yaxis_pri if yaxis_pri is not None else cfg["yaxis_pri"] leg_pos = leg_pos if leg_pos is not None else cfg["leg_pos"] cbar_frame = cbar_frame if cbar_frame is not None else cfg["cbar_frame"] cmap = cmap if cmap is not None else cfg["cmap"] series = series if series is not None else cfg["series"] # ------------------------------------------------------------------------- # Paths / loads # ------------------------------------------------------------------------- ANT_IS = self.load_ice_shelves() ice_methods = ["rolling-mean", "binary-days"] ts_dict = {} if var=='fia': ts_dict["AF2020"] = xr.open_dataset(self.AF_FI_dict['P_AF2020_FIA'])["AF2020"] for imeth in ice_methods: mets = self.load_computed_metrics(class_method=imeth) if imeth=="binary-days": da_spat = mets[cfg["bottom_name"]] ts_dict[imeth] = self.load_computed_metrics(class_method=imeth)[cfg["top_name"]] tmin, tmax = self.extract_min_max_dates(ts_dict) df_spat = self.pygmt_da_prep(da_spat) if plot_GI: plot_GI_dict = self.load_GI_lon_lats() # ------------------------------------------------------------------------- # Plot # ------------------------------------------------------------------------- plot_region = [f"{self.leap_year}-01-01", f"{self.leap_year}-12-31", ylim[0], ylim[1]] plot_projection = f"X{fig_width}/{fig_height}" frame = [frame_bndy, f"px{xaxis_pri}", f"py{yaxis_pri}"] fig = pygmt.Figure() with pygmt.config(FONT_ANNOT_PRIMARY = font_annot_pri, FONT_ANNOT_SECONDARY = font_annot_sec, FONT_LABEL = font_lab, MAP_GRID_PEN_PRIMARY = grid_pen_pri, MAP_GRID_PEN_SECONDARY = grid_pen_sec, FORMAT_GEO_MAP = fmt_geo_map, FORMAT_DATE_MAP = "o", FORMAT_TIME_PRIMARY_MAP = "Abbreviated"): # ---- time-series top panel ---- fig.basemap(region=plot_region, projection=plot_projection, **{"frame": frame}) for k, da in ts_dict.items(): if hasattr(self, "pygmt_FIA_dict") and k in self.pygmt_FIA_dict: leg_lab = self.pygmt_FIA_dict[k]["label"] line_pen = self.pygmt_FIA_dict[k]["line_pen"] line_color = self.pygmt_FIA_dict[k]["line_color"] else: leg_lab, line_pen, line_color = k, "1.5p", "black" clim = self.compute_doy_climatology(da) fig.plot(x = np.concatenate([clim["min"].index, clim["max"].index[::-1]]), y = np.concatenate([clim["min"].values, clim["max"].values[::-1]]), fill = f"{line_color}@{min_max_trans_val}", close = True, transparency = min_max_trans_val) fig.plot(x = clim["mean"].index, y = clim["mean"].values, pen = f"{line_pen},{line_color}", label = leg_lab) fig.legend(position=leg_pos, box=leg_box) # ---- bottom panel(s) ---- fig.shift_origin(yshift=yshift_top) pygmt.makecpt(cmap=cmap, series=series) for i, (reg_name, reg_vals) in enumerate(self.Ant_2sectors.items()): b_region = reg_vals["plot_region"] b_projection = reg_vals["projection"].format(fig_width=fig_width) if i > 0: fig.shift_origin(yshift=yshift_bot) fig.basemap(region = b_region, projection = b_projection, frame = [f"x{bottom_xaxis}", f"y{bottom_yaxis}"]) fig.coast(water=water_clr) fig.plot(x = df_spat["lon"], y = df_spat["lat"], fill = df_spat["z"], style = spat_var_style, cmap = True) if plot_GI: fig.plot(x = plot_GI_dict["lon"], y = plot_GI_dict["lat"], fill = GI_fill_color, style = GI_plot_style) fig.coast(land=land_clr, shorelines=coast_pen) fig.plot(data=ANT_IS, fill=ANT_IS_color) fig.colorbar(position=cbar_pos, frame=cbar_frame) # Save / Show if save_fig: F_pre = f"{cfg['top_name']}_{cfg['bottom_name']}_{sim_name}" F_suf = f"ispd_thresh{self.ispd_thresh_str}_{tmin.strftime('%Y')}-{tmax.strftime('%Y')}" F_png = f"{F_pre}_{F_suf}.png" P_png = P_png if P_png is not None else Path(self.D_graph, sim_name, F_png) P_png.parent.mkdir(parents=True, exist_ok=True) if not P_png.exists() or ow_fig: fig.savefig(P_png, dpi=300) self.logger.info(f"saved figure to {P_png}") else: self.logger.info(f"{P_png} already exists and not overwriting") if show_fig: fig.show()
[docs] def add_fip_shared_colorbar(self, fig, cmap_use, *, cbar_bframe = 'n', fig_width = "24c", cbar_width = "12c", cbar_yoffset = "0.8c", label_yshift = "-0.45c", left_label = "model-dominant", mid_label = "agreement", right_label = "obs-dominant"): # centered shared colorbar fig.colorbar(cmap=cmap_use, position=f"JBC+w{cbar_width}/0.45c+h+o0c/{cbar_yoffset}", frame=["xa1f0.25"]) # full-width invisible strip under the bar fig.shift_origin(yshift=label_yshift) fig.basemap(region=[0, 1, 0, 1], projection=f"X{fig_width}/0.8c", frame=cbar_bframe) # if cbar is 12c in a 24c-wide figure, bar spans x=[0.25, 0.75] fig.text(x = [0.25, 0.50, 0.75], y = [0.18, 0.18, 0.18], text = [left_label, mid_label, right_label], justify = "TC", no_clip = True, font = "12p,NewCenturySchlbk-Bold")
def _format_subplot_projection(self, proj_template, MC): proj = proj_template.format(MC=MC, fig_size="?") # important: strip any hard-coded unit attached to the placeholder proj = proj.replace("?i", "?").replace("?c", "?") return proj
[docs] def pygmt_FIP_figure(self, plot_data=None, *, obs = None, mod = None, fig = None, panel = None, add_colorbar = True, return_cmap = False, panel_title = None, show_fig = False, P_png = None, region = (0, 360, -90, -62), projection = "S0.0/-90.0/50/?", basemap_frame = ("af",), land_color = "#666666", water_color = "#BABCDE", shoreline_pen = "1/0.25p", G_pt_marker = "c", G_pt_size = "0.05", G_pt_unit = "c", mode = "auto", color_mode = "auto", cmap = None, series = [0, 1], tricolor = True, P_tricolor_cpt = None, tricolor_kwargs = None, P_cat_cpt = None, cat_labels = ("Agreement", "Model-dominant", "Obs-dominant"), cat_colors = ("#FDAE61", "#2CA25F", "#2171B5"), cbar_pos = "JBC+w15c/0.45c+mc+h", cbar_frame = ("xa1f0.5",), cat_cbar_frame = ("+lFIP difference category",), transparency = None, weight_da = None, transparency_from = "auto", n_transparency_bins = 5, plot_GI = False, GI_color = "#E349D0", GI_marker = "s", GI_size = "0.01", GI_pt_unit = "c", plot_bathymetry = False, lon_coord_name = 'lon', lat_coord_name = 'lat'): """ expressly for plotting a map of fast ice persistence that is either a difference (model - observations), or a fast ice persistence map of either model or observation data. If a difference map then it can either be a continuous colourbar or a categorical colourbar. """ import pygmt if plot_GI: self.load_cice_grid(slice_hem=True) if plot_bathymetry: SO_BATH = self.load_IBCSO_bath() if mode == "auto": mode = "diff" if (obs is not None and mod is not None) else "single" if color_mode == "auto": if mode == "diff": if isinstance(plot_data, xr.DataArray) and ("diff_cat" in (plot_data.name or "").lower()): color_mode = "categorical" else: color_mode = "continuous" else: color_mode = "continuous" t = None if mode == "diff" and (obs is not None and mod is not None) and color_mode == "continuous": obs2, mod2 = xr.align(obs, mod, join="exact") P_tricolor_cpt = P_tricolor_cpt or (Path(self.D_graph) / "CPTs" / "FIP_tricolor.cpt") tricolor_kwargs = tricolor_kwargs or {} cpt_path, z2d, t2d = self.tricolor_cpt_and_alpha(P_tricolor_cpt, obs2.values, mod2.values, **tricolor_kwargs) z_da = xr.DataArray(z2d, dims=obs2.dims, coords=obs2.coords, name="FIP_diff") prep = self.pygmt_da_prep(z_da, lon_coord_name=lon_coord_name, lat_coord_name=lat_coord_name, region=region, mask_zero=False, return_mask=True, return_flat_index=True) x, y, z, mask2d = prep["lon"], prep["lat"], prep["z"], prep["mask2d"] t = t2d.ravel()[prep["flat_idx"]].astype(float) cmap_use = str(cpt_path) if transparency_from == "weight": t = None elif transparency_from == "none": t = None else: if plot_data is None: raise ValueError("plot_data must be provided unless using obs+mod continuous diff mode.") if not isinstance(plot_data, xr.DataArray): raise TypeError("plot_data must be an xarray.DataArray.") prep = self.pygmt_da_prep(plot_data, lon_coord_name=lon_coord_name, lat_coord_name=lat_coord_name, region=region, mask_zero=False, return_mask=True, return_flat_index=True) x, y, z, mask2d = prep["lon"], prep["lat"], prep["z"], prep["mask2d"] if mode == "single": cmap_use = cmap or self.pygmt_dict.get("FIP_CPT", None) if cmap_use is None: raise ValueError("No CPT found for single-source mode.") else: if color_mode == "continuous": cmap_use = cmap if cmap_use is None: P_master = Path(self.D_graph) / "CPTs" / "FIP_tricolor_master.cpt" master = self.write_tricolour_cpt(P_master, vmin=-1.0, vmid=0.0, vmax=1.0) P_out = Path(self.D_graph) / "CPTs" / "FIP_tricolor_0p01.cpt" pygmt.makecpt(cmap=str(master), series=[-1, 1, 0.01], output=str(P_out)) cmap_use = str(P_out) else: P_cat_cpt = P_cat_cpt or (Path(self.D_graph) / "CPTs" / "FIP_diff_cat.cpt") P_cat_cpt = Path(P_cat_cpt) P_cat_cpt.parent.mkdir(parents=True, exist_ok=True) pygmt.makecpt(cmap=",".join(cat_colors), series=[0, 2, 1], categorical=True, color_model="R+c" + ",".join(cat_labels), output=str(P_cat_cpt)) cmap_use = str(P_cat_cpt) if transparency_from in ("auto", "weight") and (weight_da is not None) and (t is None): w2 = weight_da.squeeze(drop=True).values wv = np.asarray(w2[mask2d], dtype=float) wv = np.clip(wv, 0.0, 1.0) t = (1.0 - wv) * 100.0 if transparency is not None: t = transparency created_here = fig is None if created_here: fig = pygmt.Figure() frame = list(basemap_frame) if panel_title is not None: frame = list(frame) + [f"+t{panel_title}"] panel_ctx = fig.set_panel(panel=panel) if panel is not None else nullcontext() self.logger.info(panel) self.logger.info(panel_ctx) with pygmt.config(FONT_TITLE = "18p,Bookman-Demi", FONT_ANNOT_PRIMARY = "16p,NewCenturySchlbk-Roman", FONT_ANNOT_SECONDARY = "16p,NewCenturySchlbk-Bold", FONT_LABEL = "16p,NewCenturySchlbk-Bold", COLOR_FOREGROUND = "black"): with panel_ctx: fig.basemap(region=region, projection=projection, frame=frame) if plot_bathymetry: fig.grdimage(grid=SO_BATH, cmap="geo") else: fig.coast(region=region, projection=projection, land=land_color, water=water_color) style = f"{G_pt_marker}{G_pt_size}{G_pt_unit}" if t is None: fig.plot(x=x, y=y, style=style, zvalue=z, fill="+z", cmap=cmap_use) else: t = np.asarray(t, dtype=float) good = np.isfinite(x) & np.isfinite(y) & np.isfinite(z) & np.isfinite(t) xg, yg, zg, tg = x[good], y[good], z[good], t[good] keep = tg < 100.0 xg, yg, zg, tg = xg[keep], yg[keep], zg[keep], tg[keep] edges = np.linspace(0.0, 100.0, int(n_transparency_bins) + 1) bin_idx = np.digitize(tg, edges, right=False) - 1 bin_idx = np.clip(bin_idx, 0, len(edges) - 2) for b in range(len(edges) - 1): sel = bin_idx == b if not np.any(sel): continue tau = float(0.5 * (edges[b] + edges[b + 1])) fig.plot(x=xg[sel], y=yg[sel], style=style, zvalue=zg[sel], fill="+z", cmap=cmap_use) if plot_GI: fig.plot(x = self.G_GI["lon"].values.ravel(), y = self.G_GI["lat"].values.ravel(), fill = GI_color, style = f"{GI_marker}{GI_size}{GI_pt_unit}") fig.coast(region=region, projection=projection, shorelines=shoreline_pen) if add_colorbar: if color_mode == "categorical": fig.colorbar(position=cbar_pos, frame=list(cat_cbar_frame), cmap=cmap_use) else: fig.colorbar(position=cbar_pos, frame=list(cbar_frame), cmap=cmap_use) if created_here and P_png: fig.savefig(P_png) if created_here and show_fig: fig.show() if return_cmap: return fig, cmap_use return fig
[docs] def pygmt_FIP_diff_four_panel(self, da_diff, *, sector_names, weight_da=None, fig_size=("24c", "17c"), margins=["0.15c", "0.15c"], basemap_frame = "af", G_pt_size="0.08", plot_GI=True, GI_color="#0072B2", GI_size="0.08", P_png=None, show_fig=False, cbar_width="12c", cbar_bframe = 'n', left_label="model-dominant", # correct for diff = mod - obs mid_label="agreement", right_label="obs-dominant"): import pygmt fig = pygmt.Figure() cmap_use = None with fig.subplot(nrows=2, ncols=2, figsize=fig_size, margins=margins, frame="lrtb", autolabel=False): for i, reg_name in enumerate(sector_names): reg_vals = self.Ant_8sectors[reg_name] region = reg_vals["plot_region"] MC = self.get_meridian_center_from_geographic_extent(region) # Use subplot-sized width, not a fixed 20c width. projection = self._format_subplot_projection(reg_vals["projection"], MC) fig, cmap_use = self.pygmt_FIP_figure(plot_data = da_diff, mode = "diff", color_mode = "continuous", weight_da = weight_da, region = region, projection = projection, G_pt_size = G_pt_size, plot_GI = plot_GI, plot_bathymetry = False, GI_color = GI_color, GI_size = GI_size, basemap_frame = (basemap_frame,), fig = fig, panel = [i // 2, i % 2], panel_title = reg_name, add_colorbar = False, return_cmap = True) self.add_fip_shared_colorbar(fig, cmap_use, cbar_bframe = cbar_bframe, fig_width = fig_size[0], cbar_width = cbar_width, left_label = left_label, mid_label = mid_label, right_label = right_label) if P_png is not None: P_png = Path(P_png) P_png.parent.mkdir(parents=True, exist_ok=True) fig.savefig(P_png, dpi=300) if show_fig: fig.show() return fig
def _smooth_df_time(self, df: pd.DataFrame, window, center=True, min_periods=1): """ Smooth a time series in df with columns ['time','data']. 'window' can be an int (# of samples) or a time offset string like '15D'. Returns a new df with smoothed 'data'. """ if window is None: return df s = df.set_index("time")["data"] # Pandas rolling supports both integer and time-based windows on a DatetimeIndex s_sm = s.rolling(window=window, center=center, min_periods=min_periods).mean() out = df.copy() out["data"] = s_sm.values return out def _repeat_doy_over_range(self, df_time_data, full_range, time_coord="time"): src = df_time_data.copy() src["doy"] = src[time_coord].dt.dayofyear clim_lookup = src.groupby("doy")["data"].mean() # index = 1..365 or 366 rep = pd.DataFrame({"time": full_range}) rep["doy"] = rep["time"].dt.dayofyear # handle leap day: if 366 not in lookup, remap DOY=366 -> 365 if 366 not in clim_lookup.index: rep["doy_eff"] = rep["doy"].where(rep["doy"] != 366, 365) else: rep["doy_eff"] = rep["doy"] rep["rep"] = rep["doy_eff"].map(clim_lookup) return rep[["time", "rep"]]
[docs] def pygmt_timeseries(self, ts_dict, comp_name : str = "test", primary_key : str = "FIA", # "FIA" smooth : str|int | None = None, # e.g., 15, "15D" clim_smooth : int | None = None, # 15 climatology : bool = False, ylabel : str = None, # "@[Fast Ice Area (1\\times10^3 km^2)@[", ylim : tuple = [0,1000], yaxis_pri : int = None, ytick_pri : int = 100, ytick_sec : int = 50, projection : str = None, fig_width : str = None, fig_height : str = None, xaxis_pri : str = None, xaxis_sec : str = None, frame_bndy : str = "WS", line_colors : dict[str, str] | None = None, # {"sim": "#RRGGBB", ...} legend_labels : dict[str, str] | None = None, # {"sim": "nice label", ...} legend_pos : str = None, legend_box : str = "+gwhite+p0.5p", fmt_dt_pri : str = None, fmt_dt_sec : str = None, fmt_dt_map : str = None, fnt_type : str = "Helvetica", fnt_wght_lab : str = "20p", fnt_wght_ax : str = "18p", line_pen : str = "1p", grid_wght_pri : str = ".25p", grid_wght_sec : str = ".1p", P_png : str = None, time_coord : str = "time", time_coord_alt : str = "date", keys2plot : list = None, repeat_keys : list[str] | None = None, repeat_policy : str = "inside_others", # "inside_others" | "outside_others" | "fill_gaps" | "always" repeat_ref_keys : list[str] | None = None, # which keys define the "others" window; default: all except current clip_x_axis : bool = False, zero_line : bool = False, zero_line_level : float = 0.0, zero_line_pen : str = "2p,black", save_fig : bool = None, show_fig : bool = None): """ Plot time series of a primary variable (e.g., FIA) for a set of simulations or observations. Parameters ---------- ts_dict : dict Dictionary of xarray DataArrays keyed by simulation or dataset name. comp_name : str Name for the comparison (used in figure title and filename). primary_key : str Key used to extract the variable from each dataset (except 'AF2020'). climatology : bool If True, plot daily climatology with fill and mean lines. ylim : tuple or None Y-axis limits. If None, inferred from data with 5% padding. ylabel : str Label for the Y-axis. ytick_inc : int Interval between Y-axis ticks. xaxis_pri, xaxis_sec : str GMT frame settings for primary and secondary axes (when climatology is False). P_png : str or None Optional full path to save figure. If None, default filename is constructed. legend_box : str GMT legend box styling. line_pen : str Line thickness for plotted time series. time_coord : str Name of time coordinate in each DataArray. keys2plot : list or None If provided, only datasets with keys in this list are plotted. show_fig : bool or None If True, show figure interactively. Defaults to self.show_fig. repeat_keys : list[str] or None Keys in `ts_dict` to plot as a repeated day-of-year climatology when `climatology=False`. repeat_policy : {"outside_others","fill_gaps","always"} - "outside_others": keep original values where *other* series exist in time, but replace values outside that union with day-of-year climatology of the nominated series (good for fair comparison). - "fill_gaps": use climatology only where the nominated series has no data, keep original where it does. - "always": ignore original values and plot the repeated climatology across the full x-range. repeat_ref_keys : list[str] or None If provided, these keys define the "others" time span for the "outside_others" policy. By default, it's all plotted series except the current one. """ import pygmt show_fig = show_fig if show_fig is not None else self.show_fig save_fig = save_fig if save_fig is not None else self.save_fig fmt_dt_pri = fmt_dt_pri if fmt_dt_pri is not None else "Character" fmt_dt_sec = fmt_dt_sec if fmt_dt_sec is not None else "Abbreviated" fmt_dt_map = fmt_dt_map if fmt_dt_map is not None else "o" line_colors = line_colors or {} legend_labels = legend_labels or {} # primary_key = primary_key if primary_key is not None else keys2plot[0] # if keys2plot is None: # raise("either 'primary_key' or 'keys2plot' must be defined") # need to get out the maximum times for plot boundaries tmin, tmax = self.extract_min_max_dates(ts_dict, keys2plot=keys2plot, primary_key=primary_key, time_coord=time_coord) # --- normalize new params --- if isinstance(repeat_keys, (str, bytes)): repeat_keys = [repeat_keys] repeat_keys = set(repeat_keys or []) if isinstance(repeat_ref_keys, (str, bytes)): repeat_ref_keys = [repeat_ref_keys] x_start, x_end = tmin, tmax if (not climatology) and clip_x_axis and repeat_keys: # reference keys = all plotted except the repeated ones, unless user specified if repeat_ref_keys: ref_keys = set(repeat_ref_keys) else: ref_keys = {k for k in ts_dict.keys() if (keys2plot is None or k in (keys2plot or [])) and k not in repeat_keys} other_mins, other_maxs = [], [] for k in ref_keys: da2 = ts_dict[k][primary_key] tt2 = pd.to_datetime(da2[time_coord].values) if len(tt2) > 0: other_mins.append(tt2.min()); other_maxs.append(tt2.max()) if other_mins: # only override if we actually have refs x_start = pd.to_datetime(min(other_mins)).normalize() x_end = pd.to_datetime(max(other_maxs)).normalize() # there are differences in the projection and x-axis for the two types of figures if climatology: fake_year = 1996 fig_width = fig_width if fig_width is not None else "20c" fig_height = fig_height if fig_height is not None else "15c" xaxis_sec = xaxis_sec if xaxis_sec is not None else None xaxis_pri = xaxis_pri if xaxis_pri is not None else "a1Og" region = [f"{fake_year}-01-01", f"{fake_year}-12-31", ylim[0], ylim[1]] else: fig_width = fig_width if fig_width is not None else "50c" fig_height = fig_height if fig_height is not None else "15c" xaxis_pri = xaxis_pri if xaxis_pri is not None else "a2Of30D"#g30D" xaxis_sec = xaxis_sec if xaxis_sec is not None else "a1Y" region = [x_start.strftime("%Y-%m-%d"), x_end.strftime("%Y-%m-%d"), ylim[0], ylim[1]] # define the projection and frame of the figure legend_pos = legend_pos if legend_pos is not None else f"JTL+jTL+o0.2c+w{fig_width}" projection = f"X{fig_width}/{fig_height}" yaxis_pri = yaxis_pri if yaxis_pri is not None else f"a{ytick_pri}f{ytick_sec}+l{ylabel}" if xaxis_sec is not None: frame = [frame_bndy, f"sx{xaxis_sec}", f"px{xaxis_pri}", f"py{yaxis_pri}"] else: frame = [frame_bndy, f"px{xaxis_pri}", f"py{yaxis_pri}"] # make sure the figure has the same configuration by wrapping it in a with condition fig = pygmt.Figure() with pygmt.config(FONT_LABEL = f"{fnt_wght_lab},{fnt_type}", FONT = f"{fnt_wght_ax},{fnt_type}", MAP_GRID_PEN_PRIMARY = grid_wght_pri, MAP_GRID_PEN_SECONDARY = grid_wght_sec, FORMAT_TIME_PRIMARY_MAP = fmt_dt_pri, FORMAT_TIME_SECONDARY_MAP = fmt_dt_sec, FORMAT_DATE_MAP = fmt_dt_map): fig.basemap(projection=projection, region=region, frame=frame) # loop over each key in the dictionary and if keys2plot is defined only plot those dictionaries cnt=0 for i, (dict_key, data) in enumerate(ts_dict.items()): self.logger.info(f"extracting data array from ts_dict[{dict_key} : {primary_key}]") da = data[primary_key] self.logger.info(f"matching {dict_key} with JSON-toolbox-config dictionary for line color and legend label") # removed option to pass 'line_clr' and 'leg_lab' with ts_dict line_color = self.plot_var_dict.get(dict_key, {}).get("line_clr", f"C{i}") leg_lab = self.plot_var_dict.get(dict_key, {}).get("leg_lab", dict_key ) # skip if keys2plot specified if keys2plot is not None and dict_key not in keys2plot: continue # choose color/label: explicit override > plot_var_dict > default cycle line_color = (line_colors.get(dict_key) or self.plot_var_dict.get(dict_key, {}).get("line_clr") or f"C{cnt}") leg_lab = (legend_labels.get(dict_key) or self.plot_var_dict.get(dict_key, {}).get("leg_lab") or dict_key) cnt += 1 self.logger.info(f" legend label: {leg_lab}") self.logger.info(f" line color : {line_color}") if dict_key=="AF2020" and primary_key=="FIA": df = pd.DataFrame({"time": pd.to_datetime(da[time_coord_alt].values), "data": da.values}) else: df = pd.DataFrame({"time": pd.to_datetime(da[time_coord].values), "data": da.values}) if climatology: if dict_key=="AF2020" and primary_key=="FIA": clim = self.compute_doy_climatology(da, time_coord=time_coord_alt) else: clim = self.compute_doy_climatology(da) mean_x = clim['mean'].index mean_y = clim['mean'].values if clim_smooth is not None and clim_smooth > 1: mean_y = (pd.Series(mean_y, index=mean_x) .rolling(window=clim_smooth, center=True, min_periods=1) .mean() .values) fig.plot(x = np.concatenate([clim['min'].index, clim['max'].index[::-1]]), y = np.concatenate([clim['min'].values, clim['max'].values[::-1]]), fill = f"{line_color}@80", close = True, transparency = 80) fig.plot(x = mean_x, y = mean_y, pen = f"{line_pen},{line_color}", label = leg_lab) else: if dict_key in repeat_keys: # target x-range (union for figure) t_start, t_end = tmin.normalize(), tmax.normalize() full_range = pd.date_range(t_start, t_end, freq="D") # repeated climatology over the full range rep_df = self._repeat_doy_over_range(df, full_range) # cols: time, rep # original reindexed over full range base = pd.DataFrame({"time": full_range}).merge(df.rename(columns={"data": "orig"}), on="time", how="left") mix = base.merge(rep_df, on="time", how="left") if repeat_policy == "always": out_df = mix[["time", "rep"]].rename(columns={"rep": "data"}) elif repeat_policy == "fill_gaps": mix["data"] = mix["orig"].fillna(mix["rep"]) out_df = mix[["time", "data"]] elif repeat_policy in ("outside_others", "inside_others"): # Determine the "others" window if repeat_ref_keys is not None: ref_keys = set(repeat_ref_keys) else: ref_keys = set(k for k in ts_dict.keys() if k != dict_key and (keys2plot is None or k in keys2plot)) if len(ref_keys) == 0: # no refs -> behave like fill_gaps (but still allow inside_others to clip nothing) mix["data"] = mix["orig"].fillna(mix["rep"]) if repeat_policy == "inside_others": mix.loc[:, "data"] = np.nan # nothing to show if no reference window out_df = mix[["time", "data"]] else: # compute min/max over "others" other_mins, other_maxs = [], [] for k in ref_keys: d2 = ts_dict[k][primary_key] tt = pd.to_datetime(d2[time_coord].values) if len(tt) > 0: other_mins.append(tt.min()); other_maxs.append(tt.max()) other_min = pd.to_datetime(min(other_mins)).normalize() other_max = pd.to_datetime(max(other_maxs)).normalize() inside = (mix["time"] >= other_min) & (mix["time"] <= other_max) # inside the others’ window: use original where present, else repeated mix.loc[inside, "data"] = mix.loc[inside, "orig"].fillna(mix.loc[inside, "rep"]) if repeat_policy == "inside_others": # outside -> hide (clip) mix.loc[~inside, "data"] = np.nan else: # outside -> use repeated climatology mix.loc[~inside, "data"] = mix.loc[~inside, "rep"] out_df = mix[["time", "data"]] else: raise ValueError(f"Unknown repeat_policy: {repeat_policy}") # optional smoothing after composition out_df = self._smooth_df_time(out_df, window=smooth) fig.plot(x=out_df["time"], y=out_df["data"], pen=f"{line_pen},{line_color}", label=leg_lab) else: df_sm = self._smooth_df_time(df, window=smooth) fig.plot(x=df_sm["time"], y=df_sm["data"], pen=f"{line_pen},{line_color}", label=leg_lab) if zero_line: y0 = float(zero_line_level) y_min, y_max = ylim if (y_min <= y0 <= y_max): x0, x1 = region[0], region[1] fig.plot(x=[x0, x1], y=[y0, y0], pen=zero_line_pen) fig.legend(position=legend_pos, box=legend_box) if save_fig: F_png = f"{primary_key}_{comp_name}_{'climatology_' if climatology else ''}{tmin.strftime('%Y')}-{tmax.strftime('%Y')}.png" P_png = P_png if P_png is not None else Path(self.D_graph, "timeseries", F_png) fig.savefig(P_png, dpi=300) self.logger.info(f"saved figure to {P_png}") if show_fig: fig.show()
[docs] def pygmt_map_plot_multi_var_8sectors(self, das, var_names, sim_name = None, time_stamp = None, tit_str = None, panel_titles = None, plot_GI = False, diff_plots = None, cmaps = None, series_list = None, reverse_list = None, cbar_labels = None, cbar_units_list = None, extend_cbars = False, cbar_positions = None, lon_coord_names = None, lat_coord_names = None, use_bcoords = False, use_tcoords = False, fig_size = None, var_sq_size = 0.2, GI_sq_size = 0.1, GI_fill_color = "red", plot_iceshelves = True, plot_bathymetry = True, land_color = None, water_color = None, P_png = None, var_out = None, overwrite_fig = None, show_fig = None, xshift = "w+1c"): """ 8-sector plot with 2–3 panels laid out left-to-right using shift_origin. Performance: compute PyGMT plot point clouds ONCE per panel (u1/u2/du), then sector-filter the point clouds cheaply (no repeated Dask compute). """ import pygmt def _as_list(x, n, name): if x is None: return [None] * n if isinstance(x, (list, tuple)): if len(x) != n: raise ValueError(f"{name} must have length {n} (got {len(x)})") return list(x) return [x] * n def _as_bool_list(x, n, name): if x is None: return [False] * n if isinstance(x, (list, tuple)): if len(x) != n: raise ValueError(f"{name} must have length {n} (got {len(x)})") return [bool(v) for v in x] return [bool(x)] * n def _sector_png_path(P, sector_name): if P is None: return None P = Path(P) if P.suffix == "": return P / f"{sector_name}.png" return P.parent / f"{P.stem}_{sector_name}{P.suffix}" def _norm360(a): return np.mod(a, 360.0) def _lon_in_range_1d(lon, lon_min, lon_max): lon = _norm360(lon) lon_min = lon_min % 360.0 lon_max = lon_max % 360.0 if lon_min <= lon_max: return (lon >= lon_min) & (lon <= lon_max) else: # crosses dateline return (lon >= lon_min) | (lon <= lon_max) # --- validate --- if not isinstance(das, (list, tuple)) or not isinstance(var_names, (list, tuple)): raise TypeError("das and var_names must be lists/tuples") n_pan = len(das) if n_pan not in (2, 3): raise ValueError(f"das must contain 2 or 3 DataArrays (got {n_pan})") if len(var_names) != n_pan: raise ValueError("var_names must match length of das") for j, da in enumerate(das): if hasattr(da, "dims") and "time" in da.dims: raise ValueError(f"das[{j}] has 'time' dim; pass a 2D time-slice.") # --- defaults --- sim_name = sim_name if sim_name is not None else self.sim_name show_fig = show_fig if show_fig is not None else self.show_fig ow_fig = overwrite_fig if overwrite_fig is not None else self.ow_fig time_stamp = time_stamp if time_stamp is not None else self.dt0_str fig_size = fig_size if fig_size is not None else self.pygmt_dict["fig_size"] land_color = land_color if land_color is not None else self.pygmt_dict["land_color"] water_color = water_color if water_color is not None else self.pygmt_dict["water_color"] if panel_titles is None: panel_titles = list(var_names) diff_plots_l = _as_bool_list(diff_plots, n_pan, "diff_plots") cmaps_l = _as_list(cmaps, n_pan, "cmaps") series_l = _as_list(series_list, n_pan, "series_list") reverse_l = _as_list(reverse_list, n_pan, "reverse_list") cbar_labels_l = _as_list(cbar_labels, n_pan, "cbar_labels") cbar_units_l = _as_list(cbar_units_list, n_pan, "cbar_units_list") extend_cbars_l = _as_bool_list(extend_cbars, n_pan, "extend_cbars") cbar_positions_l = _as_list(cbar_positions, n_pan, "cbar_positions") lon_names_l = _as_list(lon_coord_names, n_pan, "lon_coord_names") lat_names_l = _as_list(lat_coord_names, n_pan, "lat_coord_names") use_bcoords_l = _as_bool_list(use_bcoords, n_pan, "use_bcoords") use_tcoords_l = _as_bool_list(use_tcoords, n_pan, "use_tcoords") # overlays loaded once if plot_iceshelves: ANT_IS = self.load_ice_shelves() if plot_bathymetry: SO_BATH = self.load_IBCSO_bath() if plot_GI: plot_GI_dict = self.load_GI_lon_lats() # ensure grid loaded (for tcoords/bcoords path in pygmt_da_prep) self.load_cice_grid(slice_hem=True) reg_dict = self.Ant_8sectors if var_out is None: var_out = "_".join(var_names) # --- BIG SPEED WIN: compute plot point clouds ONCE per panel --- self.logger.info("precomputing plot point clouds (once per panel) ...") pdicts = [] for j in range(n_pan): if (lon_names_l[j] is not None) and (lat_names_l[j] is not None): pd = self.pygmt_da_prep(das[j], bcoords=False, tcoords=False, lon_coord_name=lon_names_l[j], lat_coord_name=lat_names_l[j], diff_plot=diff_plots_l[j]) else: if use_bcoords_l[j] and use_tcoords_l[j]: raise ValueError("Cannot set both use_bcoords and use_tcoords True") pd = self.pygmt_da_prep(das[j], bcoords=use_bcoords_l[j], tcoords=use_tcoords_l[j] if (use_bcoords_l[j] is False) else False, diff_plot=diff_plots_l[j]) pdicts.append(pd) # Precompute sector masks ONCE using lon/lat from panel 0 (assumes common grid) lon0 = pdicts[0]["lon"] lat0 = pdicts[0]["lat"] if not hasattr(self, "_sector_point_idx_cache"): self._sector_point_idx_cache = {} sector_idx = {} for reg_name, reg_vals in reg_dict.items(): region = reg_vals["plot_region"] # [lonmin, lonmax, latmin, latmax] lonmin, lonmax, latmin, latmax = region m = _lon_in_range_1d(lon0, lonmin, lonmax) & (lat0 >= latmin) & (lat0 <= latmax) idx = np.flatnonzero(m) sector_idx[reg_name] = idx # --- plot --- for reg_name, reg_vals in reg_dict.items(): idx = sector_idx.get(reg_name, None) if idx is None or idx.size == 0: self.logger.warning(f"{reg_name}: no points in this sector after masking; skipping.") continue # output path per sector P_png_reg = _sector_png_path(P_png, reg_name) if P_png_reg is None and self.save_fig: P_png_reg = Path(self.D_graph, sim_name, reg_name, var_out, f"{time_stamp}_{sim_name}_{reg_name}_{var_out}.png") region = reg_vals["plot_region"] MC = self.get_meridian_center_from_geographic_extent(region) projection = reg_vals["projection"].format(MC=MC, fig_size=fig_size) fig = pygmt.Figure() with pygmt.config(FONT_TITLE = "16p,Courier-Bold", FONT_ANNOT_PRIMARY = "14p,Helvetica", COLOR_FOREGROUND = "black"): for j in range(n_pan): if j > 0: fig.shift_origin(xshift=xshift, yshift="0c") vn = var_names[j] # per-panel defaults cmap_j = cmaps_l[j] if cmaps_l[j] is not None else self.plot_var_dict[vn]["cmap"] series_j = series_l[j] if series_l[j] is not None else self.plot_var_dict[vn]["series"] rev_j = reverse_l[j] if reverse_l[j] is not None else self.plot_var_dict[vn]["reverse"] lab_j = cbar_labels_l[j] if cbar_labels_l[j] is not None else self.plot_var_dict[vn]["name"] unit_j = cbar_units_l[j] if cbar_units_l[j] is not None else self.plot_var_dict[vn]["units"] cbar_pos_j = cbar_positions_l[j] if cbar_pos_j is None: cbar_pos_j = self.pygmt_dict["cbar_pos"].format(width=fig_size * 0.8, height=0.75) # basemap frame title_here = panel_titles[j] if panel_titles is not None else vn if (j == 0) and (tit_str is not None): frame = ["af", f"+t{tit_str}"] else: frame = ["af", f"+t{title_here}"] fig.basemap(region=region, projection=projection, frame=frame) if plot_bathymetry: fig.grdimage(grid=SO_BATH, cmap="geo") else: fig.coast(region=region, projection=projection, shorelines="1/0.5p,gray30", land=land_color, water=water_color) # Plot points for this sector pd = pdicts[j] pygmt.makecpt(cmap=cmap_j, reverse=rev_j, series=series_j) fig.plot(x=pd["lon"][idx], y=pd["lat"][idx], fill=pd["data"][idx], style=f"s{var_sq_size}c", cmap=True) # overlays if plot_bathymetry: fig.coast(region=region, projection=projection, shorelines="1/0.5p,gray30") if plot_GI: fig.plot(x=plot_GI_dict["lon"], y=plot_GI_dict["lat"], fill=GI_fill_color, style=f"c{GI_sq_size}c") if plot_iceshelves: fig.plot(data=ANT_IS, fill="lightgray") # colorbar cbar_frame = self.create_cbar_frame(series_j, lab_j, units=unit_j, extend_cbar=extend_cbars_l[j]) fig.colorbar(position=cbar_pos_j, frame=cbar_frame) # save/show if P_png_reg: P_png_reg.parent.mkdir(parents=True, exist_ok=True) if (not P_png_reg.exists()) or ow_fig: fig.savefig(P_png_reg) self.logger.info(f"Saved figure to {P_png_reg}") else: self.logger.info(f"{P_png_reg} already exists and not overwriting") if show_fig: fig.show()