Source code for IPSL_AID.model

# Copyright 2026 IPSL / CNRS / Sorbonne University
# Authors: Kazem Ardaneh
#
# This work is licensed under the Creative Commons
# Attribution-NonCommercial-ShareAlike 4.0 International License.
# To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc-sa/4.0/

import torch
from IPSL_AID.utils import EasyDict
import unittest

# Import all diffusion components
from IPSL_AID.networks import VPPrecond, VEPrecond, EDMPrecond, SongUNet, DhariwalUNet
from IPSL_AID.loss import VPLoss, VELoss, EDMLoss, UnetLoss


# ============================================================================
# Model + Loss Loader
# ============================================================================


[docs] def load_model_and_loss(opts, logger=None, device="cpu"): """ Load a diffusion model or U-Net with corresponding loss function. This function initializes and configures a generative model (diffusion or direct U-Net) along with its corresponding loss function based on the provided options. It supports multiple architectures and preconditioning schemes. Parameters ---------- opts : EasyDict or dict Configuration dictionary containing model parameters. Must include: - arch : str Architecture type: 'ddpmpp', 'ncsnpp', or 'adm'. - precond : str Preconditioning type: 'vp', 've', 'edm', or 'unet'. - img_resolution : int or tuple Image resolution (height, width). - in_channels : int Number of input channels. - out_channels : int Number of output channels. - label_dim : int Dimension of label conditioning (0 for unconditional). - use_fp16 : bool Whether to use mixed precision (FP16). - model_kwargs : dict, optional Additional model-specific parameters to override defaults. logger : logging.Logger, optional Logger instance for output messages. If None, uses print(). Default is None. device : str or torch.device, optional Device to load the model onto ('cpu', 'cuda', etc.). Default is 'cpu'. Returns ------- model : torch.nn.Module Initialized model instance (preconditioner or U-Net). loss_fn : torch.nn.Module or callable Corresponding loss function for the model. Raises ------ ValueError If an invalid architecture or preconditioner type is specified. Notes ----- - The function supports three main architectures: * DDPM++ (Song et al., 2020) with VP preconditioning * NCSN++ (Song et al., 2020) with VE preconditioning * ADM (Dhariwal & Nichol, 2021) with EDM preconditioning - When precond='unet', uses a direct U-Net without diffusion preconditioning. - Model parameters are counted and logged for transparency. - Default hyperparameters are provided for each architecture but can be overridden via opts.model_kwargs. """ log = logger.info if logger else print opts = EasyDict(opts) diffusion_model = False if opts.precond == "unet" else True arch = opts.arch.lower() # -------------------------------------------------------- # Preconditioner + matching loss # -------------------------------------------------------- if opts.precond == "vp": precond_class = VPPrecond loss_class = VPLoss log("Using VP preconditioner & VPLoss") elif opts.precond == "ve": precond_class = VEPrecond loss_class = VELoss log("Using VE preconditioner & VELoss") elif opts.precond == "edm": precond_class = EDMPrecond loss_class = EDMLoss log("Using EDM preconditioner & EDMLoss") elif opts.precond == "unet": if arch == "adm": precond_class = DhariwalUNet # Direct U-Net without preconditioning elif arch in ["ddpmpp", "ncsnpp"]: precond_class = SongUNet # Direct U-Net without preconditioning else: raise ValueError(f"❌ Invalid arch '{opts.arch}' for direct U-Net") loss_class = UnetLoss log("Using direct U-Net & UnetLoss") else: raise ValueError(f"❌ Invalid opts.precond '{opts.precond}'") # -------------------------------------------------------- # Architecture network kwargs # -------------------------------------------------------- network_kwargs = EasyDict() if arch == "ddpmpp": network_kwargs.update( dict( model_type="SongUNet", embedding_type="positional", encoder_type="standard", decoder_type="standard", channel_mult_noise=1, resample_filter=[1, 1], model_channels=128, channel_mult=[2, 2, 2], ) ) log("Architecture DDPM++ / SongUNet selected") elif arch == "ncsnpp": network_kwargs.update( dict( model_type="SongUNet", embedding_type="fourier", encoder_type="residual", decoder_type="standard", channel_mult_noise=2, resample_filter=[1, 3, 3, 1], model_channels=128, channel_mult=[2, 2, 2], ) ) log("Architecture NCSN++ / SongUNet selected") elif arch == "adm": if diffusion_model: network_kwargs.update( dict( model_type="DhariwalUNet", model_channels=128, channel_mult=[1, 2, 3, 4], num_blocks=2, ) ) log("Architecture ADM / DhariwalUNet selected") else: network_kwargs.update( dict( model_channels=128, channel_mult=[1, 2, 3, 4], num_blocks=2, diffusion_model=False, ) ) log("Architecture ADM / DhariwalUNet selected for direct U-Net") else: raise ValueError(f"❌ Invalid opts.arch '{opts.arch}'") # Allow overrides from opts.model_kwargs if hasattr(opts, "model_kwargs"): log("Overriding with user model_kwargs") network_kwargs.update(opts.model_kwargs) # -------------------------------------------------------- # Create model # -------------------------------------------------------- log("Instantiating model...") if diffusion_model: log("Diffusion model enabled") total_in = opts.in_channels + ( opts.cond_channels if "cond_channels" in opts else 0 ) else: log("Diffusion model disabled, direct U-Net, no preconditioning") if diffusion_model: model = precond_class( img_resolution=opts.img_resolution, in_channels=total_in, out_channels=opts.out_channels, label_dim=opts.label_dim, use_fp16=opts.use_fp16, **network_kwargs, ) else: model = precond_class( img_resolution=opts.img_resolution, in_channels=opts.in_channels, out_channels=opts.out_channels, label_dim=opts.label_dim, **network_kwargs, ) model = model.to(device) total_num = sum(p.numel() for p in model.parameters()) trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad) # -------------------------------------------------------- # Comprehensive Model Information Logging # -------------------------------------------------------- log("Model Summary:") log(f" └── Model Type: {type(model).__name__}") log(f" └── Preconditioner: {opts.precond.upper()}") log(f" └── Architecture: {opts.arch.upper()}") if diffusion_model: log( f" └── Input Channels: {total_in} (base: {opts.in_channels} + cond: {total_in - opts.in_channels})" ) else: log(f" └── Input Channels: {opts.in_channels}") log(f" └── Output Channels: {opts.out_channels}") log(f" └── Label Dimension: {opts.label_dim}") log(f" └── Image Resolution: {opts.img_resolution}") if diffusion_model: log(f" └── FP16 Enabled: {opts.use_fp16}") else: log(" └── FP16 Enabled: N/A for direct U-Net") log(f" └── Model Parameters - Total: {total_num:,}, Trainable: {trainable_num}") # Log network architecture details log("Network Architecture:") for key, value in network_kwargs.items(): log(f" └── {key}: {value}") # Log device information device = next(model.parameters()).device log(f"Device: {device}") # Log model dtype information dtype = next(model.parameters()).dtype log(f"Model Data Type: {dtype}") # -------------------------------------------------------- # Loss function instance # -------------------------------------------------------- loss_fn = loss_class() log(f"Loss function instantiated: {loss_class.__name__}") log(f" └── Loss Type: {opts.precond.upper()} Diffusion Loss") return model, loss_fn
# ============================================================================ # Unit Tests # ============================================================================
[docs] class TestModelLoader(unittest.TestCase): """Unit tests for model and loss loader."""
[docs] def __init__(self, methodName="runTest", logger=None): super().__init__(methodName) self.logger = logger
[docs] def setUp(self): """Set up test fixtures.""" self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.batch_size = 2 self.in_channels = 3 self.cond_channels = 7 self.out_channels = 3 self.label_dim = 4 self.img_resolution = (144, 360) if self.logger: self.logger.info(f"Test setup complete - using device: {self.device}")
[docs] def test_ddpmpp_vp_combination(self): """Test DDPM++ architecture with VP preconditioner.""" if self.logger: self.logger.info("Testing DDPM++ + VP combination") opts = EasyDict( { "arch": "ddpmpp", "precond": "vp", "img_resolution": self.img_resolution, "in_channels": self.in_channels, "cond_channels": self.cond_channels, "out_channels": self.out_channels, "label_dim": self.label_dim, "use_fp16": False, } ) model, loss_fn = load_model_and_loss(opts, self.logger, self.device) # Test forward pass x = torch.randn(self.batch_size, self.in_channels, *self.img_resolution).to( self.device ) cond_img = torch.randn( self.batch_size, self.cond_channels, *self.img_resolution ).to(self.device) labels = torch.randint(0, self.label_dim, (self.batch_size, self.label_dim)).to( self.device ) sigma = torch.tensor([0.1, 0.5], device=self.device) with torch.no_grad(): output = model(x, sigma, condition_img=cond_img, class_labels=labels) self.assertEqual(output.shape, x.shape) # Test loss computation loss = loss_fn(model, x, conditional_img=cond_img, labels=labels) self.assertEqual(loss.shape, x.shape) if self.logger: self.logger.info( f"✅ DDPM++ + VP test passed - output shape: {output.shape}, loss shape: {loss.shape}" )
[docs] def test_ncsnpp_ve_combination(self): """Test NCSN++ architecture with VE preconditioner.""" if self.logger: self.logger.info("Testing NCSN++ + VE combination") opts = EasyDict( { "arch": "ncsnpp", "precond": "ve", "img_resolution": self.img_resolution, "in_channels": self.in_channels, "cond_channels": self.cond_channels, "out_channels": self.out_channels, "label_dim": self.label_dim, "use_fp16": False, } ) model, loss_fn = load_model_and_loss(opts, self.logger, self.device) # Test forward pass x = torch.randn(self.batch_size, self.in_channels, *self.img_resolution).to( self.device ) cond_img = torch.randn( self.batch_size, self.cond_channels, *self.img_resolution ).to(self.device) labels = torch.randint(0, self.label_dim, (self.batch_size, self.label_dim)).to( self.device ) sigma = torch.tensor([0.1, 0.5], device=self.device) with torch.no_grad(): output = model(x, sigma, condition_img=cond_img, class_labels=labels) self.assertEqual(output.shape, x.shape) # Test loss computation loss = loss_fn(model, x, conditional_img=cond_img, labels=labels) self.assertEqual(loss.shape, x.shape) if self.logger: self.logger.info( f"✅ NCSN++ + VE test passed - output shape: {output.shape}, loss shape: {loss.shape}" )
[docs] def test_adm_edm_combination(self): """Test ADM architecture with EDM preconditioner.""" if self.logger: self.logger.info("Testing ADM + EDM combination") opts = EasyDict( { "arch": "adm", "precond": "edm", "img_resolution": self.img_resolution, "in_channels": self.in_channels, "cond_channels": self.cond_channels, "out_channels": self.out_channels, "label_dim": self.label_dim, "use_fp16": False, } ) model, loss_fn = load_model_and_loss(opts, self.logger, self.device) # Test forward pass x = torch.randn(self.batch_size, self.in_channels, *self.img_resolution).to( self.device ) cond_img = torch.randn( self.batch_size, self.cond_channels, *self.img_resolution ).to(self.device) labels = torch.randint(0, self.label_dim, (self.batch_size, self.label_dim)).to( self.device ) sigma = torch.tensor([0.1, 0.5], device=self.device) with torch.no_grad(): output = model(x, sigma, condition_img=cond_img, class_labels=labels) self.assertEqual(output.shape, x.shape) # Test loss computation loss = loss_fn(model, x, conditional_img=cond_img, labels=labels) self.assertEqual(loss.shape, x.shape) if self.logger: self.logger.info( f"✅ ADM + EDM test passed - output shape: {output.shape}, loss shape: {loss.shape}" )
[docs] def test_adm_unet_combination(self): """Using ADM architecture as direct U-Net without preconditioning.""" if self.logger: self.logger.info("Testing ADM + UNet combination") input_channels = 5 # Example input channels opts = EasyDict( { "arch": "adm", "precond": "unet", "img_resolution": self.img_resolution, "in_channels": input_channels, "out_channels": self.out_channels, "label_dim": self.label_dim, } ) model, loss_fn = load_model_and_loss(opts, self.logger, self.device) # Test forward pass x = torch.randn(self.batch_size, input_channels, *self.img_resolution).to( self.device ) y = torch.randn(self.batch_size, self.out_channels, *self.img_resolution).to( self.device ) labels = torch.randn(self.batch_size, self.label_dim, device=self.device) with torch.no_grad(): output = model(x, class_labels=labels) self.assertEqual(output.shape, y.shape) # Test loss computation loss = loss_fn(model, y, x, labels=labels) self.assertEqual(loss.shape, ()) if self.logger: self.logger.info( f"✅ ADM + UNet test passed - output shape: {output.shape}, loss shape: {loss.shape}" )
[docs] def test_rectangular_resolution(self): """Test loader with rectangular resolution.""" if self.logger: self.logger.info("Testing rectangular resolution") rectangular_res = (128, 64) opts = EasyDict( { "arch": "ddpmpp", "precond": "vp", "img_resolution": rectangular_res, "in_channels": self.in_channels, "cond_channels": self.cond_channels, "out_channels": self.out_channels, "label_dim": self.label_dim, "use_fp16": False, } ) model, loss_fn = load_model_and_loss(opts, self.logger, self.device) # Test forward pass with rectangular resolution x = torch.randn(self.batch_size, self.in_channels, *rectangular_res).to( self.device ) cond_img = torch.randn( self.batch_size, self.cond_channels, *rectangular_res ).to(self.device) labels = torch.randint(0, self.label_dim, (self.batch_size, self.label_dim)).to( self.device ) sigma = torch.tensor([0.1, 0.5], device=self.device) with torch.no_grad(): output = model(x, sigma, condition_img=cond_img, class_labels=labels) self.assertEqual(output.shape, x.shape) if self.logger: self.logger.info( f"✅ Rectangular resolution test passed - output shape: {output.shape}" )
[docs] def test_model_kwargs_override(self): """Test that model_kwargs can override default settings.""" if self.logger: self.logger.info("Testing model_kwargs override") opts = EasyDict( { "arch": "ddpmpp", "precond": "vp", "img_resolution": self.img_resolution, "in_channels": self.in_channels, "cond_channels": self.cond_channels, "out_channels": self.out_channels, "label_dim": self.label_dim, "use_fp16": False, "model_kwargs": { "model_channels": 64, # Override default "channel_mult": [1, 2], # Override default }, } ) model, loss_fn = load_model_and_loss(opts, self.logger, self.device) # Verify the model has the overridden parameters total_params = sum(p.numel() for p in model.parameters()) self.assertLess( total_params, 10_000_000 ) # Should be smaller with overridden settings if self.logger: self.logger.info( f"✅ Model kwargs override test passed - total params: {total_params:,}" )
[docs] def test_no_conditional_channels(self): """Test loader without conditional channels.""" if self.logger: self.logger.info("Testing without conditional channels") opts = EasyDict( { "arch": "ddpmpp", "precond": "vp", "img_resolution": self.img_resolution, "in_channels": self.in_channels, "out_channels": self.out_channels, # No cond_channels specified "label_dim": self.label_dim, "use_fp16": False, } ) model, loss_fn = load_model_and_loss(opts, self.logger, self.device) # Test forward pass without conditional image x = torch.randn(self.batch_size, self.in_channels, *self.img_resolution).to( self.device ) labels = torch.randint(0, self.label_dim, (self.batch_size, self.label_dim)).to( self.device ) sigma = torch.tensor([0.1, 0.5], device=self.device) with torch.no_grad(): output = model(x, sigma, class_labels=labels) # No condition_img self.assertEqual(output.shape, x.shape) if self.logger: self.logger.info( f"✅ No conditional channels test passed - output shape: {output.shape}" )
[docs] def test_invalid_combinations(self): """Test that invalid combinations raise appropriate errors.""" if self.logger: self.logger.info("Testing invalid combinations") # Test invalid architecture with self.assertRaises(ValueError): opts = EasyDict( { "arch": "invalid_arch", "precond": "vp", "img_resolution": self.img_resolution, "in_channels": self.in_channels, "out_channels": self.out_channels, "label_dim": self.label_dim, "use_fp16": False, } ) load_model_and_loss(opts, self.logger, self.device) # Test invalid preconditioner with self.assertRaises(ValueError): opts = EasyDict( { "arch": "ddpmpp", "precond": "invalid_precond", "img_resolution": self.img_resolution, "in_channels": self.in_channels, "out_channels": self.out_channels, "label_dim": self.label_dim, "use_fp16": False, } ) load_model_and_loss(opts, self.logger, self.device) if self.logger: self.logger.info("✅ Invalid combinations test passed")
[docs] def tearDown(self): """Clean up after tests.""" if self.logger: self.logger.info("Model tests completed successfully")
# ----------------------------------------------------------------------------