IPSL_AID package

Submodules

IPSL_AID.dataset module

IPSL_AID.dataset.stats(ds, logger, input_dir, norm_mapping={})[source]

Load normalization statistics and compute coordinate metadata for a NetCDF dataset.

This function loads normalization statistics from a JSON file if available. If no statistics file is found, it falls back to predefined constants for fine and coarse variables only.

Parameters:
  • ds (xarray.Dataset) – NetCDF dataset to process.

  • logger (logging.Logger) – Logger instance for logging messages and statistics.

  • input_dir (str) – Directory containing a statistics.json file with precomputed normalization statistics.

  • norm_mapping (dict, optional) – Dictionary to store computed statistics. If empty, will be populated. Default is empty dict.

Returns:

  • norm_mapping (dict) – Dictionary mapping variable names to their computed statistics. For coordinates: min, max, mean, std. For data variables: min, max, mean, std, q1, q3, iqr, median.

  • steps (EasyDict) – Dictionary containing coordinate step sizes and lengths.

Notes

  • If statistics.json is found, all statistics are loaded as-is.

  • If not found, fallback constants are used for fine and coarse variables only.

IPSL_AID.dataset.coarse_down_up(fine_filtered, fine_batch, input_shape=(16, 32), axis=0)[source]

Downscale and then upscale fine-resolution data to compute coarse approximation.

This function performs a downscaling-upscaling operation to create a coarse resolution approximation of fine data. This is commonly used in multi-scale analysis, image processing, and super-resolution tasks.

Parameters:
  • fine_filtered (torch.Tensor or np.ndarray) – Fine-resolution filtered data. Can be of shape (C, Hf, Wf) for multi-channel data or (Hf, Wf) for single-channel data. Where C is number of channels, Hf is fine height, and Wf is fine width.

  • fine_batch (torch.Tensor or np.ndarray) – Fine-resolution target data. Must have same spatial dimensions as fine_filtered. Shape: (C, Hf, Wf) or (Hf, Wf).

  • input_shape (tuple of int, optional) – Target shape (Hc, Wc) for the coarse-resolution data after downscaling. Default is (16, 32).

  • axis (int, optional) – Axis along which to insert batch dimension if the input lacks one. Default is 0.

Returns:

coarse_up – Upscaled coarse approximation of the fine data. Same shape as input fine_filtered without batch dimension.

Return type:

torch.Tensor

Notes

  • The function ensures that the input tensors have a batch dimension by adding one if missing.

  • Uses bilinear interpolation for both downscaling and upscaling operations.

  • The antialias parameter is set to True for better quality resampling.

  • Useful for creating multi-scale representations in image processing and computer vision tasks.

IPSL_AID.dataset.gaussian_filter(image, dW, dH, cutoff_W_phys, cutoff_H_phys, epsilon=0.01, margin=8)[source]

Apply a Gaussian low-pass filter with controlled attenuation at the cutoff frequency.

This function performs a 2D Fourier transform of the input field, applies a Gaussian weighting in the frequency domain, and inversely transforms it back to the spatial domain. Unlike the standard Gaussian filter, this version defines the Gaussian width such that the response amplitude reaches a specified attenuation factor (epsilon) at the cutoff frequency. Padding with reflection is used to minimize edge artifacts.

Parameters:
  • image (np.array of shape (H, W)) – Input 2D field to be filtered (temperature, wind component).

  • dW (float) – Grid spacing in degrees of longitude.

  • dH (float) – Grid spacing in degrees of latitude.

  • cutoff_W_phys (float) – Longitudinal cutoff frequency in cycles per degree. Frequencies higher than this threshold are attenuated according to the Gaussian response.

  • cutoff_H_phys (float) – Latitudinal cutoff frequency in cycles per degree. Frequencies higher than this threshold are attenuated according to the Gaussian response.

  • epsilon (float, optional) – Desired amplitude response at the cutoff frequency (default: 0.01). Lower values produce sharper attenuation and stronger filtering.

  • margin (int, optional) – Number of pixels to pad on each side using reflection (default: 8). This reduces edge effects in the Fourier transform.

Returns:

filtered – Real-valued filtered field after inverse Fourier transform and margin cropping.

Return type:

ndarray of shape (H, W)

Notes

  • The Gaussian width parameters are computed such that: exp(-0.5 * (f_cutoff / sigma)^2) = epsilon, leading to sigma = f_cutoff / sqrt(-2 * log(epsilon)).

  • Padding the input with reflective boundaries minimizes spectral leakage and discontinuities at image edges.

  • The output field is cropped back to its original size after filtering.

  • This formulation provides more explicit control over filter sharpness than the standard Gaussian low-pass implementation.

class IPSL_AID.dataset.DataPreprocessor(*args: Any, **kwargs: Any)[source]

Bases: Dataset

Dataset class for preprocessing weather and climate data for machine learning.

This class handles loading, preprocessing, and sampling of multi-year NetCDF weather data with support for multi-scale processing, normalization, and spatial-temporal sampling strategies.

Parameters:
  • years (list of int) – Years of data to include.

  • loaded_dfs (xarray.Dataset) – Pre-loaded dataset containing the weather variables.

  • constants_file_path (str) – Path to NetCDF file containing constant variables (e.g., topography).

  • varnames_list (list of str) – List of variable names to extract from the dataset.

  • units_list (list of str) – Units for each variable in varnames_list.

  • in_shape (tuple of int, optional) – Target shape (height, width) for coarse resolution. Default is (16, 32).

  • batch_size_lat (int, optional) – Height of spatial batch in grid points. Default is 144.

  • batch_size_lon (int, optional) – Width of spatial batch in grid points. Default is 144.

  • steps (EasyDict, optional) – Dictionary containing grid dimension information. Should include: - latitude/lat: number of latitude points - longitude/lon: number of longitude points - time: number of time steps - d_latitude/d_lat: latitude spacing - d_longitude/d_lon: longitude spacing

  • tbatch (int, optional) – Number of time batches to sample. Default is 1.

  • sbatch (int, optional) – Number of spatial batches to sample. Default is 8.

  • debug (bool, optional) – Enable debug logging. Default is True.

  • mode (str, optional) – Operation mode: “train” or “validation”. Default is “train”.

  • run_type (str, optional) – Run type: “train”, “validation”, or “inference”. Default is “train”.

  • dynamic_covariates (list of str, optional) – List of dynamic covariate variable names. Default is None.

  • dynamic_covariates_dir (str, optional) – Directory containing dynamic covariate files. Default is None.

  • time_normalization (str, optional) – Method for time normalization: “linear” or “cos_sin”. Default is “linear”.

  • norm_mapping (dict, optional) – Dictionary containing normalization statistics for variables.

  • index_mapping (dict, optional) – Dictionary mapping variable names to indices in the data array.

  • normalization_type (dict, optional) – Dictionary specifying normalization type per variable.

  • constant_variables (list of str, optional) – List of constant variable names to load. Default is None.

  • epsilon (float, optional) – Small value for numerical stability in filtering. Default is 0.02.

  • margin (int, optional) – Margin for filtering operations. Default is 8.

  • dtype (tuple, optional) – Data types for torch and numpy (torch_dtype, np_dtype). Default is (torch.float32, np.float32).

  • apply_filter (bool, optional) – Whether to apply Gaussian filtering for multi-scale processing. Default is False.

  • logger (logging.Logger, optional) – Logger instance for logging messages. Default is None.

const_vars

Array of constant variables with shape (n_constants, H, W).

Type:

np.ndarray or None

time

Time coordinate from dataset.

Type:

xarray.DataArray

year

Year component of time.

Type:

xarray.DataArray

month

Month component of time.

Type:

xarray.DataArray

day

Day component of time.

Type:

xarray.DataArray

hour

Hour component of time.

Type:

xarray.DataArray

year_norm

Normalized year values.

Type:

torch.Tensor

doy_norm

Normalized day-of-year values (linear mode).

Type:

torch.Tensor or None

hour_norm

Normalized hour values (linear mode).

Type:

torch.Tensor or None

doy_sin, doy_cos

Sine and cosine of day-of-year (cos_sin mode).

Type:

torch.Tensor or None

hour_sin, hour_cos

Sine and cosine of hour (cos_sin mode).

Type:

torch.Tensor or None

time_batchs

Array of time indices for current epoch.

Type:

np.ndarray

eval_slices

List of spatial slices for evaluation mode.

Type:

list of tuple or None

random_centers

List of random spatial centers for training mode.

Type:

list of tuple or None

center_tracker

Tracks spatial centers for debugging.

Type:

list

tindex_tracker

Tracks temporal indices for debugging.

Type:

list

new_epoch()[source]

Reset time batches and random centers for new training epoch.

sample_time_steps_by_doy()[source]

Sample time steps based on day-of-year (DOY) for multi-year continuity.

sample_random_time_indices()[source]

Randomly sample time indices for training.

load_dynamic_covariates()[source]

Load dynamic covariate data (not fully implemented).

generate_random_batch_centers(n_batches)[source]

Generate random spatial centers for batch sampling.

generate_evaluation_slices()[source]

Generate deterministic spatial slices for evaluation.

extract_batch(data, ilat, ilon)[source]

Extract spatial batch centered at (ilat, ilon) with cyclic longitude.

filter_batch(fine_patch, fine_block)[source]

Apply Gaussian low-pass filtering for multi-scale processing.

normalize(data, stats, norm_type, var_name=None, data_type=None)[source]

Normalize data using specified statistics and method.

normalize_time(tindex)[source]

Return normalized time features for given time index.

__len__()[source]

Return total number of samples.

__getitem__(index)[source]

Get a single sample with appropriate spatial-temporal sampling.

Notes

  • Supports both random (training) and deterministic (validation) sampling.

  • Handles cyclic longitude wrapping for global datasets.

  • Provides multi-scale processing through downscaling/upscaling.

  • Includes time normalization with linear or trigonometric encoding.

  • Can incorporate constant variables (e.g., topography, land-sea mask).

__init__(years, loaded_dfs, constants_file_path, varnames_list, units_list, in_shape=(16, 32), batch_size_lat=144, batch_size_lon=144, steps={}, tbatch=1, sbatch=8, debug=True, mode='train', run_type='train', dynamic_covariates=None, dynamic_covariates_dir=None, time_normalization='linear', norm_mapping=None, index_mapping=None, normalization_type=None, constant_variables=None, epsilon=0.02, margin=8, dtype=(torch.float32, numpy.float32), apply_filter=False, region_center=None, region_size=None, logger=None)[source]

Initialize the DataPreprocessor.

Parameters:
  • years (list of int) – Years of data to include.

  • loaded_dfs (xarray.Dataset) – Pre-loaded dataset containing the weather variables.

  • constants_file_path (str) – Path to NetCDF file containing constant variables.

  • varnames_list (list of str) – List of variable names to extract.

  • units_list (list of str) – Units for each variable.

  • in_shape (tuple of int, optional) – Target shape for coarse resolution.

  • batch_size_lat (int, optional) – Height of spatial batch.

  • batch_size_lon (int, optional) – Width of spatial batch.

  • steps (EasyDict, optional) – Grid dimension information.

  • tbatch (int, optional) – Number of time batches.

  • sbatch (int, optional) – Number of spatial batches.

  • debug (bool, optional) – Enable debug logging.

  • mode (str, optional) – Operation mode.

  • run_type (str, optional) – Run type.

  • dynamic_covariates (list of str, optional) – Dynamic covariate variable names.

  • dynamic_covariates_dir (str, optional) – Directory for dynamic covariates.

  • time_normalization (str, optional) – Time normalization method.

  • norm_mapping (dict, optional) – Normalization statistics.

  • index_mapping (dict, optional) – Variable to index mapping.

  • normalization_type (dict, optional) – Normalization type per variable.

  • constant_variables (list of str, optional) – Constant variable names.

  • epsilon (float, optional) – Numerical stability value.

  • margin (int, optional) – Filter margin.

  • dtype (tuple, optional) – Data types for torch and numpy.

  • apply_filter (bool, optional) – Apply Gaussian filtering.

  • region_center (tuple of float or None) – Fixed geographic center (lat, lon) for spatial sampling.

  • logger (logging.Logger, optional) – Logger instance.

new_epoch()[source]

Prepare for a new training epoch by generating new time batches.

This method is called at the start of each training epoch to refresh the temporal and spatial sampling.

sample_time_steps_by_doy()[source]

Sample time steps based on day-of-year (DOY) for multi-year continuity.

This method selects unique DOYs from the available multi-year data and picks one random time index for each DOY.

Raises:

ValueError – If requested tbatch exceeds number of unique DOYs.

sample_random_time_indices()[source]

Generate random time indices for training.

This method samples random time indices uniformly across the available time range.

load_dynamic_covariates()[source]

Load dynamic covariates data.

get_center_indices_from_latlon(lat_value, lon_value)[source]

Convert geographic coordinates (latitude, longitude) to nearest grid indices.

Parameters:
  • lat_value (float) – Latitude in degrees.

  • lon_value (float) – Longitude in degrees.

Returns:

  • lat_idx (int) – Index of the closest latitude grid point.

  • lon_idx (int) – Index of the closest longitude grid point.

Notes

  • The dataset is defined on a discrete latitude–longitude grid.

  • Since spatial extraction operates on grid indices, the requested physical coordinates are mapped to the nearest available grid point.

  • This ensures consistency between user-defined locations and internal batch extraction logic.

generate_random_batch_centers(n_batches)[source]

Generate random (latitude, longitude) centers for batch sampling.

Parameters:

n_batches (int) – Number of random centers to generate.

Returns:

centers – List of (lat_center, lon_center) tuples.

Return type:

list of tuple

Notes

  • Latitude centers avoid poles to ensure full batch extraction.

  • Longitude centers can be any value due to cyclic wrapping.

generate_evaluation_slices()[source]

Generate deterministic spatial slices for evaluation mode.

Returns:

slices – List of (lat_start, lat_end, lon_start, lon_end) tuples defining non-overlapping spatial blocks covering the entire domain.

Return type:

list of tuple

generate_region_slices(lat_center, lon_center, region_size_lat, region_size_lon)[source]

Generate deterministic spatial slices for regional inference. The slices cover a block centered on (lat_center, lon_center). The region is divided into non-overlapping blocks of size (batch_size_lat, batch_size_lon) used for model inference.

Parameters:
  • lat_center (int) – Latitude index of the region center in the global grid.

  • lon_center (int) – Longitude index of the region center in the global grid.

  • region_size_lat (int) – Height of the region (in grid points).

  • region_size_lon (int) – Width of the region (in grid points).

Returns:

slices – List of (lat_start, lat_end, lon_start, lon_end) tuples defining non-overlapping spatial blocks covering the selected region.

Return type:

list of tuple

Notes

The latitude start index is clamped to ensure the region remains within the global latitude bounds. As a result, if the requested region is too close to the poles, the extracted region may be shifted and may no longer be centered exactly on (lat_center, lon_center). Longitude wrapping is handled later during patch extraction (in extract_batch).

extract_batch(data, ilat, ilon)[source]

Extract spatial batch centered at (ilat, ilon) with cyclic longitude.

Parameters:
  • data (torch.Tensor or np.ndarray) – Input data with shape (…, H, W) where last two dimensions are latitude and longitude.

  • ilat (int) – Latitude center index.

  • ilon (int) – Longitude center index.

Returns:

  • block (torch.Tensor or np.ndarray) – Extracted batch with shape (…, batch_size_lat, batch_size_lon).

  • indices (tuple) – Tuple of (lat_start, lat_end, lon_start, lon_end) indices.

Raises:

AssertionError – If input tensor dimensions don’t match grid dimensions or if indices are invalid.

Notes

  • Longitude is treated as cyclic (wraps around 0-360°).

  • Latitude is non-cyclic (no wrapping at poles).

  • The function rolls the data to center the longitude and then extracts the appropriate slice.

filter_batch(fine_patch, fine_block)[source]

Apply Gaussian low-pass filtering for multi-scale processing.

Parameters:
  • fine_patch (np.ndarray) – Fine-resolution data of shape (C, H, W).

  • fine_block (np.ndarray) – Reference block used to determine scaling factors.

Returns:

fine_filtered – Filtered data of shape (C, H, W).

Return type:

np.ndarray

Notes

  • Filters high-frequency components beyond the coarse grid’s Nyquist.

  • Uses Gaussian filtering in the frequency domain.

  • Processes each channel independently.

normalize(data, stats, norm_type, var_name=None, data_type=None)[source]

Normalize data using specified statistics and method.

Parameters:
  • data (torch.Tensor) – Input data to normalize.

  • stats (object) – Statistics object with attributes: vmin, vmax, vmean, vstd, median, iqr, q1, q3.

  • norm_type (str) – Normalization type: “minmax”, “minmax_11”, “standard”, “robust”, “log1p_minmax”, “log1p_standard”.

  • var_name (str, optional) – Variable name for logging.

  • data_type (str, optional) – Data type description for logging.

Returns:

Normalized data.

Return type:

torch.Tensor

Raises:

ValueError – If norm_type is not supported.

normalize_time(tindex)[source]

Return normalized time features for given time index.

Parameters:

tindex (int) – Time index.

Returns:

Dictionary of normalized time features.

Return type:

dict

Notes

Features depend on time_normalization setting: - “linear”: year_norm, doy_norm, hour_norm - “cos_sin”: year_norm, doy_sin, doy_cos, hour_sin, hour_cos

IPSL_AID.dataset.create_dummy_netcdf(temp_dir, year=2020, has_constants=False)[source]

Create a dummy NetCDF file for testing.

IPSL_AID.dataset.create_dummy_statistics_json(temp_dir)[source]

Create dummy statistics.json file for testing.

class IPSL_AID.dataset.TestDataPreprocessor(methodName='runTest', logger=None)[source]

Bases: TestCase

Unit tests for DataPreprocessor class.

__init__(methodName='runTest', logger=None)[source]

Create an instance of the class that will use the named test method when executed. Raises a ValueError if the instance does not have a method with the specified name.

setUp()[source]

Set up test fixtures.

test_stats_computation_without_existing_json()[source]

Test stats computation when no statistics.json exists.

test_stats_computation_with_existing_json()[source]

Test stats computation when statistics.json exists.

test_coarse_down_up_with_torch_tensors()[source]

Test coarse_down_up with torch tensor inputs.

test_coarse_down_up_with_numpy_arrays()[source]

Test coarse_down_up with numpy array inputs.

test_gaussian_filter_basic()[source]

Test basic Gaussian filtering operation.

test_gaussian_filter_different_epsilon()[source]

Test Gaussian filter with different epsilon values.

test_preprocessor_initialization_train_mode()[source]

Test DataPreprocessor initialization in train mode.

test_preprocessor_initialization_validation_mode()[source]

Test DataPreprocessor initialization in validation mode.

test_preprocessor_initialization_cos_sin_normalization()[source]

Test DataPreprocessor with cos_sin time normalization.

test_preprocessor_batch_size_validation()[source]

Test that batch size validation raises appropriate errors.

test_new_epoch_method()[source]

Test new_epoch method resets random centers.

test_get_center_indices_from_latlon()[source]

Test conversion from geographic coordinates to grid indices.

test_generate_random_batch_centers()[source]

Test random batch centers generation.

test_extract_batch()[source]

Test spatial batch extraction.

test_generate_evaluation_slices()[source]

Test evaluation slices generation.

test_generate_region_slices()[source]

Test regional slice generation.

test_normalize_methods()[source]

Test different normalization methods.

test_normalize_time_linear()[source]

Test time normalization with linear method.

test_normalize_time_cos_sin()[source]

Test time normalization with cos_sin method.

test_dataset_len_method()[source]

Test __len__ method.

test_getitem_train_mode()[source]

Test __getitem__ method in train mode.

test_getitem_validation_mode()[source]

Test __getitem__ method in validation mode.

test_getitem_with_filter_enabled()[source]

Test __getitem__ with apply_filter=True.

test_invalid_coordinate_handling()[source]

Test handling of invalid coordinate specifications.

test_preprocessor_without_constants()[source]

Test DataPreprocessor without constant variables.

test_invalid_time_normalization()[source]

Test invalid time normalization method.

tearDown()[source]

Clean up after tests.

IPSL_AID.diagnostics module

class IPSL_AID.diagnostics.PlotConfig[source]

Bases: object

Central configuration for all plotting functions.

