Inference Modes
The trained model can be used in multiple inference modes:
Global inference (full spatial coverage)
Regional inference (single region or subdomain)
Direct prediction
Sampler-based diffusion inference
This flexibility makes IPSL-AID suitable for both research experiments and downstream climate impact studies.
Global Inference
For global coverage, 20 fixed blocks corresponding to the ERA5 resolution of \(1440\times721\) are used to produce global predictions:
Tiling: Divide globe into overlapping blocks
Processing: Run inference on each block
Merging: Stitch blocks together with blending
Postprocessing: Apply consistency checks and corrections
Regional Inference
Regional inference allows running the model on a specific geographic subset of the global domain instead of processing the entire globe.
This mode is particularly useful for regional studies (e.g., Europe, North America, Southeast Asia), where only a limited area is of interest.
Conceptually, the model operates on a spatial window extracted from the global grid, centered on a given location with a fixed spatial extent.
Configuration:
inference:
run_type: inference_regional
# Option 1: predefined region
region: "europe"
# Option 2: custom region (lat, lon)
region_center: 50.0 10.0
# Region size (lat_size, lon_size)
region_size: 144 360
# Supported sizes:
# lat can be 144 or 288
# lon can be 360 or 720
Region selection:
Two approaches are available:
Predefined region: Use a named region (e.g.,
"us","europe","asia"). The corresponding spatial boundaries are internally defined.Custom region: Specify a center point using latitude and longitude (
region_center). The model will extract a region centered around this location.
Either region or region_center must be provided.
Sampling Procedure
High-resolution samples are generated by numerically solving the reverse-time SDE. The sampler uses the second-order Heun scheme for improved accuracy:
Algorithm:
Initialize: \(\mathbf{x}_0 \sim \mathcal{N}(\mathbf{0}, t_0^2 \mathbf{1})\)
For each step \(i = 0, \dots, N-1\): a. Optionally add noise increment \(\gamma_{\mathrm{i}}\) b. Compute denoising direction \(\mathbf{d}_{\mathrm{i}}\) c. Update latent: \(\mathbf{x}_{\mathrm{i+1}} = \hat{\mathbf{x}}_{\mathrm{i}} + (t_{\mathrm{i+1}} - \hat{t}_{\mathrm{i}}) \, \mathbf{d}_{\mathrm{i}}\) d. Apply 2nd-order correction if \(t_{\mathrm{i+1}} \neq 0\)
Return: Final denoised sample \(\mathbf{x}_{\rm N}\)
Sampling Parameters
sampling:
steps: 20 # Number of sampling steps
sigma_min: 0.002 # Minimum noise level
sigma_max: 80.0 # Maximum noise level
rho: 7.0 # Controls step distribution
sampler: "heun" # Integration scheme
s_churn: 40.0 # Stochasticity parameter
s_min: 0.05 # Minimum stochasticity
s_max: 50.0 # Maximum stochasticity
s_noise: 1.003 # Noise scale
Ensemble Generation
For uncertainty quantification, generate multiple samples:
Multiple seeds: Different random seeds for sampling
Parameter variations: Different sampler settings
Model ensembles: Average predictions from multiple checkpoints
Statistical analysis: Compute means, variances, quantiles
Postprocessing
After generation:
Add residuals: \(\mathbf{y}^{\mathrm{HR}} = \mathbf{y}^{\mathrm{CU}} + \mathbf{R}'\)
Denormalize: Convert from normalized to physical units
Quality checks: Validate physical constraints
Format conversion: Save in standard formats (NetCDF, GeoTIFF)
Performance Optimization
Batch inference: Process multiple time steps simultaneously
Memory management: Clear intermediate results
GPU utilization: Maximize GPU occupancy
I/O optimization: Efficient reading/writing of large files
Real-time Applications
For near real-time downscaling:
Streaming input: Ingest coarse forecasts
Fast inference: Optimized sampler settings
Caching: Reuse computations where possible
Parallelization: Distribute across multiple GPUs/nodes
Validation and Evaluation
During inference, compute:
Deterministic metrics: MAE, RMSE, R² against observations
Probabilistic metrics: CRPS, spread-skill ratio
Spatial statistics: Power spectra, variograms
Extreme values: Quantile scores, tail statistics
Example Usage
from IPSL_AID.evaluater import run_validation
# Load trained model
model = load_model("checkpoints/corresponding_expriment/best_model.pth")
# Run global inference
avg_val_loss, val_metrics = run_validation(
model,
valid_dataset,
valid_loader,
loss_fn,
norm_mapping,
normalization_type,
index_mapping,
args,
steps,
device,
logger,
epoch=0,
writer=writer,
plot_every_n_epochs=1,
edm_sampler_steps=20,
paths=paths,
compute_crps=True
)
# Check results
"results/corresponding_expriment/*.png"