# Copyright 2026 IPSL / CNRS / Sorbonne University
# Authors: Kazem Ardaneh
#
# ============================================================================
# ORIGINAL WORK (NVIDIA)
# ============================================================================
# This work is a derivative of "Elucidating the Design Space of
# Diffusion-Based Generative Models" by NVIDIA CORPORATION & AFFILIATES.
#
# Original work: Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES.
# Original license: Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
# Original source: https://github.com/NVlabs/edm
#
# ============================================================================
# MODIFICATIONS AND ADDITIONS (IPSL / CNRS / Sorbonne University)
# ============================================================================
# Modifications include:
# 1. Added rectangular resolution support for non-square images
# - Modified SongUNet and DhariwalUNet to handle (height, width) tuples
# - Updated resolution calculations throughout the network
# - Fixed attention resolution checks for rectangular inputs
#
# 2. Added conditional image support to all preconditioners
# - Extended VPPrecond, VEPrecond, iDDPMPrecond, and EDMPrecond
# - Added condition_img parameter for image-based conditioning
# - Implemented channel-wise concatenation for conditional inputs
#
# 3. Enhanced documentation and type hints
# - Added comprehensive docstrings for all classes and methods
# - Improved parameter descriptions and usage examples
#
# 4. Added comprehensive unit tests
# - Created TestDiffusionNetworks class
# - Added tests for square and rectangular resolutions
# - Added tests for all preconditioners with conditional images
# - Added parameter count validation tests
#
# 5. Code quality improvements
# - Replaced generic assertions with specific error messages
# - Added input validation for resolution parameters
# - Improved variable naming for clarity
#
# ============================================================================
# LICENSE
# ============================================================================
# This derivative work is licensed under the same terms as the original:
# Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License.
# To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc-sa/4.0/
# ============================================================================
# ACKNOWLEDGMENTS
# ============================================================================
# We thank the NVIDIA team for their excellent work on EDM and for making it
# available under an open license that enables further research and development.
import numpy as np
import torch
import unittest
from torch.nn.functional import silu
# ----------------------------------------------------------------------------
# Unified routine for initializing weights and biases.
[docs]
def weight_init(shape, mode, fan_in, fan_out):
"""
Initialize weights using various initialization methods.
Parameters
----------
shape : tuple of ints
The shape of the weight tensor to initialize.
mode : str
The initialization method to use. Options are:
- 'xavier_uniform': Xavier uniform initialization
- 'xavier_normal': Xavier normal initialization
- 'kaiming_uniform': Kaiming uniform initialization (also known as He initialization)
- 'kaiming_normal': Kaiming normal initialization (also known as He initialization)
fan_in : int
Number of input units in the weight tensor.
fan_out : int
Number of output units in the weight tensor.
Returns
-------
torch.Tensor
A tensor of the specified shape with values initialized according
to the chosen method.
Raises
------
ValueError
If an invalid initialization mode is provided.
"""
if mode == "xavier_uniform":
return np.sqrt(6 / (fan_in + fan_out)) * (torch.rand(*shape) * 2 - 1)
if mode == "xavier_normal":
return np.sqrt(2 / (fan_in + fan_out)) * torch.randn(*shape)
if mode == "kaiming_uniform":
return np.sqrt(3 / fan_in) * (torch.rand(*shape) * 2 - 1)
if mode == "kaiming_normal":
return np.sqrt(1 / fan_in) * torch.randn(*shape)
raise ValueError(f'Invalid init mode "{mode}"')
# ----------------------------------------------------------------------------
# Fully-connected layer.
[docs]
class Linear(torch.nn.Module):
"""
A linear (fully connected) layer with customizable weight initialization.
This layer applies a linear transformation to the incoming data: ``y = x W^T + b``.
Parameters
----------
in_features : int
Size of each input sample.
out_features : int
Size of each output sample.
bias : bool, optional
If set to False, the layer will not learn an additive bias.
Default is True.
init_mode : str, optional
Weight initialization method. Options are:
- ``xavier_uniform``: Xavier uniform initialization
- ``xavier_normal``: Xavier normal initialization
- ``kaiming_uniform``: Kaiming uniform initialization (He initialization)
- ``kaiming_normal``: Kaiming normal initialization (He initialization)
Default is ``kaiming_normal``.
init_weight : float or int, optional
Scaling factor for the initialized weights.
Default is 1.
init_bias : float or int, optional
Scaling factor for the initialized bias.
Default is 0.
Attributes
----------
weight : torch.nn.Parameter
The learnable weights of the layer of shape (out_features, in_features).
bias : torch.nn.Parameter or None
The learnable bias of the layer of shape (out_features,). If bias=False,
this attribute is set to None.
"""
[docs]
def __init__(
self,
in_features,
out_features,
bias=True,
init_mode="kaiming_normal",
init_weight=1,
init_bias=0,
):
"""
Initialize the Linear layer.
Parameters
----------
in_features : int
Size of each input sample.
out_features : int
Size of each output sample.
bias : bool, optional
If set to False, the layer will not learn an additive bias.
Default is True.
init_mode : str, optional
Weight initialization method.
Default is 'kaiming_normal'.
init_weight : float or int, optional
Scaling factor for the initialized weights.
Default is 1.
init_bias : float or int, optional
Scaling factor for the initialized bias.
Default is 0.
"""
super().__init__()
self.in_features = in_features
self.out_features = out_features
init_kwargs = dict(mode=init_mode, fan_in=in_features, fan_out=out_features)
self.weight = torch.nn.Parameter(
weight_init([out_features, in_features], **init_kwargs) * init_weight
)
self.bias = (
torch.nn.Parameter(weight_init([out_features], **init_kwargs) * init_bias)
if bias
else None
)
[docs]
def forward(self, x):
"""
Forward pass of the linear layer.
Parameters
----------
x : torch.Tensor
Input tensor of shape (batch_size, in_features) or
``(batch_size, *, in_features)`` where ``*`` means any number of
additional dimensions.
Returns
-------
torch.Tensor
Output tensor of shape (batch_size, out_features) or
``(batch_size, *, out_features)``.
Notes
-----
The operation performed is: ``output = x @ weight^T + bias``.
The bias is added in-place for efficiency when possible.
"""
x = x @ self.weight.to(x.dtype).t()
if self.bias is not None:
x = x.add_(self.bias.to(x.dtype))
return x
# ----------------------------------------------------------------------------
# Convolutional layer with optional up/downsampling.
[docs]
class Conv2d(torch.nn.Module):
"""
2D convolutional layer with optional upsampling, downsampling, and fused resampling.
This layer implements a 2D convolution that can optionally include upsampling
or downsampling operations with configurable resampling filters. It supports
both separate and fused resampling modes for efficiency.
Parameters
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
kernel : int
Size of the convolutional kernel (square kernel).
bias : bool, optional
If set to False, the layer will not learn an additive bias.
Default is True.
up : bool, optional
If True, upsample the input by a factor of 2 before convolution.
Cannot be True if `down` is also True.
Default is False.
down : bool, optional
If True, downsample the output by a factor of 2 after convolution.
Cannot be True if `up` is also True.
Default is False.
resample_filter : list, optional
Coefficients of the 1D resampling filter that will be turned into a 2D filter.
Default is [1, 1] (bilinear filter).
fused_resample : bool, optional
If True, fuse the resampling operation with the convolution for efficiency.
Default is False.
init_mode : str, optional
Weight initialization method. Options are:
- 'xavier_uniform': Xavier uniform initialization
- 'xavier_normal': Xavier normal initialization
- 'kaiming_uniform': Kaiming uniform initialization (He initialization)
- 'kaiming_normal': Kaiming normal initialization (He initialization)
Default is 'kaiming_normal'.
init_weight : float or int, optional
Scaling factor for the initialized weights.
Default is 1.
init_bias : float or int, optional
Scaling factor for the initialized bias.
Default is 0.
Attributes
----------
weight : torch.nn.Parameter or None
The learnable weights of the convolution of shape
(out_channels, in_channels, kernel, kernel). If kernel is 0, this is None.
bias : torch.nn.Parameter or None
The learnable bias of the convolution of shape (out_channels,).
If kernel is 0 or bias is False, this is None.
resample_filter : torch.Tensor or None
The 2D resampling filter used for upsampling or downsampling.
Registered as a buffer (non-learnable parameter).
Raises
------
AssertionError
If both `up` and `down` are set to True.
Notes
-----
- When `kernel` is 0, no convolution is performed, only resampling if enabled.
- The resampling filter is created by taking the outer product of the 1D filter
with itself to create a separable 2D filter, then normalized.
- Fused resampling combines the resampling and convolution operations into
single operations for better performance.
Examples
--------
>>> # Standard convolution
>>> conv = Conv2d(3, 16, kernel=3)
>>> x = torch.randn(4, 3, 32, 32)
>>> out = conv(x)
>>> out.shape
torch.Size([4, 16, 32, 32])
>>> # Convolution with downsampling
>>> conv_down = Conv2d(3, 16, kernel=3, down=True)
>>> out = conv_down(x)
>>> out.shape
torch.Size([4, 16, 16, 16])
>>> # Convolution with upsampling
>>> conv_up = Conv2d(3, 16, kernel=3, up=True)
>>> out = conv_up(x)
>>> out.shape
torch.Size([4, 16, 64, 64])
"""
[docs]
def __init__(
self,
in_channels,
out_channels,
kernel,
bias=True,
up=False,
down=False,
resample_filter=[1, 1],
fused_resample=False,
init_mode="kaiming_normal",
init_weight=1,
init_bias=0,
):
"""
Initialize the Conv2d layer.
Parameters
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
kernel : int
Size of the convolutional kernel.
bias : bool, optional
Whether to include a bias term.
Default is True.
up : bool, optional
Whether to upsample the input.
Default is False.
down : bool, optional
Whether to downsample the output.
Default is False.
resample_filter : list, optional
Coefficients of the 1D resampling filter.
Default is [1, 1].
fused_resample : bool, optional
Whether to fuse resampling with convolution.
Default is False.
init_mode : str, optional
Weight initialization method.
Default is 'kaiming_normal'.
init_weight : float or int, optional
Scaling factor for weight initialization.
Default is 1.
init_bias : float or int, optional
Scaling factor for bias initialization.
Default is 0.
"""
assert not (up and down)
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.up = up
self.down = down
self.fused_resample = fused_resample
init_kwargs = dict(
mode=init_mode,
fan_in=in_channels * kernel * kernel,
fan_out=out_channels * kernel * kernel,
)
self.weight = (
torch.nn.Parameter(
weight_init([out_channels, in_channels, kernel, kernel], **init_kwargs)
* init_weight
)
if kernel
else None
)
self.bias = (
torch.nn.Parameter(weight_init([out_channels], **init_kwargs) * init_bias)
if kernel and bias
else None
)
f = torch.as_tensor(resample_filter, dtype=torch.float32)
f = f.ger(f).unsqueeze(0).unsqueeze(1) / f.sum().square()
self.register_buffer("resample_filter", f if up or down else None)
[docs]
def forward(self, x):
"""
Forward pass of the Conv2d layer.
Parameters
----------
x : torch.Tensor
Input tensor of shape (batch_size, in_channels, height, width).
Returns
-------
torch.Tensor
Output tensor of shape (batch_size, out_channels, out_height, out_width).
If `up` is True, spatial dimensions are doubled.
If `down` is True, spatial dimensions are halved.
Notes
-----
The method handles four main cases:
1. Fused upsampling + convolution
2. Fused convolution + downsampling
3. Separate up/down sampling followed by convolution
4. Standard convolution only
"""
w = self.weight.to(x.dtype) if self.weight is not None else None
b = self.bias.to(x.dtype) if self.bias is not None else None
f = (
self.resample_filter.to(x.dtype)
if self.resample_filter is not None
else None
)
w_pad = w.shape[-1] // 2 if w is not None else 0
f_pad = (f.shape[-1] - 1) // 2 if f is not None else 0
if self.fused_resample and self.up and w is not None:
x = torch.nn.functional.conv_transpose2d(
x,
f.mul(4).tile([self.in_channels, 1, 1, 1]),
groups=self.in_channels,
stride=2,
padding=max(f_pad - w_pad, 0),
)
x = torch.nn.functional.conv2d(x, w, padding=max(w_pad - f_pad, 0))
elif self.fused_resample and self.down and w is not None:
x = torch.nn.functional.conv2d(x, w, padding=w_pad + f_pad)
x = torch.nn.functional.conv2d(
x,
f.tile([self.out_channels, 1, 1, 1]),
groups=self.out_channels,
stride=2,
)
else:
if self.up:
x = torch.nn.functional.conv_transpose2d(
x,
f.mul(4).tile([self.in_channels, 1, 1, 1]),
groups=self.in_channels,
stride=2,
padding=f_pad,
)
if self.down:
x = torch.nn.functional.conv2d(
x,
f.tile([self.in_channels, 1, 1, 1]),
groups=self.in_channels,
stride=2,
padding=f_pad,
)
if w is not None:
x = torch.nn.functional.conv2d(x, w, padding=w_pad)
if b is not None:
x = x.add_(b.reshape(1, -1, 1, 1))
return x
# ----------------------------------------------------------------------------
# Group normalization.
[docs]
class GroupNorm(torch.nn.Module):
"""
Group Normalization layer.
This layer implements Group Normalization, which divides channels into groups
and computes within each group the mean and variance for normalization.
It is particularly effective for small batch sizes and often used as an
alternative to Batch Normalization.
Parameters
----------
num_channels : int
Number of input channels.
num_groups : int, optional
Number of groups to divide the channels into. Must be a divisor of
the number of channels. The actual number of groups may be reduced
to satisfy `min_channels_per_group`.
Default is 32.
min_channels_per_group : int, optional
Minimum number of channels per group. If the division would result
in fewer channels per group, the number of groups is reduced.
Default is 4.
eps : float, optional
A small constant added to the denominator for numerical stability.
Default is 1e-5.
Attributes
----------
weight : torch.nn.Parameter
Learnable scale parameter of shape (num_channels,).
Initialized to ones.
bias : torch.nn.Parameter
Learnable bias parameter of shape (num_channels,).
Initialized to zeros.
Notes
-----
- Group Normalization is independent of batch size, making it suitable
for variable batch sizes and small batch training.
- The number of groups is automatically adjusted to ensure each group has
at least `min_channels_per_group` channels.
- This layer uses PyTorch's built-in `torch.nn.functional.group_norm`.
"""
[docs]
def __init__(self, num_channels, num_groups=32, min_channels_per_group=4, eps=1e-5):
"""
Initialize the GroupNorm layer.
Parameters
----------
num_channels : int
Number of input channels.
num_groups : int, optional
Desired number of groups.
Default is 32.
min_channels_per_group : int, optional
Minimum channels per group.
Default is 4.
eps : float, optional
Small constant for numerical stability.
Default is 1e-5.
"""
super().__init__()
self.num_groups = min(num_groups, num_channels // min_channels_per_group)
self.eps = eps
self.weight = torch.nn.Parameter(torch.ones(num_channels))
self.bias = torch.nn.Parameter(torch.zeros(num_channels))
[docs]
def forward(self, x):
"""
Forward pass of the GroupNorm layer.
Parameters
----------
x : torch.Tensor
Input tensor of shape (batch_size, num_channels, height, width).
Returns
-------
torch.Tensor
Normalized tensor of same shape as input.
Notes
-----
The normalization is performed across spatial dimensions and within
each group of channels, maintaining the original mean and variance
statistics per group.
"""
x = torch.nn.functional.group_norm(
x,
num_groups=self.num_groups,
weight=self.weight.to(x.dtype),
bias=self.bias.to(x.dtype),
eps=self.eps,
)
return x
# ----------------------------------------------------------------------------
# Attention weight computation, i.e., softmax(Q^T * K).
# Performs all computation using FP32, but uses the original datatype for
# inputs/outputs/gradients to conserve memory.
[docs]
class AttentionOp(torch.autograd.Function):
"""
Custom autograd function for scaled dot-product attention weight computation.
This function computes attention weights using scaled dot-product attention:
w = softmax(Q·K^T / √d_k), where d_k is the dimension of the key vectors.
It implements both forward and backward passes for gradient computation.
Notes
-----
- This is a stateless operation that uses torch.autograd.Function for custom backward.
- The forward pass computes attention weights in float32 for numerical stability.
- The backward pass computes gradients using the chain rule for softmax and matrix multiplication.
- This implementation is optimized for memory efficiency during backward pass.
"""
[docs]
@staticmethod
def forward(ctx, q, k):
"""
Forward pass for attention weight computation.
Parameters
----------
ctx : torch.autograd.function.BackwardCFunction
Context object to save tensors for backward pass.
q : torch.Tensor
Query tensor of shape (batch_size, channels, query_length).
k : torch.Tensor
Key tensor of shape (batch_size, channels, key_length).
Returns
-------
torch.Tensor
Attention weights of shape (batch_size, query_length, key_length).
Each row represents attention distribution for a query position.
Notes
-----
- Computes w = softmax(Q·K^T / √d_k) where d_k = k.shape[1] (channel dimension).
- Uses float32 for computation to maintain numerical stability.
- Saves q, k, and w in context for backward pass.
"""
w = (
torch.einsum(
"ncq,nck->nqk",
q.to(torch.float32),
(k / np.sqrt(k.shape[1])).to(torch.float32),
)
.softmax(dim=2)
.to(q.dtype)
)
ctx.save_for_backward(q, k, w)
return w
[docs]
@staticmethod
def backward(ctx, dw):
"""
Backward pass for attention weight computation.
Parameters
----------
ctx : torch.autograd.function.BackwardCFunction
Context object containing saved tensors from forward pass.
dw : torch.Tensor
Gradient of loss with respect to attention weights.
Shape: (batch_size, query_length, key_length).
Returns
-------
dq : torch.Tensor
Gradient with respect to query tensor.
Shape: (batch_size, channels, query_length).
dk : torch.Tensor
Gradient with respect to key tensor.
Shape: (batch_size, channels, key_length).
Notes
-----
- Uses the saved tensors q, k, w from forward pass.
- Computes gradient of softmax using PyTorch's internal softmax_backward.
- Applies chain rule for the scaled dot-product operation.
- Maintains original dtypes of input tensors.
"""
q, k, w = ctx.saved_tensors
db = torch._softmax_backward_data(
grad_output=dw.to(torch.float32),
output=w.to(torch.float32),
dim=2,
input_dtype=torch.float32,
)
dq = torch.einsum("nck,nqk->ncq", k.to(torch.float32), db).to(
q.dtype
) / np.sqrt(k.shape[1])
dk = torch.einsum("ncq,nqk->nck", q.to(torch.float32), db).to(
k.dtype
) / np.sqrt(k.shape[1])
return dq, dk
# ----------------------------------------------------------------------------
# Unified U-Net block with optional up/downsampling and self-attention.
# Represents the union of all features employed by the DDPM++, NCSN++, and
# ADM architectures.
[docs]
class UNetBlock(torch.nn.Module):
"""
U-Net block with optional attention, up/down sampling, and adaptive scaling.
This block implements a residual block commonly used in U-Net architectures
for diffusion models and image-to-image translation. It supports:
- Residual connections with optional skip scaling
- Adaptive scaling/shifting via conditioning embeddings
- Multi-head self-attention mechanisms
- Upsampling or downsampling operations
- Dropout for regularization
Parameters
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
emb_channels : int
Number of embedding (conditioning) channels.
up : bool, optional
If True, upsample the input by a factor of 2.
Default is False.
down : bool, optional
If True, downsample the output by a factor of 2.
Default is False.
attention : bool, optional
If True, include multi-head self-attention in the block.
Default is False.
num_heads : int, optional
Number of attention heads. If None, computed as out_channels // channels_per_head.
Default is None.
channels_per_head : int, optional
Number of channels per attention head when num_heads is None.
Default is 64.
dropout : float, optional
Dropout probability applied after the first activation.
Default is 0.
skip_scale : float, optional
Scaling factor applied to the residual connection.
Default is 1.
eps : float, optional
Epsilon value for GroupNorm layers for numerical stability.
Default is 1e-5.
resample_filter : list, optional
Coefficients for the resampling filter used in up/down sampling.
Default is [1, 1].
resample_proj : bool, optional
If True, use a 1x1 convolution in the skip connection when resampling.
Default is False.
adaptive_scale : bool, optional
If True, use both scale and shift parameters from the embedding.
If False, use only shift parameters.
Default is True.
init : dict, optional
Initialization parameters for most convolutional layers.
Default is empty dict.
init_zero : dict, optional
Initialization parameters for final convolutional layers (zero initialization).
Default is {'init_weight': 0}.
init_attn : dict, optional
Initialization parameters for attention layers.
If None, uses the same as `init`.
Default is None.
Attributes
----------
norm0, norm1, norm2 : GroupNorm
Group normalization layers.
conv0, conv1 : Conv2d
Convolutional layers.
affine : Linear
Linear layer for conditioning embedding.
skip : Conv2d or None
Skip connection projection (1x1 conv) if input and output channels differ or resampling.
qkv, proj : Conv2d
Attention query-key-value and projection layers (if attention is enabled).
Notes
-----
- The block follows a pre-activation residual structure: norm -> activation -> conv.
- When `adaptive_scale=True`, the conditioning embedding provides both scale and shift parameters.
- The attention mechanism uses multi-head self-attention within the spatial dimensions.
- The skip connection is automatically added when input/output channels differ or when resampling.
"""
[docs]
def __init__(
self,
in_channels,
out_channels,
emb_channels,
up=False,
down=False,
attention=False,
num_heads=None,
channels_per_head=64,
dropout=0,
skip_scale=1,
eps=1e-5,
resample_filter=[1, 1],
resample_proj=False,
adaptive_scale=True,
init=dict(),
init_zero=dict(init_weight=0),
init_attn=None,
):
"""
Initialize the UNetBlock.
Parameters
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
emb_channels : int
Number of embedding channels.
up : bool, optional
Enable upsampling.
down : bool, optional
Enable downsampling.
attention : bool, optional
Enable attention mechanism.
num_heads : int, optional
Number of attention heads.
channels_per_head : int, optional
Channels per attention head.
dropout : float, optional
Dropout probability.
skip_scale : float, optional
Scaling factor for skip connection.
eps : float, optional
Epsilon for GroupNorm.
resample_filter : list, optional
Filter for resampling.
resample_proj : bool, optional
Use projection in skip connection when resampling.
adaptive_scale : bool, optional
Use adaptive scaling from embedding.
init : dict, optional
Initialization parameters.
init_zero : dict, optional
Zero initialization parameters.
init_attn : dict, optional
Attention initialization parameters.
"""
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.emb_channels = emb_channels
self.num_heads = (
0
if not attention
else num_heads
if num_heads is not None
else out_channels // channels_per_head
)
self.dropout = dropout
self.skip_scale = skip_scale
self.adaptive_scale = adaptive_scale
self.norm0 = GroupNorm(num_channels=in_channels, eps=eps)
self.conv0 = Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel=3,
up=up,
down=down,
resample_filter=resample_filter,
**init,
)
self.affine = Linear(
in_features=emb_channels,
out_features=out_channels * (2 if adaptive_scale else 1),
**init,
)
self.norm1 = GroupNorm(num_channels=out_channels, eps=eps)
self.conv1 = Conv2d(
in_channels=out_channels, out_channels=out_channels, kernel=3, **init_zero
)
self.skip = None
if out_channels != in_channels or up or down:
kernel = 1 if resample_proj or out_channels != in_channels else 0
self.skip = Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel=kernel,
up=up,
down=down,
resample_filter=resample_filter,
**init,
)
if self.num_heads:
self.norm2 = GroupNorm(num_channels=out_channels, eps=eps)
self.qkv = Conv2d(
in_channels=out_channels,
out_channels=out_channels * 3,
kernel=1,
**(init_attn if init_attn is not None else init),
)
self.proj = Conv2d(
in_channels=out_channels,
out_channels=out_channels,
kernel=1,
**init_zero,
)
[docs]
def forward(self, x, emb):
"""
Forward pass of the UNetBlock.
Parameters
----------
x : torch.Tensor
Input tensor of shape (batch_size, in_channels, height, width).
emb : torch.Tensor
Conditioning embedding of shape (batch_size, emb_channels).
Returns
-------
torch.Tensor
Output tensor of shape (batch_size, out_channels, out_height, out_width).
Notes
-----
The forward pass consists of:
1. Initial normalization and convolution (with optional up/down sampling)
2. Adaptive scaling/shifting from conditioning embedding
3. Second normalization, dropout, and convolution
4. Skip connection with scaling
5. Optional multi-head self-attention
"""
orig = x
x = self.conv0(silu(self.norm0(x)))
params = self.affine(emb).unsqueeze(2).unsqueeze(3).to(x.dtype)
if self.adaptive_scale:
scale, shift = params.chunk(chunks=2, dim=1)
x = silu(torch.addcmul(shift, self.norm1(x), scale + 1))
else:
x = silu(self.norm1(x.add_(params)))
x = self.conv1(
torch.nn.functional.dropout(x, p=self.dropout, training=self.training)
)
x = x.add_(self.skip(orig) if self.skip is not None else orig)
x = x * self.skip_scale
if self.num_heads:
q, k, v = (
self.qkv(self.norm2(x))
.reshape(
x.shape[0] * self.num_heads, x.shape[1] // self.num_heads, 3, -1
)
.unbind(2)
)
w = AttentionOp.apply(q, k)
a = torch.einsum("nqk,nck->ncq", w, v)
x = self.proj(a.reshape(*x.shape)).add_(x)
x = x * self.skip_scale
return x
# ----------------------------------------------------------------------------
# Timestep embedding used in the DDPM++ and ADM architectures.
[docs]
class PositionalEmbedding(torch.nn.Module):
"""
Sinusoidal positional embedding for sequences or timesteps.
This module generates sinusoidal embeddings for input positions, commonly used
in transformer architectures and diffusion models to provide temporal or
positional information to the model.
Parameters
----------
num_channels : int
Dimensionality of the embedding vectors. Must be even.
max_positions : int, optional
Maximum number of positions (or timesteps) for which embeddings are generated.
Determines the frequency scaling. Default is 10000.
endpoint : bool, optional
If True, scales frequencies such that the last frequency is 1/2 of the first.
If False, uses the standard scaling. Default is False.
Attributes
----------
num_channels : int
Dimensionality of the embedding vectors.
max_positions : int
Maximum positions for frequency scaling.
endpoint : bool
Whether to use endpoint scaling.
Notes
-----
- The embedding uses sine and cosine functions of different frequencies to
create a unique encoding for each position.
- The frequencies are computed as:
freqs = (1 / max_positions) ** (2i / num_channels) for i in range(num_channels//2)
or with endpoint adjustment.
- The output embedding is the concatenation of [cos(x*freqs), sin(x*freqs)].
- This implementation is based on the original Transformer positional encoding
and the diffusion model timestep embedding.
"""
[docs]
def __init__(self, num_channels, max_positions=10000, endpoint=False):
"""
Initialize the PositionalEmbedding module.
Parameters
----------
num_channels : int
Dimensionality of the embedding vectors.
max_positions : int, optional
Maximum number of positions for frequency scaling.
Default is 10000.
endpoint : bool, optional
Whether to use endpoint scaling.
Default is False.
"""
super().__init__()
self.num_channels = num_channels
self.max_positions = max_positions
self.endpoint = endpoint
[docs]
def forward(self, x):
"""
Forward pass to generate positional embeddings.
Parameters
----------
x : torch.Tensor
Input tensor of positions (or timesteps) of shape (batch_size,) or (n,).
Values are typically integers in [0, max_positions-1].
Returns
-------
torch.Tensor
Positional embeddings of shape (len(x), num_channels).
Notes
-----
- The input tensor `x` is typically a 1D tensor of position indices.
- The output is a 2D tensor where each row corresponds to the embedding
of the respective position.
- The embedding uses the device and dtype of the input tensor `x`.
"""
freqs = torch.arange(
start=0, end=self.num_channels // 2, dtype=torch.float32, device=x.device
)
freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0))
freqs = (1 / self.max_positions) ** freqs
x = x.ger(freqs.to(x.dtype))
x = torch.cat([x.cos(), x.sin()], dim=1)
return x
# ----------------------------------------------------------------------------
# Timestep embedding used in the NCSN++ architecture.
[docs]
class FourierEmbedding(torch.nn.Module):
"""
Random Fourier feature embedding for positional encoding.
This module generates random Fourier features (RFF) for input positions or
coordinates, mapping low-dimensional inputs to a higher-dimensional space
using random frequency sampling. Commonly used in neural fields, kernel
methods, and coordinate-based neural networks.
Parameters
----------
num_channels : int
Dimensionality of the embedding vectors. Must be even.
scale : float, optional
Standard deviation for sampling the random frequencies.
Determines the frequency distribution. Default is 16.
Attributes
----------
freqs : torch.Tensor (buffer)
Random frequencies sampled from a normal distribution with mean 0
and standard deviation `scale`. Shape: (num_channels // 2,).
Notes
-----
- The frequencies are randomly initialized and fixed (non-learnable).
- The embedding uses sine and cosine projections of the input multiplied
by 2π times the random frequencies.
- This technique approximates shift-invariant kernels via Bochner's theorem.
- Unlike learned embeddings, this provides a fixed, deterministic mapping
from input space to embedding space.
"""
[docs]
def __init__(self, num_channels, scale=16):
"""
Initialize the FourierEmbedding module.
Parameters
----------
num_channels : int
Dimensionality of the embedding vectors.
scale : float, optional
Standard deviation for frequency sampling.
Default is 16.
"""
super().__init__()
self.register_buffer("freqs", torch.randn(num_channels // 2) * scale)
[docs]
def forward(self, x):
"""
Forward pass to generate Fourier feature embeddings.
Parameters
----------
x : torch.Tensor
Input tensor of shape (batch_size,) or (n,).
Typically continuous values representing positions or coordinates.
Returns
-------
torch.Tensor
Fourier feature embeddings of shape (len(x), num_channels).
Notes
-----
- The transformation is: x ↦ [cos(2π * freqs * x), sin(2π * freqs * x)].
- The output dimension is twice the number of frequencies (num_channels).
- This embedding is deterministic given the fixed random frequencies.
"""
x = x.ger((2 * np.pi * self.freqs).to(x.dtype))
x = torch.cat([x.cos(), x.sin()], dim=1)
return x
# ----------------------------------------------------------------------------
# DDPM++ and NCSN++ architectures
[docs]
class SongUNet(torch.nn.Module):
"""
U-Net architecture for diffusion models based on Song et al. (2020).
This implementation supports both DDPM++ and NCSN++ architectures with
configurable encoder/decoder types, attention mechanisms, and conditioning.
It handles both square and rectangular input resolutions.
Parameters
----------
img_resolution : int or tuple
Input image resolution. If int, assumes square images (img_resolution x img_resolution).
If tuple, should be (height, width).
in_channels : int
Number of input color channels.
out_channels : int
Number of output color channels.
label_dim : int, optional
Number of class labels. Set to 0 for unconditional generation.
Default is 0.
augment_dim : int, optional
Dimensionality of augmentation labels (e.g., time-dependent augmentation).
Set to 0 for no augmentation. Default is 0.
model_channels : int, optional
Base channel multiplier for the network.
Default is 128.
channel_mult : list of int, optional
Channel multipliers for each resolution level.
Default is [1, 2, 2, 2].
channel_mult_emb : int, optional
Multiplier for embedding dimensionality relative to model_channels.
Default is 4.
num_blocks : int, optional
Number of residual blocks per resolution.
Default is 4.
attn_resolutions : list of int, optional
List of resolutions (minimum dimension) to apply self-attention.
Default is [16].
dropout : float, optional
Dropout probability for intermediate activations.
Default is 0.10.
label_dropout : float, optional
Dropout probability for class labels (classifier-free guidance).
Default is 0.
embedding_type : str, optional
Type of timestep embedding: 'positional' for DDPM++, 'fourier' for NCSN++.
Default is 'positional'.
channel_mult_noise : int, optional
Multiplier for noise embedding dimensionality: 1 for DDPM++, 2 for NCSN++.
Default is 1.
encoder_type : str, optional
Encoder architecture: 'standard' for DDPM++, 'skip' or 'residual' for NCSN++.
Default is 'standard'.
decoder_type : str, optional
Decoder architecture: 'standard' for both, 'skip' for NCSN++.
Default is 'standard'.
resample_filter : list, optional
Resampling filter coefficients: [1,1] for DDPM++, [1,3,3,1] for NCSN++.
Default is [1,1].
Attributes
----------
img_resolution : tuple
Input image resolution as (height, width).
img_height : int
Input image height.
img_width : int
Input image width.
label_dropout : float
Class label dropout probability.
map_noise : PositionalEmbedding or FourierEmbedding
Noise/timestep embedding module.
map_label : Linear or None
Class label embedding module.
map_augment : Linear or None
Augmentation label embedding module.
map_layer0, map_layer1 : Linear
Embedding transformation layers.
enc : torch.nn.ModuleDict
Encoder modules organized by resolution.
dec : torch.nn.ModuleDict
Decoder modules organized by resolution.
Raises
------
AssertionError
If embedding_type is not 'fourier' or 'positional'.
If encoder_type is not 'standard', 'skip', or 'residual'.
If decoder_type is not 'standard' or 'skip'.
If img_resolution tuple doesn't have exactly 2 elements.
Notes
-----
- The architecture follows a U-Net structure with skip connections.
- Supports multiple conditioning types: noise (timestep), class labels, augmentations.
- Attention is applied at specified resolutions to capture long-range dependencies.
- Different encoder/decoder types and embedding methods allow emulating DDPM++ or NCSN++.
- Rectangular resolutions are supported by tracking height and width separately.
References
----------
- Ho et al., "Denoising Diffusion Probabilistic Models" (DDPM)
- Song et al., "Score-Based Generative Modeling through Stochastic Differential Equations" (NCSN++)
"""
[docs]
def __init__(
self,
img_resolution, # Image resolution as tuple (height, width)
in_channels, # Number of color channels at input.
out_channels, # Number of color channels at output.
label_dim=0, # Number of class labels, 0 = unconditional.
augment_dim=0, # Augmentation label dimensionality, 0 = no augmentation.
model_channels=128, # Base multiplier for the number of channels.
channel_mult=[
1,
2,
2,
2,
], # Per-resolution multipliers for the number of channels.
channel_mult_emb=4, # Multiplier for the dimensionality of the embedding vector.
num_blocks=4, # Number of residual blocks per resolution.
attn_resolutions=[16], # List of resolutions with self-attention.
dropout=0.10, # Dropout probability of intermediate activations.
label_dropout=0, # Dropout probability of class labels for classifier-free guidance.
embedding_type="positional", # Timestep embedding type: 'positional' for DDPM++, 'fourier' for NCSN++.
channel_mult_noise=1, # Timestep embedding size: 1 for DDPM++, 2 for NCSN++.
encoder_type="standard", # Encoder architecture: 'standard' for DDPM++, 'residual' for NCSN++.
decoder_type="standard", # Decoder architecture: 'standard' for both DDPM++ and NCSN++.
resample_filter=[
1,
1,
], # Resampling filter: [1,1] for DDPM++, [1,3,3,1] for NCSN++.
):
"""
Initialize the SongUNet.
Parameters
----------
img_resolution : int or tuple
Image resolution.
in_channels : int
Input channels.
out_channels : int
Output channels.
label_dim : int, optional
Class label dimension.
augment_dim : int, optional
Augmentation label dimension.
model_channels : int, optional
Base channel multiplier.
channel_mult : list, optional
Channel multipliers per resolution.
channel_mult_emb : int, optional
Embedding channel multiplier.
num_blocks : int, optional
Blocks per resolution.
attn_resolutions : list, optional
Resolutions for attention.
dropout : float, optional
Dropout probability.
label_dropout : float, optional
Label dropout probability.
embedding_type : str, optional
Embedding type.
channel_mult_noise : int, optional
Noise embedding multiplier.
encoder_type : str, optional
Encoder type.
decoder_type : str, optional
Decoder type.
resample_filter : list, optional
Resampling filter coefficients.
"""
assert embedding_type in ["fourier", "positional"]
assert encoder_type in ["standard", "skip", "residual"]
assert decoder_type in ["standard", "skip"]
# Handle rectangular resolution
if isinstance(img_resolution, (tuple, list)):
assert (
len(img_resolution) == 2
), "img_resolution must be a tuple/list (height, width)"
self.img_resolution = img_resolution
self.img_height, self.img_width = img_resolution
else:
self.img_resolution = (img_resolution, img_resolution)
self.img_height = self.img_width = img_resolution
super().__init__()
self.label_dropout = label_dropout
emb_channels = model_channels * channel_mult_emb
noise_channels = model_channels * channel_mult_noise
init = dict(init_mode="xavier_uniform")
init_zero = dict(init_mode="xavier_uniform", init_weight=1e-5)
init_attn = dict(init_mode="xavier_uniform", init_weight=np.sqrt(0.2))
block_kwargs = dict(
emb_channels=emb_channels,
num_heads=1,
dropout=dropout,
skip_scale=np.sqrt(0.5),
eps=1e-6,
resample_filter=resample_filter,
resample_proj=True,
adaptive_scale=False,
init=init,
init_zero=init_zero,
init_attn=init_attn,
)
# Mapping.
self.map_noise = (
PositionalEmbedding(num_channels=noise_channels, endpoint=True)
if embedding_type == "positional"
else FourierEmbedding(num_channels=noise_channels)
)
self.map_label = (
Linear(in_features=label_dim, out_features=noise_channels, **init)
if label_dim
else None
)
self.map_augment = (
Linear(
in_features=augment_dim, out_features=noise_channels, bias=False, **init
)
if augment_dim
else None
)
self.map_layer0 = Linear(
in_features=noise_channels, out_features=emb_channels, **init
)
self.map_layer1 = Linear(
in_features=emb_channels, out_features=emb_channels, **init
)
# Encoder.
self.enc = torch.nn.ModuleDict()
cout = in_channels
caux = in_channels
for level, mult in enumerate(channel_mult):
# Calculate current resolution level
res_h = self.img_height >> level
res_w = self.img_width >> level
res_key = f"{res_h}x{res_w}"
if level == 0:
cin = cout
cout = model_channels
self.enc[f"{res_key}_conv"] = Conv2d(
in_channels=cin, out_channels=cout, kernel=3, **init
)
else:
self.enc[f"{res_key}_down"] = UNetBlock(
in_channels=cout, out_channels=cout, down=True, **block_kwargs
)
if encoder_type == "skip":
self.enc[f"{res_key}_aux_down"] = Conv2d(
in_channels=caux,
out_channels=caux,
kernel=0,
down=True,
resample_filter=resample_filter,
)
self.enc[f"{res_key}_aux_skip"] = Conv2d(
in_channels=caux, out_channels=cout, kernel=1, **init
)
if encoder_type == "residual":
self.enc[f"{res_key}_aux_residual"] = Conv2d(
in_channels=caux,
out_channels=cout,
kernel=3,
down=True,
resample_filter=resample_filter,
fused_resample=True,
**init,
)
caux = cout
for idx in range(num_blocks):
cin = cout
cout = model_channels * mult
# Check attention for rectangular resolution
attn = min(res_h, res_w) in attn_resolutions
self.enc[f"{res_key}_block{idx}"] = UNetBlock(
in_channels=cin, out_channels=cout, attention=attn, **block_kwargs
)
skips = [
block.out_channels for name, block in self.enc.items() if "aux" not in name
]
# Decoder.
self.dec = torch.nn.ModuleDict()
for level, mult in reversed(list(enumerate(channel_mult))):
# Calculate current resolution level
res_h = self.img_height >> level
res_w = self.img_width >> level
res_key = f"{res_h}x{res_w}"
if level == len(channel_mult) - 1:
self.dec[f"{res_key}_in0"] = UNetBlock(
in_channels=cout, out_channels=cout, attention=True, **block_kwargs
)
self.dec[f"{res_key}_in1"] = UNetBlock(
in_channels=cout, out_channels=cout, **block_kwargs
)
else:
self.dec[f"{res_key}_up"] = UNetBlock(
in_channels=cout, out_channels=cout, up=True, **block_kwargs
)
for idx in range(num_blocks + 1):
cin = cout + skips.pop()
cout = model_channels * mult
# Check attention for rectangular resolution
attn = idx == num_blocks and (min(res_h, res_w) in attn_resolutions)
self.dec[f"{res_key}_block{idx}"] = UNetBlock(
in_channels=cin, out_channels=cout, attention=attn, **block_kwargs
)
if decoder_type == "skip" or level == 0:
if decoder_type == "skip" and level < len(channel_mult) - 1:
self.dec[f"{res_key}_aux_up"] = Conv2d(
in_channels=out_channels,
out_channels=out_channels,
kernel=0,
up=True,
resample_filter=resample_filter,
)
self.dec[f"{res_key}_aux_norm"] = GroupNorm(num_channels=cout, eps=1e-6)
self.dec[f"{res_key}_aux_conv"] = Conv2d(
in_channels=cout, out_channels=out_channels, kernel=3, **init_zero
)
[docs]
def forward(self, x, noise_labels, class_labels, augment_labels=None):
"""
Forward pass through the U-Net.
Parameters
----------
x : torch.Tensor
Input tensor of shape (batch_size, in_channels, height, width).
noise_labels : torch.Tensor
Noise/timestep labels of shape (batch_size,).
class_labels : torch.Tensor or None
Class labels of shape (batch_size,) or (batch_size, label_dim).
Can be None if label_dim is 0.
augment_labels : torch.Tensor or None, optional
Augmentation labels of shape (batch_size, augment_dim).
Can be None if augment_dim is 0.
Returns
-------
torch.Tensor
Output tensor of shape (batch_size, out_channels, height, width).
Notes
-----
- The forward pass consists of three main stages:
1. Embedding mapping: combines noise, class, and augmentation embeddings.
2. Encoder: extracts hierarchical features with optional skip connections.
3. Decoder: reconstructs output with skip connections from encoder.
- Classifier-free guidance is supported via label_dropout.
- The noise embedding uses sinusoidal (positional) or Fourier features.
"""
# Mapping.
emb = self.map_noise(noise_labels)
emb = (
emb.reshape(emb.shape[0], 2, -1).flip(1).reshape(*emb.shape)
) # swap sin/cos
if self.map_label is not None:
tmp = class_labels
if self.training and self.label_dropout:
tmp = tmp * (
torch.rand([x.shape[0], 1], device=x.device) >= self.label_dropout
).to(tmp.dtype)
emb = emb + self.map_label(tmp * np.sqrt(self.map_label.in_features))
if self.map_augment is not None and augment_labels is not None:
emb = emb + self.map_augment(augment_labels)
emb = silu(self.map_layer0(emb))
emb = silu(self.map_layer1(emb))
# Encoder.
skips = []
aux = x
for name, block in self.enc.items():
if "aux_down" in name:
aux = block(aux)
elif "aux_skip" in name:
x = skips[-1] = x + block(aux)
elif "aux_residual" in name:
x = skips[-1] = aux = (x + block(aux)) / np.sqrt(2)
else:
x = block(x, emb) if isinstance(block, UNetBlock) else block(x)
skips.append(x)
# Decoder.
aux = None
tmp = None
for name, block in self.dec.items():
if "aux_up" in name:
aux = block(aux)
elif "aux_norm" in name:
tmp = block(x)
elif "aux_conv" in name:
tmp = block(silu(tmp))
aux = tmp if aux is None else tmp + aux
else:
if x.shape[1] != block.in_channels:
x = torch.cat([x, skips.pop()], dim=1)
x = block(x, emb)
return aux
# ----------------------------------------------------------------------------
# ADM architecture
[docs]
class DhariwalUNet(torch.nn.Module):
"""
U-Net architecture based on Dhariwal & Nichol (2021) for diffusion models.
This implementation follows the ADM (Ablated Diffusion Model) architecture
with configurable attention mechanisms, conditioning, and rectangular resolution
support. It features a U-Net structure with skip connections, group normalization,
and optional conditioning on class labels and augmentation.
Parameters
----------
img_resolution : int or tuple
Input image resolution. If int, assumes square images (img_resolution x img_resolution).
If tuple, should be (height, width).
in_channels : int
Number of input color channels.
out_channels : int
Number of output color channels.
label_dim : int, optional
Number of class labels. Set to 0 for unconditional generation.
Default is 0.
augment_dim : int, optional
Dimensionality of augmentation labels (e.g., time-dependent augmentation).
Set to 0 for no augmentation. Default is 0.
model_channels : int, optional
Base channel multiplier for the network.
Default is 128.
channel_mult : list of int, optional
Channel multipliers for each resolution level.
Default is [1, 2, 3, 4].
channel_mult_emb : int, optional
Multiplier for embedding dimensionality relative to model_channels.
Default is 4.
num_blocks : int, optional
Number of residual blocks per resolution.
Default is 3.
attn_resolutions : list of int, optional
List of resolutions (minimum dimension) to apply self-attention.
Default is [32, 16, 8].
dropout : float, optional
Dropout probability for intermediate activations.
Default is 0.10.
label_dropout : float, optional
Dropout probability for class labels (classifier-free guidance).
Default is 0.
diffusion_model : bool, optional
Whether to configure the network for diffusion models.
If True, includes timestep embedding; if False, only uses label conditioning.
Default is True.
Attributes
----------
img_resolution : tuple
Input image resolution as (height, width).
img_height : int
Input image height.
img_width : int
Input image width.
label_dropout : float
Class label dropout probability.
map_noise : PositionalEmbedding or None
Noise/timestep embedding module (if diffusion_model=True).
map_label : Linear or None
Class label embedding module.
map_augment : Linear or None
Augmentation label embedding module.
map_layer0, map_layer1 : Linear
Embedding transformation layers.
enc : torch.nn.ModuleDict
Encoder modules organized by resolution.
dec : torch.nn.ModuleDict
Decoder modules organized by resolution.
out_norm : GroupNorm
Final group normalization layer.
out_conv : Conv2d
Final convolutional output layer.
Raises
------
AssertionError
If img_resolution tuple doesn't have exactly 2 elements.
Notes
-----
- The architecture is based on the U-Net from "Diffusion Models Beat GANs on Image Synthesis".
- Features group normalization throughout and attention at multiple resolutions.
- Supports classifier-free guidance via label_dropout.
- Can be used for both diffusion models and other conditional generation tasks.
- Rectangular resolutions are supported by tracking height and width separately.
References
----------
- Dhariwal & Nichol, "Diffusion Models Beat GANs on Image Synthesis", 2021.
"""
[docs]
def __init__(
self,
img_resolution, # Image resolution as tuple (height, width)
in_channels, # Number of color channels at input.
out_channels, # Number of color channels at output.
label_dim=0, # Number of class labels, 0 = unconditional.
augment_dim=0, # Augmentation label dimensionality, 0 = no augmentation.
model_channels=128, # Base multiplier for the number of channels.
channel_mult=[
1,
2,
3,
4,
], # Per-resolution multipliers for the number of channels.
channel_mult_emb=4, # Multiplier for the dimensionality of the embedding vector.
num_blocks=3, # Number of residual blocks per resolution.
attn_resolutions=[32, 16, 8], # List of resolutions with self-attention.
dropout=0.10, # List of resolutions with self-attention.
label_dropout=0, # Dropout probability of class labels for classifier-free guidance.
diffusion_model=True, # Whether to use the Unet for diffusion models.
):
"""
Initialize the DhariwalUNet.
Parameters
----------
img_resolution : int or tuple
Image resolution.
in_channels : int
Input channels.
out_channels : int
Output channels.
label_dim : int, optional
Class label dimension.
augment_dim : int, optional
Augmentation label dimension.
model_channels : int, optional
Base channel multiplier.
channel_mult : list, optional
Channel multipliers per resolution.
channel_mult_emb : int, optional
Embedding channel multiplier.
num_blocks : int, optional
Blocks per resolution.
attn_resolutions : list, optional
Resolutions for attention.
dropout : float, optional
Dropout probability.
label_dropout : float, optional
Label dropout probability.
diffusion_model : bool, optional
Whether to configure for diffusion models.
"""
# Handle rectangular resolution
if isinstance(img_resolution, (tuple, list)):
assert (
len(img_resolution) == 2
), "img_resolution must be a tuple (height, width)"
self.img_resolution = img_resolution
self.img_height, self.img_width = img_resolution
else:
self.img_resolution = (img_resolution, img_resolution)
self.img_height = self.img_width = img_resolution
super().__init__()
self.label_dropout = label_dropout
emb_channels = model_channels * channel_mult_emb
init = dict(
init_mode="kaiming_uniform",
init_weight=np.sqrt(1 / 3),
init_bias=np.sqrt(1 / 3),
)
init_zero = dict(init_mode="kaiming_uniform", init_weight=0, init_bias=0)
block_kwargs = dict(
emb_channels=emb_channels,
channels_per_head=64,
dropout=dropout,
init=init,
init_zero=init_zero,
)
# Mapping.
self.map_noise = (
PositionalEmbedding(num_channels=model_channels)
if diffusion_model
else None
)
self.map_augment = (
Linear(
in_features=augment_dim,
out_features=model_channels,
bias=False,
**init_zero,
)
if augment_dim
else None
)
self.map_layer0 = Linear(
in_features=model_channels, out_features=emb_channels, **init
)
self.map_layer1 = Linear(
in_features=emb_channels, out_features=emb_channels, **init
)
self.map_label = (
Linear(
in_features=label_dim,
out_features=emb_channels,
bias=False,
init_mode="kaiming_normal",
init_weight=np.sqrt(label_dim),
)
if label_dim
else None
)
# Encoder.
self.enc = torch.nn.ModuleDict()
cout = in_channels
for level, mult in enumerate(channel_mult):
# Calculate current resolution level
res_h = self.img_height >> level
res_w = self.img_width >> level
res_key = f"{res_h}x{res_w}"
if level == 0:
cin = cout
cout = model_channels * mult
self.enc[f"{res_key}_conv"] = Conv2d(
in_channels=cin, out_channels=cout, kernel=3, **init
)
else:
self.enc[f"{res_key}_down"] = UNetBlock(
in_channels=cout, out_channels=cout, down=True, **block_kwargs
)
for idx in range(num_blocks):
cin = cout
cout = model_channels * mult
# Check attention for rectangular resolution
attn = min(res_h, res_w) in attn_resolutions
self.enc[f"{res_key}_block{idx}"] = UNetBlock(
in_channels=cin, out_channels=cout, attention=attn, **block_kwargs
)
skips = [block.out_channels for block in self.enc.values()]
# Decoder.
self.dec = torch.nn.ModuleDict()
for level, mult in reversed(list(enumerate(channel_mult))):
# Calculate current resolution level
res_h = self.img_height >> level
res_w = self.img_width >> level
res_key = f"{res_h}x{res_w}"
if level == len(channel_mult) - 1:
self.dec[f"{res_key}_in0"] = UNetBlock(
in_channels=cout, out_channels=cout, attention=True, **block_kwargs
)
self.dec[f"{res_key}_in1"] = UNetBlock(
in_channels=cout, out_channels=cout, **block_kwargs
)
else:
self.dec[f"{res_key}_up"] = UNetBlock(
in_channels=cout, out_channels=cout, up=True, **block_kwargs
)
for idx in range(num_blocks + 1):
cin = cout + skips.pop()
cout = model_channels * mult
# Check attention for rectangular resolution
attn = min(res_h, res_w) in attn_resolutions
self.dec[f"{res_key}_block{idx}"] = UNetBlock(
in_channels=cin, out_channels=cout, attention=attn, **block_kwargs
)
self.out_norm = GroupNorm(num_channels=cout)
self.out_conv = Conv2d(
in_channels=cout, out_channels=out_channels, kernel=3, **init_zero
)
[docs]
def forward(self, x, noise_labels=None, class_labels=None, augment_labels=None):
"""
Forward pass through the Dhariwal U-Net.
Parameters
----------
x : torch.Tensor
Input tensor of shape (batch_size, in_channels, height, width).
noise_labels : torch.Tensor or None
Noise/timestep labels of shape (batch_size,).
Required if diffusion_model=True, otherwise optional.
class_labels : torch.Tensor or None
Class labels of shape (batch_size,) or (batch_size, label_dim).
Can be None if label_dim is 0.
augment_labels : torch.Tensor or None, optional
Augmentation labels of shape (batch_size, augment_dim).
Can be None if augment_dim is 0.
Returns
-------
torch.Tensor
Output tensor of shape (batch_size, out_channels, height, width).
Notes
-----
- The forward pass combines conditioning embeddings (noise, class, augmentation)
and processes through encoder-decoder with skip connections.
- When diffusion_model=False, the noise_labels can be omitted.
- Classifier-free guidance is implemented via label_dropout during training.
"""
# Mapping.
emb = torch.zeros([1, self.map_layer1.in_features], device=x.device)
if self.map_label is not None:
tmp = class_labels
if self.training and self.label_dropout:
tmp = tmp * (
torch.rand([x.shape[0], 1], device=x.device) >= self.label_dropout
).to(tmp.dtype)
emb = self.map_label(tmp)
if self.map_noise is not None:
emb_n = self.map_noise(noise_labels)
emb_n = silu(self.map_layer0(emb_n))
emb_n = self.map_layer1(emb_n)
emb = emb + emb_n
if self.map_augment is not None and augment_labels is not None:
emb = emb + self.map_augment(augment_labels)
emb = silu(emb)
# Encoder.
skips = []
for block in self.enc.values():
x = block(x, emb) if isinstance(block, UNetBlock) else block(x)
skips.append(x)
# Decoder.
for block in self.dec.values():
if x.shape[1] != block.in_channels:
x = torch.cat([x, skips.pop()], dim=1)
x = block(x, emb)
x = self.out_conv(silu(self.out_norm(x)))
return x
# ----------------------------------------------------------------------------
# Preconditioning corresponding to the variance preserving (VP) formulation
[docs]
class VPPrecond(torch.nn.Module):
"""
Variance Preserving (VP) preconditioning for diffusion models.
This class implements preconditioning for the Variance Preserving formulation
of diffusion models, as described in Song et al. (2020). It wraps a base U-Net
model and applies the appropriate scaling and conditioning for VP SDEs.
Parameters
----------
img_resolution : int or tuple
Input image resolution. If int, assumes square images.
If tuple, should be (height, width).
in_channels : int
Number of input color channels.
out_channels : int
Number of output color channels.
label_dim : int, optional
Number of class labels. Set to 0 for unconditional generation.
Default is 0.
use_fp16 : bool, optional
Whether to execute the underlying model at FP16 precision for speed.
Default is False.
beta_d : float, optional
Extent of the noise level schedule. Controls the rate of noise increase.
Default is 19.9.
beta_min : float, optional
Initial slope of the noise level schedule.
Default is 0.1.
M : int, optional
Original number of timesteps in the DDPM formulation.
Default is 1000.
epsilon_t : float, optional
Minimum t-value used during training. Prevents numerical issues.
Default is 1e-5.
model_type : str, optional
Class name of the underlying U-Net model ('SongUNet' or 'DhariwalUNet').
Default is 'SongUNet'.
**model_kwargs : dict
Additional keyword arguments passed to the underlying model.
Attributes
----------
img_resolution : tuple
Input image resolution as (height, width).
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
label_dim : int
Number of class labels.
use_fp16 : bool
Whether to use FP16 precision.
sigma_min : float
Minimum noise level (sigma) corresponding to epsilon_t.
sigma_max : float
Maximum noise level (sigma) corresponding to t=1.
model : torch.nn.Module
The underlying U-Net model.
Notes
-----
- The VP formulation maintains unit variance throughout the diffusion process.
- The noise schedule follows: σ(t) = sqrt(exp(0.5*β_d*t² + β_min*t) - 1)
- The preconditioning applies scaling factors: c_skip, c_out, c_in, c_noise
- Supports conditional generation via class labels and condition images.
- Implements the continuous-time formulation of diffusion models.
References
----------
- Song et al., "Score-Based Generative Modeling through Stochastic Differential Equations", 2020.
"""
[docs]
def __init__(
self,
img_resolution, # Image resolution as tuple (height, width)
in_channels, # Number of color channels.
out_channels, # Number of color channels at output.
label_dim=0, # Number of class labels, 0 = unconditional.
use_fp16=False, # Execute the underlying model at FP16 precision?
beta_d=19.9, # Extent of the noise level schedule.
beta_min=0.1, # Initial slope of the noise level schedule.
M=1000, # Original number of timesteps in the DDPM formulation.
epsilon_t=1e-5, # Minimum t-value used during training.
model_type="SongUNet", # Class name of the underlying model.
**model_kwargs, # Keyword arguments for the underlying model.
):
"""
Initialize the VPPrecond module.
Parameters
----------
img_resolution : int or tuple
Image resolution.
in_channels : int
Input channels.
out_channels : int
Output channels.
label_dim : int, optional
Class label dimension.
use_fp16 : bool, optional
Use FP16 precision.
beta_d : float, optional
Noise schedule extent.
beta_min : float, optional
Initial noise schedule slope.
M : int, optional
Number of timesteps.
epsilon_t : float, optional
Minimum t-value.
model_type : str, optional
Underlying model class name.
**model_kwargs : dict
Additional model arguments.
"""
super().__init__()
# Store resolution for compatibility
if isinstance(img_resolution, (tuple, list)):
self.img_resolution = img_resolution
else:
self.img_resolution = (img_resolution, img_resolution)
self.in_channels = in_channels
self.out_channels = out_channels
self.label_dim = label_dim
self.use_fp16 = use_fp16
self.beta_d = beta_d
self.beta_min = beta_min
self.M = M
self.epsilon_t = epsilon_t
self.sigma_min = float(self.sigma(epsilon_t))
self.sigma_max = float(self.sigma(1))
self.model = globals()[model_type](
img_resolution=img_resolution,
in_channels=in_channels,
out_channels=self.out_channels,
label_dim=label_dim,
**model_kwargs,
)
[docs]
def forward(
self,
x,
sigma,
condition_img=None,
class_labels=None,
force_fp32=False,
**model_kwargs,
):
"""
Forward pass with VP preconditioning.
Parameters
----------
x : torch.Tensor
Input noisy tensor of shape (batch_size, in_channels, height, width).
sigma : torch.Tensor
Noise level(s) of shape (batch_size,) or scalar.
condition_img : torch.Tensor, optional
Condition image tensor of same spatial dimensions as x.
Default is None.
class_labels : torch.Tensor, optional
Class labels for conditioning of shape (batch_size,) or (batch_size, label_dim).
Default is None.
force_fp32 : bool, optional
Force FP32 precision even if use_fp16 is True.
Default is False.
**model_kwargs : dict
Additional keyword arguments passed to the underlying model.
Returns
-------
torch.Tensor
Denoised output of shape (batch_size, out_channels, height, width).
Notes
-----
- Applies the preconditioning: D(x) = c_skip * x + c_out * F(c_in * x, c_noise)
- Where F is the underlying U-Net model.
- c_in, c_out, c_skip, c_noise are computed from sigma according to VP formulation.
- Condition images are concatenated along the channel dimension.
"""
in_img = (
torch.cat([x, condition_img], dim=1) if condition_img is not None else x
)
sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
if self.label_dim == 0:
class_labels = None
elif class_labels is None:
class_labels = torch.zeros([1, self.label_dim], device=in_img.device)
else:
class_labels = class_labels.to(torch.float32).reshape(-1, self.label_dim)
dtype = (
torch.float16
if (self.use_fp16 and not force_fp32 and in_img.device.type == "cuda")
else torch.float32
)
c_skip = 1
c_out = -sigma
c_in = 1 / (sigma**2 + 1).sqrt()
c_noise = (self.M - 1) * self.sigma_inv(sigma)
F_x = self.model(
(c_in * in_img).to(dtype),
c_noise.flatten(),
class_labels=class_labels,
**model_kwargs,
).to(dtype)
assert F_x.dtype == dtype
D_x = c_skip * x + c_out * F_x.to(torch.float32)
return D_x
[docs]
def sigma(self, t):
"""
Compute noise level sigma for given time t.
Parameters
----------
t : float or torch.Tensor
Time value(s) in [epsilon_t, 1].
Returns
-------
torch.Tensor
Noise level sigma corresponding to t.
Notes
-----
Formula: σ(t) = sqrt(exp(0.5*β_d*t² + β_min*t) - 1)
"""
t = torch.as_tensor(t)
return ((0.5 * self.beta_d * (t**2) + self.beta_min * t).exp() - 1).sqrt()
[docs]
def sigma_inv(self, sigma):
"""
Inverse function: compute time t for given noise level sigma.
Parameters
----------
sigma : float or torch.Tensor
Noise level(s).
Returns
-------
torch.Tensor
Time t corresponding to sigma.
Notes
-----
Formula: t = (sqrt(β_min² + 2*β_d*log(1+σ²)) - β_min) / β_d
"""
sigma = torch.as_tensor(sigma)
return (
(self.beta_min**2 + 2 * self.beta_d * (1 + sigma**2).log()).sqrt()
- self.beta_min
) / self.beta_d
[docs]
def round_sigma(self, sigma):
"""
Round noise level(s) for compatibility with discrete schedules.
Parameters
----------
sigma : float or torch.Tensor
Noise level(s).
Returns
-------
torch.Tensor
Rounded noise level(s).
"""
return torch.as_tensor(sigma)
# ----------------------------------------------------------------------------
# Preconditioning corresponding to the variance exploding (VE) formulation
[docs]
class VEPrecond(torch.nn.Module):
"""
Variance Exploding (VE) preconditioning for diffusion models.
This class implements preconditioning for the Variance Exploding formulation
of diffusion models, as described in Song et al. (2020). It wraps a base U-Net
model and applies the appropriate scaling and conditioning for VE SDEs.
Parameters
----------
img_resolution : int or tuple
Input image resolution. If int, assumes square images.
If tuple, should be (height, width).
in_channels : int
Number of input color channels.
out_channels : int
Number of output color channels.
label_dim : int, optional
Number of class labels. Set to 0 for unconditional generation.
Default is 0.
use_fp16 : bool, optional
Whether to execute the underlying model at FP16 precision for speed.
Default is False.
sigma_min : float, optional
Minimum supported noise level.
Default is 0.02.
sigma_max : float, optional
Maximum supported noise level.
Default is 100.
model_type : str, optional
Class name of the underlying U-Net model ('SongUNet' or 'DhariwalUNet').
Default is 'SongUNet'.
**model_kwargs : dict
Additional keyword arguments passed to the underlying model.
Attributes
----------
img_resolution : tuple
Input image resolution as (height, width).
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
label_dim : int
Number of class labels.
use_fp16 : bool
Whether to use FP16 precision.
sigma_min : float
Minimum noise level.
sigma_max : float
Maximum noise level.
model : torch.nn.Module
The underlying U-Net model.
Notes
-----
- The VE formulation uses a simple geometric noise schedule.
- The preconditioning applies scaling factors: c_skip, c_out, c_in, c_noise
- c_noise = 0.5 * log(sigma) maps noise levels to conditioning inputs.
- Supports conditional generation via class labels and condition images.
References
----------
- Song et al., "Score-Based Generative Modeling through Stochastic Differential Equations", 2020.
"""
[docs]
def __init__(
self,
img_resolution, # Image resolution as tuple (height, width)
in_channels, # Number of color channels.
out_channels, # Number of color channels at output.
label_dim=0, # Number of class labels, 0 = unconditional.
use_fp16=False, # Execute the underlying model at FP16 precision?
sigma_min=0.02, # Minimum supported noise level.
sigma_max=100, # Maximum supported noise level.
model_type="SongUNet", # Class name of the underlying model.
**model_kwargs, # Keyword arguments for the underlying model.
):
"""
Initialize the VEPrecond module.
Parameters
----------
img_resolution : int or tuple
Image resolution.
in_channels : int
Input channels.
out_channels : int
Output channels.
label_dim : int, optional
Class label dimension.
use_fp16 : bool, optional
Use FP16 precision.
sigma_min : float, optional
Minimum noise level.
sigma_max : float, optional
Maximum noise level.
model_type : str, optional
Underlying model class name.
**model_kwargs : dict
Additional model arguments.
"""
super().__init__()
# Store resolution for compatibility
if isinstance(img_resolution, (tuple, list)):
self.img_resolution = img_resolution
else:
self.img_resolution = (img_resolution, img_resolution)
self.in_channels = in_channels
self.out_channels = out_channels
self.label_dim = label_dim
self.use_fp16 = use_fp16
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.model = globals()[model_type](
img_resolution=img_resolution,
in_channels=in_channels,
out_channels=self.out_channels,
label_dim=label_dim,
**model_kwargs,
)
[docs]
def forward(
self,
x,
sigma,
condition_img=None,
class_labels=None,
force_fp32=False,
**model_kwargs,
):
"""
Forward pass with VE preconditioning.
Parameters
----------
x : torch.Tensor
Input noisy tensor of shape (batch_size, in_channels, height, width).
sigma : torch.Tensor
Noise level(s) of shape (batch_size,) or scalar.
condition_img : torch.Tensor, optional
Condition image tensor of same spatial dimensions as x.
Default is None.
class_labels : torch.Tensor, optional
Class labels for conditioning of shape (batch_size,) or (batch_size, label_dim).
Default is None.
force_fp32 : bool, optional
Force FP32 precision even if use_fp16 is True.
Default is False.
**model_kwargs : dict
Additional keyword arguments passed to the underlying model.
Returns
-------
torch.Tensor
Denoised output of shape (batch_size, out_channels, height, width).
Notes
-----
- Applies the preconditioning: D(x) = c_skip * x + c_out * F(c_in * x, c_noise)
- Where F is the underlying U-Net model.
- For VE: c_skip = 1, c_out = sigma, c_in = 1, c_noise = 0.5 * log(sigma)
- Condition images are concatenated along the channel dimension.
"""
in_img = (
torch.cat([x, condition_img], dim=1) if condition_img is not None else x
)
sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
if self.label_dim == 0:
class_labels = None
elif class_labels is None:
class_labels = torch.zeros([1, self.label_dim], device=in_img.device)
else:
class_labels = class_labels.to(torch.float32).reshape(-1, self.label_dim)
dtype = (
torch.float16
if (self.use_fp16 and not force_fp32 and in_img.device.type == "cuda")
else torch.float32
)
c_skip = 1
c_out = sigma
c_in = 1
c_noise = (0.5 * sigma).log()
F_x = self.model(
(c_in * in_img).to(dtype),
c_noise.flatten(),
class_labels=class_labels,
**model_kwargs,
).to(dtype)
assert F_x.dtype == dtype
D_x = c_skip * x + c_out * F_x.to(torch.float32)
return D_x
[docs]
def round_sigma(self, sigma):
"""
Round noise level(s) for compatibility with discrete schedules.
Parameters
----------
sigma : float or torch.Tensor
Noise level(s).
Returns
-------
torch.Tensor
Rounded noise level(s).
"""
return torch.as_tensor(sigma)
# ----------------------------------------------------------------------------
# Preconditioning corresponding to improved DDPM (iDDPM) formulation
[docs]
class iDDPMPrecond(torch.nn.Module):
"""
Improved DDPM (iDDPM) preconditioning for diffusion models.
This class implements the improved preconditioning scheme from the iDDPM paper,
which refines the noise schedule and preconditioning for better sample quality.
It provides a bridge between discrete-time DDPM formulations and continuous-time
SDE formulations.
Parameters
----------
img_resolution : int or tuple
Input image resolution. If int, assumes square images.
If tuple, should be (height, width).
in_channels : int
Number of input color channels.
out_channels : int
Number of output color channels.
label_dim : int, optional
Number of class labels. Set to 0 for unconditional generation.
Default is 0.
use_fp16 : bool, optional
Whether to execute the underlying model at FP16 precision for speed.
Default is False.
C_1 : float, optional
Timestep adjustment parameter for low noise levels.
Default is 0.001.
C_2 : float, optional
Timestep adjustment parameter for high noise levels.
Default is 0.008.
M : int, optional
Original number of timesteps in the DDPM formulation.
Default is 1000.
model_type : str, optional
Class name of the underlying U-Net model ('SongUNet' or 'DhariwalUNet').
Default is 'DhariwalUNet'.
**model_kwargs : dict
Additional keyword arguments passed to the underlying model.
Attributes
----------
img_resolution : tuple
Input image resolution as (height, width).
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
label_dim : int
Number of class labels.
use_fp16 : bool
Whether to use FP16 precision.
sigma_min : float
Minimum noise level (learned from schedule).
sigma_max : float
Maximum noise level (learned from schedule).
u : torch.Tensor (buffer)
Learned noise schedule of length M+1.
model : torch.nn.Module
The underlying U-Net model.
Notes
-----
- The iDDPM formulation improves upon DDPM with a refined noise schedule.
- The noise schedule is learned during initialization via backward recursion.
- Uses alpha_bar schedule: ᾱ(j) = sin(π/2 * j/M/(C₂+1))²
- Implements discrete-time preconditioning with improved numerical stability.
References
----------
- Nichol & Dhariwal, "Improved Denoising Diffusion Probabilistic Models", 2021.
"""
[docs]
def __init__(
self,
img_resolution, # Image resolution as tuple (height, width)
in_channels, # Number of color channels.
out_channels, # Number of color channels at output.
label_dim=0, # Number of class labels, 0 = unconditional.
use_fp16=False, # Execute the underlying model at FP16 precision?
C_1=0.001, # Timestep adjustment at low noise levels.
C_2=0.008, # Timestep adjustment at high noise levels.
M=1000, # Original number of timesteps in the DDPM formulation.
model_type="DhariwalUNet", # Class name of the underlying model.
**model_kwargs, # Keyword arguments for the underlying model.
):
"""
Initialize the iDDPMPrecond module.
Parameters
----------
img_resolution : int or tuple
Image resolution.
in_channels : int
Input channels.
out_channels : int
Output channels.
label_dim : int, optional
Class label dimension.
use_fp16 : bool, optional
Use FP16 precision.
C_1 : float, optional
Low noise adjustment.
C_2 : float, optional
High noise adjustment.
M : int, optional
Number of timesteps.
model_type : str, optional
Underlying model class name.
**model_kwargs : dict
Additional model arguments.
"""
super().__init__()
# Store resolution for compatibility
if isinstance(img_resolution, (tuple, list)):
self.img_resolution = img_resolution
else:
self.img_resolution = (img_resolution, img_resolution)
self.in_channels = in_channels
self.out_channels = out_channels
self.label_dim = label_dim
self.use_fp16 = use_fp16
self.C_1 = C_1
self.C_2 = C_2
self.M = M
self.model = globals()[model_type](
img_resolution=img_resolution,
in_channels=in_channels,
out_channels=self.out_channels,
label_dim=label_dim,
**model_kwargs,
)
u = torch.zeros(M + 1)
for j in range(M, 0, -1): # M, ..., 1
u[j - 1] = (
(u[j] ** 2 + 1)
/ (self.alpha_bar(j - 1) / self.alpha_bar(j)).clip(min=C_1)
- 1
).sqrt()
self.register_buffer("u", u)
self.sigma_min = float(u[M - 1])
self.sigma_max = float(u[0])
[docs]
def forward(
self,
x,
sigma,
condition_img=None,
class_labels=None,
force_fp32=False,
**model_kwargs,
):
"""
Forward pass with iDDPM preconditioning.
Parameters
----------
x : torch.Tensor
Input noisy tensor of shape (batch_size, in_channels, height, width).
sigma : torch.Tensor
Noise level(s) of shape (batch_size,) or scalar.
condition_img : torch.Tensor, optional
Condition image tensor of same spatial dimensions as x.
Default is None.
class_labels : torch.Tensor, optional
Class labels for conditioning of shape (batch_size,) or (batch_size, label_dim).
Default is None.
force_fp32 : bool, optional
Force FP32 precision even if use_fp16 is True.
Default is False.
**model_kwargs : dict
Additional keyword arguments passed to the underlying model.
Returns
-------
torch.Tensor
Denoised output of shape (batch_size, out_channels, height, width).
Notes
-----
- Applies the preconditioning: D(x) = c_skip * x + c_out * F(c_in * x, c_noise)
- Where F is the underlying U-Net model.
- For iDDPM: c_skip = 1, c_out = -σ, c_in = 1/√(σ²+1)
- Condition images are concatenated along the channel dimension.
- c_noise maps sigma to discrete timesteps for the underlying model.
"""
if condition_img is not None:
in_img = torch.cat([x, condition_img], dim=1) # [B, C + C_cond, H, W]
else:
in_img = x
sigma = sigma.reshape(-1, 1, 1, 1)
# Prepare class labels
if self.label_dim == 0:
class_labels = None
elif class_labels is None:
class_labels = torch.zeros(
[in_img.shape[0], self.label_dim], device=in_img.device
)
else:
class_labels = class_labels.to(torch.float32).reshape(-1, self.label_dim)
dtype = (
torch.float16
if (self.use_fp16 and not force_fp32 and in_img.device.type == "cuda")
else torch.float32
)
# Diffusion coefficients
c_skip = 1
c_out = -sigma
c_in = 1 / (sigma**2 + 1).sqrt()
# Noise label calculation for model
c_noise = (
self.M - 1 - self.round_sigma(sigma, return_index=True).to(torch.float32)
)
# Forward pass through underlying model
F_x = self.model(
(c_in * in_img).to(dtype),
noise_labels=c_noise.flatten(),
class_labels=class_labels,
**model_kwargs,
).to(dtype)
assert F_x.dtype == dtype
D_x = c_skip * x + c_out * F_x[:, : self.in_channels].to(torch.float32)
return D_x
[docs]
def alpha_bar(self, j):
"""
Compute alpha_bar for timestep j in the improved schedule.
Parameters
----------
j : int or torch.Tensor
Timestep index (0 <= j <= M).
Returns
-------
torch.Tensor
ᾱ(j) = sin(π/2 * j/M/(C₂+1))²
"""
j = torch.as_tensor(j)
return (0.5 * np.pi * j / self.M / (self.C_2 + 1)).sin() ** 2
[docs]
def round_sigma(self, sigma, return_index=False):
"""
Round noise level(s) to the nearest value in the learned schedule.
Parameters
----------
sigma : torch.Tensor
Noise level(s).
return_index : bool, optional
If True, return the index in the schedule instead of the value.
Default is False.
Returns
-------
torch.Tensor
Rounded noise level(s) or indices.
"""
sigma = torch.as_tensor(sigma)
index = torch.cdist(
sigma.to(self.u.device).to(torch.float32).reshape(1, -1, 1),
self.u.reshape(1, -1, 1),
).argmin(2)
result = index if return_index else self.u[index.flatten()].to(sigma.dtype)
return result.reshape(sigma.shape).to(sigma.device)
# ----------------------------------------------------------------------------
# Improved preconditioning proposed in the paper "Elucidating the Design
# Space of Diffusion-Based Generative Models" (EDM).
[docs]
class EDMPrecond(torch.nn.Module):
"""
EDM preconditioning for diffusion models.
This class implements the EDM (Elucidating Diffusion Models) preconditioning
scheme, which provides a unified framework for various diffusion formulations
with optimized preconditioning coefficients.
Parameters
----------
img_resolution : int or tuple
Input image resolution. If int, assumes square images.
If tuple, should be (height, width).
in_channels : int
Number of input color channels.
out_channels : int
Number of output color channels.
label_dim : int, optional
Number of class labels. Set to 0 for unconditional generation.
Default is 0.
use_fp16 : bool, optional
Whether to execute the underlying model at FP16 precision for speed.
Default is False.
sigma_min : float, optional
Minimum supported noise level.
Default is 0.
sigma_max : float, optional
Maximum supported noise level.
Default is float('inf').
sigma_data : float, optional
Standard deviation of the training data.
Default is 1.0.
model_type : str, optional
Class name of the underlying U-Net model ('SongUNet' or 'DhariwalUNet').
Default is 'DhariwalUNet'.
**model_kwargs : dict
Additional keyword arguments passed to the underlying model.
Attributes
----------
img_resolution : tuple
Input image resolution as (height, width).
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
label_dim : int
Number of class labels.
use_fp16 : bool
Whether to use FP16 precision.
sigma_min : float
Minimum noise level.
sigma_max : float
Maximum noise level.
sigma_data : float
Training data standard deviation.
model : torch.nn.Module
The underlying U-Net model.
Notes
-----
- The EDM formulation provides a unified preconditioning scheme that
generalizes VP, VE, and other diffusion formulations.
- Preconditioning coefficients: c_skip = σ_data²/(σ²+σ_data²)
c_out = σ·σ_data/√(σ²+σ_data²), c_in = 1/√(σ_data²+σ²)
- c_noise = log(σ)/4 provides the noise conditioning input.
- This formulation often yields better sample quality and training stability.
References
----------
- Karras et al., "Elucidating the Design Space of Diffusion-Based Generative Models", 2022.
"""
[docs]
def __init__(
self,
img_resolution, # Image resolution.
in_channels, # Number of input channels.
out_channels, # Number of output channels.
label_dim=0, # Number of class labels.
use_fp16=False, # FP16 execution?
sigma_min=0, # Min noise level.
sigma_max=float("inf"), # Max noise level.
sigma_data=1.0, # Training data std.
model_type="DhariwalUNet", # Underlying model class.
**model_kwargs, # Keyword args.
):
"""
Initialize the EDMPrecond module.
Parameters
----------
img_resolution : int or tuple
Image resolution.
in_channels : int
Input channels.
out_channels : int
Output channels.
label_dim : int, optional
Class label dimension.
use_fp16 : bool, optional
Use FP16 precision.
sigma_min : float, optional
Minimum noise level.
sigma_max : float, optional
Maximum noise level.
sigma_data : float, optional
Training data standard deviation.
model_type : str, optional
Underlying model class name.
**model_kwargs : dict
Additional model arguments.
"""
super().__init__()
# Store resolution for compatibility
if isinstance(img_resolution, (tuple, list)):
self.img_resolution = img_resolution
else:
self.img_resolution = (img_resolution, img_resolution)
self.in_channels = in_channels
self.out_channels = out_channels
self.label_dim = label_dim
self.use_fp16 = use_fp16
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.sigma_data = sigma_data
# keep names exactly the same
self.model = globals()[model_type](
img_resolution=img_resolution,
in_channels=in_channels,
out_channels=out_channels,
label_dim=label_dim,
**model_kwargs,
)
[docs]
def forward(
self,
x,
sigma,
condition_img=None,
class_labels=None,
force_fp32=True,
**model_kwargs,
):
"""
Forward pass with EDM preconditioning.
Parameters
----------
x : torch.Tensor
Input noisy tensor of shape (batch_size, in_channels, height, width).
sigma : torch.Tensor
Noise level(s) of shape (batch_size,) or scalar.
condition_img : torch.Tensor, optional
Condition image tensor of same spatial dimensions as x.
Default is None.
class_labels : torch.Tensor, optional
Class labels for conditioning of shape (batch_size,) or (batch_size, label_dim).
Default is None.
force_fp32 : bool, optional
Force FP32 precision even if use_fp16 is True.
Default is True.
**model_kwargs : dict
Additional keyword arguments passed to the underlying model.
Returns
-------
torch.Tensor
Denoised output of shape (batch_size, out_channels, height, width).
Notes
-----
- Applies the EDM preconditioning: D(x) = c_skip * x + c_out * F(c_in * x, c_noise)
- Where F is the underlying U-Net model.
- EDM coefficients: c_skip = σ_data²/(σ²+σ_data²)
c_out = σ·σ_data/√(σ²+σ_data²), c_in = 1/√(σ_data²+σ²)
- Condition images are concatenated along the channel dimension.
- c_noise = log(σ)/4 provides the noise conditioning.
"""
# -----------------------------
# Input concatenation
# -----------------------------
if condition_img is not None:
in_img = torch.cat([x, condition_img], dim=1)
else:
in_img = x
sigma = sigma.reshape(-1, 1, 1, 1)
# -----------------------------
# Class labels (same variable name, UNet#2-compatible)
# -----------------------------
if self.label_dim == 0:
class_labels = None
else:
if class_labels is None:
class_labels = torch.zeros(
[in_img.shape[0], self.label_dim], device=in_img.device
)
else:
class_labels = class_labels.to(torch.float32).reshape(
-1, self.label_dim
)
# -----------------------------
# Precision logic
# -----------------------------
dtype = (
torch.float16
if (self.use_fp16 and not force_fp32 and in_img.device.type == "cuda")
else torch.float32
)
# -----------------------------
# EDM coefficients
# -----------------------------
c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
c_out = sigma * self.sigma_data / torch.sqrt(sigma**2 + self.sigma_data**2)
c_in = 1.0 / torch.sqrt(self.sigma_data**2 + sigma**2)
c_noise = sigma.log() / 4
# -----------------------------
# Call the UNet
# -----------------------------
F_x = self.model(
(c_in * in_img).to(dtype),
noise_labels=c_noise.flatten(), # required by UNet
class_labels=class_labels, # UNet optional labels
**model_kwargs,
).to(dtype)
# -----------------------------
# Output reconstruction
# -----------------------------
D_x = c_skip * x + c_out * F_x
return D_x
[docs]
def round_sigma(self, sigma):
"""
Round noise level(s) for compatibility with discrete schedules.
Parameters
----------
sigma : float or torch.Tensor
Noise level(s).
Returns
-------
torch.Tensor
Rounded noise level(s).
Notes
-----
In EDM, sigma is continuous, so rounding is typically a no-op
unless implementing a discrete schedule variant.
"""
return torch.as_tensor(sigma)
# ----------------------------------------------------------------------------
#
[docs]
class TestDiffusionNetworks(unittest.TestCase):
"""Unit tests for diffusion network architectures."""
[docs]
def __init__(self, methodName="runTest", logger=None):
super().__init__(methodName)
self.logger = logger
[docs]
def setUp(self):
"""Set up test fixtures."""
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.batch_size = 2
self.in_channels = 3
self.cond_channels = 2
self.out_channels = 3
self.label_dim = 2
if self.logger:
self.logger.info(f"Test setup complete - using device: {self.device}")
[docs]
def test_song_unet_square_resolution(self):
"""Test SongUNet with square resolution."""
if self.logger:
self.logger.info("Testing SongUNet with square resolution")
img_resolution = 64
total_in_channels = self.in_channels + self.cond_channels
model = SongUNet(
img_resolution=img_resolution,
in_channels=total_in_channels, # Use total channels including conditional
out_channels=self.out_channels,
label_dim=self.label_dim,
model_channels=32,
channel_mult=[1, 2],
attn_resolutions=[32],
embedding_type="positional",
).to(self.device)
# Test forward pass - concatenate input and conditional image
x = torch.randn(
self.batch_size, self.in_channels, img_resolution, img_resolution
).to(self.device)
cond_img = torch.randn(
self.batch_size, self.cond_channels, img_resolution, img_resolution
).to(self.device)
input_img = torch.cat(
[x, cond_img], dim=1
) # Concatenate along channel dimension
noise_labels = torch.randn(self.batch_size).to(self.device)
class_labels = torch.randint(0, self.label_dim, (self.batch_size,)).to(
self.device
)
with torch.no_grad():
output = model(input_img, noise_labels, class_labels)
self.assertEqual(
output.shape,
(self.batch_size, self.out_channels, img_resolution, img_resolution),
)
if self.logger:
self.logger.info(
f"✅ SongUNet square test passed - output shape: {output.shape}"
)
[docs]
def test_song_unet_rectangular_resolution(self):
"""Test SongUNet with rectangular resolution."""
if self.logger:
self.logger.info("Testing SongUNet with rectangular resolution")
img_resolution = (64, 32)
total_in_channels = self.in_channels + self.cond_channels
model = SongUNet(
img_resolution=img_resolution,
in_channels=total_in_channels, # Use total channels including conditional
out_channels=self.out_channels,
label_dim=self.label_dim,
model_channels=32,
channel_mult=[1, 2],
attn_resolutions=[16],
embedding_type="fourier",
).to(self.device)
# Test forward pass - concatenate input and conditional image
x = torch.randn(self.batch_size, self.in_channels, *img_resolution).to(
self.device
)
cond_img = torch.randn(self.batch_size, self.cond_channels, *img_resolution).to(
self.device
)
input_img = torch.cat(
[x, cond_img], dim=1
) # Concatenate along channel dimension
noise_labels = torch.randn(self.batch_size).to(self.device)
class_labels = torch.randint(0, self.label_dim, (self.batch_size,)).to(
self.device
)
with torch.no_grad():
output = model(input_img, noise_labels, class_labels)
self.assertEqual(
output.shape, (self.batch_size, self.out_channels, *img_resolution)
)
if self.logger:
self.logger.info(
f"✅ SongUNet rectangular test passed - output shape: {output.shape}"
)
[docs]
def test_dhariwal_unet(self):
"""Test DhariwalUNet architecture."""
if self.logger:
self.logger.info("Testing DhariwalUNet")
img_resolution = (128, 64)
total_in_channels = self.in_channels + self.cond_channels
model = DhariwalUNet(
img_resolution=img_resolution,
in_channels=total_in_channels, # Use total channels including conditional
out_channels=self.out_channels,
label_dim=self.label_dim,
model_channels=32,
channel_mult=[1, 2],
attn_resolutions=[32, 16],
).to(self.device)
# Test forward pass - concatenate input and conditional image
x = torch.randn(self.batch_size, self.in_channels, *img_resolution).to(
self.device
)
cond_img = torch.randn(self.batch_size, self.cond_channels, *img_resolution).to(
self.device
)
input_img = torch.cat(
[x, cond_img], dim=1
) # Concatenate along channel dimension
noise_labels = torch.randn(self.batch_size).to(self.device)
class_labels = (
torch.randint(0, self.label_dim, (self.batch_size,)).to(self.device).float()
) # Convert to float
with torch.no_grad():
output = model(input_img, noise_labels, class_labels)
self.assertEqual(
output.shape, (self.batch_size, self.out_channels, *img_resolution)
)
if self.logger:
self.logger.info(
f"✅ DhariwalUNet test passed - output shape: {output.shape}"
)
[docs]
def test_vp_preconditioner(self):
"""Test VPPrecond with conditional images."""
if self.logger:
self.logger.info("Testing VPPrecond")
img_resolution = 64
total_in_channels = self.in_channels + self.cond_channels
model = VPPrecond(
img_resolution=img_resolution,
in_channels=total_in_channels,
out_channels=self.out_channels,
label_dim=self.label_dim,
use_fp16=False,
model_type="SongUNet",
model_channels=32,
channel_mult=[1, 2],
).to(self.device)
# Test forward pass
x = torch.randn(
self.batch_size, self.in_channels, img_resolution, img_resolution
).to(self.device)
cond_img = torch.randn(
self.batch_size, self.cond_channels, img_resolution, img_resolution
).to(self.device)
sigma = torch.tensor([0.1, 0.5], device=self.device)
class_labels = torch.randint(
0, self.label_dim, (self.batch_size, 2), device=self.device
)
with torch.no_grad():
output = model(x, sigma, condition_img=cond_img, class_labels=class_labels)
self.assertEqual(output.shape, x.shape)
if self.logger:
self.logger.info(f"✅ VPPrecond test passed - output shape: {output.shape}")
[docs]
def test_ve_preconditioner(self):
"""Test VEPrecond with conditional images."""
if self.logger:
self.logger.info("Testing VEPrecond")
img_resolution = (64, 32)
total_in_channels = self.in_channels + self.cond_channels
model = VEPrecond(
img_resolution=img_resolution,
in_channels=total_in_channels,
out_channels=self.out_channels,
label_dim=self.label_dim,
use_fp16=False,
model_type="SongUNet",
model_channels=32,
channel_mult=[1, 2],
).to(self.device)
# Test forward pass
x = torch.randn(self.batch_size, self.in_channels, *img_resolution).to(
self.device
)
cond_img = torch.randn(self.batch_size, self.cond_channels, *img_resolution).to(
self.device
)
sigma = torch.tensor([0.1, 0.5], device=self.device)
class_labels = torch.randint(
0, self.label_dim, (self.batch_size, 2), device=self.device
)
with torch.no_grad():
output = model(x, sigma, condition_img=cond_img, class_labels=class_labels)
self.assertEqual(output.shape, x.shape)
if self.logger:
self.logger.info(f"✅ VEPrecond test passed - output shape: {output.shape}")
[docs]
def test_edm_preconditioner(self):
"""Test EDMPrecond with conditional images."""
if self.logger:
self.logger.info("Testing EDMPrecond")
img_resolution = (128, 64)
total_in_channels = self.in_channels + self.cond_channels
model = EDMPrecond(
img_resolution=img_resolution,
in_channels=total_in_channels,
out_channels=self.out_channels,
label_dim=self.label_dim,
use_fp16=False,
model_type="DhariwalUNet",
model_channels=32,
channel_mult=[1, 2],
).to(self.device)
# Test forward pass
x = torch.randn(self.batch_size, self.in_channels, *img_resolution).to(
self.device
)
cond_img = torch.randn(self.batch_size, self.cond_channels, *img_resolution).to(
self.device
)
sigma = torch.tensor([0.1, 0.5], device=self.device)
class_labels = torch.randint(
0, self.label_dim, (self.batch_size, 2), device=self.device
)
with torch.no_grad():
output = model(x, sigma, condition_img=cond_img, class_labels=class_labels)
self.assertEqual(output.shape, x.shape)
if self.logger:
self.logger.info(
f"✅ EDMPrecond test passed - output shape: {output.shape}"
)
[docs]
def test_parameter_counts(self):
"""Test that all models have reasonable parameter counts."""
if self.logger:
self.logger.info("Testing parameter counts")
configs = [
(
"SongUNet-Small",
SongUNet,
{"model_channels": 32, "channel_mult": [1, 2]},
),
(
"SongUNet-Medium",
SongUNet,
{"model_channels": 64, "channel_mult": [1, 2, 2]},
),
(
"DhariwalUNet-Small",
DhariwalUNet,
{"model_channels": 32, "channel_mult": [1, 2]},
),
]
for name, model_class, kwargs in configs:
with self.subTest(model=name):
model = model_class(
img_resolution=64,
in_channels=self.in_channels
+ self.cond_channels, # Use total channels
out_channels=self.out_channels,
label_dim=self.label_dim,
**kwargs,
).to(self.device)
total_params = sum(p.numel() for p in model.parameters())
self.assertGreater(
total_params, 1000
) # Should have at least 1K parameters
if self.logger:
self.logger.info(f"✅ {name} parameter count: {total_params:,}")
[docs]
def tearDown(self):
"""Clean up after tests."""
if self.logger:
self.logger.info("Network tests completed successfully")
# ----------------------------------------------------------------------------