DEFAULT_SAVE_DIR = './results'
DEFAULT_FIGSIZE_MULTIPLIER = 4
COLORMAPS = {'10u': 'BrBG_r', '10v': 'BrBG_r', '2t': 'rainbow', 'TP': 'Blues', 'curl': 'seismic', 'd2m': 'rainbow', 'default': 'viridis', 'dewpoint': 'rainbow', 'divergence': 'seismic', 'error': 'Reds', 'humid': 'Greens', 'humidity': 'Greens', 'mae': 'Reds', 'meridional': 'BrBG_r', 'precipitation': 'Blues', 'pres': 'viridis', 'pressure': 'viridis', 'speed': 'coolwarm', 'ssr': 'seismic', 'st': 'rainbow', 'surface temperature': 'rainbow', 'temp': 'rainbow', 'temperature': 'rainbow', 'tp': 'Blues', 'wind': 'coolwarm', 'zonal': 'BrBG_r'}
FIXED_DIFF_RANGES = {'10u': (-5.0, 5.0), '10v': (-5.0, 5.0), '2t': (-5.0, 5.0), 'T2M': (-5.0, 5.0), 'TP': (-0.5, 0.5), 'U10': (-5.0, 5.0), 'V10': (-5.0, 5.0), 'VAR_10U': (-5.0, 5.0), 'VAR_10V': (-5.0, 5.0), 'VAR_2T': (-5.0, 5.0), 'VAR_D2M': (-5.0, 5.0), 'VAR_ST': (-5.0, 5.0), 'VAR_TP': (-0.5, 0.5), 'meridional': (-5.0, 5.0), 'temperature': (-5.0, 5.0), 'tp': (-0.5, 0.5)}
FIXED_DIFF_RANGES_ERRORS = {'Humid': (0, 3.0), 'Press': (0, 3.0), 'Temp': (0, 3.0), 'VAR_10U': (0, 3.0), 'VAR_10V': (0, 3.0), 'VAR_2T': (0, 0.01), 'VAR_D2M': (0, 1.0), 'VAR_ST': (0, 1.0), 'VAR_TP': (0, 0.5), 'Wind': (0, 3.0)}
FIXED_MAE_RANGES = {'10u': (0.0, 3.0), '10v': (0.0, 3.0), '2t': (0.0, 3.0), 'T2M': (0.0, 3.0), 'TP': (0.0, 1.0), 'U10': (0.0, 3.0), 'V10': (0.0, 3.0), 'VAR_10U': (0.0, 3.0), 'VAR_10V': (0.0, 3.0), 'VAR_2T': (0.0, 3.0), 'VAR_D2M': (0.0, 3.0), 'VAR_ST': (0.0, 3.0), 'VAR_TP': (0.0, 1.0), 'meridional': (0.0, 3.0), 'temperature': (0.0, 3.0), 'tp': (0.0, 1.0)}
FIXED_SSR_RANGES = {'10u': (0.0, 3.0), '10v': (0.0, 3.0), '2t': (0.0, 3.0), 'T2M': (0.0, 3.0), 'TP': (0.0, 3.0), 'U10': (0.0, 3.0), 'V10': (0.0, 3.0), 'VAR_10U': (0.0, 3.0), 'VAR_10V': (0.0, 3.0), 'VAR_2T': (0.0, 3.0), 'VAR_D2M': (0.0, 3.0), 'VAR_ST': (0.0, 3.0), 'VAR_TP': (0.0, 1.0), 'meridional': (0.0, 3.0), 'temperature': (0.0, 3.0), 'tp': (0.0, 3.0)}
COASTLINE_w = 0.5
BORDER_w = 0.5
LAKE_w = 0.5
BORDER_STYLE = '--'
COLORBAR_h = 0.02
COLORBAR_PAD = 0.05
classmethod get_colormap(variable_name)[source]

Get appropriate colormap for a variable.

classmethod get_plot_name(variable_name)[source]

Convert variable name to readable plot name.

classmethod convert_units(variable_name, data)[source]

Safe unit conversion when required. - NEVER modifies input - Returns a new array only if conversion is needed

static get_fixed_diff_range(var_name)[source]

Get fixed visualization range for signed differences (Prediction − Truth).

static get_fixed_diff_range_errors(var_name)[source]

Get fixed visualization range for error map.

static get_fixed_mae_range(var_name)[source]

Get fixed visualization range for Mean Absolute Error (MAE).

static get_fixed_ssr_range(var_name)[source]

Get fixed visualization range for Spread Skill Ratio (SSR).

IPSL_AID.diagnostics.plot_validation_hexbin(predictions, targets, coarse_inputs=None, variable_names=None, filename='validation_hexbin.png', save_dir='./results', figsize_multiplier=4)[source]

Create hexbin plots comparing model predictions vs ground truth for all variables.

Parameters:
  • predictions (torch.Tensor or np.array) – Model predictions of shape [batch_size, num_variables, h, w]

  • targets (torch.Tensor or np.array) – Ground truth of shape [batch_size, num_variables, h, w]

  • coarse_inputs (torch.Tensor or np.array, optional) – Coarse inputs of shape [batch_size, num_variables, h, w]

  • variable_names (list of str, optional) – Names of the variables for subplot titles

  • filename (str, optional) – Output filename

  • save_dir (str, optional) – Directory to save the plot

  • figsize_multiplier (int, optional) – Base size multiplier for subplots

IPSL_AID.diagnostics.plot_comparison_hexbin(predictions, targets, coarse_inputs, variable_names=None, filename='comparison_hexbin.png', save_dir='./results', figsize_multiplier=4)[source]

Create hexbin comparison plots between model predictions, ground truth, and coarse inputs.

For each variable, creates two side-by-side hexbin plots: 1. Model predictions vs ground truth 2. Coarse inputs vs ground truth

Each plot includes an identity line and R²/MAE metrics.

Parameters:
  • predictions (torch.Tensor or np.array) – Model predictions of shape [batch_size, num_variables, h, w]

  • targets (torch.Tensor or np.array) – Ground truth of shape [batch_size, num_variables, h, w]

  • coarse_inputs (torch.Tensor or np.array) – Coarse inputs of shape [batch_size, num_variables, h, w]

  • variable_names (list of str, optional) – Names of the variables for subplot titles. If None, uses VAR_0, VAR_1, etc.

  • filename (str, optional) – Output filename

  • save_dir (str, optional) – Directory to save the plot

  • figsize_multiplier (int, optional) – Base size multiplier for subplots

Returns:

save_path – Path to the saved figure

Return type:

str

IPSL_AID.diagnostics.plot_metric_histories(valid_metrics_history, variable_names, metric_names, filename='validation_metrics', save_dir='./results', figsize_multiplier=4)[source]

Creates row-based panel plots: one figure per metric, rows = variables, shared x-axis.

Parameters:
  • valid_metrics_history (dict) – Dict from training loop storing metric histories.

  • variable_names (list of str) – Names of variables.

  • metric_names (list of str) – List of metric names (e.g. [“MAE”]).

  • filename (str) – Prefix for saved figures.

  • save_dir (str) – Directory where images are saved.

IPSL_AID.diagnostics.plot_metrics_heatmap(valid_metrics_history, variable_names, metric_names, filename='validation_metrics_heatmap', save_dir='./results', figsize_multiplier=4)[source]

Plot a heatmap of validation metrics.

Parameters:
  • valid_metrics_history (dict) – Dict from validation loop storing metric histories.

  • variable_names (list of str) – Names of variables.

  • metric_names (list of str) – List of metric names ([“MAE”, “NMAE”, “RMSE”, “R²”]).

  • filename (str) – Prefix for saved figures.

  • save_dir (str) – Directory where images are saved.

  • figsize_multiplier (float) – Controls overall figure size

IPSL_AID.diagnostics.plot_loss_histories(train_loss_history, valid_loss_history, filename='training_validation_loss.png', save_dir='./results', figsize_multiplier=4)[source]

Plots training and validation loss in a single panel.

Parameters:

train_loss_historylist or array

History of training loss values.

valid_loss_historylist or array

History of validation loss values.

filenamestr

Output image file name for the plot.

save_dirstr

Directory to save the plot.

IPSL_AID.diagnostics.plot_average_metrics(valid_metrics_history, metric_names, filename='average_metrics.png', save_dir='./results', figsize_multiplier=4)[source]

Plots average metrics across all variables in a row-based layout with shared x-axis.

Each row corresponds to one metric, plotting both:
  • average_pred_vs_fine_<metric>

  • average_coarse_vs_fine_<metric>

Parameters:
  • valid_metrics_history (dict) – Dictionary containing validation metrics history.

  • metric_names (list of str) – Names of metrics to plot.

  • filename (str) – Output image file name for the plot.

  • save_dir (str) – Directory to save the plot.

IPSL_AID.diagnostics.plot_spatiotemporal_histograms(steps, tindex_lim, centers, tindices, mode='train', filename='average_metrics.png', save_dir='./results', figsize_multiplier=4)[source]

Plot two 2D hexagonal bin histograms showing spatial-temporal data coverage: latitude center vs temporal index and longitude center vs temporal index.

This function visualizes the distribution of data samples across spatial (latitude/longitude) and temporal dimensions using hexagonal binning, which provides smoother density estimation compared to rectangular binning.

Parameters:
  • steps (EasyDict) – Dictionary containing coordinate dimensions and limits. Expected to have attributes ‘latitude’ (or ‘lat’) and ‘longitude’ (or ‘lon’) specifying the maximum spatial indices.

  • tindex_lim (tuple) – Tuple of (min_time, max_time) specifying the temporal index limits.

  • centers (list of tuples) – List of (lat_center, lon_center) coordinates for each data sample. Each center represents the spatial location of a data point.

  • tindices (list or array-like) – List of temporal indices corresponding to each data sample. Should have the same length as ‘centers’.

  • mode (str) – Dataset mode identifier, typically “train” or “validation”. Used for plot title and filename.

  • save_dir (str) – Directory path where the plot will be saved. Directory will be created if it doesn’t exist.

  • filename (str, optional) – Optional prefix to prepend to the output filename. Default is empty string.

Returns:

The function saves the plot to disk and does not return any value.

Return type:

None

Notes

  • The function creates two side-by-side subplots: 1. Latitude center index vs temporal index 2. Longitude center index vs temporal index

  • Uses hexagonal binning (hexbin) for density visualization, which reduces visual artifacts compared to rectangular histograms.

  • A single colorbar is shared between both plots with log10 scaling.

  • The color scale is normalized to the maximum count across both histograms.

  • Hexagons with zero count (mincnt=1) are not displayed.

Examples

>>> steps = EasyDict({'latitude': 180, 'longitude': 360})
>>> tindex_lim = (0, 1000)
>>> centers = [(10, 20), (15, 25), (10, 20), ...]  # list of (lat, lon)
>>> tindices = [0, 5, 10, 15, ...]  # corresponding temporal indices
>>> plot_spatiotemporal_histograms(steps, tindex_lim, centers,
...                                tindices, "train", "./plots")

The function will save a plot named “spatiotemporal_train_hexbin.png” in the “./plots” directory.

IPSL_AID.diagnostics.plot_surface(predictions, targets, coarse_inputs, lat_1d, lon_1d, timestamp=None, variable_names=None, filename='forecast_plot.png', save_dir=None, figsize_multiplier=None)[source]

Plot side-by-side forecast maps (coarse_inputs input, true target, model prediction, and difference) for one or more meteorological variables over a geographic domain.

Parameters:
  • coarse_inputs (torch.Tensor or np.ndarray) – coarse_inputs-resolution input data with shape [1, n_vars, H, W].

  • targets (torch.Tensor or np.ndarray) – Ground-truth high-resolution data with shape [1, n_vars, H, W].

  • predictions (torch.Tensor or np.ndarray) – Model predictions at targets resolution with shape [1, n_vars, H, W].

  • lat_1d (array-like) – 1D array of latitude coordinates with shape [H].

  • lon_1d (array-like) – 1D array of longitude coordinates with shape [W].

  • timestamp (datetime.datetime) – Forecast timestamp to include in the plot title.

  • variable_names (list of str, optional) – Variable names or identifiers.

  • filename (str, optional) – Output filename for saving the plot.

  • save_dir (str, optional) – Directory to save the plot.

  • figsize_multiplier (int, optional) – Base size multiplier for subplots.

Return type:

None

IPSL_AID.diagnostics.plot_ensemble_surface(predictions_ens, lat_1d, lon_1d, variable_names, timestamp=None, filename='ensemble_surface.png', save_dir='./results')[source]

Plot ensemble members, ensemble mean, and ensemble spread.

Parameters:
  • predictions_ens (torch.Tensor or np.ndarray) – Ensemble predictions of shape [n_ensemble_members, n_vars, H, W]

  • lat_1d (array-like) – 1D array of latitude coordinates with shape [H].

  • lon_1d (array-like) – 1D array of longitude coordinates with shape [W].

  • variable_names (list of str, optional) – Variable names or identifiers.

  • timestamp (datetime.datetime) – Forecast timestamp to include in the plot title.

  • filename (str, optional) – Output filename for saving the plot.

  • save_dir (str, optional) – Directory to save the plot.

  • figsize_multiplier (int, optional) – Base size multiplier for subplots.

Return type:

None

IPSL_AID.diagnostics.plot_zoom_comparison(predictions, targets, lat_1d, lon_1d, variable_names=None, filename='zoom_plot.png', save_dir=None, zoom_box=None)[source]

Plot a comparison between ground truth and model predictions with a geographic zoom.

Parameters:
  • targets (torch.Tensor or np.ndarray) – Ground-truth high-resolution data with shape [1, n_vars, H, W].

  • predictions (torch.Tensor or np.ndarray) – Model predictions at targets resolution with shape [1, n_vars, H, W].

  • lat_1d (array-like) – 1D array of latitude coordinates with shape [H].

  • lon_1d (array-like) – 1D array of longitude coordinates with shape [W].

  • variable_names (list of str, optional) – Variable names or identifiers.

  • filename (str, optional) – Output filename for saving the plot.

  • save_dir (str, optional) – Directory to save the plot.

  • zoom_box (dict, optional) – Dictionary defining the zoom region with keys.

Return type:

None

IPSL_AID.diagnostics.plot_global_surface_robinson(predictions, targets, coarse_inputs, lat_1d, lon_1d, timestamp=None, variable_names=None, filename='global_robinson.png', save_dir=None, figsize_multiplier=None)[source]

Plot coarse, truth, prediction and difference fields in Robinson projection.

Parameters:
  • coarse_inputs (torch.Tensor or np.ndarray) – coarse_inputs-resolution input data with shape [1, n_vars, H, W].

  • targets (torch.Tensor or np.ndarray) – Ground-truth high-resolution data with shape [1, n_vars, H, W].

  • predictions (torch.Tensor or np.ndarray) – Model predictions at targets resolution with shape [1, n_vars, H, W].

  • lat_1d (array-like) – 1D array of latitude coordinates with shape [H].

  • lon_1d (array-like) – 1D array of longitude coordinates with shape [W].

  • timestamp (datetime.datetime) – Forecast timestamp to include in the plot title.

  • variable_names (list of str, optional) – Variable names or identifiers.

  • filename (str, optional) – Output filename for saving the plot.

  • save_dir (str, optional) – Directory to save the plot.

  • figsize_multiplier (int, optional) – Base size multiplier for subplots.

Return type:

None

IPSL_AID.diagnostics.plot_MAE_map(predictions, targets, lat_1d, lon_1d, timestamp=None, variable_names=None, filename='validation_mae_map.png', save_dir=None, figsize_multiplier=None)[source]

Plot spatial MAE maps averaged over all time steps: MAE(x, y) = mean_t(abs(prediction - target))

Parameters:
  • predictions (torch.Tensor or np.array) – Model predictions of shape [batch_size, num_variables, h, w]

  • targets (torch.Tensor or np.array) – Ground truth of shape [batch_size, num_variables, h, w]

  • lat_1d (array-like) – 1D array of latitude coordinates with shape [H].

  • lon_1d (array-like) – 1D array of longitude coordinates with shape [W].

  • timestamp (datetime.datetime) – Forecast timestamp to include in the plot title.

  • variable_names (list of str, optional) – Variable names or identifiers.

  • filename (str, optional) – Output filename for saving the plot.

  • save_dir (str, optional) – Directory to save the plot.

  • figsize_multiplier (int, optional) – Base size multiplier for subplots.

Return type:

None

IPSL_AID.diagnostics.plot_error_map(predictions, targets, lat_1d, lon_1d, timestamp=None, variable_names=None, filename='validation_error_map.png', save_dir=None, figsize_multiplier=None)[source]

Plot spatial ERROR maps averaged over all time steps.

Parameters:
  • predictions (torch.Tensor or np.array) – Model predictions of shape [batch_size, num_variables, h, w]

  • targets (torch.Tensor or np.array) – Ground truth of shape [batch_size, num_variables, h, w]

  • lat_1d (array-like) – 1D array of latitude coordinates with shape [H].

  • lon_1d (array-like) – 1D array of longitude coordinates with shape [W].

  • timestamp (datetime.datetime) – Forecast timestamp to include in the plot title.

  • variable_names (list of str, optional) – Variable names or identifiers.

  • filename (str, optional) – Output filename for saving the plot.

  • save_dir (str, optional) – Directory to save the plot.

  • figsize_multiplier (int, optional) – Base size multiplier for subplots.

Return type:

None

IPSL_AID.diagnostics.spread_skill_ratio(predictions, targets, variable_names, pixel_wise=False)[source]

Compute spread skill ratio of predictions with respect to targets. The formula implemented is equation (15) in “Why Should Ensemble Spread Match the RMSE of the Ensemble Mean?”, Fortin et al.

Parameters:
  • predictions (torch.Tensor or np.array) – Model predictions of shape [ensemble_size, batch_size, num_variables, h, w] It is very important not to switch dimensions order. ensemble_size must be greater or equal than 2 for spread skill ratio to be computed.

  • targets (torch.Tensor or np.array) – Ground truth of shape [batch_size, num_variables, h, w]

  • variable_names (list of str, optional) – Variable names or identifiers.

  • pixel_wise (bool) – If True, computes and return the SSR for each pixel independantly. If False, computes and return the SSR averaged over all pixels and all timesteps. Defaults to False

Returns:

  • np.array of shape [num_variables, h, w] if pixel_wise == True

  • or of shape [num_variables,] if pixel_wise == False (default)

IPSL_AID.diagnostics.plot_spread_skill_ratio_map(predictions, targets, lat_1d, lon_1d, timestamp=None, variable_names=None, filename='validation_spread_skill_ratio_map.png', save_dir=None, figsize_multiplier=None)[source]

Plot spatial spread skill ratio maps averaged over all time steps for each individual pixel. The formula implemented is equation (15) in article “Why Should Ensemble Spread Match the RMSE of the Ensemble Mean?”, Fortin et al.

Parameters:
  • predictions (torch.Tensor or np.array) – Model predictions of shape [ensemble_size, batch_size, num_variables, h, w] It is very important not to switch dimensions order. ensemble_size must be greater or equal than 2 for spread skill ratio to be computed.

  • targets (torch.Tensor or np.array) – Ground truth of shape [batch_size, num_variables, h, w]

  • lat_1d (array-like) – 1D array of latitude coordinates with shape [H].

  • lon_1d (array-like) – 1D array of longitude coordinates with shape [W].

  • timestamp (datetime.datetime) – Forecast timestamp to include in the plot title.

  • variable_names (list of str, optional) – Variable names or identifiers.

  • filename (str, optional) – Output filename for saving the plot.

  • save_dir (str, optional) – Directory to save the plot.

  • figsize_multiplier (int, optional) – Base size multiplier for subplots.

Return type:

None

IPSL_AID.diagnostics.plot_spread_skill_ratio_hexbin(predictions, targets, variable_names=None, filename='validation_spread_skill_ratio_hexbin.png', save_dir=None, figsize_multiplier=None)[source]

Plot spatial spread skill ratio scatterplot, where each point represent a prediction for a single pixel, single timestep: SSR(x, y) = spread(x,y) / skill(x,y) where spread(x,y) = temporal mean of standard deviation of ensemble members predictions and skill = temporal mean of RMSE of the mean of the ensemble members.

Parameters:
  • predictions (torch.Tensor or np.array) – Model predictions of shape [ensemble_size, batch_size, num_variables, h, w] It is very important not to switch dimensions order. ensemble_size must be greater or equal than 2 for spread skill ratio to be computed.

  • targets (torch.Tensor or np.array) – Ground truth of shape [batch_size, num_variables, h, w]

  • variable_names (list of str, optional) – Variable names or identifiers.

  • filename (str, optional) – Output filename for saving the plot.

  • save_dir (str, optional) – Directory to save the plot.

  • figsize_multiplier (int, optional) – Base size multiplier for subplots.

Return type:

None

IPSL_AID.diagnostics.plot_validation_pdfs(predictions, targets, coarse_inputs=None, variable_names=None, filename='validation_pdfs.png', save_dir='./results', figsize_multiplier=4, save_npz=False)[source]

Create PDF (Probability Density Function) plots comparing distributions of model predictions vs ground truth for all variables.

Parameters:
  • predictions (torch.Tensor or np.array) – Model predictions of shape [batch_size, num_variables, h, w]

  • targets (torch.Tensor or np.array) – Ground truth of shape [batch_size, num_variables, h, w]

  • coarse_inputs (torch.Tensor or np.array, optional) – Coarse inputs of shape [batch_size, num_variables, h, w]

  • variable_names (list of str, optional) – Names of the variables for subplot titles

  • filename (str, optional) – Output filename

  • save_dir (str, optional) – Directory to save the plot

  • figsize_multiplier (int, optional) – Base size multiplier for subplots

  • save_npz (bool, optional) – If True, saves the PDF diagnostics to a compressed .npz file.

Returns:

The function saves the plot to disk and does not return any value.

Return type:

None

