IPSL_AID package
Submodules
IPSL_AID.dataset module
- IPSL_AID.dataset.stats(ds, logger, input_dir, norm_mapping={})[source]
Load normalization statistics and compute coordinate metadata for a NetCDF dataset.
This function loads normalization statistics from a JSON file if available. If no statistics file is found, it falls back to predefined constants for fine and coarse variables only.
- Parameters:
ds (xarray.Dataset) – NetCDF dataset to process.
logger (logging.Logger) – Logger instance for logging messages and statistics.
input_dir (str) – Directory containing a statistics.json file with precomputed normalization statistics.
norm_mapping (dict, optional) – Dictionary to store computed statistics. If empty, will be populated. Default is empty dict.
- Returns:
norm_mapping (dict) – Dictionary mapping variable names to their computed statistics. For coordinates: min, max, mean, std. For data variables: min, max, mean, std, q1, q3, iqr, median.
steps (EasyDict) – Dictionary containing coordinate step sizes and lengths.
Notes
If statistics.json is found, all statistics are loaded as-is.
If not found, fallback constants are used for fine and coarse variables only.
- IPSL_AID.dataset.coarse_down_up(fine_filtered, fine_batch, input_shape=(16, 32), axis=0)[source]
Downscale and then upscale fine-resolution data to compute coarse approximation.
This function performs a downscaling-upscaling operation to create a coarse resolution approximation of fine data. This is commonly used in multi-scale analysis, image processing, and super-resolution tasks.
- Parameters:
fine_filtered (torch.Tensor or np.ndarray) – Fine-resolution filtered data. Can be of shape (C, Hf, Wf) for multi-channel data or (Hf, Wf) for single-channel data. Where C is number of channels, Hf is fine height, and Wf is fine width.
fine_batch (torch.Tensor or np.ndarray) – Fine-resolution target data. Must have same spatial dimensions as fine_filtered. Shape: (C, Hf, Wf) or (Hf, Wf).
input_shape (tuple of int, optional) – Target shape (Hc, Wc) for the coarse-resolution data after downscaling. Default is (16, 32).
axis (int, optional) – Axis along which to insert batch dimension if the input lacks one. Default is 0.
- Returns:
coarse_up – Upscaled coarse approximation of the fine data. Same shape as input fine_filtered without batch dimension.
- Return type:
torch.Tensor
Notes
The function ensures that the input tensors have a batch dimension by adding one if missing.
Uses bilinear interpolation for both downscaling and upscaling operations.
The antialias parameter is set to True for better quality resampling.
Useful for creating multi-scale representations in image processing and computer vision tasks.
- IPSL_AID.dataset.gaussian_filter(image, dW, dH, cutoff_W_phys, cutoff_H_phys, epsilon=0.01, margin=8)[source]
Apply a Gaussian low-pass filter with controlled attenuation at the cutoff frequency.
This function performs a 2D Fourier transform of the input field, applies a Gaussian weighting in the frequency domain, and inversely transforms it back to the spatial domain. Unlike the standard Gaussian filter, this version defines the Gaussian width such that the response amplitude reaches a specified attenuation factor (epsilon) at the cutoff frequency. Padding with reflection is used to minimize edge artifacts.
- Parameters:
image (np.array of shape (H, W)) – Input 2D field to be filtered (temperature, wind component).
dW (float) – Grid spacing in degrees of longitude.
dH (float) – Grid spacing in degrees of latitude.
cutoff_W_phys (float) – Longitudinal cutoff frequency in cycles per degree. Frequencies higher than this threshold are attenuated according to the Gaussian response.
cutoff_H_phys (float) – Latitudinal cutoff frequency in cycles per degree. Frequencies higher than this threshold are attenuated according to the Gaussian response.
epsilon (float, optional) – Desired amplitude response at the cutoff frequency (default: 0.01). Lower values produce sharper attenuation and stronger filtering.
margin (int, optional) – Number of pixels to pad on each side using reflection (default: 8). This reduces edge effects in the Fourier transform.
- Returns:
filtered – Real-valued filtered field after inverse Fourier transform and margin cropping.
- Return type:
ndarray of shape (H, W)
Notes
The Gaussian width parameters are computed such that: exp(-0.5 * (f_cutoff / sigma)^2) = epsilon, leading to sigma = f_cutoff / sqrt(-2 * log(epsilon)).
Padding the input with reflective boundaries minimizes spectral leakage and discontinuities at image edges.
The output field is cropped back to its original size after filtering.
This formulation provides more explicit control over filter sharpness than the standard Gaussian low-pass implementation.
- class IPSL_AID.dataset.DataPreprocessor(*args: Any, **kwargs: Any)[source]
Bases:
DatasetDataset class for preprocessing weather and climate data for machine learning.
This class handles loading, preprocessing, and sampling of multi-year NetCDF weather data with support for multi-scale processing, normalization, and spatial-temporal sampling strategies.
- Parameters:
loaded_dfs (xarray.Dataset) – Pre-loaded dataset containing the weather variables.
constants_file_path (str) – Path to NetCDF file containing constant variables (e.g., topography).
varnames_list (list of str) – List of variable names to extract from the dataset.
units_list (list of str) – Units for each variable in varnames_list.
in_shape (tuple of int, optional) – Target shape (height, width) for coarse resolution. Default is (16, 32).
batch_size_lat (int, optional) – Height of spatial batch in grid points. Default is 144.
batch_size_lon (int, optional) – Width of spatial batch in grid points. Default is 144.
steps (EasyDict, optional) – Dictionary containing grid dimension information. Should include: - latitude/lat: number of latitude points - longitude/lon: number of longitude points - time: number of time steps - d_latitude/d_lat: latitude spacing - d_longitude/d_lon: longitude spacing
tbatch (int, optional) – Number of time batches to sample. Default is 1.
sbatch (int, optional) – Number of spatial batches to sample. Default is 8.
debug (bool, optional) – Enable debug logging. Default is True.
mode (str, optional) – Operation mode: “train” or “validation”. Default is “train”.
run_type (str, optional) – Run type: “train”, “validation”, or “inference”. Default is “train”.
dynamic_covariates (list of str, optional) – List of dynamic covariate variable names. Default is None.
dynamic_covariates_dir (str, optional) – Directory containing dynamic covariate files. Default is None.
time_normalization (str, optional) – Method for time normalization: “linear” or “cos_sin”. Default is “linear”.
norm_mapping (dict, optional) – Dictionary containing normalization statistics for variables.
index_mapping (dict, optional) – Dictionary mapping variable names to indices in the data array.
normalization_type (dict, optional) – Dictionary specifying normalization type per variable.
constant_variables (list of str, optional) – List of constant variable names to load. Default is None.
epsilon (float, optional) – Small value for numerical stability in filtering. Default is 0.02.
margin (int, optional) – Margin for filtering operations. Default is 8.
dtype (tuple, optional) – Data types for torch and numpy (torch_dtype, np_dtype). Default is (torch.float32, np.float32).
apply_filter (bool, optional) – Whether to apply Gaussian filtering for multi-scale processing. Default is False.
logger (logging.Logger, optional) – Logger instance for logging messages. Default is None.
- const_vars
Array of constant variables with shape (n_constants, H, W).
- Type:
np.ndarray or None
- time
Time coordinate from dataset.
- Type:
- year
Year component of time.
- Type:
- month
Month component of time.
- Type:
- day
Day component of time.
- Type:
- hour
Hour component of time.
- Type:
- year_norm
Normalized year values.
- Type:
torch.Tensor
- doy_norm
Normalized day-of-year values (linear mode).
- Type:
torch.Tensor or None
- hour_norm
Normalized hour values (linear mode).
- Type:
torch.Tensor or None
- doy_sin, doy_cos
Sine and cosine of day-of-year (cos_sin mode).
- Type:
torch.Tensor or None
- hour_sin, hour_cos
Sine and cosine of hour (cos_sin mode).
- Type:
torch.Tensor or None
- time_batchs
Array of time indices for current epoch.
- Type:
np.ndarray
- sample_time_steps_by_doy()[source]
Sample time steps based on day-of-year (DOY) for multi-year continuity.
- generate_random_batch_centers(n_batches)[source]
Generate random spatial centers for batch sampling.
- extract_batch(data, ilat, ilon)[source]
Extract spatial batch centered at (ilat, ilon) with cyclic longitude.
- filter_batch(fine_patch, fine_block)[source]
Apply Gaussian low-pass filtering for multi-scale processing.
- normalize(data, stats, norm_type, var_name=None, data_type=None)[source]
Normalize data using specified statistics and method.
Notes
Supports both random (training) and deterministic (validation) sampling.
Handles cyclic longitude wrapping for global datasets.
Provides multi-scale processing through downscaling/upscaling.
Includes time normalization with linear or trigonometric encoding.
Can incorporate constant variables (e.g., topography, land-sea mask).
- __init__(years, loaded_dfs, constants_file_path, varnames_list, units_list, in_shape=(16, 32), batch_size_lat=144, batch_size_lon=144, steps={}, tbatch=1, sbatch=8, debug=True, mode='train', run_type='train', dynamic_covariates=None, dynamic_covariates_dir=None, time_normalization='linear', norm_mapping=None, index_mapping=None, normalization_type=None, constant_variables=None, epsilon=0.02, margin=8, dtype=(torch.float32, numpy.float32), apply_filter=False, region_center=None, region_size=None, logger=None)[source]
Initialize the DataPreprocessor.
- Parameters:
loaded_dfs (xarray.Dataset) – Pre-loaded dataset containing the weather variables.
constants_file_path (str) – Path to NetCDF file containing constant variables.
varnames_list (list of str) – List of variable names to extract.
in_shape (tuple of int, optional) – Target shape for coarse resolution.
batch_size_lat (int, optional) – Height of spatial batch.
batch_size_lon (int, optional) – Width of spatial batch.
steps (EasyDict, optional) – Grid dimension information.
tbatch (int, optional) – Number of time batches.
sbatch (int, optional) – Number of spatial batches.
debug (bool, optional) – Enable debug logging.
mode (str, optional) – Operation mode.
run_type (str, optional) – Run type.
dynamic_covariates (list of str, optional) – Dynamic covariate variable names.
dynamic_covariates_dir (str, optional) – Directory for dynamic covariates.
time_normalization (str, optional) – Time normalization method.
norm_mapping (dict, optional) – Normalization statistics.
index_mapping (dict, optional) – Variable to index mapping.
normalization_type (dict, optional) – Normalization type per variable.
constant_variables (list of str, optional) – Constant variable names.
epsilon (float, optional) – Numerical stability value.
margin (int, optional) – Filter margin.
dtype (tuple, optional) – Data types for torch and numpy.
apply_filter (bool, optional) – Apply Gaussian filtering.
region_center (tuple of float or None) – Fixed geographic center (lat, lon) for spatial sampling.
logger (logging.Logger, optional) – Logger instance.
- new_epoch()[source]
Prepare for a new training epoch by generating new time batches.
This method is called at the start of each training epoch to refresh the temporal and spatial sampling.
- sample_time_steps_by_doy()[source]
Sample time steps based on day-of-year (DOY) for multi-year continuity.
This method selects unique DOYs from the available multi-year data and picks one random time index for each DOY.
- Raises:
ValueError – If requested tbatch exceeds number of unique DOYs.
- sample_random_time_indices()[source]
Generate random time indices for training.
This method samples random time indices uniformly across the available time range.
- get_center_indices_from_latlon(lat_value, lon_value)[source]
Convert geographic coordinates (latitude, longitude) to nearest grid indices.
- Parameters:
- Returns:
lat_idx (int) – Index of the closest latitude grid point.
lon_idx (int) – Index of the closest longitude grid point.
Notes
The dataset is defined on a discrete latitude–longitude grid.
Since spatial extraction operates on grid indices, the requested physical coordinates are mapped to the nearest available grid point.
This ensures consistency between user-defined locations and internal batch extraction logic.
- generate_random_batch_centers(n_batches)[source]
Generate random (latitude, longitude) centers for batch sampling.
- Parameters:
n_batches (int) – Number of random centers to generate.
- Returns:
centers – List of (lat_center, lon_center) tuples.
- Return type:
Notes
Latitude centers avoid poles to ensure full batch extraction.
Longitude centers can be any value due to cyclic wrapping.
- generate_region_slices(lat_center, lon_center, region_size_lat, region_size_lon)[source]
Generate deterministic spatial slices for regional inference. The slices cover a block centered on (lat_center, lon_center). The region is divided into non-overlapping blocks of size (batch_size_lat, batch_size_lon) used for model inference.
- Parameters:
- Returns:
slices – List of (lat_start, lat_end, lon_start, lon_end) tuples defining non-overlapping spatial blocks covering the selected region.
- Return type:
Notes
The latitude start index is clamped to ensure the region remains within the global latitude bounds. As a result, if the requested region is too close to the poles, the extracted region may be shifted and may no longer be centered exactly on (lat_center, lon_center). Longitude wrapping is handled later during patch extraction (in extract_batch).
- extract_batch(data, ilat, ilon)[source]
Extract spatial batch centered at (ilat, ilon) with cyclic longitude.
- Parameters:
- Returns:
block (torch.Tensor or np.ndarray) – Extracted batch with shape (…, batch_size_lat, batch_size_lon).
indices (tuple) – Tuple of (lat_start, lat_end, lon_start, lon_end) indices.
- Raises:
AssertionError – If input tensor dimensions don’t match grid dimensions or if indices are invalid.
Notes
Longitude is treated as cyclic (wraps around 0-360°).
Latitude is non-cyclic (no wrapping at poles).
The function rolls the data to center the longitude and then extracts the appropriate slice.
- filter_batch(fine_patch, fine_block)[source]
Apply Gaussian low-pass filtering for multi-scale processing.
- Parameters:
fine_patch (np.ndarray) – Fine-resolution data of shape (C, H, W).
fine_block (np.ndarray) – Reference block used to determine scaling factors.
- Returns:
fine_filtered – Filtered data of shape (C, H, W).
- Return type:
np.ndarray
Notes
Filters high-frequency components beyond the coarse grid’s Nyquist.
Uses Gaussian filtering in the frequency domain.
Processes each channel independently.
- normalize(data, stats, norm_type, var_name=None, data_type=None)[source]
Normalize data using specified statistics and method.
- Parameters:
data (torch.Tensor) – Input data to normalize.
stats (object) – Statistics object with attributes: vmin, vmax, vmean, vstd, median, iqr, q1, q3.
norm_type (str) – Normalization type: “minmax”, “minmax_11”, “standard”, “robust”, “log1p_minmax”, “log1p_standard”.
var_name (str, optional) – Variable name for logging.
data_type (str, optional) – Data type description for logging.
- Returns:
Normalized data.
- Return type:
torch.Tensor
- Raises:
ValueError – If norm_type is not supported.
- normalize_time(tindex)[source]
Return normalized time features for given time index.
- Parameters:
tindex (int) – Time index.
- Returns:
Dictionary of normalized time features.
- Return type:
Notes
Features depend on time_normalization setting: - “linear”: year_norm, doy_norm, hour_norm - “cos_sin”: year_norm, doy_sin, doy_cos, hour_sin, hour_cos
- IPSL_AID.dataset.create_dummy_netcdf(temp_dir, year=2020, has_constants=False)[source]
Create a dummy NetCDF file for testing.
- IPSL_AID.dataset.create_dummy_statistics_json(temp_dir)[source]
Create dummy statistics.json file for testing.
- class IPSL_AID.dataset.TestDataPreprocessor(methodName='runTest', logger=None)[source]
Bases:
TestCaseUnit tests for DataPreprocessor class.
- __init__(methodName='runTest', logger=None)[source]
Create an instance of the class that will use the named test method when executed. Raises a ValueError if the instance does not have a method with the specified name.
- test_stats_computation_without_existing_json()[source]
Test stats computation when no statistics.json exists.
- test_stats_computation_with_existing_json()[source]
Test stats computation when statistics.json exists.
- test_gaussian_filter_different_epsilon()[source]
Test Gaussian filter with different epsilon values.
- test_preprocessor_initialization_train_mode()[source]
Test DataPreprocessor initialization in train mode.
- test_preprocessor_initialization_validation_mode()[source]
Test DataPreprocessor initialization in validation mode.
- test_preprocessor_initialization_cos_sin_normalization()[source]
Test DataPreprocessor with cos_sin time normalization.
- test_preprocessor_batch_size_validation()[source]
Test that batch size validation raises appropriate errors.
IPSL_AID.diagnostics module
- class IPSL_AID.diagnostics.PlotConfig[source]
Bases:
objectCentral configuration for all plotting functions.
- DEFAULT_SAVE_DIR = './results'
- DEFAULT_FIGSIZE_MULTIPLIER = 4
- COLORMAPS = {'10u': 'BrBG_r', '10v': 'BrBG_r', '2t': 'rainbow', 'TP': 'Blues', 'curl': 'seismic', 'd2m': 'rainbow', 'default': 'viridis', 'dewpoint': 'rainbow', 'divergence': 'seismic', 'error': 'Reds', 'humid': 'Greens', 'humidity': 'Greens', 'mae': 'Reds', 'meridional': 'BrBG_r', 'precipitation': 'Blues', 'pres': 'viridis', 'pressure': 'viridis', 'speed': 'coolwarm', 'ssr': 'seismic', 'st': 'rainbow', 'surface temperature': 'rainbow', 'temp': 'rainbow', 'temperature': 'rainbow', 'tp': 'Blues', 'wind': 'coolwarm', 'zonal': 'BrBG_r'}
- FIXED_DIFF_RANGES = {'10u': (-5.0, 5.0), '10v': (-5.0, 5.0), '2t': (-5.0, 5.0), 'T2M': (-5.0, 5.0), 'TP': (-0.5, 0.5), 'U10': (-5.0, 5.0), 'V10': (-5.0, 5.0), 'VAR_10U': (-5.0, 5.0), 'VAR_10V': (-5.0, 5.0), 'VAR_2T': (-5.0, 5.0), 'VAR_D2M': (-5.0, 5.0), 'VAR_ST': (-5.0, 5.0), 'VAR_TP': (-0.5, 0.5), 'meridional': (-5.0, 5.0), 'temperature': (-5.0, 5.0), 'tp': (-0.5, 0.5)}
- FIXED_DIFF_RANGES_ERRORS = {'Humid': (0, 3.0), 'Press': (0, 3.0), 'Temp': (0, 3.0), 'VAR_10U': (0, 3.0), 'VAR_10V': (0, 3.0), 'VAR_2T': (0, 0.01), 'VAR_D2M': (0, 1.0), 'VAR_ST': (0, 1.0), 'VAR_TP': (0, 0.5), 'Wind': (0, 3.0)}
- FIXED_MAE_RANGES = {'10u': (0.0, 3.0), '10v': (0.0, 3.0), '2t': (0.0, 3.0), 'T2M': (0.0, 3.0), 'TP': (0.0, 1.0), 'U10': (0.0, 3.0), 'V10': (0.0, 3.0), 'VAR_10U': (0.0, 3.0), 'VAR_10V': (0.0, 3.0), 'VAR_2T': (0.0, 3.0), 'VAR_D2M': (0.0, 3.0), 'VAR_ST': (0.0, 3.0), 'VAR_TP': (0.0, 1.0), 'meridional': (0.0, 3.0), 'temperature': (0.0, 3.0), 'tp': (0.0, 1.0)}
- FIXED_SSR_RANGES = {'10u': (0.0, 3.0), '10v': (0.0, 3.0), '2t': (0.0, 3.0), 'T2M': (0.0, 3.0), 'TP': (0.0, 3.0), 'U10': (0.0, 3.0), 'V10': (0.0, 3.0), 'VAR_10U': (0.0, 3.0), 'VAR_10V': (0.0, 3.0), 'VAR_2T': (0.0, 3.0), 'VAR_D2M': (0.0, 3.0), 'VAR_ST': (0.0, 3.0), 'VAR_TP': (0.0, 1.0), 'meridional': (0.0, 3.0), 'temperature': (0.0, 3.0), 'tp': (0.0, 3.0)}
- COASTLINE_w = 0.5
- BORDER_w = 0.5
- LAKE_w = 0.5
- BORDER_STYLE = '--'
- COLORBAR_h = 0.02
- COLORBAR_PAD = 0.05
- classmethod convert_units(variable_name, data)[source]
Safe unit conversion when required. - NEVER modifies input - Returns a new array only if conversion is needed
- static get_fixed_diff_range(var_name)[source]
Get fixed visualization range for signed differences (Prediction − Truth).
- IPSL_AID.diagnostics.plot_validation_hexbin(predictions, targets, coarse_inputs=None, variable_names=None, filename='validation_hexbin.png', save_dir='./results', figsize_multiplier=4)[source]
Create hexbin plots comparing model predictions vs ground truth for all variables.
- Parameters:
predictions (torch.Tensor or np.array) – Model predictions of shape [batch_size, num_variables, h, w]
targets (torch.Tensor or np.array) – Ground truth of shape [batch_size, num_variables, h, w]
coarse_inputs (torch.Tensor or np.array, optional) – Coarse inputs of shape [batch_size, num_variables, h, w]
variable_names (list of str, optional) – Names of the variables for subplot titles
filename (str, optional) – Output filename
save_dir (str, optional) – Directory to save the plot
figsize_multiplier (int, optional) – Base size multiplier for subplots
- IPSL_AID.diagnostics.plot_comparison_hexbin(predictions, targets, coarse_inputs, variable_names=None, filename='comparison_hexbin.png', save_dir='./results', figsize_multiplier=4)[source]
Create hexbin comparison plots between model predictions, ground truth, and coarse inputs.
For each variable, creates two side-by-side hexbin plots: 1. Model predictions vs ground truth 2. Coarse inputs vs ground truth
Each plot includes an identity line and R²/MAE metrics.
- Parameters:
predictions (torch.Tensor or np.array) – Model predictions of shape [batch_size, num_variables, h, w]
targets (torch.Tensor or np.array) – Ground truth of shape [batch_size, num_variables, h, w]
coarse_inputs (torch.Tensor or np.array) – Coarse inputs of shape [batch_size, num_variables, h, w]
variable_names (list of str, optional) – Names of the variables for subplot titles. If None, uses VAR_0, VAR_1, etc.
filename (str, optional) – Output filename
save_dir (str, optional) – Directory to save the plot
figsize_multiplier (int, optional) – Base size multiplier for subplots
- Returns:
save_path – Path to the saved figure
- Return type:
- IPSL_AID.diagnostics.plot_metric_histories(valid_metrics_history, variable_names, metric_names, filename='validation_metrics', save_dir='./results', figsize_multiplier=4)[source]
Creates row-based panel plots: one figure per metric, rows = variables, shared x-axis.
- Parameters:
- IPSL_AID.diagnostics.plot_metrics_heatmap(valid_metrics_history, variable_names, metric_names, filename='validation_metrics_heatmap', save_dir='./results', figsize_multiplier=4)[source]
Plot a heatmap of validation metrics.
- Parameters:
valid_metrics_history (dict) – Dict from validation loop storing metric histories.
metric_names (list of str) – List of metric names ([“MAE”, “NMAE”, “RMSE”, “R²”]).
filename (str) – Prefix for saved figures.
save_dir (str) – Directory where images are saved.
figsize_multiplier (float) – Controls overall figure size
- IPSL_AID.diagnostics.plot_loss_histories(train_loss_history, valid_loss_history, filename='training_validation_loss.png', save_dir='./results', figsize_multiplier=4)[source]
Plots training and validation loss in a single panel.
Parameters:
- train_loss_historylist or array
History of training loss values.
- valid_loss_historylist or array
History of validation loss values.
- filenamestr
Output image file name for the plot.
- save_dirstr
Directory to save the plot.
- IPSL_AID.diagnostics.plot_average_metrics(valid_metrics_history, metric_names, filename='average_metrics.png', save_dir='./results', figsize_multiplier=4)[source]
Plots average metrics across all variables in a row-based layout with shared x-axis.
- Each row corresponds to one metric, plotting both:
average_pred_vs_fine_<metric>
average_coarse_vs_fine_<metric>
- IPSL_AID.diagnostics.plot_spatiotemporal_histograms(steps, tindex_lim, centers, tindices, mode='train', filename='average_metrics.png', save_dir='./results', figsize_multiplier=4)[source]
Plot two 2D hexagonal bin histograms showing spatial-temporal data coverage: latitude center vs temporal index and longitude center vs temporal index.
This function visualizes the distribution of data samples across spatial (latitude/longitude) and temporal dimensions using hexagonal binning, which provides smoother density estimation compared to rectangular binning.
- Parameters:
steps (EasyDict) – Dictionary containing coordinate dimensions and limits. Expected to have attributes ‘latitude’ (or ‘lat’) and ‘longitude’ (or ‘lon’) specifying the maximum spatial indices.
tindex_lim (tuple) – Tuple of (min_time, max_time) specifying the temporal index limits.
centers (list of tuples) – List of (lat_center, lon_center) coordinates for each data sample. Each center represents the spatial location of a data point.
tindices (list or array-like) – List of temporal indices corresponding to each data sample. Should have the same length as ‘centers’.
mode (str) – Dataset mode identifier, typically “train” or “validation”. Used for plot title and filename.
save_dir (str) – Directory path where the plot will be saved. Directory will be created if it doesn’t exist.
filename (str, optional) – Optional prefix to prepend to the output filename. Default is empty string.
- Returns:
The function saves the plot to disk and does not return any value.
- Return type:
None
Notes
The function creates two side-by-side subplots: 1. Latitude center index vs temporal index 2. Longitude center index vs temporal index
Uses hexagonal binning (hexbin) for density visualization, which reduces visual artifacts compared to rectangular histograms.
A single colorbar is shared between both plots with log10 scaling.
The color scale is normalized to the maximum count across both histograms.
Hexagons with zero count (mincnt=1) are not displayed.
Examples
>>> steps = EasyDict({'latitude': 180, 'longitude': 360}) >>> tindex_lim = (0, 1000) >>> centers = [(10, 20), (15, 25), (10, 20), ...] # list of (lat, lon) >>> tindices = [0, 5, 10, 15, ...] # corresponding temporal indices >>> plot_spatiotemporal_histograms(steps, tindex_lim, centers, ... tindices, "train", "./plots")
The function will save a plot named “spatiotemporal_train_hexbin.png” in the “./plots” directory.
- IPSL_AID.diagnostics.plot_surface(predictions, targets, coarse_inputs, lat_1d, lon_1d, timestamp=None, variable_names=None, filename='forecast_plot.png', save_dir=None, figsize_multiplier=None)[source]
Plot side-by-side forecast maps (coarse_inputs input, true target, model prediction, and difference) for one or more meteorological variables over a geographic domain.
- Parameters:
coarse_inputs (torch.Tensor or np.ndarray) – coarse_inputs-resolution input data with shape [1, n_vars, H, W].
targets (torch.Tensor or np.ndarray) – Ground-truth high-resolution data with shape [1, n_vars, H, W].
predictions (torch.Tensor or np.ndarray) – Model predictions at targets resolution with shape [1, n_vars, H, W].
lat_1d (array-like) – 1D array of latitude coordinates with shape [H].
lon_1d (array-like) – 1D array of longitude coordinates with shape [W].
timestamp (datetime.datetime) – Forecast timestamp to include in the plot title.
variable_names (list of str, optional) – Variable names or identifiers.
filename (str, optional) – Output filename for saving the plot.
save_dir (str, optional) – Directory to save the plot.
figsize_multiplier (int, optional) – Base size multiplier for subplots.
- Return type:
None
- IPSL_AID.diagnostics.plot_ensemble_surface(predictions_ens, lat_1d, lon_1d, variable_names, timestamp=None, filename='ensemble_surface.png', save_dir='./results')[source]
Plot ensemble members, ensemble mean, and ensemble spread.
- Parameters:
predictions_ens (torch.Tensor or np.ndarray) – Ensemble predictions of shape [n_ensemble_members, n_vars, H, W]
lat_1d (array-like) – 1D array of latitude coordinates with shape [H].
lon_1d (array-like) – 1D array of longitude coordinates with shape [W].
variable_names (list of str, optional) – Variable names or identifiers.
timestamp (datetime.datetime) – Forecast timestamp to include in the plot title.
filename (str, optional) – Output filename for saving the plot.
save_dir (str, optional) – Directory to save the plot.
figsize_multiplier (int, optional) – Base size multiplier for subplots.
- Return type:
None
- IPSL_AID.diagnostics.plot_zoom_comparison(predictions, targets, lat_1d, lon_1d, variable_names=None, filename='zoom_plot.png', save_dir=None, zoom_box=None)[source]
Plot a comparison between ground truth and model predictions with a geographic zoom.
- Parameters:
targets (torch.Tensor or np.ndarray) – Ground-truth high-resolution data with shape [1, n_vars, H, W].
predictions (torch.Tensor or np.ndarray) – Model predictions at targets resolution with shape [1, n_vars, H, W].
lat_1d (array-like) – 1D array of latitude coordinates with shape [H].
lon_1d (array-like) – 1D array of longitude coordinates with shape [W].
variable_names (list of str, optional) – Variable names or identifiers.
filename (str, optional) – Output filename for saving the plot.
save_dir (str, optional) – Directory to save the plot.
zoom_box (dict, optional) – Dictionary defining the zoom region with keys.
- Return type:
None
- IPSL_AID.diagnostics.plot_global_surface_robinson(predictions, targets, coarse_inputs, lat_1d, lon_1d, timestamp=None, variable_names=None, filename='global_robinson.png', save_dir=None, figsize_multiplier=None)[source]
Plot coarse, truth, prediction and difference fields in Robinson projection.
- Parameters:
coarse_inputs (torch.Tensor or np.ndarray) – coarse_inputs-resolution input data with shape [1, n_vars, H, W].
targets (torch.Tensor or np.ndarray) – Ground-truth high-resolution data with shape [1, n_vars, H, W].
predictions (torch.Tensor or np.ndarray) – Model predictions at targets resolution with shape [1, n_vars, H, W].
lat_1d (array-like) – 1D array of latitude coordinates with shape [H].
lon_1d (array-like) – 1D array of longitude coordinates with shape [W].
timestamp (datetime.datetime) – Forecast timestamp to include in the plot title.
variable_names (list of str, optional) – Variable names or identifiers.
filename (str, optional) – Output filename for saving the plot.
save_dir (str, optional) – Directory to save the plot.
figsize_multiplier (int, optional) – Base size multiplier for subplots.
- Return type:
None
- IPSL_AID.diagnostics.plot_MAE_map(predictions, targets, lat_1d, lon_1d, timestamp=None, variable_names=None, filename='validation_mae_map.png', save_dir=None, figsize_multiplier=None)[source]
Plot spatial MAE maps averaged over all time steps: MAE(x, y) = mean_t(abs(prediction - target))
- Parameters:
predictions (torch.Tensor or np.array) – Model predictions of shape [batch_size, num_variables, h, w]
targets (torch.Tensor or np.array) – Ground truth of shape [batch_size, num_variables, h, w]
lat_1d (array-like) – 1D array of latitude coordinates with shape [H].
lon_1d (array-like) – 1D array of longitude coordinates with shape [W].
timestamp (datetime.datetime) – Forecast timestamp to include in the plot title.
variable_names (list of str, optional) – Variable names or identifiers.
filename (str, optional) – Output filename for saving the plot.
save_dir (str, optional) – Directory to save the plot.
figsize_multiplier (int, optional) – Base size multiplier for subplots.
- Return type:
None
- IPSL_AID.diagnostics.plot_error_map(predictions, targets, lat_1d, lon_1d, timestamp=None, variable_names=None, filename='validation_error_map.png', save_dir=None, figsize_multiplier=None)[source]
Plot spatial ERROR maps averaged over all time steps.
- Parameters:
predictions (torch.Tensor or np.array) – Model predictions of shape [batch_size, num_variables, h, w]
targets (torch.Tensor or np.array) – Ground truth of shape [batch_size, num_variables, h, w]
lat_1d (array-like) – 1D array of latitude coordinates with shape [H].
lon_1d (array-like) – 1D array of longitude coordinates with shape [W].
timestamp (datetime.datetime) – Forecast timestamp to include in the plot title.
variable_names (list of str, optional) – Variable names or identifiers.
filename (str, optional) – Output filename for saving the plot.
save_dir (str, optional) – Directory to save the plot.
figsize_multiplier (int, optional) – Base size multiplier for subplots.
- Return type:
None
- IPSL_AID.diagnostics.spread_skill_ratio(predictions, targets, variable_names, pixel_wise=False)[source]
Compute spread skill ratio of predictions with respect to targets. The formula implemented is equation (15) in “Why Should Ensemble Spread Match the RMSE of the Ensemble Mean?”, Fortin et al.
- Parameters:
predictions (torch.Tensor or np.array) – Model predictions of shape [ensemble_size, batch_size, num_variables, h, w] It is very important not to switch dimensions order. ensemble_size must be greater or equal than 2 for spread skill ratio to be computed.
targets (torch.Tensor or np.array) – Ground truth of shape [batch_size, num_variables, h, w]
variable_names (list of str, optional) – Variable names or identifiers.
pixel_wise (bool) – If True, computes and return the SSR for each pixel independantly. If False, computes and return the SSR averaged over all pixels and all timesteps. Defaults to False
- Returns:
np.array of shape [num_variables, h, w] if pixel_wise == True
or of shape [num_variables,] if pixel_wise == False (default)
- IPSL_AID.diagnostics.plot_spread_skill_ratio_map(predictions, targets, lat_1d, lon_1d, timestamp=None, variable_names=None, filename='validation_spread_skill_ratio_map.png', save_dir=None, figsize_multiplier=None)[source]
Plot spatial spread skill ratio maps averaged over all time steps for each individual pixel. The formula implemented is equation (15) in article “Why Should Ensemble Spread Match the RMSE of the Ensemble Mean?”, Fortin et al.
- Parameters:
predictions (torch.Tensor or np.array) – Model predictions of shape [ensemble_size, batch_size, num_variables, h, w] It is very important not to switch dimensions order. ensemble_size must be greater or equal than 2 for spread skill ratio to be computed.
targets (torch.Tensor or np.array) – Ground truth of shape [batch_size, num_variables, h, w]
lat_1d (array-like) – 1D array of latitude coordinates with shape [H].
lon_1d (array-like) – 1D array of longitude coordinates with shape [W].
timestamp (datetime.datetime) – Forecast timestamp to include in the plot title.
variable_names (list of str, optional) – Variable names or identifiers.
filename (str, optional) – Output filename for saving the plot.
save_dir (str, optional) – Directory to save the plot.
figsize_multiplier (int, optional) – Base size multiplier for subplots.
- Return type:
None
- IPSL_AID.diagnostics.plot_spread_skill_ratio_hexbin(predictions, targets, variable_names=None, filename='validation_spread_skill_ratio_hexbin.png', save_dir=None, figsize_multiplier=None)[source]
Plot spatial spread skill ratio scatterplot, where each point represent a prediction for a single pixel, single timestep: SSR(x, y) = spread(x,y) / skill(x,y) where spread(x,y) = temporal mean of standard deviation of ensemble members predictions and skill = temporal mean of RMSE of the mean of the ensemble members.
- Parameters:
predictions (torch.Tensor or np.array) – Model predictions of shape [ensemble_size, batch_size, num_variables, h, w] It is very important not to switch dimensions order. ensemble_size must be greater or equal than 2 for spread skill ratio to be computed.
targets (torch.Tensor or np.array) – Ground truth of shape [batch_size, num_variables, h, w]
variable_names (list of str, optional) – Variable names or identifiers.
filename (str, optional) – Output filename for saving the plot.
save_dir (str, optional) – Directory to save the plot.
figsize_multiplier (int, optional) – Base size multiplier for subplots.
- Return type:
None
- IPSL_AID.diagnostics.plot_validation_pdfs(predictions, targets, coarse_inputs=None, variable_names=None, filename='validation_pdfs.png', save_dir='./results', figsize_multiplier=4, save_npz=False)[source]
Create PDF (Probability Density Function) plots comparing distributions of model predictions vs ground truth for all variables.
- Parameters:
predictions (torch.Tensor or np.array) – Model predictions of shape [batch_size, num_variables, h, w]
targets (torch.Tensor or np.array) – Ground truth of shape [batch_size, num_variables, h, w]
coarse_inputs (torch.Tensor or np.array, optional) – Coarse inputs of shape [batch_size, num_variables, h, w]
variable_names (list of str, optional) – Names of the variables for subplot titles
filename (str, optional) – Output filename
save_dir (str, optional) – Directory to save the plot
figsize_multiplier (int, optional) – Base size multiplier for subplots
save_npz (bool, optional) – If True, saves the PDF diagnostics to a compressed .npz file.
- Returns:
The function saves the plot to disk and does not return any value.
- Return type:
None
Notes
Creates horizontal subplots (one per variable) showing PDFs
Each subplot shows up to 3 lines: Predictions, Ground Truth, and Coarse Inputs
Uses automatic color and linestyle cycling based on global matplotlib settings
Calculates and displays key statistics for each distribution
Handles both PyTorch tensors and numpy arrays
Examples
>>> predictions = np.random.randn(10, 3, 64, 64) # 10 samples, 3 variables >>> targets = np.random.randn(10, 3, 64, 64) >>> plot_validation_pdfs(predictions, targets, variable_names=['Temp', 'Pres', 'Humid'])
- IPSL_AID.diagnostics.plot_power_spectra(predictions, targets, dlat, dlon, coarse_inputs=None, variable_names=None, filename='power_spectra_physical.png', save_dir='./results', figsize_multiplier=4, save_npz=False)[source]
Calculate and plot power spectra with proper physical wavenumbers.
- Parameters:
predictions (torch.Tensor or np.array) – Model predictions of shape [batch_size, num_variables, nh, nw]
targets (torch.Tensor or np.array) – Ground truth of shape [batch_size, num_variables, nh, nw]
dlat (float) – Grid spacing in latitude (degrees)
dlon (float) – Grid spacing in longitude (degrees)
coarse_inputs (torch.Tensor or np.array, optional) – Coarse inputs of shape [batch_size, num_variables, nh, nw]
variable_names (list of str, optional) – Names of the variable names for subplot titles
filename (str, optional) – Output filename
save_dir (str, optional) – Directory to save the plot
figsize_multiplier (int, optional) – Base size multiplier for subplots
save_npz (bool, optional) – If True, saves the PDF diagnostics to a compressed .npz file.
- Return type:
None
- IPSL_AID.diagnostics.calculate_psd2d_simple(field)[source]
Simple 2D PSD calculation without preprocessing.
- IPSL_AID.diagnostics.radial_average_psd(psd2d, k_mag, k_bins)[source]
Radially average 2D PSD using wavenumber magnitude.
- IPSL_AID.diagnostics.plot_qq_quantiles(predictions, targets, coarse_inputs, variable_names=None, units=None, quantiles=[0.9, 0.95, 0.975, 0.99, 0.995], filename='qq_quantiles.png', save_dir='./results', figsize_multiplier=4, save_npz=False)[source]
Create QQ-plats at different quantiles comparing model predictions and coarse inputs against ground truth.
For each variable, plots quantiles of predictions and coarse inputs against quantiles of ground truth with a 1:1 reference line.
- Parameters:
predictions (torch.Tensor or np.array) – Model predictions of shape [batch_size, num_variables, h, w]
targets (torch.Tensor or np.array) – Ground truth of shape [batch_size, num_variables, h, w]
coarse_inputs (torch.Tensor or np.array) – Coarse inputs of shape [batch_size, num_variables, h, w]
variable_names (list of str, optional) – Names of the variables for subplot titles. If None, uses [“VAR_0”, “VAR_1”, …]
units (list of str, optional) – Units for each variable for axis labels. If None, uses empty strings.
quantiles (list of float, optional) – Quantile values to plot (e.g., [0.90, 0.95, 0.975, 0.99, 0.995])
filename (str, optional) – Output filename
save_dir (str, optional) – Directory to save the plot
figsize_multiplier (int, optional) – Base size multiplier for subplots
save_npz (bool, optional) – If True, saves the PDF diagnostics to a compressed .npz file.
- Returns:
save_path – Path to the saved figure
- Return type:
- IPSL_AID.diagnostics.dry_frequency_map(array, threshold)[source]
Compute spatial dry pixels proportion maps. Value of each pixel corresponds to the frequency of dry weather for this pixel.
- Parameters:
array (torch.Tensor or np.array) – Model predictions of shape [batch_size, h, w]
threshold (float) – threshold for precipitation (expressed in mm): under it, pixel is considered dry.
- Return type:
np.ndarray(np.float64) of shape [h,w]
- IPSL_AID.diagnostics.plot_dry_frequency_map(predictions, targets, threshold, lat_1d, lon_1d, filename='validation_dry_frequency_map.png', save_dir=None, figsize_multiplier=None)[source]
Plot spatial dry pixels proportion maps. Value of each pixel corresponds to the frequency of dry weather for this pixel.
- Parameters:
predictions (torch.Tensor or np.array) – Model predictions of shape [batch_size, h, w]
targets (torch.Tensor or np.array) – Ground truth of shape [batch_size, h, w]
threshold (float) – threshold for precipitation (expressed in mm): under it, pixel is considered dry.
lat_1d (array-like) – 1D array of latitude coordinates with shape [H].
lon_1d (array-like) – 1D array of longitude coordinates with shape [W].
filename (str, optional) – Output filename for saving the plot.
save_dir (str, optional) – Directory to save the plot.
figsize_multiplier (int, optional) – Base size multiplier for subplots.
- Return type:
None
- IPSL_AID.diagnostics.calculate_pearsoncorr_nparray(arr1, arr2, axis=0)[source]
Calculate Pearson correlation between 2 N-dimensional numpy arrays.
Parameters:
- arr1numpy.ndarray
First N-dimensional array
- arr2numpy.ndarray
Second N-dimensional array (must have same shape as arr1)
- axisint or type of int, default=0
Axis or tuple of axes over which to compute correlation
Returns:
- numpy.ndarray
Pearson correlation coefficients. Output has N - len(axis) dimensions (input shape with the specified axis/axes removed).
- IPSL_AID.diagnostics.plot_validation_mvcorr_space(predictions, targets, coarse_inputs=None, variable_names=None, filename='validation_mvcorr_space.png', save_dir='./results', figsize_multiplier=4)[source]
Compute multivariate correlation over the space dimensions and plot as time-series, comparing model predictions vs ground truth, for all combinations of variables. Uses Pearson’s correlation coefficient.
- Parameters:
predictions (torch.Tensor or np.array) – Model predictions of shape [batch_size, num_variables, h, w]
targets (torch.Tensor or np.array) – Ground truth of shape [batch_size, num_variables, h, w]
coarse_inputs (torch.Tensor or np.array, optional) – Coarse inputs of shape [batch_size, num_variables, h, w]
variable_names (list of str, optional) – Names of the variables for subplot titles
filename (str, optional) – Output filename
save_dir (str, optional) – Directory to save the plot
figsize_multiplier (int, optional) – Base size multiplier for subplots
- Returns:
save_path – Path to the saved figure
- Return type:
- IPSL_AID.diagnostics.plot_validation_mvcorr(predictions, targets, lat, lon, coarse_inputs=None, variable_names=None, filename='validation_mvcorr_time.png', save_dir='./results', figsize_multiplier=4)[source]
Compute multivariate correlation over the time dimension and plot as maps, comparing model predictions vs ground truth, for all combinations of variables. Uses Pearson’s correlation coefficient.
- Parameters:
predictions (torch.Tensor or np.array) – Model predictions of shape [batch_size, num_variables, h, w]
targets (torch.Tensor or np.array) – Ground truth of shape [batch_size, num_variables, h, w]
lat (array-like) – 2D array of latitude coordinates with shape [h, w].
lon (array-like) – 2D array of longitude coordinates with shape [h, w].
coarse_inputs (torch.Tensor or np.array, optional) – Coarse inputs of shape [batch_size, num_variables, h, w]
variable_names (list of str, optional) – Names of the variables for subplot titles
filename (str, optional) – Output filename
save_dir (str, optional) – Directory to save the plot
figsize_multiplier (int, optional) – Base size multiplier for subplots
- Returns:
save_path – Path to the saved figure
- Return type:
- IPSL_AID.diagnostics.plot_temporal_series_comparison(predictions, targets, coarse_inputs=None, variable_names=None, filename='validation_temp_series.png', save_dir='./results', figsize_multiplier=4)[source]
Plot spatially averaged temporal series for each variable.
- Parameters:
predictions (torch.Tensor or np.array) – Model predictions of shape [batch_size, num_variables, h, w]
targets (torch.Tensor or np.array) – Ground truth of shape [batch_size, num_variables, h, w]
coarse_inputs (torch.Tensor or np.array, optional) – Coarse inputs of shape [batch_size, num_variables, h, w]
variable_names (list of str, optional) – Names of the variables for subplot titles
filename (str, optional) – Output filename
save_dir (str, optional) – Directory to save the plot
figsize_multiplier (int, optional) – Base size multiplier for subplots
- Returns:
save_path – Path to the saved figure
- Return type:
- IPSL_AID.diagnostics.ranks(predictions, targets)[source]
Compute ranks of predictions compared to targets.
- Parameters:
predictions (torch.Tensor or np.array) – Model predictions of shape [ensemble_size, batch_size, h, w]
targets (torch.Tensor or np.array) – Targets of shape [batch_size, h, w]
- Return type:
np.ndarray(np.float64) of shape [batch_size*h*w,]
- IPSL_AID.diagnostics.plot_ranks(predictions, targets, variable_names=None, filename='ranks.png', save_dir='./results', figsize_multiplier=4)[source]
Create rank histograms of predictions compared to targets for each variable.
- Parameters:
predictions (torch.Tensor or np.array) – Model predictions of shape [ensemble_size, batch_size, num_variables, h, w]
targets (torch.Tensor or np.array) – Ground truth of shape [batch_size, num_variables, h, w]
variable_names (list of str, optional) – Names of the variables for subplot titles. If None, uses [“VAR_0”, “VAR_1”, …]
filename (str, optional) – Output filename
save_dir (str, optional) – Directory to save the plot
figsize_multiplier (int, optional) – Base size multiplier for subplots
- Returns:
save_path – Path to the saved figure
- Return type:
- IPSL_AID.diagnostics.get_divergence(u_tensor, v_tensor, spacing)[source]
Compute the horizontal divergence of a windfield.
- Parameters:
u_tensor (torch.Tensor or np.array, shape [...,h,w]) – tensor that stores the zonal component of the windfield. Can have arbitrary number of dimensions, but the last two dimensions have to correspond to longitude and latitude. u_tensor and v_tensor need to have the same shape.
v_tensor (torch.Tensor or np.array) – tensor that stores the meridional component of the windfield. Can have arbitrary number of dimensions, but the last two dimensions have to correspond to longitude and latitude. u_tensor and v_tensor need to have the same shape.
spacing (float) – float that describes the resolution of the windfield. Used to compute the gradients.
- Return type:
np.ndarray(np.float64) of same shape as u_tensor and v_tensor
- IPSL_AID.diagnostics.get_curl(u_tensor, v_tensor, spacing)[source]
Compute the curl of a windfield.
- Parameters:
u_tensor (torch.Tensor or np.array, shape [...,h,w]) – tensor that stores the zonal component of the windfield. Can have arbitrary number of dimensions, but the last two dimensions have to correspond to longitude and latitude. u_tensor and v_tensor need to have the same shape.
v_tensor (torch.Tensor or np.array) – tensor that stores the meridional component of the windfield. Can have arbitrary number of dimensions, but the last two dimensions have to correspond to longitude and latitude. u_tensor and v_tensor need to have the same shape.
spacing (float) – spatial resolution of the windfield. Used to compute the gradients.
- Return type:
np.ndarray(np.float64) of same shape as u_tensor and v_tensor
- IPSL_AID.diagnostics.plot_mean_divergence_map(u_prediction, v_prediction, u_target, v_target, spacing, lat_1d, lon_1d, filename='mean_divergence.png', save_dir=None, figsize_multiplier=None)[source]
Plot spatial dry pixels proportion maps. Value of each pixel corresponds to the frequency of dry weather for this pixel.
- Parameters:
u_prediction (torch.Tensor or np.array) – Model predictions of shape [batch_size, h, w] for zonal component of wind Last two dims have to correspond to longitude and latitude u_prediction and v_prediction need to have the same shape
v_prediction (torch.Tensor or np.array) – Model predictions of shape [batch_size, h, w] for meridional component of wind Last two dims have to correspond to longitude and latitude u_prediction and v_prediction need to have the same shape
u_target (torch.Tensor or np.array) – Ground truth of shape [batch_size, h, w] Last two dims have to correspond to longitude and latitude u_target and v_target need to have the same shape
v_target (torch.Tensor or np.array) – Ground truth of shape [batch_size, h, w] Last two dims have to correspond to longitude and latitude u_target and v_target need to have the same shape
spacing (float) – spatial resolution of the windfield. Used to compute the gradients.
lat_1d (array-like) – 1D array of latitude coordinates with shape [H].
lon_1d (array-like) – 1D array of longitude coordinates with shape [W].
filename (str, optional) – Output filename for saving the plot.
save_dir (str, optional) – Directory to save the plot.
figsize_multiplier (int, optional) – Base size multiplier for subplots.
- Return type:
None
- IPSL_AID.diagnostics.plot_mean_curl_map(u_prediction, v_prediction, u_target, v_target, spacing, lat_1d, lon_1d, filename='mean_curl.png', save_dir=None, figsize_multiplier=None)[source]
Plot spatial dry pixels proportion maps. Value of each pixel corresponds to the frequency of dry weather for this pixel.
- Parameters:
u_prediction (torch.Tensor or np.array) – Model predictions of shape [batch_size, h, w] for zonal component of wind Last two dims have to correspond to longitude and latitude u_prediction and v_prediction need to have the same shape
v_prediction (torch.Tensor or np.array) – Model predictions of shape [batch_size, h, w] for meridional component of wind Last two dims have to correspond to longitude and latitude u_prediction and v_prediction need to have the same shape
u_target (torch.Tensor or np.array) – Ground truth of shape [batch_size, h, w] Last two dims have to correspond to longitude and latitude u_target and v_target need to have the same shape
v_target (torch.Tensor or np.array) – Ground truth of shape [batch_size, h, w] Last two dims have to correspond to longitude and latitude u_target and v_target need to have the same shape
spacing (float) – spatial resolution of the windfield. Used to compute the gradients.
lat_1d (array-like) – 1D array of latitude coordinates with shape [H].
lon_1d (array-like) – 1D array of longitude coordinates with shape [W].
filename (str, optional) – Output filename for saving the plot.
save_dir (str, optional) – Directory to save the plot.
figsize_multiplier (int, optional) – Base size multiplier for subplots.
- Return type:
None
- class IPSL_AID.diagnostics.TestPlottingFunctions(methodName='runTest', logger=None)[source]
Bases:
TestCaseUnit tests for plotting functions with visible output for styling adjustment.
- __init__(methodName='runTest', logger=None)[source]
Create an instance of the class that will use the named test method when executed. Raises a ValueError if the instance does not have a method with the specified name.
- test_spatiotemporal_histograms_comprehensive()[source]
Comprehensive test for spatiotemporal histograms.
- test_plot_global_surface_robinson_comprehensive()[source]
Comprehensive test for global Robinson surface plots.
- test_plot_mae_map_comprehensive()[source]
Comprehensive test for time-averaged MAE spatial map plots.
- test_plot_error_map_comprehensive()[source]
Comprehensive test for time-averaged ERROR spatial map plots.
- test_plot_spread_skill_ratio_map_comprehensive()[source]
Comprehensive test for time-averaged MAE spatial map plots.
- test_plot_spread_skill_ratio_hexbin_comprehensive()[source]
Comprehensive test for spread skill ratio hexbin plots
- test_plot_mean_divergence_map_comprehensive()[source]
Comprehensive test for mean divergence map plots.
- test_plot_dry_frequency_map_comprehensive()[source]
Comprehensive test for dry frequency map plots.
- test_plot_metrics_heatmap_comprehensive()[source]
Comprehensive test for validation metrics heatmap.
- test_mv_correlation()[source]
Test for correlation over the time dimension for pairs of variables. Test for correlation over the spatial dimensions.
- class IPSL_AID.diagnostics.TestSSRFunction(methodName='runTest', logger=None)[source]
Bases:
TestCaseUnit tests for crps_ensemble_all function.
IPSL_AID.evaluater module
- class IPSL_AID.evaluater.MetricTracker[source]
Bases:
objectA utility class for tracking and computing statistics of metric values.
This class maintains a running average of metric values and provides methods to compute mean and root mean squared values.
Examples
>>> tracker = MetricTracker() >>> tracker.update(10.0, 5) # value=10.0, count=5 samples >>> tracker.update(20.0, 3) # value=20.0, count=3 samples >>> print(tracker.getmean()) # (10*5 + 20*3) / (5+3) = 110/8 = 13.75 13.75 >>> print(tracker.getsqrtmean()) # sqrt(13.75) 3.7080992435478315
- getmean()[source]
Calculate the mean of all tracked values.
- Returns:
Weighted mean of all values: total_value / total_count
- Return type:
- Raises:
ZeroDivisionError – If no values have been added (count == 0)
- getstd()[source]
Calculate the standard deviation of all tracked values.
- Returns:
Weighted standard deviation of all values: sqrt(E(x^2) - (E(x))^2)
- Return type:
- Raises:
ZeroDivisionError – If no values have been added (count == 0)
- getsqrtmean()[source]
Calculate the square root of the mean of all tracked values.
- Returns:
Square root of the weighted mean: sqrt(total_value / total_count)
- Return type:
- Raises:
ZeroDivisionError – If no values have been added (count == 0)
- IPSL_AID.evaluater.mae_all(pred, true)[source]
Calculate Mean Absolute Error (MAE) between predicted and true values.
Computes the MAE metric and returns both the number of elements and the mean absolute error value.
- Parameters:
pred (torch.Tensor) – Predicted values from the model
true (torch.Tensor) – Ground truth values
- Returns:
(num_elements, mae_value) where: - num_elements (int): Total number of elements in the tensors - mae_value (torch.Tensor): Mean absolute error value
- Return type:
Examples
>>> pred = torch.tensor([1.0, 2.0, 3.0]) >>> true = torch.tensor([1.1, 1.9, 3.2]) >>> num_elements, mae = mae_all(pred, true) >>> print(f"MAE: {mae.item():.4f}, Elements: {num_elements}") MAE: 0.1333, Elements: 3
Notes
The MAE is calculated as: mean(abs(pred - true)) This function is useful for tracking metrics with MetricTracker
- IPSL_AID.evaluater.nmae_all(pred, true, eps=1e-08)[source]
Normalized Mean Absolute Error (NMAE). NMAE = MAE(pred, true) / mean(abs(true))
Computes the NMAE metric and returns both the number of elements and the normalized mean absolute error value.
- Parameters:
pred (torch.Tensor) – Predicted values from the model
true (torch.Tensor) – Ground truth values
eps (float) – Small value to avoid division by zero
- Returns:
(num_elements, mae_value) where: - num_elements (int): Total number of elements in the tensors - mae_value (torch.Tensor): Mean absolute error value
- Return type:
Examples
>>> pred = torch.tensor([1.0, 2.0, 3.0]) >>> true = torch.tensor([1.1, 1.9, 3.2]) >>> num_elements, nmae = nmae_all(pred, true) >>> print(f"NMAE: {nmae.item():.4f}, Elements: {num_elements}") NMAE: 0.047059, Elements: 3
Notes
The NMAE is calculated as: MAE(pred, true) / mean(abs(true)) This function is useful for tracking metrics with MetricTracker
- IPSL_AID.evaluater.crps_ensemble_all(pred_ens, true)[source]
Continuous Ranked Probability Score (CRPS) for an ensemble.
Computes the CRPS metric for ensemble predictions and returns both the number of elements and the mean CRPS value.
- Parameters:
pred_ens (torch.Tensor) – Ensemble predictions, shape [N_ens, N_pixels]
true (torch.Tensor) – Ground truth values, shape [N_pixels]
- Returns:
(num_elements, crps_mean) where: - num_elements (int): Total number of elements in the tensors - crps_mean (torch.Tensor): Mean CRPS
- Return type:
Notes
The CRPS for an ensemble is computed as:
CRPS = E|X - y| - 0.5 * E|X - X’|
where X and X’ are independent ensemble members and y is the observation.
- IPSL_AID.evaluater.rmse_all(pred, true)[source]
Calculate Root Mean Square Error (RMSE) between predicted and true values.
Computes the RMSE metric and returns both the number of elements and the root mean square error value.
- Parameters:
pred (torch.Tensor) – Predicted values from the model
true (torch.Tensor) – Ground truth values
- Returns:
(num_elements, rmse_value) where: - num_elements (int): Total number of elements in the tensors - rmse_value (torch.Tensor): Root mean square error value
- Return type:
Examples
>>> pred = torch.tensor([1.0, 2.0, 3.0]) >>> true = torch.tensor([1.1, 1.9, 3.2]) >>> num_elements, rmse = rmse_all(pred, true) >>> print(f"RMSE: {rmse.item():.4f}, Elements: {num_elements}") RMSE: 0.1414, Elements: 3
Notes
The RMSE is calculated as: sqrt(mean((pred - true)^2)) This function is useful for tracking metrics with MetricTracker
- IPSL_AID.evaluater.r2_all(pred, true)[source]
Calculate R2 (coefficient of determination) between predicted and true values.
Computes the R2 metric and returns both the number of elements and the R2 value.
- Parameters:
pred (torch.Tensor) – Predicted values from the model
true (torch.Tensor) – Ground truth values
- Returns:
(num_elements, r2_value) where: - num_elements (int): Total number of elements in the tensors - r2_value (torch.Tensor): R2 score
- Return type:
Notes
R2 is calculated as:
R2 = 1 - sum((true - pred)^2) / sum((true - mean(true))^2)
This implementation is fully torch-based and works on CPU and GPU.
- IPSL_AID.evaluater.pearson_all(pred, true)[source]
Compute the Pearson correlation coefficient between predicted and ground truth values using torch.corrcoef.
- Parameters:
pred (torch.Tensor) – Predicted values from the model.
true (torch.Tensor) – Ground truth values.
- Returns:
(num_elements, pearson_value) where: - num_elements (int): Total number of elements in the tensors. - pearson_value (torch.Tensor): Pearson correlation coefficient.
- Return type:
Notes
The Pearson correlation coefficient is defined as:
rho = Cov(pred, true) / (std(pred) * std(true))
- IPSL_AID.evaluater.kl_divergence_all(pred, true)[source]
Compute the Kullback–Leibler (KL) divergence between predicted and ground truth distributions using histogram-based estimation.
- Parameters:
pred (torch.Tensor) – Predicted values from the model.
true (torch.Tensor) – Ground truth values.
- Returns:
(num_elements, kl_value) where: - num_elements (int): Total number of elements in the tensors. - kl_value (torch.Tensor): KL divergence value.
- Return type:
Notes
The KL divergence is defined as:
KL(P|Q) = sum_i P_i * log(P_i / Q_i)
- where:
P represents the true distribution
Q represents the predicted distribution
- IPSL_AID.evaluater.denormalize(data, stats, norm_type, device, var_name=None, data_type=None, debug=False, logger=None)[source]
Denormalize a data tensor using the inverse of the normalization operation.
- Parameters:
data (torch.Tensor) – Normalized tensor to denormalize.
stats (object) – Object containing the required statistics.
norm_type (str) – Normalization type used originally.
device (torch.device) – Device for tensor operations.
var_name (str, optional) – Variable name for debugging.
data_type (str, optional) – Data type for debugging (e.g., “residual”, “coarse”).
debug (bool, optional) – Enable debug logging.
logger (Logger, optional) – Logger instance for debug output.
- IPSL_AID.evaluater.edm_sampler(model, image_input, class_labels=None, num_steps=40, sigma_min=0.02, sigma_max=80.0, rho=7, S_churn=40, S_min=0, S_max=inf, S_noise=1)
EDM sampler for diffusion model inference. Original work: Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. Original source: https://github.com/NVlabs/edm
- Parameters:
model (torch.nn.Module) – Diffusion model
image_input (torch.Tensor) – Conditioning input (coarse + constants)
class_labels (torch.Tensor, optional) – Time conditioning labels
num_steps (int, optional) – Number of sampling steps
sigma_min (float, optional) – Minimum noise level
sigma_max (float, optional) – Maximum noise level
rho (float, optional) – Time step exponent
S_churn (int, optional) – Stochasticity parameter
S_min (float, optional) – Minimum stochasticity threshold
S_max (float, optional) – Maximum stochasticity threshold
S_noise (float, optional) – Noise scale for stochasticity
- Returns:
Generated residual predictions
- Return type:
torch.Tensor
- IPSL_AID.evaluater.sampler(epoch, batch_idx, model, image_input, class_labels=None, num_steps=18, sigma_min=None, sigma_max=None, rho=7, solver='heun', discretization='edm', schedule='linear', scaling='none', epsilon_s=0.001, C_1=0.001, C_2=0.008, M=1000, alpha=1, S_churn=40, S_min=0, S_max=inf, S_noise=1, logger=None)
General sampler for diffusion model inference with multiple configurations. Original work: Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. Original source: https://github.com/NVlabs/edm
- Parameters:
model (torch.nn.Module) – Diffusion model
image_input (torch.Tensor) – Conditioning input (coarse + constants)
class_labels (torch.Tensor, optional) – Time conditioning labels
num_steps (int, optional) – Number of sampling steps
sigma_min (float, optional) – Minimum noise level
sigma_max (float, optional) – Maximum noise level
rho (float, optional) – Time step exponent for EDM discretization
solver (str, optional) – Solver type: ‘euler’ or ‘heun’
discretization (str, optional) – Discretization type: ‘vp’, ‘ve’, ‘iddpm’, or ‘edm’
schedule (str, optional) – Noise schedule: ‘vp’, ‘ve’, or ‘linear’
scaling (str, optional) – Scaling type: ‘vp’ or ‘none’
epsilon_s (float, optional) – Small epsilon for VP schedule
C_1 (float, optional) – Constant for IDDPM discretization
C_2 (float, optional) – Constant for IDDPM discretization
M (int, optional) – Number of steps for IDDPM discretization
alpha (float, optional) – Parameter for Heun’s method
S_churn (int, optional) – Stochasticity parameter
S_min (float, optional) – Minimum stochasticity threshold
S_max (float, optional) – Maximum stochasticity threshold
S_noise (float, optional) – Noise scale for stochasticity
logger (logging.Logger, optional) – Logger instance for logging sampler parameters
- Returns:
Generated residual predictions
- Return type:
torch.Tensor
- IPSL_AID.evaluater.reconstruct_original_layout(epoch, args, paths, steps, all_data, dataset, device, logger)[source]
Robust reconstruction using dataset information directly.
Parameters:
- all_datadict
Dictionary containing lists of batches for: - ‘predictions’: model predictions [B, C, H, W] - ‘coarse’: coarse resolution data [B, C, H, W] - ‘fine’: fine resolution ground truth [B, C, H, W] - ‘lat’: latitude coordinates [B, H] - ‘lon’: longitude coordinates [B, W]
- datasettorch.utils.data.Dataset
The validation dataset instance
- devicetorch.device
Device to store tensors on
- loggerLogger
Logger instance for logging
Returns:
dict: Reconstructed data with metadata
- IPSL_AID.evaluater.generate_residuals_norm(model, features, labels, targets, loss_fn, args, device, logger, epoch=0, batch_idx=0, inference_type='sampler')[source]
Generate normalized residuals for all variables.
- Parameters:
model (torch.nn.Module) – Diffusion model
features (torch.Tensor) – Input feature tensor provided to the model
labels (torch.Tensor) – Conditioning labels provided to the model
targets (torch.Tensor) – Ground truth target tensor used for noise injection in direct inference
loss_fn (callable) – Loss function
args (argparse.Namespace) – Command line arguments
device (torch.device) – Training device
logger (Logger) – Logger instance
epoch (int) – Current epoch number
inference_type (str, optional) – Inference mode, either “direct” (deterministic) or “sampler” (stochastic diffusion sampling)
- Returns:
[B, C, H, W] residuals in normalized space
- Return type:
torch.Tensor
- IPSL_AID.evaluater.run_validation(model, valid_dataset, valid_loader, loss_fn, norm_mapping, normalization_type, index_mapping, args, steps, device, logger, epoch, writer=None, plot_every_n_epochs=None, paths=None, compute_crps=False, crps_batch_size=2, crps_ensemble_size=10)[source]
Run validation on the model.
- Parameters:
model (torch.nn.Module) – Diffusion model
valid_loader (DataLoader) – Validation data loader
loss_fn (callable) – Loss function
norm_mapping (dict) – Normalization statistics
normalization_type (EasyDict) – Normalization types for each variable
args (argparse.Namespace) – Command line arguments
device (torch.device) – Training device
logger (Logger) – Logger instance
epoch (int) – Current epoch number
writer (SummaryWriter, optional) – TensorBoard writer
plot_every_n_epochs (int, optional) – Frequency (in epochs) at which validation plots are generated
paths (dict, optional) – Paths used for saving reconstructions and plots
compute_crps (bool, optional) – Whether to compute CRPS using stochastic ensemble sampling
crps_batch_size (int, optional) – Number of validation batches used for CRPS computation
crps_ensemble_size (int, optional) – Number of ensemble members used to estimate CRPS
- Returns:
(avg_val_loss, val_metrics) - average validation loss and metrics dictionary
- Return type:
- class IPSL_AID.evaluater.TestMetricTracker(methodName='runTest', logger=None)[source]
Bases:
TestCaseUnit tests for MetricTracker class.
- class IPSL_AID.evaluater.TestErrorMetrics(methodName='runTest', logger=None)[source]
Bases:
TestCaseUnit tests for error metrics.
- __init__(methodName='runTest', logger=None)[source]
Create an instance of the class that will use the named test method when executed. Raises a ValueError if the instance does not have a method with the specified name.
- class IPSL_AID.evaluater.TestCRPSFunction(methodName='runTest', logger=None)[source]
Bases:
TestCaseUnit tests for crps_ensemble_all function.
- __init__(methodName='runTest', logger=None)[source]
Create an instance of the class that will use the named test method when executed. Raises a ValueError if the instance does not have a method with the specified name.
- class IPSL_AID.evaluater.TestDenormalizeFunction(methodName='runTest', logger=None)[source]
Bases:
TestCaseUnit tests for denormalize function.
- class IPSL_AID.evaluater.TestRunValidation(methodName='runTest', logger=None)[source]
Bases:
TestCaseUnit tests for run_validation function focusing on return values verification.
- __init__(methodName='runTest', logger=None)[source]
Create an instance of the class that will use the named test method when executed. Raises a ValueError if the instance does not have a method with the specified name.
- test_val_loss_and_metrics_across_3_batches_consistent_shape()[source]
Verify avg_val_loss and val_metrics with 3 batches of consistent shape.
- test_crps_zero_when_predictions_equal_fine()[source]
Verify that CRPS is zero when all ensemble predictions exactly match the fine target.
IPSL_AID.logger module
- class IPSL_AID.logger.Logger(console_output=True, file_output=False, log_file='module_log_file.log', pretty_print=True, record=False)[source]
Bases:
object- __init__(console_output=True, file_output=False, log_file='module_log_file.log', pretty_print=True, record=False)[source]
- start_task(task_name: str, description: str = '', **meta)[source]
Display a clearly formatted ‘task start’ message with good spacing.
- class IPSL_AID.logger.TestLogger(methodName='runTest', logger=None)[source]
Bases:
TestCaseUnit tests for Logger class.
- __init__(methodName='runTest', logger=None)[source]
Create an instance of the class that will use the named test method when executed. Raises a ValueError if the instance does not have a method with the specified name.
IPSL_AID.loss module
Diffusion model loss functions and testing utilities.
This module implements various loss functions for diffusion models including: - VPLoss: Variance Preserving loss from Score-Based Generative Modeling - VELoss: Variance Exploding loss from Score-Based Generative Modeling - EDMLoss: Improved loss from Elucidating the Design Space of Diffusion-Based Generative Models
- class IPSL_AID.loss.VPLoss(beta_d=19.9, beta_min=0.1, epsilon_t=1e-05)[source]
Bases:
objectLoss function for Variance Preserving (VP) formulation diffusion models.
This class implements the loss function for the Variance Preserving SDE formulation of diffusion models. It follows the continuous-time training objective from score-based generative modeling through stochastic differential equations.
- Parameters:
beta_d (float, optional) – Maximum β parameter controlling the extent of the noise schedule. Larger values lead to faster noise increase. Default is 19.9.
beta_min (float, optional) – Minimum β parameter controlling the initial slope of the noise schedule. Default is 0.1.
epsilon_t (float, optional) – Minimum time value threshold to avoid numerical issues near t=0. Default is 1e-5.
- __call__(net, images, conditional_img=None, labels=None, augment_pipe=None)[source]
Compute the VP loss for a batch of images.
Notes
The loss is based on denoising score matching: E[λ(t) * ||D_θ(x_t, t) - x_0||²]
The weighting function λ(t) = 1/σ(t)² gives equal importance to all noise levels.
Time t is uniformly sampled between [epsilon_t, 1] during training.
This loss corresponds to training the model to predict the clean image x_0 from noisy input x_t = x_0 + σ(t)·ε.
References
Song et al., “Score-Based Generative Modeling through Stochastic Differential Equations”, 2020.
- sigma(t)[source]
Compute noise level sigma for given timestep t.
- Parameters:
t (torch.Tensor or float) – Timestep value(s) in [epsilon_t, 1].
- Returns:
Noise level sigma corresponding to t, with same shape as input.
- Return type:
torch.Tensor
Notes
The noise schedule follows: σ(t) = sqrt(exp(0.5*β_d*t² + β_min*t) - 1)
This ensures smooth transition from low to high noise levels, with σ(0) ≈ 0 and σ(1) determined by β_d and β_min.
- class IPSL_AID.loss.VELoss(sigma_min=0.02, sigma_max=100)[source]
Bases:
objectLoss function for Variance Exploding (VE) formulation diffusion models.
This class implements the loss function for the Variance Exploding SDE formulation of diffusion models. It follows the continuous-time training objective from score-based generative modeling through stochastic differential equations.
- Parameters:
sigma_min (float, optional) – Minimum noise level. Controls the lower bound of the noise schedule. Smaller values allow modeling finer details. Default is 0.02.
sigma_max (float, optional) – Maximum noise level. Controls the upper bound of the noise schedule. Larger values allow modeling broader structure. Default is 100.
- __call__(net, images, conditional_img=None, labels=None, augment_pipe=None)[source]
Compute the VE loss for a batch of images.
Notes
The VE formulation uses a geometric noise schedule: σ(t) = σ_min * (σ_max/σ_min)^t
Time t is uniformly sampled between [0, 1] during training.
The weighting function λ(t) = 1/σ(t)² gives more emphasis to lower noise levels.
This corresponds to training the model to predict the clean image x_0 from noisy input x_t = x_0 + σ(t)·ε.
The geometric schedule provides a simple and effective way to span a wide range of noise levels with a single parameter.
References
Song et al., “Score-Based Generative Modeling through Stochastic Differential Equations”, 2020.
- class IPSL_AID.loss.EDMLoss(P_mean=-1.2, P_std=1.2, sigma_data=1.0)[source]
Bases:
objectEDM (Elucidating Diffusion Models) loss function for diffusion models.
This class implements the improved loss function from the EDM paper, which uses a log-normal distribution for noise level sampling and an optimized weighting scheme for better training stability and sample quality.
- Parameters:
P_mean (float, optional) – Mean parameter for the log-normal distribution of sigma. Controls the center of the noise level distribution. Default is -1.2.
P_std (float, optional) – Standard deviation parameter for the log-normal distribution of sigma. Controls the spread of the noise level distribution. Default is 1.2.
sigma_data (float, optional) – Standard deviation of the training data. Used in the weighting function to balance the loss across noise levels. Default is 1.0.
- __call__(net, images, conditional_img=None, labels=None, augment_pipe=None)[source]
Compute the EDM loss for a batch of images.
Notes
The EDM loss uses a log-normal distribution for sigma: σ ~ logNormal(P_mean, P_std)
The weighting function: λ(σ) = (σ² + σ_data²) / (σ·σ_data)²
This weighting minimizes the variance of the loss gradient, leading to more stable training and faster convergence.
The loss corresponds to training the model to predict the clean image x_0 from noisy input x_t = x_0 + σ·ε.
The log-normal distribution provides a better prior for noise levels compared to uniform sampling.
References
Karras et al., “Elucidating the Design Space of Diffusion-Based Generative Models”, 2022.
- class IPSL_AID.loss.UnetLoss(loss_type='mse', reduction='mean')[source]
Bases:
objectSimple UNet loss function for direct image-to-image prediction.
This loss function works with UNet models that predict images directly, without any diffusion noise process. It’s a standard supervised loss for image generation/transformation tasks such as segmentation, denoising, super-resolution, or autoencoding.
- Parameters:
loss_type (str, optional) – Type of loss function to use: -
mse: Mean Squared Error (L2 loss) -l1: Mean Absolute Error (L1 loss) -smooth_l1: Smooth L1 loss (Huber loss) Default ismse.reduction (str, optional) – Reduction method for the loss: -
mean: Average the loss over all elements -sum: Sum the loss over all elements -none: Return loss per element Default ismean.
- loss_fn
PyTorch loss function instance.
- Type:
torch.nn.Module
- Raises:
ValueError – If an unknown
loss_typeis provided.
Notes
The loss computes the discrepancy between the model’s output and the input image.
This is suitable for autoencoder-style tasks where the model learns to reconstruct the input.
For conditional generation, labels can be provided to the model.
Data augmentation can be applied via augment_pipe.
- class IPSL_AID.loss.TestLosses(methodName='runTest', logger=None)[source]
Bases:
TestCaseUnit tests for diffusion models and loss functions.
IPSL_AID.main module
- IPSL_AID.main.parse_args()[source]
Parse command line arguments for diffusion model training and inference.
This function defines and parses all command line arguments required for configuring and running diffusion model training, resumption, or inference experiments. It provides comprehensive options for data loading, model architecture, training hyperparameters, and output management.
- Returns:
Parsed command line arguments as a namespace object with attributes corresponding to each argument.
- Return type:
Notes
Arguments are organized into logical groups: execution mode, data configuration, training configuration, model architecture, and output.
Boolean arguments use string conversion with lambda functions for flexibility (accepts “true”/”false”, “True”/”False”, etc.).
Default values are provided for most parameters to allow minimal configuration for basic usage.
Some arguments have constraints or choices to ensure valid configurations.
- IPSL_AID.main.make_divisible_hw(h, w, n)[source]
Adjust height and width to be divisible by 2**n by decrementing.
This function ensures that both the height (h) and width (w) are divisible by 2 raised to the power n, which is often required for neural network architectures that use pooling or strided convolutions multiple times.
- Parameters:
- Returns:
h_new (int) – Adjusted height that is divisible by 2**n.
w_new (int) – Adjusted width that is divisible by 2**n.
Notes
The function decrements h and w until they become divisible by 2**n.
This is a common requirement for U-Net and other encoder-decoder architectures that use multiple downsampling and upsampling operations.
The adjustment is conservative (decrementing) to avoid adding padding, which might be important for maintaining exact spatial relationships.
- IPSL_AID.main.setup_directories_and_logging(args)[source]
Set up directory structure and logging infrastructure for experiments.
This function creates a standardized directory hierarchy for organizing experiment outputs (logs, results, model checkpoints, etc.) and initializes a logging system with both console and file output.
- Parameters:
args (argparse.Namespace or EasyDict) –
Configuration object containing the following attributes:
- main_folderstr
Main experiment folder name.
- sub_folderstr
Sub-folder name for the current run.
- prefixstr
Prefix for log files and outputs.
- datadirstr
Base data directory path.
- constant_varnames_filestr
Filename for constant variables data.
- Returns:
paths (EasyDict) – Dictionary containing paths to created directories:
- logsstr
Path to log files directory.
- resultsstr
Path to results output directory.
- runsstr
Path to experiment run tracking directory.
- checkpointsstr
Path to model checkpoint directory.
- statsstr
Path to statistics and metrics directory.
- datadirstr
Original data directory path.
- constantsstr
Full path to constant variables file.
logger (Logger) – Configured logger instance with console and file output.
Notes
Directory structure:
logs/main_folder/sub_folder/ results/main_folder/sub_folder/ runs/main_folder/sub_folder/ checkpoints/main_folder/sub_folder/ stats/main_folder/sub_folder/
Log files are named with timestamp: {prefix}_log.txt
The logger outputs to both console and file by default.
All directories are created if they don’t exist (via FileUtils.makedir).
- IPSL_AID.main.log_configuration(args, paths, logger)[source]
Log all configuration parameters to the provided logger.
This function comprehensively logs all experiment configuration parameters including execution mode, data settings, training hyperparameters, model architecture, and directory structure. It provides a clear overview of the experiment setup for reproducibility and debugging.
- Parameters:
args (argparse.Namespace or EasyDict) – Configuration object containing all experiment parameters.
paths (EasyDict) – Dictionary containing paths to various experiment directories.
logger (Logger) – Logger instance for outputting configuration information.
Notes
The function organizes parameters into logical sections for readability.
Includes both user-specified parameters and derived directory paths.
Provides warnings for important configuration choices (e.g., disabled checkpoint saving).
The output is formatted with clear section headers and indentation.
- IPSL_AID.main.setup_data_paths(args, paths, logger)[source]
Set up data file paths, load datasets, and compute normalization statistics.
This function handles the data loading pipeline for training and validation datasets. It manages per-variable data paths, concatenates multi-year data for each variable, computes normalization statistics, and sets up variable mappings and normalization types.
- Parameters:
args (argparse.Namespace or EasyDict) – Configuration object containing runtime options such as training years, execution mode, variable names, and normalization specifications.
paths (EasyDict) –
Dictionary containing directory paths.
Expected keys: - datadir - stats
logger (logging.Logger) – Logger instance for output messages.
- Returns:
norm_mapping (dict) – Mapping from variable name to normalization statistics.
steps (EasyDict) – Grid dimension information (time, latitude, longitude).
normalization_type (EasyDict) – Mapping from variable name to normalization method.
index_mapping (dict) – Mapping from variable name to array index.
train_ds (xarray.Dataset or None) – Training dataset, or None if run_type is
inference.valid_ds (xarray.Dataset) – Validation dataset.
Notes
Per-variable data directories may be provided using
VAR=pathsyntax.Training data is only loaded when
run_typeis notinference.Normalization statistics are computed on the validation dataset.
Variables from different files and years are merged into a single dataset.
- IPSL_AID.main.setup_training_environment(args, logger)[source]
Set up the training environment including device selection, random seeds, and data type configuration.
- Parameters:
args (argparse.Namespace or EasyDict) – Configuration object containing precision settings.
logger (logging.Logger) – Logger instance for output messages.
- Returns:
device (torch.device) – Selected computing device.
torch_dtype (torch.dtype) – PyTorch data type.
np_dtype (numpy.dtype) – NumPy data type.
use_fp16 (bool) – Whether half precision is enabled.
Notes
Sets global random seeds for reproducibility.
Automatically selects CUDA if available.
Enables PyTorch anomaly detection for debugging.
- IPSL_AID.main.create_data_loaders(args, paths, norm_mapping, steps, normalization_type, index_mapping, torch_dtype, np_dtype, logger, mode='train', run_type='train', train_loaded_dfs=None, valid_loaded_dfs=None)[source]
Create data loaders for training, validation, or inference.
- Parameters:
args (argparse.Namespace or EasyDict) – Runtime configuration options.
paths (EasyDict) – Directory paths including constants files.
norm_mapping (dict) – Normalization statistics per variable.
steps (EasyDict) – Grid dimension information.
normalization_type (EasyDict) – Normalization method per variable.
index_mapping (dict) – Variable-to-index mapping.
torch_dtype (torch.dtype) – PyTorch tensor dtype.
np_dtype (numpy.dtype) – NumPy array dtype.
logger (logging.Logger) – Logger instance.
mode (str, optional) – Either
trainorvalidation.run_type (str, optional) – Execution mode (train, resume_train, inference).
train_loaded_dfs (dict, optional) – Pre-loaded training datasets.
valid_loaded_dfs (dict, optional) – Pre-loaded validation datasets.
- Returns:
data_loader (torch.utils.data.DataLoader) – Configured data loader.
img_res (tuple of int) – Spatial resolution used by the model.
dataset (DataPreprocessor) – Underlying dataset object.
- Raises:
ValueError – If
modeis invalid.AssertionError – If required datasets are missing.
Notes
Spatial dimensions are adjusted to be divisible by powers of two.
Validation falls back to training data if validation data is unavailable.
Data is assumed to be pre-loaded into memory.
- IPSL_AID.main.setup_model(args, img_res, use_fp16, device, logger)[source]
Set up the diffusion model and its loss function.
- Parameters:
args (argparse.Namespace or EasyDict) – Model configuration options.
use_fp16 (bool) – Whether FP16 precision is enabled.
device (torch.device) – Target device.
logger (logging.Logger) – Logger instance.
- Returns:
model (torch.nn.Module) – Initialized model.
loss_fn (callable) – Loss function.
- Raises:
ValueError – If an unsupported time normalization is specified.
Notes
Label dimensionality depends on the selected time normalization.
Model creation is delegated to
load_model_and_loss.
- IPSL_AID.main.resolve_region_center(args)[source]
Resolve the regional inference center coordinates.
This function enforces the logic for regional inference: user can provide either region or region_center
- Parameters:
args (argparse.Namespace) – Parsed command line arguments.
- Returns:
(lat, lon) if inference_regional, None otherwise.
- Return type:
tuple or None
- Raises:
ValueError – If both region and region_center are provided, neither is provided in inference_regional mode, an unknown region name is specified.
Notes
Predefined regions map to fixed center coordinates.
Longitude follows the convention [0, 360].
- IPSL_AID.main.main()[source]
Main training and inference pipeline for IPSL-AID diffusion models.
This function orchestrates the entire training and inference process for diffusion-based generative models on weather and climate data. It handles: - Argument parsing and configuration - Directory setup and logging - Data loading and preprocessing - Model initialization and checkpoint management - Training loop with validation - Inference execution - Visualization and result saving
The pipeline supports multiple modes of operation: - Training from scratch (run_type=’train’) - Resuming training from a checkpoint (run_type=’resume_train’) - Running inference with a trained model (run_type=’inference’)
The function follows a structured workflow: 1. Parse command line arguments 2. Setup directories and logging 3. Load and preprocess data 4. Initialize model, optimizer, and loss function 5. Handle checkpoint loading if required 6. Execute training loop with validation or run inference 7. Generate plots and save results
- Parameters:
None – All configuration is provided via command line arguments.
- Return type:
None
Notes
The function uses argparse for command line argument parsing.
All output (logs, checkpoints, results) is saved to organized directories.
Training includes validation at each epoch with metrics tracking.
Inference mode runs validation metrics without training.
Mixed precision training (FP16) is supported when available.
Model checkpoints include full training state for resumption.
TensorBoard integration is provided for training visualization.
- Raises:
FileNotFoundError – If required checkpoints are not found for resumption or inference.
RuntimeError – If inference mode is requested without validation data.
ValueError – If invalid configurations are provided.
IPSL_AID.model module
- IPSL_AID.model.load_model_and_loss(opts, logger=None, device='cpu')[source]
Load a diffusion model or U-Net with corresponding loss function.
This function initializes and configures a generative model (diffusion or direct U-Net) along with its corresponding loss function based on the provided options. It supports multiple architectures and preconditioning schemes.
- Parameters:
Configuration dictionary containing model parameters. Must include:
- archstr
Architecture type: ‘ddpmpp’, ‘ncsnpp’, or ‘adm’.
- precondstr
Preconditioning type: ‘vp’, ‘ve’, ‘edm’, or ‘unet’.
- img_resolutionint or tuple
Image resolution (height, width).
- in_channelsint
Number of input channels.
- out_channelsint
Number of output channels.
- label_dimint
Dimension of label conditioning (0 for unconditional).
- use_fp16bool
Whether to use mixed precision (FP16).
- model_kwargsdict, optional
Additional model-specific parameters to override defaults.
logger (logging.Logger, optional) – Logger instance for output messages. If None, uses print(). Default is None.
device (str or torch.device, optional) – Device to load the model onto (‘cpu’, ‘cuda’, etc.). Default is ‘cpu’.
- Returns:
model (torch.nn.Module) – Initialized model instance (preconditioner or U-Net).
loss_fn (torch.nn.Module or callable) – Corresponding loss function for the model.
- Raises:
ValueError – If an invalid architecture or preconditioner type is specified.
Notes
- The function supports three main architectures:
DDPM++ (Song et al., 2020) with VP preconditioning
NCSN++ (Song et al., 2020) with VE preconditioning
ADM (Dhariwal & Nichol, 2021) with EDM preconditioning
When precond=’unet’, uses a direct U-Net without diffusion preconditioning.
Model parameters are counted and logged for transparency.
Default hyperparameters are provided for each architecture but can be overridden via opts.model_kwargs.
- class IPSL_AID.model.TestModelLoader(methodName='runTest', logger=None)[source]
Bases:
TestCaseUnit tests for model and loss loader.
- __init__(methodName='runTest', logger=None)[source]
Create an instance of the class that will use the named test method when executed. Raises a ValueError if the instance does not have a method with the specified name.
IPSL_AID.model_utils module
- class IPSL_AID.model_utils.ModelUtils[source]
Bases:
objectUtility class for model inspection, checkpointing, and memory profiling.
This class provides static methods for common model operations including parameter counting, memory usage analysis, checkpoint management, and model inspection.
Examples
>>> utils = ModelUtils() >>> param_counts = ModelUtils.get_parameter_number(model) >>> ModelUtils.save_checkpoint(state, "checkpoint.pth.tar", logger)
- static get_parameter_number(model, logger=None)[source]
Calculate the total and trainable number of parameters in a model.
- Parameters:
model (torch.nn.Module) – PyTorch model to inspect
logger (Logger, optional) – Logger instance for output, by default None
- Returns:
Dictionary containing: - ‘Total’: Total number of parameters - ‘Trainable’: Number of trainable parameters
- Return type:
Examples
>>> model = torch.nn.Linear(10, 5) >>> counts = ModelUtils.get_parameter_number(model, logger)
- static print_model_layers(model, logger=None)[source]
Print model parameter names along with their gradient requirements.
- Parameters:
model (torch.nn.Module) – PyTorch model to inspect
logger (Logger, optional) – Logger instance for output, by default None
Examples
>>> model = torch.nn.Sequential( ... torch.nn.Linear(10, 5), ... torch.nn.ReLU(), ... torch.nn.Linear(5, 1) ... ) >>> ModelUtils.print_model_layers(model, logger)
- static save_checkpoint(state, filename='checkpoint.pth.tar', logger=None)[source]
Save model and optimizer state to a file.
- Parameters:
state (dict) – Dictionary containing model state_dict and other training information. Typically includes: - ‘state_dict’: Model parameters - ‘optimizer’: Optimizer state - ‘epoch’: Current epoch - ‘loss’: Current loss value
filename (str, optional) – File path to save the checkpoint, by default “checkpoint.pth.tar”
logger (Logger, optional) – Logger instance for output, by default None
Examples
>>> state = { ... 'state_dict': model.state_dict(), ... 'optimizer': optimizer.state_dict(), ... 'epoch': epoch, ... 'loss': loss ... } >>> ModelUtils.save_checkpoint(state, 'model_checkpoint.pth.tar', logger)
- static load_checkpoint(checkpoint, model, optimizer=None, logger=None)[source]
Load model and optimizer state from a checkpoint file.
- Parameters:
Examples
>>> checkpoint = torch.load('model_checkpoint.pth.tar') >>> ModelUtils.load_checkpoint(checkpoint, model, optimizer, logger)
- static load_training_checkpoint(checkpoint_path, model, optimizer, device, logger=None)[source]
Load comprehensive training checkpoint.
- Parameters:
- Returns:
(epoch, samples_processed, batches_processed, best_val_loss, best_epoch, checkpoint)
- Return type:
- static count_parameters_by_layer(model, logger=None)[source]
Count parameters for each layer in the model.
- Parameters:
model (torch.nn.Module) – PyTorch model to analyze
logger (Logger, optional) – Logger instance for output, by default None
- Returns:
Dictionary with layer names as keys and parameter counts as values
- Return type:
Examples
>>> layer_params = ModelUtils.count_parameters_by_layer(model, logger)
- static log_model_summary(model, input_shape=None, logger=None)[source]
Log comprehensive model summary including parameters and architecture.
- static save_training_checkpoint(model, optimizer, epoch, samples_processed, batches_processed, train_loss_history, valid_loss_history, valid_metrics_history, best_val_loss, best_epoch, avg_val_loss, avg_epoch_loss, args, paths, logger, checkpoint_type='epoch', save_full_model=True)[source]
Save comprehensive training checkpoint with consistent formatting.
- Parameters:
model (torch.nn.Module) – Model to save
optimizer (torch.optim.Optimizer) – Optimizer to save
epoch (int) – Current epoch
samples_processed (int) – Number of samples processed so far
batches_processed (int) – Number of batches processed so far
train_loss_history (list) – History of training losses
valid_loss_history (list) – History of validation losses
valid_metrics_history (dict) – History of validation metrics
best_val_loss (float) – Best validation loss so far
best_epoch (int) – Epoch with best validation loss
avg_val_loss (float) – Current epoch validation loss
avg_epoch_loss (float) – Current epoch training loss
args (argparse.Namespace) – Command line arguments
paths (EasyDict) – Directory paths
logger (Logger) – Logger instance
checkpoint_type (str) – Type of checkpoint: “samples”, “epoch”, “best”, “final”
save_full_model (bool) – Whether to also save the full model separately
- Returns:
(checkpoint_filename, full_model_filename)
- Return type:
Examples
>>> checkpoint_file, full_model_file = ModelUtils.save_training_checkpoint( ... model, optimizer, epoch, samples_processed, batches_processed, ... train_loss_history, valid_loss_history, valid_metrics_history, ... best_val_loss, best_epoch, avg_val_loss, avg_epoch_loss, ... args, paths, logger, checkpoint_type="best" ... )
- class IPSL_AID.model_utils.TestModel(*args: Any, **kwargs: Any)[source]
Bases:
ModuleA model for testing purposes that includes a mix of convolutional and linear layers, as well as some non-trainable parameters (buffers). This model is designed to have a reasonable number of parameters for testing the ModelUtils methods without being too large. It includes batch normalization layers to add complexity and a dropout layer to demonstrate non-trainable parameters.
- class IPSL_AID.model_utils.TestModelUtils(methodName='runTest', logger=None)[source]
Bases:
TestCaseUnit tests for ModelUtils class.
- __init__(methodName='runTest', logger=None)[source]
Create an instance of the class that will use the named test method when executed. Raises a ValueError if the instance does not have a method with the specified name.
- test_log_model_summary_without_input_shape()[source]
Test model summary logging without input shape.
- test_load_training_checkpoint_nonexistent()[source]
Test loading a nonexistent training checkpoint.
- test_save_training_checkpoint_epoch_type()[source]
Test saving training checkpoint with epoch type.
- test_save_training_checkpoint_final_type()[source]
Test saving training checkpoint with final type.
IPSL_AID.networks module
- IPSL_AID.networks.weight_init(shape, mode, fan_in, fan_out)[source]
Initialize weights using various initialization methods.
- Parameters:
shape (tuple of ints) – The shape of the weight tensor to initialize.
mode (str) – The initialization method to use. Options are: - ‘xavier_uniform’: Xavier uniform initialization - ‘xavier_normal’: Xavier normal initialization - ‘kaiming_uniform’: Kaiming uniform initialization (also known as He initialization) - ‘kaiming_normal’: Kaiming normal initialization (also known as He initialization)
fan_in (int) – Number of input units in the weight tensor.
fan_out (int) – Number of output units in the weight tensor.
- Returns:
A tensor of the specified shape with values initialized according to the chosen method.
- Return type:
torch.Tensor
- Raises:
ValueError – If an invalid initialization mode is provided.
- class IPSL_AID.networks.Linear(*args: Any, **kwargs: Any)[source]
Bases:
ModuleA linear (fully connected) layer with customizable weight initialization.
This layer applies a linear transformation to the incoming data:
y = x W^T + b.- Parameters:
in_features (int) – Size of each input sample.
out_features (int) – Size of each output sample.
bias (bool, optional) – If set to False, the layer will not learn an additive bias. Default is True.
init_mode (str, optional) – Weight initialization method. Options are: -
xavier_uniform: Xavier uniform initialization -xavier_normal: Xavier normal initialization -kaiming_uniform: Kaiming uniform initialization (He initialization) -kaiming_normal: Kaiming normal initialization (He initialization) Default iskaiming_normal.init_weight (float or int, optional) – Scaling factor for the initialized weights. Default is 1.
init_bias (float or int, optional) – Scaling factor for the initialized bias. Default is 0.
- weight
The learnable weights of the layer of shape (out_features, in_features).
- Type:
torch.nn.Parameter
- bias
The learnable bias of the layer of shape (out_features,). If bias=False, this attribute is set to None.
- Type:
torch.nn.Parameter or None
- __init__(in_features, out_features, bias=True, init_mode='kaiming_normal', init_weight=1, init_bias=0)[source]
Initialize the Linear layer.
- Parameters:
in_features (int) – Size of each input sample.
out_features (int) – Size of each output sample.
bias (bool, optional) – If set to False, the layer will not learn an additive bias. Default is True.
init_mode (str, optional) – Weight initialization method. Default is ‘kaiming_normal’.
init_weight (float or int, optional) – Scaling factor for the initialized weights. Default is 1.
init_bias (float or int, optional) – Scaling factor for the initialized bias. Default is 0.
- forward(x)[source]
Forward pass of the linear layer.
- Parameters:
x (torch.Tensor) – Input tensor of shape (batch_size, in_features) or
(batch_size, *, in_features)where*means any number of additional dimensions.- Returns:
Output tensor of shape (batch_size, out_features) or
(batch_size, *, out_features).- Return type:
torch.Tensor
Notes
The operation performed is:
output = x @ weight^T + bias. The bias is added in-place for efficiency when possible.
- class IPSL_AID.networks.Conv2d(*args: Any, **kwargs: Any)[source]
Bases:
Module2D convolutional layer with optional upsampling, downsampling, and fused resampling.
This layer implements a 2D convolution that can optionally include upsampling or downsampling operations with configurable resampling filters. It supports both separate and fused resampling modes for efficiency.
- Parameters:
in_channels (int) – Number of input channels.
out_channels (int) – Number of output channels.
kernel (int) – Size of the convolutional kernel (square kernel).
bias (bool, optional) – If set to False, the layer will not learn an additive bias. Default is True.
up (bool, optional) – If True, upsample the input by a factor of 2 before convolution. Cannot be True if down is also True. Default is False.
down (bool, optional) – If True, downsample the output by a factor of 2 after convolution. Cannot be True if up is also True. Default is False.
resample_filter (list, optional) – Coefficients of the 1D resampling filter that will be turned into a 2D filter. Default is [1, 1] (bilinear filter).
fused_resample (bool, optional) – If True, fuse the resampling operation with the convolution for efficiency. Default is False.
init_mode (str, optional) – Weight initialization method. Options are: - ‘xavier_uniform’: Xavier uniform initialization - ‘xavier_normal’: Xavier normal initialization - ‘kaiming_uniform’: Kaiming uniform initialization (He initialization) - ‘kaiming_normal’: Kaiming normal initialization (He initialization) Default is ‘kaiming_normal’.
init_weight (float or int, optional) – Scaling factor for the initialized weights. Default is 1.
init_bias (float or int, optional) – Scaling factor for the initialized bias. Default is 0.
- weight
The learnable weights of the convolution of shape (out_channels, in_channels, kernel, kernel). If kernel is 0, this is None.
- Type:
torch.nn.Parameter or None
- bias
The learnable bias of the convolution of shape (out_channels,). If kernel is 0 or bias is False, this is None.
- Type:
torch.nn.Parameter or None
- resample_filter
The 2D resampling filter used for upsampling or downsampling. Registered as a buffer (non-learnable parameter).
- Type:
torch.Tensor or None
- Raises:
AssertionError – If both up and down are set to True.
Notes
When kernel is 0, no convolution is performed, only resampling if enabled.
The resampling filter is created by taking the outer product of the 1D filter with itself to create a separable 2D filter, then normalized.
Fused resampling combines the resampling and convolution operations into single operations for better performance.
Examples
>>> # Standard convolution >>> conv = Conv2d(3, 16, kernel=3) >>> x = torch.randn(4, 3, 32, 32) >>> out = conv(x) >>> out.shape torch.Size([4, 16, 32, 32])
>>> # Convolution with downsampling >>> conv_down = Conv2d(3, 16, kernel=3, down=True) >>> out = conv_down(x) >>> out.shape torch.Size([4, 16, 16, 16])
>>> # Convolution with upsampling >>> conv_up = Conv2d(3, 16, kernel=3, up=True) >>> out = conv_up(x) >>> out.shape torch.Size([4, 16, 64, 64])
- __init__(in_channels, out_channels, kernel, bias=True, up=False, down=False, resample_filter=[1, 1], fused_resample=False, init_mode='kaiming_normal', init_weight=1, init_bias=0)[source]
Initialize the Conv2d layer.
- Parameters:
in_channels (int) – Number of input channels.
out_channels (int) – Number of output channels.
kernel (int) – Size of the convolutional kernel.
bias (bool, optional) – Whether to include a bias term. Default is True.
up (bool, optional) – Whether to upsample the input. Default is False.
down (bool, optional) – Whether to downsample the output. Default is False.
resample_filter (list, optional) – Coefficients of the 1D resampling filter. Default is [1, 1].
fused_resample (bool, optional) – Whether to fuse resampling with convolution. Default is False.
init_mode (str, optional) – Weight initialization method. Default is ‘kaiming_normal’.
init_weight (float or int, optional) – Scaling factor for weight initialization. Default is 1.
init_bias (float or int, optional) – Scaling factor for bias initialization. Default is 0.
- forward(x)[source]
Forward pass of the Conv2d layer.
- Parameters:
x (torch.Tensor) – Input tensor of shape (batch_size, in_channels, height, width).
- Returns:
Output tensor of shape (batch_size, out_channels, out_height, out_width). If up is True, spatial dimensions are doubled. If down is True, spatial dimensions are halved.
- Return type:
torch.Tensor
Notes
The method handles four main cases: 1. Fused upsampling + convolution 2. Fused convolution + downsampling 3. Separate up/down sampling followed by convolution 4. Standard convolution only
- class IPSL_AID.networks.GroupNorm(*args: Any, **kwargs: Any)[source]
Bases:
ModuleGroup Normalization layer.
This layer implements Group Normalization, which divides channels into groups and computes within each group the mean and variance for normalization. It is particularly effective for small batch sizes and often used as an alternative to Batch Normalization.
- Parameters:
num_channels (int) – Number of input channels.
num_groups (int, optional) – Number of groups to divide the channels into. Must be a divisor of the number of channels. The actual number of groups may be reduced to satisfy min_channels_per_group. Default is 32.
min_channels_per_group (int, optional) – Minimum number of channels per group. If the division would result in fewer channels per group, the number of groups is reduced. Default is 4.
eps (float, optional) – A small constant added to the denominator for numerical stability. Default is 1e-5.
- weight
Learnable scale parameter of shape (num_channels,). Initialized to ones.
- Type:
torch.nn.Parameter
- bias
Learnable bias parameter of shape (num_channels,). Initialized to zeros.
- Type:
torch.nn.Parameter
Notes
Group Normalization is independent of batch size, making it suitable for variable batch sizes and small batch training.
The number of groups is automatically adjusted to ensure each group has at least min_channels_per_group channels.
This layer uses PyTorch’s built-in torch.nn.functional.group_norm.
- __init__(num_channels, num_groups=32, min_channels_per_group=4, eps=1e-05)[source]
Initialize the GroupNorm layer.
- forward(x)[source]
Forward pass of the GroupNorm layer.
- Parameters:
x (torch.Tensor) – Input tensor of shape (batch_size, num_channels, height, width).
- Returns:
Normalized tensor of same shape as input.
- Return type:
torch.Tensor
Notes
The normalization is performed across spatial dimensions and within each group of channels, maintaining the original mean and variance statistics per group.
- class IPSL_AID.networks.AttentionOp(*args: Any, **kwargs: Any)[source]
Bases:
FunctionCustom autograd function for scaled dot-product attention weight computation.
This function computes attention weights using scaled dot-product attention: w = softmax(Q·K^T / √d_k), where d_k is the dimension of the key vectors. It implements both forward and backward passes for gradient computation.
Notes
This is a stateless operation that uses torch.autograd.Function for custom backward.
The forward pass computes attention weights in float32 for numerical stability.
The backward pass computes gradients using the chain rule for softmax and matrix multiplication.
This implementation is optimized for memory efficiency during backward pass.
- static forward(ctx, q, k)[source]
Forward pass for attention weight computation.
- Parameters:
ctx (torch.autograd.function.BackwardCFunction) – Context object to save tensors for backward pass.
q (torch.Tensor) – Query tensor of shape (batch_size, channels, query_length).
k (torch.Tensor) – Key tensor of shape (batch_size, channels, key_length).
- Returns:
Attention weights of shape (batch_size, query_length, key_length). Each row represents attention distribution for a query position.
- Return type:
torch.Tensor
Notes
Computes w = softmax(Q·K^T / √d_k) where d_k = k.shape[1] (channel dimension).
Uses float32 for computation to maintain numerical stability.
Saves q, k, and w in context for backward pass.
- static backward(ctx, dw)[source]
Backward pass for attention weight computation.
- Parameters:
ctx (torch.autograd.function.BackwardCFunction) – Context object containing saved tensors from forward pass.
dw (torch.Tensor) – Gradient of loss with respect to attention weights. Shape: (batch_size, query_length, key_length).
- Returns:
dq (torch.Tensor) – Gradient with respect to query tensor. Shape: (batch_size, channels, query_length).
dk (torch.Tensor) – Gradient with respect to key tensor. Shape: (batch_size, channels, key_length).
Notes
Uses the saved tensors q, k, w from forward pass.
Computes gradient of softmax using PyTorch’s internal softmax_backward.
Applies chain rule for the scaled dot-product operation.
Maintains original dtypes of input tensors.
- class IPSL_AID.networks.UNetBlock(*args: Any, **kwargs: Any)[source]
Bases:
ModuleU-Net block with optional attention, up/down sampling, and adaptive scaling.
This block implements a residual block commonly used in U-Net architectures for diffusion models and image-to-image translation. It supports: - Residual connections with optional skip scaling - Adaptive scaling/shifting via conditioning embeddings - Multi-head self-attention mechanisms - Upsampling or downsampling operations - Dropout for regularization
- Parameters:
in_channels (int) – Number of input channels.
out_channels (int) – Number of output channels.
emb_channels (int) – Number of embedding (conditioning) channels.
up (bool, optional) – If True, upsample the input by a factor of 2. Default is False.
down (bool, optional) – If True, downsample the output by a factor of 2. Default is False.
attention (bool, optional) – If True, include multi-head self-attention in the block. Default is False.
num_heads (int, optional) – Number of attention heads. If None, computed as out_channels // channels_per_head. Default is None.
channels_per_head (int, optional) – Number of channels per attention head when num_heads is None. Default is 64.
dropout (float, optional) – Dropout probability applied after the first activation. Default is 0.
skip_scale (float, optional) – Scaling factor applied to the residual connection. Default is 1.
eps (float, optional) – Epsilon value for GroupNorm layers for numerical stability. Default is 1e-5.
resample_filter (list, optional) – Coefficients for the resampling filter used in up/down sampling. Default is [1, 1].
resample_proj (bool, optional) – If True, use a 1x1 convolution in the skip connection when resampling. Default is False.
adaptive_scale (bool, optional) – If True, use both scale and shift parameters from the embedding. If False, use only shift parameters. Default is True.
init (dict, optional) – Initialization parameters for most convolutional layers. Default is empty dict.
init_zero (dict, optional) – Initialization parameters for final convolutional layers (zero initialization). Default is {‘init_weight’: 0}.
init_attn (dict, optional) – Initialization parameters for attention layers. If None, uses the same as init. Default is None.
- norm0, norm1, norm2
Group normalization layers.
- Type:
- conv0, conv1
Convolutional layers.
- Type:
- skip
Skip connection projection (1x1 conv) if input and output channels differ or resampling.
- Type:
Conv2d or None
- qkv, proj
Attention query-key-value and projection layers (if attention is enabled).
- Type:
Notes
The block follows a pre-activation residual structure: norm -> activation -> conv.
When adaptive_scale=True, the conditioning embedding provides both scale and shift parameters.
The attention mechanism uses multi-head self-attention within the spatial dimensions.
The skip connection is automatically added when input/output channels differ or when resampling.
- __init__(in_channels, out_channels, emb_channels, up=False, down=False, attention=False, num_heads=None, channels_per_head=64, dropout=0, skip_scale=1, eps=1e-05, resample_filter=[1, 1], resample_proj=False, adaptive_scale=True, init={}, init_zero={'init_weight': 0}, init_attn=None)[source]
Initialize the UNetBlock.
- Parameters:
in_channels (int) – Number of input channels.
out_channels (int) – Number of output channels.
emb_channels (int) – Number of embedding channels.
up (bool, optional) – Enable upsampling.
down (bool, optional) – Enable downsampling.
attention (bool, optional) – Enable attention mechanism.
num_heads (int, optional) – Number of attention heads.
channels_per_head (int, optional) – Channels per attention head.
dropout (float, optional) – Dropout probability.
skip_scale (float, optional) – Scaling factor for skip connection.
eps (float, optional) – Epsilon for GroupNorm.
resample_filter (list, optional) – Filter for resampling.
resample_proj (bool, optional) – Use projection in skip connection when resampling.
adaptive_scale (bool, optional) – Use adaptive scaling from embedding.
init (dict, optional) – Initialization parameters.
init_zero (dict, optional) – Zero initialization parameters.
init_attn (dict, optional) – Attention initialization parameters.
- forward(x, emb)[source]
Forward pass of the UNetBlock.
- Parameters:
x (torch.Tensor) – Input tensor of shape (batch_size, in_channels, height, width).
emb (torch.Tensor) – Conditioning embedding of shape (batch_size, emb_channels).
- Returns:
Output tensor of shape (batch_size, out_channels, out_height, out_width).
- Return type:
torch.Tensor
Notes
The forward pass consists of: 1. Initial normalization and convolution (with optional up/down sampling) 2. Adaptive scaling/shifting from conditioning embedding 3. Second normalization, dropout, and convolution 4. Skip connection with scaling 5. Optional multi-head self-attention
- class IPSL_AID.networks.PositionalEmbedding(*args: Any, **kwargs: Any)[source]
Bases:
ModuleSinusoidal positional embedding for sequences or timesteps.
This module generates sinusoidal embeddings for input positions, commonly used in transformer architectures and diffusion models to provide temporal or positional information to the model.
- Parameters:
num_channels (int) – Dimensionality of the embedding vectors. Must be even.
max_positions (int, optional) – Maximum number of positions (or timesteps) for which embeddings are generated. Determines the frequency scaling. Default is 10000.
endpoint (bool, optional) – If True, scales frequencies such that the last frequency is 1/2 of the first. If False, uses the standard scaling. Default is False.
Notes
The embedding uses sine and cosine functions of different frequencies to create a unique encoding for each position.
The frequencies are computed as: freqs = (1 / max_positions) ** (2i / num_channels) for i in range(num_channels//2) or with endpoint adjustment.
The output embedding is the concatenation of [cos(x*freqs), sin(x*freqs)].
This implementation is based on the original Transformer positional encoding and the diffusion model timestep embedding.
- __init__(num_channels, max_positions=10000, endpoint=False)[source]
Initialize the PositionalEmbedding module.
- forward(x)[source]
Forward pass to generate positional embeddings.
- Parameters:
x (torch.Tensor) – Input tensor of positions (or timesteps) of shape (batch_size,) or (n,). Values are typically integers in [0, max_positions-1].
- Returns:
Positional embeddings of shape (len(x), num_channels).
- Return type:
torch.Tensor
Notes
The input tensor x is typically a 1D tensor of position indices.
The output is a 2D tensor where each row corresponds to the embedding of the respective position.
The embedding uses the device and dtype of the input tensor x.
- class IPSL_AID.networks.FourierEmbedding(*args: Any, **kwargs: Any)[source]
Bases:
ModuleRandom Fourier feature embedding for positional encoding.
This module generates random Fourier features (RFF) for input positions or coordinates, mapping low-dimensional inputs to a higher-dimensional space using random frequency sampling. Commonly used in neural fields, kernel methods, and coordinate-based neural networks.
- Parameters:
- freqs
Random frequencies sampled from a normal distribution with mean 0 and standard deviation scale. Shape: (num_channels // 2,).
- Type:
torch.Tensor (buffer)
Notes
The frequencies are randomly initialized and fixed (non-learnable).
The embedding uses sine and cosine projections of the input multiplied by 2π times the random frequencies.
This technique approximates shift-invariant kernels via Bochner’s theorem.
Unlike learned embeddings, this provides a fixed, deterministic mapping from input space to embedding space.
- forward(x)[source]
Forward pass to generate Fourier feature embeddings.
- Parameters:
x (torch.Tensor) – Input tensor of shape (batch_size,) or (n,). Typically continuous values representing positions or coordinates.
- Returns:
Fourier feature embeddings of shape (len(x), num_channels).
- Return type:
torch.Tensor
Notes
The transformation is: x ↦ [cos(2π * freqs * x), sin(2π * freqs * x)].
The output dimension is twice the number of frequencies (num_channels).
This embedding is deterministic given the fixed random frequencies.
- class IPSL_AID.networks.SongUNet(*args: Any, **kwargs: Any)[source]
Bases:
ModuleU-Net architecture for diffusion models based on Song et al. (2020).
This implementation supports both DDPM++ and NCSN++ architectures with configurable encoder/decoder types, attention mechanisms, and conditioning. It handles both square and rectangular input resolutions.
- Parameters:
img_resolution (int or tuple) – Input image resolution. If int, assumes square images (img_resolution x img_resolution). If tuple, should be (height, width).
in_channels (int) – Number of input color channels.
out_channels (int) – Number of output color channels.
label_dim (int, optional) – Number of class labels. Set to 0 for unconditional generation. Default is 0.
augment_dim (int, optional) – Dimensionality of augmentation labels (e.g., time-dependent augmentation). Set to 0 for no augmentation. Default is 0.
model_channels (int, optional) – Base channel multiplier for the network. Default is 128.
channel_mult (list of int, optional) – Channel multipliers for each resolution level. Default is [1, 2, 2, 2].
channel_mult_emb (int, optional) – Multiplier for embedding dimensionality relative to model_channels. Default is 4.
num_blocks (int, optional) – Number of residual blocks per resolution. Default is 4.
attn_resolutions (list of int, optional) – List of resolutions (minimum dimension) to apply self-attention. Default is [16].
dropout (float, optional) – Dropout probability for intermediate activations. Default is 0.10.
label_dropout (float, optional) – Dropout probability for class labels (classifier-free guidance). Default is 0.
embedding_type (str, optional) – Type of timestep embedding: ‘positional’ for DDPM++, ‘fourier’ for NCSN++. Default is ‘positional’.
channel_mult_noise (int, optional) – Multiplier for noise embedding dimensionality: 1 for DDPM++, 2 for NCSN++. Default is 1.
encoder_type (str, optional) – Encoder architecture: ‘standard’ for DDPM++, ‘skip’ or ‘residual’ for NCSN++. Default is ‘standard’.
decoder_type (str, optional) – Decoder architecture: ‘standard’ for both, ‘skip’ for NCSN++. Default is ‘standard’.
resample_filter (list, optional) – Resampling filter coefficients: [1,1] for DDPM++, [1,3,3,1] for NCSN++. Default is [1,1].
- map_noise
Noise/timestep embedding module.
- Type:
- map_layer0, map_layer1
Embedding transformation layers.
- Type:
- enc
Encoder modules organized by resolution.
- Type:
torch.nn.ModuleDict
- dec
Decoder modules organized by resolution.
- Type:
torch.nn.ModuleDict
- Raises:
AssertionError – If embedding_type is not ‘fourier’ or ‘positional’. If encoder_type is not ‘standard’, ‘skip’, or ‘residual’. If decoder_type is not ‘standard’ or ‘skip’. If img_resolution tuple doesn’t have exactly 2 elements.
Notes
The architecture follows a U-Net structure with skip connections.
Supports multiple conditioning types: noise (timestep), class labels, augmentations.
Attention is applied at specified resolutions to capture long-range dependencies.
Different encoder/decoder types and embedding methods allow emulating DDPM++ or NCSN++.
Rectangular resolutions are supported by tracking height and width separately.
References
Ho et al., “Denoising Diffusion Probabilistic Models” (DDPM)
Song et al., “Score-Based Generative Modeling through Stochastic Differential Equations” (NCSN++)
- __init__(img_resolution, in_channels, out_channels, label_dim=0, augment_dim=0, model_channels=128, channel_mult=[1, 2, 2, 2], channel_mult_emb=4, num_blocks=4, attn_resolutions=[16], dropout=0.1, label_dropout=0, embedding_type='positional', channel_mult_noise=1, encoder_type='standard', decoder_type='standard', resample_filter=[1, 1])[source]
Initialize the SongUNet.
- Parameters:
in_channels (int) – Input channels.
out_channels (int) – Output channels.
label_dim (int, optional) – Class label dimension.
augment_dim (int, optional) – Augmentation label dimension.
model_channels (int, optional) – Base channel multiplier.
channel_mult (list, optional) – Channel multipliers per resolution.
channel_mult_emb (int, optional) – Embedding channel multiplier.
num_blocks (int, optional) – Blocks per resolution.
attn_resolutions (list, optional) – Resolutions for attention.
dropout (float, optional) – Dropout probability.
label_dropout (float, optional) – Label dropout probability.
embedding_type (str, optional) – Embedding type.
channel_mult_noise (int, optional) – Noise embedding multiplier.
encoder_type (str, optional) – Encoder type.
decoder_type (str, optional) – Decoder type.
resample_filter (list, optional) – Resampling filter coefficients.
- forward(x, noise_labels, class_labels, augment_labels=None)[source]
Forward pass through the U-Net.
- Parameters:
x (torch.Tensor) – Input tensor of shape (batch_size, in_channels, height, width).
noise_labels (torch.Tensor) – Noise/timestep labels of shape (batch_size,).
class_labels (torch.Tensor or None) – Class labels of shape (batch_size,) or (batch_size, label_dim). Can be None if label_dim is 0.
augment_labels (torch.Tensor or None, optional) – Augmentation labels of shape (batch_size, augment_dim). Can be None if augment_dim is 0.
- Returns:
Output tensor of shape (batch_size, out_channels, height, width).
- Return type:
torch.Tensor
Notes
The forward pass consists of three main stages: 1. Embedding mapping: combines noise, class, and augmentation embeddings. 2. Encoder: extracts hierarchical features with optional skip connections. 3. Decoder: reconstructs output with skip connections from encoder.
Classifier-free guidance is supported via label_dropout.
The noise embedding uses sinusoidal (positional) or Fourier features.
- class IPSL_AID.networks.DhariwalUNet(*args: Any, **kwargs: Any)[source]
Bases:
ModuleU-Net architecture based on Dhariwal & Nichol (2021) for diffusion models.
This implementation follows the ADM (Ablated Diffusion Model) architecture with configurable attention mechanisms, conditioning, and rectangular resolution support. It features a U-Net structure with skip connections, group normalization, and optional conditioning on class labels and augmentation.
- Parameters:
img_resolution (int or tuple) – Input image resolution. If int, assumes square images (img_resolution x img_resolution). If tuple, should be (height, width).
in_channels (int) – Number of input color channels.
out_channels (int) – Number of output color channels.
label_dim (int, optional) – Number of class labels. Set to 0 for unconditional generation. Default is 0.
augment_dim (int, optional) – Dimensionality of augmentation labels (e.g., time-dependent augmentation). Set to 0 for no augmentation. Default is 0.
model_channels (int, optional) – Base channel multiplier for the network. Default is 128.
channel_mult (list of int, optional) – Channel multipliers for each resolution level. Default is [1, 2, 3, 4].
channel_mult_emb (int, optional) – Multiplier for embedding dimensionality relative to model_channels. Default is 4.
num_blocks (int, optional) – Number of residual blocks per resolution. Default is 3.
attn_resolutions (list of int, optional) – List of resolutions (minimum dimension) to apply self-attention. Default is [32, 16, 8].
dropout (float, optional) – Dropout probability for intermediate activations. Default is 0.10.
label_dropout (float, optional) – Dropout probability for class labels (classifier-free guidance). Default is 0.
diffusion_model (bool, optional) – Whether to configure the network for diffusion models. If True, includes timestep embedding; if False, only uses label conditioning. Default is True.
- map_noise
Noise/timestep embedding module (if diffusion_model=True).
- Type:
PositionalEmbedding or None
- map_layer0, map_layer1
Embedding transformation layers.
- Type:
- enc
Encoder modules organized by resolution.
- Type:
torch.nn.ModuleDict
- dec
Decoder modules organized by resolution.
- Type:
torch.nn.ModuleDict
- Raises:
AssertionError – If img_resolution tuple doesn’t have exactly 2 elements.
Notes
The architecture is based on the U-Net from “Diffusion Models Beat GANs on Image Synthesis”.
Features group normalization throughout and attention at multiple resolutions.
Supports classifier-free guidance via label_dropout.
Can be used for both diffusion models and other conditional generation tasks.
Rectangular resolutions are supported by tracking height and width separately.
References
Dhariwal & Nichol, “Diffusion Models Beat GANs on Image Synthesis”, 2021.
- __init__(img_resolution, in_channels, out_channels, label_dim=0, augment_dim=0, model_channels=128, channel_mult=[1, 2, 3, 4], channel_mult_emb=4, num_blocks=3, attn_resolutions=[32, 16, 8], dropout=0.1, label_dropout=0, diffusion_model=True)[source]
Initialize the DhariwalUNet.
- Parameters:
in_channels (int) – Input channels.
out_channels (int) – Output channels.
label_dim (int, optional) – Class label dimension.
augment_dim (int, optional) – Augmentation label dimension.
model_channels (int, optional) – Base channel multiplier.
channel_mult (list, optional) – Channel multipliers per resolution.
channel_mult_emb (int, optional) – Embedding channel multiplier.
num_blocks (int, optional) – Blocks per resolution.
attn_resolutions (list, optional) – Resolutions for attention.
dropout (float, optional) – Dropout probability.
label_dropout (float, optional) – Label dropout probability.
diffusion_model (bool, optional) – Whether to configure for diffusion models.
- forward(x, noise_labels=None, class_labels=None, augment_labels=None)[source]
Forward pass through the Dhariwal U-Net.
- Parameters:
x (torch.Tensor) – Input tensor of shape (batch_size, in_channels, height, width).
noise_labels (torch.Tensor or None) – Noise/timestep labels of shape (batch_size,). Required if diffusion_model=True, otherwise optional.
class_labels (torch.Tensor or None) – Class labels of shape (batch_size,) or (batch_size, label_dim). Can be None if label_dim is 0.
augment_labels (torch.Tensor or None, optional) – Augmentation labels of shape (batch_size, augment_dim). Can be None if augment_dim is 0.
- Returns:
Output tensor of shape (batch_size, out_channels, height, width).
- Return type:
torch.Tensor
Notes
The forward pass combines conditioning embeddings (noise, class, augmentation) and processes through encoder-decoder with skip connections.
When diffusion_model=False, the noise_labels can be omitted.
Classifier-free guidance is implemented via label_dropout during training.
- class IPSL_AID.networks.VPPrecond(*args: Any, **kwargs: Any)[source]
Bases:
ModuleVariance Preserving (VP) preconditioning for diffusion models.
This class implements preconditioning for the Variance Preserving formulation of diffusion models, as described in Song et al. (2020). It wraps a base U-Net model and applies the appropriate scaling and conditioning for VP SDEs.
- Parameters:
img_resolution (int or tuple) – Input image resolution. If int, assumes square images. If tuple, should be (height, width).
in_channels (int) – Number of input color channels.
out_channels (int) – Number of output color channels.
label_dim (int, optional) – Number of class labels. Set to 0 for unconditional generation. Default is 0.
use_fp16 (bool, optional) – Whether to execute the underlying model at FP16 precision for speed. Default is False.
beta_d (float, optional) – Extent of the noise level schedule. Controls the rate of noise increase. Default is 19.9.
beta_min (float, optional) – Initial slope of the noise level schedule. Default is 0.1.
M (int, optional) – Original number of timesteps in the DDPM formulation. Default is 1000.
epsilon_t (float, optional) – Minimum t-value used during training. Prevents numerical issues. Default is 1e-5.
model_type (str, optional) – Class name of the underlying U-Net model (‘SongUNet’ or ‘DhariwalUNet’). Default is ‘SongUNet’.
**model_kwargs (dict) – Additional keyword arguments passed to the underlying model.
- model
The underlying U-Net model.
- Type:
torch.nn.Module
Notes
The VP formulation maintains unit variance throughout the diffusion process.
The noise schedule follows: σ(t) = sqrt(exp(0.5*β_d*t² + β_min*t) - 1)
The preconditioning applies scaling factors: c_skip, c_out, c_in, c_noise
Supports conditional generation via class labels and condition images.
Implements the continuous-time formulation of diffusion models.
References
Song et al., “Score-Based Generative Modeling through Stochastic Differential Equations”, 2020.
- __init__(img_resolution, in_channels, out_channels, label_dim=0, use_fp16=False, beta_d=19.9, beta_min=0.1, M=1000, epsilon_t=1e-05, model_type='SongUNet', **model_kwargs)[source]
Initialize the VPPrecond module.
- Parameters:
in_channels (int) – Input channels.
out_channels (int) – Output channels.
label_dim (int, optional) – Class label dimension.
use_fp16 (bool, optional) – Use FP16 precision.
beta_d (float, optional) – Noise schedule extent.
beta_min (float, optional) – Initial noise schedule slope.
M (int, optional) – Number of timesteps.
epsilon_t (float, optional) – Minimum t-value.
model_type (str, optional) – Underlying model class name.
**model_kwargs (dict) – Additional model arguments.
- forward(x, sigma, condition_img=None, class_labels=None, force_fp32=False, **model_kwargs)[source]
Forward pass with VP preconditioning.
- Parameters:
x (torch.Tensor) – Input noisy tensor of shape (batch_size, in_channels, height, width).
sigma (torch.Tensor) – Noise level(s) of shape (batch_size,) or scalar.
condition_img (torch.Tensor, optional) – Condition image tensor of same spatial dimensions as x. Default is None.
class_labels (torch.Tensor, optional) – Class labels for conditioning of shape (batch_size,) or (batch_size, label_dim). Default is None.
force_fp32 (bool, optional) – Force FP32 precision even if use_fp16 is True. Default is False.
**model_kwargs (dict) – Additional keyword arguments passed to the underlying model.
- Returns:
Denoised output of shape (batch_size, out_channels, height, width).
- Return type:
torch.Tensor
Notes
Applies the preconditioning: D(x) = c_skip * x + c_out * F(c_in * x, c_noise)
Where F is the underlying U-Net model.
c_in, c_out, c_skip, c_noise are computed from sigma according to VP formulation.
Condition images are concatenated along the channel dimension.
- sigma(t)[source]
Compute noise level sigma for given time t.
- Parameters:
t (float or torch.Tensor) – Time value(s) in [epsilon_t, 1].
- Returns:
Noise level sigma corresponding to t.
- Return type:
torch.Tensor
Notes
Formula: σ(t) = sqrt(exp(0.5*β_d*t² + β_min*t) - 1)
- class IPSL_AID.networks.VEPrecond(*args: Any, **kwargs: Any)[source]
Bases:
ModuleVariance Exploding (VE) preconditioning for diffusion models.
This class implements preconditioning for the Variance Exploding formulation of diffusion models, as described in Song et al. (2020). It wraps a base U-Net model and applies the appropriate scaling and conditioning for VE SDEs.
- Parameters:
img_resolution (int or tuple) – Input image resolution. If int, assumes square images. If tuple, should be (height, width).
in_channels (int) – Number of input color channels.
out_channels (int) – Number of output color channels.
label_dim (int, optional) – Number of class labels. Set to 0 for unconditional generation. Default is 0.
use_fp16 (bool, optional) – Whether to execute the underlying model at FP16 precision for speed. Default is False.
sigma_min (float, optional) – Minimum supported noise level. Default is 0.02.
sigma_max (float, optional) – Maximum supported noise level. Default is 100.
model_type (str, optional) – Class name of the underlying U-Net model (‘SongUNet’ or ‘DhariwalUNet’). Default is ‘SongUNet’.
**model_kwargs (dict) – Additional keyword arguments passed to the underlying model.
- model
The underlying U-Net model.
- Type:
torch.nn.Module
Notes
The VE formulation uses a simple geometric noise schedule.
The preconditioning applies scaling factors: c_skip, c_out, c_in, c_noise
c_noise = 0.5 * log(sigma) maps noise levels to conditioning inputs.
Supports conditional generation via class labels and condition images.
References
Song et al., “Score-Based Generative Modeling through Stochastic Differential Equations”, 2020.
- __init__(img_resolution, in_channels, out_channels, label_dim=0, use_fp16=False, sigma_min=0.02, sigma_max=100, model_type='SongUNet', **model_kwargs)[source]
Initialize the VEPrecond module.
- Parameters:
in_channels (int) – Input channels.
out_channels (int) – Output channels.
label_dim (int, optional) – Class label dimension.
use_fp16 (bool, optional) – Use FP16 precision.
sigma_min (float, optional) – Minimum noise level.
sigma_max (float, optional) – Maximum noise level.
model_type (str, optional) – Underlying model class name.
**model_kwargs (dict) – Additional model arguments.
- forward(x, sigma, condition_img=None, class_labels=None, force_fp32=False, **model_kwargs)[source]
Forward pass with VE preconditioning.
- Parameters:
x (torch.Tensor) – Input noisy tensor of shape (batch_size, in_channels, height, width).
sigma (torch.Tensor) – Noise level(s) of shape (batch_size,) or scalar.
condition_img (torch.Tensor, optional) – Condition image tensor of same spatial dimensions as x. Default is None.
class_labels (torch.Tensor, optional) – Class labels for conditioning of shape (batch_size,) or (batch_size, label_dim). Default is None.
force_fp32 (bool, optional) – Force FP32 precision even if use_fp16 is True. Default is False.
**model_kwargs (dict) – Additional keyword arguments passed to the underlying model.
- Returns:
Denoised output of shape (batch_size, out_channels, height, width).
- Return type:
torch.Tensor
Notes
Applies the preconditioning: D(x) = c_skip * x + c_out * F(c_in * x, c_noise)
Where F is the underlying U-Net model.
For VE: c_skip = 1, c_out = sigma, c_in = 1, c_noise = 0.5 * log(sigma)
Condition images are concatenated along the channel dimension.
- class IPSL_AID.networks.iDDPMPrecond(*args: Any, **kwargs: Any)[source]
Bases:
ModuleImproved DDPM (iDDPM) preconditioning for diffusion models.
This class implements the improved preconditioning scheme from the iDDPM paper, which refines the noise schedule and preconditioning for better sample quality. It provides a bridge between discrete-time DDPM formulations and continuous-time SDE formulations.
- Parameters:
img_resolution (int or tuple) – Input image resolution. If int, assumes square images. If tuple, should be (height, width).
in_channels (int) – Number of input color channels.
out_channels (int) – Number of output color channels.
label_dim (int, optional) – Number of class labels. Set to 0 for unconditional generation. Default is 0.
use_fp16 (bool, optional) – Whether to execute the underlying model at FP16 precision for speed. Default is False.
C_1 (float, optional) – Timestep adjustment parameter for low noise levels. Default is 0.001.
C_2 (float, optional) – Timestep adjustment parameter for high noise levels. Default is 0.008.
M (int, optional) – Original number of timesteps in the DDPM formulation. Default is 1000.
model_type (str, optional) – Class name of the underlying U-Net model (‘SongUNet’ or ‘DhariwalUNet’). Default is ‘DhariwalUNet’.
**model_kwargs (dict) – Additional keyword arguments passed to the underlying model.
- u
Learned noise schedule of length M+1.
- Type:
torch.Tensor (buffer)
- model
The underlying U-Net model.
- Type:
torch.nn.Module
Notes
The iDDPM formulation improves upon DDPM with a refined noise schedule.
The noise schedule is learned during initialization via backward recursion.
Uses alpha_bar schedule: ᾱ(j) = sin(π/2 * j/M/(C₂+1))²
Implements discrete-time preconditioning with improved numerical stability.
References
Nichol & Dhariwal, “Improved Denoising Diffusion Probabilistic Models”, 2021.
- __init__(img_resolution, in_channels, out_channels, label_dim=0, use_fp16=False, C_1=0.001, C_2=0.008, M=1000, model_type='DhariwalUNet', **model_kwargs)[source]
Initialize the iDDPMPrecond module.
- Parameters:
in_channels (int) – Input channels.
out_channels (int) – Output channels.
label_dim (int, optional) – Class label dimension.
use_fp16 (bool, optional) – Use FP16 precision.
C_1 (float, optional) – Low noise adjustment.
C_2 (float, optional) – High noise adjustment.
M (int, optional) – Number of timesteps.
model_type (str, optional) – Underlying model class name.
**model_kwargs (dict) – Additional model arguments.
- forward(x, sigma, condition_img=None, class_labels=None, force_fp32=False, **model_kwargs)[source]
Forward pass with iDDPM preconditioning.
- Parameters:
x (torch.Tensor) – Input noisy tensor of shape (batch_size, in_channels, height, width).
sigma (torch.Tensor) – Noise level(s) of shape (batch_size,) or scalar.
condition_img (torch.Tensor, optional) – Condition image tensor of same spatial dimensions as x. Default is None.
class_labels (torch.Tensor, optional) – Class labels for conditioning of shape (batch_size,) or (batch_size, label_dim). Default is None.
force_fp32 (bool, optional) – Force FP32 precision even if use_fp16 is True. Default is False.
**model_kwargs (dict) – Additional keyword arguments passed to the underlying model.
- Returns:
Denoised output of shape (batch_size, out_channels, height, width).
- Return type:
torch.Tensor
Notes
Applies the preconditioning: D(x) = c_skip * x + c_out * F(c_in * x, c_noise)
Where F is the underlying U-Net model.
For iDDPM: c_skip = 1, c_out = -σ, c_in = 1/√(σ²+1)
Condition images are concatenated along the channel dimension.
c_noise maps sigma to discrete timesteps for the underlying model.
- alpha_bar(j)[source]
Compute alpha_bar for timestep j in the improved schedule.
- Parameters:
j (int or torch.Tensor) – Timestep index (0 <= j <= M).
- Returns:
ᾱ(j) = sin(π/2 * j/M/(C₂+1))²
- Return type:
torch.Tensor
- round_sigma(sigma, return_index=False)[source]
Round noise level(s) to the nearest value in the learned schedule.
- Parameters:
sigma (torch.Tensor) – Noise level(s).
return_index (bool, optional) – If True, return the index in the schedule instead of the value. Default is False.
- Returns:
Rounded noise level(s) or indices.
- Return type:
torch.Tensor
- class IPSL_AID.networks.EDMPrecond(*args: Any, **kwargs: Any)[source]
Bases:
ModuleEDM preconditioning for diffusion models.
This class implements the EDM (Elucidating Diffusion Models) preconditioning scheme, which provides a unified framework for various diffusion formulations with optimized preconditioning coefficients.
- Parameters:
img_resolution (int or tuple) – Input image resolution. If int, assumes square images. If tuple, should be (height, width).
in_channels (int) – Number of input color channels.
out_channels (int) – Number of output color channels.
label_dim (int, optional) – Number of class labels. Set to 0 for unconditional generation. Default is 0.
use_fp16 (bool, optional) – Whether to execute the underlying model at FP16 precision for speed. Default is False.
sigma_min (float, optional) – Minimum supported noise level. Default is 0.
sigma_max (float, optional) – Maximum supported noise level. Default is float(‘inf’).
sigma_data (float, optional) – Standard deviation of the training data. Default is 1.0.
model_type (str, optional) – Class name of the underlying U-Net model (‘SongUNet’ or ‘DhariwalUNet’). Default is ‘DhariwalUNet’.
**model_kwargs (dict) – Additional keyword arguments passed to the underlying model.
- model
The underlying U-Net model.
- Type:
torch.nn.Module
Notes
The EDM formulation provides a unified preconditioning scheme that generalizes VP, VE, and other diffusion formulations.
Preconditioning coefficients: c_skip = σ_data²/(σ²+σ_data²) c_out = σ·σ_data/√(σ²+σ_data²), c_in = 1/√(σ_data²+σ²)
c_noise = log(σ)/4 provides the noise conditioning input.
This formulation often yields better sample quality and training stability.
References
Karras et al., “Elucidating the Design Space of Diffusion-Based Generative Models”, 2022.
- __init__(img_resolution, in_channels, out_channels, label_dim=0, use_fp16=False, sigma_min=0, sigma_max=inf, sigma_data=1.0, model_type='DhariwalUNet', **model_kwargs)[source]
Initialize the EDMPrecond module.
- Parameters:
in_channels (int) – Input channels.
out_channels (int) – Output channels.
label_dim (int, optional) – Class label dimension.
use_fp16 (bool, optional) – Use FP16 precision.
sigma_min (float, optional) – Minimum noise level.
sigma_max (float, optional) – Maximum noise level.
sigma_data (float, optional) – Training data standard deviation.
model_type (str, optional) – Underlying model class name.
**model_kwargs (dict) – Additional model arguments.
- forward(x, sigma, condition_img=None, class_labels=None, force_fp32=True, **model_kwargs)[source]
Forward pass with EDM preconditioning.
- Parameters:
x (torch.Tensor) – Input noisy tensor of shape (batch_size, in_channels, height, width).
sigma (torch.Tensor) – Noise level(s) of shape (batch_size,) or scalar.
condition_img (torch.Tensor, optional) – Condition image tensor of same spatial dimensions as x. Default is None.
class_labels (torch.Tensor, optional) – Class labels for conditioning of shape (batch_size,) or (batch_size, label_dim). Default is None.
force_fp32 (bool, optional) – Force FP32 precision even if use_fp16 is True. Default is True.
**model_kwargs (dict) – Additional keyword arguments passed to the underlying model.
- Returns:
Denoised output of shape (batch_size, out_channels, height, width).
- Return type:
torch.Tensor
Notes
Applies the EDM preconditioning: D(x) = c_skip * x + c_out * F(c_in * x, c_noise)
Where F is the underlying U-Net model.
EDM coefficients: c_skip = σ_data²/(σ²+σ_data²) c_out = σ·σ_data/√(σ²+σ_data²), c_in = 1/√(σ_data²+σ²)
Condition images are concatenated along the channel dimension.
c_noise = log(σ)/4 provides the noise conditioning.
- round_sigma(sigma)[source]
Round noise level(s) for compatibility with discrete schedules.
- Parameters:
sigma (float or torch.Tensor) – Noise level(s).
- Returns:
Rounded noise level(s).
- Return type:
torch.Tensor
Notes
In EDM, sigma is continuous, so rounding is typically a no-op unless implementing a discrete schedule variant.
- class IPSL_AID.networks.TestDiffusionNetworks(methodName='runTest', logger=None)[source]
Bases:
TestCaseUnit tests for diffusion network architectures.
IPSL_AID.utils module
- class IPSL_AID.utils.EasyDict[source]
Bases:
dictA dictionary subclass that allows for attribute-style access to its items. This class extends the built-in dict and overrides the __getattr__, __setattr__, and __delattr__ methods to enable accessing dictionary keys as attributes. Original work: Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. Original source: https://github.com/NVlabs/edm
- class IPSL_AID.utils.FileUtils[source]
Bases:
objectUtility class for file and directory operations.
- __init__()[source]
Initialize the FileUtils class. This class does not maintain any state, so the constructor is empty.
- class IPSL_AID.utils.TestEasyDict(methodName='runTest', logger=None)[source]
Bases:
TestCaseUnit tests for EasyDict class.
- __init__(methodName='runTest', logger=None)[source]
Create an instance of the class that will use the named test method when executed. Raises a ValueError if the instance does not have a method with the specified name.
- class IPSL_AID.utils.TestFileUtils(methodName='runTest', logger=None)[source]
Bases:
TestCaseUnit tests for FileUtils class.
- __init__(methodName='runTest', logger=None)[source]
Create an instance of the class that will use the named test method when executed. Raises a ValueError if the instance does not have a method with the specified name.
- test_makedir_multiple_nested_directories()[source]
Test creating multiple nested directories at once.
IPSL_AID.download_ERA5_cds module
- IPSL_AID.download_ERA5_cds.parse_args()[source]
Parse command-line arguments.
- Returns:
Parsed command line arguments as a namespace object with attributes corresponding to each argument.
- Return type:
Notes
ERA5 variable names must match those defined in the Copernicus Climate Data Store catalogue.
- IPSL_AID.download_ERA5_cds.main(logger)[source]
Download ERA5 data from the Copernicus Climate Data Store.
The function follows a structured workflow:
Parse command-line arguments.
Create output directories.
Loop over requested years, variables, and months.
Submit download requests to the CDS API.
Save results as NetCDF files.
Files are skipped if they already exist.
Notes
Data are downloaded at hourly resolution for all days of each month.
The CDS API client requires a valid configuration file. Visit: https://cds.climate.copernicus.eu/
The dataset used depends on whether pressure levels are requested:
reanalysis-era5-single-levels
reanalysis-era5-pressure-levels
IPSL_AID.generate_all_data_ERA5 module
- IPSL_AID.generate_all_data_ERA5.parse_args()[source]
Parse command-line arguments.
- Returns:
Parsed command line arguments as a namespace object with attributes corresponding to each argument.
- Return type:
- Raises:
ValueError – If the number of variables does not match the number of rename variables.
- IPSL_AID.generate_all_data_ERA5.main(logger)[source]
Generate yearly ERA5 datasets from monthly NetCDF files.
The function follows a structured workflow:
Parse command-line arguments.
Load a CSV file containing timestamps to extract.
Loop over years and variables.
Open monthly ERA5 NetCDF files using xarray.
Extract requested timestamps.
Concatenate monthly subsets into yearly datasets.
Rename variables and write compressed NetCDF files.
Notes
ERA5 data is stored monthly, so timestamps are grouped by month.