Inference Modes

The trained model can be used in multiple inference modes:

  • Global inference (full spatial coverage)

  • Regional inference (single region or subdomain)

  • Direct prediction

  • Sampler-based diffusion inference

This flexibility makes IPSL-AID suitable for both research experiments and downstream climate impact studies.

Global Inference

For global coverage, 20 fixed blocks corresponding to the ERA5 resolution of \(1440\times721\) are used to produce global predictions:

  1. Tiling: Divide globe into overlapping blocks

  2. Processing: Run inference on each block

  3. Merging: Stitch blocks together with blending

  4. Postprocessing: Apply consistency checks and corrections

Regional Inference

Regional inference allows running the model on a specific geographic subset of the global domain instead of processing the entire globe.

This mode is particularly useful for regional studies (e.g., Europe, North America, Southeast Asia), where only a limited area is of interest.

Conceptually, the model operates on a spatial window extracted from the global grid, centered on a given location with a fixed spatial extent.

Configuration:

inference:
  run_type: inference_regional

  # Option 1: predefined region
  region: "europe"

  # Option 2: custom region (lat, lon)
  region_center: 50.0 10.0

  # Region size (lat_size, lon_size)
  region_size: 144 360

  # Supported sizes:
  # lat can be 144 or 288
  # lon can be 360 or 720

Region selection:

Two approaches are available:

  • Predefined region: Use a named region (e.g., "us", "europe", "asia"). The corresponding spatial boundaries are internally defined.

  • Custom region: Specify a center point using latitude and longitude (region_center). The model will extract a region centered around this location.

Either region or region_center must be provided.

Sampling Procedure

High-resolution samples are generated by numerically solving the reverse-time SDE. The sampler uses the second-order Heun scheme for improved accuracy:

Algorithm:

  1. Initialize: \(\mathbf{x}_0 \sim \mathcal{N}(\mathbf{0}, t_0^2 \mathbf{1})\)

  2. For each step \(i = 0, \dots, N-1\): a. Optionally add noise increment \(\gamma_{\mathrm{i}}\) b. Compute denoising direction \(\mathbf{d}_{\mathrm{i}}\) c. Update latent: \(\mathbf{x}_{\mathrm{i+1}} = \hat{\mathbf{x}}_{\mathrm{i}} + (t_{\mathrm{i+1}} - \hat{t}_{\mathrm{i}}) \, \mathbf{d}_{\mathrm{i}}\) d. Apply 2nd-order correction if \(t_{\mathrm{i+1}} \neq 0\)

  3. Return: Final denoised sample \(\mathbf{x}_{\rm N}\)

Sampling Parameters

sampling:
  steps: 20                    # Number of sampling steps
  sigma_min: 0.002             # Minimum noise level
  sigma_max: 80.0              # Maximum noise level
  rho: 7.0                     # Controls step distribution
  sampler: "heun"              # Integration scheme
  s_churn: 40.0                # Stochasticity parameter
  s_min: 0.05                  # Minimum stochasticity
  s_max: 50.0                  # Maximum stochasticity
  s_noise: 1.003               # Noise scale

Ensemble Generation

For uncertainty quantification, generate multiple samples:

  1. Multiple seeds: Different random seeds for sampling

  2. Parameter variations: Different sampler settings

  3. Model ensembles: Average predictions from multiple checkpoints

  4. Statistical analysis: Compute means, variances, quantiles

Postprocessing

After generation:

  1. Add residuals: \(\mathbf{y}^{\mathrm{HR}} = \mathbf{y}^{\mathrm{CU}} + \mathbf{R}'\)

  2. Denormalize: Convert from normalized to physical units

  3. Quality checks: Validate physical constraints

  4. Format conversion: Save in standard formats (NetCDF, GeoTIFF)

Performance Optimization

  • Batch inference: Process multiple time steps simultaneously

  • Memory management: Clear intermediate results

  • GPU utilization: Maximize GPU occupancy

  • I/O optimization: Efficient reading/writing of large files

Real-time Applications

For near real-time downscaling:

  1. Streaming input: Ingest coarse forecasts

  2. Fast inference: Optimized sampler settings

  3. Caching: Reuse computations where possible

  4. Parallelization: Distribute across multiple GPUs/nodes

Validation and Evaluation

During inference, compute:

  1. Deterministic metrics: MAE, RMSE, R² against observations

  2. Probabilistic metrics: CRPS, spread-skill ratio

  3. Spatial statistics: Power spectra, variograms

  4. Extreme values: Quantile scores, tail statistics

Example Usage

from IPSL_AID.evaluater import run_validation

# Load trained model
model = load_model("checkpoints/corresponding_expriment/best_model.pth")

# Run global inference
avg_val_loss, val_metrics = run_validation(
   model,
   valid_dataset,
   valid_loader,
   loss_fn,
   norm_mapping,
   normalization_type,
   index_mapping,
   args,
   steps,
   device,
   logger,
   epoch=0,
   writer=writer,
   plot_every_n_epochs=1,
   edm_sampler_steps=20,
   paths=paths,
   compute_crps=True
   )

# Check results
"results/corresponding_expriment/*.png"