Notes

  • Creates horizontal subplots (one per variable) showing PDFs

  • Each subplot shows up to 3 lines: Predictions, Ground Truth, and Coarse Inputs

  • Uses automatic color and linestyle cycling based on global matplotlib settings

  • Calculates and displays key statistics for each distribution

  • Handles both PyTorch tensors and numpy arrays

Examples

>>> predictions = np.random.randn(10, 3, 64, 64)  # 10 samples, 3 variables
>>> targets = np.random.randn(10, 3, 64, 64)
>>> plot_validation_pdfs(predictions, targets, variable_names=['Temp', 'Pres', 'Humid'])
IPSL_AID.diagnostics.plot_power_spectra(predictions, targets, dlat, dlon, coarse_inputs=None, variable_names=None, filename='power_spectra_physical.png', save_dir='./results', figsize_multiplier=4, save_npz=False)[source]

Calculate and plot power spectra with proper physical wavenumbers.

Parameters:
  • predictions (torch.Tensor or np.array) – Model predictions of shape [batch_size, num_variables, nh, nw]

  • targets (torch.Tensor or np.array) – Ground truth of shape [batch_size, num_variables, nh, nw]

  • dlat (float) – Grid spacing in latitude (degrees)

  • dlon (float) – Grid spacing in longitude (degrees)

  • coarse_inputs (torch.Tensor or np.array, optional) – Coarse inputs of shape [batch_size, num_variables, nh, nw]

  • variable_names (list of str, optional) – Names of the variable names for subplot titles

  • filename (str, optional) – Output filename

  • save_dir (str, optional) – Directory to save the plot

  • figsize_multiplier (int, optional) – Base size multiplier for subplots

  • save_npz (bool, optional) – If True, saves the PDF diagnostics to a compressed .npz file.

Return type:

None

IPSL_AID.diagnostics.calculate_psd2d_simple(field)[source]

Simple 2D PSD calculation without preprocessing.

IPSL_AID.diagnostics.radial_average_psd(psd2d, k_mag, k_bins)[source]

Radially average 2D PSD using wavenumber magnitude.

IPSL_AID.diagnostics.plot_qq_quantiles(predictions, targets, coarse_inputs, variable_names=None, units=None, quantiles=[0.9, 0.95, 0.975, 0.99, 0.995], filename='qq_quantiles.png', save_dir='./results', figsize_multiplier=4, save_npz=False)[source]

Create QQ-plats at different quantiles comparing model predictions and coarse inputs against ground truth.

For each variable, plots quantiles of predictions and coarse inputs against quantiles of ground truth with a 1:1 reference line.

Parameters:
  • predictions (torch.Tensor or np.array) – Model predictions of shape [batch_size, num_variables, h, w]

  • targets (torch.Tensor or np.array) – Ground truth of shape [batch_size, num_variables, h, w]

  • coarse_inputs (torch.Tensor or np.array) – Coarse inputs of shape [batch_size, num_variables, h, w]

  • variable_names (list of str, optional) – Names of the variables for subplot titles. If None, uses [“VAR_0”, “VAR_1”, …]

  • units (list of str, optional) – Units for each variable for axis labels. If None, uses empty strings.

  • quantiles (list of float, optional) – Quantile values to plot (e.g., [0.90, 0.95, 0.975, 0.99, 0.995])

  • filename (str, optional) – Output filename

  • save_dir (str, optional) – Directory to save the plot

  • figsize_multiplier (int, optional) – Base size multiplier for subplots

  • save_npz (bool, optional) – If True, saves the PDF diagnostics to a compressed .npz file.

Returns:

save_path – Path to the saved figure

Return type:

str

IPSL_AID.diagnostics.dry_frequency_map(array, threshold)[source]

Compute spatial dry pixels proportion maps. Value of each pixel corresponds to the frequency of dry weather for this pixel.

Parameters:
  • array (torch.Tensor or np.array) – Model predictions of shape [batch_size, h, w]

  • threshold (float) – threshold for precipitation (expressed in mm): under it, pixel is considered dry.

Return type:

np.ndarray(np.float64) of shape [h,w]

IPSL_AID.diagnostics.plot_dry_frequency_map(predictions, targets, threshold, lat_1d, lon_1d, filename='validation_dry_frequency_map.png', save_dir=None, figsize_multiplier=None)[source]

Plot spatial dry pixels proportion maps. Value of each pixel corresponds to the frequency of dry weather for this pixel.

Parameters:
  • predictions (torch.Tensor or np.array) – Model predictions of shape [batch_size, h, w]

  • targets (torch.Tensor or np.array) – Ground truth of shape [batch_size, h, w]

  • threshold (float) – threshold for precipitation (expressed in mm): under it, pixel is considered dry.

  • lat_1d (array-like) – 1D array of latitude coordinates with shape [H].

  • lon_1d (array-like) – 1D array of longitude coordinates with shape [W].

  • filename (str, optional) – Output filename for saving the plot.

  • save_dir (str, optional) – Directory to save the plot.

  • figsize_multiplier (int, optional) – Base size multiplier for subplots.

Return type:

None

IPSL_AID.diagnostics.calculate_pearsoncorr_nparray(arr1, arr2, axis=0)[source]

Calculate Pearson correlation between 2 N-dimensional numpy arrays.

Parameters:

arr1numpy.ndarray

First N-dimensional array

arr2numpy.ndarray

Second N-dimensional array (must have same shape as arr1)

axisint or type of int, default=0

Axis or tuple of axes over which to compute correlation

Returns:

numpy.ndarray

Pearson correlation coefficients. Output has N - len(axis) dimensions (input shape with the specified axis/axes removed).

IPSL_AID.diagnostics.plot_validation_mvcorr_space(predictions, targets, coarse_inputs=None, variable_names=None, filename='validation_mvcorr_space.png', save_dir='./results', figsize_multiplier=4)[source]

Compute multivariate correlation over the space dimensions and plot as time-series, comparing model predictions vs ground truth, for all combinations of variables. Uses Pearson’s correlation coefficient.

Parameters:
  • predictions (torch.Tensor or np.array) – Model predictions of shape [batch_size, num_variables, h, w]

  • targets (torch.Tensor or np.array) – Ground truth of shape [batch_size, num_variables, h, w]

  • coarse_inputs (torch.Tensor or np.array, optional) – Coarse inputs of shape [batch_size, num_variables, h, w]

  • variable_names (list of str, optional) – Names of the variables for subplot titles

  • filename (str, optional) – Output filename

  • save_dir (str, optional) – Directory to save the plot

  • figsize_multiplier (int, optional) – Base size multiplier for subplots

Returns:

save_path – Path to the saved figure

Return type:

str

IPSL_AID.diagnostics.plot_validation_mvcorr(predictions, targets, lat, lon, coarse_inputs=None, variable_names=None, filename='validation_mvcorr_time.png', save_dir='./results', figsize_multiplier=4)[source]

Compute multivariate correlation over the time dimension and plot as maps, comparing model predictions vs ground truth, for all combinations of variables. Uses Pearson’s correlation coefficient.

Parameters:
  • predictions (torch.Tensor or np.array) – Model predictions of shape [batch_size, num_variables, h, w]

  • targets (torch.Tensor or np.array) – Ground truth of shape [batch_size, num_variables, h, w]

  • lat (array-like) – 2D array of latitude coordinates with shape [h, w].

  • lon (array-like) – 2D array of longitude coordinates with shape [h, w].

  • coarse_inputs (torch.Tensor or np.array, optional) – Coarse inputs of shape [batch_size, num_variables, h, w]

  • variable_names (list of str, optional) – Names of the variables for subplot titles

  • filename (str, optional) – Output filename

  • save_dir (str, optional) – Directory to save the plot

  • figsize_multiplier (int, optional) – Base size multiplier for subplots

Returns:

save_path – Path to the saved figure

Return type:

str

IPSL_AID.diagnostics.plot_temporal_series_comparison(predictions, targets, coarse_inputs=None, variable_names=None, filename='validation_temp_series.png', save_dir='./results', figsize_multiplier=4)[source]

Plot spatially averaged temporal series for each variable.

Parameters:
  • predictions (torch.Tensor or np.array) – Model predictions of shape [batch_size, num_variables, h, w]

  • targets (torch.Tensor or np.array) – Ground truth of shape [batch_size, num_variables, h, w]

  • coarse_inputs (torch.Tensor or np.array, optional) – Coarse inputs of shape [batch_size, num_variables, h, w]

  • variable_names (list of str, optional) – Names of the variables for subplot titles

  • filename (str, optional) – Output filename

  • save_dir (str, optional) – Directory to save the plot

  • figsize_multiplier (int, optional) – Base size multiplier for subplots

Returns:

save_path – Path to the saved figure

Return type:

str

IPSL_AID.diagnostics.ranks(predictions, targets)[source]

Compute ranks of predictions compared to targets.

Parameters:
  • predictions (torch.Tensor or np.array) – Model predictions of shape [ensemble_size, batch_size, h, w]

  • targets (torch.Tensor or np.array) – Targets of shape [batch_size, h, w]

Return type:

np.ndarray(np.float64) of shape [batch_size*h*w,]

IPSL_AID.diagnostics.plot_ranks(predictions, targets, variable_names=None, filename='ranks.png', save_dir='./results', figsize_multiplier=4)[source]

Create rank histograms of predictions compared to targets for each variable.

Parameters:
  • predictions (torch.Tensor or np.array) – Model predictions of shape [ensemble_size, batch_size, num_variables, h, w]

  • targets (torch.Tensor or np.array) – Ground truth of shape [batch_size, num_variables, h, w]

  • variable_names (list of str, optional) – Names of the variables for subplot titles. If None, uses [“VAR_0”, “VAR_1”, …]

  • filename (str, optional) – Output filename

  • save_dir (str, optional) – Directory to save the plot

  • figsize_multiplier (int, optional) – Base size multiplier for subplots

Returns:

save_path – Path to the saved figure

Return type:

str

IPSL_AID.diagnostics.get_divergence(u_tensor, v_tensor, spacing)[source]

Compute the horizontal divergence of a windfield.

Parameters:
  • u_tensor (torch.Tensor or np.array, shape [...,h,w]) – tensor that stores the zonal component of the windfield. Can have arbitrary number of dimensions, but the last two dimensions have to correspond to longitude and latitude. u_tensor and v_tensor need to have the same shape.

  • v_tensor (torch.Tensor or np.array) – tensor that stores the meridional component of the windfield. Can have arbitrary number of dimensions, but the last two dimensions have to correspond to longitude and latitude. u_tensor and v_tensor need to have the same shape.

  • spacing (float) – float that describes the resolution of the windfield. Used to compute the gradients.

Return type:

np.ndarray(np.float64) of same shape as u_tensor and v_tensor

IPSL_AID.diagnostics.get_curl(u_tensor, v_tensor, spacing)[source]

Compute the curl of a windfield.

Parameters:
  • u_tensor (torch.Tensor or np.array, shape [...,h,w]) – tensor that stores the zonal component of the windfield. Can have arbitrary number of dimensions, but the last two dimensions have to correspond to longitude and latitude. u_tensor and v_tensor need to have the same shape.

  • v_tensor (torch.Tensor or np.array) – tensor that stores the meridional component of the windfield. Can have arbitrary number of dimensions, but the last two dimensions have to correspond to longitude and latitude. u_tensor and v_tensor need to have the same shape.

  • spacing (float) – spatial resolution of the windfield. Used to compute the gradients.

Return type:

np.ndarray(np.float64) of same shape as u_tensor and v_tensor

IPSL_AID.diagnostics.plot_mean_divergence_map(u_prediction, v_prediction, u_target, v_target, spacing, lat_1d, lon_1d, filename='mean_divergence.png', save_dir=None, figsize_multiplier=None)[source]

Plot spatial dry pixels proportion maps. Value of each pixel corresponds to the frequency of dry weather for this pixel.

Parameters:
  • u_prediction (torch.Tensor or np.array) – Model predictions of shape [batch_size, h, w] for zonal component of wind Last two dims have to correspond to longitude and latitude u_prediction and v_prediction need to have the same shape

  • v_prediction (torch.Tensor or np.array) – Model predictions of shape [batch_size, h, w] for meridional component of wind Last two dims have to correspond to longitude and latitude u_prediction and v_prediction need to have the same shape

  • u_target (torch.Tensor or np.array) – Ground truth of shape [batch_size, h, w] Last two dims have to correspond to longitude and latitude u_target and v_target need to have the same shape

  • v_target (torch.Tensor or np.array) – Ground truth of shape [batch_size, h, w] Last two dims have to correspond to longitude and latitude u_target and v_target need to have the same shape

  • spacing (float) – spatial resolution of the windfield. Used to compute the gradients.

  • lat_1d (array-like) – 1D array of latitude coordinates with shape [H].

  • lon_1d (array-like) – 1D array of longitude coordinates with shape [W].

  • filename (str, optional) – Output filename for saving the plot.

  • save_dir (str, optional) – Directory to save the plot.

  • figsize_multiplier (int, optional) – Base size multiplier for subplots.

Return type:

None

IPSL_AID.diagnostics.plot_mean_curl_map(u_prediction, v_prediction, u_target, v_target, spacing, lat_1d, lon_1d, filename='mean_curl.png', save_dir=None, figsize_multiplier=None)[source]

Plot spatial dry pixels proportion maps. Value of each pixel corresponds to the frequency of dry weather for this pixel.

Parameters:
  • u_prediction (torch.Tensor or np.array) – Model predictions of shape [batch_size, h, w] for zonal component of wind Last two dims have to correspond to longitude and latitude u_prediction and v_prediction need to have the same shape

  • v_prediction (torch.Tensor or np.array) – Model predictions of shape [batch_size, h, w] for meridional component of wind Last two dims have to correspond to longitude and latitude u_prediction and v_prediction need to have the same shape

  • u_target (torch.Tensor or np.array) – Ground truth of shape [batch_size, h, w] Last two dims have to correspond to longitude and latitude u_target and v_target need to have the same shape

  • v_target (torch.Tensor or np.array) – Ground truth of shape [batch_size, h, w] Last two dims have to correspond to longitude and latitude u_target and v_target need to have the same shape

  • spacing (float) – spatial resolution of the windfield. Used to compute the gradients.

  • lat_1d (array-like) – 1D array of latitude coordinates with shape [H].

  • lon_1d (array-like) – 1D array of longitude coordinates with shape [W].

  • filename (str, optional) – Output filename for saving the plot.

  • save_dir (str, optional) – Directory to save the plot.

  • figsize_multiplier (int, optional) – Base size multiplier for subplots.

Return type:

None

class IPSL_AID.diagnostics.TestPlottingFunctions(methodName='runTest', logger=None)[source]

Bases: TestCase

Unit tests for plotting functions with visible output for styling adjustment.

__init__(methodName='runTest', logger=None)[source]

Create an instance of the class that will use the named test method when executed. Raises a ValueError if the instance does not have a method with the specified name.

setUp()[source]

Set up test fixtures.

test_validation_hexbin_comprehensive()[source]

Comprehensive test for validation hexbin plots.

test_validation_pdfs_comprehensive()[source]

Comprehensive test for validation PDF plots.

test_power_spectra_comprehensive()[source]

Comprehensive test for power spectra plots.

test_spatiotemporal_histograms_comprehensive()[source]

Comprehensive test for spatiotemporal histograms.

test_plot_surface_comprehensive()[source]

Comprehensive test for surface plots.

test_plot_ensemble_surface_comprehensive()[source]

Comprehensive test for ensemble surface plots.

test_plot_zoom_comparison_comprehensive()[source]

Comprehensive test for zoom comparison plots.

test_plot_global_surface_robinson_comprehensive()[source]

Comprehensive test for global Robinson surface plots.

test_plot_mae_map_comprehensive()[source]

Comprehensive test for time-averaged MAE spatial map plots.

test_plot_error_map_comprehensive()[source]

Comprehensive test for time-averaged ERROR spatial map plots.

test_plot_spread_skill_ratio_map_comprehensive()[source]

Comprehensive test for time-averaged MAE spatial map plots.

test_plot_spread_skill_ratio_hexbin_comprehensive()[source]

Comprehensive test for spread skill ratio hexbin plots

test_plot_mean_divergence_map_comprehensive()[source]

Comprehensive test for mean divergence map plots.

test_plot_mean_curl_map_comprehensive()[source]

Comprehensive test for mean curl map plots.

test_plot_dry_frequency_map_comprehensive()[source]

Comprehensive test for dry frequency map plots.

test_dry_frequency_map()[source]

Comprehensive test for the dry frequency map compute function.

test_divergence()[source]

Comprehensive test for the divergence compute function.

test_curl()[source]

Comprehensive test for the curl compute function.

test_metric_plots_comprehensive()[source]

Comprehensive test for metric plots.

test_plot_metrics_heatmap_comprehensive()[source]

Comprehensive test for validation metrics heatmap.

test_qq_quantiles_comprehensive()[source]

Comprehensive test for QQ-quantiles plots.

test_mv_correlation()[source]

Test for correlation over the time dimension for pairs of variables. Test for correlation over the spatial dimensions.

test_temporal_series_comparison_comprehensive()[source]

Comprehensive test for spatially averaged temporal series comparison.

test_ranks()[source]

Comprehensive test for the ranks compute function.

test_plot_ranks()[source]

Comprehensive test for the ranks plot function.

tearDown()[source]

Clean up after tests.

class IPSL_AID.diagnostics.TestSSRFunction(methodName='runTest', logger=None)[source]

Bases: TestCase

Unit tests for crps_ensemble_all function.

__init__(methodName='runTest', logger=None)[source]

Create an instance of the class that will use the named test method when executed. Raises a ValueError if the instance does not have a method with the specified name.

setUp()[source]

Set up test fixtures.

test_ssr_basic()[source]

Test SSR with simple known values.

test_ssr_one_when_perfect_prediction()[source]

Test SSR is supposed to be 1 when the predictions follow the same distribution as the truth.

IPSL_AID.evaluater module

class IPSL_AID.evaluater.MetricTracker[source]

Bases: object

A utility class for tracking and computing statistics of metric values.

This class maintains a running average of metric values and provides methods to compute mean and root mean squared values.

value

Cumulative weighted sum of metric values

Type:

float

count

Total number of samples processed

Type:

int

Examples

>>> tracker = MetricTracker()
>>> tracker.update(10.0, 5)  # value=10.0, count=5 samples
>>> tracker.update(20.0, 3)  # value=20.0, count=3 samples
>>> print(tracker.getmean())  # (10*5 + 20*3) / (5+3) = 110/8 = 13.75
13.75
>>> print(tracker.getsqrtmean())  # sqrt(13.75)
3.7080992435478315
__init__()[source]

Initialize MetricTracker with zero values.

reset()[source]

Reset all tracked values to zero.

Return type:

None

update(value, count)[source]

Update the tracker with new metric values.

Parameters:
  • value (float) – The metric value to add

  • count (int) – Number of samples this value represents (weight)

Return type:

None

getmean()[source]

Calculate the mean of all tracked values.

Returns:

Weighted mean of all values: total_value / total_count

Return type:

float

Raises:

ZeroDivisionError – If no values have been added (count == 0)

getstd()[source]

Calculate the standard deviation of all tracked values.

Returns:

Weighted standard deviation of all values: sqrt(E(x^2) - (E(x))^2)

Return type:

float

Raises:

ZeroDivisionError – If no values have been added (count == 0)

getsqrtmean()[source]

Calculate the square root of the mean of all tracked values.

Returns:

Square root of the weighted mean: sqrt(total_value / total_count)

Return type:

float

Raises:

ZeroDivisionError – If no values have been added (count == 0)

IPSL_AID.evaluater.mae_all(pred, true)[source]

Calculate Mean Absolute Error (MAE) between predicted and true values.

Computes the MAE metric and returns both the number of elements and the mean absolute error value.

Parameters:
  • pred (torch.Tensor) – Predicted values from the model

  • true (torch.Tensor) – Ground truth values

Returns:

(num_elements, mae_value) where: - num_elements (int): Total number of elements in the tensors - mae_value (torch.Tensor): Mean absolute error value

Return type:

tuple

Examples

>>> pred = torch.tensor([1.0, 2.0, 3.0])
>>> true = torch.tensor([1.1, 1.9, 3.2])
>>> num_elements, mae = mae_all(pred, true)
>>> print(f"MAE: {mae.item():.4f}, Elements: {num_elements}")
MAE: 0.1333, Elements: 3

Notes

The MAE is calculated as: mean(abs(pred - true)) This function is useful for tracking metrics with MetricTracker

IPSL_AID.evaluater.nmae_all(pred, true, eps=1e-08)[source]

Normalized Mean Absolute Error (NMAE). NMAE = MAE(pred, true) / mean(abs(true))

Computes the NMAE metric and returns both the number of elements and the normalized mean absolute error value.

Parameters:
  • pred (torch.Tensor) – Predicted values from the model

  • true (torch.Tensor) – Ground truth values

  • eps (float) – Small value to avoid division by zero

Returns:

(num_elements, mae_value) where: - num_elements (int): Total number of elements in the tensors - mae_value (torch.Tensor): Mean absolute error value

Return type:

tuple

Examples

