# Copyright 2026 IPSL / CNRS / Sorbonne University
# Authors: Kazem Ardaneh
#
# ============================================================================
# ORIGINAL WORK (NVIDIA)
# ============================================================================
# This work is a derivative of "Elucidating the Design Space of
# Diffusion-Based Generative Models" by NVIDIA CORPORATION & AFFILIATES.
#
# Original work: Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES.
# Original license: Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
# Original source: https://github.com/NVlabs/edm
#
# ============================================================================
# MODIFICATIONS AND ADDITIONS (IPSL / CNRS / Sorbonne University)
# ============================================================================
# Modifications to loss functions include:
# 1. Added conditional image support to all loss functions
# - Extended VPLoss, VELoss, EDMLoss with conditional_img parameter
# - Modified __call__ method to pass conditional images to the network
# - Updated documentation to reflect conditional image usage
#
# 2. Added UnetLoss class for non-diffusion UNet training
# - Created new loss class for direct image-to-image prediction tasks
# - Supports MSE, L1, and Smooth L1 loss types
# - Compatible with UNet architectures (DhariwalUNet, SongUNet)
# - Includes support for data augmentation and conditioning
#
# 3. Enhanced documentation
# - Added comprehensive docstrings for all classes and methods
# - Included mathematical formulas and training procedures
# - Added usage examples and parameter descriptions
#
# 4. Added comprehensive unit tests
# - Created TestLosses class with test methods for all loss functions
# - Added tests for loss gradients and numerical stability
# - Added tests with data augmentation
# - Added loss comparison tests
# - Added rectangular resolution support tests
#
# 5. Code quality improvements
# - Added type hints for better code clarity
# - Improved variable naming
# - Added input validation where appropriate
#
# ============================================================================
# LICENSE
# ============================================================================
# This derivative work is licensed under the same terms as the original:
# Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License.
# To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc-sa/4.0/
# ============================================================================
# ACKNOWLEDGMENTS
# ============================================================================
# We thank the NVIDIA team for their excellent work on EDM and for making it
# available under an open license that enables further research and development.
"""
Diffusion model loss functions and testing utilities.
This module implements various loss functions for diffusion models including:
- VPLoss: Variance Preserving loss from Score-Based Generative Modeling
- VELoss: Variance Exploding loss from Score-Based Generative Modeling
- EDMLoss: Improved loss from Elucidating the Design Space of Diffusion-Based Generative Models
"""
import torch
import unittest
from unittest.mock import Mock
from IPSL_AID.networks import VPPrecond, VEPrecond, EDMPrecond, DhariwalUNet
# ----------------------------------------------------------------------------
# Loss function corresponding to the variance preserving (VP) formulation
[docs]
class VPLoss:
"""
Loss function for Variance Preserving (VP) formulation diffusion models.
This class implements the loss function for the Variance Preserving SDE
formulation of diffusion models. It follows the continuous-time training
objective from score-based generative modeling through stochastic differential
equations.
Parameters
----------
beta_d : float, optional
Maximum β parameter controlling the extent of the noise schedule.
Larger values lead to faster noise increase. Default is 19.9.
beta_min : float, optional
Minimum β parameter controlling the initial slope of the noise schedule.
Default is 0.1.
epsilon_t : float, optional
Minimum time value threshold to avoid numerical issues near t=0.
Default is 1e-5.
Attributes
----------
beta_d : float
Maximum β parameter for noise schedule.
beta_min : float
Minimum β parameter for noise schedule.
epsilon_t : float
Minimum time threshold.
Methods
-------
__call__(net, images, conditional_img=None, labels=None, augment_pipe=None)
Compute the VP loss for a batch of images.
sigma(t)
Compute noise level sigma for given timestep t.
Notes
-----
- The loss is based on denoising score matching: E[λ(t) * ||D_θ(x_t, t) - x_0||²]
- The weighting function λ(t) = 1/σ(t)² gives equal importance to all noise levels.
- Time t is uniformly sampled between [epsilon_t, 1] during training.
- This loss corresponds to training the model to predict the clean image x_0
from noisy input x_t = x_0 + σ(t)·ε.
References
----------
- Song et al., "Score-Based Generative Modeling through Stochastic Differential Equations", 2020.
"""
[docs]
def __init__(self, beta_d=19.9, beta_min=0.1, epsilon_t=1e-5):
"""
Initialize the VPLoss function.
Parameters
----------
beta_d : float, optional
Maximum β parameter for noise schedule.
Default is 19.9.
beta_min : float, optional
Minimum β parameter for noise schedule.
Default is 0.1.
epsilon_t : float, optional
Minimum time threshold.
Default is 1e-5.
"""
self.beta_d = beta_d
self.beta_min = beta_min
self.epsilon_t = epsilon_t
[docs]
def __call__(
self, net, images, conditional_img=None, labels=None, augment_pipe=None
):
"""
Compute the VP loss for a batch of images.
Parameters
----------
net : torch.nn.Module
The diffusion model network (typically a preconditioned U-Net).
Should accept inputs (x, sigma, condition_img, labels, augment_labels).
images : torch.Tensor
Clean input images of shape (batch_size, channels, height, width).
conditional_img : torch.Tensor, optional
Conditional images for guided generation. Should have same spatial
dimensions as `images`. Default is None.
labels : torch.Tensor, optional
Class labels for conditional generation of shape (batch_size,) or
(batch_size, label_dim). Default is None.
augment_pipe : callable, optional
Data augmentation pipeline that takes images and returns augmented
images and augmentation labels. Default is None.
Returns
-------
torch.Tensor
Loss values for each image and channel of shape
(batch_size, channels, height, width). Typically reduced via mean()
for training.
Notes
-----
- The training procedure:
1. Sample time t ~ Uniform[epsilon_t, 1]
2. Compute noise level σ(t)
3. Generate noisy images: x_t = x_0 + σ(t)·ε where ε ~ N(0, I)
4. Compute model prediction D_θ(x_t, t)
5. Calculate weighted MSE loss: λ(t) * ||D_θ(x_t, t) - x_0||²
- The weight λ(t) = 1/σ(t)² ensures balanced learning across noise levels.
- Data augmentation is applied before adding noise if augment_pipe is provided.
"""
rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device)
sigma = self.sigma(1 + rnd_uniform * (self.epsilon_t - 1))
weight = 1 / sigma**2
y, augment_labels = (
augment_pipe(images) if augment_pipe is not None else (images, None)
)
n = torch.randn_like(y) * sigma
D_yn = net(y + n, sigma, conditional_img, labels, augment_labels=augment_labels)
loss = weight * ((D_yn - y) ** 2)
return loss
[docs]
def sigma(self, t):
"""
Compute noise level sigma for given timestep t.
Parameters
----------
t : torch.Tensor or float
Timestep value(s) in [epsilon_t, 1].
Returns
-------
torch.Tensor
Noise level sigma corresponding to t, with same shape as input.
Notes
-----
The noise schedule follows:
σ(t) = sqrt(exp(0.5*β_d*t² + β_min*t) - 1)
This ensures smooth transition from low to high noise levels, with
σ(0) ≈ 0 and σ(1) determined by β_d and β_min.
"""
t = torch.as_tensor(t)
return ((0.5 * self.beta_d * (t**2) + self.beta_min * t).exp() - 1).sqrt()
# ----------------------------------------------------------------------------
# Loss function corresponding to the variance exploding (VE) formulation
[docs]
class VELoss:
"""
Loss function for Variance Exploding (VE) formulation diffusion models.
This class implements the loss function for the Variance Exploding SDE
formulation of diffusion models. It follows the continuous-time training
objective from score-based generative modeling through stochastic differential
equations.
Parameters
----------
sigma_min : float, optional
Minimum noise level. Controls the lower bound of the noise schedule.
Smaller values allow modeling finer details. Default is 0.02.
sigma_max : float, optional
Maximum noise level. Controls the upper bound of the noise schedule.
Larger values allow modeling broader structure. Default is 100.
Attributes
----------
sigma_min : float
Minimum noise level for the geometric schedule.
sigma_max : float
Maximum noise level for the geometric schedule.
Methods
-------
__call__(net, images, conditional_img=None, labels=None, augment_pipe=None)
Compute the VE loss for a batch of images.
Notes
-----
- The VE formulation uses a geometric noise schedule: σ(t) = σ_min * (σ_max/σ_min)^t
- Time t is uniformly sampled between [0, 1] during training.
- The weighting function λ(t) = 1/σ(t)² gives more emphasis to lower noise levels.
- This corresponds to training the model to predict the clean image x_0 from
noisy input x_t = x_0 + σ(t)·ε.
- The geometric schedule provides a simple and effective way to span a wide
range of noise levels with a single parameter.
References
----------
- Song et al., "Score-Based Generative Modeling through Stochastic Differential Equations", 2020.
"""
[docs]
def __init__(self, sigma_min=0.02, sigma_max=100):
"""
Initialize the VELoss function.
Parameters
----------
sigma_min : float, optional
Minimum noise level for the geometric schedule.
Default is 0.02.
sigma_max : float, optional
Maximum noise level for the geometric schedule.
Default is 100.
"""
self.sigma_min = sigma_min
self.sigma_max = sigma_max
[docs]
def __call__(
self, net, images, conditional_img=None, labels=None, augment_pipe=None
):
"""
Compute the VE loss for a batch of images.
Parameters
----------
net : torch.nn.Module
The diffusion model network (typically a preconditioned U-Net).
Should accept inputs (x, sigma, condition_img, labels, augment_labels).
images : torch.Tensor
Clean input images of shape (batch_size, channels, height, width).
conditional_img : torch.Tensor, optional
Conditional images for guided generation. Should have same spatial
dimensions as `images`. Default is None.
labels : torch.Tensor, optional
Class labels for conditional generation of shape (batch_size,) or
(batch_size, label_dim). Default is None.
augment_pipe : callable, optional
Data augmentation pipeline that takes images and returns augmented
images and augmentation labels. Default is None.
Returns
-------
torch.Tensor
Loss values for each image and channel of shape
(batch_size, channels, height, width). Typically reduced via mean()
for training.
Notes
-----
- The training procedure:
1. Sample time t ~ Uniform[0, 1]
2. Compute noise level σ = σ_min * (σ_max/σ_min)^t
3. Generate noisy images: x_t = x_0 + σ·ε where ε ~ N(0, I)
4. Compute model prediction D_θ(x_t, t)
5. Calculate weighted MSE loss: λ(t) * ||D_θ(x_t, t) - x_0||²
- The weight λ(t) = 1/σ² ensures higher weighting for lower noise levels.
- Data augmentation is applied before adding noise if augment_pipe is provided.
- The geometric noise schedule spans orders of magnitude from σ_min to σ_max.
"""
rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device)
sigma = self.sigma_min * ((self.sigma_max / self.sigma_min) ** rnd_uniform)
weight = 1 / sigma**2
y, augment_labels = (
augment_pipe(images) if augment_pipe is not None else (images, None)
)
n = torch.randn_like(y) * sigma
D_yn = net(y + n, sigma, conditional_img, labels, augment_labels=augment_labels)
loss = weight * ((D_yn - y) ** 2)
return loss
# ----------------------------------------------------------------------------
# Improved loss function proposed in the paper "Elucidating the Design Space
# of Diffusion-Based Generative Models" (EDM).
[docs]
class EDMLoss:
"""
EDM (Elucidating Diffusion Models) loss function for diffusion models.
This class implements the improved loss function from the EDM paper, which
uses a log-normal distribution for noise level sampling and an optimized
weighting scheme for better training stability and sample quality.
Parameters
----------
P_mean : float, optional
Mean parameter for the log-normal distribution of sigma.
Controls the center of the noise level distribution. Default is -1.2.
P_std : float, optional
Standard deviation parameter for the log-normal distribution of sigma.
Controls the spread of the noise level distribution. Default is 1.2.
sigma_data : float, optional
Standard deviation of the training data. Used in the weighting function
to balance the loss across noise levels. Default is 1.0.
Attributes
----------
P_mean : float
Mean of log-normal distribution for sigma sampling.
P_std : float
Standard deviation of log-normal distribution for sigma sampling.
sigma_data : float
Training data standard deviation.
Methods
-------
__call__(net, images, conditional_img=None, labels=None, augment_pipe=None)
Compute the EDM loss for a batch of images.
Notes
-----
- The EDM loss uses a log-normal distribution for sigma: σ ~ logNormal(P_mean, P_std)
- The weighting function: λ(σ) = (σ² + σ_data²) / (σ·σ_data)²
- This weighting minimizes the variance of the loss gradient, leading to
more stable training and faster convergence.
- The loss corresponds to training the model to predict the clean image x_0
from noisy input x_t = x_0 + σ·ε.
- The log-normal distribution provides a better prior for noise levels
compared to uniform sampling.
References
----------
- Karras et al., "Elucidating the Design Space of Diffusion-Based Generative Models", 2022.
"""
[docs]
def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=1.0):
"""
Initialize the EDMLoss function.
Parameters
----------
P_mean : float, optional
Mean parameter for log-normal distribution.
Default is -1.2.
P_std : float, optional
Standard deviation parameter for log-normal distribution.
Default is 1.2.
sigma_data : float, optional
Standard deviation of training data.
Default is 1.0.
"""
self.P_mean = P_mean
self.P_std = P_std
self.sigma_data = sigma_data
[docs]
def __call__(
self, net, images, conditional_img=None, labels=None, augment_pipe=None
):
"""
Compute the EDM loss for a batch of images.
Parameters
----------
net : torch.nn.Module
The diffusion model network (typically an EDM-preconditioned U-Net).
Should accept inputs (x, sigma, condition_img, labels, augment_labels).
images : torch.Tensor
Clean input images of shape (batch_size, channels, height, width).
conditional_img : torch.Tensor, optional
Conditional images for guided generation. Should have same spatial
dimensions as `images`. Default is None.
labels : torch.Tensor, optional
Class labels for conditional generation of shape (batch_size,) or
(batch_size, label_dim). Default is None.
augment_pipe : callable, optional
Data augmentation pipeline that takes images and returns augmented
images and augmentation labels. Default is None.
Returns
-------
torch.Tensor
Loss values for each image and channel of shape
(batch_size, channels, height, width). Typically reduced via mean()
for training.
Notes
-----
- The training procedure:
1. Sample log(sigma) ~ Normal(P_mean, P_std)
2. Compute noise level σ = exp(log(sigma))
3. Generate noisy images: x_t = x_0 + σ·ε where ε ~ N(0, I)
4. Compute model prediction D_θ(x_t, σ)
5. Calculate weighted MSE loss: λ(σ) * ||D_θ(x_t, σ) - x_0||²
- The weight λ(σ) = (σ² + σ_data²) / (σ·σ_data)² minimizes gradient variance.
- Data augmentation is applied before adding noise if augment_pipe is provided.
- The log-normal distribution provides a natural prior for noise levels,
avoiding the need for manual schedule design.
"""
rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device)
sigma = (rnd_normal * self.P_std + self.P_mean).exp()
weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2
y, augment_labels = (
augment_pipe(images) if augment_pipe is not None else (images, None)
)
n = torch.randn_like(y) * sigma
D_yn = net(y + n, sigma, conditional_img, labels, augment_labels=augment_labels)
loss = weight * ((D_yn - y) ** 2)
return loss
# ----------------------------------------------------------------------------
# Loss function for UNet architectures.
[docs]
class UnetLoss:
"""
Simple UNet loss function for direct image-to-image prediction.
This loss function works with UNet models that predict images directly,
without any diffusion noise process. It's a standard supervised loss
for image generation/transformation tasks such as segmentation,
denoising, super-resolution, or autoencoding.
Parameters
----------
loss_type : str, optional
Type of loss function to use:
- ``mse``: Mean Squared Error (L2 loss)
- ``l1``: Mean Absolute Error (L1 loss)
- ``smooth_l1``: Smooth L1 loss (Huber loss)
Default is ``mse``.
reduction : str, optional
Reduction method for the loss:
- ``mean``: Average the loss over all elements
- ``sum``: Sum the loss over all elements
- ``none``: Return loss per element
Default is ``mean``.
Attributes
----------
loss_type : str
Type of loss function.
reduction : str
Reduction method.
loss_fn : torch.nn.Module
PyTorch loss function instance.
Raises
------
ValueError
If an unknown ``loss_type`` is provided.
Notes
-----
- The loss computes the discrepancy between the model's output and the input image.
- This is suitable for autoencoder-style tasks where the model learns to
reconstruct the input.
- For conditional generation, labels can be provided to the model.
- Data augmentation can be applied via `augment_pipe`.
"""
[docs]
def __init__(self, loss_type="mse", reduction="mean"):
"""
Initialize the UnetLoss function.
Parameters
----------
loss_type : str, optional
Type of loss function.
Default is ``mse``.
reduction : str, optional
Reduction method.
Default is ``mean``.
"""
self.loss_type = loss_type
self.reduction = reduction
# Initialize loss function
if loss_type == "mse":
self.loss_fn = torch.nn.MSELoss(reduction=reduction)
elif loss_type == "l1":
self.loss_fn = torch.nn.L1Loss(reduction=reduction)
elif loss_type == "smooth_l1":
self.loss_fn = torch.nn.SmoothL1Loss(reduction=reduction)
else:
raise ValueError(f"Unknown loss_type: {loss_type}")
def __call__(self, net, targets, images, labels=None, augment_pipe=None):
"""
Compute UNet loss.
Parameters
----------
net : torch.nn.Module
Neural network model (DhariwalUNet or SongUNet) that outputs an image
of the same shape as input.
images : torch.Tensor
Input images tensor of shape (batch_size, channels, height, width).
labels : torch.Tensor, optional
Class labels for conditional generation of shape (batch_size,) or
(batch_size, label_dim). Default is None.
augment_pipe : callable, optional
Data augmentation pipeline that takes images and returns augmented
images and augmentation labels. Default is None.
Returns
-------
torch.Tensor
Computed loss value (scalar if reduction is ``mean`` or ``sum``,
otherwise tensor of shape (batch_size, channels, height, width)).
Notes
-----
- The model is called with the input images and optional labels.
- The loss is computed between the model output and the (augmented) input images.
- This setup is typical for autoencoder or denoising tasks where the
model learns to reconstruct the input.
"""
# Apply data augmentation if provided
if augment_pipe is not None:
images, augment_labels = augment_pipe(images)
else:
augment_labels = None
# Get model prediction
model_out = net(images, class_labels=labels, augment_labels=augment_labels)
# Simple loss: compare model output with input image
loss = self.loss_fn(model_out, targets)
return loss
# ----------------------------------------------------------------------------
# Unit tests
[docs]
class TestLosses(unittest.TestCase):
"""Unit tests for diffusion models and loss functions."""
[docs]
def __init__(self, methodName="runTest", logger=None):
super().__init__(methodName)
self.logger = logger
[docs]
def setUp(self):
"""Set up test fixtures."""
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.batch_size = 2
self.in_channels = 3
self.cond_channels = 7
self.out_channels = 3
self.label_dim = 2
self.img_resolution = (64, 128)
if self.logger:
self.logger.info(f"Test setup complete - using device: {self.device}")
[docs]
def test_vp_loss(self):
"""Test VP loss function."""
if self.logger:
self.logger.info("Testing VPLoss")
# Create model and loss
total_in_channels = self.in_channels + self.cond_channels
model = VPPrecond(
img_resolution=self.img_resolution,
in_channels=total_in_channels,
out_channels=self.out_channels,
label_dim=self.label_dim,
use_fp16=False,
model_type="SongUNet",
model_channels=64,
channel_mult=[1, 2],
).to(self.device)
loss_fn = VPLoss()
# Test data
images = torch.randn(
self.batch_size, self.in_channels, *self.img_resolution
).to(self.device)
cond_img = torch.randn(
self.batch_size, self.cond_channels, *self.img_resolution
).to(self.device)
labels = torch.randn(self.batch_size, self.label_dim, device=self.device)
# Compute loss
loss = loss_fn(model, images, conditional_img=cond_img, labels=labels)
self.assertEqual(loss.shape, images.shape)
self.assertGreater(loss.mean().item(), 0)
if self.logger:
self.logger.info(
f"✅ VPLoss test passed - loss shape: {loss.shape}, mean: {loss.mean().item():.4f}"
)
[docs]
def test_ve_loss(self):
"""Test VE loss function."""
if self.logger:
self.logger.info("Testing VELoss")
# Create model and loss
total_in_channels = self.in_channels + self.cond_channels
model = VEPrecond(
img_resolution=self.img_resolution,
in_channels=total_in_channels,
out_channels=self.out_channels,
label_dim=self.label_dim,
use_fp16=False,
model_type="SongUNet",
model_channels=64,
channel_mult=[1, 2],
).to(self.device)
loss_fn = VELoss()
# Test data
images = torch.randn(
self.batch_size, self.in_channels, *self.img_resolution
).to(self.device)
cond_img = torch.randn(
self.batch_size, self.cond_channels, *self.img_resolution
).to(self.device)
labels = torch.randn(self.batch_size, self.label_dim, device=self.device)
# Compute loss
loss = loss_fn(model, images, conditional_img=cond_img, labels=labels)
self.assertEqual(loss.shape, images.shape)
self.assertGreater(loss.mean().item(), 0)
if self.logger:
self.logger.info(
f"✅ VELoss test passed - loss shape: {loss.shape}, mean: {loss.mean().item():.4f}"
)
[docs]
def test_edm_loss(self):
"""Test EDM loss function."""
if self.logger:
self.logger.info("Testing EDMLoss")
# Create model and loss
total_in_channels = self.in_channels + self.cond_channels
model = EDMPrecond(
img_resolution=self.img_resolution,
in_channels=total_in_channels,
out_channels=self.out_channels,
label_dim=self.label_dim,
use_fp16=False,
model_type="DhariwalUNet",
model_channels=64,
channel_mult=[1, 2],
).to(self.device)
loss_fn = EDMLoss()
# Test data
images = torch.randn(
self.batch_size, self.in_channels, *self.img_resolution
).to(self.device)
cond_img = torch.randn(
self.batch_size, self.cond_channels, *self.img_resolution
).to(self.device)
labels = torch.randn(self.batch_size, self.label_dim, device=self.device)
# Compute loss
loss = loss_fn(model, images, conditional_img=cond_img, labels=labels)
self.assertEqual(loss.shape, images.shape)
self.assertGreater(loss.mean().item(), 0)
if self.logger:
self.logger.info(
f"✅ EDMLoss test passed - loss shape: {loss.shape}, mean: {loss.mean().item():.4f}"
)
[docs]
def test_unet_loss(self):
"""Test UnetLoss function."""
if self.logger:
self.logger.info("Testing UnetLoss")
# Create UNet model (not diffusion-based)
input_channels = 5 # Example input channels
model = DhariwalUNet(
img_resolution=self.img_resolution,
in_channels=input_channels, # No conditional channels needed
out_channels=self.out_channels,
label_dim=self.label_dim,
diffusion_model=False,
).to(self.device)
loss_fn = UnetLoss()
# Test data - UNet just reconstructs the input image
images = torch.randn(self.batch_size, input_channels, *self.img_resolution).to(
self.device
)
targets = torch.randn(
self.batch_size, self.out_channels, *self.img_resolution
).to(self.device)
labels = torch.randn(self.batch_size, self.label_dim, device=self.device)
# Compute loss
loss = loss_fn(model, targets, images, labels=labels)
# Loss should be a scalar (not per-pixel like diffusion losses)
self.assertEqual(loss.shape, ()) # Scalar tensor
self.assertGreater(loss.item(), 0)
if self.logger:
self.logger.info(f"✅ UnetLoss test passed - loss value: {loss.item():.4f}")
[docs]
def test_loss_comparison(self):
"""Compare different loss functions on the same model."""
if self.logger:
self.logger.info("Testing loss function comparison")
# Create model
unet_model = DhariwalUNet(
img_resolution=self.img_resolution,
in_channels=self.in_channels,
out_channels=self.out_channels,
label_dim=self.label_dim,
diffusion_model=False,
).to(self.device)
total_in_channels = self.in_channels + self.cond_channels
model = VPPrecond(
img_resolution=self.img_resolution,
in_channels=total_in_channels,
out_channels=self.out_channels,
label_dim=self.label_dim,
use_fp16=False,
model_type="SongUNet",
model_channels=64,
channel_mult=[1, 2],
).to(self.device)
# Create loss functions
vp_loss = VPLoss()
ve_loss = VELoss()
edm_loss = EDMLoss()
unet_loss_fn = UnetLoss()
# Test data
images = torch.randn(
self.batch_size, self.in_channels, *self.img_resolution
).to(self.device)
targets = torch.randn(
self.batch_size, self.out_channels, *self.img_resolution
).to(self.device)
cond_img = torch.randn(
self.batch_size, self.cond_channels, *self.img_resolution
).to(self.device)
labels = torch.randn(self.batch_size, self.label_dim, device=self.device)
# Compute losses
vp_loss_val = vp_loss(model, images, conditional_img=cond_img, labels=labels)
ve_loss_val = ve_loss(model, images, conditional_img=cond_img, labels=labels)
edm_loss_val = edm_loss(model, images, conditional_img=cond_img, labels=labels)
unet_loss_val = unet_loss_fn(unet_model, targets, images, labels=labels)
# All losses should have same shape and be positive
self.assertEqual(vp_loss_val.shape, ve_loss_val.shape)
self.assertEqual(ve_loss_val.shape, edm_loss_val.shape)
self.assertGreater(vp_loss_val.mean().item(), 0)
self.assertGreater(ve_loss_val.mean().item(), 0)
self.assertGreater(edm_loss_val.mean().item(), 0)
self.assertGreater(unet_loss_val.item(), 0)
if self.logger:
self.logger.info("✅ Loss comparison test passed")
self.logger.info(f" └── VPLoss mean: {vp_loss_val.mean().item():.4f}")
self.logger.info(f" └── VELoss mean: {ve_loss_val.mean().item():.4f}")
self.logger.info(f" └── EDMLoss mean: {edm_loss_val.mean().item():.4f}")
self.logger.info(f" └── UnetLoss (scalar): {unet_loss_val.item():.4f}")
[docs]
def test_loss_with_augmentation(self):
"""Test loss functions with data augmentation."""
if self.logger:
self.logger.info("Testing loss with augmentation")
# Mock augmentation pipe
augment_pipe = Mock()
augment_pipe.return_value = (
torch.randn(self.batch_size, self.in_channels, *self.img_resolution).to(
self.device
),
torch.randint(0, 2, (self.batch_size, 1), device=self.device),
)
# Create model and loss
total_in_channels = self.in_channels + self.cond_channels
model = VPPrecond(
img_resolution=self.img_resolution,
in_channels=total_in_channels,
out_channels=self.out_channels,
label_dim=self.label_dim,
use_fp16=False,
model_type="SongUNet",
model_channels=64,
channel_mult=[1, 2],
).to(self.device)
loss_fn = VPLoss()
# Test data
images = torch.randn(
self.batch_size, self.in_channels, *self.img_resolution
).to(self.device)
cond_img = torch.randn(
self.batch_size, self.cond_channels, *self.img_resolution
).to(self.device)
labels = torch.randn(self.batch_size, self.label_dim, device=self.device)
# Compute loss with augmentation
loss = loss_fn(
model,
images,
conditional_img=cond_img,
labels=labels,
augment_pipe=augment_pipe,
)
self.assertEqual(loss.shape, images.shape)
self.assertGreater(loss.mean().item(), 0)
if self.logger:
self.logger.info(
f"✅ Loss with augmentation test passed - loss shape: {loss.shape}"
)
[docs]
def test_loss_gradients(self):
"""Test that loss computation supports gradient computation."""
if self.logger:
self.logger.info("Testing loss gradients")
# Create model and loss
total_in_channels = self.in_channels + self.cond_channels
model = VPPrecond(
img_resolution=self.img_resolution,
in_channels=total_in_channels,
out_channels=self.out_channels,
label_dim=self.label_dim,
use_fp16=False,
model_type="SongUNet",
model_channels=32,
channel_mult=[1],
).to(self.device)
loss_fn = VPLoss()
# Test data with requires_grad
images = torch.randn(
self.batch_size,
self.in_channels,
*self.img_resolution,
device=self.device,
requires_grad=True,
)
cond_img = torch.randn(
self.batch_size,
self.cond_channels,
*self.img_resolution,
device=self.device,
)
labels = torch.randn(self.batch_size, self.label_dim, device=self.device)
torch.cuda.empty_cache()
# Compute loss and gradients
loss = loss_fn(model, images, conditional_img=cond_img, labels=labels)
total_loss = loss.mean()
total_loss.backward()
# Check that gradients were computed
self.assertIsNotNone(images.grad)
self.assertTrue(torch.isfinite(images.grad).all())
if self.logger:
self.logger.info(
"✅ Loss gradients test passed - gradients computed successfully"
)
[docs]
def tearDown(self):
"""Clean up after tests."""
if self.logger:
self.logger.info("Loss function tests completed successfully")
# ----------------------------------------------------------------------------