# Copyright 2026 IPSL / CNRS / Sorbonne University
# Authors: Kazem Ardaneh
#
# This work is licensed under the Creative Commons
# Attribution-NonCommercial-ShareAlike 4.0 International License.
# To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc-sa/4.0/
from IPSL_AID.utils import EasyDict
# Import all diffusion components
from IPSL_AID.networks import VPPrecond, VEPrecond, EDMPrecond, SongUNet, DhariwalUNet
from IPSL_AID.loss import VPLoss, VELoss, EDMLoss, UnetLoss
# ============================================================================
# Model + Loss Loader
# ============================================================================
[docs]
def load_model_and_loss(opts, logger=None, device="cpu"):
"""
Load a diffusion model or U-Net with corresponding loss function.
This function initializes and configures a generative model (diffusion or
direct U-Net) along with its corresponding loss function based on the
provided options. It supports multiple architectures and preconditioning
schemes.
Parameters
----------
opts : EasyDict or dict
Configuration dictionary containing model parameters. Must include:
- arch : str
Architecture type: 'ddpmpp', 'ncsnpp', or 'adm'.
- precond : str
Preconditioning type: 'vp', 've', 'edm', or 'unet'.
- img_resolution : int or tuple
Image resolution (height, width).
- in_channels : int
Number of input channels.
- out_channels : int
Number of output channels.
- label_dim : int
Dimension of label conditioning (0 for unconditional).
- use_fp16 : bool
Whether to use mixed precision (FP16).
- model_kwargs : dict, optional
Additional model-specific parameters to override defaults.
logger : logging.Logger, optional
Logger instance for output messages. If None, uses print().
Default is None.
device : str or torch.device, optional
Device to load the model onto ('cpu', 'cuda', etc.).
Default is 'cpu'.
Returns
-------
model : torch.nn.Module
Initialized model instance (preconditioner or U-Net).
loss_fn : torch.nn.Module or callable
Corresponding loss function for the model.
Raises
------
ValueError
If an invalid architecture or preconditioner type is specified.
Notes
-----
- The function supports three main architectures:
* DDPM++ (Song et al., 2020) with VP preconditioning
* NCSN++ (Song et al., 2020) with VE preconditioning
* ADM (Dhariwal & Nichol, 2021) with EDM preconditioning
- When precond='unet', uses a direct U-Net without diffusion preconditioning.
- Model parameters are counted and logged for transparency.
- Default hyperparameters are provided for each architecture but can be
overridden via opts.model_kwargs.
"""
log = logger.info if logger else print
opts = EasyDict(opts)
diffusion_model = False if opts.precond == "unet" else True
arch = opts.arch.lower()
# --------------------------------------------------------
# Preconditioner + matching loss
# --------------------------------------------------------
if opts.precond == "vp":
precond_class = VPPrecond
loss_class = VPLoss
log("Using VP preconditioner & VPLoss")
elif opts.precond == "ve":
precond_class = VEPrecond
loss_class = VELoss
log("Using VE preconditioner & VELoss")
elif opts.precond == "edm":
precond_class = EDMPrecond
loss_class = EDMLoss
log("Using EDM preconditioner & EDMLoss")
elif opts.precond == "unet":
if arch == "adm":
precond_class = DhariwalUNet # Direct U-Net without preconditioning
elif arch in ["ddpmpp", "ncsnpp"]:
precond_class = SongUNet # Direct U-Net without preconditioning
else:
raise ValueError(f"❌ Invalid arch '{opts.arch}' for direct U-Net")
loss_class = UnetLoss
log("Using direct U-Net & UnetLoss")
else:
raise ValueError(f"❌ Invalid opts.precond '{opts.precond}'")
# --------------------------------------------------------
# Architecture network kwargs
# --------------------------------------------------------
network_kwargs = EasyDict()
if arch == "ddpmpp":
if diffusion_model:
network_kwargs.update(
dict(
model_type="SongUNet",
embedding_type="positional",
encoder_type="standard",
decoder_type="standard",
channel_mult_noise=1,
resample_filter=[1, 1],
model_channels=128,
channel_mult=[2, 2, 2],
)
)
log("Architecture DDPM++ / SongUNet selected")
else:
network_kwargs.update(
dict(
embedding_type="positional",
encoder_type="standard",
decoder_type="standard",
channel_mult_noise=1,
resample_filter=[1, 1],
model_channels=128,
channel_mult=[2, 2, 2],
diffusion_model=False, # Direct U-Net without preconditioning
)
)
log("Architecture DDPM++ / SongUNet selected for direct U-Net")
elif arch == "ncsnpp":
if diffusion_model:
network_kwargs.update(
dict(
model_type="SongUNet",
embedding_type="fourier",
encoder_type="residual",
decoder_type="standard",
channel_mult_noise=2,
resample_filter=[1, 3, 3, 1],
model_channels=128,
channel_mult=[2, 2, 2],
)
)
log("Architecture NCSN++ / SongUNet selected")
else:
network_kwargs.update(
dict(
embedding_type="fourier",
encoder_type="residual",
decoder_type="standard",
channel_mult_noise=2,
resample_filter=[1, 3, 3, 1],
model_channels=128,
channel_mult=[2, 2, 2],
diffusion_model=False, # Direct U-Net without preconditioning
)
)
log("Architecture NCSN++ / SongUNet selected for direct U-Net")
elif arch == "adm":
if diffusion_model:
network_kwargs.update(
dict(
model_type="DhariwalUNet",
model_channels=128,
channel_mult=[1, 2, 3, 4],
num_blocks=2,
)
)
log("Architecture ADM / DhariwalUNet selected")
else:
network_kwargs.update(
dict(
model_channels=128,
channel_mult=[1, 2, 3, 4],
num_blocks=2,
diffusion_model=False,
)
)
log("Architecture ADM / DhariwalUNet selected for direct U-Net")
else:
raise ValueError(f"❌ Invalid opts.arch '{opts.arch}'")
# Allow overrides from opts.model_kwargs
if hasattr(opts, "model_kwargs"):
log("Overriding with user model_kwargs")
network_kwargs.update(opts.model_kwargs)
# --------------------------------------------------------
# Create model
# --------------------------------------------------------
log("Instantiating model...")
if diffusion_model:
log("Diffusion model enabled")
total_in = opts.in_channels + (
opts.cond_channels if "cond_channels" in opts else 0
)
log(
f"Total input channels calculated: {total_in} (base: {opts.in_channels} + cond: {total_in - opts.in_channels})"
)
else:
log("Diffusion model disabled, direct U-Net, no preconditioning")
if diffusion_model:
model = precond_class(
img_resolution=opts.img_resolution,
in_channels=total_in,
out_channels=opts.out_channels,
label_dim=opts.label_dim,
use_fp16=opts.use_fp16,
**network_kwargs,
)
else:
model = precond_class(
img_resolution=opts.img_resolution,
in_channels=opts.in_channels,
out_channels=opts.out_channels,
label_dim=opts.label_dim,
**network_kwargs,
)
model = model.to(device)
total_num = sum(p.numel() for p in model.parameters())
trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
# --------------------------------------------------------
# Comprehensive Model Information Logging
# --------------------------------------------------------
log("Model Summary:")
log(f" └── Model Type: {type(model).__name__}")
log(f" └── Preconditioner: {opts.precond.upper()}")
log(f" └── Architecture: {opts.arch.upper()}")
if diffusion_model:
log(
f" └── Input Channels: {total_in} (base: {opts.in_channels} + cond: {total_in - opts.in_channels})"
)
else:
log(f" └── Input Channels: {opts.in_channels}")
log(f" └── Output Channels: {opts.out_channels}")
log(f" └── Label Dimension: {opts.label_dim}")
log(f" └── Image Resolution: {opts.img_resolution}")
if diffusion_model:
log(f" └── FP16 Enabled: {opts.use_fp16}")
else:
log(" └── FP16 Enabled: N/A for direct U-Net")
log(f" └── Model Parameters - Total: {total_num:,}, Trainable: {trainable_num}")
# Log network architecture details
log("Network Architecture:")
for key, value in network_kwargs.items():
log(f" └── {key}: {value}")
# Log device information
device = next(model.parameters()).device
log(f"Device: {device}")
# Log model dtype information
dtype = next(model.parameters()).dtype
log(f"Model Data Type: {dtype}")
# --------------------------------------------------------
# Loss function instance
# --------------------------------------------------------
loss_fn = loss_class()
log(f"Loss function instantiated: {loss_class.__name__}")
log(f" └── Loss Type: {opts.precond.upper()} Diffusion Loss")
return model, loss_fn