>>> pred = torch.tensor([1.0, 2.0, 3.0])
>>> true = torch.tensor([1.1, 1.9, 3.2])
>>> num_elements, nmae = nmae_all(pred, true)
>>> print(f"NMAE: {nmae.item():.4f}, Elements: {num_elements}")
NMAE: 0.047059, Elements: 3

Notes

The NMAE is calculated as: MAE(pred, true) / mean(abs(true)) This function is useful for tracking metrics with MetricTracker

IPSL_AID.evaluater.crps_ensemble_all(pred_ens, true)[source]

Continuous Ranked Probability Score (CRPS) for an ensemble.

Computes the CRPS metric for ensemble predictions and returns both the number of elements and the mean CRPS value.

Parameters:
  • pred_ens (torch.Tensor) – Ensemble predictions, shape [N_ens, N_pixels]

  • true (torch.Tensor) – Ground truth values, shape [N_pixels]

Returns:

(num_elements, crps_mean) where: - num_elements (int): Total number of elements in the tensors - crps_mean (torch.Tensor): Mean CRPS

Return type:

tuple

Notes

The CRPS for an ensemble is computed as:

CRPS = E|X - y| - 0.5 * E|X - X’|

where X and X’ are independent ensemble members and y is the observation.

IPSL_AID.evaluater.rmse_all(pred, true)[source]

Calculate Root Mean Square Error (RMSE) between predicted and true values.

Computes the RMSE metric and returns both the number of elements and the root mean square error value.

Parameters:
  • pred (torch.Tensor) – Predicted values from the model

  • true (torch.Tensor) – Ground truth values

Returns:

(num_elements, rmse_value) where: - num_elements (int): Total number of elements in the tensors - rmse_value (torch.Tensor): Root mean square error value

Return type:

tuple

Examples

>>> pred = torch.tensor([1.0, 2.0, 3.0])
>>> true = torch.tensor([1.1, 1.9, 3.2])
>>> num_elements, rmse = rmse_all(pred, true)
>>> print(f"RMSE: {rmse.item():.4f}, Elements: {num_elements}")
RMSE: 0.1414, Elements: 3

Notes

The RMSE is calculated as: sqrt(mean((pred - true)^2)) This function is useful for tracking metrics with MetricTracker

IPSL_AID.evaluater.r2_all(pred, true)[source]

Calculate R2 (coefficient of determination) between predicted and true values.

Computes the R2 metric and returns both the number of elements and the R2 value.

Parameters:
  • pred (torch.Tensor) – Predicted values from the model

  • true (torch.Tensor) – Ground truth values

Returns:

(num_elements, r2_value) where: - num_elements (int): Total number of elements in the tensors - r2_value (torch.Tensor): R2 score

Return type:

tuple

Notes

R2 is calculated as:

R2 = 1 - sum((true - pred)^2) / sum((true - mean(true))^2)

This implementation is fully torch-based and works on CPU and GPU.

IPSL_AID.evaluater.pearson_all(pred, true)[source]

Compute the Pearson correlation coefficient between predicted and ground truth values using torch.corrcoef.

Parameters:
  • pred (torch.Tensor) – Predicted values from the model.

  • true (torch.Tensor) – Ground truth values.

Returns:

(num_elements, pearson_value) where: - num_elements (int): Total number of elements in the tensors. - pearson_value (torch.Tensor): Pearson correlation coefficient.

Return type:

tuple

Notes

The Pearson correlation coefficient is defined as:

rho = Cov(pred, true) / (std(pred) * std(true))

IPSL_AID.evaluater.kl_divergence_all(pred, true)[source]

Compute the Kullback–Leibler (KL) divergence between predicted and ground truth distributions using histogram-based estimation.

Parameters:
  • pred (torch.Tensor) – Predicted values from the model.

  • true (torch.Tensor) – Ground truth values.

Returns:

(num_elements, kl_value) where: - num_elements (int): Total number of elements in the tensors. - kl_value (torch.Tensor): KL divergence value.

Return type:

tuple

Notes

The KL divergence is defined as:

KL(P|Q) = sum_i P_i * log(P_i / Q_i)

where:
  • P represents the true distribution

  • Q represents the predicted distribution

IPSL_AID.evaluater.denormalize(data, stats, norm_type, device, var_name=None, data_type=None, debug=False, logger=None)[source]

Denormalize a data tensor using the inverse of the normalization operation.

Parameters:
  • data (torch.Tensor) – Normalized tensor to denormalize.

  • stats (object) – Object containing the required statistics.

  • norm_type (str) – Normalization type used originally.

  • device (torch.device) – Device for tensor operations.

  • var_name (str, optional) – Variable name for debugging.

  • data_type (str, optional) – Data type for debugging (e.g., “residual”, “coarse”).

  • debug (bool, optional) – Enable debug logging.

  • logger (Logger, optional) – Logger instance for debug output.

IPSL_AID.evaluater.edm_sampler(model, image_input, class_labels=None, num_steps=40, sigma_min=0.02, sigma_max=80.0, rho=7, S_churn=40, S_min=0, S_max=inf, S_noise=1)

EDM sampler for diffusion model inference. Original work: Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. Original source: https://github.com/NVlabs/edm

Parameters:
  • model (torch.nn.Module) – Diffusion model

  • image_input (torch.Tensor) – Conditioning input (coarse + constants)

  • class_labels (torch.Tensor, optional) – Time conditioning labels

  • num_steps (int, optional) – Number of sampling steps

  • sigma_min (float, optional) – Minimum noise level

  • sigma_max (float, optional) – Maximum noise level

  • rho (float, optional) – Time step exponent

  • S_churn (int, optional) – Stochasticity parameter

  • S_min (float, optional) – Minimum stochasticity threshold

  • S_max (float, optional) – Maximum stochasticity threshold

  • S_noise (float, optional) – Noise scale for stochasticity

Returns:

Generated residual predictions

Return type:

torch.Tensor

IPSL_AID.evaluater.sampler(epoch, batch_idx, model, image_input, class_labels=None, num_steps=18, sigma_min=None, sigma_max=None, rho=7, solver='heun', discretization='edm', schedule='linear', scaling='none', epsilon_s=0.001, C_1=0.001, C_2=0.008, M=1000, alpha=1, S_churn=40, S_min=0, S_max=inf, S_noise=1, logger=None)

General sampler for diffusion model inference with multiple configurations. Original work: Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. Original source: https://github.com/NVlabs/edm

Parameters:
  • model (torch.nn.Module) – Diffusion model

  • image_input (torch.Tensor) – Conditioning input (coarse + constants)

  • class_labels (torch.Tensor, optional) – Time conditioning labels

  • num_steps (int, optional) – Number of sampling steps

  • sigma_min (float, optional) – Minimum noise level

  • sigma_max (float, optional) – Maximum noise level

  • rho (float, optional) – Time step exponent for EDM discretization

  • solver (str, optional) – Solver type: ‘euler’ or ‘heun’

  • discretization (str, optional) – Discretization type: ‘vp’, ‘ve’, ‘iddpm’, or ‘edm’

  • schedule (str, optional) – Noise schedule: ‘vp’, ‘ve’, or ‘linear’

  • scaling (str, optional) – Scaling type: ‘vp’ or ‘none’

  • epsilon_s (float, optional) – Small epsilon for VP schedule

  • C_1 (float, optional) – Constant for IDDPM discretization

  • C_2 (float, optional) – Constant for IDDPM discretization

  • M (int, optional) – Number of steps for IDDPM discretization

  • alpha (float, optional) – Parameter for Heun’s method

  • S_churn (int, optional) – Stochasticity parameter

  • S_min (float, optional) – Minimum stochasticity threshold

  • S_max (float, optional) – Maximum stochasticity threshold

  • S_noise (float, optional) – Noise scale for stochasticity

  • logger (logging.Logger, optional) – Logger instance for logging sampler parameters

Returns:

Generated residual predictions

Return type:

torch.Tensor

IPSL_AID.evaluater.reconstruct_original_layout(epoch, args, paths, steps, all_data, dataset, device, logger)[source]

Robust reconstruction using dataset information directly.

Parameters:

all_datadict

Dictionary containing lists of batches for: - ‘predictions’: model predictions [B, C, H, W] - ‘coarse’: coarse resolution data [B, C, H, W] - ‘fine’: fine resolution ground truth [B, C, H, W] - ‘lat’: latitude coordinates [B, H] - ‘lon’: longitude coordinates [B, W]

datasettorch.utils.data.Dataset

The validation dataset instance

devicetorch.device

Device to store tensors on

loggerLogger

Logger instance for logging

Returns:

dict: Reconstructed data with metadata

IPSL_AID.evaluater.generate_residuals_norm(model, features, labels, targets, loss_fn, args, device, logger, epoch=0, batch_idx=0, inference_type='sampler')[source]

Generate normalized residuals for all variables.

Parameters:
  • model (torch.nn.Module) – Diffusion model

  • features (torch.Tensor) – Input feature tensor provided to the model

  • labels (torch.Tensor) – Conditioning labels provided to the model

  • targets (torch.Tensor) – Ground truth target tensor used for noise injection in direct inference

  • loss_fn (callable) – Loss function

  • args (argparse.Namespace) – Command line arguments

  • device (torch.device) – Training device

  • logger (Logger) – Logger instance

  • epoch (int) – Current epoch number

  • inference_type (str, optional) – Inference mode, either “direct” (deterministic) or “sampler” (stochastic diffusion sampling)

Returns:

[B, C, H, W] residuals in normalized space

Return type:

torch.Tensor

IPSL_AID.evaluater.run_validation(model, valid_dataset, valid_loader, loss_fn, norm_mapping, normalization_type, index_mapping, args, steps, device, logger, epoch, writer=None, plot_every_n_epochs=None, paths=None, compute_crps=False, crps_batch_size=2, crps_ensemble_size=10)[source]

Run validation on the model.

Parameters:
  • model (torch.nn.Module) – Diffusion model

  • valid_loader (DataLoader) – Validation data loader

  • loss_fn (callable) – Loss function

  • norm_mapping (dict) – Normalization statistics

  • normalization_type (EasyDict) – Normalization types for each variable

  • args (argparse.Namespace) – Command line arguments

  • device (torch.device) – Training device

  • logger (Logger) – Logger instance

  • epoch (int) – Current epoch number

  • writer (SummaryWriter, optional) – TensorBoard writer

  • plot_every_n_epochs (int, optional) – Frequency (in epochs) at which validation plots are generated

  • paths (dict, optional) – Paths used for saving reconstructions and plots

  • compute_crps (bool, optional) – Whether to compute CRPS using stochastic ensemble sampling

  • crps_batch_size (int, optional) – Number of validation batches used for CRPS computation

  • crps_ensemble_size (int, optional) – Number of ensemble members used to estimate CRPS

Returns:

(avg_val_loss, val_metrics) - average validation loss and metrics dictionary

Return type:

tuple

class IPSL_AID.evaluater.TestMetricTracker(methodName='runTest', logger=None)[source]

Bases: TestCase

Unit tests for MetricTracker class.

__init__(methodName='runTest', logger=None)[source]

Create an instance of the class that will use the named test method when executed. Raises a ValueError if the instance does not have a method with the specified name.

setUp()[source]

Set up test fixtures.

test_metric_tracker_init()[source]

Test MetricTracker initialization.

test_metric_tracker_reset()[source]

Test MetricTracker reset method.

test_metric_tracker_update()[source]

Test MetricTracker update method.

test_metric_tracker_getmean()[source]

Test MetricTracker getmean method.

test_metric_tracker_getstd()[source]

Test MetricTracker getstd method.

test_metric_tracker_getsqrtmean()[source]

Test MetricTracker getsqrtmean method.

test_metric_tracker_example_from_docstring()[source]

Test the example provided in the docstring.

class IPSL_AID.evaluater.TestErrorMetrics(methodName='runTest', logger=None)[source]

Bases: TestCase

Unit tests for error metrics.

__init__(methodName='runTest', logger=None)[source]

Create an instance of the class that will use the named test method when executed. Raises a ValueError if the instance does not have a method with the specified name.

setUp()[source]

Set up test fixtures.

test_basic()[source]

Test error metrics with simple tensors.

test_exact_match()[source]

Test error metrics with identical tensors.

test_multi_dimensional()[source]

Test error metrics with multi-dimensional tensors.

test_different_shapes()[source]

Test error metrics with tensors of different shapes.

test_dtype_preservation()[source]

Test that error metrics preserve data types.

test_example_from_docstring()[source]

Test error metrics examples from their docstrings.

test_kl_divergence_basic()[source]

Test KL divergence properties.

test_kl_different_shapes()[source]

KL divergence should raise RuntimeError if tensor shapes differ.

test_kl_dtype_preservation()[source]

Ensure KL divergence preserves the input tensor dtype.

test_kl_multi_dimensional()[source]

KL divergence should correctly handle multi-dimensional tensors by flattening them internally.

class IPSL_AID.evaluater.TestCRPSFunction(methodName='runTest', logger=None)[source]

Bases: TestCase

Unit tests for crps_ensemble_all function.

__init__(methodName='runTest', logger=None)[source]

Create an instance of the class that will use the named test method when executed. Raises a ValueError if the instance does not have a method with the specified name.

setUp()[source]

Set up test fixtures.

test_crps_basic()[source]

Test CRPS with simple known values.

test_crps_zero_when_perfect_prediction()[source]

Test CRPS is zero when all ensemble members equal truth.

test_crps_equals_mae_for_single_member()[source]

Test CRPS reduces to MAE when N_ens = 1.

test_crps_multi_dimensional_flatten()[source]

Test CRPS with flattened multi-dimensional data.

test_crps_dtype_preservation()[source]

Test CRPS preserves floating point dtype.

class IPSL_AID.evaluater.TestDenormalizeFunction(methodName='runTest', logger=None)[source]

Bases: TestCase

Unit tests for denormalize function.

__init__(methodName='runTest', logger=None)[source]

Create an instance of the class that will use the named test method when executed. Raises a ValueError if the instance does not have a method with the specified name.

setUp()[source]

Set up test fixtures.

test_denormalize_minmax()[source]

Test denormalize with minmax normalization.

test_denormalize_minmax_11()[source]

Test denormalize with minmax_11 normalization.

test_denormalize_standard()[source]

Test denormalize with standard normalization.

test_denormalize_robust()[source]

Test denormalize with robust normalization.

test_denormalize_log1p_minmax()[source]

Test denormalize with log1p_minmax normalization.

test_denormalize_log1p_standard()[source]

Test denormalize with log1p_standard normalization.

test_denormalize_zero_denominator()[source]

Test denormalize with zero denominator.

test_denormalize_unsupported_type()[source]

Test denormalize with unsupported normalization type.

class IPSL_AID.evaluater.TestRunValidation(methodName='runTest', logger=None)[source]

Bases: TestCase

Unit tests for run_validation function focusing on return values verification.

__init__(methodName='runTest', logger=None)[source]

Create an instance of the class that will use the named test method when executed. Raises a ValueError if the instance does not have a method with the specified name.

setUp()[source]

Set up test fixtures.

test_val_loss_and_metrics_across_3_batches_consistent_shape()[source]

Verify avg_val_loss and val_metrics with 3 batches of consistent shape.

test_crps_zero_when_predictions_equal_fine()[source]

Verify that CRPS is zero when all ensemble predictions exactly match the fine target.

test_generate_residuals_matches_fine()[source]

Final prediction (coarse + residuals) should exactly match fine when residuals = fine - coarse.

tearDown()[source]

Clean up after tests.

IPSL_AID.logger module

class IPSL_AID.logger.Logger(console_output=True, file_output=False, log_file='module_log_file.log', pretty_print=True, record=False)[source]

Bases: object

__init__(console_output=True, file_output=False, log_file='module_log_file.log', pretty_print=True, record=False)[source]
clear_logs()[source]

Clear the stored Rich logs if record=True.

show_header(module_name)[source]

Display startup banner.

start_task(task_name: str, description: str = '', **meta)[source]

Display a clearly formatted ‘task start’ message with good spacing.

log_metrics()[source]

Log pipeline metrics

info(message)[source]

Formatted info message

warning(message)[source]

Formatted warning message

success(message)[source]

Custom success level (not default logging level)

step(step_name, message)[source]

Highlight pipeline step events

exception(message, exception=None)[source]

Display a formatted exception message with visual stack trace.

error(message, exception=None)[source]

Display a formatted error log, optionally including exception trace.

class IPSL_AID.logger.TestLogger(methodName='runTest', logger=None)[source]

Bases: TestCase

Unit tests for Logger class.

__init__(methodName='runTest', logger=None)[source]

Create an instance of the class that will use the named test method when executed. Raises a ValueError if the instance does not have a method with the specified name.

setUp()[source]

Set up test fixtures.

test_initialization_default()[source]

Test Logger initialization with default parameters.

test_initialization_with_file_output()[source]

Test Logger initialization with file output enabled.

test_initialization_with_record()[source]

Test Logger initialization with record mode enabled.

test_initialization_console_only()[source]

Test Logger initialization with console only.

test_initialization_file_only()[source]

Test Logger initialization with file only.

test_clear_logs_with_record()[source]

Test clear_logs method when record is True.

test_clear_logs_without_record()[source]

Test clear_logs method when record is False.

test_show_header_console()[source]

Test show_header method with console output.

test_show_header_file()[source]

Test show_header method with file output.

test_show_header_both()[source]

Test show_header method with both console and file output.

test_start_task_minimal()[source]

Test start_task method with minimal parameters.

test_start_task_with_description()[source]

Test start_task method with description.

test_start_task_with_metadata()[source]

Test start_task method with metadata.

test_start_task_file_output()[source]

Test start_task method with file output.

test_info_console()[source]

Test info method with console output.

test_info_file()[source]

Test info method with file output.

test_warning_console()[source]

Test warning method with console output.

test_warning_file()[source]

Test warning method with file output.

test_success_console()[source]

Test success method with console output.

test_success_file()[source]

Test success method with file output.

test_step_console()[source]

Test step method with console output.

test_step_file()[source]

Test step method with file output.

test_error_without_exception()[source]

Test error method without exception.

test_error_with_exception()[source]

Test error method with exception.

test_error_file_output()[source]

Test error method with file output.

test_exception_without_exception()[source]

Test exception method without exception object.

test_exception_with_exception()[source]

Test exception method with exception object.

test_exception_file_output()[source]

Test exception method with file output.

test_log_metrics_empty()[source]

Test log_metrics with empty metrics.

test_log_metrics_with_data()[source]

Test log_metrics with populated metrics.

test_log_metrics_file_output()[source]

Test log_metrics with file output.

test_full_logging_cycle()[source]

Test a complete logging cycle with all methods.

test_progress_bar()[source]

Test progress bar functionality.

test_logger_with_unicode()[source]

Test logger with unicode characters.

tearDown()[source]

Clean up after tests.

IPSL_AID.loss module

Diffusion model loss functions and testing utilities.

This module implements various loss functions for diffusion models including: - VPLoss: Variance Preserving loss from Score-Based Generative Modeling - VELoss: Variance Exploding loss from Score-Based Generative Modeling - EDMLoss: Improved loss from Elucidating the Design Space of Diffusion-Based Generative Models

class IPSL_AID.loss.VPLoss(beta_d=19.9, beta_min=0.1, epsilon_t=1e-05)[source]

Bases: object

Loss function for Variance Preserving (VP) formulation diffusion models.

This class implements the loss function for the Variance Preserving SDE formulation of diffusion models. It follows the continuous-time training objective from score-based generative modeling through stochastic differential equations.

Parameters:
  • beta_d (float, optional) – Maximum β parameter controlling the extent of the noise schedule. Larger values lead to faster noise increase. Default is 19.9.

  • beta_min (float, optional) – Minimum β parameter controlling the initial slope of the noise schedule. Default is 0.1.

  • epsilon_t (float, optional) – Minimum time value threshold to avoid numerical issues near t=0. Default is 1e-5.

beta_d

Maximum β parameter for noise schedule.

Type:

float

beta_min

Minimum β parameter for noise schedule.

Type:

float

epsilon_t

Minimum time threshold.

Type:

float

__call__(net, images, conditional_img=None, labels=None, augment_pipe=None)[source]

Compute the VP loss for a batch of images.

sigma(t)[source]

Compute noise level sigma for given timestep t.

Notes

  • The loss is based on denoising score matching: E[λ(t) * ||D_θ(x_t, t) - x_0||²]

  • The weighting function λ(t) = 1/σ(t)² gives equal importance to all noise levels.

  • Time t is uniformly sampled between [epsilon_t, 1] during training.

  • This loss corresponds to training the model to predict the clean image x_0 from noisy input x_t = x_0 + σ(t)·ε.

References

  • Song et al., “Score-Based Generative Modeling through Stochastic Differential Equations”, 2020.

__init__(beta_d=19.9, beta_min=0.1, epsilon_t=1e-05)[source]

Initialize the VPLoss function.

Parameters:
  • beta_d (float, optional) – Maximum β parameter for noise schedule. Default is 19.9.

  • beta_min (float, optional) – Minimum β parameter for noise schedule. Default is 0.1.

  • epsilon_t (float, optional) – Minimum time threshold. Default is 1e-5.

sigma(t)[source]

Compute noise level sigma for given timestep t.

Parameters:

