# Copyright 2026 IPSL / CNRS / Sorbonne University
# Authors: Kazem Ardaneh, Kishanthan Kingston
#
# This work is licensed under the Creative Commons
# Attribution-NonCommercial-ShareAlike 4.0 International License.
# To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc-sa/4.0/
import os
import sys
import time
import argparse
from IPSL_AID.logger import Logger
from IPSL_AID.utils import FileUtils, EasyDict
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from IPSL_AID.dataset import stats, DataPreprocessor
from torch.utils.data import DataLoader
from tqdm import tqdm
from IPSL_AID.model import load_model_and_loss
from IPSL_AID.model_utils import ModelUtils
import torch.optim as optim
import xarray as xr
from IPSL_AID.diagnostics import (
plot_metric_histories,
plot_loss_histories,
plot_average_metrics,
plot_spatiotemporal_histograms,
)
from IPSL_AID.evaluater import (
MetricTracker,
run_validation,
)
[docs]
def parse_args():
"""
Parse command line arguments for diffusion model training and inference.
This function defines and parses all command line arguments required for
configuring and running diffusion model training, resumption, or inference
experiments. It provides comprehensive options for data loading, model
architecture, training hyperparameters, and output management.
Returns
-------
argparse.Namespace
Parsed command line arguments as a namespace object with attributes
corresponding to each argument.
Notes
-----
- Arguments are organized into logical groups: execution mode, data
configuration, training configuration, model architecture, and output.
- Boolean arguments use string conversion with lambda functions for
flexibility (accepts "true"/"false", "True"/"False", etc.).
- Default values are provided for most parameters to allow minimal
configuration for basic usage.
- Some arguments have constraints or choices to ensure valid configurations.
"""
parser = argparse.ArgumentParser(description="Train IPSL-AID diffusion model")
# Execution mode and region
parser.add_argument(
"--debug",
type=lambda x: x.lower() == "true",
default=False,
help="Enable or disable debug mode",
)
parser.add_argument(
"--region",
type=lambda x: x.lower(),
default=None,
choices=["us", "europe", "asia"],
help="region (only used if run_type=inference_regional)",
)
# Run configuration
parser.add_argument(
"--run_type",
type=str,
default="train",
choices=["train", "resume_train", "inference", "inference_regional"],
help="Run type: 'train', 'resume_train', 'inference' or 'inference_regional'",
)
parser.add_argument(
"--model_name",
type=str,
default="diffusion_model",
help="Model name for inference (only for run_type=inference)",
)
parser.add_argument(
"--inference_type",
type=str,
default="direct",
choices=["direct", "sampler"],
help="Inference mode: 'direct' for deterministic inference, 'sample' for stochastic sampling.",
)
# Data configuration
parser.add_argument(
"--varnames_list",
type=str,
nargs="+",
default=["VAR_2T", "VAR_10U", "VAR_10V"],
help="List of variable names to train on",
)
parser.add_argument(
"--constant_varnames_list",
type=str,
nargs="+",
default=["z", "lsm"],
help="List of constant variable names",
)
parser.add_argument(
"--constant_varnames_file",
type=str,
default="ERA5_const_sfc_variables.nc",
help="Path to NetCDF file containing constant variables",
)
parser.add_argument(
"--normalization_types",
type=str,
nargs="+",
default=["VAR_2T=standard", "VAR_10U=standard", "VAR_10V=standard"],
help="Normalization types for each variable as 'var=type' pairs",
)
parser.add_argument(
"--dynamic_covariates",
nargs="+",
type=str,
default=None,
help="List of dynamic covariates",
)
parser.add_argument(
"--dynamic_covariates_dir",
type=str,
default="../data_covariates/",
help="Directory containing NetCDF files for dynamic covariates",
)
parser.add_argument(
"--units_list",
type=str,
nargs="+",
default=["K", "m/s", "m/s"],
help="List of variable units corresponding to varnames_list",
)
# Time range configuration
parser.add_argument(
"--year_start", type=int, default=1980, help="Start year for dataset"
)
parser.add_argument(
"--year_end", type=int, default=2020, help="End year for dataset"
)
parser.add_argument(
"--year_start_test", type=int, default=2020, help="Start year for test dataset"
)
parser.add_argument(
"--year_end_test", type=int, default=2022, help="End year for test dataset"
)
# Training configuration
parser.add_argument(
"--num_epochs", type=int, default=100, help="Number of training epochs"
)
parser.add_argument(
"--batch_size", type=int, default=8, help="Batch size for training"
)
parser.add_argument(
"--tbatch", type=int, default=1, help="Temporal batch length for processing"
)
parser.add_argument(
"--sbatch",
type=int,
default=8,
help="Number of spatial batches per timestamp for the traning",
)
parser.add_argument(
"--train_temporal_batch_mode",
type=str,
default="partial", # or "full"
choices=["full", "partial"],
help="Train temporal batch mode: 'full' for whole sequence, 'partial' for batched",
)
parser.add_argument(
"--tbatch_train",
type=int,
default=1,
help="Temporal batch length for training phase (only used when train_temporal_batch_mode='partial')",
)
parser.add_argument(
"--test_temporal_batch_mode",
type=str,
default="full", # or "different"
choices=["full", "partial"],
help="Test temporal batch mode: 'same' as training, 'different' for test-specific",
)
parser.add_argument(
"--tbatch_test",
type=int,
default=None,
help="Temporal batch length for test phase (only used when test_temporal_batch_mode='partial')",
)
parser.add_argument(
"--test_spatial_batch_mode",
type=str,
default="full", # or "partial"
choices=["full", "partial"],
help="Test spatial batch mode: 'full' for whole domain, 'partial' for batched processing",
)
parser.add_argument(
"--sbatch_test",
type=int,
default=None,
help="Number of spatial batches for test phase (only used when test_spatial_batch_mode=partial)",
)
parser.add_argument(
"--batch_size_lat",
type=int,
default=145,
help="Height of spatial batch in grid points (latitude direction), must be odd",
)
parser.add_argument(
"--batch_size_lon",
type=int,
default=145,
help="Width of spatial batch in grid points (longitude direction), must be odd",
)
parser.add_argument(
"--num_workers", type=int, default=16, help="Number of DataLoader workers"
)
parser.add_argument(
"--learning_rate", type=float, default=1e-4, help="Learning rate"
)
parser.add_argument("--datadir", type=str, required=True, help="Dataset path")
parser.add_argument(
"--per_var_datadir",
type=str,
nargs="+",
default=None,
help="Per-variable data directories as VAR=path pairs",
)
# Data processing parameters
parser.add_argument(
"--time_normalization",
type=str,
default="linear",
help="Type of time normalization",
)
parser.add_argument(
"--epsilon", type=float, default=0.02, help="Epsilon parameter for filtering"
)
parser.add_argument(
"--beta", type=float, default=1.0, help="Beta parameter for loss function"
)
parser.add_argument(
"--margin", type=int, default=8, help="Margin parameter for filtering"
)
# Output configuration
parser.add_argument(
"--main_folder", type=str, default="experiment", help="Main output folder name"
)
parser.add_argument(
"--sub_folder",
type=str,
default="experiment",
help="Sub-folder name for current run",
)
parser.add_argument(
"--prefix", type=str, default="run", help="Prefix for saved files"
)
parser.add_argument(
"--dtype",
type=str,
default="fp32",
choices=["fp16", "fp32", "fp64"],
help="Floating point precision",
)
# Diffusion model configuration
parser.add_argument(
"--arch",
type=str,
default="adm",
choices=["ddpmpp", "ncsnpp", "adm"],
help="Diffusion architecture type",
)
parser.add_argument(
"--precond",
type=str,
default="edm",
choices=["vp", "ve", "edm", "unet"],
help="Diffusion preconditioner",
)
parser.add_argument(
"--in_channels", type=int, default=3, help="Number of variable channels"
)
parser.add_argument(
"--cond_channels", type=int, default=0, help="Number of conditioning channels"
)
parser.add_argument(
"--out_channels", type=int, default=3, help="Number of output channels"
)
# Checkpoint configuration
parser.add_argument(
"--save_model",
type=lambda x: x.lower() == "true",
default=False,
help="Enable model checkpoint saving",
)
parser.add_argument(
"--apply_filter",
type=lambda x: x.lower() == "true",
default=False,
help="Apply fine filtering for coarse data generation (default: True)",
)
parser.add_argument(
"--save_checkpoint_name",
type=str,
default="diffusion_model_checkpoint",
help="The name for saved checkpoints",
)
parser.add_argument(
"--save_per_samples",
type=int,
default=10000,
help="Save checkpoint every N samples",
)
parser.add_argument(
"--load_checkpoint_name",
type=str,
default="model.pth.tar",
help="Checkpoint file to load",
)
parser.add_argument(
"--region_center",
type=float,
nargs=2,
default=None,
help="Latitude and longitude center for regional inference "
"(used only when run_type=inference_regional)",
)
parser.add_argument(
"--region_size",
type=int,
nargs=2,
default=None,
help="Requested regional size in grid points (lat lon) "
"for regional inference (used only when run_type=inference_regional)",
)
# EDM sampler configuration
parser.add_argument(
"--num_steps",
type=int,
default=20,
help="Number of sampling steps used in the diffusion sampler",
)
parser.add_argument(
"--sigma_min",
type=float,
default=0.002,
help="Minimum noise level used by the sampler",
)
parser.add_argument(
"--sigma_max",
type=float,
default=80.0,
help="Maximum noise level used by the sampler",
)
parser.add_argument(
"--rho",
type=float,
default=7.0,
help="Exponent used for EDM time step discretization",
)
parser.add_argument(
"--s_churn",
type=float,
default=40,
help="Stochasticity strength parameter controlling noise injection during sampling",
)
parser.add_argument(
"--s_min",
type=float,
default=0,
help="Minimum noise level at which stochasticity is applied",
)
parser.add_argument(
"--s_max",
type=float,
default=float("inf"),
help="Maximum noise level at which stochasticity is applied",
)
parser.add_argument(
"--s_noise",
type=float,
default=1.0,
help="Noise scale applied when stochasticity is enabled",
)
parser.add_argument(
"--solver",
type=str,
default="heun",
choices=["heun", "euler"],
)
parser.add_argument(
"--compute_crps", type=lambda x: x.lower() == "true", default=False
)
parser.add_argument("--crps_ensemble_size", type=int, default=10)
parser.add_argument("--crps_batch_size", type=int, default=2)
return parser.parse_args()
[docs]
def make_divisible_hw(h, w, n):
"""
Adjust height and width to be divisible by 2**n by decrementing.
This function ensures that both the height (h) and width (w) are divisible
by 2 raised to the power n, which is often required for neural network
architectures that use pooling or strided convolutions multiple times.
Parameters
----------
h : int
Original height value.
w : int
Original width value.
n : int
Exponent for divisor calculation. The divisor is 2**n.
Returns
-------
h_new : int
Adjusted height that is divisible by 2**n.
w_new : int
Adjusted width that is divisible by 2**n.
Notes
-----
- The function decrements h and w until they become divisible by 2**n.
- This is a common requirement for U-Net and other encoder-decoder architectures
that use multiple downsampling and upsampling operations.
- The adjustment is conservative (decrementing) to avoid adding padding,
which might be important for maintaining exact spatial relationships.
"""
div = 2**n
# Fix H
while h % div != 0:
h -= 1
# Fix W
while w % div != 0:
w -= 1
return h, w
[docs]
def setup_directories_and_logging(args):
"""
Set up directory structure and logging infrastructure for experiments.
This function creates a standardized directory hierarchy for organizing
experiment outputs (logs, results, model checkpoints, etc.) and initializes
a logging system with both console and file output.
Parameters
----------
args : argparse.Namespace or EasyDict
Configuration object containing the following attributes:
- main_folder : str
Main experiment folder name.
- sub_folder : str
Sub-folder name for the current run.
- prefix : str
Prefix for log files and outputs.
- datadir : str
Base data directory path.
- constant_varnames_file : str
Filename for constant variables data.
Returns
-------
paths : EasyDict
Dictionary containing paths to created directories:
- logs : str
Path to log files directory.
- results : str
Path to results output directory.
- runs : str
Path to experiment run tracking directory.
- checkpoints : str
Path to model checkpoint directory.
- stats : str
Path to statistics and metrics directory.
- datadir : str
Original data directory path.
- constants : str
Full path to constant variables file.
logger : Logger
Configured logger instance with console and file output.
Notes
-----
- Directory structure:
logs/main_folder/sub_folder/
results/main_folder/sub_folder/
runs/main_folder/sub_folder/
checkpoints/main_folder/sub_folder/
stats/main_folder/sub_folder/
- Log files are named with timestamp: {prefix}_log.txt
- The logger outputs to both console and file by default.
- All directories are created if they don't exist (via FileUtils.makedir).
"""
# now = datetime.datetime.now()
# date_time_str = now.strftime("%Y%m%d_%H%M%S")
current_dir = os.path.abspath(__file__)
parent_dir = os.path.dirname(current_dir)
project_root = os.path.dirname(parent_dir)
paths = EasyDict()
paths.logs = os.path.join(project_root, "logs", args.main_folder, args.sub_folder)
paths.results = os.path.join(
project_root, "results", args.main_folder, args.sub_folder
)
paths.runs = os.path.join(project_root, "runs", args.main_folder, args.sub_folder)
paths.checkpoints = os.path.join(
project_root, "checkpoints", args.main_folder, args.sub_folder
)
paths.stats = os.path.join(project_root, "stats", args.main_folder, args.sub_folder)
paths.stats_dir = os.path.join(project_root, "stats")
paths.datadir = args.datadir
paths.constants = os.path.join(paths.datadir, args.constant_varnames_file)
# Create directories
for path in [paths.logs, paths.results, paths.runs, paths.checkpoints, paths.stats]:
FileUtils.makedir(path)
# Setup logger
log_file = os.path.join(paths.logs, f"{args.prefix}_log.txt")
logger = Logger(
console_output=True, file_output=True, log_file=log_file, record=True
)
logger.show_header("Main")
return paths, logger
[docs]
def log_configuration(args, paths, logger):
"""
Log all configuration parameters to the provided logger.
This function comprehensively logs all experiment configuration parameters
including execution mode, data settings, training hyperparameters, model
architecture, and directory structure. It provides a clear overview of the
experiment setup for reproducibility and debugging.
Parameters
----------
args : argparse.Namespace or EasyDict
Configuration object containing all experiment parameters.
paths : EasyDict
Dictionary containing paths to various experiment directories.
logger : Logger
Logger instance for outputting configuration information.
Notes
-----
- The function organizes parameters into logical sections for readability.
- Includes both user-specified parameters and derived directory paths.
- Provides warnings for important configuration choices (e.g., disabled
checkpoint saving).
- The output is formatted with clear section headers and indentation.
"""
logger.info("=" * 60)
logger.info("CONFIGURATION PARAMETERS")
logger.info("=" * 60)
# Execution mode and system
logger.info("Execution Mode:")
logger.info(f" └── Debug: {args.debug}")
logger.info(f" └── Run type: '{args.run_type}'")
logger.info(f" └── Inference type: '{args.inference_type}'")
logger.info(f" └── Region: '{args.region}'")
logger.info(f" └── Apply filter: {args.apply_filter}")
# Checkpoint configuration
logger.info("\nCheckpoint Configuration:")
logger.info(f" └── Save model: {args.save_model}")
logger.info(f" └── Save checkpoint name: '{args.save_checkpoint_name}'")
logger.info(f" └── Load checkpoint name: '{args.load_checkpoint_name}'")
logger.info(f" └── Save per samples: {args.save_per_samples}")
if args.model_name:
logger.info(f" └── Model name: '{args.model_name}'")
else:
logger.info(" └── Model name: Not specified")
# Data configuration
logger.info("\nData Configuration:")
logger.info(f" └── Variable names: {args.varnames_list}")
logger.info(f" └── Constant variables: {args.constant_varnames_list}")
logger.info(f" └── Constant variables file: '{args.constant_varnames_file}'")
logger.info(
f" └── Dynamic covariates: {args.dynamic_covariates if args.dynamic_covariates else 'None'}"
)
logger.info(f" └── Dynamic covariates dir: '{args.dynamic_covariates_dir}'")
logger.info(f" └── Units list: {args.units_list}")
logger.info(f" └── Normalization types: {args.normalization_types}")
logger.info(f" └── Data directory: '{args.datadir}'")
# Time range configuration
logger.info("\nTime Range Configuration:")
logger.info(f" └── Training years: {args.year_start}-{args.year_end}")
logger.info(f" └── Test years: {args.year_start_test}-{args.year_end_test}")
logger.info(f" └── Time normalization: '{args.time_normalization}'")
# Training configuration
logger.info("\nTraining Configuration:")
logger.info(f" └── Number of epochs: {args.num_epochs}")
logger.info(f" └── Batch size: {args.batch_size}")
logger.info(f" └── Number of workers: {args.num_workers}")
logger.info(f" └── Learning rate: {args.learning_rate}")
logger.info(f" └── Dtype: '{args.dtype}'")
# Spatial-temporal batching
logger.info("\nSpatial-Temporal Batching:")
logger.info(f" └── Spatial batches: {args.sbatch}")
logger.info(f" └── Temporal time steps: {args.tbatch}")
logger.info(f" └── Batch size (lat): {args.batch_size_lat} grid points")
logger.info(f" └── Batch size (lon): {args.batch_size_lon} grid points")
# Data processing parameters
logger.info("\nData Processing Parameters:")
logger.info(f" └── Epsilon: {args.epsilon}")
logger.info(f" └── Beta: {args.beta}")
logger.info(f" └── Margin: {args.margin}")
# Model architecture
logger.info("\nModel Architecture:")
logger.info(f" └── Architecture: '{args.arch}'")
logger.info(f" └── Preconditioner: '{args.precond}'")
logger.info(f" └── Input channels: {args.in_channels}")
logger.info(f" └── Conditioning channels: {args.cond_channels}")
logger.info(f" └── Output channels: {args.out_channels}")
# Sampler configuration
logger.info("\nSampler Configuration:")
logger.info(f" └── solver: {args.solver}")
logger.info(f" └── num_steps: {args.num_steps}")
logger.info(f" └── sigma_min: {args.sigma_min}")
logger.info(f" └── sigma_max: {args.sigma_max}")
logger.info(f" └── rho: {args.rho}")
logger.info(f" └── s_churn: {args.s_churn}")
logger.info(f" └── s_min: {args.s_min}")
logger.info(f" └── s_max: {args.s_max}")
logger.info(f" └── s_noise: {args.s_noise}")
# Output configuration
logger.info("\nOutput Configuration:")
logger.info(f" └── Main folder: '{args.main_folder}'")
logger.info(f" └── Sub folder: '{args.sub_folder}'")
logger.info(f" └── Prefix: '{args.prefix}'")
# Directory paths (for reference)
logger.info("\nGenerated Directory Paths:")
logger.info(f" └── Logs directory: '{paths.logs}'")
logger.info(f" └── Results directory: '{paths.results}'")
logger.info(f" └── TensorBoard runs: '{paths.runs}'")
logger.info(f" └── Model checkpoints: '{paths.checkpoints}'")
logger.info(f" └── Statistics: '{paths.stats}'")
logger.info(f" └── Statistics: '{paths.stats_dir}'")
logger.info(f" └── Data directory: '{paths.datadir}'")
logger.info(f" └── Constants file: '{paths.constants}'")
# Per-variable data directories
logger.info("\nPer-variable data directories:")
if args.per_var_datadir is None:
logger.info(" └── Using default data directory for all variables")
else:
for item in args.per_var_datadir:
var, path = item.split("=")
logger.info(f" └── {var}: '{path}'")
# Special notes based on run type
logger.info("\nRun Type Notes:")
if args.run_type == "train":
logger.info(" └── Mode: Training from scratch")
elif args.run_type == "resume_train":
logger.info(" └── Mode: Resuming training from checkpoint")
elif args.run_type == "inference":
logger.info(" └── Mode: Inference only")
logger.info(f" └── Inference type: {args.inference_type}")
logger.info(f" └── Compute CRPS: {args.compute_crps}")
elif args.run_type == "inference_regional":
logger.info(" └── Mode: Regional inference")
logger.info(f" └── Region center: {args.region_center}")
logger.info(f" └── Region size: {args.region_size}")
logger.info(f" └── Compute CRPS: {args.compute_crps}")
# Checkpoint saving strategy
if args.save_model:
logger.info("\nCheckpoint Saving Strategy:")
logger.info(" └── Checkpoints enabled:")
logger.info(f" └── Saving every {args.save_per_samples:,} samples")
logger.info(" └── Saving epoch checkpoints every 10 epochs")
logger.info(" └── Saving best model based on validation MAE")
logger.info(" └── Saving final model at end of training")
logger.info(" └── Each checkpoint includes:")
logger.info(" └── Model state dict")
logger.info(" └── Optimizer state")
logger.info(" └── Training/validation history")
logger.info(" └── Metrics history")
logger.info(" └── Training state (epoch, samples processed)")
logger.info(" └── Configuration arguments")
else:
logger.info("\nCheckpoint Saving:")
logger.info(" └── DISABLED - No checkpoints will be saved!")
logger.info(
" └── Warning: Training progress cannot be resumed if interrupted"
)
logger.info("=" * 60)
[docs]
def setup_data_paths(args, paths, logger):
"""
Set up data file paths, load datasets, and compute normalization statistics.
This function handles the data loading pipeline for training and validation
datasets. It manages per-variable data paths, concatenates multi-year data
for each variable, computes normalization statistics, and sets up variable
mappings and normalization types.
Parameters
----------
args : argparse.Namespace or EasyDict
Configuration object containing runtime options such as training years,
execution mode, variable names, and normalization specifications.
paths : EasyDict
Dictionary containing directory paths.
Expected keys:
- datadir
- stats
logger : logging.Logger
Logger instance for output messages.
Returns
-------
norm_mapping : dict
Mapping from variable name to normalization statistics.
steps : EasyDict
Grid dimension information (time, latitude, longitude).
normalization_type : EasyDict
Mapping from variable name to normalization method.
index_mapping : dict
Mapping from variable name to array index.
train_ds : xarray.Dataset or None
Training dataset, or None if run_type is ``inference``.
valid_ds : xarray.Dataset
Validation dataset.
Notes
-----
- Per-variable data directories may be provided using ``VAR=path`` syntax.
- Training data is only loaded when ``run_type`` is not ``inference``.
- Normalization statistics are computed on the validation dataset.
- Variables from different files and years are merged into a single dataset.
"""
train_years = np.arange(args.year_start, args.year_end + 1)
test_years = np.arange(args.year_start_test, args.year_end_test + 1)
logger.info(f"Training years: {train_years}")
logger.info(f"Testing years: {test_years}")
# ------------------------------------------------------------------
# Per-variable data paths configuration (using EasyDict)
# ------------------------------------------------------------------
# Default path is used as a fallback when a variable-specific path
# is not provided via the command line.
per_var_paths = EasyDict()
per_var_paths.default = paths.datadir
# logger.info(f"[DEBUG] args.per_var_datadir = {args.per_var_datadir}")
# Per-variable data directories passed as VAR=path
if args.per_var_datadir is not None:
for item in args.per_var_datadir:
var, path = item.split("=")
per_var_paths[var] = path
# logger.info(f"Per-variable path: {var} → {path}")
logger.info(f"[Data paths] default → {per_var_paths.default}")
# --------------------------
# Training datasets
# --------------------------
train_ds = None
# if args.run_type != "inference":
if args.run_type not in ["inference", "inference_regional"]:
logger.info("Pre-loading training datasets...")
train_var_datasets = []
# Load each variable independently, then concatenate along time
for var in args.varnames_list:
base_path = per_var_paths.get(var, per_var_paths.default)
train_filenames = [f"{base_path}/samples_{year}.nc" for year in train_years]
logger.info(
f"{var} training files:\n[\n"
+ "\n".join(f" {f}" for f in train_filenames)
+ "\n]"
)
# Open first dataset and sort by time
ds_var = xr.open_dataset(train_filenames[0]).sortby("time")
# Loop through remaining files and concatenate along time
for fname in train_filenames[1:]:
ds_next = xr.open_dataset(fname).sortby("time")
ds_var = xr.concat([ds_var, ds_next], dim="time")
# Keep only the current variable before merging
train_var_datasets.append(ds_var[[var]])
# Merge all variables into a single training dataset
train_ds = xr.merge(train_var_datasets).load()
logger.info(f"Training dataset concatenated: {train_ds.sizes}")
else:
logger.info("Inference mode: skipping training dataset loading")
# --------------------------
# Validation datasets
# --------------------------
logger.info("Pre-loading validation datasets...")
valid_var_datasets = []
# Load each variable independently, then concatenate along time
for var in args.varnames_list:
base_path = per_var_paths.get(var, per_var_paths.default)
valid_filenames = [f"{base_path}/samples_{year}.nc" for year in test_years]
logger.info(
f"{var} validation files:\n[\n"
+ "\n".join(f" {f}" for f in valid_filenames)
+ "\n]"
)
# Open first validation dataset
ds_var = xr.open_dataset(valid_filenames[0]).sortby("time")
# Loop through remaining validation files
for fname in valid_filenames[1:]:
ds_next = xr.open_dataset(fname).sortby("time")
ds_var = xr.concat([ds_var, ds_next], dim="time")
# Keep only the current variable before merging
valid_var_datasets.append(ds_var[[var]])
# Merge all variables into a single validation dataset
valid_ds = xr.merge(valid_var_datasets).load()
logger.info(f"Validation dataset concatenated: {valid_ds.sizes}")
# norm_mapping, steps = stats(train_ds, logger, paths.stats)
norm_mapping, steps = stats(valid_ds, logger, paths.stats_dir)
assert hasattr(steps, "time"), "steps does not contain a 'time' attribute"
# Setup normalization types
normalization_type = EasyDict()
for mapping in args.normalization_types:
if "=" in mapping:
var_name, norm_type = mapping.split("=")
normalization_type[var_name] = norm_type
else:
logger.warning(
f"Invalid normalization mapping: {mapping}. Expected 'VAR_NAME=type'"
)
# Verify all variables have normalization types
for var in args.varnames_list:
if var not in normalization_type:
logger.warning(
f"Variable '{var}' not found in normalization_types. Defaulting to 'standard'"
)
normalization_type[var] = "standard"
logger.info(f"Normalization types: {normalization_type}")
# Create index mapping
index_mapping = {var_name: i for i, var_name in enumerate(args.varnames_list)}
for var_name, idx in index_mapping.items():
logger.info(f"{var_name}: Index {idx}")
# Log normalization statistics
logger.info("------ Normalization Statistics (norm_mapping) ------")
for key, st in norm_mapping.items():
logger.info(
f"\n[{key}]\n"
f" └── vmin={getattr(st, 'vmin', None)}\n"
f" └── vmax={getattr(st, 'vmax', None)}\n"
f" └── vmean={getattr(st, 'vmean', None)}\n"
f" └── vstd={getattr(st, 'vstd', None)}\n"
f" └── median={getattr(st, 'median', None)}\n"
f" └── iqr={getattr(st, 'iqr', None)}\n"
f" └── q1={getattr(st, 'q1', None)}\n"
f" └── q3={getattr(st, 'q3', None)}"
)
logger.info("------------------------------------------------------")
return norm_mapping, steps, normalization_type, index_mapping, train_ds, valid_ds
[docs]
def setup_training_environment(args, logger):
"""
Set up the training environment including device selection, random seeds,
and data type configuration.
Parameters
----------
args : argparse.Namespace or EasyDict
Configuration object containing precision settings.
logger : logging.Logger
Logger instance for output messages.
Returns
-------
device : torch.device
Selected computing device.
torch_dtype : torch.dtype
PyTorch data type.
np_dtype : numpy.dtype
NumPy data type.
use_fp16 : bool
Whether half precision is enabled.
Notes
-----
- Sets global random seeds for reproducibility.
- Automatically selects CUDA if available.
- Enables PyTorch anomaly detection for debugging.
"""
# Set random seeds for reproducibility
random_state = 0
np.random.seed(random_state)
torch.manual_seed(random_state)
torch.set_printoptions(precision=5)
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
# Enable anomaly detection for debugging
torch.autograd.set_detect_anomaly(True)
# Setup data types
torch_dtype_map = EasyDict(
{"fp16": torch.float16, "fp32": torch.float32, "fp64": torch.float64}
)
np_dtype_map = EasyDict(
{"fp16": np.float16, "fp32": np.float32, "fp64": np.float64}
)
torch_dtype = torch_dtype_map[args.dtype]
np_dtype = np_dtype_map[args.dtype]
use_fp16 = torch_dtype == torch.float16
return device, torch_dtype, np_dtype, use_fp16
[docs]
def create_data_loaders(
args,
paths,
norm_mapping,
steps,
normalization_type,
index_mapping,
torch_dtype,
np_dtype,
logger,
mode="train",
run_type="train",
train_loaded_dfs=None,
valid_loaded_dfs=None,
):
"""
Create data loaders for training, validation, or inference.
Parameters
----------
args : argparse.Namespace or EasyDict
Runtime configuration options.
paths : EasyDict
Directory paths including constants files.
norm_mapping : dict
Normalization statistics per variable.
steps : EasyDict
Grid dimension information.
normalization_type : EasyDict
Normalization method per variable.
index_mapping : dict
Variable-to-index mapping.
torch_dtype : torch.dtype
PyTorch tensor dtype.
np_dtype : numpy.dtype
NumPy array dtype.
logger : logging.Logger
Logger instance.
mode : str, optional
Either ``train`` or ``validation``.
run_type : str, optional
Execution mode (train, resume_train, inference).
train_loaded_dfs : dict, optional
Pre-loaded training datasets.
valid_loaded_dfs : dict, optional
Pre-loaded validation datasets.
Returns
-------
data_loader : torch.utils.data.DataLoader
Configured data loader.
img_res : tuple of int
Spatial resolution used by the model.
dataset : DataPreprocessor
Underlying dataset object.
Raises
------
ValueError
If ``mode`` is invalid.
AssertionError
If required datasets are missing.
Notes
-----
- Spatial dimensions are adjusted to be divisible by powers of two.
- Validation falls back to training data if validation data is unavailable.
- Data is assumed to be pre-loaded into memory.
"""
# Validate mode parameter
if mode not in ["train", "validation"]:
raise ValueError(f"Invalid mode: {mode}. Must be 'train' or 'validation'")
# Adjust resolution for model compatibility
n = args.depth if hasattr(args, "Unet_depth") else 3
h, w = make_divisible_hw(args.batch_size_lat, args.batch_size_lon, n)
img_res = (h, w)
logger.info(f"Creating {mode} dataset:")
logger.info(
f" └── Original resolution: ({args.batch_size_lat}, {args.batch_size_lon})"
)
logger.info(f" └── Adjusted to divisible-by-2^{n} resolution: {img_res}")
# Determine dataset parameters based on mode with assertions
if mode == "train":
years = np.arange(args.year_start, args.year_end + 1)
assert (
train_loaded_dfs is not None
), "train_loaded_dfs must be provided for training mode"
assert len(train_loaded_dfs) > 0, "train_loaded_dfs must not be empty"
loaded_dfs = train_loaded_dfs
shuffle = True # Shuffle for training
tbatch = args.tbatch
sbatch = args.sbatch
else: # validation
years = np.arange(args.year_start_test, args.year_end_test + 1)
if valid_loaded_dfs is None or len(valid_loaded_dfs) == 0:
logger.warning(
"No validation data provided, using training data for validation"
)
assert (
train_loaded_dfs is not None
), "train_loaded_dfs must be provided as fallback for validation"
assert len(train_loaded_dfs) > 0, "train_loaded_dfs must not be empty"
loaded_dfs = train_loaded_dfs
years = np.arange(
args.year_start, args.year_start + 1
) # Use first training year
else:
loaded_dfs = valid_loaded_dfs
shuffle = False # No shuffle for validation
# Use smaller batches for validation to save memory
tbatch = args.batch_size # same as torch batch size
sbatch = args.sbatch # Half the spatial batches
logger.info(f" └── {mode} years: {years}")
logger.info(
f" └── {mode} parameters - tbatch: {tbatch}, sbatch: {sbatch}, shuffle: {shuffle}"
)
logger.info(f" └── Number of {mode} files: {len(years)}")
# Create dataset with pre-loaded data
dataset = DataPreprocessor(
years=years, # List of years
loaded_dfs=loaded_dfs, # Pre-loaded datasets dictionary
constants_file_path=paths.constants,
varnames_list=args.varnames_list,
units_list=args.units_list,
in_shape=(16, 32),
batch_size_lat=h,
batch_size_lon=w,
steps=steps,
tbatch=tbatch,
sbatch=sbatch,
debug=args.debug,
mode=mode,
run_type=run_type,
dynamic_covariates=args.dynamic_covariates,
dynamic_covariates_dir=args.dynamic_covariates_dir,
time_normalization=args.time_normalization,
norm_mapping=norm_mapping, # Same normalization for consistency
index_mapping=index_mapping,
normalization_type=normalization_type,
constant_variables=args.constant_varnames_list,
epsilon=args.epsilon,
margin=args.margin,
dtype=(torch_dtype, np_dtype), # Same dtype for consistency
apply_filter=args.apply_filter,
region_center=args.region_center,
region_size=args.region_size,
logger=logger,
)
# Create data loader - set num_workers=0 since data is pre-loaded
data_loader = DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=shuffle,
num_workers=0, # Data is pre-loaded, no workers needed
pin_memory=True,
)
logger.info(f" {mode} dataset size: {len(dataset)}")
logger.info(f" {mode} data loader batches: {len(data_loader)}")
return data_loader, img_res, dataset
[docs]
def setup_model(args, img_res, use_fp16, device, logger):
"""
Set up the diffusion model and its loss function.
Parameters
----------
args : argparse.Namespace or EasyDict
Model configuration options.
img_res : tuple of int
Image resolution (height, width).
use_fp16 : bool
Whether FP16 precision is enabled.
device : torch.device
Target device.
logger : logging.Logger
Logger instance.
Returns
-------
model : torch.nn.Module
Initialized model.
loss_fn : callable
Loss function.
Raises
------
ValueError
If an unsupported time normalization is specified.
Notes
-----
- Label dimensionality depends on the selected time normalization.
- Model creation is delegated to ``load_model_and_loss``.
"""
logger.info("------ Model and Loss info ------")
# Determine label_dim based on time_normalization
if args.time_normalization == "linear":
label_dim = 2
elif args.time_normalization == "cos_sin":
label_dim = 4
else:
raise ValueError(f"Unsupported time_normalization: {args.time_normalization}")
logger.info(
f"Label dimension: {label_dim} (time_normalization: {args.time_normalization})"
)
opts = EasyDict(
{
"arch": args.arch,
"precond": args.precond,
"img_resolution": img_res,
"in_channels": args.in_channels,
"cond_channels": args.cond_channels,
"out_channels": args.out_channels,
"label_dim": label_dim,
"use_fp16": use_fp16,
}
)
model, loss_fn = load_model_and_loss(opts, logger=logger, device=device)
return model, loss_fn
[docs]
def resolve_region_center(args):
"""
Resolve the regional inference center coordinates.
This function enforces the logic for regional inference:
user can provide either region or region_center
Parameters
----------
args : argparse.Namespace
Parsed command line arguments.
Returns
-------
tuple or None
(lat, lon) if inference_regional,
None otherwise.
Raises
------
ValueError
If both region and region_center are provided,
neither is provided in inference_regional mode,
an unknown region name is specified.
Notes
-----
- Predefined regions map to fixed center coordinates.
- Longitude follows the convention [0, 360].
"""
if args.run_type != "inference_regional":
return None
# predefined region center coordinates (lat, lon)
# longitude convention: [0, 360]
region_coords = {
"us": (39.0, 262.0),
"europe": (50.0, 10.0),
"asia": (35.0, 100.0),
}
# case 1: both provided (error)
if args.region is not None and args.region_center is not None:
raise ValueError("provide either --region or --region_center, not both.")
# case 2: predefined region
if args.region is not None:
if args.region not in region_coords:
raise ValueError(
f"invalid region '{args.region}'. "
f"available regions: {list(region_coords.keys())}"
)
return region_coords[args.region]
# case 3: explicit coordinates
if args.region_center is not None:
if len(args.region_center) != 2:
raise ValueError("--region_center must contain exactly two values: lat lon")
return tuple(args.region_center)
# case 4: nothing provided (error)
raise ValueError(
"for run_type='inference_regional', you must provide "
"either --region (us, europe, asia) or --region_center lat lon."
)
[docs]
def main():
"""
Main training and inference pipeline for IPSL-AID diffusion models.
This function orchestrates the entire training and inference process for
diffusion-based generative models on weather and climate data. It handles:
- Argument parsing and configuration
- Directory setup and logging
- Data loading and preprocessing
- Model initialization and checkpoint management
- Training loop with validation
- Inference execution
- Visualization and result saving
The pipeline supports multiple modes of operation:
- Training from scratch (run_type='train')
- Resuming training from a checkpoint (run_type='resume_train')
- Running inference with a trained model (run_type='inference')
The function follows a structured workflow:
1. Parse command line arguments
2. Setup directories and logging
3. Load and preprocess data
4. Initialize model, optimizer, and loss function
5. Handle checkpoint loading if required
6. Execute training loop with validation or run inference
7. Generate plots and save results
Parameters
----------
None
All configuration is provided via command line arguments.
Returns
-------
None
Notes
-----
- The function uses argparse for command line argument parsing.
- All output (logs, checkpoints, results) is saved to organized directories.
- Training includes validation at each epoch with metrics tracking.
- Inference mode runs validation metrics without training.
- Mixed precision training (FP16) is supported when available.
- Model checkpoints include full training state for resumption.
- TensorBoard integration is provided for training visualization.
Raises
------
FileNotFoundError
If required checkpoints are not found for resumption or inference.
RuntimeError
If inference mode is requested without validation data.
ValueError
If invalid configurations are provided.
"""
# Check for --version flag
if "--version" in sys.argv or "-V" in sys.argv:
from IPSL_AID import __version__, __author__, __license__
print(f"IPSL-AID version {__version__}")
print(f"Copyright (c) 2026 {__author__}")
print(f"License: {__license__}")
print("Repository: https://github.com/kardaneh/IPSL-AID")
sys.exit(0)
# Parse command line arguments
args = parse_args()
args.region_center = resolve_region_center(args)
# Setup directories and logging
paths, logger = setup_directories_and_logging(args)
# Log configuration parameters
log_configuration(args, paths, logger)
# Setup data paths and normalization statistics
(
norm_mapping,
steps,
normalization_type,
index_mapping,
train_loaded_dfs,
valid_loaded_dfs,
) = setup_data_paths(args, paths, logger)
# Setup training environment (device, data types, random seeds)
device, torch_dtype, np_dtype, use_fp16 = setup_training_environment(args, logger)
# Setup TensorBoard for visualization
# if args.run_type != "inference":
if args.run_type not in ["inference", "inference_regional"]:
writer = SummaryWriter(f"runs/{args.main_folder}/{args.sub_folder}/")
logger.info(
f"TensorBoard enabled at: runs/{args.main_folder}/{args.sub_folder}/"
)
else:
writer = None
logger.info("TensorBoard disabled for inference mode")
# Create data loaders
# if args.run_type != "inference":
if args.run_type not in ["inference", "inference_regional"]:
train_loader, img_res, train_dataset = create_data_loaders(
args,
paths,
norm_mapping,
steps,
normalization_type,
index_mapping,
torch_dtype,
np_dtype,
logger,
mode="train",
train_loaded_dfs=train_loaded_dfs,
)
logger.info(f"Training dataset loaded with image resolution: {img_res}")
else:
logger.info("Inference mode: Skipping training data loader creation")
train_loader, img_res, train_dataset = None, None, None
if valid_loaded_dfs is not None:
valid_loader, valid_img_res, valid_dataset = create_data_loaders(
args,
paths,
norm_mapping,
steps,
normalization_type,
index_mapping,
torch_dtype,
np_dtype,
logger,
mode="validation",
run_type=args.run_type,
valid_loaded_dfs=valid_loaded_dfs,
)
logger.info(f"Validation dataset loaded with image resolution: {valid_img_res}")
# if args.run_type == "inference":
if args.run_type in ["inference", "inference_regional"]:
img_res = valid_img_res # Use validation image resolution for inference
else:
valid_loader, valid_img_res, valid_dataset = None, None, None
logger.warning("No validation dataset created (test files not found or empty)")
if args.run_type == "inference" and valid_loader is None:
logger.error(
"Inference mode requires a validation dataset, but none was created."
)
raise RuntimeError("Cannot run inference without a validation dataset.")
# Setup model and loss function
model, loss_fn = setup_model(args, img_res, use_fp16, device, logger)
# Log model information
_ = ModelUtils.get_parameter_number(model, logger=logger)
# Setup optimizer, scheduler, and training components
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, factor=0.5, patience=5
)
if device.type == "cuda":
from torch.amp import GradScaler
scaler = GradScaler("cuda")
logger.info("GradScaler initialized for CUDA")
else:
scaler = None
logger.info("GradScaler disabled (AMP not supported on CPU)")
# Setup metrics tracking
metric_names = ["MAE", "NMAE", "RMSE", "R2", "PEARSON", "KL"]
# metric_funcs = {"MAE": mae_all, "NMAE": nmae_all, "RMSE": rmse_all, "R2": r2_all}
# Initialize validation metrics with ALL expected keys from run_validation
valid_metrics_keys = []
for k in args.varnames_list:
for m in metric_names:
valid_metrics_keys.append(f"{k}_pred_vs_fine_{m}") # Model predictions
valid_metrics_keys.append(f"{k}_coarse_vs_fine_{m}") # Coarse baselines
for m in metric_names:
valid_metrics_keys.append(f"average_pred_vs_fine_{m}") # Overall averages
valid_metrics_keys.append(
f"average_coarse_vs_fine_{m}"
) # Overall coarse averages
valid_metrics_history = {key: [] for key in valid_metrics_keys}
train_loss_history = [0] * args.num_epochs
valid_loss_history = [0] * args.num_epochs
train_loss = MetricTracker()
logger.info(f"Tracking metrics: {metric_names}")
logger.info(f"Validation metrics: {list(valid_metrics_history.keys())}")
# Setup training state tracking
start_epoch = 0
samples_processed = 0
batches_processed = 0
avg_val_loss = float("inf")
best_val_loss = float("inf")
avg_epoch_loss = float("inf")
best_epoch = 0
# Handle checkpoint loading if needed
# if args.run_type in ["resume_train", "inference"]:
if args.run_type in ["resume_train", "inference", "inference_regional"]:
checkpoint_path = os.path.join(paths.checkpoints, args.load_checkpoint_name)
if args.debug:
logger.info("=" * 60)
logger.info("CHECKPOINT LOADING DEBUG INFO")
logger.info("=" * 60)
logger.info(f"Run type: {args.run_type}")
logger.info(f"Checkpoint path: {checkpoint_path}")
logger.info(f"Model checkpoint directory: {paths.checkpoints}")
logger.info(f"Load checkpoint name: {args.load_checkpoint_name}")
logger.info(
f"Full checkpoint path exists: {os.path.exists(checkpoint_path)}"
)
if os.path.exists(checkpoint_path):
if args.debug:
logger.info(f"Loading checkpoint from: {checkpoint_path}")
(
epoch,
samples_processed,
batches_processed,
best_val_loss,
best_epoch,
checkpoint,
) = ModelUtils.load_training_checkpoint(
checkpoint_path, model, optimizer, device, logger=logger
)
avg_val_loss = best_val_loss
if args.debug:
logger.info("Checkpoint loaded successfully")
logger.info(f" └── Epoch: {epoch}")
logger.info(f" └── Samples processed: {samples_processed:,}")
logger.info(f" └── Batches processed: {batches_processed:,}")
logger.info(f" └── Best validation loss: {best_val_loss:.6f}")
logger.info(f" └── Best epoch: {best_epoch}")
logger.info("Checkpoint keys available:")
for key in checkpoint.keys():
if isinstance(checkpoint[key], (list, dict)):
if key in ["train_loss_history", "valid_loss_history"]:
logger.info(
f" └── {key}: list with {len(checkpoint[key])} elements"
)
elif key == "valid_metrics_history":
logger.info(
f" └── {key}: dict with {len(checkpoint[key])} keys"
)
elif key == "args":
logger.info(
f" └── {key}: dict with {len(checkpoint[key])} arguments"
)
else:
logger.info(f" └── {key}: {type(checkpoint[key]).__name__}")
else:
logger.info(f" └── {key}: {checkpoint[key]}")
if args.run_type == "resume_train":
start_epoch = epoch + 1
if args.debug:
logger.info(f"Resuming training from epoch {start_epoch}")
logger.info(
f"Current train_loss_history length: {len(train_loss_history)}"
)
logger.info(
f"Current valid_loss_history length: {len(valid_loss_history)}"
)
# Load history if available
if "train_loss_history" in checkpoint:
train_loss_history[:start_epoch] = checkpoint["train_loss_history"][
:start_epoch
]
if "valid_loss_history" in checkpoint:
valid_loss_history[:start_epoch] = checkpoint["valid_loss_history"][
:start_epoch
]
if "valid_metrics_history" in checkpoint:
for key in valid_metrics_history:
if key in checkpoint["valid_metrics_history"]:
valid_metrics_history[key] = checkpoint[
"valid_metrics_history"
][key]
logger.info(f"Resuming training from epoch {start_epoch}")
else:
logger.info(f"Model loaded for {args.run_type}")
else:
logger.error(f"Checkpoint not found at: {checkpoint_path}")
logger.error(
f"Cannot do run type: {args.run_type} without checkpoint. Exiting."
)
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
# Prepare for model saving
if args.save_model:
logger.info("Model saving enabled")
# save_counter = 0
# Setup GPU support
if torch.cuda.is_available():
model.cuda()
if torch.cuda.device_count() > 1:
logger.info(f"Using {torch.cuda.device_count()} GPUs!")
model = torch.nn.DataParallel(model)
# ============================================================================
# INFERENCE MODE - Run validation directly
# ============================================================================
# if args.run_type == "inference":
if args.run_type in ["inference", "inference_regional"]:
logger.info("=" * 60)
logger.info("RUNNING INFERENCE/VALIDATION")
logger.info("=" * 60)
logger.info(
f"Validation dataset temporal range: "
f"{valid_dataset.stime} → {valid_dataset.etime} "
f"(total timesteps = {valid_dataset.etime - valid_dataset.stime})"
)
if args.compute_crps and args.inference_type != "sampler":
logger.warning(
"CRPS requested but inference_type is not 'sampler'. "
"CRPS requires probabilistic sampling. Disabling CRPS."
)
args.compute_crps = False
# Run validation (which is essentially inference on validation data)
assert (
valid_loader is not None
), "Validation data loader must be available for inference"
avg_val_loss, val_metrics = run_validation(
model,
valid_dataset,
valid_loader,
loss_fn,
norm_mapping,
normalization_type,
index_mapping,
args,
steps,
device,
logger,
epoch=0,
writer=writer,
plot_every_n_epochs=1, # Always plot for inference
paths=paths,
compute_crps=args.compute_crps, # True for diffusion models, False for unet
crps_ensemble_size=args.crps_ensemble_size,
crps_batch_size=args.crps_batch_size,
)
logger.info("Inference completed successfully!")
exit(0)
logger.info("Start training...")
# Training loop with validation
for epoch in range(start_epoch, args.num_epochs):
train_dataset.new_epoch()
model.train()
train_loss.reset()
previous_time = time.time()
loop = tqdm(
enumerate(train_loader),
total=len(train_loader),
desc=f"Training Epoch {epoch}",
)
for batch_idx, batch in loop:
# Move data to device
features = batch["inputs"].to(device)
targets = batch["targets"].to(device)
# lat_batch = batch["corrdinates"]["lat"].to(device)
# lon_batch = batch["corrdinates"]["lon"].to(device)
if epoch == 0 and batch_idx == 0:
logger.info(
f"batch idx:{batch_idx}, features shape:{features.shape}, targets shape:{targets.shape}"
)
# Prepare labels based on time normalization
if args.time_normalization == "linear":
labels = torch.stack(
(batch["doy"].to(device), batch["hour"].to(device)), dim=1
)
elif args.time_normalization == "cos_sin":
labels = torch.stack(
(
batch["doy_sin"].to(device),
batch["doy_cos"].to(device),
batch["hour_sin"].to(device),
batch["hour_cos"].to(device),
),
dim=1,
)
# Zero gradients
optimizer.zero_grad()
# Mixed precision training
with torch.amp.autocast(device_type=device.type, dtype=torch_dtype):
loss = loss_fn(model, targets, features, labels)
loss = loss.mean()
# Backward pass with gradient scaling
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
# Update loss trackers
train_loss.update(loss.item(), targets.shape[0])
# Calculate timing
current_time = time.time()
batch_time = current_time - previous_time
previous_time = current_time
# Update progress bar
loop.set_postfix(
{
"Loss": f"{loss.item():.4f}",
"Avg Loss": f"{train_loss.getmean():.4f}",
"Time": f"{batch_time:.2f}s",
}
)
# End of epoch - Run validation if validation loader exists
avg_epoch_loss = train_loss.getmean()
train_loss_history[epoch] = avg_epoch_loss
# TensorBoard logging for training
writer.add_scalar("Loss/train_epoch", avg_epoch_loss, epoch)
# Run validation
if valid_loader is not None:
avg_val_loss, val_metrics = run_validation(
model,
valid_dataset,
valid_loader,
loss_fn,
norm_mapping,
normalization_type,
index_mapping,
args,
steps,
device,
logger,
epoch,
writer,
plot_every_n_epochs=10,
paths=paths,
)
valid_loss_history[epoch] = avg_val_loss
# Update validation metrics history
for metric_name, tracker in val_metrics.items():
if metric_name in valid_metrics_history:
valid_metrics_history[metric_name].append(tracker.getmean())
else:
logger.warning(
f"Unexpected metric {metric_name} not found in valid_metrics_history"
)
# Update scheduler based on mean of all validation MAE history
if (
valid_loader is not None
and valid_metrics_history["average_pred_vs_fine_MAE"]
):
# Calculate mean of all validation MAE values so far using Python's sum/len
mae_history = valid_metrics_history["average_pred_vs_fine_MAE"]
mean_val_mae = sum(mae_history) / len(mae_history)
scheduler.step(mean_val_mae)
logger.info(
f"Scheduler step with mean validation MAE (all {len(mae_history)} epochs): {mean_val_mae:.4f}"
)
else:
scheduler.step(avg_epoch_loss)
# Log epoch results
logger.info(f"Epoch {epoch} completed - Train Loss: {avg_epoch_loss:.4f}")
if valid_loader is not None:
logger.info(f"Epoch {epoch} completed - Val Loss: {avg_val_loss:.4f}")
# Save epoch checkpoint (every 10 epochs)
if args.save_model and epoch % 10 == 0:
ModelUtils.save_training_checkpoint(
model=model,
optimizer=optimizer,
epoch=epoch,
samples_processed=samples_processed,
batches_processed=batches_processed,
train_loss_history=train_loss_history,
valid_loss_history=valid_loss_history,
valid_metrics_history=valid_metrics_history,
best_val_loss=best_val_loss,
best_epoch=best_epoch,
avg_val_loss=avg_val_loss if valid_loader is not None else 0.0,
avg_epoch_loss=avg_epoch_loss,
args=args,
paths=paths,
logger=logger,
checkpoint_type="epoch",
save_full_model=True,
)
# Save best model based on validation loss
if valid_loader is not None and avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
best_epoch = epoch
ModelUtils.save_training_checkpoint(
model=model,
optimizer=optimizer,
epoch=epoch,
samples_processed=samples_processed,
batches_processed=batches_processed,
train_loss_history=train_loss_history,
valid_loss_history=valid_loss_history,
valid_metrics_history=valid_metrics_history,
best_val_loss=best_val_loss,
best_epoch=best_epoch,
avg_val_loss=avg_val_loss,
avg_epoch_loss=avg_epoch_loss,
args=args,
paths=paths,
logger=logger,
checkpoint_type="best",
save_full_model=True,
)
# Generate plots at the end of training
logger.info("Generating training summary plots...")
# Plot losses
plot_loss_histories(
train_loss_history,
valid_loss_history,
filename=f"training_validation_loss_{args.prefix}.png",
save_dir=paths.results,
)
# Plot metrics (only validation metrics available)
plot_metric_histories(
valid_metrics_history,
args.varnames_list,
metric_names,
filename=f"validation_metrics_{args.prefix}",
save_dir=paths.results,
)
# Plot average metrics
plot_average_metrics(
valid_metrics_history,
metric_names,
filename=f"average_metrics_{args.prefix}.png",
save_dir=paths.results,
)
logger.info("Training summary plots generated successfully!")
logger.info("Generating spatiotemporal coverage plots...")
# Plot training data
plot_spatiotemporal_histograms(
steps,
tindex_lim=(train_dataset.stime, train_dataset.etime),
centers=train_dataset.center_tracker,
tindices=train_dataset.tindex_tracker,
mode=train_dataset.mode,
filename=f"{args.prefix}_",
save_dir=paths.results,
)
# Plot validation data if available
if valid_dataset is not None:
plot_spatiotemporal_histograms(
steps,
tindex_lim=(valid_dataset.stime, valid_dataset.etime),
centers=valid_dataset.center_tracker,
tindices=valid_dataset.tindex_tracker,
mode=valid_dataset.mode,
filename=f"{args.prefix}_",
save_dir=paths.results,
)
logger.info("Spatiotemporal coverage plots completed!")
# Save final checkpoint at the end of training
if args.save_model:
ModelUtils.save_training_checkpoint(
model=model,
optimizer=optimizer,
epoch=args.num_epochs - 1,
samples_processed=samples_processed,
batches_processed=batches_processed,
train_loss_history=train_loss_history,
valid_loss_history=valid_loss_history,
valid_metrics_history=valid_metrics_history,
best_val_loss=best_val_loss,
best_epoch=best_epoch,
avg_val_loss=avg_val_loss if valid_loader is not None else 0.0,
avg_epoch_loss=avg_epoch_loss,
args=args,
paths=paths,
logger=logger,
checkpoint_type="final",
save_full_model=True,
)
logger.info("Final model checkpoint saved successfully!")
logger.info("Training process completed successfully!")
if __name__ == "__main__":
main()