from __future__ import annotations
import xarray as xr
import pandas as pd
import numpy as np
from pathlib import Path
from contextlib import nullcontext
__all__ = ["SeaIcePlotter"]
[docs]
class SeaIcePlotter:
"""
Plotting utilities for AFIM / CICE sea-ice diagnostics.
This class groups helper methods that prepare gridded xarray.DataArray fields
for plotting (primarily with PyGMT), including:
- Writing custom CPT (colour palette table) files for continuous or categorical
difference fields (e.g., obs–model agreement maps).
- Building “alpha/transparency” layers that modulate opacity by confidence or
persistence strength.
- Resolving 2D lon/lat coordinates from a DataArray via explicit coordinate names,
inferred coordinate names, or fallback class grid objects (B-grid / T-grid).
- Converting a 2D DataArray into 1D lon/lat/z vectors + masks suitable for
`pygmt.Figure.plot` / `pygmt.Figure.grdimage` workflows.
Notes
-----
- Several methods assume the class has access to grid datasets such as `self.G_t`
and `self.G_u`, and a method `self.load_bgrid(slice_hem=True)` that populates them.
- The methods here are designed to be safe with Dask-backed DataArrays: they will
compute only when necessary and will avoid materialising coordinate grids unless
required.
- CPT writing methods create plain text files on disk and return the resolved path.
Expected Attributes / Methods
-----------------------------
_hex_to_rgb : callable
Helper that converts hex colour strings (e.g. "#2CA25F") to integer RGB triplets.
load_bgrid : callable
Loads grid coordinate datasets (at least `self.G_t` and/or `self.G_u`).
G_t, G_u : xarray.Dataset
Grid datasets holding 2D lon/lat fields under keys like `"lon"` and `"lat"`.
normalise_longitudes : callable
Converts longitudes to a target convention ("0-360" or "-180-180").
"""
def __init__(self,**kwargs):
"""
Initialise the plotter.
Parameters
----------
**kwargs
Implementation-dependent configuration. In typical AFIM usage, kwargs may
include paths, logging, grid configuration, or pre-loaded grid datasets.
Notes
-----
- This constructor is intentionally lightweight; many plotting workflows defer
expensive grid loading until it is required (e.g., in `_resolve_lonlat_2d()`).
"""
return
[docs]
def load_ice_shelves(self):
"""
Load and preprocess Antarctic ice shelf polygons for PyGMT plotting.
This method reads a shapefile containing Antarctic coastal geometries,
filters for polygons classified as ice shelves (`POLY_TYPE == 'S'`),
ensures valid geometry, reprojects them to WGS84 (EPSG:4326),
and applies a zero-width buffer to clean topology issues.
Returns
-------
geopandas.GeoSeries
Cleaned and reprojected geometries of Antarctic ice shelves.
Notes
-----
- The input shapefile path is read from `self.config['pygmt_dict']['P_coast_shape']`.
- This method is typically used to overlay ice shelf boundaries in PyGMT plots.
- The returned geometries can be passed directly to `pygmt.Figure.plot()`.
See Also
--------
- self.plot_FIA_FIP_faceted : Uses this method to overlay ice shelves in map panels.
- geopandas.read_file : For reading shapefiles.
"""
import geopandas as gpd
gdf = gpd.read_file(self.config['pygmt_dict']['P_coast_shape'])
shelves = gdf[gdf['POLY_TYPE'] == 'S']
shelves = shelves[~shelves.geometry.is_empty & shelves.geometry.notnull()]
shelves = shelves.to_crs("EPSG:4326")
shelves.geometry = shelves.geometry.buffer(0)
return shelves.geometry
[docs]
def create_IBCSO_bath(self):
"""
Extract and save a masked IBCSO bathymetry layer for Antarctic plotting.
This method loads the IBCSO v2.0 dataset (as a NetCDF raster), extracts
the seafloor depth (negative elevations), masks out land areas (positive or zero),
and saves a cleaned version to NetCDF for use in plotting.
The result is saved at the path specified in
`self.config['pygmt_dict']['P_IBCSO_bath']`.
Returns
-------
None
Notes
-----
- Input file is assumed to be a NetCDF raster with variable `band_data`.
- Only `band=0` is used (i.e., the main bathymetric band).
- Output variable is named `bath` and excludes attributes and encodings for simplicity.
- This method is typically called once to prepare the bathymetry layer before reuse.
See Also
--------
- self.load_IBCSO_bath : Loads the pre-saved masked bathymetry layer.
- https://www.ibcso.org/ : International Bathymetric Chart of the Southern Ocean.
"""
ds = xr.open_dataset(self.config['pygmt_dict']['P_IBCSO_bed'])
bed = ds.band_data.isel(band=0)
bed_masked = bed.where(bed < 0)
bed_masked.name = "bath"
bed_masked.attrs = {}
ds_out = bed_masked.to_dataset()
ds_out.attrs = {}
ds_out.encoding = {}
ds_out.to_netcdf(self.config['pygmt_dict']["P_IBCSO_bath"])
[docs]
def load_IBCSO_bath(self):
"""
Load masked IBCSO bathymetry dataset prepared by `create_IBCSO_bath()`.
Returns the `bath` variable from the NetCDF file specified in
`self.config['pygmt_dict']["P_IBCSO_bath"]`.
Returns
-------
xarray.DataArray
2D array of bathymetry values (only ocean depths, in meters).
Values are negative below sea level, NaN over land.
Notes
-----
- This method assumes that `create_IBCSO_bath()` has already been run.
- Designed for use in PyGMT background plotting or masking.
See Also
--------
- self.create_IBCSO_bath : Creates this file if it doesn't exist.
- xarray.open_dataset : Loads the bathymetry layer.
"""
return xr.open_dataset(self.config['pygmt_dict']["P_IBCSO_bath"]).bath
[docs]
def create_cbar_frame(self, series, label, units=None, extend_cbar=False, max_ann_steps=10):
"""
Construct a GMT-style colorbar annotation frame string for PyGMT plotting.
This utility generates a clean, readable colorbar frame string using adaptive
step sizing based on the data range. An optional second axis label (e.g., for units)
and extension arrows can be included.
Parameters
----------
series : list or tuple of float
Data range for the colorbar as [vmin, vmax].
label : str
Label text for the colorbar (e.g., "Fast Ice Persistence").
units : str, optional
Units label to be shown along the secondary (y) axis (e.g., "1/100").
extend_cbar : bool, optional
Whether to append extension arrows to the colorbar (+e). Default is False.
max_ann_steps : int, default=10
Desired maximum number of major annotations (controls tick spacing).
Returns
-------
str or list of str
GMT-format colorbar annotation string (e.g., "a0.1f0.02+lLabel"), or
a list of two strings if `units` is provided.
Notes
-----
- Tick spacing is determined using a scaled logarithmic rounding to ensure clean steps (e.g., 0.1, 0.2, 0.5).
- If `extend_cbar` is True, `+e` is added to indicate out-of-bounds extension arrows.
- Compatible with `pygmt.Figure.colorbar(frame=...)`.
Examples
--------
>>> self.create_cbar_frame([0.01, 1.0], "Persistence", units="1/100")
['a0.1f0.02+lPersistence', 'y+l 1/100']
"""
vmin, vmax = series[0], series[1]
vrange = vmax - vmin
raw_step = vrange / max_ann_steps
exp = np.floor(np.log10(raw_step))
base = 10 ** exp
mult = raw_step / base
if mult < 1.5:
ann_step = base * 1
elif mult < 3:
ann_step = base * 2
elif mult < 7:
ann_step = base * 5
else:
ann_step = base * 10
tick_step = ann_step / 5
ann_str = f"{ann_step:.3f}".rstrip("0").rstrip(".")
tick_str = f"{tick_step:.3f}".rstrip("0").rstrip(".")
# Build annotation string
frame = f"a{ann_str}f{tick_str}+l{label}"
if extend_cbar:
frame += "+e" # or use "+eU" / "+eL" for one-sided arrows
if units is not None:
return [frame, f"y+l {units}"]
else:
return frame
[docs]
def get_meridian_center_from_geographic_extent(self, geographic_extent):
"""
Determine the optimal central meridian for PyGMT polar stereographic projections.
Given a geographic extent in longitude/latitude format, this method calculates
the central meridian (longitude) to use in 'S<lon>/<lat>/<width>' PyGMT projection
strings. It accounts for dateline wrapping and ensures the plot is centered visually.
Parameters
----------
geographic_extent : list of float
Geographic region as [min_lon, max_lon, min_lat, max_lat]. Accepts longitudes in
either [-180, 180] or [0, 360] and gracefully handles dateline crossing.
Returns
-------
float
Central meridian (longitude) in the range [-180, 180].
Notes
-----
- If the computed center falls outside the intended range (e.g., due to dateline wrapping),
the method rotates the meridian 180° to better align the figure.
- The result is saved to `self.plot_meridian_center` for reuse.
- Used in PyGMT stereographic projections like `'S{lon}/-90/30c'`.
See Also
--------
- https://docs.generic-mapping-tools.org/latest/cookbook/proj.html
- pygmt.Figure.basemap : For setting map projections using meridian centers.
"""
lon_min, lon_max = geographic_extent[0], geographic_extent[1]
lon_min_360 = lon_min % 360
lon_max_360 = lon_max % 360
if (lon_max_360 - lon_min_360) % 360 > 180:
center = ((lon_min_360 + lon_max_360 + 360) / 2) % 360
else:
center = (lon_min_360 + lon_max_360) / 2
if center > 180:
center -= 360
# Edge case fix: ensure center is visually aligned with geographic_extent
# If the computed center is 180° out of phase (i.e., upside-down plots)
if not (geographic_extent[0] <= center <= geographic_extent[1]):
# Flip 180°
center = (center + 180) % 360
if center > 180:
center -= 360
#print(f"meridian center computed as {center:.2f}°")
self.plot_meridian_center = center
return center
[docs]
def generate_regional_annotation_stats(self, da,
region,
lon_coord_name,
lat_coord_name,
var_name, # decimals for lat/lon in min/max
*, aice_da=None, u_da=None, v_da=None,
area_unit="1e6km2", # "km2" or "1e6km2"
vol_unit="1e3km3",
vel_unit="cm/s",
loc_ndp=3):
def _area_scale(area_m2):
if area_unit == "km2":
return area_m2 / 1e6, "km^2"
elif area_unit == "1e6km2":
return area_m2 / 1e12, "10^6 km^2"
else:
return area_m2, "m^2"
def _vol_scale(vol_m3):
if vol_unit == "km3":
return vol_m3 / 1e9, "km^3"
elif vol_unit == "1e3km3":
return vol_m3 / 1e12, "10^3 km^3"
else:
return vol_m3, "m^3"
def _fmt_loc(latv, lonv):
return f"({latv:.{loc_ndp}f}, {lonv:.{loc_ndp}f})"
lon = da[lon_coord_name]
lat = da[lat_coord_name]
# Basic coordinate sanity
if (not np.isfinite(lon).any()) or (not np.isfinite(lat).any()):
return ["(no valid lon/lat coords)"]
lon0, lon1, lat0, lat1 = region
# Dateline-safe lon mask
if lon0 <= lon1:
m_lon = (lon >= lon0) & (lon <= lon1)
else:
m_lon = (lon >= lon0) | (lon <= lon1)
m_reg = m_lon & (lat >= lat0) & (lat <= lat1)
# Ensure grid area available
if not hasattr(self, "G_t") or ("area" not in self.G_t):
self.load_cice_grid()
area = self.G_t["area"]
# Align core fields
da2, lon2, lat2, area2, m_reg2 = xr.align(da, lon, lat, area, m_reg, join="inner")
# -------------------------
# aice special case (unchanged logic)
# -------------------------
if var_name == "aice":
da_reg = da2.where(m_reg2)
thr = float(self.icon_thresh)
ice_cells = da_reg > thr
if int(ice_cells.sum().values) == 0:
sie_val, sie_unit = _area_scale(0.0)
sia_val, sia_unit = _area_scale(0.0)
return [f"SIE: {sie_val:.2f} {sie_unit}",
f"SIA: {sia_val:.2f} {sia_unit}",
"Mean aice (ice cells): NaN",
"Std aice (ice cells): NaN",
f"Cells (aice>{thr:g}): 0"]
sie_m2 = area2.where(ice_cells).sum(skipna=True)
sia_m2 = (da_reg.clip(0, 1) * area2).sum(skipna=True)
aice_mean = da_reg.where(ice_cells).mean(skipna=True)
aice_std = da_reg.where(ice_cells).std(skipna=True)
n_cells = int(ice_cells.sum().values)
sie_val, sie_unit = _area_scale(float(sie_m2.values))
sia_val, sia_unit = _area_scale(float(sia_m2.values))
return [
f"SIE: {sie_val:.2f} {sie_unit}",
f"SIA: {sia_val:.2f} {sia_unit}",
f"Mean aice (ice cells): {float(aice_mean.values):.2f}",
f"Std aice (ice cells): {float(aice_std.values):.2f}",
f"Cells (aice>{float(self.icon_thresh):g}): {n_cells}",
]
# -------------------------
# hi special case
# -------------------------
if var_name == "hi":
hi_reg = da2.where(m_reg2)
# Prefer defining "ice cells" using aice>icon_thresh for consistency with SIE/SIA/SIV
if aice_da is not None:
aice2, = xr.align(aice_da, hi_reg, join="inner")[:1] # align to same grid/time
aice_reg = aice2.where(m_reg2)
thr = float(self.icon_thresh)
ice_cells = aice_reg > thr
if int(ice_cells.sum().values) == 0:
sie_val, sie_unit = _area_scale(0.0)
sia_val, sia_unit = _area_scale(0.0)
siv_val, siv_unit = _vol_scale(0.0)
return [f"SIE: {sie_val:.2f} {sie_unit}",
f"SIA: {sia_val:.2f} {sia_unit}",
f"SIV: {siv_val:.2f} {siv_unit}",
"Mean SIT (SIV/SIA): NaN",
"Median hi (ice cells): NaN",
"P90 hi (ice cells): NaN",
"Max hi: NaN at (NaN, NaN)"]
sie_m2 = area2.where(ice_cells).sum(skipna=True)
sia_m2 = (aice_reg.clip(0, 1) * area2).where(ice_cells).sum(skipna=True)
# Ice volume: hi * aice * area (most defensible for CICE aggregate hi)
siv_m3 = (hi_reg.clip(min=0) * aice_reg.clip(0, 1) * area2).where(ice_cells).sum(skipna=True)
sie_val, sie_unit = _area_scale(float(sie_m2.values))
sia_val, sia_unit = _area_scale(float(sia_m2.values))
siv_val, siv_unit = _vol_scale(float(siv_m3.values))
mean_sit = float((siv_m3 / (sia_m2 * 1.0)).values) # m (since siv: m^3, sia: m^2)
# Distribution over ice cells (unweighted)
v = hi_reg.where(ice_cells).stack(z=("nj", "ni")).dropna("z")
med = float(v.quantile(0.50).values)
p90 = float(v.quantile(0.90).values)
# Max thickness + location
zmax = v.idxmax("z").values
lon_s = lon2.stack(z=("nj", "ni"))
lat_s = lat2.stack(z=("nj", "ni"))
vmax = float(v.sel(z=zmax).values)
max_loc = (float(lat_s.sel(z=zmax).values), float(lon_s.sel(z=zmax).values))
return [
f"SIE: {sie_val:.2f} {sie_unit}",
f"SIA: {sia_val:.2f} {sia_unit}",
f"SIV: {siv_val:.2f} {siv_unit}",
f"Mean SIT (SIV/SIA): {mean_sit:.2f} m",
f"Median hi (ice cells): {med:.2f} m",
f"P90 hi (ice cells): {p90:.2f} m",
f"Max hi: {vmax:.2f} m at {_fmt_loc(*max_loc)}",
]
# -------------------------
# Special case: ice speed magnitude (ispd)
# -------------------------
if var_name == "ispd":
if aice_da is None:
return ["(need aice_da for ispd stats: mask ice cells)"]
# Ensure grid area available
if not hasattr(self, "G_t") or ("area" not in self.G_t):
self.load_cice_grid()
area = self.G_t["area"]
lon = da[lon_coord_name]
lat = da[lat_coord_name]
lon0, lon1, lat0, lat1 = region
if lon0 <= lon1:
m_lon = (lon >= lon0) & (lon <= lon1)
else:
m_lon = (lon >= lon0) | (lon <= lon1)
m_reg = m_lon & (lat >= lat0) & (lat <= lat1)
# Align da (speed), aice, area, and region mask
da2, a2, area2, m2 = xr.align(da, aice_da, area, m_reg, join="inner")
da_reg = da2.where(m2)
a_reg = a2.where(m2)
thr = float(self.icon_thresh)
ice_cells = (a_reg > thr) & np.isfinite(da_reg)
n_cells = int(ice_cells.sum().values)
if n_cells == 0:
return [f"Cells (aice>{thr:g}): 0"]
# area metrics
sie_m2 = area2.where(ice_cells).sum(skipna=True)
sia_m2 = (a_reg.clip(0, 1) * area2).where(ice_cells).sum(skipna=True)
sie_val, sie_unit = _area_scale(float(sie_m2.values))
sia_val, sia_unit = _area_scale(float(sia_m2.values))
# speed stats (m/s)
mean_spd = float(da_reg.where(ice_cells).mean(skipna=True).values)
med_spd = float(da_reg.where(ice_cells).quantile(0.5, skipna=True).values)
p90_spd = float(da_reg.where(ice_cells).quantile(0.9, skipna=True).values)
# max speed + location
v = da_reg.where(ice_cells).stack(z=("nj", "ni")).dropna("z")
zmax = v.idxmax("z").values
lon_s = lon.where(ice_cells).stack(z=("nj", "ni")).dropna("z")
lat_s = lat.where(ice_cells).stack(z=("nj", "ni")).dropna("z")
vmax = float(v.sel(z=zmax).values)
max_loc = (float(lat_s.sel(z=zmax).values), float(lon_s.sel(z=zmax).values))
def _fmt_loc(latv, lonv):
return f"({latv:.{loc_ndp}f}, {lonv:.{loc_ndp}f})"
return [
f"SIE: {sie_val:.2f} {sie_unit}",
f"SIA: {sia_val:.2f} {sia_unit}",
f"Mean speed (ice): {mean_spd:.2f} m/s",
f"Median speed (ice): {med_spd:.2f} m/s",
f"P90 speed (ice): {p90_spd:.2f} m/s",
f"Max speed (ice): {vmax:.2f} m/s at {_fmt_loc(*max_loc)}",
f"Cells (aice>{thr:g}): {n_cells}",
]
# -------------------------
# Default (other vars): your existing min/max logic is fine
# -------------------------
da_reg = da2.where(m_reg2)
v = da_reg.stack(z=("nj", "ni")).dropna("z")
if v.size == 0:
return ["(no valid data in region)"]
mean_v = float(v.mean().values)
std_v = float(v.std().values)
n = int(v.size)
zmin = v.idxmin("z").values
zmax = v.idxmax("z").values
lon_s = lon2.stack(z=("nj", "ni"))
lat_s = lat2.stack(z=("nj", "ni"))
vmin = float(v.sel(z=zmin).values)
vmax = float(v.sel(z=zmax).values)
min_loc = (float(lat_s.sel(z=zmin).values), float(lon_s.sel(z=zmin).values))
max_loc = (float(lat_s.sel(z=zmax).values), float(lon_s.sel(z=zmax).values))
return [
f"Mean: {mean_v:.2f}",
f"Std: {std_v:.2f}",
f"Cells: {n}",
f"Min: {vmin:.2f} at {_fmt_loc(*min_loc)}",
f"Max: {vmax:.2f} at {_fmt_loc(*max_loc)}",
]
def _hex_to_rgb(self, hexstr: str) -> tuple[int, int, int]:
h = hexstr.lstrip("#")
return (int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16))
[docs]
def write_tricolour_cpt(
self,
P_cpt,
*,
vmin: float = -1.0,
vmid: float = 0.0,
vmax: float = 1.0,
cmin: str = "#2CA25F", # green
cmid: str = "#FDAE61", # orange
cmax: str = "#2171B5", # blue
background_rgb: str = "255/255/255",
foreground_rgb: str = "0/0/0",
nan_rgb: str = "255/255/255") -> str:
"""
Create and write a 3-colour continuous CPT file for difference fields.
This writes a simple two-segment CPT suitable for continuous data where values
below/above a midpoint should transition through three anchor colours:
- [vmin, vmid] mapped cmin → cmid
- [vmid, vmax] mapped cmid → cmax
The output is a GMT CPT text file, usable directly by PyGMT (e.g. `cmap=P_cpt`).
Parameters
----------
P_cpt : str or pathlib.Path
Target output path for the CPT file. Parent directories are created if needed.
vmin, vmid, vmax : float, optional
Data range breakpoints. `vmid` defines the central colour transition point.
cmin, cmid, cmax : str, optional
Hex colours used at `vmin`, `vmid`, and `vmax` respectively.
Example: "#2CA25F".
background_rgb : str, optional
GMT background colour specification used for values below `vmin` (B record),
formatted as "R/G/B".
foreground_rgb : str, optional
GMT foreground colour specification used for values above `vmax` (F record),
formatted as "R/G/B".
nan_rgb : str, optional
GMT NaN colour specification (N record), formatted as "R/G/B".
Returns
-------
str
String path to the written CPT file.
Notes
-----
- This function assumes `self._hex_to_rgb()` returns (r, g, b) integers in [0, 255].
- GMT expects CPT lines of the form:
z0 r/g/b z1 r/g/b
plus optional B/F/N records for background/foreground/NaN.
"""
r1, g1, b1 = self._hex_to_rgb(cmin)
r2, g2, b2 = self._hex_to_rgb(cmid)
r3, g3, b3 = self._hex_to_rgb(cmax)
cpt_text = (
f"{vmin} {r1}/{g1}/{b1} {vmid} {r2}/{g2}/{b2}\n"
f"{vmid} {r2}/{g2}/{b2} {vmax} {r3}/{g3}/{b3}\n"
f"B {background_rgb}\n"
f"F {foreground_rgb}\n"
f"N {nan_rgb}\n"
)
P_cpt = Path(P_cpt)
P_cpt.parent.mkdir(parents=True, exist_ok=True)
P_cpt.write_text(cpt_text)
return str(P_cpt)
[docs]
def tricolor_cpt_and_alpha(
self,
P_cpt,
obs,
mod,
*,
# CPT
vmin: float = -1.0,
vmid: float = 0.0,
vmax: float = 1.0,
cmin: str = "#2CA25F",
cmid: str = "#FDAE61",
cmax: str = "#2171B5",
# difference sign
diff_mode: str = "obs-mod", # recommended for green=model-dom, blue=obs-dom
# alpha rule
alpha_mode: str = "max", # "max" or "mean"
alpha_power: float = 1.0,
alpha_floor: float = 0.0,
nan_transparency: float = 100.0): # percent (0 opaque, 100 transparent)
"""
Build a tricolour CPT plus a transparency (alpha) layer from obs/model fields.
This helper is designed for “difference + confidence” map products where:
- colour encodes the signed difference (obs − model or model − obs), clipped
into a fixed range [vmin, vmax], and
- transparency encodes the strength/credibility of the signal (e.g., based on
persistence or agreement magnitude), computed from the obs/model magnitudes.
The function returns:
- `cpt_path`: path to the written tricolour CPT
- `z`: the clipped difference array
- `t`: transparency percent array (0 = opaque, 100 = fully transparent)
Parameters
----------
P_cpt : str or pathlib.Path
Output path for the CPT file written by `write_tricolour_cpt()`.
obs, mod : array-like
Observed and modelled fields on a common grid. These are converted to
`numpy.ndarray(dtype=float)` internally and must be broadcast-compatible.
vmin, vmid, vmax : float, optional
Colour scale limits and midpoint. The returned difference `z` is clipped
into [vmin, vmax].
cmin, cmid, cmax : str, optional
Hex colours for the CPT anchor points at vmin/vmid/vmax.
diff_mode : {"obs-mod", "mod-obs"}, optional
Sign convention for the difference:
- "obs-mod": z = obs - mod
- "mod-obs": z = mod - obs
Choose a convention consistent with your map interpretation (e.g. green as
model-dominant vs obs-dominant).
alpha_mode : {"max", "mean"}, optional
How to compute the “alpha strength” `a` from `obs` and `mod`:
- "max": a = max(obs, mod)
- "mean": a = 0.5 * (obs + mod)
In both cases, `a` is clipped to [0, 1].
alpha_power : float, optional
Exponent applied to `a` after clipping. Values > 1 sharpen opacity toward
high-confidence regions; values < 1 broaden opacity.
alpha_floor : float, optional
Minimum opacity floor expressed in alpha-space. When > 0, rescales:
a <- alpha_floor + (1 - alpha_floor) * a
This prevents regions from becoming too transparent.
nan_transparency : float, optional
Transparency percentage used where either `z` or `t` becomes non-finite
(NaN/inf). Default 100 = fully transparent.
Returns
-------
(str, np.ndarray, np.ndarray)
cpt_path : str
Path to the written CPT file.
z : np.ndarray
Difference field, clipped to [vmin, vmax].
t : np.ndarray
Transparency percent in [0, 100], where 0 is opaque.
Raises
------
ValueError
If `diff_mode` is not one of {"obs-mod", "mod-obs"}.
ValueError
If `alpha_mode` is not one of {"max", "mean"}.
Notes
-----
- Transparency is computed as: t = (1 - a) * 100, so larger `a` yields more
opaque pixels (smaller transparency percentage).
- This method does not attempt to regrid or align obs/model; inputs must already
be on a common grid and comparable.
"""
cpt_path = self.write_tricolour_cpt(
P_cpt, vmin=vmin, vmid=vmid, vmax=vmax, cmin=cmin, cmid=cmid, cmax=cmax
)
obs = np.asarray(obs, dtype=float)
mod = np.asarray(mod, dtype=float)
dm = diff_mode.lower()
if dm in ("obs-mod", "obs_minus_mod"):
z = obs - mod
elif dm in ("mod-obs", "mod_minus_obs"):
z = mod - obs
else:
raise ValueError("diff_mode must be 'obs-mod' or 'mod-obs'.")
z = np.clip(z, vmin, vmax)
am = alpha_mode.lower()
if am == "max":
a = np.maximum(obs, mod)
elif am == "mean":
a = 0.5 * (obs + mod)
else:
raise ValueError("alpha_mode must be 'max' or 'mean'.")
a = np.clip(a, 0.0, 1.0)
if alpha_power != 1.0:
a = a ** alpha_power
if alpha_floor > 0.0:
a = alpha_floor + (1.0 - alpha_floor) * a
t = (1.0 - a) * 100.0
bad = ~np.isfinite(z) | ~np.isfinite(t)
if np.any(bad):
t = np.where(bad, nan_transparency, t)
z = np.where(bad, np.nan, z)
return cpt_path, z, t
[docs]
def write_tricolor_category_cpt(self, P_cpt, *, # codes: 0=agreement, 1=model-dom, 2=obs-dom
int0=(0.0, 1.0),
int1=(1.0, 2.0),
int2=(2.0, 3.0),
hex0="#FDAE61", # agreement = orange
hex1="#2CA25F", # model-dom = green
hex2="#2171B5", # obs-dom = blue
background_rgb="255/255/255",
foreground_rgb="0/0/0",
nan_rgb="255/255/255") -> str:
"""
Create and write a 3-class categorical CPT file.
This is intended for integer-coded classification maps, commonly:
- 0 : agreement
- 1 : model-dominant
- 2 : observation-dominant
Each class is given a constant colour across its interval. The default colour
scheme matches the continuous tricolour palette:
- agreement: orange
- model-dominant: green
- obs-dominant: blue
Parameters
----------
P_cpt : str or pathlib.Path
Target output path for the categorical CPT file. Parent directories are
created if needed.
int0, int1, int2 : tuple[float, float], optional
Numeric intervals for categories 0, 1, and 2 respectively. Defaults are
(0,1), (1,2), (2,3), which works well with integer-coded rasters.
hex0, hex1, hex2 : str, optional
Hex colours for category 0, 1, and 2.
background_rgb : str, optional
GMT background (B record) colour as "R/G/B".
foreground_rgb : str, optional
GMT foreground (F record) colour as "R/G/B".
nan_rgb : str, optional
GMT NaN (N record) colour as "R/G/B".
Returns
-------
str
String path to the written CPT file.
Notes
-----
- The CPT uses constant colours per interval, which is typically what you want
for categorical class rasters.
- This function assumes `self._hex_to_rgb()` returns integer RGB triplets.
"""
r0, g0, b0 = self._hex_to_rgb(hex0)
r1, g1, b1 = self._hex_to_rgb(hex1)
r2, g2, b2 = self._hex_to_rgb(hex2)
cpt_text = (
f"{int0[0]} {r0}/{g0}/{b0} {int0[1]} {r0}/{g0}/{b0}\n"
f"{int1[0]} {r1}/{g1}/{b1} {int1[1]} {r1}/{g1}/{b1}\n"
f"{int2[0]} {r2}/{g2}/{b2} {int2[1]} {r2}/{g2}/{b2}\n"
f"B {background_rgb}\n"
f"F {foreground_rgb}\n"
f"N {nan_rgb}\n"
)
P_cpt = Path(P_cpt)
P_cpt.parent.mkdir(parents=True, exist_ok=True)
P_cpt.write_text(cpt_text)
return str(P_cpt)
def _is_dask_array(self, x) -> bool:
"""
Return True if `x` appears to be a Dask-backed array-like.
Parameters
----------
x : object
Any object. The check is heuristic and designed to work for common
Dask array types attached to xarray objects.
Returns
-------
bool
True if `x` has a `.compute()` method and its module path suggests a Dask type;
False otherwise.
Notes
-----
- This is intentionally lightweight and avoids importing dask explicitly.
- False positives are unlikely but possible if an object mimics the Dask API.
"""
return hasattr(x, "compute") and ("dask" in type(x).__module__.lower())
def _auto_mask_zero(self, da: xr.DataArray) -> bool:
"""
Decide whether zero values should be masked when preparing data for plotting.
Many continuous diagnostics use zeros as “no data” placeholders (or contain large
regions of structural zeros that would dominate plotting). However, categorical
fields frequently use 0 as a valid class label and must not be masked.
Parameters
----------
da : xr.DataArray
2D data field intended for plotting.
Returns
-------
bool
True if zeros should be masked (typical for continuous fields),
False if zeros should be retained (typical for categorical class rasters).
Decision Logic
--------------
- If the variable name suggests a categorical difference field (e.g. contains
"diff_cat"), or if `da.attrs["flag_values"]` indicates a {0,1,2} category
encoding, then zeros are treated as valid and are NOT masked.
- Otherwise, zeros are masked by default.
Notes
-----
- You can override this behaviour explicitly in `pygmt_da_prep(mask_zero=...)`.
"""
name = (da.name or "").lower()
fv = str(da.attrs.get("flag_values", "")).strip().replace(",", " ")
is_categorical = ("diff_cat" in name) or (fv in {"0 1 2", "0 1 2 "})
return not is_categorical
def _resolve_lonlat_2d(self, da2: xr.DataArray, *,
bcoords : bool = False,
tcoords : bool = True,
lon_coord_name : str | None = None,
lat_coord_name : str | None = None,
infer_if_missing : bool = True):
"""
Resolve 2D longitude/latitude arrays matching a 2D DataArray.
This method returns `(lon2d, lat2d)` as NumPy arrays matching `da2.shape`.
It supports multiple coordinate discovery strategies (in priority order):
1) Explicit coordinate names passed via `lon_coord_name` / `lat_coord_name`.
- Supports either 2D lon/lat coords aligned to the data, or 1D lon/lat
vectors that can be meshed into 2D grids.
2) Inference from common coordinate naming conventions present in `da2.coords`
(if `infer_if_missing=True`), e.g.:
("lon","lat"), ("longitude","latitude"), ("TLON","TLAT"), ("ULON","ULAT")
3) Fallback to class grid datasets (`self.G_u` for B-grid or `self.G_t` for T-grid),
loaded on demand via `self.load_bgrid(slice_hem=True)`, including optional
subsetting via `da2["nj"]` and `da2["ni"]` coordinate indices.
Parameters
----------
da2 : xr.DataArray
2D DataArray whose spatial coordinates are required.
bcoords : bool, default False
If True, use B-grid coordinates (typically `self.G_u["lon"]`, `self.G_u["lat"]`).
tcoords : bool, default True
If True, use T-grid coordinates (typically `self.G_t["lon"]`, `self.G_t["lat"]`).
lon_coord_name, lat_coord_name : str or None, optional
Explicit coordinate names to use from `da2`. Both must be provided together.
infer_if_missing : bool, default True
If True, attempt to infer coordinate names from common patterns in `da2.coords`
when explicit names are not provided.
Returns
-------
(np.ndarray, np.ndarray)
lon2d, lat2d arrays with shape equal to `da2.shape`.
Raises
------
ValueError
If `da2` is not 2D.
ValueError
If only one of `lon_coord_name` or `lat_coord_name` is provided.
ValueError
If both `bcoords` and `tcoords` are True, or both are False, when fallback
grid coordinates are required.
ValueError
If lon/lat cannot be resolved or cannot be matched/subset to `da2.shape`.
Notes
-----
- If `da2` lacks lon/lat coordinates and the grid arrays do not match in shape,
`da2` must provide integer index coordinates `nj` and `ni` so the grid can be
subset consistently.
"""
# Ensure clean 2D view, and standard ordering if possible
if da2.ndim != 2:
raise ValueError(f"Expected 2D DataArray; got dims={da2.dims}, shape={da2.shape}")
# ---- 1) Explicit coords ----
if (lon_coord_name is not None) or (lat_coord_name is not None):
if lon_coord_name is None or lat_coord_name is None:
raise ValueError("Provide both lon_coord_name and lat_coord_name.")
lon_da = da2[lon_coord_name]
lat_da = da2[lat_coord_name]
if lon_da.ndim == 2 and lat_da.ndim == 2:
# align dims if needed
if lon_da.dims != da2.dims:
da2 = da2.transpose(*lon_da.dims)
lon_da = da2[lon_coord_name]
lat_da = da2[lat_coord_name]
return np.asarray(lon_da.data), np.asarray(lat_da.data)
if lon_da.ndim == 1 and lat_da.ndim == 1:
lon2d, lat2d = np.meshgrid(np.asarray(lon_da.data), np.asarray(lat_da.data), indexing="xy")
return lon2d, lat2d
raise ValueError(f"Mixed/unsupported explicit coord dims: lon={lon_da.ndim}, lat={lat_da.ndim}")
# ---- 2) Infer coords from da2 ----
if infer_if_missing:
# common patterns
cand_pairs = [("lon", "lat"),
("longitude", "latitude"),
("TLON", "TLAT"),
("ULON", "ULAT")]
for xnm, ynm in cand_pairs:
if (xnm in da2.coords) and (ynm in da2.coords):
lon_da = da2.coords[xnm]
lat_da = da2.coords[ynm]
if lon_da.ndim == 2 and lat_da.ndim == 2:
if lon_da.dims != da2.dims:
da2 = da2.transpose(*lon_da.dims)
lon_da = da2.coords[xnm]
lat_da = da2.coords[ynm]
return np.asarray(lon_da.data), np.asarray(lat_da.data)
if lon_da.ndim == 1 and lat_da.ndim == 1:
lon2d, lat2d = np.meshgrid(np.asarray(lon_da.data), np.asarray(lat_da.data), indexing="xy")
return lon2d, lat2d
# ---- 3) B/T grid coords ----
if bcoords and tcoords:
raise ValueError("Cannot set both bcoords and tcoords True.")
if not (bcoords or tcoords):
raise ValueError("Must set bcoords or tcoords, or provide lon/lat coords on da.")
# Only load grid if we need it
# (adjust to your class: load_bgrid may load both; otherwise add load_tgrid, etc.)
self.load_bgrid(slice_hem=True)
if bcoords:
lon2d_full = np.asarray(self.G_u["lon"].values)
lat2d_full = np.asarray(self.G_u["lat"].values)
else:
lon2d_full = np.asarray(self.G_t["lon"].values)
lat2d_full = np.asarray(self.G_t["lat"].values)
if lon2d_full.shape == da2.shape:
return lon2d_full, lat2d_full
# subset lon/lat to match da2 using nj/ni index coords
if ("nj" in da2.coords) and ("ni" in da2.coords):
nj_idx = np.asarray(da2["nj"].values, dtype=int)
ni_idx = np.asarray(da2["ni"].values, dtype=int)
return lon2d_full[np.ix_(nj_idx, ni_idx)], lat2d_full[np.ix_(nj_idx, ni_idx)]
raise ValueError(f"Grid lon/lat shape {lon2d_full.shape} does not match da shape {da2.shape}, "
"and da has no nj/ni coords for subsetting.")
[docs]
def pygmt_da_prep(self, da: xr.DataArray, *,
bcoords : bool = False,
tcoords : bool = True,
lon_coord_name : str | None = None,
lat_coord_name : str | None = None,
region : tuple[float, float, float, float] | None = None,
lon_wrap : str = "auto", # "auto" | "0-360" | "-180-180"
extra_mask : xr.DataArray = None,
mask_zero : bool | None = None,
z_clip : tuple[float, float] | None = None,
z_range_mask : tuple[float, float] | None = None,
dtype : str = "float32",
infer_coords : bool = True,
return_mask : bool = True,
return_flat_index : bool = True):
"""
Prepare a 2D DataArray for PyGMT plotting by returning 1D lon/lat/z vectors.
This routine standardises a gridded field into a structure that PyGMT commonly
expects for point plotting or scattered gridding. It also constructs masks for:
- finite values,
- optional masking of structural zeros,
- optional clipping/masking by z-range,
- optional geographic subsetting by `region` (dateline-safe),
- optional additional user-provided mask.
Parameters
----------
da : xr.DataArray
Input data field. It may contain singleton dimensions; these will be squeezed.
After squeeze, the data must be 2D.
bcoords : bool, default False
Use B-grid coordinates if lon/lat are not directly available on `da`.
tcoords : bool, default True
Use T-grid coordinates if lon/lat are not directly available on `da`.
lon_coord_name, lat_coord_name : str or None, optional
Explicit coordinate names to use for longitude and latitude in `da`.
If provided, both must be provided.
region : tuple[float, float, float, float], optional
Geographic bounding box `(xmin, xmax, ymin, ymax)` applied as a mask.
This masking is dateline-safe: if `xmin > xmax`, the region is interpreted
as crossing the dateline and uses `(lon >= xmin) OR (lon <= xmax)`.
lon_wrap : {"auto", "0-360", "-180-180"}, default "auto"
Longitude convention to enforce on the resolved lon grid. If "auto" and
`region` is provided, the method will choose a convention consistent with
the region bounds; otherwise it leaves longitudes unchanged.
extra_mask : array-like, xr.DataArray, callable, optional
Additional mask to apply. If callable, it is invoked as `extra_mask(da2)`
and must return a boolean mask with the same shape as the 2D data.
If array-like, it is converted to a boolean array and combined with the
existing mask.
mask_zero : bool or None, optional
Whether to mask zeros (treat as missing) before flattening.
- If None, uses `_auto_mask_zero(da2)` to decide.
- If True, masks values close to zero (|z| <= 1e-8).
- If False, preserves zeros.
z_clip : tuple[float, float], optional
If provided, clip z values into `[z_clip[0], z_clip[1]]` prior to masking.
This is a hard clip (values outside become boundary values).
z_range_mask : tuple[float, float], optional
If provided, apply an additional mask retaining only values within
`[lo, hi]`.
dtype : str, default "float32"
Target dtype used when materialising arrays (especially helpful for memory).
infer_coords : bool, default True
If True, allow `_resolve_lonlat_2d()` to infer lon/lat from common coordinate
names on `da` when explicit names are not provided.
return_mask : bool, default True
If True, include `mask2d` (boolean 2D mask) in the returned dict.
return_flat_index : bool, default True
If True, include `flat_idx` (indices into `z2d.ravel()` where mask is True)
in the returned dict.
Returns
-------
dict
Always includes:
- "lon" : 1D np.ndarray of longitudes (masked + flattened)
- "lat" : 1D np.ndarray of latitudes (masked + flattened)
- "z" : 1D np.ndarray of values (masked + flattened)
- "shape" : tuple of original 2D shape for sanity checks
Optionally includes:
- "mask2d" : 2D boolean mask (if `return_mask=True`)
- "flat_idx" : 1D integer indices into `z2d.ravel()` (if `return_flat_index=True`)
Raises
------
ValueError
If `da` cannot be reduced to a 2D field (after squeeze).
ValueError
If coordinate resolution fails or `extra_mask` has a shape mismatch.
Notes
-----
- If `da` is Dask-backed, the data are computed only once, and coerced to `dtype`.
- If `da` has dims ("nj","ni") in a different order, the method transposes to
("nj","ni") to keep indexing consistent with AFIM grid conventions.
"""
da2 = da.squeeze(drop=True)
if da2.ndim != 2:
raise ValueError(f"Expected 2D DataArray after squeeze; got dims={da2.dims}, shape={da2.shape}")
# prefer canonical order if present
if ("nj" in da2.dims) and ("ni" in da2.dims) and (da2.dims != ("nj", "ni")):
da2 = da2.transpose("nj", "ni")
lon2d, lat2d = self._resolve_lonlat_2d(da2,
bcoords = bcoords,
tcoords = tcoords,
lon_coord_name = lon_coord_name,
lat_coord_name = lat_coord_name,
infer_if_missing = infer_coords)
# ---- NEW: harmonise lon convention with region ----
if lon_wrap == "auto":
if region is not None:
xmin, xmax, _, _ = region
# If region uses negative longitudes, force lon to -180..180
if (xmin < 0.0) or (xmax < 0.0):
lon2d = self.normalise_longitudes(lon2d, to="-180-180")
# If region is clearly 0..360-ish, force lon to 0..360
elif (xmin >= 0.0) and (xmax > 180.0):
lon2d = self.normalise_longitudes(lon2d, to="0-360")
# else: leave as-is
else:
lon2d = self.normalise_longitudes(lon2d, to=lon_wrap)
# materialize Z
z_data = da2.data
if self._is_dask_array(z_data):
z2d = da2.astype(dtype).compute().values
else:
z2d = np.asarray(z_data, dtype=dtype)
if z_clip is not None:
z2d = np.clip(z2d, z_clip[0], z_clip[1])
# base mask
m = np.isfinite(z2d)
if mask_zero is None:
mask_zero = self._auto_mask_zero(da2)
if mask_zero:
m &= ~np.isclose(z2d, 0.0, atol=1e-8)
if z_range_mask is not None:
lo, hi = z_range_mask
m &= (z2d >= lo) & (z2d <= hi)
# ---- region mask (DATELINE-SAFE) ----
if region is not None:
xmin, xmax, ymin, ymax = region
lat_ok = (lat2d >= ymin) & (lat2d <= ymax)
# handle regions that cross the dateline by allowing xmin > xmax
if xmin <= xmax:
lon_ok = (lon2d >= xmin) & (lon2d <= xmax)
else:
lon_ok = (lon2d >= xmin) | (lon2d <= xmax)
m &= lon_ok & lat_ok
if extra_mask is not None:
em = extra_mask(da2) if callable(extra_mask) else extra_mask
if isinstance(em, xr.DataArray):
em = np.asarray(em.squeeze(drop=True).values, dtype=bool)
else:
em = np.asarray(em, dtype=bool)
if em.shape != z2d.shape:
raise ValueError(f"extra_mask shape mismatch: {em.shape} vs {z2d.shape}")
m &= em
lon1 = np.asarray(lon2d, dtype=dtype)[m].ravel()
lat1 = np.asarray(lat2d, dtype=dtype)[m].ravel()
z1 = np.asarray(z2d, dtype=dtype)[m].ravel()
out = {"lon" : lon1,
"lat" : lat1,
"z" : z1,
"shape" : z2d.shape}
if return_mask:
out["mask2d"] = m
if return_flat_index:
out["flat_idx"] = np.flatnonzero(m.ravel())
return out
[docs]
def pygmt_map_plot_one_var(self, da, var_name,
aux_ds = None,
sim_name = None,
plot_regions = None,
regional_dict = None,
hemisphere = "south",
time_stamp = None,
tit_str = None,
plot_GI = False,
cmap = None,
series = None,
reverse = None,
cbar_label = None,
cbar_units = None,
extend_cbar = False,
cbar_position = None,
lon_name = None,
lat_name = None,
fig_size = None,
var_sq_size = 0.2,
GI_sq_size = 0.1,
GI_fill_color = "red",
plot_iceshelves= True,
plot_bathymetry= True,
add_stat_annot = False,
land_color = None,
water_color = None,
P_png = None,
var_out = None,
overwrite_fig = None,
show_fig = None):
"""
Generate a PyGMT figure showing a variable (e.g., FIA, FIP, differences) as a spatial map
over Antarctic regions or hemispheric view, optionally including grounded icebergs,
ice shelf outlines, and bathymetry.
This flexible mapping function supports multiple types of visualizations:
- Scalar field plots (e.g., fast ice persistence, concentration)
- Binary or categorical masks (e.g., agreement maps, simulation masks)
- Difference plots (e.g., observation minus simulation)
Parameters
----------
da : xarray.DataArray
Input 2D (or broadcastable) data array to be plotted, e.g., fast ice persistence.
var_name : str
Name of the variable to be plotted. Used to look up color map settings and labels.
aux_ds : optional Dataset containing auxiliaries (e.g. aice) aligned with da
sim_name : str, optional
Simulation name used for file naming and figure annotation. Defaults to `self.sim_name`.
plot_regions : int or None, optional
Number of regional plots:
- 8 : Antarctic 8-sector view
- 2 : East/West sectors
- None : Full hemisphere plot (default)
regional_dict : dict, optional
Custom dictionary of plotting regions (overrides built-ins if provided).
hemisphere : str, default="south"
Hemisphere name for hemispheric plot. Typically "south".
time_stamp : str, optional
Timestamp string for file naming. Defaults to `self.dt0_str`.
tit_str : str, optional
Title string to display on the figure.
plot_GI : bool, default=False
Whether to overlay grounded iceberg locations using `load_GI_lon_lats()`.
cmap : str, optional
Color map name for continuous scalar plots.
series : list of float, optional
Min, max, and increment for the colorbar (e.g., [0, 1, 0.1]).
reverse : bool, optional
Whether to reverse the colormap.
cbar_label : str, optional
Label to show on the colorbar.
cbar_units : str, optional
Units for the colorbar (shown on secondary axis).
extend_cbar : bool, default=False
Whether to add extension arrows to colorbar.
cbar_position : str, optional
PyGMT-compatible position string for placing the colorbar.
lon_name : str, optional
Longitude coordinate name in `da`. Defaults to `self.pygmt_dict`.
lat_name : str, optional
Latitude coordinate name in `da`. Defaults to `self.pygmt_dict`.
fig_size : float, optional
Figure width in centimeters for hemispheric plot.
var_sq_size : float, default=0.2
Marker size (in cm) for main plotted variable (for scatter plotting).
GI_sq_size : float, default=0.1
Marker size (in cm) for grounded iceberg overlay.
GI_fill_color : str, default="red"
Color to use for grounded iceberg markers.
plot_iceshelves : bool, default=True
Whether to overlay Antarctic ice shelf outlines.
plot_bathymetry : bool, default=True
Whether to plot IBCSO bathymetry using shaded relief (`grdimage`).
add_stat_annot : bool, default=False
Whether to annotate figure with basic regional statistics.
land_color : str, optional
Color for land. Defaults to `self.pygmt_dict['land_color']`.
water_color : str, optional
Color for ocean/water. Defaults to `self.pygmt_dict['water_color']`.
P_png : pathlib.Path, optional
File path to save the figure. If None, path is generated automatically.
var_out : str, optional
Output variable name used in file naming. Defaults to `var_name`.
overwrite_fig : bool, optional
Whether to overwrite existing figure if it exists.
show_fig : bool, optional
Whether to display the figure interactively.
Returns
-------
None
The method generates and optionally saves or displays a PyGMT figure.
Notes
-----
- Supports 8-region, 2-region, or full hemisphere plotting via `plot_regions`.
- For "diff" plots (where 'diff' in var_name), colors are assigned categorically (e.g., agreement/simulation/observation).
- Bathymetry and ice shelf overlays are loaded from NetCDF and shapefiles respectively.
- The method gracefully skips plotting if required coordinates or data are missing.
See Also
--------
- self.pygmt_da_prep : Prepares data dictionary for plotting
- self.create_cbar_frame : Builds formatted colorbar strings
- self.load_GI_lon_lats : Loads grounded iceberg locations
- self.load_ice_shelves : Loads Antarctic ice shelf polygons
- self.load_IBCSO_bath : Loads bathymetry grid from IBCSO
Examples
--------
>>> self.pygmt_map_plot_one_var(FIP_DA, "FIP", plot_regions=8, show_fig=True)
"""
import pygmt
sim_name = sim_name if sim_name is not None else self.sim_name
show_fig = show_fig if show_fig is not None else self.show_fig
ow_fig = overwrite_fig if overwrite_fig is not None else self.ow_fig
time_stamp = time_stamp if time_stamp is not None else self.dt0_str
lon_name = lon_name if lon_name is not None else self.pygmt_dict.get("lon_coord_name", "TLON")
lat_name = lat_name if lat_name is not None else self.pygmt_dict.get("lat_coord_name", "TLAT")
cmap = cmap if cmap is not None else self.plot_var_dict[var_name]['cmap']
series = series if series is not None else self.plot_var_dict[var_name]['series']
reverse = reverse if reverse is not None else self.plot_var_dict[var_name]['reverse']
cbar_lab = cbar_label if cbar_label is not None else self.plot_var_dict[var_name]['name']
cbar_units = cbar_units if cbar_units is not None else self.plot_var_dict[var_name]['units']
fig_size = fig_size if fig_size is not None else self.pygmt_dict['fig_size']
cbar_pos = cbar_position if cbar_position is not None else self.pygmt_dict['cbar_pos'].format(width=fig_size*0.8,height=0.75)
land_color = land_color if land_color is not None else self.pygmt_dict['land_color']
water_color = water_color if water_color is not None else self.pygmt_dict['water_color']
# Accept Dataset directly
if isinstance(da, xr.Dataset):
ds = da
da = ds[var_name]
else:
ds = aux_ds # may be None
aice_da = None
u_da = None
v_da = None
if aux_ds is not None:
if "aice" in aux_ds: aice_da = aux_ds["aice"]
if "uvel" in aux_ds: u_da = aux_ds["uvel"]
if "vvel" in aux_ds: v_da = aux_ds["vvel"]
if var_out is None:
var_out = var_name
if plot_iceshelves:
if not hasattr(self, "_ANT_IS"):
self._ANT_IS = self.load_ice_shelves()
ANT_IS = self._ANT_IS
if plot_bathymetry:
if not hasattr(self, "_SO_BATH"):
self._SO_BATH = self.load_IBCSO_bath()
SO_BATH = self._SO_BATH
required_keys = ['lon', 'lat', 'z']
plot_data_dict = self.pygmt_da_prep(da,
bcoords = False,
tcoords = False,
lon_coord_name = lon_name,
lat_coord_name = lat_name)
try:
if not isinstance(plot_data_dict, dict):
self.logger.warning("plot_data_dict is not a dictionary — skipping plot.")
return
for k in required_keys:
if k not in plot_data_dict:
self.logger.warning(f"Missing key '{k}' in plot_data_dict — skipping plot.")
return
v = plot_data_dict[k]
if v is None:
self.logger.warning(f"plot_data_dict['{k}'] is None — skipping plot.")
return
if hasattr(v, "size") and v.size == 0:
self.logger.warning(f"plot_data_dict['{k}'] is empty — skipping plot.")
return
except Exception as e:
self.logger.warning(f"Skipping plot due to error: {e}")
return
cbar_frame = self.create_cbar_frame(series, cbar_lab, units=cbar_units, extend_cbar=extend_cbar)
hem_plot = False
if plot_GI:
plot_GI_dict = self.load_GI_lon_lats()#plot_data_dict)
if plot_regions is not None and plot_regions==8:
self.logger.info("method will plot eight Antarctic sectors regional dictionary")
reg_dict = self.Ant_8sectors
elif plot_regions is not None and plot_regions==2:
self.logger.info("method will plot two Antarctic sectors regional dictionary")
reg_dict = self.Ant_2sectors
elif plot_regions is not None and regional_dict is not None:
self.logger.info("method will plot regional dictionary passed to this method")
reg_dict = regional_dict
elif plot_regions is not None and regional_dict is None:
self.logger.info("plot_regions argument not valid")
else:
self.logger.info("method will plot hemispheric data")
hem_plot = True
reg_dict = self.hemispheres_dict
if tit_str is not None:
basemap_frame = ["af", f"+t{tit_str}"]
else:
basemap_frame = ["af"]
for i, (reg_name, reg_vals) in enumerate(reg_dict.items()):
if hem_plot and reg_name!=hemisphere:
continue
if P_png is None and self.save_fig:
P_png = Path(self.D_graph, sim_name, reg_name, var_out, f"{time_stamp}_{sim_name}_{reg_name}_{var_out}.png")
region = reg_vals['plot_region']
projection = reg_vals['projection']
if hem_plot:
projection = projection.format(fig_size=fig_size)
elif reg_name in list(self.Ant_8sectors.keys()):
MC = self.get_meridian_center_from_geographic_extent(region)
projection = projection.format(MC=MC, fig_size=fig_size)
elif reg_name in list(self.Ant_2sectors.keys()):
projection = projection.format(fig_width=fig_size)
fig = pygmt.Figure()
with pygmt.config(FONT_TITLE = "16p,Courier-Bold",
FONT_ANNOT_PRIMARY = "14p,Helvetica",
COLOR_FOREGROUND = 'black'):
fig.basemap(region=region, projection=projection, frame=basemap_frame)
if plot_bathymetry:
fig.grdimage(grid=SO_BATH, cmap='geo')
else:
fig.coast(region=region, projection=projection, shorelines="1/0.5p,gray30", land=land_color, water=water_color)
pygmt.makecpt(cmap=cmap, reverse=reverse, series=series)
fig.plot(x=plot_data_dict['lon'], y=plot_data_dict['lat'], fill=plot_data_dict['z'], style=f"s{var_sq_size}c", cmap=True)
if plot_bathymetry:
fig.coast(region=region, projection=projection, shorelines="1/0.5p,gray30")
if plot_GI:
fig.plot(x=plot_GI_dict['lon'], y=plot_GI_dict['lat'], fill=GI_fill_color, style=f"c{GI_sq_size}c")
if plot_iceshelves:
fig.plot(data=ANT_IS, pen="0.2p,gray", fill="lightgray")
if add_stat_annot:
annot_text = self.generate_regional_annotation_stats(
da, region, lon_name, lat_name, var_name,
area_unit="1e6km2",
aice_da=aice_da, # <--- optional; used by hi
)
for i, line in enumerate(annot_text):
try:
fig.text(position="TR", text=line, font="12p,Helvetica-Bold,black", justify="LM", no_clip=True, offset=f"-1/{-0.5*i}")
except pygmt.exceptions.GMTCLibError as e:
self.logger.warning(f"Error in plotting anotation text {e} -- skipping annotation")
fig.colorbar(position=cbar_pos, frame=cbar_frame)
if P_png:
if not P_png.exists():
P_png.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(P_png)
self.logger.info(f"Saved figure to {P_png}")
else:
if ow_fig:
fig.savefig(P_png)
self.logger.info(f"Saved figure to {P_png}")
else:
self.logger.info(f"{P_png} already exists and not overwriting")
P_png = None
if show_fig:
fig.show()
pygmt.clib.Session.__exit__
[docs]
def pygmt_fastice_panel(self,
fast_ice_variable : str = "FIA", # "FIA"|"fia" or "FIT"|"fit" or "FIS|fis"
sim_name : str = None,
roll_days : int = 0,
# Generic (can be overridden)
fig_width : str = None,
fig_height : str = None,
ylim : tuple = None,
frame_bndy : str = None,
yaxis_pri : str = None,
xaxis_pri : str = "a1Of15Dg",
leg_pos : str = None,
leg_box : str = "+gwhite+p.5p",
spat_var_style : str = None,
GI_plot_style : str = "c0.05c",
GI_fill_color : str = "#BA561A",
plot_GI : bool = False,
min_max_trans_val : int = 80,
yshift_top : str = None,
yshift_bot : str = None,
bottom_frame_bndy : str = "WSne",
bottom_yaxis : str = None,
bottom_xaxis : str = None,
land_clr : str = '#D1DDE0',
water_clr : str = "#EDF2F5",
coast_pen : str = "1/0.5p,black",
cbar_pos : str = None,
lon_coord_name : str = None,
lat_coord_name : str = None,
cmap : str = None,
series : list = None,
cbar_frame : str = None,
ANT_IS_pen : str = "0.2p,black",
ANT_IS_color : str = "#C1CED6",
font_annot_pri : str = "24p,Times-Roman",
font_annot_sec : str = "16p,Times-Roman",
font_lab : str = "22p,Times-Bold",
line_pen : str = "2p",
grid_pen_pri : str = ".5p",
grid_pen_sec : str = ".25p",
fmt_geo_map : str = "D:mm",
P_png : str = None,
save_fig : bool = None,
overwrite_fig : bool = None,
show_fig : bool = None):
"""
Unified PyGMT fast-ice panel plotter.
Set `fast_ice_variable` to:
- "FIA" (or "fia") -> plots FIA time series + FIP maps
- "FIT" (or "fit") -> plots FIT time series + FIHI maps
Everything else (loading, styling, legends, grounded iceberg overlay, etc.)
follows the same code path with variable-specific defaults injected via
a small configuration dictionary.
The rest of the arguments let you override those defaults if you need to.
"""
import pygmt
var = fast_ice_variable.lower()
if var not in ("fia", "fit", "fis", "fimar", "fimvr", "fitar", "fitvr"):
raise ValueError(f"`fast_ice_variable` must be one of ['FIA','FIT','FIS','FIMAR','FIMVR','FITAR','FITVR']; got {fast_ice_variable}")
# -----------------------------------------------------------------------
# Per-variable defaults (you can push more things in here if you like)
# -------------------------------------------------------------------------
cfg = self.pygmt_FI_panel[var]
# -------------------------------------------------------------------------
# Resolve user overrides or fall back to defaults
# -------------------------------------------------------------------------
sim_name = sim_name if sim_name is not None else self.sim_name
show_fig = show_fig if show_fig is not None else self.show_fig
save_fig = save_fig if save_fig is not None else self.save_fig
ow_fig = overwrite_fig if overwrite_fig is not None else self.ow_fig
frame_bndy = frame_bndy if frame_bndy is not None else "WS"
fig_width = fig_width if fig_width is not None else "30c"
fig_height = fig_height if fig_height is not None else "25c"
spat_var_style = spat_var_style if spat_var_style is not None else "s0.2c"
yshift_top = yshift_top if yshift_top is not None else "-6.25c"
yshift_bot = yshift_bot if yshift_bot is not None else "-10.5c"
bottom_yaxis = bottom_yaxis if bottom_yaxis is not None else "a5f1g"
bottom_xaxis = bottom_xaxis if bottom_xaxis is not None else "a30f10g"
cbar_pos = cbar_pos if cbar_pos is not None else "JBC+w25c/1c+mc+h"
ylim = ylim if ylim is not None else cfg["panel_ylim"]
yaxis_pri = yaxis_pri if yaxis_pri is not None else cfg["yaxis_pri"]
leg_pos = leg_pos if leg_pos is not None else cfg["leg_pos"]
cbar_frame = cbar_frame if cbar_frame is not None else cfg["cbar_frame"]
cmap = cmap if cmap is not None else cfg["cmap"]
series = series if series is not None else cfg["series"]
# -------------------------------------------------------------------------
# Paths / loads
# -------------------------------------------------------------------------
ANT_IS = self.load_ice_shelves()
ice_methods = ["rolling-mean", "binary-days"]
ts_dict = {}
if var=='fia':
ts_dict["AF2020"] = xr.open_dataset(self.AF_FI_dict['P_AF2020_FIA'])["AF2020"]
for imeth in ice_methods:
mets = self.load_computed_metrics(class_method=imeth)
if imeth=="binary-days":
da_spat = mets[cfg["bottom_name"]]
ts_dict[imeth] = self.load_computed_metrics(class_method=imeth)[cfg["top_name"]]
tmin, tmax = self.extract_min_max_dates(ts_dict)
df_spat = self.pygmt_da_prep(da_spat)
if plot_GI:
plot_GI_dict = self.load_GI_lon_lats()
# -------------------------------------------------------------------------
# Plot
# -------------------------------------------------------------------------
plot_region = [f"{self.leap_year}-01-01", f"{self.leap_year}-12-31", ylim[0], ylim[1]]
plot_projection = f"X{fig_width}/{fig_height}"
frame = [frame_bndy, f"px{xaxis_pri}", f"py{yaxis_pri}"]
fig = pygmt.Figure()
with pygmt.config(FONT_ANNOT_PRIMARY = font_annot_pri,
FONT_ANNOT_SECONDARY = font_annot_sec,
FONT_LABEL = font_lab,
MAP_GRID_PEN_PRIMARY = grid_pen_pri,
MAP_GRID_PEN_SECONDARY = grid_pen_sec,
FORMAT_GEO_MAP = fmt_geo_map,
FORMAT_DATE_MAP = "o",
FORMAT_TIME_PRIMARY_MAP = "Abbreviated"):
# ---- time-series top panel ----
fig.basemap(region=plot_region, projection=plot_projection, **{"frame": frame})
for k, da in ts_dict.items():
if hasattr(self, "pygmt_FIA_dict") and k in self.pygmt_FIA_dict:
leg_lab = self.pygmt_FIA_dict[k]["label"]
line_pen = self.pygmt_FIA_dict[k]["line_pen"]
line_color = self.pygmt_FIA_dict[k]["line_color"]
else:
leg_lab, line_pen, line_color = k, "1.5p", "black"
clim = self.compute_doy_climatology(da)
fig.plot(x = np.concatenate([clim["min"].index, clim["max"].index[::-1]]),
y = np.concatenate([clim["min"].values, clim["max"].values[::-1]]),
fill = f"{line_color}@{min_max_trans_val}",
close = True,
transparency = min_max_trans_val)
fig.plot(x = clim["mean"].index,
y = clim["mean"].values,
pen = f"{line_pen},{line_color}",
label = leg_lab)
fig.legend(position=leg_pos, box=leg_box)
# ---- bottom panel(s) ----
fig.shift_origin(yshift=yshift_top)
pygmt.makecpt(cmap=cmap, series=series)
for i, (reg_name, reg_vals) in enumerate(self.Ant_2sectors.items()):
b_region = reg_vals["plot_region"]
b_projection = reg_vals["projection"].format(fig_width=fig_width)
if i > 0:
fig.shift_origin(yshift=yshift_bot)
fig.basemap(region = b_region,
projection = b_projection,
frame = [f"x{bottom_xaxis}", f"y{bottom_yaxis}"])
fig.coast(water=water_clr)
fig.plot(x = df_spat["lon"],
y = df_spat["lat"],
fill = df_spat["z"],
style = spat_var_style,
cmap = True)
if plot_GI:
fig.plot(x = plot_GI_dict["lon"],
y = plot_GI_dict["lat"],
fill = GI_fill_color,
style = GI_plot_style)
fig.coast(land=land_clr, shorelines=coast_pen)
fig.plot(data=ANT_IS, pen=ANT_IS_pen, fill=ANT_IS_color)
fig.colorbar(position=cbar_pos, frame=cbar_frame)
# Save / Show
if save_fig:
F_pre = f"{cfg['top_name']}_{cfg['bottom_name']}_{sim_name}"
F_suf = f"ispd_thresh{self.ispd_thresh_str}_{tmin.strftime('%Y')}-{tmax.strftime('%Y')}"
F_png = f"{F_pre}_{F_suf}.png"
P_png = P_png if P_png is not None else Path(self.D_graph, sim_name, F_png)
P_png.parent.mkdir(parents=True, exist_ok=True)
if not P_png.exists() or ow_fig:
fig.savefig(P_png, dpi=300)
self.logger.info(f"saved figure to {P_png}")
else:
self.logger.info(f"{P_png} already exists and not overwriting")
if show_fig:
fig.show()
[docs]
def add_fip_shared_colorbar(self, fig, cmap_use, *,
cbar_bframe = 'n',
fig_width = "24c",
cbar_width = "12c",
cbar_yoffset = "0.8c",
label_yshift = "-0.45c",
left_label = "model-dominant",
mid_label = "agreement",
right_label = "obs-dominant"):
# centered shared colorbar
fig.colorbar(cmap=cmap_use, position=f"JBC+w{cbar_width}/0.45c+h+o0c/{cbar_yoffset}", frame=["xa1f0.25"])
# full-width invisible strip under the bar
fig.shift_origin(yshift=label_yshift)
fig.basemap(region=[0, 1, 0, 1], projection=f"X{fig_width}/0.8c", frame=cbar_bframe)
# if cbar is 12c in a 24c-wide figure, bar spans x=[0.25, 0.75]
fig.text(x = [0.25, 0.50, 0.75],
y = [0.18, 0.18, 0.18],
text = [left_label, mid_label, right_label],
justify = "TC",
no_clip = True,
font = "12p,NewCenturySchlbk-Bold")
def _format_subplot_projection(self, proj_template, MC):
proj = proj_template.format(MC=MC, fig_size="?")
# important: strip any hard-coded unit attached to the placeholder
proj = proj.replace("?i", "?").replace("?c", "?")
return proj
[docs]
def pygmt_FIP_diff_four_panel(self, da_diff, *, sector_names,
weight_da=None,
fig_size=("24c", "17c"),
margins=["0.15c", "0.15c"],
basemap_frame = "af",
G_pt_size="0.08",
plot_GI=True,
GI_color="#0072B2",
GI_size="0.08",
P_png=None,
show_fig=False,
cbar_width="12c",
cbar_bframe = 'n',
left_label="model-dominant", # correct for diff = mod - obs
mid_label="agreement",
right_label="obs-dominant"):
import pygmt
fig = pygmt.Figure()
cmap_use = None
with fig.subplot(nrows=2, ncols=2, figsize=fig_size, margins=margins, frame="lrtb", autolabel=False):
for i, reg_name in enumerate(sector_names):
reg_vals = self.Ant_8sectors[reg_name]
region = reg_vals["plot_region"]
MC = self.get_meridian_center_from_geographic_extent(region)
# Use subplot-sized width, not a fixed 20c width.
projection = self._format_subplot_projection(reg_vals["projection"], MC)
fig, cmap_use = self.pygmt_FIP_figure(plot_data = da_diff,
mode = "diff",
color_mode = "continuous",
weight_da = weight_da,
region = region,
projection = projection,
G_pt_size = G_pt_size,
plot_GI = plot_GI,
plot_bathymetry = False,
GI_color = GI_color,
GI_size = GI_size,
basemap_frame = (basemap_frame,),
fig = fig,
panel = [i // 2, i % 2],
panel_title = reg_name,
add_colorbar = False,
return_cmap = True)
self.add_fip_shared_colorbar(fig, cmap_use,
cbar_bframe = cbar_bframe,
fig_width = fig_size[0],
cbar_width = cbar_width,
left_label = left_label,
mid_label = mid_label,
right_label = right_label)
if P_png is not None:
P_png = Path(P_png)
P_png.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(P_png, dpi=300)
if show_fig:
fig.show()
return fig
def _smooth_df_time(self, df: pd.DataFrame, window, center=True, min_periods=1):
"""
Smooth a time series in df with columns ['time','data'].
'window' can be an int (# of samples) or a time offset string like '15D'.
Returns a new df with smoothed 'data'.
"""
if window is None:
return df
s = df.set_index("time")["data"]
# Pandas rolling supports both integer and time-based windows on a DatetimeIndex
s_sm = s.rolling(window=window, center=center, min_periods=min_periods).mean()
out = df.copy()
out["data"] = s_sm.values
return out
def _repeat_doy_over_range(self, df_time_data, full_range, time_coord="time"):
src = df_time_data.copy()
src["doy"] = src[time_coord].dt.dayofyear
clim_lookup = src.groupby("doy")["data"].mean() # index = 1..365 or 366
rep = pd.DataFrame({"time": full_range})
rep["doy"] = rep["time"].dt.dayofyear
# handle leap day: if 366 not in lookup, remap DOY=366 -> 365
if 366 not in clim_lookup.index:
rep["doy_eff"] = rep["doy"].where(rep["doy"] != 366, 365)
else:
rep["doy_eff"] = rep["doy"]
rep["rep"] = rep["doy_eff"].map(clim_lookup)
return rep[["time", "rep"]]
[docs]
def pygmt_timeseries(self, ts_dict,
comp_name : str = "test",
primary_key : str = "FIA", # "FIA"
smooth : str|int | None = None, # e.g., 15, "15D"
clim_smooth : int | None = None, # 15
climatology : bool = False,
ylabel : str = None, # "@[Fast Ice Area (1\\times10^3 km^2)@[",
ylim : tuple = [0,1000],
yaxis_pri : int = None,
ytick_pri : int = 100,
ytick_sec : int = 50,
projection : str = None,
fig_width : str = None,
fig_height : str = None,
xaxis_pri : str = None,
xaxis_sec : str = None,
frame_bndy : str = "WS",
line_colors : dict[str, str] | None = None, # {"sim": "#RRGGBB", ...}
legend_labels : dict[str, str] | None = None, # {"sim": "nice label", ...}
legend_pos : str = None,
legend_box : str = "+gwhite+p0.5p",
fmt_dt_pri : str = None,
fmt_dt_sec : str = None,
fmt_dt_map : str = None,
fnt_type : str = "Helvetica",
fnt_wght_lab : str = "20p",
fnt_wght_ax : str = "18p",
line_pen : str = "1p",
grid_wght_pri : str = ".25p",
grid_wght_sec : str = ".1p",
P_png : str = None,
time_coord : str = "time",
time_coord_alt : str = "date",
keys2plot : list = None,
repeat_keys : list[str] | None = None,
repeat_policy : str = "inside_others", # "inside_others" | "outside_others" | "fill_gaps" | "always"
repeat_ref_keys : list[str] | None = None, # which keys define the "others" window; default: all except current
clip_x_axis : bool = False,
zero_line : bool = False,
zero_line_level : float = 0.0,
zero_line_pen : str = "2p,black",
save_fig : bool = None,
show_fig : bool = None):
"""
Plot time series of a primary variable (e.g., FIA) for a set of simulations or observations.
Parameters
----------
ts_dict : dict
Dictionary of xarray DataArrays keyed by simulation or dataset name.
comp_name : str
Name for the comparison (used in figure title and filename).
primary_key : str
Key used to extract the variable from each dataset (except 'AF2020').
climatology : bool
If True, plot daily climatology with fill and mean lines.
ylim : tuple or None
Y-axis limits. If None, inferred from data with 5% padding.
ylabel : str
Label for the Y-axis.
ytick_inc : int
Interval between Y-axis ticks.
xaxis_pri, xaxis_sec : str
GMT frame settings for primary and secondary axes (when climatology is False).
P_png : str or None
Optional full path to save figure. If None, default filename is constructed.
legend_box : str
GMT legend box styling.
line_pen : str
Line thickness for plotted time series.
time_coord : str
Name of time coordinate in each DataArray.
keys2plot : list or None
If provided, only datasets with keys in this list are plotted.
show_fig : bool or None
If True, show figure interactively. Defaults to self.show_fig.
repeat_keys : list[str] or None
Keys in `ts_dict` to plot as a repeated day-of-year climatology when
`climatology=False`.
repeat_policy : {"outside_others","fill_gaps","always"}
- "outside_others": keep original values where *other* series exist in time,
but replace values outside that union with day-of-year climatology of
the nominated series (good for fair comparison).
- "fill_gaps": use climatology only where the nominated series has no data,
keep original where it does.
- "always": ignore original values and plot the repeated climatology across
the full x-range.
repeat_ref_keys : list[str] or None
If provided, these keys define the "others" time span for the
"outside_others" policy. By default, it's all plotted series except the
current one.
"""
import pygmt
show_fig = show_fig if show_fig is not None else self.show_fig
save_fig = save_fig if save_fig is not None else self.save_fig
fmt_dt_pri = fmt_dt_pri if fmt_dt_pri is not None else "Character"
fmt_dt_sec = fmt_dt_sec if fmt_dt_sec is not None else "Abbreviated"
fmt_dt_map = fmt_dt_map if fmt_dt_map is not None else "o"
line_colors = line_colors or {}
legend_labels = legend_labels or {}
# primary_key = primary_key if primary_key is not None else keys2plot[0]
# if keys2plot is None:
# raise("either 'primary_key' or 'keys2plot' must be defined")
# need to get out the maximum times for plot boundaries
tmin, tmax = self.extract_min_max_dates(ts_dict, keys2plot=keys2plot, primary_key=primary_key, time_coord=time_coord)
# --- normalize new params ---
if isinstance(repeat_keys, (str, bytes)):
repeat_keys = [repeat_keys]
repeat_keys = set(repeat_keys or [])
if isinstance(repeat_ref_keys, (str, bytes)):
repeat_ref_keys = [repeat_ref_keys]
x_start, x_end = tmin, tmax
if (not climatology) and clip_x_axis and repeat_keys:
# reference keys = all plotted except the repeated ones, unless user specified
if repeat_ref_keys:
ref_keys = set(repeat_ref_keys)
else:
ref_keys = {k for k in ts_dict.keys() if (keys2plot is None or k in (keys2plot or [])) and k not in repeat_keys}
other_mins, other_maxs = [], []
for k in ref_keys:
da2 = ts_dict[k][primary_key]
tt2 = pd.to_datetime(da2[time_coord].values)
if len(tt2) > 0:
other_mins.append(tt2.min()); other_maxs.append(tt2.max())
if other_mins: # only override if we actually have refs
x_start = pd.to_datetime(min(other_mins)).normalize()
x_end = pd.to_datetime(max(other_maxs)).normalize()
# there are differences in the projection and x-axis for the two types of figures
if climatology:
fake_year = 1996
fig_width = fig_width if fig_width is not None else "20c"
fig_height = fig_height if fig_height is not None else "15c"
xaxis_sec = xaxis_sec if xaxis_sec is not None else None
xaxis_pri = xaxis_pri if xaxis_pri is not None else "a1Og"
region = [f"{fake_year}-01-01", f"{fake_year}-12-31", ylim[0], ylim[1]]
else:
fig_width = fig_width if fig_width is not None else "50c"
fig_height = fig_height if fig_height is not None else "15c"
xaxis_pri = xaxis_pri if xaxis_pri is not None else "a2Of30D"#g30D"
xaxis_sec = xaxis_sec if xaxis_sec is not None else "a1Y"
region = [x_start.strftime("%Y-%m-%d"), x_end.strftime("%Y-%m-%d"), ylim[0], ylim[1]]
# define the projection and frame of the figure
legend_pos = legend_pos if legend_pos is not None else f"JTL+jTL+o0.2c+w{fig_width}"
projection = f"X{fig_width}/{fig_height}"
yaxis_pri = yaxis_pri if yaxis_pri is not None else f"a{ytick_pri}f{ytick_sec}+l{ylabel}"
if xaxis_sec is not None:
frame = [frame_bndy, f"sx{xaxis_sec}", f"px{xaxis_pri}", f"py{yaxis_pri}"]
else:
frame = [frame_bndy, f"px{xaxis_pri}", f"py{yaxis_pri}"]
# make sure the figure has the same configuration by wrapping it in a with condition
fig = pygmt.Figure()
with pygmt.config(FONT_LABEL = f"{fnt_wght_lab},{fnt_type}",
FONT = f"{fnt_wght_ax},{fnt_type}",
MAP_GRID_PEN_PRIMARY = grid_wght_pri,
MAP_GRID_PEN_SECONDARY = grid_wght_sec,
FORMAT_TIME_PRIMARY_MAP = fmt_dt_pri,
FORMAT_TIME_SECONDARY_MAP = fmt_dt_sec,
FORMAT_DATE_MAP = fmt_dt_map):
fig.basemap(projection=projection, region=region, frame=frame)
# loop over each key in the dictionary and if keys2plot is defined only plot those dictionaries
cnt=0
for i, (dict_key, data) in enumerate(ts_dict.items()):
self.logger.info(f"extracting data array from ts_dict[{dict_key} : {primary_key}]")
da = data[primary_key]
self.logger.info(f"matching {dict_key} with JSON-toolbox-config dictionary for line color and legend label")
# removed option to pass 'line_clr' and 'leg_lab' with ts_dict
line_color = self.plot_var_dict.get(dict_key, {}).get("line_clr", f"C{i}")
leg_lab = self.plot_var_dict.get(dict_key, {}).get("leg_lab", dict_key )
# skip if keys2plot specified
if keys2plot is not None and dict_key not in keys2plot:
continue
# choose color/label: explicit override > plot_var_dict > default cycle
line_color = (line_colors.get(dict_key)
or self.plot_var_dict.get(dict_key, {}).get("line_clr")
or f"C{cnt}")
leg_lab = (legend_labels.get(dict_key)
or self.plot_var_dict.get(dict_key, {}).get("leg_lab")
or dict_key)
cnt += 1
self.logger.info(f" legend label: {leg_lab}")
self.logger.info(f" line color : {line_color}")
if dict_key=="AF2020" and primary_key=="FIA":
df = pd.DataFrame({"time": pd.to_datetime(da[time_coord_alt].values), "data": da.values})
else:
df = pd.DataFrame({"time": pd.to_datetime(da[time_coord].values), "data": da.values})
if climatology:
if dict_key=="AF2020" and primary_key=="FIA":
clim = self.compute_doy_climatology(da, time_coord=time_coord_alt)
else:
clim = self.compute_doy_climatology(da)
mean_x = clim['mean'].index
mean_y = clim['mean'].values
if clim_smooth is not None and clim_smooth > 1:
mean_y = (pd.Series(mean_y, index=mean_x)
.rolling(window=clim_smooth, center=True, min_periods=1)
.mean()
.values)
fig.plot(x = np.concatenate([clim['min'].index, clim['max'].index[::-1]]),
y = np.concatenate([clim['min'].values, clim['max'].values[::-1]]),
fill = f"{line_color}@80",
close = True,
transparency = 80)
fig.plot(x = mean_x,
y = mean_y,
pen = f"{line_pen},{line_color}",
label = leg_lab)
else:
if dict_key in repeat_keys:
# target x-range (union for figure)
t_start, t_end = tmin.normalize(), tmax.normalize()
full_range = pd.date_range(t_start, t_end, freq="D")
# repeated climatology over the full range
rep_df = self._repeat_doy_over_range(df, full_range) # cols: time, rep
# original reindexed over full range
base = pd.DataFrame({"time": full_range}).merge(df.rename(columns={"data": "orig"}), on="time", how="left")
mix = base.merge(rep_df, on="time", how="left")
if repeat_policy == "always":
out_df = mix[["time", "rep"]].rename(columns={"rep": "data"})
elif repeat_policy == "fill_gaps":
mix["data"] = mix["orig"].fillna(mix["rep"])
out_df = mix[["time", "data"]]
elif repeat_policy in ("outside_others", "inside_others"):
# Determine the "others" window
if repeat_ref_keys is not None:
ref_keys = set(repeat_ref_keys)
else:
ref_keys = set(k for k in ts_dict.keys()
if k != dict_key and (keys2plot is None or k in keys2plot))
if len(ref_keys) == 0:
# no refs -> behave like fill_gaps (but still allow inside_others to clip nothing)
mix["data"] = mix["orig"].fillna(mix["rep"])
if repeat_policy == "inside_others":
mix.loc[:, "data"] = np.nan # nothing to show if no reference window
out_df = mix[["time", "data"]]
else:
# compute min/max over "others"
other_mins, other_maxs = [], []
for k in ref_keys:
d2 = ts_dict[k][primary_key]
tt = pd.to_datetime(d2[time_coord].values)
if len(tt) > 0:
other_mins.append(tt.min()); other_maxs.append(tt.max())
other_min = pd.to_datetime(min(other_mins)).normalize()
other_max = pd.to_datetime(max(other_maxs)).normalize()
inside = (mix["time"] >= other_min) & (mix["time"] <= other_max)
# inside the others’ window: use original where present, else repeated
mix.loc[inside, "data"] = mix.loc[inside, "orig"].fillna(mix.loc[inside, "rep"])
if repeat_policy == "inside_others":
# outside -> hide (clip)
mix.loc[~inside, "data"] = np.nan
else:
# outside -> use repeated climatology
mix.loc[~inside, "data"] = mix.loc[~inside, "rep"]
out_df = mix[["time", "data"]]
else:
raise ValueError(f"Unknown repeat_policy: {repeat_policy}")
# optional smoothing after composition
out_df = self._smooth_df_time(out_df, window=smooth)
fig.plot(x=out_df["time"], y=out_df["data"], pen=f"{line_pen},{line_color}", label=leg_lab)
else:
df_sm = self._smooth_df_time(df, window=smooth)
fig.plot(x=df_sm["time"], y=df_sm["data"], pen=f"{line_pen},{line_color}", label=leg_lab)
if zero_line:
y0 = float(zero_line_level)
y_min, y_max = ylim
if (y_min <= y0 <= y_max):
x0, x1 = region[0], region[1]
fig.plot(x=[x0, x1], y=[y0, y0], pen=zero_line_pen)
fig.legend(position=legend_pos, box=legend_box)
if save_fig:
F_png = f"{primary_key}_{comp_name}_{'climatology_' if climatology else ''}{tmin.strftime('%Y')}-{tmax.strftime('%Y')}.png"
P_png = P_png if P_png is not None else Path(self.D_graph, "timeseries", F_png)
fig.savefig(P_png, dpi=300)
self.logger.info(f"saved figure to {P_png}")
if show_fig:
fig.show()
[docs]
def pygmt_map_plot_multi_var_8sectors(
self,
das,
var_names,
sim_name = None,
time_stamp = None,
tit_str = None,
panel_titles = None,
plot_GI = False,
diff_plots = None,
cmaps = None,
series_list = None,
reverse_list = None,
cbar_labels = None,
cbar_units_list = None,
extend_cbars = False,
cbar_positions = None,
lon_coord_names = None,
lat_coord_names = None,
use_bcoords = False,
use_tcoords = False,
fig_size = None,
var_sq_size = 0.2,
GI_sq_size = 0.1,
GI_fill_color = "red",
plot_iceshelves = True,
plot_bathymetry = True,
land_color = None,
water_color = None,
P_png = None,
var_out = None,
overwrite_fig = None,
show_fig = None,
xshift = "w+1c"):
"""
8-sector plot with 2–3 panels laid out left-to-right using shift_origin.
Performance: compute PyGMT plot point clouds ONCE per panel (u1/u2/du),
then sector-filter the point clouds cheaply (no repeated Dask compute).
"""
import pygmt
def _as_list(x, n, name):
if x is None:
return [None] * n
if isinstance(x, (list, tuple)):
if len(x) != n:
raise ValueError(f"{name} must have length {n} (got {len(x)})")
return list(x)
return [x] * n
def _as_bool_list(x, n, name):
if x is None:
return [False] * n
if isinstance(x, (list, tuple)):
if len(x) != n:
raise ValueError(f"{name} must have length {n} (got {len(x)})")
return [bool(v) for v in x]
return [bool(x)] * n
def _sector_png_path(P, sector_name):
if P is None:
return None
P = Path(P)
if P.suffix == "":
return P / f"{sector_name}.png"
return P.parent / f"{P.stem}_{sector_name}{P.suffix}"
def _norm360(a):
return np.mod(a, 360.0)
def _lon_in_range_1d(lon, lon_min, lon_max):
lon = _norm360(lon)
lon_min = lon_min % 360.0
lon_max = lon_max % 360.0
if lon_min <= lon_max:
return (lon >= lon_min) & (lon <= lon_max)
else:
# crosses dateline
return (lon >= lon_min) | (lon <= lon_max)
# --- validate ---
if not isinstance(das, (list, tuple)) or not isinstance(var_names, (list, tuple)):
raise TypeError("das and var_names must be lists/tuples")
n_pan = len(das)
if n_pan not in (2, 3):
raise ValueError(f"das must contain 2 or 3 DataArrays (got {n_pan})")
if len(var_names) != n_pan:
raise ValueError("var_names must match length of das")
for j, da in enumerate(das):
if hasattr(da, "dims") and "time" in da.dims:
raise ValueError(f"das[{j}] has 'time' dim; pass a 2D time-slice.")
# --- defaults ---
sim_name = sim_name if sim_name is not None else self.sim_name
show_fig = show_fig if show_fig is not None else self.show_fig
ow_fig = overwrite_fig if overwrite_fig is not None else self.ow_fig
time_stamp = time_stamp if time_stamp is not None else self.dt0_str
fig_size = fig_size if fig_size is not None else self.pygmt_dict["fig_size"]
land_color = land_color if land_color is not None else self.pygmt_dict["land_color"]
water_color = water_color if water_color is not None else self.pygmt_dict["water_color"]
if panel_titles is None:
panel_titles = list(var_names)
diff_plots_l = _as_bool_list(diff_plots, n_pan, "diff_plots")
cmaps_l = _as_list(cmaps, n_pan, "cmaps")
series_l = _as_list(series_list, n_pan, "series_list")
reverse_l = _as_list(reverse_list, n_pan, "reverse_list")
cbar_labels_l = _as_list(cbar_labels, n_pan, "cbar_labels")
cbar_units_l = _as_list(cbar_units_list, n_pan, "cbar_units_list")
extend_cbars_l = _as_bool_list(extend_cbars, n_pan, "extend_cbars")
cbar_positions_l = _as_list(cbar_positions, n_pan, "cbar_positions")
lon_names_l = _as_list(lon_coord_names, n_pan, "lon_coord_names")
lat_names_l = _as_list(lat_coord_names, n_pan, "lat_coord_names")
use_bcoords_l = _as_bool_list(use_bcoords, n_pan, "use_bcoords")
use_tcoords_l = _as_bool_list(use_tcoords, n_pan, "use_tcoords")
# overlays loaded once
if plot_iceshelves:
ANT_IS = self.load_ice_shelves()
if plot_bathymetry:
SO_BATH = self.load_IBCSO_bath()
if plot_GI:
plot_GI_dict = self.load_GI_lon_lats()
# ensure grid loaded (for tcoords/bcoords path in pygmt_da_prep)
self.load_cice_grid(slice_hem=True)
reg_dict = self.Ant_8sectors
if var_out is None:
var_out = "_".join(var_names)
# --- BIG SPEED WIN: compute plot point clouds ONCE per panel ---
self.logger.info("precomputing plot point clouds (once per panel) ...")
pdicts = []
for j in range(n_pan):
if (lon_names_l[j] is not None) and (lat_names_l[j] is not None):
pd = self.pygmt_da_prep(
das[j],
bcoords=False,
tcoords=False,
lon_coord_name=lon_names_l[j],
lat_coord_name=lat_names_l[j],
diff_plot=diff_plots_l[j],
)
else:
if use_bcoords_l[j] and use_tcoords_l[j]:
raise ValueError("Cannot set both use_bcoords and use_tcoords True")
pd = self.pygmt_da_prep(
das[j],
bcoords=use_bcoords_l[j],
tcoords=use_tcoords_l[j] if (use_bcoords_l[j] is False) else False,
diff_plot=diff_plots_l[j],
)
pdicts.append(pd)
# Precompute sector masks ONCE using lon/lat from panel 0 (assumes common grid)
lon0 = pdicts[0]["lon"]
lat0 = pdicts[0]["lat"]
if not hasattr(self, "_sector_point_idx_cache"):
self._sector_point_idx_cache = {}
sector_idx = {}
for reg_name, reg_vals in reg_dict.items():
region = reg_vals["plot_region"] # [lonmin, lonmax, latmin, latmax]
lonmin, lonmax, latmin, latmax = region
m = _lon_in_range_1d(lon0, lonmin, lonmax) & (lat0 >= latmin) & (lat0 <= latmax)
idx = np.flatnonzero(m)
sector_idx[reg_name] = idx
# --- plot ---
for reg_name, reg_vals in reg_dict.items():
idx = sector_idx.get(reg_name, None)
if idx is None or idx.size == 0:
self.logger.warning(f"{reg_name}: no points in this sector after masking; skipping.")
continue
# output path per sector
P_png_reg = _sector_png_path(P_png, reg_name)
if P_png_reg is None and self.save_fig:
P_png_reg = Path(self.D_graph, sim_name, reg_name, var_out,
f"{time_stamp}_{sim_name}_{reg_name}_{var_out}.png")
region = reg_vals["plot_region"]
MC = self.get_meridian_center_from_geographic_extent(region)
projection = reg_vals["projection"].format(MC=MC, fig_size=fig_size)
fig = pygmt.Figure()
with pygmt.config(
FONT_TITLE="16p,Courier-Bold",
FONT_ANNOT_PRIMARY="14p,Helvetica",
COLOR_FOREGROUND="black",
):
for j in range(n_pan):
if j > 0:
fig.shift_origin(xshift=xshift, yshift="0c")
vn = var_names[j]
# per-panel defaults
cmap_j = cmaps_l[j] if cmaps_l[j] is not None else self.plot_var_dict[vn]["cmap"]
series_j = series_l[j] if series_l[j] is not None else self.plot_var_dict[vn]["series"]
rev_j = reverse_l[j] if reverse_l[j] is not None else self.plot_var_dict[vn]["reverse"]
lab_j = cbar_labels_l[j] if cbar_labels_l[j] is not None else self.plot_var_dict[vn]["name"]
unit_j = cbar_units_l[j] if cbar_units_l[j] is not None else self.plot_var_dict[vn]["units"]
cbar_pos_j = cbar_positions_l[j]
if cbar_pos_j is None:
cbar_pos_j = self.pygmt_dict["cbar_pos"].format(width=fig_size * 0.8, height=0.75)
# basemap frame
title_here = panel_titles[j] if panel_titles is not None else vn
if (j == 0) and (tit_str is not None):
frame = ["af", f"+t{tit_str}"]
else:
frame = ["af", f"+t{title_here}"]
fig.basemap(region=region, projection=projection, frame=frame)
if plot_bathymetry:
fig.grdimage(grid=SO_BATH, cmap="geo")
else:
fig.coast(region=region, projection=projection,
shorelines="1/0.5p,gray30", land=land_color, water=water_color)
# Plot points for this sector
pd = pdicts[j]
pygmt.makecpt(cmap=cmap_j, reverse=rev_j, series=series_j)
fig.plot(
x=pd["lon"][idx],
y=pd["lat"][idx],
fill=pd["data"][idx],
style=f"s{var_sq_size}c",
cmap=True,
)
# overlays
if plot_bathymetry:
fig.coast(region=region, projection=projection, shorelines="1/0.5p,gray30")
if plot_GI:
fig.plot(x=plot_GI_dict["lon"], y=plot_GI_dict["lat"],
fill=GI_fill_color, style=f"c{GI_sq_size}c")
if plot_iceshelves:
fig.plot(data=ANT_IS, pen="0.2p,gray", fill="lightgray")
# colorbar
cbar_frame = self.create_cbar_frame(series_j, lab_j, units=unit_j, extend_cbar=extend_cbars_l[j])
fig.colorbar(position=cbar_pos_j, frame=cbar_frame)
# save/show
if P_png_reg:
P_png_reg.parent.mkdir(parents=True, exist_ok=True)
if (not P_png_reg.exists()) or ow_fig:
fig.savefig(P_png_reg)
self.logger.info(f"Saved figure to {P_png_reg}")
else:
self.logger.info(f"{P_png_reg} already exists and not overwriting")
if show_fig:
fig.show()