t (torch.Tensor or float) – Timestep value(s) in [epsilon_t, 1].

Returns:

Noise level sigma corresponding to t, with same shape as input.

Return type:

torch.Tensor

Notes

The noise schedule follows: σ(t) = sqrt(exp(0.5*β_d*t² + β_min*t) - 1)

This ensures smooth transition from low to high noise levels, with σ(0) ≈ 0 and σ(1) determined by β_d and β_min.

class IPSL_AID.loss.VELoss(sigma_min=0.02, sigma_max=100)[source]

Bases: object

Loss function for Variance Exploding (VE) formulation diffusion models.

This class implements the loss function for the Variance Exploding SDE formulation of diffusion models. It follows the continuous-time training objective from score-based generative modeling through stochastic differential equations.

Parameters:
  • sigma_min (float, optional) – Minimum noise level. Controls the lower bound of the noise schedule. Smaller values allow modeling finer details. Default is 0.02.

  • sigma_max (float, optional) – Maximum noise level. Controls the upper bound of the noise schedule. Larger values allow modeling broader structure. Default is 100.

sigma_min

Minimum noise level for the geometric schedule.

Type:

float

sigma_max

Maximum noise level for the geometric schedule.

Type:

float

__call__(net, images, conditional_img=None, labels=None, augment_pipe=None)[source]

Compute the VE loss for a batch of images.

Notes

  • The VE formulation uses a geometric noise schedule: σ(t) = σ_min * (σ_max/σ_min)^t

  • Time t is uniformly sampled between [0, 1] during training.

  • The weighting function λ(t) = 1/σ(t)² gives more emphasis to lower noise levels.

  • This corresponds to training the model to predict the clean image x_0 from noisy input x_t = x_0 + σ(t)·ε.

  • The geometric schedule provides a simple and effective way to span a wide range of noise levels with a single parameter.

References

  • Song et al., “Score-Based Generative Modeling through Stochastic Differential Equations”, 2020.

__init__(sigma_min=0.02, sigma_max=100)[source]

Initialize the VELoss function.

Parameters:
  • sigma_min (float, optional) – Minimum noise level for the geometric schedule. Default is 0.02.

  • sigma_max (float, optional) – Maximum noise level for the geometric schedule. Default is 100.

class IPSL_AID.loss.EDMLoss(P_mean=-1.2, P_std=1.2, sigma_data=1.0)[source]

Bases: object

EDM (Elucidating Diffusion Models) loss function for diffusion models.

This class implements the improved loss function from the EDM paper, which uses a log-normal distribution for noise level sampling and an optimized weighting scheme for better training stability and sample quality.

Parameters:
  • P_mean (float, optional) – Mean parameter for the log-normal distribution of sigma. Controls the center of the noise level distribution. Default is -1.2.

  • P_std (float, optional) – Standard deviation parameter for the log-normal distribution of sigma. Controls the spread of the noise level distribution. Default is 1.2.

  • sigma_data (float, optional) – Standard deviation of the training data. Used in the weighting function to balance the loss across noise levels. Default is 1.0.

P_mean

Mean of log-normal distribution for sigma sampling.

Type:

float

P_std

Standard deviation of log-normal distribution for sigma sampling.

Type:

float

sigma_data

Training data standard deviation.

Type:

float

__call__(net, images, conditional_img=None, labels=None, augment_pipe=None)[source]

Compute the EDM loss for a batch of images.

Notes

  • The EDM loss uses a log-normal distribution for sigma: σ ~ logNormal(P_mean, P_std)

  • The weighting function: λ(σ) = (σ² + σ_data²) / (σ·σ_data)²

  • This weighting minimizes the variance of the loss gradient, leading to more stable training and faster convergence.

  • The loss corresponds to training the model to predict the clean image x_0 from noisy input x_t = x_0 + σ·ε.

  • The log-normal distribution provides a better prior for noise levels compared to uniform sampling.

References

  • Karras et al., “Elucidating the Design Space of Diffusion-Based Generative Models”, 2022.

__init__(P_mean=-1.2, P_std=1.2, sigma_data=1.0)[source]

Initialize the EDMLoss function.

Parameters:
  • P_mean (float, optional) – Mean parameter for log-normal distribution. Default is -1.2.

  • P_std (float, optional) – Standard deviation parameter for log-normal distribution. Default is 1.2.

  • sigma_data (float, optional) – Standard deviation of training data. Default is 1.0.

class IPSL_AID.loss.UnetLoss(loss_type='mse', reduction='mean')[source]

Bases: object

Simple UNet loss function for direct image-to-image prediction.

This loss function works with UNet models that predict images directly, without any diffusion noise process. It’s a standard supervised loss for image generation/transformation tasks such as segmentation, denoising, super-resolution, or autoencoding.

Parameters:
  • loss_type (str, optional) – Type of loss function to use: - mse: Mean Squared Error (L2 loss) - l1: Mean Absolute Error (L1 loss) - smooth_l1: Smooth L1 loss (Huber loss) Default is mse.

  • reduction (str, optional) – Reduction method for the loss: - mean: Average the loss over all elements - sum: Sum the loss over all elements - none: Return loss per element Default is mean.

loss_type

Type of loss function.

Type:

str

reduction

Reduction method.

Type:

str

loss_fn

PyTorch loss function instance.

Type:

torch.nn.Module

Raises:

ValueError – If an unknown loss_type is provided.

Notes

  • The loss computes the discrepancy between the model’s output and the input image.

  • This is suitable for autoencoder-style tasks where the model learns to reconstruct the input.

  • For conditional generation, labels can be provided to the model.

  • Data augmentation can be applied via augment_pipe.

__init__(loss_type='mse', reduction='mean')[source]

Initialize the UnetLoss function.

Parameters:
  • loss_type (str, optional) – Type of loss function. Default is mse.

  • reduction (str, optional) – Reduction method. Default is mean.

class IPSL_AID.loss.TestLosses(methodName='runTest', logger=None)[source]

Bases: TestCase

Unit tests for diffusion models and loss functions.

__init__(methodName='runTest', logger=None)[source]

Create an instance of the class that will use the named test method when executed. Raises a ValueError if the instance does not have a method with the specified name.

setUp()[source]

Set up test fixtures.

test_vp_loss()[source]

Test VP loss function.

test_ve_loss()[source]

Test VE loss function.

test_edm_loss()[source]

Test EDM loss function.

test_unet_loss()[source]

Test UnetLoss function.

test_loss_comparison()[source]

Compare different loss functions on the same model.

test_loss_with_augmentation()[source]

Test loss functions with data augmentation.

test_loss_gradients()[source]

Test that loss computation supports gradient computation.

tearDown()[source]

Clean up after tests.

IPSL_AID.main module

IPSL_AID.main.parse_args()[source]

Parse command line arguments for diffusion model training and inference.

This function defines and parses all command line arguments required for configuring and running diffusion model training, resumption, or inference experiments. It provides comprehensive options for data loading, model architecture, training hyperparameters, and output management.

Returns:

Parsed command line arguments as a namespace object with attributes corresponding to each argument.

Return type:

argparse.Namespace

Notes

  • Arguments are organized into logical groups: execution mode, data configuration, training configuration, model architecture, and output.

  • Boolean arguments use string conversion with lambda functions for flexibility (accepts “true”/”false”, “True”/”False”, etc.).

  • Default values are provided for most parameters to allow minimal configuration for basic usage.

  • Some arguments have constraints or choices to ensure valid configurations.

IPSL_AID.main.make_divisible_hw(h, w, n)[source]

Adjust height and width to be divisible by 2**n by decrementing.

This function ensures that both the height (h) and width (w) are divisible by 2 raised to the power n, which is often required for neural network architectures that use pooling or strided convolutions multiple times.

Parameters:
  • h (int) – Original height value.

  • w (int) – Original width value.

  • n (int) – Exponent for divisor calculation. The divisor is 2**n.

Returns:

  • h_new (int) – Adjusted height that is divisible by 2**n.

  • w_new (int) – Adjusted width that is divisible by 2**n.

Notes

  • The function decrements h and w until they become divisible by 2**n.

  • This is a common requirement for U-Net and other encoder-decoder architectures that use multiple downsampling and upsampling operations.

  • The adjustment is conservative (decrementing) to avoid adding padding, which might be important for maintaining exact spatial relationships.

IPSL_AID.main.setup_directories_and_logging(args)[source]

Set up directory structure and logging infrastructure for experiments.

This function creates a standardized directory hierarchy for organizing experiment outputs (logs, results, model checkpoints, etc.) and initializes a logging system with both console and file output.

Parameters:

args (argparse.Namespace or EasyDict) –

Configuration object containing the following attributes:

  • main_folderstr

    Main experiment folder name.

  • sub_folderstr

    Sub-folder name for the current run.

  • prefixstr

    Prefix for log files and outputs.

  • datadirstr

    Base data directory path.

  • constant_varnames_filestr

    Filename for constant variables data.

Returns:

  • paths (EasyDict) – Dictionary containing paths to created directories:

    • logsstr

      Path to log files directory.

    • resultsstr

      Path to results output directory.

    • runsstr

      Path to experiment run tracking directory.

    • checkpointsstr

      Path to model checkpoint directory.

    • statsstr

      Path to statistics and metrics directory.

    • datadirstr

      Original data directory path.

    • constantsstr

      Full path to constant variables file.

  • logger (Logger) – Configured logger instance with console and file output.

Notes

  • Directory structure:

    logs/main_folder/sub_folder/ results/main_folder/sub_folder/ runs/main_folder/sub_folder/ checkpoints/main_folder/sub_folder/ stats/main_folder/sub_folder/

  • Log files are named with timestamp: {prefix}_log.txt

  • The logger outputs to both console and file by default.

  • All directories are created if they don’t exist (via FileUtils.makedir).

IPSL_AID.main.log_configuration(args, paths, logger)[source]

Log all configuration parameters to the provided logger.

This function comprehensively logs all experiment configuration parameters including execution mode, data settings, training hyperparameters, model architecture, and directory structure. It provides a clear overview of the experiment setup for reproducibility and debugging.

Parameters:
  • args (argparse.Namespace or EasyDict) – Configuration object containing all experiment parameters.

  • paths (EasyDict) – Dictionary containing paths to various experiment directories.

  • logger (Logger) – Logger instance for outputting configuration information.

Notes

  • The function organizes parameters into logical sections for readability.

  • Includes both user-specified parameters and derived directory paths.

  • Provides warnings for important configuration choices (e.g., disabled checkpoint saving).

  • The output is formatted with clear section headers and indentation.

IPSL_AID.main.setup_data_paths(args, paths, logger)[source]

Set up data file paths, load datasets, and compute normalization statistics.

This function handles the data loading pipeline for training and validation datasets. It manages per-variable data paths, concatenates multi-year data for each variable, computes normalization statistics, and sets up variable mappings and normalization types.

Parameters:
  • args (argparse.Namespace or EasyDict) – Configuration object containing runtime options such as training years, execution mode, variable names, and normalization specifications.

  • paths (EasyDict) –

    Dictionary containing directory paths.

    Expected keys: - datadir - stats

  • logger (logging.Logger) – Logger instance for output messages.

Returns:

  • norm_mapping (dict) – Mapping from variable name to normalization statistics.

  • steps (EasyDict) – Grid dimension information (time, latitude, longitude).

  • normalization_type (EasyDict) – Mapping from variable name to normalization method.

  • index_mapping (dict) – Mapping from variable name to array index.

  • train_ds (xarray.Dataset or None) – Training dataset, or None if run_type is inference.

  • valid_ds (xarray.Dataset) – Validation dataset.

Notes

  • Per-variable data directories may be provided using VAR=path syntax.

  • Training data is only loaded when run_type is not inference.

  • Normalization statistics are computed on the validation dataset.

  • Variables from different files and years are merged into a single dataset.

IPSL_AID.main.setup_training_environment(args, logger)[source]

Set up the training environment including device selection, random seeds, and data type configuration.

Parameters:
Returns:

  • device (torch.device) – Selected computing device.

  • torch_dtype (torch.dtype) – PyTorch data type.

  • np_dtype (numpy.dtype) – NumPy data type.

  • use_fp16 (bool) – Whether half precision is enabled.

Notes

  • Sets global random seeds for reproducibility.

  • Automatically selects CUDA if available.

  • Enables PyTorch anomaly detection for debugging.

IPSL_AID.main.create_data_loaders(args, paths, norm_mapping, steps, normalization_type, index_mapping, torch_dtype, np_dtype, logger, mode='train', run_type='train', train_loaded_dfs=None, valid_loaded_dfs=None)[source]

Create data loaders for training, validation, or inference.

Parameters:
  • args (argparse.Namespace or EasyDict) – Runtime configuration options.

  • paths (EasyDict) – Directory paths including constants files.

  • norm_mapping (dict) – Normalization statistics per variable.

  • steps (EasyDict) – Grid dimension information.

  • normalization_type (EasyDict) – Normalization method per variable.

  • index_mapping (dict) – Variable-to-index mapping.

  • torch_dtype (torch.dtype) – PyTorch tensor dtype.

  • np_dtype (numpy.dtype) – NumPy array dtype.

  • logger (logging.Logger) – Logger instance.

  • mode (str, optional) – Either train or validation.

  • run_type (str, optional) – Execution mode (train, resume_train, inference).

  • train_loaded_dfs (dict, optional) – Pre-loaded training datasets.

  • valid_loaded_dfs (dict, optional) – Pre-loaded validation datasets.

Returns:

  • data_loader (torch.utils.data.DataLoader) – Configured data loader.

  • img_res (tuple of int) – Spatial resolution used by the model.

  • dataset (DataPreprocessor) – Underlying dataset object.

Raises:

Notes

  • Spatial dimensions are adjusted to be divisible by powers of two.

  • Validation falls back to training data if validation data is unavailable.

  • Data is assumed to be pre-loaded into memory.

IPSL_AID.main.setup_model(args, img_res, use_fp16, device, logger)[source]

Set up the diffusion model and its loss function.

Parameters:
  • args (argparse.Namespace or EasyDict) – Model configuration options.

  • img_res (tuple of int) – Image resolution (height, width).

  • use_fp16 (bool) – Whether FP16 precision is enabled.

  • device (torch.device) – Target device.

  • logger (logging.Logger) – Logger instance.

Returns:

  • model (torch.nn.Module) – Initialized model.

  • loss_fn (callable) – Loss function.

Raises:

ValueError – If an unsupported time normalization is specified.

Notes

  • Label dimensionality depends on the selected time normalization.

  • Model creation is delegated to load_model_and_loss.

IPSL_AID.main.resolve_region_center(args)[source]

Resolve the regional inference center coordinates.

This function enforces the logic for regional inference: user can provide either region or region_center

Parameters:

args (argparse.Namespace) – Parsed command line arguments.

Returns:

(lat, lon) if inference_regional, None otherwise.

Return type:

tuple or None

Raises:

ValueError – If both region and region_center are provided, neither is provided in inference_regional mode, an unknown region name is specified.

Notes

  • Predefined regions map to fixed center coordinates.

  • Longitude follows the convention [0, 360].

IPSL_AID.main.main()[source]

Main training and inference pipeline for IPSL-AID diffusion models.

This function orchestrates the entire training and inference process for diffusion-based generative models on weather and climate data. It handles: - Argument parsing and configuration - Directory setup and logging - Data loading and preprocessing - Model initialization and checkpoint management - Training loop with validation - Inference execution - Visualization and result saving

The pipeline supports multiple modes of operation: - Training from scratch (run_type=’train’) - Resuming training from a checkpoint (run_type=’resume_train’) - Running inference with a trained model (run_type=’inference’)

The function follows a structured workflow: 1. Parse command line arguments 2. Setup directories and logging 3. Load and preprocess data 4. Initialize model, optimizer, and loss function 5. Handle checkpoint loading if required 6. Execute training loop with validation or run inference 7. Generate plots and save results

Parameters:

None – All configuration is provided via command line arguments.

Return type:

None

Notes

  • The function uses argparse for command line argument parsing.

  • All output (logs, checkpoints, results) is saved to organized directories.

  • Training includes validation at each epoch with metrics tracking.

  • Inference mode runs validation metrics without training.

  • Mixed precision training (FP16) is supported when available.

  • Model checkpoints include full training state for resumption.

  • TensorBoard integration is provided for training visualization.

Raises:
  • FileNotFoundError – If required checkpoints are not found for resumption or inference.

  • RuntimeError – If inference mode is requested without validation data.

  • ValueError – If invalid configurations are provided.

IPSL_AID.model module

IPSL_AID.model.load_model_and_loss(opts, logger=None, device='cpu')[source]

Load a diffusion model or U-Net with corresponding loss function.

This function initializes and configures a generative model (diffusion or direct U-Net) along with its corresponding loss function based on the provided options. It supports multiple architectures and preconditioning schemes.

Parameters:
  • opts (EasyDict or dict) –

    Configuration dictionary containing model parameters. Must include:

    • archstr

      Architecture type: ‘ddpmpp’, ‘ncsnpp’, or ‘adm’.

    • precondstr

      Preconditioning type: ‘vp’, ‘ve’, ‘edm’, or ‘unet’.

    • img_resolutionint or tuple

      Image resolution (height, width).

    • in_channelsint

      Number of input channels.

    • out_channelsint

      Number of output channels.

    • label_dimint

      Dimension of label conditioning (0 for unconditional).

    • use_fp16bool

      Whether to use mixed precision (FP16).

    • model_kwargsdict, optional

      Additional model-specific parameters to override defaults.

  • logger (logging.Logger, optional) – Logger instance for output messages. If None, uses print(). Default is None.

  • device (str or torch.device, optional) – Device to load the model onto (‘cpu’, ‘cuda’, etc.). Default is ‘cpu’.

Returns:

  • model (torch.nn.Module) – Initialized model instance (preconditioner or U-Net).

  • loss_fn (torch.nn.Module or callable) – Corresponding loss function for the model.

Raises:

ValueError – If an invalid architecture or preconditioner type is specified.

Notes

  • The function supports three main architectures:
    • DDPM++ (Song et al., 2020) with VP preconditioning

    • NCSN++ (Song et al., 2020) with VE preconditioning

    • ADM (Dhariwal & Nichol, 2021) with EDM preconditioning

  • When precond=’unet’, uses a direct U-Net without diffusion preconditioning.

  • Model parameters are counted and logged for transparency.

  • Default hyperparameters are provided for each architecture but can be overridden via opts.model_kwargs.

class IPSL_AID.model.TestModelLoader(methodName='runTest', logger=None)[source]

Bases: TestCase

Unit tests for model and loss loader.

__init__(methodName='runTest', logger=None)[source]

Create an instance of the class that will use the named test method when executed. Raises a ValueError if the instance does not have a method with the specified name.

setUp()[source]

Set up test fixtures.

test_ddpmpp_vp_combination()[source]

Test DDPM++ architecture with VP preconditioner.

test_ncsnpp_ve_combination()[source]

Test NCSN++ architecture with VE preconditioner.

test_adm_edm_combination()[source]

Test ADM architecture with EDM preconditioner.

test_adm_unet_combination()[source]

Using ADM architecture as direct U-Net without preconditioning.

test_rectangular_resolution()[source]

Test loader with rectangular resolution.

test_model_kwargs_override()[source]

Test that model_kwargs can override default settings.

test_no_conditional_channels()[source]

Test loader without conditional channels.

test_invalid_combinations()[source]

Test that invalid combinations raise appropriate errors.

tearDown()[source]

Clean up after tests.

IPSL_AID.model_utils module

class IPSL_AID.model_utils.ModelUtils[source]

Bases: object

Utility class for model inspection, checkpointing, and memory profiling.

This class provides static methods for common model operations including parameter counting, memory usage analysis, checkpoint management, and model inspection.

Examples

>>> utils = ModelUtils()
>>> param_counts = ModelUtils.get_parameter_number(model)
>>> ModelUtils.save_checkpoint(state, "checkpoint.pth.tar", logger)
__init__()[source]

Initialize ModelUtils instance.

static get_parameter_number(model, logger=None)[source]

Calculate the total and trainable number of parameters in a model.

Parameters:
  • model (torch.nn.Module) – PyTorch model to inspect

  • logger (Logger, optional) – Logger instance for output, by default None

Returns:

Dictionary containing: - ‘Total’: Total number of parameters - ‘Trainable’: Number of trainable parameters

Return type:

dict

Examples

>>> model = torch.nn.Linear(10, 5)
>>> counts = ModelUtils.get_parameter_number(model, logger)
static print_model_layers(model, logger=None)[source]

Print model parameter names along with their gradient requirements.

Parameters:
  • model (torch.nn.Module) – PyTorch model to inspect

  • logger (Logger, optional) – Logger instance for output, by default None

Examples

>>> model = torch.nn.Sequential(
...     torch.nn.Linear(10, 5),
...     torch.nn.ReLU(),
...     torch.nn.Linear(5, 1)
... )
>>> ModelUtils.print_model_layers(model, logger)
static save_checkpoint(state, filename='checkpoint.pth.tar', logger=None)[source]

Save model and optimizer state to a file.

