Neural Architectures

IPSL-AID relies on UNet-based architectures, adapted for climate data:

  • ADM-style UNet

  • Conditional UNet variants

  • Support for static and dynamic covariates

  • Flexible input / output channel definitions

Architectures are selected using runtime parameters, allowing rapid experimentation without code changes.

_images/unet.png

Shematic of the UNet architecture used in IPSL-AID, showing encoder, decoder, and attention components.

Default Configuration

The U-Net configuration includes:

  • Base channel count: \(C_\mathrm{base} = 128\)

  • Channel multipliers per resolution: [1, 2, 3, 4]

  • Residual blocks per resolution: 3

  • Self-attention at resolutions: [32, 16, 8]

  • Dropout probability: \(p = 0.10\)

  • Embedding dimension: \(C_\mathrm{emb} = 4 \times C_\mathrm{base}\)

Architecture Components

Encoder

Progressive downsampling with convolutional and residual blocks. At level \(l\), feature map has height \(H_{\mathrm l} = \lfloor H / 2^{\mathrm l} \rfloor\) and width \(W_{\mathrm l} = \lfloor W / 2^{\mathrm l} \rfloor\).

Decoder

Mirror of encoder with upsampling and skip connections from encoder.

Attention

Multi-head self-attention at specific resolutions (64 channels per head):

\[\text{Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{softmax}\Big({\mathbf{Q}^\top \mathbf{K}}/{\sqrt{d_{\mathrm k}}}\Big)\mathbf{V}\]
Conditioning

Support for: - Noise level embeddings - Class conditioning (e.g., season, region) - Augmentation embeddings - Spatiotemporal context (latitude, longitude, time)

Embedding Layers

Noise levels \(\sigma\) are represented using sinusoidal positional embedding:

\[\mathbf{e}_\sigma = \text{PE}(\sigma) \in \mathbb{R}^{C_\mathrm{base}}\]

Processed by two fully connected layers with SiLU activations.

Conditioning Strategies

  1. Spatial Conditioning: Low-resolution inputs concatenated channel-wise

  2. Global Conditioning: Scalar features projected and added to embeddings

  3. Adaptive Normalization: Feature-wise modulation based on conditioning

  4. Cross-Attention: Attention between features and conditioning vectors

Input/Output Specification

class DhariwalUNet(nn.Module):
    def __init__(
      self,
      img_resolution,                     # Image resolution as tuple (height, width)
      in_channels,                        # Number of color channels at input.
      out_channels,                       # Number of color channels at output.
      label_dim           = 0,            # Number of class labels, 0 = unconditional.
      augment_dim         = 0,            # Augmentation label dimensionality, 0 = no augmentation.
      model_channels      = 128,          # Base multiplier for the number of channels.
      channel_mult        = [1,2,3,4],    # Per-resolution multipliers for the number of channels.
      channel_mult_emb    = 4,            # Multiplier for the dimensionality of the embedding vector.
      num_blocks          = 3,            # Number of residual blocks per resolution.
      attn_resolutions    = [32,16,8],    # List of resolutions with self-attention.
      dropout             = 0.10,         # List of resolutions with self-attention.
      label_dropout       = 0,            # Dropout probability of class labels for classifier-free guidance.
      diffusion_model = True              # Whether to use the Unet for diffusion models.
      ):
      ...

Climate-Specific Adaptations

  1. Periodic Boundary Handling: Special convolutions for longitude wrapping

  2. Spatial Context: Incorporation of latitude/longitude grids

  3. Topography Integration: Terrain elevation as conditioning input

  4. Land-Sea Masks: Binary masks for land/ocean differentiation

Configuration Examples

opts = dict(
    arch="adm",
    precond="edm",
    img_resolution=[128, 128],

    # --------------------------------------------------
    # Data channels
    # --------------------------------------------------
    in_channels=3,        # ["t2m", "u10", "v10"]
    cond_channels=7,      # 3 variables + z + LSM + lat + lon
    out_channels=3,       # same as in_channels

    label_dim=2,          # conditional model
    use_fp16=True,

    # --------------------------------------------------
    # Architecture overrides
    # --------------------------------------------------
    model_kwargs=dict(
        model_channels=128,
        channel_mult=[1, 2, 3, 4],
        num_blocks=3,
        dropout=0.1,
    )
)

Performance Considerations

  • Memory: Channel multipliers control memory usage

  • Speed: Attention layers can be computationally expensive

  • Accuracy: More channels/residual blocks generally improve quality

  • Overfitting: Dropout and regularization important for small datasets