Source code for sea_ice_toolbox

from __future__ import annotations
import os, sys, logging, warnings, shutil
import xarray as xr
import pandas as pd
import numpy  as np
from contextlib import contextmanager
from pathlib    import Path
_dev_path = "/home/581/da1339/AFIM/src/AFIM/src"
if os.path.isdir(_dev_path) and _dev_path not in sys.path:
    sys.path.insert(0, _dev_path)
from sea_ice_classification import SeaIceClassification
from sea_ice_metrics        import SeaIceMetrics
from sea_ice_plotter        import SeaIcePlotter
from sea_ice_icebergs       import SeaIceIcebergs
from sea_ice_observations   import SeaIceObservations
from sea_ice_ACCESS         import SeaIceACCESS
from sea_ice_gridwork       import SeaIceGridWork
from sea_ice_regridder      import SeaIceRegridder
from sea_ice_fast           import SeaIceFast
from sea_ice_cice           import SeaIceCICE
from sea_ice_waves          import SeaIceWaves

__all__ = ["SeaIceToolbox", "SeaIceToolboxManager"]

##################################################################################################

[docs] class SeaIceToolboxManager: """ Factory for creating `SeaIceToolbox` instances backed by a shared Dask client. `SeaIceToolboxManager` maintains a **class-level** Dask client so that multiple toolboxes (e.g., multiple simulations or multiple analyses) reuse a single scheduler and worker pool. This is particularly useful on HPC login/compute nodes where repeatedly creating local clusters is slow and can fragment resources. Parameters ---------- P_log : str or pathlib.Path Path to a log file. The created toolboxes attach a file handler to this path. n_workers : int, default 4 Number of Dask workers for the shared LocalCluster. n_threads : int, default 1 Threads per worker. mem_lim : str, default "16GB" Worker memory limit passed to Dask (e.g., "8GB", "16GB"). process : bool, default True Whether to use separate worker processes (True) or threads-only workers (False). D_dask : str, optional Local directory for Dask worker spill and temporary files. If None, uses `DASK_TEMPORARY_DIRECTORY` or the system temp directory. Notes ----- - The shared client is cached on the class (`SeaIceToolboxManager._shared_client`). - `shutdown()` closes the shared client and attempts to close file handlers. - `get_toolbox(sim_name, **kwargs)` returns a preconfigured `SeaIceToolbox` attached to the shared client and log file. See Also -------- SeaIceToolbox Unified AFIM sea-ice analysis class composed of multiple SeaIce* mixins. """ _shared_client = None def __init__(self, P_log, n_workers : int = 4, n_threads : int = 1, mem_lim : str = "16GB", process : bool = True, D_dask : str = None): """ Create (or reuse) a shared Dask client for AFIM sea-ice analysis. If a shared client has not yet been created, this initialiser creates a `dask.distributed.LocalCluster` and a corresponding `Client`, storing it at the class level so that subsequent instances reuse the same client. Parameters ---------- P_log : str or pathlib.Path Log file path to pass through to created `SeaIceToolbox` instances. n_workers : int, default 4 Number of Dask workers. n_threads : int, default 1 Threads per worker. mem_lim : str, default "16GB" Worker memory limit (Dask "memory_limit"). process : bool, default True If True, use multiple processes. If False, use threads-only workers. D_dask : str, optional Local directory for worker temporary files and spill. If None, uses `DASK_TEMPORARY_DIRECTORY` or the system temp directory. Notes ----- - This method does not return a `Client`; access it via `SeaIceToolboxManager._shared_client`. - The shared client is created once per Python process. """ import tempfile from dask.distributed import Client, LocalCluster D_dask = D_dask if D_dask is not None else os.environ.get("DASK_TEMPORARY_DIRECTORY", tempfile.gettempdir()) self.P_log = P_log if SeaIceToolboxManager._shared_client is None: LocCls = LocalCluster(n_workers = n_workers, threads_per_worker = n_threads, # use all CPUs processes = process, # threads, not processes memory_limit = mem_lim, # no nanny hard limit local_directory = D_dask, dashboard_address = None) SeaIceToolboxManager._shared_client = Client(LocCls)
[docs] def shutdown(self): """ Shut down the shared Dask client and close file log handlers. Closes the class-level Dask client if it exists and resets the cached reference. Also attempts to close any `logging.FileHandler` instances attached to the `SeaIceToolbox` logger. Notes ----- - Intended for interactive sessions to avoid orphaned local clusters. - If multiple loggers/handlers are in use, only file handlers on the named logger are closed. """ import logging if SeaIceToolboxManager._shared_client is not None: SeaIceToolboxManager._shared_client.close() print("Dask client shut down.") SeaIceToolboxManager._shared_client = None logger = logging.getLogger("SeaIceToolbox") for h in logger.handlers[:]: if isinstance(h, logging.FileHandler): h.flush(); h.close(); logger.removeHandler(h) print(f"Closed log file handler: {getattr(h, 'baseFilename', h)}")
[docs] def get_toolbox(self, sim_name, **kwargs): """ Create a `SeaIceToolbox` for a given simulation using the shared Dask client. Parameters ---------- sim_name : str Simulation name (must correspond to a configured simulation directory). **kwargs Additional keyword arguments forwarded to `SeaIceToolbox(...)`. Use this to override date ranges, thresholds, plotting options, etc. Returns ------- SeaIceToolbox Toolbox instance bound to the shared Dask client and configured log file. Raises ------ ValueError If the shared Dask client has not been created successfully. """ return SeaIceToolbox(sim_name = sim_name, client = SeaIceToolboxManager._shared_client, # explicit P_log = self.P_log, **kwargs)
####################################################################################################################
[docs] class SeaIceToolbox(SeaIceClassification, SeaIceMetrics, SeaIcePlotter, SeaIceIcebergs, SeaIceObservations, SeaIceACCESS, SeaIceGridWork, SeaIceRegridder, SeaIceFast, SeaIceCICE, SeaIceWaves): """ Unified AFIM toolbox for processing and analysing Antarctic sea ice from CICE. `SeaIceToolbox` composes several functional mixins into a single interface for reading CICE output, classifying fast/pack ice, computing metrics, regridding, working with grounded-iceberg masks, and producing plots. Inheritance / Composition ------------------------- SeaIceToolbox inherits from the following modules: - SeaIceClassification : classification masks and simulation I/O helpers - SeaIceMetrics : time-series and spatial metrics, skill statistics - SeaIcePlotter : PyGMT maps and time series plotting utilities - SeaIceIcebergs : grounded-iceberg thinning/masking and GI datasets - SeaIceObservations : Fraser et al. (2020) and NSIDC observational utilities - SeaIceACCESS : ACCESS-OM related helpers (where applicable) - SeaIceGridWork : grid geometry, landmask application, hemisphere slicing - SeaIceRegridder : B-grid → T-grid and swath/gridded regridding utilities Configuration ------------- The toolbox is configured primarily by an AFIM JSON file. The constructor loads that JSON, sets commonly used directories (simulation output, zarr, metrics, figures), defines hemisphere behaviour, and stores thresholds used throughout the pipeline. Parameters ---------- P_json : str or pathlib.Path, optional Path to the AFIM configuration JSON. If None, a project default path is used. P_CICE_grid : str or pathlib.Path, optional Override path to the CICE grid file; otherwise uses config entry `CICE_dict['P_G']`. sim_name : str Simulation name (must exist under the configured AFIM output root). dt0_str, dtN_str : str, optional Inclusive analysis window bounds in ``YYYY-MM-DD``. list_of_BorC2T : list[str], optional Speed/vector products to use (e.g., ["Tb"], ["Ta","Tx"], or ["B"]). iceh_frequency : {"hourly","daily","monthly","yearly"}, optional Which CICE history cadence to use for iceh inputs. ice_concentration_threshold : float, optional Concentration threshold used for masking / metrics (default from config; often 0.15). ice_speed_threshold : float, optional Speed threshold (m/s) below which ice is treated as “fast” (default from config). ice_type : str or list[str], optional Ice classification type(s) to process (e.g., "FI", "PI", "SI"). mean_period : int, optional Rolling mean window length (days) used in some classification products. bin_win_days, bin_min_days : int, optional Binary-days window length and minimum count used for persistence classification. extra_cice_vars : list[str] or bool, optional Additional variables to include when loading/processing beyond `cice_vars_reqd`. hemisphere : str, optional Hemisphere selector. Common aliases are accepted (e.g., "south", "SH", "nh"). P_log : str or pathlib.Path, optional Log file path to attach a `FileHandler` to. log_level : int or str, optional Python logging level for the toolbox logger. dask_memory_limit : str, optional Memory limit (informational here unless you create the client externally). overwrite_zarr : bool, optional If True, overwrite Zarr groups when writing classification/metrics products. overwrite_saved_figs : bool, optional If True, overwrite existing figure files. save_new_figs, show_figs : bool, optional Figure saving and interactive display toggles. delete_original_cice_iceh_nc : bool, optional If True, delete original NetCDF inputs after conversion to monthly Zarr (where implemented). client : dask.distributed.Client, optional Dask client to use. In this implementation, a client must already exist and be passed in. force_recompile_ice_in : bool, default False Force regeneration of derived simulation metadata (e.g., parsing/rebuilding ice_in JSON). **kwargs Additional keyword arguments are attached to the instance so mixins can access specialised tunables without changing the constructor signature. Attributes Set (high-level) --------------------------- - Configuration dicts: `self.CICE_dict`, `self.GI_dict`, `self.NSIDC_dict`, etc. - Directory paths: `self.D_sim`, `self.D_zarr`, `self.D_metrics`, etc. - Hemisphere metadata: `self.hemisphere_dict`, `self.hemisphere` - Thresholds/settings: `self.ispd_thresh`, `self.icon_thresh`, `self.mean_period`, etc. - State flags: `self.grid_loaded`, `self.reG_weights_defined`, etc. Notes ----- - The constructor performs substantial configuration and path normalisation and therefore has side effects (file IO, logger configuration). - Grids and regridders are not loaded/built until needed (`load_cice_grid`, `define_reG_weights`, etc.). """
[docs] def summary(self): """ Log a concise summary of key configuration and runtime settings. This is primarily a convenience method for sanity checking during interactive work and batch runs. It writes configuration metadata to `self.logger`. Notes ----- - No values are returned. - Assumes `self.logger` has already been configured and that core attributes (sim_name, date bounds, thresholds, toggles) are present. """ self.logger.info("--- SeaIceToolbox Summary ---") self.logger.info(f"Simulation Name : {self.sim_name}") self.logger.info(f"Analysis Start Date : {self.dt0_str}") self.logger.info(f"Analysis End Date : {self.dtN_str}") self.logger.info(f"grid file : {self.CICE_dict['P_G']}") self.logger.info(f"landmask file : {self.P_KMT_org}") self.logger.info(f"Using GI? : {self.use_gi}") if self.use_gi: self.logger.info(f"modified landmask file: {self.P_KMT_mod}") self.logger.info(f"Speed Threshold : {self.ispd_thresh:.1e} m/s") self.logger.info(f"BorC-regrid Type(s) : {self.BorC2T_type}") self.logger.info(f"Ice Type(s) : {self.ice_type}") self.logger.info(f"Mean Period : {self.mean_period} days") self.logger.info(f"Binary-days Window : {self.bin_win_days} days") self.logger.info(f"Binary-days Min-Days: {self.bin_min_days}") self.logger.info(f"Overwrite Zarr : {self.overwrite_zarr_group}") self.logger.info(f"Save Figures : {self.save_fig}") self.logger.info(f"Show Figures : {self.show_fig}") self.logger.info(f"Hemisphere : {self.hemisphere}") self.logger.info("------------------------------")
def __init__(self, P_json = None,# the configuration file for which there are many dependencies # that this toolbox relies upon P_CICE_grid = None,# name of the CICE grid file to be used; default is in JSON file sim_name = None,# valid name of a model simulation; essentially 'valid' means # any name given underneath the directory in the config file # named 'AFIM_out'; the sea_ice_model class underneath the hood # of this super-class relies on this name for processing and # loading of simulation data dt0_str = None,# the start period over which many methods underneath use # format is YYYY-MM-DD; default is 1993-01-01 dtN_str = None,# the end period over which many methods underneath use # format is YYYY-MM-DD; default is 1999-12-31 list_of_BorC2T = None,# select list of ["B", "Ta", "Tb", "Tc", "Tx"]; must be a list; # default ["Tb"] iceh_frequency = None,# 'hourly', 'daily', 'monthly', 'yearly' # defines the history files that will be used by this toolbox ice_concentration_threshold = None,# almost should never be changed from default value of # 0.15 (grid cell concentration) ice_speed_threshold = None,# a significantly important value in the determination # and classification (masking) of fast ice; defualt value # 5e-4 m/s ice_type = None,# a valid ice_type or list thereof mean_period = None,# rolling average, N-days bin_win_days = None,# the window of days with which to apply binary-days method bin_min_days = None,# minimum number of days binary-days extra_cice_vars = None,# these will be included in the fast ice mask # that is, in addtion to those listed in the # in config file 'cice_vars_reqd'; default is # also a list in config file 'FI_cice_vars_ext' # can be set to True and will use those listed in the JSON config file hemisphere = None,# used in many ares of the toolbox to define the hemisphere # that the user is interested; unfortunately, the toolbox # does not allow for a user to be interested in both at the # same time; either 'south' or 'north'; defualt is 'south' P_log = None ,# the log file to send print statements to log_level = None ,# the logging level (see python logging doc for more info) dask_memory_limit = None ,# provide the memory limit to dask, default is 16GB overwrite_zarr = False,# whether or not to overwrite a zarr; default is false overwrite_saved_figs = False,# whether or not to overwite saved figures; default is false save_new_figs = True ,# whether or not to write new figures to disk; default is true show_figs = False,# whether or not to show/print figures to screen; default is false delete_original_cice_iceh_nc= False,# whether or not to delete the original CICE ice history client = None ,# dask distributed client, can be externally passed here force_recompile_ice_in = False,# reinitialise ice_in JSON file; see help self.parse_simulation_metadata() **kwargs): """ Initialise the unified AFIM sea-ice toolbox for a given simulation/config. The constructor loads the AFIM JSON configuration file, normalises user-provided overrides (simulation name, date bounds, thresholds, hemisphere), defines key directories, parses simulation metadata, and initialises the composed mixins. Parameters ---------- P_json : str or pathlib.Path, optional Path to the AFIM configuration JSON. If None, uses a project default path. P_CICE_grid : str or pathlib.Path, optional Optional override for the CICE grid file path used by `load_cice_grid`. sim_name : str, optional Simulation name used to resolve directories under AFIM output root. dt0_str, dtN_str : str, optional Inclusive analysis window bounds in ``YYYY-MM-DD``. list_of_BorC2T : list[str], optional Speed/vector product identifiers (e.g., ["Tb"], ["Ta","Tx"], ["B"]). iceh_frequency : {"hourly","daily","monthly","yearly"}, optional CICE history cadence to use for loading inputs and/or locating Zarr stores. ice_concentration_threshold : float, optional Concentration threshold used in masking/metrics (often 0.15). ice_speed_threshold : float, optional Speed threshold (m/s) used to classify fast ice (e.g., 5e-4). ice_type : str or list[str], optional Ice classification product(s) to work with ("FI", "PI", "SI", etc.). mean_period : int, optional Rolling mean window length (days) used for smoothed classification products. bin_win_days, bin_min_days : int, optional Binary-days window length and minimum count for persistence-based masks. extra_cice_vars : list[str] or bool, optional Additional CICE variables to include beyond `cice_vars_reqd`. If True, use config `cice_vars_ext`. If a list, extend required list with that list. hemisphere : str, optional Hemisphere selector (accepts common aliases). P_log : str or pathlib.Path, optional Log file path to attach to the toolbox logger. log_level : int or str, optional Logging verbosity level. dask_memory_limit : str, optional Informational; the client is expected to be created externally. overwrite_zarr : bool, optional Overwrite Zarr groups when writing outputs. overwrite_saved_figs : bool, optional Overwrite existing figure files. save_new_figs, show_figs : bool, optional Figure saving and interactive display toggles. delete_original_cice_iceh_nc : bool, optional Delete original NetCDF after conversion to Zarr (where implemented). client : dask.distributed.Client, optional Dask client for computation. This implementation requires an existing client. force_recompile_ice_in : bool, default False Force regeneration/reparse of simulation metadata. **kwargs Additional keyword arguments are attached to the instance and available to mixins. Raises ------ FileNotFoundError If `P_json` is provided but cannot be opened. ValueError If a Dask client is not provided (and no manager-created client is present), or if hemisphere is invalid. KeyError If required keys are missing from the JSON configuration. Side Effects ------------ - Reads JSON configuration and may open/create the log file path. - Configures `self.logger` (stream handler + optional file handler). - Defines many path attributes and state flags. - Initialises mixin classes. Notes ----- - Heavyweight initialisation is intentional; it ensures all mixins share a consistent configuration namespace. - Grids and regridders are not loaded until requested by other methods. """ import json # essentially high-level administrative work: self.methods_init_executed = {} self.sim_name = sim_name if sim_name is not None else 'test' if P_json is None: P_json = '/home/581/da1339/AFIM/src/AFIM/src/JSONs/sea_ice_config.json' with open(P_json, 'r') as f: self.config = json.load(f) hemisphere = hemisphere or self.config.get('hemisphere', 'south') self.D_dict = self.config.get('D_dict', {}) D_log = self.D_dict['logs'] P_log = P_log if P_log is not None else Path(D_log, f'SeaIceToolbox_{self.sim_name}.log') log_level = log_level if log_level is not None else logging.INFO if not os.path.exists(P_log): os.system(f"touch {P_log}") if not hasattr(self, 'logger'): self.setup_logging(logfile=P_log, log_level=log_level) dask_memory_limit = dask_memory_limit if dask_memory_limit is not None else "7GB" if client is not None: self.client = client if not hasattr(self, "client"): raise ValueError("Dask client must be provided explicitly or via manager.") self.logger.info(f"Dask Client Connected\n" f" Dashboard : {self.client.dashboard_link}\n" f" Threads : {sum(w['nthreads'] for w in client.scheduler_info()['workers'].values())}\n" f" Threads/Worker : {[w['nthreads'] for w in client.scheduler_info()['workers'].values()]}\n" f" Total Memory : {sum(w['memory_limit'] for w in client.scheduler_info()['workers'].values()) / 1e9:.2f} GB\n") if sim_name=="__SI-toolbox-mgr__": return # now for the technical and sim-specific configurations self.class_types_dict = self.config.get("class_types_dict" , {}) self.CICE_dict = self.config.get("CICE_dict" , {}) self.GI_dict = self.config.get('GI_dict' , {}) self.Waves_dict = self.config.get('Waves_dict' , {}) self.NSIDC_dict = self.config.get('NSIDC_dict' , {}) self.BAS_dict = self.config.get('BAS_dict' , {}) self.AF_FI_dict = self.config.get("AF_FI_dict" , {}) self.Sea_Ice_Obs_dict = self.config.get("Sea_Ice_Obs_dict" , {}) self.AOM2_dict = self.config.get("AOM2_dict" , {}) self.MOM_dict = self.config.get("MOM_dict" , {}) self.ERA5_dict = self.config.get("ERA5_dict" , {}) self.ORAS_dict = self.config.get("ORAS_dict" , {}) self.plot_var_dict = self.config.get("plot_var_dict" , {}) self.hemispheres_dict = self.config.get("hemispheres_dict" , {}) self.Ant_8sectors = self.config.get('Ant_8sectors' , {}) self.Ant_2sectors = self.config.get('Ant_2sectors' , {}) self.pygmt_dict = self.config.get("pygmt_dict" , {}) self.pygmt_FIA_dict = self.config.get('pygmt_FIA_dict' , {}) self.pygmt_FI_panel = self.config.get('pygmt_FI_panel' , {}) self.dt0_str = dt0_str or self.config.get('dt0_str', '1993-01-01') self.dtN_str = dtN_str or self.config.get('dtN_str', '1999-12-31') self.ispd_thresh = ice_speed_threshold or self.config.get('ice_speed_thresh_hi', 5.0e-4) self.BorC2T_type = list_of_BorC2T or self.config.get('BorC2T_type', ['Tb']) self.ice_type = ice_type or self.config.get('ice_type', 'FI') self.iceh_freq = iceh_frequency or self.config.get('iceh_freq', 'daily') self.mean_period = mean_period or self.config.get('mean_period', 15) self.bin_win_days = bin_win_days or self.config.get('bin_win_days', 11) self.bin_min_days = bin_min_days or self.config.get('bin_min_days', 9) self.icon_thresh = ice_concentration_threshold or self.config.get('ice_conc_thresh', 0.15) self.CICE_dict['P_G'] = P_CICE_grid or self.CICE_dict['P_G'] self.overwrite_zarr_group = overwrite_zarr self.ow_fig = overwrite_saved_figs self.save_fig = save_new_figs self.show_fig = show_figs self.del_org_cice_iceh_nc = delete_original_cice_iceh_nc self.leap_year = self.config.get("leap_year" , 1996) self.metrics_name = self.config.get("metrics_name" , "mets") self.valid_BorC2T_types = self.config.get("valid_BorC2T_types", []) self.valid_ice_types = self.config.get("valid_ice_types" , []) self.cice_vars_reqd = self.CICE_dict["cice_vars_reqd"] self.spatial_dims = self.CICE_dict["spatial_dims"] self.define_hemisphere(hemisphere) self.define_toolbox_paths() self.parse_simulation_metadata(force_recompile=force_recompile_ice_in) if extra_cice_vars is not None: if extra_cice_vars: self.cice_var_list = self.cice_vars_reqd + self.CICE_dict["cice_vars_ext"] else: self.cice_var_list = self.cice_vars_reqd + extra_cice_vars else: self.cice_var_list = self.cice_vars_reqd self.FIC_scale = self.config.get('FIC_scale', 1e9) self.SIC_scale = self.config.get('SIC_scale', 1e12) if self.CICE_dict['coupled']: self.P_KMT_org = Path(self.CICE_dict['P_KMT']) else: self.P_KMT_org = Path(self.GI_dict["D_GI_thin"],self.GI_dict['KMT_org_fmt']) if self.sim_config is not None: self.GI_thin = self.sim_config.get('GI_thin_fact') self.GI_version = self.sim_config.get('GI_version') self.GI_iteration = self.sim_config.get("GI_iter") if self.GI_thin is not None and self.GI_thin>0 and self.GI_version>0: self.use_gi = True GI_thin_str = f"{self.GI_thin:0.2f}".replace('.', 'p') GI_vers_str = f"{self.GI_version:0.2f}".replace('.', 'p') self.P_KMT_mod = os.path.join(self.GI_dict['D_GI_thin'], self.GI_dict['KMT_mod_fmt'].format(GI_thin = GI_thin_str, version = GI_vers_str, iteration = self.GI_iteration)) self.P_GI_thin = os.path.join(self.GI_dict['D_GI_thin'], self.GI_dict['GI_thin_fmt'].format(GI_thin = GI_thin_str, version = GI_vers_str, iteration = self.GI_iteration)) else: self.P_KMT_mod = self.P_KMT_org self.use_gi = False else: self.GI_thin = None self.GI_version = None self.GI_iteration = None self.use_gi = None self.reG_weights_defined = False self.modified_landmask_aligned = False self.grid_loaded = False SeaIceClassification.__init__(self, sim_name, **kwargs) SeaIceMetrics.__init__(self, **kwargs) SeaIcePlotter.__init__(self, **kwargs) SeaIceIcebergs.__init__(self, **kwargs) SeaIceObservations.__init__(self, **kwargs) SeaIceACCESS.__init__(self, **kwargs) self.summary() ####################################################################################################### ######################################### LOGGING ##################################### #######################################################################################################
[docs] def setup_logging(self, logfile=None, log_level=logging.INFO): """ Configure the toolbox logger with a stream handler and optional file handler. This method creates (or reuses) a named logger, sets the logging level, and ensures that (a) a console StreamHandler exists and (b) exactly one FileHandler is attached when `logfile` is provided. Parameters ---------- logfile : str or pathlib.Path, optional File to write log messages to. If None, only console logging is configured. log_level : int, default logging.INFO Logging level for the logger and its handlers. Notes ----- - Existing FileHandlers are removed and closed before attaching a new one. - `self.logger.propagate` is set False to avoid duplicate logs when parent loggers are configured elsewhere. """ logger_name = "sea_ice_classification" self.logger = logging.getLogger(logger_name) self.logger.setLevel(log_level) self.logger.propagate = False formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') # === Remove any old file handlers pointing to other files === if logfile: for h in list(self.logger.handlers): if isinstance(h, logging.FileHandler): self.logger.removeHandler(h) h.close() # === Add stream handler if none exists === if not any(isinstance(h, logging.StreamHandler) for h in self.logger.handlers): ch = logging.StreamHandler() ch.setFormatter(formatter) ch.setLevel(log_level) self.logger.addHandler(ch) # === Add (new) file handler === if logfile: fh = logging.FileHandler(logfile, mode='a') # always attach new one fh.setFormatter(formatter) fh.setLevel(log_level) self.logger.addHandler(fh) self.logger.info(f"log file connected: {logfile}")
@contextmanager def _suppress_large_graph_warning(self): with warnings.catch_warnings(): warnings.filterwarnings("ignore", message = r"Sending large graph of size.*", category = UserWarning, module = r"distributed\.client") yield def _method_name(self): import inspect return inspect.currentframe().f_back.f_code.co_name ####################################################################################################### ######################################### TRIGONOMETRY ####################################### #######################################################################################################
[docs] def radians_to_degrees(self, da): """ Convert radians to degrees. Parameters ---------- da : array-like Values in radians (NumPy array, xarray DataArray, or scalar). Returns ------- same type as `da` Values converted to degrees. """ return (da * 180) / np.pi
[docs] def cosine_vector_similarity(self, uo, vo, um, vm, eps=1e-12): """ Compute the cosine similarity between two vector fields (e.g., observed vs. modelled velocity vectors). This metric quantifies the directional alignment of the two vector fields without considering magnitude. A value of: - +1.0 means the vectors point in exactly the same direction, - 0.0 means the vectors are orthogonal (90° apart), - -1.0 means the vectors point in opposite directions. The cosine similarity is computed as the dot product of the two vectors, divided by the product of their magnitudes. Parameters ---------- uo, vo : xarray.DataArray Components of the observed vector field (e.g., `u` and `v` velocity components) in units of m/s. um, vm : xarray.DataArray Components of the modelled vector field (same units as `uo`, `vo`). eps : float, optional Small constant to prevent division by zero in regions where either vector magnitude is near-zero. Default is 1e-12. Returns ------- xarray.DataArray Cosine similarity between vectors, dimensionless, in the range [-1, 1]. Notes ----- - NaNs are returned where either the observed or modelled vector magnitude is near-zero. - This metric is **scale-invariant** — it compares **direction only**, not speed. - Particularly useful for evaluating sea ice drift direction skill, regardless of speed bias. """ dot_prod = um * uo + vm * vo obs_mag = np.sqrt(uo**2 + vo**2) mod_mag = np.sqrt(um**2 + vm**2) return dot_prod / xr.where((obs_mag * mod_mag) < eps, np.nan, obs_mag * mod_mag)
[docs] def vector_angle_diff(self, uo, vo, um, vm): """ Compute the signed angular difference (in radians) between two vector fields. This metric measures the angle by which the modelled vector differs from the observed vector. The difference is returned as a signed value in the range [-π, π], where: - 0 indicates perfect alignment, - +pi/2 indicates model is rotated 90° counter-clockwise from observation, - -pi/2 indicates 90° clockwise rotation, - +/-pi indicates vectors are anti-parallel. Parameters ---------- uo, vo : xarray.DataArray Components of the observed vector field (e.g., `u` and `v` velocity components), in m/s. um, vm : xarray.DataArray Components of the modelled vector field (in m/s). Returns ------- xarray.DataArray Signed angular difference (model - obs) in radians, bounded in [-pi, pi]. Notes ----- - The angular difference is computed using `arctan2` on the vector components. - Use `np.rad2deg()` to convert output to degrees, if desired. - This metric is useful for directional error analysis in drift or flow fields. """ ang_o = np.arctan2(vo, uo) ang_m = np.arctan2(vm, um) d = ang_m - ang_o return (d + np.pi) % (2 * np.pi) - np.pi
####################################################################################################### ##################################### SEA-ICE ANALYSIS CONFIG ######################################### #######################################################################################################
[docs] def define_hemisphere(self, hemisphere): """ Initialise hemisphere configuration used across grid slicing and plotting. Accepts common hemisphere aliases (e.g., "south", "SH", "nh") and maps them to the internal `self.hemisphere_dict` configuration (slicing indices and labels). Also converts `nj_slice` from a (start, stop) tuple into a Python `slice`. Parameters ---------- hemisphere : str Hemisphere selector. Accepted aliases include: - North: "north", "northern", "nh", "n" - South: "south", "southern", "sh", "s" Sets ---- hemisphere_dict : dict Hemisphere metadata dictionary loaded from the JSON config, with `nj_slice` stored as a Python `slice`. hemisphere : str Canonical lowercase hemisphere string used internally. Raises ------ ValueError If `hemisphere` does not match any accepted alias. Notes ----- Downstream methods rely on `self.hemisphere_dict['nj_slice']` for slicing and `self.hemisphere_dict['abbreviation']` for path construction. """ key = hemisphere.lower() if key in ['north', 'northern', 'nh', 'n', 'no']: self.hemisphere_dict = self.hemispheres_dict['north'] elif key in ['south', 'southern', 'sh', 's', 'so']: self.hemisphere_dict = self.hemispheres_dict['south'] else: raise ValueError(f"Invalid hemisphere '{hemisphere}'. Valid options are: " "['north', 'south', 'northern', 'southern', 'sh', 'nh', 'SH', 'NH']") start, stop = self.hemisphere_dict['nj_slice'] self.hemisphere_dict['nj_slice'] = slice(start, stop) self.hemisphere = key self.logger.info(f"hemisphere initialised: {self.hemisphere_dict['abbreviation']}")
[docs] def interpret_ice_speed_threshold(self, ispd_thresh=None, lat_thresh=-60): """ Translate the ice speed threshold into intuitive grid-scale metrics. Computes: - meters per day at the given speed, - the **median** grid-cell edge length south of `lat_thresh` on the model grid (`self.CICE_dict['P_G']`), - displacement as a fraction of a grid cell per day, - days required to traverse a grid cell at the threshold speed. Parameters ---------- ispd_thresh : float, optional Threshold in m/s. Defaults to ``self.ispd_thresh``. lat_thresh : float, default -60 Latitude (degrees) used to select the polar region for the median cell size. Returns ------- dict Summary metrics with keys: ``{'ice_speed_thresh_m_per_s', 'displacement_m_per_day', 'median_grid_cell_length_m', 'percent_displacement_per_day', 'days_per_grid_cell'}``. Notes ----- - Uses CICE grid variables from ``self.CICE_dict['P_G']`` (radians) and converts to degrees for masking. - Logs a human-readable summary via `self.logger`. """ ispd_thresh = ispd_thresh if ispd_thresh is not None else self.ispd_thresh m_per_day = ispd_thresh * 86400 # meters/day area_da = xr.open_dataset( self.CICE_dict["P_G"] )['tarea'] # [m^2] lat_da = self.radians_to_degrees( xr.open_dataset(self.CICE_dict["P_G"])['tlat'] ) # [degrees] mask = lat_da < lat_thresh area_vals = area_da.where(mask).values grid_lengths = np.sqrt(area_vals) grid_lengths = grid_lengths[np.isfinite(grid_lengths)] if len(grid_lengths) == 0: raise ValueError("No valid grid cells found south of the specified latitude.") GC_len_median = np.median(grid_lengths) pct_GC_disp = m_per_day / GC_len_median days_per_GC = GC_len_median / m_per_day self.logger.info(f"Ice speed threshold : {ispd_thresh:.1e} m/s → {m_per_day:.1f} m/day") self.logger.info(f"Median grid cell length below {lat_thresh}°: {GC_len_median:.1f} m") self.logger.info(f"→ Displacement = {pct_GC_disp*100:.2f}% of grid cell per day") self.logger.info(f"→ Days to fully traverse one grid cell : {days_per_GC:.2f} days") return {"ice_speed_thresh_m_per_s" : ispd_thresh, "displacement_m_per_day" : m_per_day, "median_grid_cell_length_m" : GC_len_median, "percent_displacement_per_day": pct_GC_disp, "days_per_grid_cell" : days_per_GC}
def _check_BorC2T_type(self,BorC2T_type): if isinstance(BorC2T_type, str): BorC2T_type = [BorC2T_type] assert all(v in self.valid_BorC2T_types for v in BorC2T_type), f"Invalid BorC2T_type: {BorC2T_type}" def _check_ice_type(self,ice_type): if isinstance(ice_type, str): ice_type = [ice_type] assert all(v in self.valid_ice_types for v in ice_type), f"Invalid ice_type: {ice_type}" ########################################################################################################## ################################## CICE MODEL CONFIGURATION/DIAGNOSTIRICS ################################ ##########################################################################################################
[docs] def parse_simulation_metadata(self, force_recompile=False): """ Parse or load simulation metadata for CICE AFIM runs. If a cached JSON config file exists in the simulation directory and `force_recompile=False`, this method will load and return the metadata from that file. Otherwise, it will parse the `ice_diag.d` file and store the extracted parameters in a structured JSON file for future use. Parameters ---------- force_recompile : bool If True, re-parse the raw ice_diag.d file even if a cached JSON exists. Returns ------- dict Dictionary of extracted parameters including inferred grounded iceberg info. """ import json, re self.logger.info(f"reading {self.P_ice_diag} to construct {self.P_sim_cfg}") if self.P_sim_cfg.exists() and not force_recompile: with open(self.P_sim_cfg, "r") as f: self.sim_config = json.load(f) return True PARAM_KEYS = ["dt", "ndtd", "ndte", "kdyn", "revised_evp", "e_yieldcurve", "e_plasticpot", "Ktens", "kstrength", "Pstar", "Cstar", "Cf", "visc_method", "kmt_file"] PATTERNS = {key: re.compile(rf"{key}\s*=\s*(.+?)\s*:") if key != "kmt_file" else re.compile(rf"{key}\s*=\s*(.+)$") for key in PARAM_KEYS} result = {key: "" for key in PARAM_KEYS} try: with open(self.P_ice_diag, "r", encoding="utf-8", errors="replace") as f: for i, line in enumerate(f): if i > 500: break for key, pattern in PATTERNS.items(): if result[key] == "": match = pattern.search(line) if match: val = match.group(1).strip() if key == "kmt_file": val = Path(val).name result[key] = val except Exception: self.sim_config = None return False # Additional derived metadata kmt_name = result["kmt_file"] if "kmt_mod_thinned-" in kmt_name: try: thin_str = kmt_name.split("thinned-")[1].split("_")[0] # e.g., "0p85" result["GI_thin_fact"] = float(thin_str.replace("p", ".")) except Exception: result["GI_thin_fact"] = "?" try: version_match = re.search(r"_v(\d+p\d+)", kmt_name) if version_match: version_str = version_match.group(1) # '1p50' result["GI_version"] = float(version_str.replace("p", ".")) else: self.logger.info(f"[DEBUG] No version match in: {kmt_name}") result["GI_version"] = "?" except Exception as e: self.logger.info(f"[ERROR] GI_version failed for {kmt_name}: {e}") result["GI_version"] = "?" except Exception: result["GI_version"] = "?" try: if "iter" in kmt_name: iter_str = kmt_name.split("iter")[1].split(".")[0] # e.g., "0" result["GI_iter"] = int(iter_str) else: result["GI_iter"] = None except Exception: result["GI_iter"] = None else: result["GI_thin_fact"] = 0.0 result["GI_version"] = 0.0 result["GI_iter"] = 0 # Save to JSON self.logger.info(f"[CHECK] Parsed GI_version = {result.get('GI_version')} from {kmt_name}") with open(self.P_sim_cfg, "w") as f: json.dump(result, f, indent=2) self.sim_config = result return True
########################################################################################################## ############################# XARRAY/NUMPY DATASET/ARRAY EXTRACTION/WORK ########################## ##########################################################################################################
[docs] def align_time_coordinate_of_three_arrays(self, ds1, ds2, ds3, time_coord="time"): for da in [ds1, ds2, ds3]: da[time_coord] = pd.to_datetime(da[time_coord].values).normalize() t_common = np.intersect1d(np.intersect1d(ds1[time_coord].values, ds2[time_coord].values), ds3[time_coord].values) return ds1.sel(time=t_common), ds2.sel(time=t_common), ds3.sel(time=t_common)
[docs] def dict_to_ds(self, data_dict): """ Convert a dictionary of DataArrays into an xarray.Dataset. Parameters ---------- data_dict : dict Mapping of variable name -> xarray.DataArray (or array-like compatible). Returns ------- xarray.Dataset Dataset with keys from `data_dict` as variables. """ return xr.Dataset({k: v for k, v in data_dict.items()})
[docs] def create_empty_valid_DS_dictionary(self, valid_zarr_DS_list=None): """ Create a nested dictionary template for collecting datasets by category. Parameters ---------- valid_zarr_DS_list : list[str], optional List of dataset keys to initialise per outer key. Defaults to `self.valid_ice_types`. Returns ------- collections.defaultdict A defaultdict where each new outer key maps to a dict of empty lists: { <outer_key>: {<ds_key_1>: [], <ds_key_2>: [], ...} }. Notes ----- This is a lightweight utility for accumulating per-year/per-month objects before concatenation. """ from collections import defaultdict valid_DS_list = valid_zarr_DS_list if valid_zarr_DS_list is not None else self.valid_ice_types return defaultdict(lambda: {k: [] for k in valid_DS_list})
def _to_float_scalar(self, x): """ Convert a scalar-like value (NumPy/xarray) to a Python float. Intended for outputs of reductions (e.g., `.sum()`, `.max()`) or `dask.compute(...)` that yield 0-D arrays. Parameters ---------- x : Any NumPy scalar, 0-D ndarray, or 0-D xarray object. Returns ------- float The scalar as a Python float (including NaN if present). Notes ----- If a higher-dimensional array is passed accidentally, `float(np.asarray(x))` will raise; callers should only pass scalar-like values. """ # Accept numpy scalars, xarray 0-d arrays, etc. if hasattr(x, "values"): x = x.values try: return float(x.item()) # numpy scalar or 0-d array except Exception: return float(np.asarray(x)) def _get_first(self, ds: xr.Dataset, names) -> Optional[xr.DataArray]: """ Return the first matching variable or coordinate from a dataset. Parameters ---------- ds : xarray.Dataset Dataset to search. names : iterable of str Candidate names to check in order. Returns ------- xarray.DataArray or None The first match found in ds.variables or ds.coords, else None. """ for n in names: if n in ds.variables: return ds[n] if n in ds.coords: return ds.coords[n] return None def _has(self, ds, var): """ Return whether the wrapped input contains a variable/key/attribute named ``var``. This helper provides a uniform “does it exist?” check across common container types encountered in the sea-ice workflow (e.g., ``xarray.Dataset`` outputs, dict-like objects, or lightweight objects with attributes). Lookup precedence ----------------- 1. If ``ds`` is an ``xarray.Dataset``: check ``var`` in ``ds.data_vars``. 2. Otherwise attempt membership: ``var in ds`` (for dict-like / list-like containers). 3. If membership is not supported (raises ``TypeError``): fall back to ``hasattr(I_data, var)``. Parameters ---------- ds : xr.Dataset dataset to test var var : str variable name to test for. Returns ------- bool ``True`` if ``var`` is present under the rules above, otherwise ``False``. Notes ----- - For ``xarray.Dataset``, this checks *data variables only* (not coordinates). If you want coordinates too, consider checking ``(var in I_data.variables)``. - The function relies on an outer-scope variable ``ds`` (closure). If you refactor to store data on the instance, replace ``ds`` with ``self.ds``. """ if isinstance(ds, xr.Dataset): return var in ds.data_vars try: return var in ds except TypeError: return hasattr(ds, var) def _get(self, ds, var): """ Retrieve a variable/key/attribute named ``var`` from the wrapped input. This helper provides a uniform accessor across common container types (e.g., ``xarray.Dataset`` outputs, dict-like objects, or objects with attributes). Lookup precedence ----------------- 1. If ``ds`` is an ``xarray.Dataset``: return ``ds[var]`` (typically an ``xarray.DataArray`` or variable-like object). 2. Otherwise attempt key access: ``ds[var]``. 3. If key access fails for any reason: fall back to attribute access ``getattr(ds, var)``. Parameters ---------- var : str Variable name to retrieve. Returns ------- Any The retrieved object. For ``xarray.Dataset`` this is usually an ``xarray.DataArray``. Raises ------ KeyError If ``var`` is not a key in a mapping-like ``ds`` and attribute fallback does not exist. AttributeError If key access fails and ``ds`` does not have attribute ``var``. Exception Any exception raised by ``ds[var]`` may be swallowed and replaced by the attribute fallback attempt. Notes ----- - Because the fallback catches *all* exceptions from ``ds[var]``, genuine indexing errors (not just missing keys) will be masked. If you only want to fall back on missing keys, narrow the exception to ``(KeyError, TypeError)``. - The function relies on an outer-scope variable ``ds`` (closure). If you refactor to store data on the instance, replace ``ds`` with ``self.ds``. """ if isinstance(ds, xr.Dataset): return ds[var] try: return ds[var] except Exception: return getattr(ds, var) ########################################################################################################## ############################# MASKING ########################## ########################################################################################################## @staticmethod def _expand_bounds_slice(slc: slice, nb_len: int) -> slice: """ Convert a centre-row slice (length nb_len-1) to a corner-row slice (length nb_len) by expanding the stop bound by +1, clipped to nb_len. """ if not isinstance(slc, slice): raise TypeError(f"hemisphere nj_slice must be a Python slice, got {type(slc)}") start = 0 if slc.start is None else slc.start stop_nj = nb_len - 1 # last valid nj index is nb_len-2 stop = stop_nj if slc.stop is None else slc.stop step = slc.step stop_b = min(stop + 1, nb_len) # include boundary row at stop return slice(start, stop_b, step) @staticmethod def _indexers_for(obj, y_dim: str, nj_slice: slice): """ Build an `.isel()` indexer dict for an object with optional centre/corner dims. - Applies `nj_slice` to `y_dim` if present. - If a matching corner dim `'nj_b'` exists, expands the slice using `_expand_bounds_slice()` to include the boundary row. """ idx = {} dims = getattr(obj, "dims", {}) if y_dim in dims: idx[y_dim] = nj_slice if "nj_b" in dims: nb_len = obj.sizes["nj_b"] idx["nj_b"] = SeaIceToolbox._expand_bounds_slice(nj_slice, nb_len) return idx
[docs] def slice_hemisphere(self, var): """ Apply the configured hemisphere slice to a Dataset/DataArray (and corners). Uses `self.hemisphere_dict['nj_slice']` on the cell-centre row dimension (e.g., `nj`). If a matching corner dimension is present (e.g., `nj_b`), expands the stop bound by +1 to include the boundary row. Parameters ---------- var : xr.Dataset | xr.DataArray | dict[str, xr.Dataset | xr.DataArray] Object(s) to slice. Dict values are sliced if their dims match. Returns ------- xr.Dataset | xr.DataArray | dict Sliced object of the same type as the input. Notes ----- - Cell-centre dimension name is taken from `self.CICE_dict["y_dim"]`. - Corner dim is assumed to be `'nj_b'`. If you parameterise corners in your config, adapt `_indexers_for()` accordingly. """ y_dim = self.CICE_dict["y_dim"] # typically 'nj' nj_slice = self.hemisphere_dict['nj_slice'] # a Python slice def _apply(obj): idx = SeaIceToolbox._indexers_for(obj, y_dim, nj_slice) return obj.isel(idx) if idx else obj if isinstance(var, dict): out = {} for k, v in var.items(): out[k] = _apply(v) if isinstance(v, (xr.Dataset, xr.DataArray)) else v which = f"{y_dim} and/or nj_b" self.logger.debug(f"Hemisphere slice applied on dict members where dims matched ('{which}').") return out elif isinstance(var, (xr.Dataset, xr.DataArray)): sliced = _apply(var) which = " & ".join([d for d in (y_dim, "nj_b") if d in getattr(var, "dims", {})]) or "none" self.logger.debug(f"Hemisphere slice applied on dims: {which}.") return sliced else: raise ValueError(f"Unsupported input type: {type(var)}. Must be dict, Dataset, or DataArray.")
def _region_mask(self, lon : xr.DataArray, lat : xr.DataArray, geo_reg : tuple[float, float, float, float], *, right_open: bool = True) -> xr.DataArray: """ Geographic mask for [lon_min, lon_max, lat_min, lat_max]. right_open=True makes lon_max exclusive (helps avoid double-counting at boundaries). """ lon_min, lon_max, lat_min, lat_max = geo_reg m = (lon >= lon_min) & (lat >= lat_min) & (lat <= lat_max) if right_open: m = m & (lon < lon_max) else: m = m & (lon <= lon_max) return m ########################################################################################################## #################################### PATH/DIRECTORY DEFINITIONS ######################################### ##########################################################################################################
[docs] def count_zarr_files(self, path): """ Count and log the number of files under a directory tree (e.g., a Zarr store). Parameters ---------- path : str or pathlib.Path Directory to count files under. Notes ----- This counts filesystem entries, not Zarr logical arrays. It is useful for quick sanity checks and estimating metadata overhead. """ total_files = sum(len(files) for _, _, files in os.walk(path)) self.logger.info(f"{path} contains {total_files} files")
[docs] def get_dir_size(self, path): """ Compute and log the total on-disk size of a directory tree. Parameters ---------- path : str or pathlib.Path Directory to scan recursively. Notes ----- - Size is computed from file sizes returned by `Path.rglob` and may be slow on large filesystems. - The result is logged in GiB (1024**3). """ size_gb = sum(f.stat().st_size for f in path.rglob("*") if f.is_file()) / (1024**3) self.logger.info(f"Disk-usage (size) of directory {path}: {size_gb:.2f} GB")
@staticmethod def _class_method_key(class_method: str) -> str: lut = {"raw" : "raw", "daily" : "raw", "binary-days" : "bin", "binary_days" : "bin", "bin" : "bin", "rolling-mean": "roll", "rolling_mean": "roll", "roll" : "roll"} if class_method not in lut: raise ValueError(f"Unsupported class_method: {class_method}") return lut[class_method] def _resolve_product_store(self, ice_type, class_method = "raw", product = "data", D_zarr = None, BorC2T_type = None, ispd_thresh = None, bin_win_days = None, bin_min_days = None, mean_period = None): self.define_toolbox_paths(D_zarr = D_zarr, ice_type = ice_type, BorC2T_type = BorC2T_type, ispd_thresh = ispd_thresh, bin_win_days = bin_win_days, bin_min_days = bin_min_days, mean_period = mean_period) method_key = self._class_method_key(class_method) if ice_type in {"SI", "MIZ"}: if method_key != "raw": raise ValueError(f"{ice_type} only supports class_method='raw'") return self.P_zarrs[ice_type][product] return self.P_zarrs[ice_type][product][method_key]
[docs] def define_toolbox_paths(self, sim_name = None, D_sim = None, D_zarr = None, D_graph = None, ice_type = None, BorC2T_type = None, ispd_thresh = None, bin_win_days = None, bin_min_days = None, mean_period = None): """ Define canonical paths for ice-history, classified products, and metrics. Layout: <D_zarr>/ iceh_daily.zarr/<YYYY-MM>/ iceh_monthly.zarr/<YYYY-MM>/ <HEM>/ SI/ raw.zarr mets.zarr MIZ/ raw.zarr mets.zarr ispd_thresh_<thr>/ FI/ <BorC2T>/ raw.zarr mets.zarr bin-win-XX_bin-min-YY/ raw.zarr mets.zarr roll-days-ZZ/ raw.zarr mets.zarr PI/ <BorC2T>/ raw.zarr mets.zarr bin-win-XX_bin-min-YY/ raw.zarr mets.zarr roll-days-ZZ/ raw.zarr mets.zarr """ sim_name = sim_name if sim_name is not None else self.sim_name self.D_sim = Path(D_sim if D_sim is not None else Path(self.D_dict['AFIM_out'], sim_name)) self.D_zarr = Path(D_zarr if D_zarr is not None else Path(self.D_sim , 'zarr')) self.D_graph = Path(D_graph if D_graph is not None else Path(self.config['D_dict']['graph'], 'AFIM')) ice_type = ice_type if ice_type is not None else self.ice_type BorC2T_type = BorC2T_type if BorC2T_type is not None else self.BorC2T_type ispd_thresh = float(ispd_thresh if ispd_thresh is not None else self.ispd_thresh) bin_win_days = int(bin_win_days if bin_win_days is not None else self.bin_win_days) bin_min_days = int(bin_min_days if bin_min_days is not None else self.bin_min_days) mean_period = int(mean_period if mean_period is not None else self.mean_period) self._check_ice_type(ice_type) self._check_BorC2T_type(BorC2T_type) self.P_ice_diag = self.D_sim / "ice_diag.d" self.P_sim_cfg = self.D_sim / f"ice_in_AFIM_subset_{sim_name}.json" if isinstance(BorC2T_type, str): self.reG_type = BorC2T_type else: self.reG_type = "".join(BorC2T_type) ispd_thr_str = f"{ispd_thresh:.1e}".replace("e-0", "e-").replace("e+0", "e+") self.D_hem = self.D_zarr / self.hemisphere_dict["abbreviation"] self.D_ispd = self.D_hem / f"ispd_thresh_{ispd_thr_str}" self.D_iceh = {"zarr": {"daily": self.D_zarr / "iceh_daily.zarr", "monthly": self.D_zarr / "iceh_monthly.zarr"}, "nc" : {"daily": self.D_sim / "history" / "daily", "monthly": self.D_sim / "history" / "monthly"}} self.D_class = {"SI" : self.D_hem / "SI", "MIZ": self.D_hem / "MIZ", "FI" : {"raw": self.D_ispd / "FI" / self.reG_type / "raw", "bin": self.D_ispd / "FI" / self.reG_type / f"bin-win-{bin_win_days:02d}_bin-min-{bin_min_days:02d}", "roll": self.D_ispd / "FI" / self.reG_type / f"roll-days-{mean_period:02d}"}, "PI" : {"raw": self.D_ispd / "PI" / self.reG_type / "raw", "bin": self.D_ispd / "PI" / self.reG_type / f"bin-win-{bin_win_days:02d}_bin-min-{bin_min_days:02d}", "roll": self.D_ispd / "PI" / self.reG_type / f"roll-days-{mean_period:02d}"}} self.P_zarrs = {"SI" : {"data": self.D_class["SI"] / "raw.zarr", "mets": self.D_class["SI"] / "mets.zarr"}, "MIZ": {"data": self.D_class["MIZ"] / "raw.zarr", "mets": self.D_class["MIZ"] / "mets.zarr"}, "FI" : {"data": {"raw": self.D_class["FI"]["raw"] / "data.zarr", "bin": self.D_class["FI"]["bin"] / "data.zarr", "roll": self.D_class["FI"]["roll"] / "data.zarr"}, "mets": {"raw": self.D_class["FI"]["raw"] / "mets.zarr", "bin": self.D_class["FI"]["bin"] / "mets.zarr", "roll": self.D_class["FI"]["roll"] / "mets.zarr"}}, "PI" : {"data": {"raw": self.D_class["PI"]["raw"] / "data.zarr", "bin": self.D_class["PI"]["bin"] / "data.zarr", "roll": self.D_class["PI"]["roll"] / "data.zarr"}, "mets": {"raw": self.D_class["PI"]["raw"] / "mets.zarr", "bin": self.D_class["PI"]["bin"] / "mets.zarr", "roll": self.D_class["PI"]["roll"] / "mets.zarr"}}} self.methods_init_executed['define_toolbox_paths'] = True return True
[docs] def define_ice_mask_name(self, ice_type=None): ice_type = ice_type or self.ice_type self._check_ice_type(ice_type) self.mask_name = f"{ice_type}_mask"
def _sanitize_for_zarr_write(self, obj: xr.Dataset | xr.DataArray, drop_coords: list | None = None, cast_float32: bool = False) -> xr.Dataset | xr.DataArray: """ Remove non-essential auxiliary coordinates before writing classified products. This is primarily to prevent year-to-year inconsistencies in saved Zarr groups, e.g. one year carrying NLON/NLAT while another does not, which later breaks xarray.concat in load_classified_ice(). Parameters ---------- obj : xr.Dataset or xr.DataArray Classified product to be written. drop_coords : list[str], optional Extra coords to remove. If None, uses a conservative default plus any entries in self.CICE_dict["drop_coords"]. cast_float32 : bool, default False If True, cast float64 data variables to float32 before write. Returns ------- xr.Dataset or xr.DataArray Sanitised object ready for Zarr write. """ default_drop = ["NLON", "NLAT", "ULON", "ULAT", "TLON", "TLAT", "elon", "elat", "nlon", "nlat", "ulon", "ulat", "tlon", "tlat"] cfg_drop = [] if hasattr(self, "CICE_dict") and isinstance(self.CICE_dict, dict): cfg_drop = list(self.CICE_dict.get("drop_coords", [])) all_drop = drop_coords or list(dict.fromkeys(default_drop + cfg_drop)) if isinstance(obj, xr.DataArray): da = obj present = [c for c in all_drop if c in da.coords] if present: self.logger.debug(f"write-sanitize DataArray: dropping coords {present}") da = da.drop_vars(present, errors="ignore") if cast_float32 and np.issubdtype(da.dtype, np.floating) and da.dtype != np.float32: da = da.astype(np.float32) return da ds = obj present = [c for c in all_drop if c in ds.coords] if present: self.logger.debug(f"write-sanitize Dataset: dropping coords {present}") ds = ds.drop_vars(present, errors="ignore") # Optional float downcast for storage consistency ds = ds.drop_vars("time_bounds", errors="ignore") if "time" in ds.coords: ds["time"].attrs.pop("bounds", None) if cast_float32: ds = ds.copy() for v in ds.data_vars: if np.issubdtype(ds[v].dtype, np.floating) and ds[v].dtype != np.float32: ds[v] = ds[v].astype(np.float32) # if cast_float32: # cast_map = {} # for v in ds.data_vars: # if np.issubdtype(ds[v].dtype, np.floating) and ds[v].dtype != np.float32: # cast_map[v] = np.float32 # if cast_map: # ds = ds.astype(cast_map) return ds def _clean_var_encoding(self, da: xr.DataArray, keep_chunks : bool = None, keep_compressor : bool = None) -> xr.DataArray: enc0 = dict(getattr(da, "encoding", {}) or {}) enc1 = {} if keep_chunks and "chunks" in enc0: enc1["chunks"] = enc0["chunks"] if keep_compressor and "compressor" in enc0: enc1["compressor"] = enc0["compressor"] # keep fill value if present if "_FillValue" in enc0: enc1["_FillValue"] = enc0["_FillValue"] da.encoding = enc1 return da def _strip_unsafe_zarr_encoding(self, obj: xr.Dataset | xr.DataArray, keep_chunks : bool = True, keep_compressor: bool = True) -> xr.Dataset | xr.DataArray: """ Remove problematic / inherited variable encodings before Zarr write. This is especially useful when objects inherit encodings from upstream reads and are later written with xarray.to_zarr(). We keep only a small safe subset by default. """ if isinstance(obj, xr.DataArray): return self._clean_var_encoding(obj, keep_chunks=keep_chunks, keep_compressor=keep_compressor) ds = obj.copy() for v in ds.variables: ds[v].encoding = {} # clear first for v in ds.data_vars: ds[v] = self._clean_var_encoding(ds[v], keep_chunks=keep_chunks, keep_compressor=keep_compressor) for c in ds.coords: ds[c].encoding = {} return ds def _drop_duplicate_coords(self, ds: xr.Dataset, dim: str = "ni") -> xr.Dataset: """ Drop duplicate coordinate values along `dim` (keeping the first occurrence). Useful when concatenating yearly groups that may contain duplicated x-indices. """ if dim in ds.coords: _, index = np.unique(ds[dim], return_index=True) ds = ds.isel({dim: sorted(index)}) return ds def _normalise_concat_coords(self, ds: xr.Dataset, dim: str = "ni", drop_coords: list | None = None) -> xr.Dataset: """ Normalise non-essential spatial coords before concatenation of yearly groups. This protects against cases where some yearly Zarr groups carry auxiliary coordinates (for example NLON/NLAT) and others do not, which causes xarray.concat(..., coords="minimal") to raise: ValueError: coordinate 'NLON' not present in all datasets Parameters ---------- ds : xr.Dataset Dataset for a single year/group. dim : str, default "ni" Optional dimension along which duplicate coordinate values should be removed, keeping first occurrence. drop_coords : list[str], optional Explicit list of coord names to drop before concat. If None, uses a conservative default set of known non-essential geolocation coords. Returns ------- xr.Dataset Dataset with problematic auxiliary coords removed and duplicate concat coordinates cleaned. """ # 1) drop duplicate concat coord values, if present if dim in ds.coords: _, index = np.unique(ds[dim], return_index=True) ds = ds.isel({dim: sorted(index)}) # 2) drop auxiliary geolocation coords that are not needed for concat default_drop = ["NLON", "NLAT", "ULON", "ULAT", "TLON", "TLAT", "elon", "elat", "nlon", "nlat", "ulat", "ulon", "tlat", "tlon"] cfg_drop = [] if hasattr(self, "CICE_dict") and isinstance(self.CICE_dict, dict): cfg_drop = list(self.CICE_dict.get("drop_coords", [])) drop_now = drop_coords or list(dict.fromkeys(default_drop + cfg_drop)) present = [c for c in drop_now if c in ds.coords] if present: self.logger.debug(f"Dropping auxiliary coords before concat: {present}") ds = ds.drop_vars(present, errors="ignore") return ds def _write_grouped_zarr(self, ds, store, group, overwrite_group=False, consolidated=False): store = Path(store) P_grp = store / group if P_grp.exists(): if overwrite_group: self.logger.info(f"group {P_grp} already exists but overwrite_group = True ... OVER-WRITTING GROUP") shutil.rmtree(P_grp) else: self.logger.info(f"Group already exists, skipping: {P_grp}") return False ds = self._sanitize_for_zarr_write(ds, cast_float32=True) ds = self._strip_unsafe_zarr_encoding(ds) mode = "a" if store.exists() else "w" self.logger.info(f"writing to group {P_grp}") self.logger.info(f"writing to group with mode '{mode}' and consolidated {consolidated}") ds.to_zarr(store, group=group, mode=mode, consolidated=consolidated, zarr_format=2) return True ########################################################################################################## #################################### NORMALISATIONS ######################################### ########################################################################################################## def _as_da_mask(self, x): """ Normalize a fast-ice mask input to an xarray.DataArray. Accepts either: • a DataArray that already is the binary fast-ice mask, or • a Dataset that contains a variable named 'FI_mask'. Returns ------- xarray.DataArray The 'FI_mask' mask with its original coordinates/dtype preserved. Raises ------ ValueError If a Dataset is provided but it does not contain 'FI_mask'. TypeError If the input is neither a DataArray nor a Dataset with 'FI_mask'. Notes ----- This function does not cast dtype; callers may wish to `.astype('i1')` (or similar) if they want a compact integer mask. """ if isinstance(x, xr.Dataset): if "FI_mask" in x: x = x["FI_mask"] else: raise ValueError("I_mask is a Dataset but lacks variable 'FI_mask'.") if not isinstance(x, xr.DataArray): raise TypeError("I_mask must be an xarray.DataArray or Dataset containing 'FI_mask'.") return x def _as_da_area(self, x): """ Normalize a grid-cell area input to an xarray.DataArray. Accepts either: • a DataArray that already is the cell-area field, or • a Dataset containing one of {'tarea','area','TAREA'} (first match wins). Returns ------- xarray.DataArray The area field with its original coordinates/dtype preserved. Raises ------ ValueError If a Dataset is provided but no recognized area variable is present. TypeError If the input is neither a DataArray nor a Dataset with an area var. Notes ----- This function does not alter dimensions. If the area has a 'time' dim, strip it upstream or downstream (e.g., `A.isel(time=0)`). """ if isinstance(x, xr.Dataset): for k in ("tarea", "area", "TAREA"): if k in x: x = x[k] break else: raise ValueError("A is a Dataset but no area variable found (looked for 'tarea','area','TAREA').") if not isinstance(x, xr.DataArray): raise TypeError("A must be an xarray.DataArray or Dataset containing an area variable.") return x def _norm_list(self, x: Optional[Iterable[str]]) -> Optional[List[str]]: """ Normalize an optional iterable of strings into a clean list. Parameters ---------- x : iterable of str or None Input strings (e.g., sensors, levels, versions). Elements that are None, empty, or whitespace-only are removed. Returns ------- list[str] or None Cleaned list of non-empty strings, or None if the result is empty. """ if x is None: return None out = [str(s).strip() for s in x if s and str(s).strip()] return out or None ########################################################################################################## #################################### DATE/TIME MANIPULATIONS ######################################### ##########################################################################################################
[docs] def define_datetime_vars(self, dt0_str=None, dtN_str=None): """ Define date range attributes from start and end date strings. Parameters ---------- dt0_str : str, optional Start date string (e.g., '1994-01-01'). Defaults to `self.dt0_str`. dtN_str : str, optional End date string (e.g., '1999-12-31'). Defaults to `self.dtN_str`. Sets ---- self.dt0 : pd.Timestamp Parsed start date. self.dtN : pd.Timestamp Parsed end date. self.yrs_mos : np.ndarray Array of 'YYYY-MM' strings for each month in the range. self.ymd_strs : np.ndarray Array of 'YYYY-MM-DD' strings for each day in the range. """ from pandas.tseries.offsets import MonthEnd dt0_str = dt0_str or self.dt0_str dtN_str = dtN_str or self.dtN_str self.dt0 = pd.to_datetime(dt0_str) self.dtN = pd.to_datetime(dtN_str) self.dt_range = pd.date_range(self.dt0, self.dtN, freq="D") self.ymd_strs = self.dt_range.strftime("%Y-%m-%d") self.mos0 = pd.date_range(self.dt0, self.dtN, freq="MS") self.mosN = pd.date_range(self.dt0, self.dtN, freq="ME") self.yrs_mos0 = self.mos0.strftime("%Y-%m") self.mo0_strs = self.mos0.strftime("%Y-%m-%d").tolist() self.moN_strs = self.mosN.strftime("%Y-%m-%d").tolist()#(self.mos + MonthEnd(1)).strftime("%Y-%m-%d").tolist() self.yrs0 = pd.date_range(self.dt0, self.dtN, freq="YS") self.yrsN = pd.date_range(self.dt0, self.dtN, freq="YE") self.yr0_strs = self.yrs0.strftime("%Y-%m-%d") self.yrN_strs = self.yrsN.strftime("%Y-%m-%d")
[docs] def define_month_first_last_dates(self, year_month_str): """ Given a 'YYYY-MM' string, return the first and last dates of that month. Parameters ---------- year_month_str : str Year and month string in 'YYYY-MM' format. Returns ------- tuple of str First and last day of the month in 'YYYY-MM-DD' format. """ m0_str = f"{year_month_str}-01" mN_str = (pd.to_datetime(m0_str) + pd.offsets.MonthEnd()).strftime("%Y-%m-%d") return m0_str, mN_str
def _days_in_month(self, y: int, m: int) -> int: """ Return the number of days in a given month of a given year (Gregorian calendar). Parameters ---------- y : int Year (e.g., 2002). m : int Month number in [1..12]. Returns ------- int Number of days in the month, accounting for leap years. """ if m in (1,3,5,7,8,10,12): return 31 if m in (4,6,9,11): return 30 return 29 if ((y % 4 == 0 and y % 100 != 0) or (y % 400 == 0)) else 28 def _month_overlap(self, y: int, m: int, t0: pd.Timestamp, t1: pd.Timestamp) -> bool: """ Check whether a given (year, month) overlaps an inclusive UTC time window. Parameters ---------- y, m : int Year and month to test. t0, t1 : pandas.Timestamp Start and end timestamps of the desired window (inclusive), interpreted in UTC. Returns ------- bool True if the month interval intersects [t0, t1], otherwise False. """ first = pd.Timestamp(year=y, month=m, day=1, tz="UTC") last = pd.Timestamp(year=y, month=m, day=self._days_in_month(y,m), tz="UTC") return not (last < t0 or first > t1) def _parse_yyyymmdd(self, name: str) -> Optional[pd.Timestamp]: """ Parse a YYYYMMDD date token from a filename-like string. This looks for an 8-digit date token in the range 2000–2099 and returns it as a UTC pandas.Timestamp. Parameters ---------- name : str Filename (or other string) to search. Returns ------- pandas.Timestamp or None Parsed UTC timestamp for the detected date, or None if no valid token exists. Notes ----- - The regex is constrained to 20xx years by design. - Invalid day/month combinations return None. """ m = re.search(r"(?<!\d)(20\d{2})(0[1-9]|1[0-2])(0[1-9]|[12]\d|3[01])(?!\d)", name) if not m: return None y, mo, d = int(m.group(1)), int(m.group(2)), int(m.group(3)) try: return pd.Timestamp(year=y, month=mo, day=d, tz="UTC") except Exception: return None def _parse_yyyymm(self, name: str) -> Optional[Tuple[int,int]]: """ Parse a YYYYMM month token from a filename-like string. Parameters ---------- name : str Filename (or other string) to search. Returns ------- tuple[int, int] or None (year, month) if a YYYYMM token is found (2000–2099), else None. """ m = re.search(r"(?<!\d)(20\d{2})(0[1-9]|1[0-2])(?!\d)", name) if not m: return None return int(m.group(1)), int(m.group(2))
[docs] def create_monthly_strings(self, dt0_str=None, dtN_str=None): """ Return sorted unique YYYY-MM strings between two dates (inclusive). Parameters ---------- dt0_str, dtN_str : str, optional Start/end dates in ``YYYY-MM-DD``. Defaults to `self.dt0_str` and `self.dtN_str`. Returns ------- list[str] Sorted list of unique month strings, e.g., ["1993-01", "1993-02", ...]. Notes ----- Uses a daily pandas date range and then de-duplicates by month. """ dt0_str = dt0_str if dt0_str is not None else self.dt0_str dtN_str = dtN_str if dtN_str is not None else self.dtN_str dts = pd.date_range(dt0_str, dtN_str, freq="D") return sorted(set(dt.strftime("%Y-%m") for dt in dts))
########################################################################################################## ################################### BASIC STATISTICS/CLIMATOLOGY ######################################### ##########################################################################################################
[docs] def compute_rolling_mean_on_dataset(self, ds, mean_period=None): """ Apply a centered temporal rolling mean to a dataset. This method smooths temporal noise by computing a centered moving average across the time dimension, typically for use in rolling fast ice classification. INPUTS: ds : xarray.Dataset; input dataset with a `time` dimension. mean_period : int, optional; rolling window size in days. Defaults to `self.mean_period`. OUTPUTS: xarray.Dataset; dataset with all variables averaged over the specified rolling window. """ mean_period = mean_period if mean_period is not None else self.mean_period return ds.rolling(time=mean_period, center=True, min_periods=1).mean()
[docs] def compute_doy_climatology(self, da, leap_year=None, time_coord=None): """ Compute day-of-year (DOY) climatology statistics from a time series. This method calculates the climatological mean, minimum, maximum, and standard deviation of a given time series DataArray, grouped by day-of-year. The result is returned as a dictionary of Pandas Series indexed by a datetime index constructed using a reference `leap_year`. Parameters ---------- da : xarray.DataArray The input time series with a time coordinate. Can be daily or sub-daily resolution, but must be regular and span multiple years for meaningful climatology. leap_year : int, optional The reference leap year to use when reconstructing the datetime index for the output. Defaults to `self.leap_year` if not provided. time_coord : str, optional The name of the time coordinate in `da`. Defaults to `self.CICE_dict['time_dim']` if not specified. Returns ------- dict A dictionary with keys: - 'mean' : pd.Series of climatological mean values - 'min' : pd.Series of climatological minimum values - 'max' : pd.Series of climatological maximum values - 'std' : pd.Series of climatological standard deviation Each Series is indexed by datetime values (from the specified `leap_year`) corresponding to days 1–366. Notes ----- - The output includes 366 days if data contains leap years; otherwise, it includes up to 365. - The use of a leap year for index construction ensures that the DOY mapping to dates is valid, especially for plotting or seasonal alignment. - The DataArray is fully loaded into memory before processing. """ leap_year = leap_year if leap_year is not None else self.leap_year time_coord = time_coord if time_coord is not None else self.CICE_dict['time_dim'] da = da.load() df = pd.DataFrame({"time" : pd.to_datetime(da[time_coord].values), "data" : da.values}) df["doy"] = df["time"].dt.dayofyear data_clim = df.groupby("doy")["data"] data_min = data_clim.min() data_max = data_clim.max() data_mean = data_clim.mean() data_std = data_clim.std() t_idx = pd.to_datetime(data_mean.index - 1, unit="D", origin=pd.Timestamp(f"{leap_year}-01-01")) data_min.index = data_max.index = data_mean.index = data_std.index = t_idx return {'min' : data_min, 'max' : data_max, 'std' : data_std, 'mean' : data_mean}