Parameters:
  • state (dict) – Dictionary containing model state_dict and other training information. Typically includes: - ‘state_dict’: Model parameters - ‘optimizer’: Optimizer state - ‘epoch’: Current epoch - ‘loss’: Current loss value

  • filename (str, optional) – File path to save the checkpoint, by default “checkpoint.pth.tar”

  • logger (Logger, optional) – Logger instance for output, by default None

Examples

>>> state = {
...     'state_dict': model.state_dict(),
...     'optimizer': optimizer.state_dict(),
...     'epoch': epoch,
...     'loss': loss
... }
>>> ModelUtils.save_checkpoint(state, 'model_checkpoint.pth.tar', logger)
static load_checkpoint(checkpoint, model, optimizer=None, logger=None)[source]

Load model and optimizer state from a checkpoint file.

Parameters:
  • checkpoint (dict) – Loaded checkpoint dictionary

  • model (torch.nn.Module) – Model to load weights into

  • optimizer (torch.optim.Optimizer, optional) – Optimizer to restore state, by default None

  • logger (Logger, optional) – Logger instance for output, by default None

Examples

>>> checkpoint = torch.load('model_checkpoint.pth.tar')
>>> ModelUtils.load_checkpoint(checkpoint, model, optimizer, logger)
static load_training_checkpoint(checkpoint_path, model, optimizer, device, logger=None)[source]

Load comprehensive training checkpoint.

Parameters:
  • checkpoint_path (str) – Path to checkpoint file

  • model (torch.nn.Module) – Model to load weights into

  • optimizer (torch.optim.Optimizer) – Optimizer to restore state

  • device (torch.device) – Device to load checkpoint to

  • logger (Logger, optional) – Logger instance for output

Returns:

(epoch, samples_processed, batches_processed, best_val_loss, best_epoch, checkpoint)

Return type:

tuple

static count_parameters_by_layer(model, logger=None)[source]

Count parameters for each layer in the model.

Parameters:
  • model (torch.nn.Module) – PyTorch model to analyze

  • logger (Logger, optional) – Logger instance for output, by default None

Returns:

Dictionary with layer names as keys and parameter counts as values

Return type:

dict

Examples

>>> layer_params = ModelUtils.count_parameters_by_layer(model, logger)
static log_model_summary(model, input_shape=None, logger=None)[source]

Log comprehensive model summary including parameters and architecture.

Parameters:
  • model (torch.nn.Module) – PyTorch model to summarize

  • input_shape (tuple, optional) – Input shape for memory analysis, by default None

  • logger (Logger, optional) – Logger instance for output, by default None

static save_training_checkpoint(model, optimizer, epoch, samples_processed, batches_processed, train_loss_history, valid_loss_history, valid_metrics_history, best_val_loss, best_epoch, avg_val_loss, avg_epoch_loss, args, paths, logger, checkpoint_type='epoch', save_full_model=True)[source]

Save comprehensive training checkpoint with consistent formatting.

Parameters:
  • model (torch.nn.Module) – Model to save

  • optimizer (torch.optim.Optimizer) – Optimizer to save

  • epoch (int) – Current epoch

  • samples_processed (int) – Number of samples processed so far

  • batches_processed (int) – Number of batches processed so far

  • train_loss_history (list) – History of training losses

  • valid_loss_history (list) – History of validation losses

  • valid_metrics_history (dict) – History of validation metrics

  • best_val_loss (float) – Best validation loss so far

  • best_epoch (int) – Epoch with best validation loss

  • avg_val_loss (float) – Current epoch validation loss

  • avg_epoch_loss (float) – Current epoch training loss

  • args (argparse.Namespace) – Command line arguments

  • paths (EasyDict) – Directory paths

  • logger (Logger) – Logger instance

  • checkpoint_type (str) – Type of checkpoint: “samples”, “epoch”, “best”, “final”

  • save_full_model (bool) – Whether to also save the full model separately

Returns:

(checkpoint_filename, full_model_filename)

Return type:

tuple

Examples

>>> checkpoint_file, full_model_file = ModelUtils.save_training_checkpoint(
...     model, optimizer, epoch, samples_processed, batches_processed,
...     train_loss_history, valid_loss_history, valid_metrics_history,
...     best_val_loss, best_epoch, avg_val_loss, avg_epoch_loss,
...     args, paths, logger, checkpoint_type="best"
... )
static save_emergency_checkpoint(model, optimizer, epoch, samples_processed, batches_processed, train_loss_history, valid_loss_history, valid_metrics_history, args, paths, logger, reason='emergency')[source]

Save emergency checkpoint for recovery.

Parameters:

reason (str) – Reason for emergency save (e.g., “crash”, “interrupt”, “error”)

Returns:

(checkpoint_filename, full_model_filename)

Return type:

tuple

class IPSL_AID.model_utils.TestModel(*args: Any, **kwargs: Any)[source]

Bases: Module

A model for testing purposes that includes a mix of convolutional and linear layers, as well as some non-trainable parameters (buffers). This model is designed to have a reasonable number of parameters for testing the ModelUtils methods without being too large. It includes batch normalization layers to add complexity and a dropout layer to demonstrate non-trainable parameters.

__init__()[source]
forward(x)[source]
class IPSL_AID.model_utils.TestModelUtils(methodName='runTest', logger=None)[source]

Bases: TestCase

Unit tests for ModelUtils class.

__init__(methodName='runTest', logger=None)[source]

Create an instance of the class that will use the named test method when executed. Raises a ValueError if the instance does not have a method with the specified name.

setUp()[source]

Set up test fixtures.

test_get_parameter_number()[source]

Test parameter counting functionality.

test_get_parameter_number_with_frozen_layers()[source]

Test parameter counting with frozen layers.

test_count_parameters_by_layer()[source]

Test layer-wise parameter counting.

test_print_model_layers()[source]

Test model layer printing functionality.

test_log_model_summary_without_input_shape()[source]

Test model summary logging without input shape.

test_log_model_summary_with_input_shape()[source]

Test model summary logging with input shape.

test_save_checkpoint()[source]

Test saving a checkpoint.

test_load_checkpoint()[source]

Test loading a checkpoint.

test_load_checkpoint_without_optimizer()[source]

Test loading a checkpoint without optimizer.

test_load_training_checkpoint()[source]

Test loading a comprehensive training checkpoint.

test_load_training_checkpoint_nonexistent()[source]

Test loading a nonexistent training checkpoint.

test_save_training_checkpoint_epoch_type()[source]

Test saving training checkpoint with epoch type.

test_save_training_checkpoint_best_type()[source]

Test saving training checkpoint with best type.

test_save_training_checkpoint_final_type()[source]

Test saving training checkpoint with final type.

test_save_training_checkpoint_samples_type()[source]

Test saving training checkpoint with samples type.

test_save_emergency_checkpoint()[source]

Test saving emergency checkpoint.

test_full_checkpoint_cycle()[source]

Test complete checkpoint save-load cycle with training state.

tearDown()[source]

Clean up after tests.

IPSL_AID.networks module

IPSL_AID.networks.weight_init(shape, mode, fan_in, fan_out)[source]

Initialize weights using various initialization methods.

Parameters:
  • shape (tuple of ints) – The shape of the weight tensor to initialize.

  • mode (str) – The initialization method to use. Options are: - ‘xavier_uniform’: Xavier uniform initialization - ‘xavier_normal’: Xavier normal initialization - ‘kaiming_uniform’: Kaiming uniform initialization (also known as He initialization) - ‘kaiming_normal’: Kaiming normal initialization (also known as He initialization)

  • fan_in (int) – Number of input units in the weight tensor.

  • fan_out (int) – Number of output units in the weight tensor.

Returns:

A tensor of the specified shape with values initialized according to the chosen method.

Return type:

torch.Tensor

Raises:

ValueError – If an invalid initialization mode is provided.

class IPSL_AID.networks.Linear(*args: Any, **kwargs: Any)[source]

Bases: Module

A linear (fully connected) layer with customizable weight initialization.

This layer applies a linear transformation to the incoming data: y = x W^T + b.

Parameters:
  • in_features (int) – Size of each input sample.

  • out_features (int) – Size of each output sample.

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. Default is True.

  • init_mode (str, optional) – Weight initialization method. Options are: - xavier_uniform: Xavier uniform initialization - xavier_normal: Xavier normal initialization - kaiming_uniform: Kaiming uniform initialization (He initialization) - kaiming_normal: Kaiming normal initialization (He initialization) Default is kaiming_normal.

  • init_weight (float or int, optional) – Scaling factor for the initialized weights. Default is 1.

  • init_bias (float or int, optional) – Scaling factor for the initialized bias. Default is 0.

weight

The learnable weights of the layer of shape (out_features, in_features).

Type:

torch.nn.Parameter

bias

The learnable bias of the layer of shape (out_features,). If bias=False, this attribute is set to None.

Type:

torch.nn.Parameter or None

__init__(in_features, out_features, bias=True, init_mode='kaiming_normal', init_weight=1, init_bias=0)[source]

Initialize the Linear layer.

Parameters:
  • in_features (int) – Size of each input sample.

  • out_features (int) – Size of each output sample.

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. Default is True.

  • init_mode (str, optional) – Weight initialization method. Default is ‘kaiming_normal’.

  • init_weight (float or int, optional) – Scaling factor for the initialized weights. Default is 1.

  • init_bias (float or int, optional) – Scaling factor for the initialized bias. Default is 0.

forward(x)[source]

Forward pass of the linear layer.

Parameters:

x (torch.Tensor) – Input tensor of shape (batch_size, in_features) or (batch_size, *, in_features) where * means any number of additional dimensions.

Returns:

Output tensor of shape (batch_size, out_features) or (batch_size, *, out_features).

Return type:

torch.Tensor

Notes

The operation performed is: output = x @ weight^T + bias. The bias is added in-place for efficiency when possible.

class IPSL_AID.networks.Conv2d(*args: Any, **kwargs: Any)[source]

Bases: Module

2D convolutional layer with optional upsampling, downsampling, and fused resampling.

This layer implements a 2D convolution that can optionally include upsampling or downsampling operations with configurable resampling filters. It supports both separate and fused resampling modes for efficiency.

Parameters:
  • in_channels (int) – Number of input channels.

  • out_channels (int) – Number of output channels.

  • kernel (int) – Size of the convolutional kernel (square kernel).

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. Default is True.

  • up (bool, optional) – If True, upsample the input by a factor of 2 before convolution. Cannot be True if down is also True. Default is False.

  • down (bool, optional) – If True, downsample the output by a factor of 2 after convolution. Cannot be True if up is also True. Default is False.

  • resample_filter (list, optional) – Coefficients of the 1D resampling filter that will be turned into a 2D filter. Default is [1, 1] (bilinear filter).

  • fused_resample (bool, optional) – If True, fuse the resampling operation with the convolution for efficiency. Default is False.

  • init_mode (str, optional) – Weight initialization method. Options are: - ‘xavier_uniform’: Xavier uniform initialization - ‘xavier_normal’: Xavier normal initialization - ‘kaiming_uniform’: Kaiming uniform initialization (He initialization) - ‘kaiming_normal’: Kaiming normal initialization (He initialization) Default is ‘kaiming_normal’.

  • init_weight (float or int, optional) – Scaling factor for the initialized weights. Default is 1.

  • init_bias (float or int, optional) – Scaling factor for the initialized bias. Default is 0.

weight

The learnable weights of the convolution of shape (out_channels, in_channels, kernel, kernel). If kernel is 0, this is None.

Type:

torch.nn.Parameter or None

bias

The learnable bias of the convolution of shape (out_channels,). If kernel is 0 or bias is False, this is None.

Type:

torch.nn.Parameter or None

resample_filter

The 2D resampling filter used for upsampling or downsampling. Registered as a buffer (non-learnable parameter).

Type:

torch.Tensor or None

Raises:

AssertionError – If both up and down are set to True.

Notes

  • When kernel is 0, no convolution is performed, only resampling if enabled.

  • The resampling filter is created by taking the outer product of the 1D filter with itself to create a separable 2D filter, then normalized.

  • Fused resampling combines the resampling and convolution operations into single operations for better performance.

Examples

>>> # Standard convolution
>>> conv = Conv2d(3, 16, kernel=3)
>>> x = torch.randn(4, 3, 32, 32)
>>> out = conv(x)
>>> out.shape
torch.Size([4, 16, 32, 32])
>>> # Convolution with downsampling
>>> conv_down = Conv2d(3, 16, kernel=3, down=True)
>>> out = conv_down(x)
>>> out.shape
torch.Size([4, 16, 16, 16])
>>> # Convolution with upsampling
>>> conv_up = Conv2d(3, 16, kernel=3, up=True)
>>> out = conv_up(x)
>>> out.shape
torch.Size([4, 16, 64, 64])
__init__(in_channels, out_channels, kernel, bias=True, up=False, down=False, resample_filter=[1, 1], fused_resample=False, init_mode='kaiming_normal', init_weight=1, init_bias=0)[source]

Initialize the Conv2d layer.

Parameters:
  • in_channels (int) – Number of input channels.

  • out_channels (int) – Number of output channels.

  • kernel (int) – Size of the convolutional kernel.

  • bias (bool, optional) – Whether to include a bias term. Default is True.

  • up (bool, optional) – Whether to upsample the input. Default is False.

  • down (bool, optional) – Whether to downsample the output. Default is False.

  • resample_filter (list, optional) – Coefficients of the 1D resampling filter. Default is [1, 1].

  • fused_resample (bool, optional) – Whether to fuse resampling with convolution. Default is False.

  • init_mode (str, optional) – Weight initialization method. Default is ‘kaiming_normal’.

  • init_weight (float or int, optional) – Scaling factor for weight initialization. Default is 1.

  • init_bias (float or int, optional) – Scaling factor for bias initialization. Default is 0.

forward(x)[source]

Forward pass of the Conv2d layer.

Parameters:

x (torch.Tensor) – Input tensor of shape (batch_size, in_channels, height, width).

Returns:

Output tensor of shape (batch_size, out_channels, out_height, out_width). If up is True, spatial dimensions are doubled. If down is True, spatial dimensions are halved.

Return type:

torch.Tensor

Notes

The method handles four main cases: 1. Fused upsampling + convolution 2. Fused convolution + downsampling 3. Separate up/down sampling followed by convolution 4. Standard convolution only

class IPSL_AID.networks.GroupNorm(*args: Any, **kwargs: Any)[source]

Bases: Module

Group Normalization layer.

This layer implements Group Normalization, which divides channels into groups and computes within each group the mean and variance for normalization. It is particularly effective for small batch sizes and often used as an alternative to Batch Normalization.

Parameters:
  • num_channels (int) – Number of input channels.

  • num_groups (int, optional) – Number of groups to divide the channels into. Must be a divisor of the number of channels. The actual number of groups may be reduced to satisfy min_channels_per_group. Default is 32.

  • min_channels_per_group (int, optional) – Minimum number of channels per group. If the division would result in fewer channels per group, the number of groups is reduced. Default is 4.

  • eps (float, optional) – A small constant added to the denominator for numerical stability. Default is 1e-5.

weight

Learnable scale parameter of shape (num_channels,). Initialized to ones.

Type:

torch.nn.Parameter

bias

Learnable bias parameter of shape (num_channels,). Initialized to zeros.

Type:

torch.nn.Parameter

Notes

  • Group Normalization is independent of batch size, making it suitable for variable batch sizes and small batch training.

  • The number of groups is automatically adjusted to ensure each group has at least min_channels_per_group channels.

  • This layer uses PyTorch’s built-in torch.nn.functional.group_norm.

__init__(num_channels, num_groups=32, min_channels_per_group=4, eps=1e-05)[source]

Initialize the GroupNorm layer.

Parameters:
  • num_channels (int) – Number of input channels.

  • num_groups (int, optional) – Desired number of groups. Default is 32.

  • min_channels_per_group (int, optional) – Minimum channels per group. Default is 4.

  • eps (float, optional) – Small constant for numerical stability. Default is 1e-5.

forward(x)[source]

Forward pass of the GroupNorm layer.

Parameters:

x (torch.Tensor) – Input tensor of shape (batch_size, num_channels, height, width).

Returns:

Normalized tensor of same shape as input.

Return type:

torch.Tensor

Notes

The normalization is performed across spatial dimensions and within each group of channels, maintaining the original mean and variance statistics per group.

class IPSL_AID.networks.AttentionOp(*args: Any, **kwargs: Any)[source]

Bases: Function

Custom autograd function for scaled dot-product attention weight computation.

This function computes attention weights using scaled dot-product attention: w = softmax(Q·K^T / √d_k), where d_k is the dimension of the key vectors. It implements both forward and backward passes for gradient computation.

Notes

  • This is a stateless operation that uses torch.autograd.Function for custom backward.

  • The forward pass computes attention weights in float32 for numerical stability.

  • The backward pass computes gradients using the chain rule for softmax and matrix multiplication.

  • This implementation is optimized for memory efficiency during backward pass.

static forward(ctx, q, k)[source]

Forward pass for attention weight computation.

Parameters:
  • ctx (torch.autograd.function.BackwardCFunction) – Context object to save tensors for backward pass.

  • q (torch.Tensor) – Query tensor of shape (batch_size, channels, query_length).

  • k (torch.Tensor) – Key tensor of shape (batch_size, channels, key_length).

Returns:

Attention weights of shape (batch_size, query_length, key_length). Each row represents attention distribution for a query position.

Return type:

torch.Tensor

Notes

  • Computes w = softmax(Q·K^T / √d_k) where d_k = k.shape[1] (channel dimension).

  • Uses float32 for computation to maintain numerical stability.

  • Saves q, k, and w in context for backward pass.

static backward(ctx, dw)[source]

Backward pass for attention weight computation.

Parameters:
  • ctx (torch.autograd.function.BackwardCFunction) – Context object containing saved tensors from forward pass.

  • dw (torch.Tensor) – Gradient of loss with respect to attention weights. Shape: (batch_size, query_length, key_length).

Returns:

  • dq (torch.Tensor) – Gradient with respect to query tensor. Shape: (batch_size, channels, query_length).

  • dk (torch.Tensor) – Gradient with respect to key tensor. Shape: (batch_size, channels, key_length).

Notes

  • Uses the saved tensors q, k, w from forward pass.

  • Computes gradient of softmax using PyTorch’s internal softmax_backward.

  • Applies chain rule for the scaled dot-product operation.

  • Maintains original dtypes of input tensors.

class IPSL_AID.networks.UNetBlock(*args: Any, **kwargs: Any)[source]

Bases: Module

U-Net block with optional attention, up/down sampling, and adaptive scaling.

This block implements a residual block commonly used in U-Net architectures for diffusion models and image-to-image translation. It supports: - Residual connections with optional skip scaling - Adaptive scaling/shifting via conditioning embeddings - Multi-head self-attention mechanisms - Upsampling or downsampling operations - Dropout for regularization

Parameters:
  • in_channels (int) – Number of input channels.

  • out_channels (int) – Number of output channels.

  • emb_channels (int) – Number of embedding (conditioning) channels.

  • up (bool, optional) – If True, upsample the input by a factor of 2. Default is False.

  • down (bool, optional) – If True, downsample the output by a factor of 2. Default is False.

  • attention (bool, optional) – If True, include multi-head self-attention in the block. Default is False.

  • num_heads (int, optional) – Number of attention heads. If None, computed as out_channels // channels_per_head. Default is None.

  • channels_per_head (int, optional) – Number of channels per attention head when num_heads is None. Default is 64.

  • dropout (float, optional) – Dropout probability applied after the first activation. Default is 0.

  • skip_scale (float, optional) – Scaling factor applied to the residual connection. Default is 1.

  • eps (float, optional) – Epsilon value for GroupNorm layers for numerical stability. Default is 1e-5.

  • resample_filter (list, optional) – Coefficients for the resampling filter used in up/down sampling. Default is [1, 1].

  • resample_proj (bool, optional) – If True, use a 1x1 convolution in the skip connection when resampling. Default is False.

  • adaptive_scale (bool, optional) – If True, use both scale and shift parameters from the embedding. If False, use only shift parameters. Default is True.

  • init (dict, optional) – Initialization parameters for most convolutional layers. Default is empty dict.

  • init_zero (dict, optional) – Initialization parameters for final convolutional layers (zero initialization). Default is {‘init_weight’: 0}.

  • init_attn (dict, optional) – Initialization parameters for attention layers. If None, uses the same as init. Default is None.

norm0, norm1, norm2

Group normalization layers.

Type:

GroupNorm

conv0, conv1

Convolutional layers.

Type:

Conv2d

affine

Linear layer for conditioning embedding.

Type:

Linear

skip

Skip connection projection (1x1 conv) if input and output channels differ or resampling.

Type:

Conv2d or None

qkv, proj

Attention query-key-value and projection layers (if attention is enabled).

Type:

Conv2d

Notes

  • The block follows a pre-activation residual structure: norm -> activation -> conv.

  • When adaptive_scale=True, the conditioning embedding provides both scale and shift parameters.

  • The attention mechanism uses multi-head self-attention within the spatial dimensions.

  • The skip connection is automatically added when input/output channels differ or when resampling.

__init__(in_channels, out_channels, emb_channels, up=False, down=False, attention=False, num_heads=None, channels_per_head=64, dropout=0, skip_scale=1, eps=1e-05, resample_filter=[1, 1], resample_proj=False, adaptive_scale=True, init={}, init_zero={'init_weight': 0}, init_attn=None)[source]

Initialize the UNetBlock.

Parameters:
  • in_channels (int) – Number of input channels.

  • out_channels (int) – Number of output channels.

  • emb_channels (int) – Number of embedding channels.

  • up (bool, optional) – Enable upsampling.

  • down (bool, optional) – Enable downsampling.

  • attention (bool, optional) – Enable attention mechanism.

  • num_heads (int, optional) – Number of attention heads.

  • channels_per_head (int, optional) – Channels per attention head.

  • dropout (float, optional) – Dropout probability.

  • skip_scale (float, optional) – Scaling factor for skip connection.

  • eps (float, optional) – Epsilon for GroupNorm.

  • resample_filter (list, optional) – Filter for resampling.

  • resample_proj (bool, optional) – Use projection in skip connection when resampling.

  • adaptive_scale (bool, optional) – Use adaptive scaling from embedding.

  • init (dict, optional) – Initialization parameters.

  • init_zero (dict, optional) – Zero initialization parameters.

  • init_attn (dict, optional) – Attention initialization parameters.

forward(x, emb)[source]

Forward pass of the UNetBlock.

Parameters:
  • x (torch.Tensor) – Input tensor of shape (batch_size, in_channels, height, width).

  • emb (torch.Tensor) – Conditioning embedding of shape (batch_size, emb_channels).

Returns:

Output tensor of shape (batch_size, out_channels, out_height, out_width).

Return type:

torch.Tensor

Notes

The forward pass consists of: 1. Initial normalization and convolution (with optional up/down sampling) 2. Adaptive scaling/shifting from conditioning embedding 3. Second normalization, dropout, and convolution 4. Skip connection with scaling 5. Optional multi-head self-attention

class IPSL_AID.networks.PositionalEmbedding(*args: Any, **kwargs: Any)[source]

Bases: Module

Sinusoidal positional embedding for sequences or timesteps.

This module generates sinusoidal embeddings for input positions, commonly used in transformer architectures and diffusion models to provide temporal or positional information to the model.

Parameters:
  • num_channels (int) – Dimensionality of the embedding vectors. Must be even.

  • max_positions (int, optional) – Maximum number of positions (or timesteps) for which embeddings are generated. Determines the frequency scaling. Default is 10000.

  • endpoint (bool, optional) – If True, scales frequencies such that the last frequency is 1/2 of the first. If False, uses the standard scaling. Default is False.

num_channels

Dimensionality of the embedding vectors.

Type:

int

max_positions

Maximum positions for frequency scaling.

Type:

int

endpoint

Whether to use endpoint scaling.

Type:

bool

Notes

  • The embedding uses sine and cosine functions of different frequencies to create a unique encoding for each position.

  • The frequencies are computed as: freqs = (1 / max_positions) ** (2i / num_channels) for i in range(num_channels//2) or with endpoint adjustment.

  • The output embedding is the concatenation of [cos(x*freqs), sin(x*freqs)].

  • This implementation is based on the original Transformer positional encoding and the diffusion model timestep embedding.

__init__(num_channels, max_positions=10000, endpoint=False)[source]

Initialize the PositionalEmbedding module.

Parameters:
  • num_channels (int) – Dimensionality of the embedding vectors.

  • max_positions (int, optional) – Maximum number of positions for frequency scaling. Default is 10000.

  • endpoint (bool, optional) – Whether to use endpoint scaling. Default is False.

forward(x)[source]

Forward pass to generate positional embeddings.

Parameters:

x (torch.Tensor) – Input tensor of positions (or timesteps) of shape (batch_size,) or (n,). Values are typically integers in [0, max_positions-1].

Returns:

Positional embeddings of shape (len(x), num_channels).

Return type:

torch.Tensor

Notes

  • The input tensor x is typically a 1D tensor of position indices.

  • The output is a 2D tensor where each row corresponds to the embedding of the respective position.

  • The embedding uses the device and dtype of the input tensor x.

class IPSL_AID.networks.FourierEmbedding(*args: Any, **kwargs: Any)[source]

Bases: Module

Random Fourier feature embedding for positional encoding.

This module generates random Fourier features (RFF) for input positions or coordinates, mapping low-dimensional inputs to a higher-dimensional space using random frequency sampling. Commonly used in neural fields, kernel methods, and coordinate-based neural networks.

Parameters:
  • num_channels (int) – Dimensionality of the embedding vectors. Must be even.

  • scale (float, optional) – Standard deviation for sampling the random frequencies. Determines the frequency distribution. Default is 16.

freqs

Random frequencies sampled from a normal distribution with mean 0 and standard deviation scale. Shape: (num_channels // 2,).

Type:

torch.Tensor (buffer)

Notes

  • The frequencies are randomly initialized and fixed (non-learnable).

  • The embedding uses sine and cosine projections of the input multiplied by 2π times the random frequencies.

  • This technique approximates shift-invariant kernels via Bochner’s theorem.

  • Unlike learned embeddings, this provides a fixed, deterministic mapping from input space to embedding space.

__init__(num_channels, scale=16)[source]

Initialize the FourierEmbedding module.

Parameters:
  • num_channels (int) – Dimensionality of the embedding vectors.

  • scale (float, optional) – Standard deviation for frequency sampling. Default is 16.

forward(x)[source]

Forward pass to generate Fourier feature embeddings.

Parameters:

x (torch.Tensor) – Input tensor of shape (batch_size,) or (n,). Typically continuous values representing positions or coordinates.

Returns:

Fourier feature embeddings of shape (len(x), num_channels).

Return type:

torch.Tensor

Notes

  • The transformation is: x ↦ [cos(2π * freqs * x), sin(2π * freqs * x)].

  • The output dimension is twice the number of frequencies (num_channels).

  • This embedding is deterministic given the fixed random frequencies.

class IPSL_AID.networks.SongUNet(*args: Any, **kwargs: Any)[source]

Bases: Module

U-Net architecture for diffusion models based on Song et al. (2020).

This implementation supports both DDPM++ and NCSN++ architectures with configurable encoder/decoder types, attention mechanisms, and conditioning. It handles both square and rectangular input resolutions.

Parameters:
  • img_resolution (int or tuple) – Input image resolution. If int, assumes square images (img_resolution x img_resolution). If tuple, should be (height, width).

  • in_channels (int) – Number of input color channels.

  • out_channels (int) – Number of output color channels.

  • label_dim (int, optional) – Number of class labels. Set to 0 for unconditional generation. Default is 0.

  • augment_dim (int, optional) – Dimensionality of augmentation labels (e.g., time-dependent augmentation). Set to 0 for no augmentation. Default is 0.

  • model_channels (int, optional) – Base channel multiplier for the network. Default is 128.

  • channel_mult (list of int, optional) – Channel multipliers for each resolution level. Default is [1, 2, 2, 2].

  • channel_mult_emb (int, optional) – Multiplier for embedding dimensionality relative to model_channels. Default is 4.

  • num_blocks (int, optional) – Number of residual blocks per resolution. Default is 4.

  • attn_resolutions (list of int, optional) – List of resolutions (minimum dimension) to apply self-attention. Default is [16].

  • dropout (float, optional) – Dropout probability for intermediate activations. Default is 0.10.

  • label_dropout (float, optional) – Dropout probability for class labels (classifier-free guidance). Default is 0.

  • embedding_type (str, optional) – Type of timestep embedding: ‘positional’ for DDPM++, ‘fourier’ for NCSN++. Default is ‘positional’.

  • channel_mult_noise (int, optional) – Multiplier for noise embedding dimensionality: 1 for DDPM++, 2 for NCSN++. Default is 1.

  • encoder_type (str, optional) – Encoder architecture: ‘standard’ for DDPM++, ‘skip’ or ‘residual’ for NCSN++. Default is ‘standard’.

  • decoder_type (str, optional) – Decoder architecture: ‘standard’ for both, ‘skip’ for NCSN++. Default is ‘standard’.

  • resample_filter (list, optional) – Resampling filter coefficients: [1,1] for DDPM++, [1,3,3,1] for NCSN++. Default is [1,1].

img_resolution

Input image resolution as (height, width).

Type:

tuple

img_height

Input image height.

Type:

int

img_width

Input image width.

Type:

int

label_dropout

Class label dropout probability.

Type:

float

map_noise

Noise/timestep embedding module.

Type:

PositionalEmbedding or FourierEmbedding

map_label

Class label embedding module.

Type:

Linear or None

map_augment

Augmentation label embedding module.

Type:

Linear or None

map_layer0, map_layer1

Embedding transformation layers.

Type:

Linear

enc

Encoder modules organized by resolution.

Type:

torch.nn.ModuleDict

dec

Decoder modules organized by resolution.

Type:

torch.nn.ModuleDict

Raises:

AssertionError – If embedding_type is not ‘fourier’ or ‘positional’. If encoder_type is not ‘standard’, ‘skip’, or ‘residual’. If decoder_type is not ‘standard’ or ‘skip’. If img_resolution tuple doesn’t have exactly 2 elements.

Notes

  • The architecture follows a U-Net structure with skip connections.

  • Supports multiple conditioning types: noise (timestep), class labels, augmentations.

  • Attention is applied at specified resolutions to capture long-range dependencies.

  • Different encoder/decoder types and embedding methods allow emulating DDPM++ or NCSN++.

  • Rectangular resolutions are supported by tracking height and width separately.

References

  • Ho et al., “Denoising Diffusion Probabilistic Models” (DDPM)

  • Song et al., “Score-Based Generative Modeling through Stochastic Differential Equations” (NCSN++)

__init__(img_resolution, in_channels, out_channels, label_dim=0, augment_dim=0, model_channels=128, channel_mult=[1, 2, 2, 2], channel_mult_emb=4, num_blocks=4, attn_resolutions=[16], dropout=0.1, label_dropout=0, embedding_type='positional', channel_mult_noise=1, encoder_type='standard', decoder_type='standard', resample_filter=[1, 1])[source]

Initialize the SongUNet.

Parameters:
  • img_resolution (int or tuple) – Image resolution.

  • in_channels (int) – Input channels.

  • out_channels (int) – Output channels.

  • label_dim (int, optional) – Class label dimension.

  • augment_dim (int, optional) – Augmentation label dimension.

  • model_channels (int, optional) – Base channel multiplier.

  • channel_mult (list, optional) – Channel multipliers per resolution.

  • channel_mult_emb (int, optional) – Embedding channel multiplier.

  • num_blocks (int, optional) – Blocks per resolution.

  • attn_resolutions (list, optional) – Resolutions for attention.

  • dropout (float, optional) – Dropout probability.

  • label_dropout (float, optional) – Label dropout probability.

  • embedding_type (str, optional) – Embedding type.

  • channel_mult_noise (int, optional) – Noise embedding multiplier.

  • encoder_type (str, optional) – Encoder type.

  • decoder_type (str, optional) – Decoder type.

  • resample_filter (list, optional) – Resampling filter coefficients.

forward(x, noise_labels, class_labels, augment_labels=None)[source]

Forward pass through the U-Net.

Parameters:
  • x (torch.Tensor) – Input tensor of shape (batch_size, in_channels, height, width).

  • noise_labels (torch.Tensor) – Noise/timestep labels of shape (batch_size,).

  • class_labels (torch.Tensor or None) – Class labels of shape (batch_size,) or (batch_size, label_dim). Can be None if label_dim is 0.

  • augment_labels (torch.Tensor or None, optional) – Augmentation labels of shape (batch_size, augment_dim). Can be None if augment_dim is 0.

Returns:

Output tensor of shape (batch_size, out_channels, height, width).

Return type:

torch.Tensor

Notes

  • The forward pass consists of three main stages: 1. Embedding mapping: combines noise, class, and augmentation embeddings. 2. Encoder: extracts hierarchical features with optional skip connections. 3. Decoder: reconstructs output with skip connections from encoder.

  • Classifier-free guidance is supported via label_dropout.

  • The noise embedding uses sinusoidal (positional) or Fourier features.

class IPSL_AID.networks.DhariwalUNet(*args: Any, **kwargs: Any)[source]

Bases: Module

U-Net architecture based on Dhariwal & Nichol (2021) for diffusion models.

This implementation follows the ADM (Ablated Diffusion Model) architecture with configurable attention mechanisms, conditioning, and rectangular resolution support. It features a U-Net structure with skip connections, group normalization, and optional conditioning on class labels and augmentation.

Parameters:
  • img_resolution (int or tuple) – Input image resolution. If int, assumes square images (img_resolution x img_resolution). If tuple, should be (height, width).

  • in_channels (int) – Number of input color channels.

  • out_channels (int) – Number of output color channels.

  • label_dim (int, optional) – Number of class labels. Set to 0 for unconditional generation. Default is 0.

  • augment_dim (int, optional) – Dimensionality of augmentation labels (e.g., time-dependent augmentation). Set to 0 for no augmentation. Default is 0.

  • model_channels (int, optional) – Base channel multiplier for the network. Default is 128.

  • channel_mult (list of int, optional) – Channel multipliers for each resolution level. Default is [1, 2, 3, 4].

  • channel_mult_emb (int, optional) – Multiplier for embedding dimensionality relative to model_channels. Default is 4.

  • num_blocks (int, optional) – Number of residual blocks per resolution. Default is 3.

  • attn_resolutions (list of int, optional) – List of resolutions (minimum dimension) to apply self-attention. Default is [32, 16, 8].

  • dropout (float, optional) – Dropout probability for intermediate activations. Default is 0.10.

  • label_dropout (float, optional) – Dropout probability for class labels (classifier-free guidance). Default is 0.

  • diffusion_model (bool, optional) – Whether to configure the network for diffusion models. If True, includes timestep embedding; if False, only uses label conditioning. Default is True.

img_resolution

Input image resolution as (height, width).

Type:

tuple

img_height

Input image height.

Type:

int

img_width

Input image width.

Type:

int

label_dropout

Class label dropout probability.

Type:

float

map_noise

Noise/timestep embedding module (if diffusion_model=True).

Type:

PositionalEmbedding or None

map_label

Class label embedding module.

Type:

Linear or None

map_augment

Augmentation label embedding module.

Type:

Linear or None

map_layer0, map_layer1

Embedding transformation layers.

Type:

Linear

enc

Encoder modules organized by resolution.

Type:

torch.nn.ModuleDict

dec

Decoder modules organized by resolution.

Type:

torch.nn.ModuleDict

out_norm

Final group normalization layer.

Type:

GroupNorm

out_conv

Final convolutional output layer.

Type:

Conv2d

Raises:

AssertionError – If img_resolution tuple doesn’t have exactly 2 elements.

Notes

  • The architecture is based on the U-Net from “Diffusion Models Beat GANs on Image Synthesis”.

  • Features group normalization throughout and attention at multiple resolutions.

  • Supports classifier-free guidance via label_dropout.

  • Can be used for both diffusion models and other conditional generation tasks.

  • Rectangular resolutions are supported by tracking height and width separately.

References

  • Dhariwal & Nichol, “Diffusion Models Beat GANs on Image Synthesis”, 2021.

__init__(img_resolution, in_channels, out_channels, label_dim=0, augment_dim=0, model_channels=128, channel_mult=[1, 2, 3, 4], channel_mult_emb=4, num_blocks=3, attn_resolutions=[32, 16, 8], dropout=0.1, label_dropout=0, diffusion_model=True)[source]

Initialize the DhariwalUNet.

Parameters:
  • img_resolution (int or tuple) – Image resolution.

  • in_channels (int) – Input channels.

  • out_channels (int) – Output channels.

  • label_dim (int, optional) – Class label dimension.

  • augment_dim (int, optional) – Augmentation label dimension.

  • model_channels (int, optional) – Base channel multiplier.

  • channel_mult (list, optional) – Channel multipliers per resolution.

  • channel_mult_emb (int, optional) – Embedding channel multiplier.

  • num_blocks (int, optional) – Blocks per resolution.

  • attn_resolutions (list, optional) – Resolutions for attention.

  • dropout (float, optional) – Dropout probability.

  • label_dropout (float, optional) – Label dropout probability.

  • diffusion_model (bool, optional) – Whether to configure for diffusion models.

forward(x, noise_labels=None, class_labels=None, augment_labels=None)[source]

Forward pass through the Dhariwal U-Net.

Parameters:
  • x (torch.Tensor) – Input tensor of shape (batch_size, in_channels, height, width).

  • noise_labels (torch.Tensor or None) – Noise/timestep labels of shape (batch_size,). Required if diffusion_model=True, otherwise optional.

  • class_labels (torch.Tensor or None) – Class labels of shape (batch_size,) or (batch_size, label_dim). Can be None if label_dim is 0.

  • augment_labels (torch.Tensor or None, optional) – Augmentation labels of shape (batch_size, augment_dim). Can be None if augment_dim is 0.

Returns:

Output tensor of shape (batch_size, out_channels, height, width).

Return type:

torch.Tensor

Notes

  • The forward pass combines conditioning embeddings (noise, class, augmentation) and processes through encoder-decoder with skip connections.

  • When diffusion_model=False, the noise_labels can be omitted.

  • Classifier-free guidance is implemented via label_dropout during training.

class IPSL_AID.networks.VPPrecond(*args: Any, **kwargs: Any)[source]

Bases: Module

Variance Preserving (VP) preconditioning for diffusion models.

This class implements preconditioning for the Variance Preserving formulation of diffusion models, as described in Song et al. (2020). It wraps a base U-Net model and applies the appropriate scaling and conditioning for VP SDEs.

Parameters:
  • img_resolution (int or tuple) – Input image resolution. If int, assumes square images. If tuple, should be (height, width).

  • in_channels (int) – Number of input color channels.

  • out_channels (int) – Number of output color channels.

  • label_dim (int, optional) – Number of class labels. Set to 0 for unconditional generation. Default is 0.

  • use_fp16 (bool, optional) – Whether to execute the underlying model at FP16 precision for speed. Default is False.

  • beta_d (float, optional) – Extent of the noise level schedule. Controls the rate of noise increase. Default is 19.9.

  • beta_min (float, optional) – Initial slope of the noise level schedule. Default is 0.1.

  • M (int, optional) – Original number of timesteps in the DDPM formulation. Default is 1000.

  • epsilon_t (float, optional) – Minimum t-value used during training. Prevents numerical issues. Default is 1e-5.

  • model_type (str, optional) – Class name of the underlying U-Net model (‘SongUNet’ or ‘DhariwalUNet’). Default is ‘SongUNet’.

  • **model_kwargs (dict) – Additional keyword arguments passed to the underlying model.

img_resolution

Input image resolution as (height, width).

Type:

tuple

in_channels

Number of input channels.

Type:

int

out_channels

Number of output channels.

Type:

int

label_dim

Number of class labels.

Type:

int

use_fp16

Whether to use FP16 precision.

Type:

bool

sigma_min

Minimum noise level (sigma) corresponding to epsilon_t.

Type:

float

sigma_max

Maximum noise level (sigma) corresponding to t=1.

Type:

float

model

The underlying U-Net model.

Type:

torch.nn.Module

Notes

  • The VP formulation maintains unit variance throughout the diffusion process.

  • The noise schedule follows: σ(t) = sqrt(exp(0.5*β_d*t² + β_min*t) - 1)

  • The preconditioning applies scaling factors: c_skip, c_out, c_in, c_noise

  • Supports conditional generation via class labels and condition images.

  • Implements the continuous-time formulation of diffusion models.

References

  • Song et al., “Score-Based Generative Modeling through Stochastic Differential Equations”, 2020.

__init__(img_resolution, in_channels, out_channels, label_dim=0, use_fp16=False, beta_d=19.9, beta_min=0.1, M=1000, epsilon_t=1e-05, model_type='SongUNet', **model_kwargs)[source]

Initialize the VPPrecond module.

Parameters:
  • img_resolution (int or tuple) – Image resolution.

  • in_channels (int) – Input channels.

  • out_channels (int) – Output channels.

  • label_dim (int, optional) – Class label dimension.

  • use_fp16 (bool, optional) – Use FP16 precision.

  • beta_d (float, optional) – Noise schedule extent.

  • beta_min (float, optional) – Initial noise schedule slope.

  • M (int, optional) – Number of timesteps.

  • epsilon_t (float, optional) – Minimum t-value.

  • model_type (str, optional) – Underlying model class name.

  • **model_kwargs (dict) – Additional model arguments.

forward(x, sigma, condition_img=None, class_labels=None, force_fp32=False, **model_kwargs)[source]

Forward pass with VP preconditioning.

Parameters:
  • x (torch.Tensor) – Input noisy tensor of shape (batch_size, in_channels, height, width).

  • sigma (torch.Tensor) – Noise level(s) of shape (batch_size,) or scalar.

  • condition_img (torch.Tensor, optional) – Condition image tensor of same spatial dimensions as x. Default is None.

  • class_labels (torch.Tensor, optional) – Class labels for conditioning of shape (batch_size,) or (batch_size, label_dim). Default is None.

  • force_fp32 (bool, optional) – Force FP32 precision even if use_fp16 is True. Default is False.

  • **model_kwargs (dict) – Additional keyword arguments passed to the underlying model.

Returns:

Denoised output of shape (batch_size, out_channels, height, width).

Return type:

torch.Tensor

Notes

  • Applies the preconditioning: D(x) = c_skip * x + c_out * F(c_in * x, c_noise)

  • Where F is the underlying U-Net model.

  • c_in, c_out, c_skip, c_noise are computed from sigma according to VP formulation.

  • Condition images are concatenated along the channel dimension.

sigma(t)[source]

Compute noise level sigma for given time t.

Parameters:

t (float or torch.Tensor) – Time value(s) in [epsilon_t, 1].

Returns:

Noise level sigma corresponding to t.

Return type:

torch.Tensor

Notes

Formula: σ(t) = sqrt(exp(0.5*β_d*t² + β_min*t) - 1)

sigma_inv(sigma)[source]

Inverse function: compute time t for given noise level sigma.

Parameters:

sigma (float or torch.Tensor) – Noise level(s).

Returns:

Time t corresponding to sigma.

Return type:

torch.Tensor

Notes

Formula: t = (sqrt(β_min² + 2*β_d*log(1+σ²)) - β_min) / β_d

round_sigma(sigma)[source]

Round noise level(s) for compatibility with discrete schedules.

Parameters:

sigma (float or torch.Tensor) – Noise level(s).

Returns:

Rounded noise level(s).

Return type:

torch.Tensor

class IPSL_AID.networks.VEPrecond(*args: Any, **kwargs: Any)[source]

Bases: Module

Variance Exploding (VE) preconditioning for diffusion models.

This class implements preconditioning for the Variance Exploding formulation of diffusion models, as described in Song et al. (2020). It wraps a base U-Net model and applies the appropriate scaling and conditioning for VE SDEs.

Parameters:
  • img_resolution (int or tuple) – Input image resolution. If int, assumes square images. If tuple, should be (height, width).

  • in_channels (int) – Number of input color channels.

  • out_channels (int) – Number of output color channels.

  • label_dim (int, optional) – Number of class labels. Set to 0 for unconditional generation. Default is 0.

  • use_fp16 (bool, optional) – Whether to execute the underlying model at FP16 precision for speed. Default is False.

  • sigma_min (float, optional) – Minimum supported noise level. Default is 0.02.

  • sigma_max (float, optional) – Maximum supported noise level. Default is 100.

  • model_type (str, optional) – Class name of the underlying U-Net model (‘SongUNet’ or ‘DhariwalUNet’). Default is ‘SongUNet’.

  • **model_kwargs (dict) – Additional keyword arguments passed to the underlying model.

img_resolution

Input image resolution as (height, width).

Type:

tuple

in_channels

Number of input channels.

Type:

int

out_channels

Number of output channels.

Type:

int

label_dim

Number of class labels.

Type:

int

use_fp16

Whether to use FP16 precision.

Type:

bool

sigma_min

Minimum noise level.

Type:

float

sigma_max

Maximum noise level.

Type:

float

model

The underlying U-Net model.

Type:

torch.nn.Module

Notes

  • The VE formulation uses a simple geometric noise schedule.

  • The preconditioning applies scaling factors: c_skip, c_out, c_in, c_noise

  • c_noise = 0.5 * log(sigma) maps noise levels to conditioning inputs.

  • Supports conditional generation via class labels and condition images.

References

  • Song et al., “Score-Based Generative Modeling through Stochastic Differential Equations”, 2020.

__init__(img_resolution, in_channels, out_channels, label_dim=0, use_fp16=False, sigma_min=0.02, sigma_max=100, model_type='SongUNet', **model_kwargs)[source]

Initialize the VEPrecond module.

Parameters:
  • img_resolution (int or tuple) – Image resolution.

  • in_channels (int) – Input channels.

  • out_channels (int) – Output channels.

  • label_dim (int, optional) – Class label dimension.

  • use_fp16 (bool, optional) – Use FP16 precision.

  • sigma_min (float, optional) – Minimum noise level.

  • sigma_max (float, optional) – Maximum noise level.

  • model_type (str, optional) – Underlying model class name.

  • **model_kwargs (dict) – Additional model arguments.

forward(x, sigma, condition_img=None, class_labels=None, force_fp32=False, **model_kwargs)[source]

Forward pass with VE preconditioning.

Parameters:
  • x (torch.Tensor) – Input noisy tensor of shape (batch_size, in_channels, height, width).

  • sigma (torch.Tensor) – Noise level(s) of shape (batch_size,) or scalar.

  • condition_img (torch.Tensor, optional) – Condition image tensor of same spatial dimensions as x. Default is None.

  • class_labels (torch.Tensor, optional) – Class labels for conditioning of shape (batch_size,) or (batch_size, label_dim). Default is None.

  • force_fp32 (bool, optional) – Force FP32 precision even if use_fp16 is True. Default is False.

  • **model_kwargs (dict) – Additional keyword arguments passed to the underlying model.

Returns:

Denoised output of shape (batch_size, out_channels, height, width).

Return type:

torch.Tensor

Notes

  • Applies the preconditioning: D(x) = c_skip * x + c_out * F(c_in * x, c_noise)

  • Where F is the underlying U-Net model.

  • For VE: c_skip = 1, c_out = sigma, c_in = 1, c_noise = 0.5 * log(sigma)

  • Condition images are concatenated along the channel dimension.

round_sigma(sigma)[source]

Round noise level(s) for compatibility with discrete schedules.

Parameters:

sigma (float or torch.Tensor) – Noise level(s).

Returns:

Rounded noise level(s).

Return type:

torch.Tensor

class IPSL_AID.networks.iDDPMPrecond(*args: Any, **kwargs: Any)[source]

Bases: Module

Improved DDPM (iDDPM) preconditioning for diffusion models.

This class implements the improved preconditioning scheme from the iDDPM paper, which refines the noise schedule and preconditioning for better sample quality. It provides a bridge between discrete-time DDPM formulations and continuous-time SDE formulations.

Parameters:
  • img_resolution (int or tuple) – Input image resolution. If int, assumes square images. If tuple, should be (height, width).

  • in_channels (int) – Number of input color channels.

  • out_channels (int) – Number of output color channels.

  • label_dim (int, optional) – Number of class labels. Set to 0 for unconditional generation. Default is 0.

  • use_fp16 (bool, optional) – Whether to execute the underlying model at FP16 precision for speed. Default is False.

  • C_1 (float, optional) – Timestep adjustment parameter for low noise levels. Default is 0.001.

  • C_2 (float, optional) – Timestep adjustment parameter for high noise levels. Default is 0.008.

  • M (int, optional) – Original number of timesteps in the DDPM formulation. Default is 1000.

  • model_type (str, optional) – Class name of the underlying U-Net model (‘SongUNet’ or ‘DhariwalUNet’). Default is ‘DhariwalUNet’.

  • **model_kwargs (dict) – Additional keyword arguments passed to the underlying model.

img_resolution

Input image resolution as (height, width).

Type:

tuple

in_channels

Number of input channels.

Type:

int

out_channels

Number of output channels.

Type:

int

label_dim

Number of class labels.

Type:

int

use_fp16

Whether to use FP16 precision.

Type:

bool

sigma_min

Minimum noise level (learned from schedule).

Type:

float

sigma_max

Maximum noise level (learned from schedule).

Type:

float

u

Learned noise schedule of length M+1.

Type:

torch.Tensor (buffer)

model

The underlying U-Net model.

Type:

torch.nn.Module

Notes

  • The iDDPM formulation improves upon DDPM with a refined noise schedule.

  • The noise schedule is learned during initialization via backward recursion.

  • Uses alpha_bar schedule: ᾱ(j) = sin(π/2 * j/M/(C₂+1))²

  • Implements discrete-time preconditioning with improved numerical stability.

References

  • Nichol & Dhariwal, “Improved Denoising Diffusion Probabilistic Models”, 2021.

__init__(img_resolution, in_channels, out_channels, label_dim=0, use_fp16=False, C_1=0.001, C_2=0.008, M=1000, model_type='DhariwalUNet', **model_kwargs)[source]

Initialize the iDDPMPrecond module.

Parameters:
  • img_resolution (int or tuple) – Image resolution.

  • in_channels (int) – Input channels.

  • out_channels (int) – Output channels.

  • label_dim (int, optional) – Class label dimension.

  • use_fp16 (bool, optional) – Use FP16 precision.

  • C_1 (float, optional) – Low noise adjustment.

  • C_2 (float, optional) – High noise adjustment.

  • M (int, optional) – Number of timesteps.

  • model_type (str, optional) – Underlying model class name.

  • **model_kwargs (dict) – Additional model arguments.

forward(x, sigma, condition_img=None, class_labels=None, force_fp32=False, **model_kwargs)[source]

Forward pass with iDDPM preconditioning.

Parameters:
  • x (torch.Tensor) – Input noisy tensor of shape (batch_size, in_channels, height, width).

  • sigma (torch.Tensor) – Noise level(s) of shape (batch_size,) or scalar.

  • condition_img (torch.Tensor, optional) – Condition image tensor of same spatial dimensions as x. Default is None.

  • class_labels (torch.Tensor, optional) – Class labels for conditioning of shape (batch_size,) or (batch_size, label_dim). Default is None.

  • force_fp32 (bool, optional) – Force FP32 precision even if use_fp16 is True. Default is False.

  • **model_kwargs (dict) – Additional keyword arguments passed to the underlying model.

Returns:

Denoised output of shape (batch_size, out_channels, height, width).

Return type:

torch.Tensor

Notes

  • Applies the preconditioning: D(x) = c_skip * x + c_out * F(c_in * x, c_noise)

  • Where F is the underlying U-Net model.

  • For iDDPM: c_skip = 1, c_out = -σ, c_in = 1/√(σ²+1)

  • Condition images are concatenated along the channel dimension.

  • c_noise maps sigma to discrete timesteps for the underlying model.

alpha_bar(j)[source]

Compute alpha_bar for timestep j in the improved schedule.

Parameters:

j (int or torch.Tensor) – Timestep index (0 <= j <= M).

Returns:

ᾱ(j) = sin(π/2 * j/M/(C₂+1))²

Return type:

torch.Tensor

round_sigma(sigma, return_index=False)[source]

Round noise level(s) to the nearest value in the learned schedule.

Parameters:
  • sigma (torch.Tensor) – Noise level(s).

  • return_index (bool, optional) – If True, return the index in the schedule instead of the value. Default is False.

Returns:

Rounded noise level(s) or indices.

Return type:

torch.Tensor

class IPSL_AID.networks.EDMPrecond(*args: Any, **kwargs: Any)[source]

Bases: Module

EDM preconditioning for diffusion models.

This class implements the EDM (Elucidating Diffusion Models) preconditioning scheme, which provides a unified framework for various diffusion formulations with optimized preconditioning coefficients.

Parameters:
  • img_resolution (int or tuple) – Input image resolution. If int, assumes square images. If tuple, should be (height, width).

  • in_channels (int) – Number of input color channels.

  • out_channels (int) – Number of output color channels.

  • label_dim (int, optional) – Number of class labels. Set to 0 for unconditional generation. Default is 0.

  • use_fp16 (bool, optional) – Whether to execute the underlying model at FP16 precision for speed. Default is False.

  • sigma_min (float, optional) – Minimum supported noise level. Default is 0.

  • sigma_max (float, optional) – Maximum supported noise level. Default is float(‘inf’).

  • sigma_data (float, optional) – Standard deviation of the training data. Default is 1.0.

  • model_type (str, optional) – Class name of the underlying U-Net model (‘SongUNet’ or ‘DhariwalUNet’). Default is ‘DhariwalUNet’.

  • **model_kwargs (dict) – Additional keyword arguments passed to the underlying model.

img_resolution

Input image resolution as (height, width).

Type:

tuple

in_channels

Number of input channels.

Type:

int

out_channels

Number of output channels.

Type:

int

label_dim

Number of class labels.

Type:

int

use_fp16

Whether to use FP16 precision.

Type:

bool

sigma_min

Minimum noise level.

Type:

float

sigma_max

Maximum noise level.

Type:

float

sigma_data

Training data standard deviation.

Type:

float

model

The underlying U-Net model.

Type:

torch.nn.Module

Notes

  • The EDM formulation provides a unified preconditioning scheme that generalizes VP, VE, and other diffusion formulations.

  • Preconditioning coefficients: c_skip = σ_data²/(σ²+σ_data²) c_out = σ·σ_data/√(σ²+σ_data²), c_in = 1/√(σ_data²+σ²)

  • c_noise = log(σ)/4 provides the noise conditioning input.

  • This formulation often yields better sample quality and training stability.

References

  • Karras et al., “Elucidating the Design Space of Diffusion-Based Generative Models”, 2022.

__init__(img_resolution, in_channels, out_channels, label_dim=0, use_fp16=False, sigma_min=0, sigma_max=inf, sigma_data=1.0, model_type='DhariwalUNet', **model_kwargs)[source]

Initialize the EDMPrecond module.

Parameters:
  • img_resolution (int or tuple) – Image resolution.

  • in_channels (int) – Input channels.

  • out_channels (int) – Output channels.

  • label_dim (int, optional) – Class label dimension.

  • use_fp16 (bool, optional) – Use FP16 precision.

  • sigma_min (float, optional) – Minimum noise level.

  • sigma_max (float, optional) – Maximum noise level.

  • sigma_data (float, optional) – Training data standard deviation.

  • model_type (str, optional) – Underlying model class name.

  • **model_kwargs (dict) – Additional model arguments.

forward(x, sigma, condition_img=None, class_labels=None, force_fp32=True, **model_kwargs)[source]

Forward pass with EDM preconditioning.

Parameters:
  • x (torch.Tensor) – Input noisy tensor of shape (batch_size, in_channels, height, width).

  • sigma (torch.Tensor) – Noise level(s) of shape (batch_size,) or scalar.

  • condition_img (torch.Tensor, optional) – Condition image tensor of same spatial dimensions as x. Default is None.

  • class_labels (torch.Tensor, optional) – Class labels for conditioning of shape (batch_size,) or (batch_size, label_dim). Default is None.

  • force_fp32 (bool, optional) – Force FP32 precision even if use_fp16 is True. Default is True.

  • **model_kwargs (dict) – Additional keyword arguments passed to the underlying model.

Returns:

Denoised output of shape (batch_size, out_channels, height, width).

Return type:

torch.Tensor

Notes

  • Applies the EDM preconditioning: D(x) = c_skip * x + c_out * F(c_in * x, c_noise)

  • Where F is the underlying U-Net model.

  • EDM coefficients: c_skip = σ_data²/(σ²+σ_data²) c_out = σ·σ_data/√(σ²+σ_data²), c_in = 1/√(σ_data²+σ²)

  • Condition images are concatenated along the channel dimension.

  • c_noise = log(σ)/4 provides the noise conditioning.

round_sigma(sigma)[source]

Round noise level(s) for compatibility with discrete schedules.

Parameters:

sigma (float or torch.Tensor) – Noise level(s).

Returns:

Rounded noise level(s).

Return type:

torch.Tensor

Notes

In EDM, sigma is continuous, so rounding is typically a no-op unless implementing a discrete schedule variant.

class IPSL_AID.networks.TestDiffusionNetworks(methodName='runTest', logger=None)[source]

Bases: TestCase

Unit tests for diffusion network architectures.

__init__(methodName='runTest', logger=None)[source]

Create an instance of the class that will use the named test method when executed. Raises a ValueError if the instance does not have a method with the specified name.

setUp()[source]

Set up test fixtures.

test_song_unet_square_resolution()[source]

Test SongUNet with square resolution.

test_song_unet_rectangular_resolution()[source]

Test SongUNet with rectangular resolution.

test_dhariwal_unet()[source]

Test DhariwalUNet architecture.

test_vp_preconditioner()[source]

Test VPPrecond with conditional images.

test_ve_preconditioner()[source]

Test VEPrecond with conditional images.

test_edm_preconditioner()[source]

Test EDMPrecond with conditional images.

test_parameter_counts()[source]

Test that all models have reasonable parameter counts.

tearDown()[source]

Clean up after tests.

IPSL_AID.utils module

class IPSL_AID.utils.EasyDict[source]

Bases: dict

A dictionary subclass that allows for attribute-style access to its items. This class extends the built-in dict and overrides the __getattr__, __setattr__, and __delattr__ methods to enable accessing dictionary keys as attributes. Original work: Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. Original source: https://github.com/NVlabs/edm

class IPSL_AID.utils.FileUtils[source]

Bases: object

Utility class for file and directory operations.

__init__()[source]

Initialize the FileUtils class. This class does not maintain any state, so the constructor is empty.

static makedir(dirs)[source]

Create a directory if it does not exist.

Parameters:

dirs (str) – The path of the directory to be created.

static makefile(dirs, filename)[source]

Create an empty file in the specified directory. :param dirs: The path of the directory where the file will be created. :type dirs: str :param filename: The name of the file to be created. :type filename: str

class IPSL_AID.utils.TestEasyDict(methodName='runTest', logger=None)[source]

Bases: TestCase

Unit tests for EasyDict class.

__init__(methodName='runTest', logger=None)[source]

Create an instance of the class that will use the named test method when executed. Raises a ValueError if the instance does not have a method with the specified name.

setUp()[source]

Set up test fixtures.

test_empty_initialization()[source]

Test initializing an empty EasyDict.

test_initialization_with_dict()[source]

Test initializing EasyDict with a dictionary.

test_initialization_with_kwargs()[source]

Test initializing EasyDict with keyword arguments.

test_attribute_get_set()[source]

Test getting and setting attributes with dot notation.

test_dict_get_set()[source]

Test getting and setting with dictionary notation.

test_mixed_access()[source]

Test mixing dot and bracket notation.

test_attribute_error_for_nonexistent()[source]

Test that accessing nonexistent attribute raises AttributeError.

test_key_error_for_nonexistent()[source]

Test that accessing nonexistent key raises KeyError.

test_delattr()[source]

Test deleting attributes with delattr.

test_delitem()[source]

Test deleting items with del.

test_dict_methods()[source]

Test that standard dictionary methods work.

test_nested_easydict()[source]

Test that nested dictionaries are not automatically converted.

test_easydict_with_easydict()[source]

Test nesting EasyDict inside EasyDict.

tearDown()[source]

Clean up after tests.

class IPSL_AID.utils.TestFileUtils(methodName='runTest', logger=None)[source]

Bases: TestCase

Unit tests for FileUtils class.

__init__(methodName='runTest', logger=None)[source]

Create an instance of the class that will use the named test method when executed. Raises a ValueError if the instance does not have a method with the specified name.

setUp()[source]

Set up test fixtures.

test_makedir_new_directory()[source]

Test creating a new directory that doesn’t exist.

test_makedir_existing_directory()[source]

Test calling makedir on an existing directory.

test_makedir_multiple_nested_directories()[source]

Test creating multiple nested directories at once.

test_makefile_new_file()[source]

Test creating a new file in an existing directory.

test_makefile_in_nonexistent_directory()[source]

Test creating a file in a directory that doesn’t exist.

test_makefile_existing_file()[source]

Test creating a file that already exists.

test_makefile_multiple_files()[source]

Test creating multiple files.

test_makedir_then_makefile()[source]

Test creating a directory then a file inside it.

tearDown()[source]

Clean up after tests.

IPSL_AID.download_ERA5_cds module

IPSL_AID.download_ERA5_cds.parse_args()[source]

Parse command-line arguments.

Returns:

Parsed command line arguments as a namespace object with attributes corresponding to each argument.

Return type:

argparse.Namespace

Notes

ERA5 variable names must match those defined in the Copernicus Climate Data Store catalogue.

IPSL_AID.download_ERA5_cds.main(logger)[source]

Download ERA5 data from the Copernicus Climate Data Store.

The function follows a structured workflow:

  1. Parse command-line arguments.

  2. Create output directories.

  3. Loop over requested years, variables, and months.

  4. Submit download requests to the CDS API.

  5. Save results as NetCDF files.

Files are skipped if they already exist.

Notes

Data are downloaded at hourly resolution for all days of each month.

The CDS API client requires a valid configuration file. Visit: https://cds.climate.copernicus.eu/

The dataset used depends on whether pressure levels are requested:

  • reanalysis-era5-single-levels

  • reanalysis-era5-pressure-levels

IPSL_AID.generate_all_data_ERA5 module

IPSL_AID.generate_all_data_ERA5.parse_args()[source]

Parse command-line arguments.

Returns:

Parsed command line arguments as a namespace object with attributes corresponding to each argument.

Return type:

argparse.Namespace

Raises:

ValueError – If the number of variables does not match the number of rename variables.

IPSL_AID.generate_all_data_ERA5.main(logger)[source]

Generate yearly ERA5 datasets from monthly NetCDF files.

The function follows a structured workflow:

  1. Parse command-line arguments.

  2. Load a CSV file containing timestamps to extract.

  3. Loop over years and variables.

  4. Open monthly ERA5 NetCDF files using xarray.

  5. Extract requested timestamps.

  6. Concatenate monthly subsets into yearly datasets.

  7. Rename variables and write compressed NetCDF files.

Notes

ERA5 data is stored monthly, so timestamps are grouped by month.