Source code for IPSL_AID.evaluater

# Copyright 2026 IPSL / CNRS / Sorbonne University
# Authors: Kazem Ardaneh, Kishanthan Kingston
#
# This work is licensed under the Creative Commons
# Attribution-NonCommercial-ShareAlike 4.0 International License.
# To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc-sa/4.0/

# ruff: noqa: E731
import numpy as np
import torch
import pandas as pd
from tqdm import tqdm
from IPSL_AID.diagnostics import (
    plot_validation_hexbin,
    plot_comparison_hexbin,
    plot_validation_pdfs,
    plot_power_spectra,
    plot_qq_quantiles,
    plot_surface,
    plot_zoom_comparison,
    plot_MAE_map,
    plot_error_map,
    plot_metrics_heatmap,
    plot_validation_mvcorr,
    plot_validation_mvcorr_space,
    plot_temporal_series_comparison,
)


[docs] class MetricTracker: """ A utility class for tracking and computing statistics of metric values. This class maintains a running average of metric values and provides methods to compute mean and root mean squared values. Attributes ---------- value : float Cumulative weighted sum of metric values count : int Total number of samples processed Examples -------- >>> tracker = MetricTracker() >>> tracker.update(10.0, 5) # value=10.0, count=5 samples >>> tracker.update(20.0, 3) # value=20.0, count=3 samples >>> print(tracker.getmean()) # (10*5 + 20*3) / (5+3) = 110/8 = 13.75 13.75 >>> print(tracker.getsqrtmean()) # sqrt(13.75) 3.7080992435478315 """
[docs] def __init__(self): """ Initialize MetricTracker with zero values. """ self.reset()
[docs] def reset(self): """ Reset all tracked values to zero. Returns ------- None """ self.value = 0.0 self.count = 0 self.value_sq = 0.0
[docs] def update(self, value, count): """ Update the tracker with new metric values. Parameters ---------- value : float The metric value to add count : int Number of samples this value represents (weight) Returns ------- None """ self.count += count self.value += value * count self.value_sq += (value**2) * count
[docs] def getmean(self): """ Calculate the mean of all tracked values. Returns ------- float Weighted mean of all values: total_value / total_count Raises ------ ZeroDivisionError If no values have been added (count == 0) """ if self.count == 0: raise ZeroDivisionError("Cannot compute mean with zero samples") return self.value / self.count
[docs] def getstd(self): """ Calculate the standard deviation of all tracked values. Returns ------- float Weighted standard deviation of all values: sqrt(E(x^2) - (E(x))^2) Raises ------ ZeroDivisionError If no values have been added (count == 0) """ if self.count == 0: raise ZeroDivisionError("Cannot compute std with zero samples") mean = self.getmean() variance = self.value_sq / self.count - mean**2 return np.sqrt(max(variance, 0.0)) # numerical safety
[docs] def getsqrtmean(self): """ Calculate the square root of the mean of all tracked values. Returns ------- float Square root of the weighted mean: sqrt(total_value / total_count) Raises ------ ZeroDivisionError If no values have been added (count == 0) """ return np.sqrt(self.getmean())
[docs] def mae_all(pred, true): """ Calculate Mean Absolute Error (MAE) between predicted and true values. Computes the MAE metric and returns both the number of elements and the mean absolute error value. Parameters ---------- pred : torch.Tensor Predicted values from the model true : torch.Tensor Ground truth values Returns ------- tuple (num_elements, mae_value) where: - num_elements (int): Total number of elements in the tensors - mae_value (torch.Tensor): Mean absolute error value Examples -------- >>> pred = torch.tensor([1.0, 2.0, 3.0]) >>> true = torch.tensor([1.1, 1.9, 3.2]) >>> num_elements, mae = mae_all(pred, true) >>> print(f"MAE: {mae.item():.4f}, Elements: {num_elements}") MAE: 0.1333, Elements: 3 Notes ----- The MAE is calculated as: mean(abs(pred - true)) This function is useful for tracking metrics with MetricTracker """ num_elements = pred.numel() mae_value = torch.mean(torch.abs(pred - true)) return num_elements, mae_value
[docs] def nmae_all(pred, true, eps=1e-8): """ Normalized Mean Absolute Error (NMAE). NMAE = MAE(pred, true) / mean(abs(true)) Computes the NMAE metric and returns both the number of elements and the normalized mean absolute error value. Parameters ---------- pred : torch.Tensor Predicted values from the model true : torch.Tensor Ground truth values eps : float Small value to avoid division by zero Returns ------- tuple (num_elements, mae_value) where: - num_elements (int): Total number of elements in the tensors - mae_value (torch.Tensor): Mean absolute error value Examples -------- >>> pred = torch.tensor([1.0, 2.0, 3.0]) >>> true = torch.tensor([1.1, 1.9, 3.2]) >>> num_elements, nmae = nmae_all(pred, true) >>> print(f"NMAE: {nmae.item():.4f}, Elements: {num_elements}") NMAE: 0.047059, Elements: 3 Notes ----- The NMAE is calculated as: MAE(pred, true) / mean(abs(true)) This function is useful for tracking metrics with MetricTracker """ num_elements = pred.numel() mae = torch.mean(torch.abs(pred - true)) norm = torch.mean(torch.abs(true)) + eps nmae = mae / norm return num_elements, nmae
# To verify with Kazem
[docs] def crps_ensemble_all(pred_ens, true): """ Continuous Ranked Probability Score (CRPS) for an ensemble. Computes the CRPS metric for ensemble predictions and returns both the number of elements and the mean CRPS value. Parameters ---------- pred_ens : torch.Tensor Ensemble predictions, shape [N_ens, N_pixels] true : torch.Tensor Ground truth values, shape [N_pixels] Returns ------- tuple (num_elements, crps_mean) where: - num_elements (int): Total number of elements in the tensors - crps_mean (torch.Tensor): Mean CRPS Notes ----- The CRPS for an ensemble is computed as: CRPS = E|X - y| - 0.5 * E|X - X'| where X and X' are independent ensemble members and y is the observation. """ # Number of ensemble members n = pred_ens.shape[0] # Sort ensemble pred_ens_sorted, _ = torch.sort(pred_ens, dim=0) # Term 1: E|X - y| term1 = torch.mean(torch.abs(pred_ens - true.unsqueeze(0)), dim=0) # Term 2: ensemble spread term diff = pred_ens_sorted[1:] - pred_ens_sorted[:-1] weight = torch.arange(1, n, device=pred_ens.device) * torch.arange( n - 1, 0, -1, device=pred_ens.device ) term2 = torch.sum(diff * weight.unsqueeze(1), dim=0) / (n**2) crps_pixel = term1 - term2 # [N_pixels] # Final aggregation num_elements = crps_pixel.numel() crps_mean = crps_pixel.mean() return num_elements, crps_mean
[docs] def rmse_all(pred, true): """ Calculate Root Mean Square Error (RMSE) between predicted and true values. Computes the RMSE metric and returns both the number of elements and the root mean square error value. Parameters ---------- pred : torch.Tensor Predicted values from the model true : torch.Tensor Ground truth values Returns ------- tuple (num_elements, rmse_value) where: - num_elements (int): Total number of elements in the tensors - rmse_value (torch.Tensor): Root mean square error value Examples -------- >>> pred = torch.tensor([1.0, 2.0, 3.0]) >>> true = torch.tensor([1.1, 1.9, 3.2]) >>> num_elements, rmse = rmse_all(pred, true) >>> print(f"RMSE: {rmse.item():.4f}, Elements: {num_elements}") RMSE: 0.1414, Elements: 3 Notes ----- The RMSE is calculated as: sqrt(mean((pred - true)^2)) This function is useful for tracking metrics with MetricTracker """ num_elements = pred.numel() mse = torch.mean((pred - true) ** 2) rmse_value = torch.sqrt(mse) return num_elements, rmse_value
[docs] def r2_all(pred, true): """ Calculate R2 (coefficient of determination) between predicted and true values. Computes the R2 metric and returns both the number of elements and the R2 value. Parameters ---------- pred : torch.Tensor Predicted values from the model true : torch.Tensor Ground truth values Returns ------- tuple (num_elements, r2_value) where: - num_elements (int): Total number of elements in the tensors - r2_value (torch.Tensor): R2 score Notes ----- R2 is calculated as: R2 = 1 - sum((true - pred)^2) / sum((true - mean(true))^2) This implementation is fully torch-based and works on CPU and GPU. """ if pred.shape != true.shape: raise RuntimeError(f"Shape mismatch: pred {pred.shape} vs true {true.shape}") eps = 1e-12 # Small value to avoid division by zero when variance is zero num_elements = pred.numel() # Flatten pred_flat = pred.reshape(-1) true_flat = true.reshape(-1) # Residual sum of squares ss_res = torch.sum((true_flat - pred_flat) ** 2) # Total sum of squares true_mean = torch.mean(true_flat) ss_tot = torch.sum((true_flat - true_mean) ** 2) # R2 score r2_value = 1.0 - ss_res / (ss_tot + eps) return num_elements, r2_value
[docs] def pearson_all(pred, true): """ Compute the Pearson correlation coefficient between predicted and ground truth values using torch.corrcoef. Parameters ---------- pred : torch.Tensor Predicted values from the model. true : torch.Tensor Ground truth values. Returns ------- tuple (num_elements, pearson_value) where: - num_elements (int): Total number of elements in the tensors. - pearson_value (torch.Tensor): Pearson correlation coefficient. Notes ----- The Pearson correlation coefficient is defined as: rho = Cov(pred, true) / (std(pred) * std(true)) """ if pred.shape != true.shape: raise RuntimeError(f"Shape mismatch: {pred.shape} vs {true.shape}") num_elements = pred.numel() # Flatten tensors to 1D vectors pred_flat = pred.reshape(-1) true_flat = true.reshape(-1) # Stack into a 2 x N matrix required by torch.corrcoef stacked = torch.stack([pred_flat, true_flat], dim=0) # Compute correlation matrix corr_matrix = torch.corrcoef(stacked) # Extract Pearson correlation coefficient between # predictions (row 0) and truth (row 1) pearson_value = corr_matrix[0, 1] return num_elements, pearson_value
[docs] def kl_divergence_all(pred, true): """ Compute the Kullback–Leibler (KL) divergence between predicted and ground truth distributions using histogram-based estimation. Parameters ---------- pred : torch.Tensor Predicted values from the model. true : torch.Tensor Ground truth values. Returns ------- tuple (num_elements, kl_value) where: - num_elements (int): Total number of elements in the tensors. - kl_value (torch.Tensor): KL divergence value. Notes ----- The KL divergence is defined as: KL(P|Q) = sum_i P_i * log(P_i / Q_i) where: - P represents the true distribution - Q represents the predicted distribution """ if pred.shape != true.shape: raise RuntimeError(f"Shape mismatch: {pred.shape} vs {true.shape}") num_elements = pred.numel() n_bins = 100 eps = 1e-12 # Flatten tensors to 1D vectors pred_flat = pred.reshape(-1) true_flat = true.reshape(-1) # Combine for percentile computation all_values = torch.cat([pred_flat, true_flat]) # Percentile clipping data_min = torch.quantile(all_values, 0.0025) data_max = torch.quantile(all_values, 0.995) data_range = data_max - data_min x_min = data_min - 0.05 * data_range x_max = data_max + 0.05 * data_range hist_pred = torch.histc(pred_flat, bins=n_bins, min=x_min.item(), max=x_max.item()) hist_true = torch.histc(true_flat, bins=n_bins, min=x_min.item(), max=x_max.item()) # Add epsilon hist_pred = hist_pred + eps hist_true = hist_true + eps # Normalize to probability mass hist_pred = hist_pred / hist_pred.sum() hist_true = hist_true / hist_true.sum() # KL divergence kl_value = torch.sum(hist_true * torch.log(hist_true / hist_pred)) return num_elements, kl_value
[docs] def denormalize( data, stats, norm_type, device, var_name=None, data_type=None, debug=False, logger=None, ): """ Denormalize a data tensor using the inverse of the normalization operation. Parameters ---------- data : torch.Tensor Normalized tensor to denormalize. stats : object Object containing the required statistics. norm_type : str Normalization type used originally. device : torch.device Device for tensor operations. var_name : str, optional Variable name for debugging. data_type : str, optional Data type for debugging (e.g., "residual", "coarse"). debug : bool, optional Enable debug logging. logger : Logger, optional Logger instance for debug output. """ # Add debug logging at the start if debug and logger: # Create context string context = "" if var_name: context = f" for {var_name}" if data_type: context += f" ({data_type})" logger.info( f"Denormalizing{context} with type '{norm_type}'\n" f" └── Denormalization stats;\n" f" └── vmin: {getattr(stats, 'vmin', None)}\n" f" └── vmax: {getattr(stats, 'vmax', None)}\n" f" └── vmean: {getattr(stats, 'vmean', None)}\n" f" └── vstd: {getattr(stats, 'vstd', None)}\n" f" └── median: {getattr(stats, 'median', None)}\n" f" └── iqr: {getattr(stats, 'iqr', None)}\n" f" └── q1: {getattr(stats, 'q1', None)}\n" f" └── q3: {getattr(stats, 'q3', None)}" ) # ------------------ MIN-MAX ------------------ if norm_type == "minmax": vmin = torch.tensor(stats.vmin, dtype=data.dtype, device=device) vmax = torch.tensor(stats.vmax, dtype=data.dtype, device=device) denom = vmax - vmin if denom == 0: return torch.zeros_like(data) return data * denom + vmin # ------------------ MIN-MAX [-1, 1] ----------------- elif norm_type == "minmax_11": vmin = torch.tensor(stats.vmin, dtype=data.dtype, device=device) vmax = torch.tensor(stats.vmax, dtype=data.dtype, device=device) denom = vmax - vmin if denom == 0: return torch.zeros_like(data) return ((data + 1) / 2) * denom + vmin # ------------------ STANDARD ----------------- elif norm_type == "standard": mean = torch.tensor(stats.vmean, dtype=data.dtype, device=device) std = torch.tensor(stats.vstd, dtype=data.dtype, device=device) if std == 0: return torch.zeros_like(data) return data * std + mean # ------------------ ROBUST ------------------- elif norm_type == "robust": median = torch.tensor(stats.median, dtype=data.dtype, device=device) iqr = torch.tensor(stats.iqr, dtype=data.dtype, device=device) if iqr == 0: return torch.zeros_like(data) return data * iqr + median # ------------------ LOG1P + MIN-MAX ------------------ elif norm_type == "log1p_minmax": log_min = torch.tensor(stats.vmin, dtype=data.dtype, device=device) log_max = torch.tensor(stats.vmax, dtype=data.dtype, device=device) denom = log_max - log_min if denom == 0: return torch.zeros_like(data) log_data = data * denom + log_min return torch.expm1(log_data) # ------------------ LOG1P + STANDARD ------------------ elif norm_type == "log1p_standard": mean = torch.tensor(stats.vmean, dtype=data.dtype, device=device) std = torch.tensor(stats.vstd, dtype=data.dtype, device=device) if std == 0: return torch.zeros_like(data) log_data = data * std + mean return torch.expm1(log_data) else: raise ValueError(f"Unsupported norm_type '{norm_type}'")
[docs] @torch.no_grad() def edm_sampler( model, image_input, class_labels=None, num_steps=40, sigma_min=0.02, sigma_max=80.0, rho=7, S_churn=40, S_min=0, S_max=float("inf"), S_noise=1, ): """ EDM sampler for diffusion model inference. Original work: Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. Original source: https://github.com/NVlabs/edm Parameters ---------- model : torch.nn.Module Diffusion model image_input : torch.Tensor Conditioning input (coarse + constants) class_labels : torch.Tensor, optional Time conditioning labels num_steps : int, optional Number of sampling steps sigma_min : float, optional Minimum noise level sigma_max : float, optional Maximum noise level rho : float, optional Time step exponent S_churn : int, optional Stochasticity parameter S_min : float, optional Minimum stochasticity threshold S_max : float, optional Maximum stochasticity threshold S_noise : float, optional Noise scale for stochasticity Returns ------- torch.Tensor Generated residual predictions """ batch_size, _, H, W = image_input.shape # Get the actual model (unwrap DataParallel if needed) if isinstance(model, torch.nn.DataParallel): model = model.module # init noise init_noise = torch.randn( (batch_size, model.out_channels, H, W), dtype=image_input.dtype, device=image_input.device, ) # Adjust noise levels based on what's supported by the model. sigma_min = max(sigma_min, model.sigma_min) sigma_max = min(sigma_max, model.sigma_max) # Time step discretization. step_indices = torch.arange( num_steps, dtype=image_input.dtype, device=image_input.device ) t_steps = ( sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho)) ) ** rho t_steps = torch.cat( [model.round_sigma(t_steps), torch.zeros_like(t_steps[:1])] ) # t_N = 0 # Main sampling loop. x_next = init_noise * t_steps[0] for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 x_cur = x_next # Increase noise temporarily. gamma = ( min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0 ) t_hat = model.round_sigma(t_cur + gamma * t_cur) x_hat = x_cur + (t_hat**2 - t_cur**2).sqrt() * S_noise * torch.randn_like(x_cur) # Euler step. denoised = model(x_hat, t_hat, image_input, class_labels).to(torch.float64) d_cur = (x_hat - denoised) / t_hat x_next = x_hat + (t_next - t_hat) * d_cur # Apply 2nd order correction. if i < num_steps - 1: denoised = model(x_next, t_next, image_input, class_labels).to( torch.float64 ) d_prime = (x_next - denoised) / t_next x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) return x_next.detach()
[docs] @torch.no_grad() def sampler( epoch, batch_idx, model, image_input, class_labels=None, num_steps=18, sigma_min=None, sigma_max=None, rho=7, solver="heun", discretization="edm", schedule="linear", scaling="none", epsilon_s=1e-3, C_1=0.001, C_2=0.008, M=1000, alpha=1, S_churn=40, S_min=0, S_max=float("inf"), S_noise=1, logger=None, ): """ General sampler for diffusion model inference with multiple configurations. Original work: Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. Original source: https://github.com/NVlabs/edm Parameters ---------- model : torch.nn.Module Diffusion model image_input : torch.Tensor Conditioning input (coarse + constants) class_labels : torch.Tensor, optional Time conditioning labels num_steps : int, optional Number of sampling steps sigma_min : float, optional Minimum noise level sigma_max : float, optional Maximum noise level rho : float, optional Time step exponent for EDM discretization solver : str, optional Solver type: 'euler' or 'heun' discretization : str, optional Discretization type: 'vp', 've', 'iddpm', or 'edm' schedule : str, optional Noise schedule: 'vp', 've', or 'linear' scaling : str, optional Scaling type: 'vp' or 'none' epsilon_s : float, optional Small epsilon for VP schedule C_1 : float, optional Constant for IDDPM discretization C_2 : float, optional Constant for IDDPM discretization M : int, optional Number of steps for IDDPM discretization alpha : float, optional Parameter for Heun's method S_churn : int, optional Stochasticity parameter S_min : float, optional Minimum stochasticity threshold S_max : float, optional Maximum stochasticity threshold S_noise : float, optional Noise scale for stochasticity logger : logging.Logger, optional Logger instance for logging sampler parameters Returns ------- torch.Tensor Generated residual predictions """ # Only the original asserts with messages assert solver in [ "euler", "heun", ], f"Solver must be 'euler' or 'heun', but got '{solver}'" assert ( discretization in ["vp", "ve", "iddpm", "edm"] ), f"Discretization must be 'vp', 've', 'iddpm' or 'edm', but got '{discretization}'" assert schedule in [ "vp", "ve", "linear", ], f"Schedule must be 'vp', 've' or 'linear', but got '{schedule}'" assert scaling in [ "vp", "none", ], f"Scaling must be 'vp' or 'none', but got '{scaling}'" batch_size, _, H, W = image_input.shape # Get the actual model (unwrap DataParallel if needed) if isinstance(model, torch.nn.DataParallel): model = model.module # Initialize noise latents = torch.randn( (batch_size, model.out_channels, H, W), dtype=image_input.dtype, device=image_input.device, ) # Helper functions for VP & VE noise level schedules. vp_sigma = lambda beta_d, beta_min: ( lambda t: (np.e ** (0.5 * beta_d * (t**2) + beta_min * t) - 1) ** 0.5 ) vp_sigma_deriv = lambda beta_d, beta_min: ( lambda t: 0.5 * (beta_min + beta_d * t) * (sigma(t) + 1 / sigma(t)) ) vp_sigma_inv = lambda beta_d, beta_min: ( lambda sigma: ( ((beta_min**2 + 2 * beta_d * (sigma**2 + 1).log()).sqrt() - beta_min) / beta_d ) ) ve_sigma = lambda t: t.sqrt() ve_sigma_deriv = lambda t: 0.5 / t.sqrt() ve_sigma_inv = lambda sigma: sigma**2 # Select default noise level range based on the specified time step discretization. if sigma_min is None: vp_def = vp_sigma(beta_d=19.9, beta_min=0.1)(t=epsilon_s) sigma_min = {"vp": vp_def, "ve": 0.02, "iddpm": 0.002, "edm": 0.002}[ discretization ] if sigma_max is None: vp_def = vp_sigma(beta_d=19.9, beta_min=0.1)(t=1) sigma_max = {"vp": vp_def, "ve": 100, "iddpm": 81, "edm": 80}[discretization] # Log sampler parameters if logger is provided if logger is not None and epoch == 0 and batch_idx == 0: logger.info("=== Sampler Parameters ===") logger.info(f" └── num_steps: {num_steps}") logger.info(f" └── solver: {solver}") logger.info(f" └── discretization: {discretization}") logger.info(f" └── schedule: {schedule}") logger.info(f" └── scaling: {scaling}") logger.info(f" └── sigma_min: {sigma_min}") logger.info(f" └── sigma_max: {sigma_max}") logger.info(f" └── rho: {rho}") logger.info(f" └── S_churn: {S_churn}") logger.info(f" └── S_min: {S_min}") logger.info(f" └── S_max: {S_max}") logger.info(f" └── S_noise: {S_noise}") logger.info(f" └── epsilon_s: {epsilon_s}") logger.info(f" └── C_1: {C_1}") logger.info(f" └── C_2: {C_2}") logger.info(f" └── M: {M}") logger.info(f" └── alpha: {alpha}") logger.info("==========================") # Adjust noise levels based on what's supported by the network. sigma_min = max(sigma_min, model.sigma_min) sigma_max = min(sigma_max, model.sigma_max) # Compute corresponding betas for VP. vp_beta_d = ( 2 * (np.log(sigma_min**2 + 1) / epsilon_s - np.log(sigma_max**2 + 1)) / (epsilon_s - 1) ) vp_beta_min = np.log(sigma_max**2 + 1) - 0.5 * vp_beta_d # Define time steps in terms of noise level. step_indices = torch.arange( num_steps, dtype=image_input.dtype, device=image_input.device ) if discretization == "vp": orig_t_steps = 1 + step_indices / (num_steps - 1) * (epsilon_s - 1) sigma_steps = vp_sigma(vp_beta_d, vp_beta_min)(orig_t_steps) elif discretization == "ve": orig_t_steps = (sigma_max**2) * ( (sigma_min**2 / sigma_max**2) ** (step_indices / (num_steps - 1)) ) sigma_steps = ve_sigma(orig_t_steps) elif discretization == "iddpm": u = torch.zeros(M + 1, dtype=image_input.dtype, device=image_input.device) alpha_bar = lambda j: (0.5 * np.pi * j / M / (C_2 + 1)).sin() ** 2 for j in torch.arange(M, 0, -1, device=image_input.device): # M, ..., 1 u[j - 1] = ( (u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=C_1) - 1 ).sqrt() u_filtered = u[torch.logical_and(u >= sigma_min, u <= sigma_max)] sigma_steps = u_filtered[ ((len(u_filtered) - 1) / (num_steps - 1) * step_indices) .round() .to(torch.int64) ] else: assert discretization == "edm" sigma_steps = ( sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho)) ) ** rho # Define noise level schedule. if schedule == "vp": sigma = vp_sigma(vp_beta_d, vp_beta_min) sigma_deriv = vp_sigma_deriv(vp_beta_d, vp_beta_min) sigma_inv = vp_sigma_inv(vp_beta_d, vp_beta_min) elif schedule == "ve": sigma = ve_sigma sigma_deriv = ve_sigma_deriv sigma_inv = ve_sigma_inv else: assert schedule == "linear" sigma = lambda t: t sigma_deriv = lambda t: 1 sigma_inv = lambda sigma: sigma # Define scaling schedule. if scaling == "vp": s = lambda t: 1 / (1 + sigma(t) ** 2).sqrt() s_deriv = lambda t: -sigma(t) * sigma_deriv(t) * (s(t) ** 3) else: assert scaling == "none" s = lambda t: 1 s_deriv = lambda t: 0 # Compute final time steps based on the corresponding noise levels. t_steps = sigma_inv(model.round_sigma(sigma_steps)) t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0 # Main sampling loop. t_next = t_steps[0] x_next = latents.to(image_input.dtype) * (sigma(t_next) * s(t_next)) for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 x_cur = x_next # Increase noise temporarily. gamma = ( min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= sigma(t_cur) <= S_max else 0 ) t_hat = sigma_inv(model.round_sigma(sigma(t_cur) + gamma * sigma(t_cur))) x_hat = s(t_hat) / s(t_cur) * x_cur + ( sigma(t_hat) ** 2 - sigma(t_cur) ** 2 ).clip(min=0).sqrt() * s(t_hat) * S_noise * torch.randn_like(x_cur) # Euler step. h = t_next - t_hat denoised = model(x_hat / s(t_hat), sigma(t_hat), image_input, class_labels).to( image_input.dtype ) d_cur = ( sigma_deriv(t_hat) / sigma(t_hat) + s_deriv(t_hat) / s(t_hat) ) * x_hat - sigma_deriv(t_hat) * s(t_hat) / sigma(t_hat) * denoised x_prime = x_hat + alpha * h * d_cur t_prime = t_hat + alpha * h # Apply 2nd order correction. if solver == "euler" or i == num_steps - 1: x_next = x_hat + h * d_cur else: assert solver == "heun" denoised = model( x_prime / s(t_prime), sigma(t_prime), image_input, class_labels ).to(image_input.dtype) d_prime = ( sigma_deriv(t_prime) / sigma(t_prime) + s_deriv(t_prime) / s(t_prime) ) * x_prime - sigma_deriv(t_prime) * s(t_prime) / sigma(t_prime) * denoised x_next = x_hat + h * ( (1 - 1 / (2 * alpha)) * d_cur + 1 / (2 * alpha) * d_prime ) return x_next.detach()
[docs] def reconstruct_original_layout( epoch, args, paths, steps, all_data, dataset, device, logger ): """ Robust reconstruction using dataset information directly. Parameters: ----------- all_data : dict Dictionary containing lists of batches for: - 'predictions': model predictions [B, C, H, W] - 'coarse': coarse resolution data [B, C, H, W] - 'fine': fine resolution ground truth [B, C, H, W] - 'lat': latitude coordinates [B, H] - 'lon': longitude coordinates [B, W] dataset : torch.utils.data.Dataset The validation dataset instance device : torch.device Device to store tensors on logger : Logger Logger instance for logging Returns: -------- dict: Reconstructed data with metadata """ # Get dataset parameters time_batchs = len(dataset.time_batchs) sbatch = dataset.sbatch total_dataset_samples = len(dataset) # time_batchs * sbatch # dataset_times = dataset.loaded_dfs.time.values # Get total samples from all batches total_batch_samples = sum(batch.shape[0] for batch in all_data["predictions"]) logger.info("Dataset reconstruction info:") logger.info(f" └── time_batchs: {time_batchs}") logger.info(f" └── sbatch: {sbatch}") logger.info(f" └── total dataset samples: {total_dataset_samples}") logger.info(f" └── total batch samples: {total_batch_samples}") # Handle different scenarios if total_batch_samples > total_dataset_samples: error_msg = ( f"More batch samples ({total_batch_samples}) than dataset samples ({total_dataset_samples})! " f"Something is wrong with the DataLoader." ) logger.error(error_msg) raise elif total_batch_samples < total_dataset_samples: logger.info( f"Note: Batch samples ({total_batch_samples}) < dataset samples ({total_dataset_samples})" ) logger.info("This is normal if DataLoader has drop_last=True") # Get sample shape pred_shape = all_data["predictions"][0].shape[1:] # [C, H, W] C, H, W = pred_shape logger.info(f"Sample shape: C={C}, H={H}, W={W}") # Initialize reconstruction arrays reconstructions = {} for key in ["predictions", "coarse", "fine"]: reconstructions[key] = torch.zeros( time_batchs, sbatch, C, H, W, device=device, dtype=all_data[key][0].dtype ) logger.info(f"Initialized {key} with shape: {reconstructions[key].shape}") reconstructions["lat"] = torch.zeros( time_batchs, sbatch, H, device=device, dtype=all_data["lat"][0].dtype ) reconstructions["lon"] = torch.zeros( time_batchs, sbatch, W, device=device, dtype=all_data["lon"][0].dtype ) logger.info(f"Initialized lat with shape: {reconstructions['lat'].shape}") logger.info(f"Initialized lon with shape: {reconstructions['lon'].shape}") # Create position tracking position_filled = torch.zeros(time_batchs, sbatch, dtype=torch.bool, device=device) # Map each dataset index to position index_to_position = {} for idx in range(total_dataset_samples): sindex = idx % sbatch tindex = idx // sbatch index_to_position[idx] = (tindex, sindex) logger.info(f"Created index mapping for {total_dataset_samples} samples") # Reconstruct using dataset indices dataset_idx = 0 total_reconstructed = 0 logger.info("Starting reconstruction process...") for batch_idx in range(len(all_data["predictions"])): batch = all_data["predictions"][batch_idx] batch_size = batch.shape[0] logger.info( f"Processing batch {batch_idx+1}/{len(all_data['predictions'])} with size {batch_size}" ) for i_in_batch in range(batch_size): # We can only reconstruct up to dataset samples if dataset_idx >= total_dataset_samples: logger.warning( f"Stopping at dataset_idx {dataset_idx} (dataset has {total_dataset_samples} samples)" ) break tindex, sindex = index_to_position[dataset_idx] # Store all data for key in ["predictions", "coarse", "fine"]: reconstructions[key][tindex, sindex] = all_data[key][batch_idx][ i_in_batch ] reconstructions["lat"][tindex, sindex] = all_data["lat"][batch_idx][ i_in_batch ] reconstructions["lon"][tindex, sindex] = all_data["lon"][batch_idx][ i_in_batch ] position_filled[tindex, sindex] = True total_reconstructed += 1 dataset_idx += 1 # Free memory for this batch for key in ("predictions", "coarse", "fine", "lat", "lon"): all_data[key][batch_idx] = None # Break if we've reached dataset limit if dataset_idx >= total_dataset_samples: break logger.info(f"Successfully reconstructed {total_reconstructed} samples") # Check results filled_count = position_filled.sum().item() if filled_count != total_reconstructed: logger.warning( f"filled_count ({filled_count}) != total_reconstructed ({total_reconstructed})" ) if filled_count < total_dataset_samples: missing = total_dataset_samples - filled_count logger.info( f"Note: {missing}/{total_dataset_samples} samples not reconstructed" ) logger.info("This is expected with drop_last=True in DataLoader") # Metadata metadata = { "time_batchs": time_batchs, "sbatch": sbatch, "total_dataset_samples": total_dataset_samples, "total_batch_samples": total_batch_samples, "total_reconstructed": total_reconstructed, "position_filled": position_filled, "index_to_position": index_to_position, "filled_ratio": filled_count / total_dataset_samples if total_dataset_samples > 0 else 0, "reconstruction_device": str(device), } logger.info("Reconstruction completed successfully") # Check if we need to combine spatial blocks for inference if args.run_type in ["inference", "inference_regional"]: logger.info( "Inference mode is active - combining spatial blocks to reconstruct full domain..." ) # Get evaluation slices directly from the DataPreprocessor if hasattr(dataset, "eval_slices"): eval_slices = dataset.eval_slices logger.info(f"Found {len(eval_slices)} evaluation slices") # Determine the spatial extent covered by evaluation slices. # In regional inference, slices may not start at index 0, so the domain size # is computed from the min/max slice indices. lat_min = min(s[0] for s in eval_slices) lat_max = max(s[1] for s in eval_slices) lon_min = min(s[2] for s in eval_slices) lon_max = max(s[3] for s in eval_slices) covered_H = lat_max - lat_min covered_W = lon_max - lon_min logger.info(f"Dataset dimensions: H={dataset.H}, W={dataset.W}") logger.info(f"Blocks cover: H={covered_H}, W={covered_W}") # Initialize coordinate arrays lat_reconstructed = torch.zeros(covered_H, device=device) lon_reconstructed = torch.zeros(covered_W, device=device) # Track which coordinates we've filled (must fill all!) lat_filled = torch.zeros(covered_H, dtype=torch.bool, device=device) lon_filled = torch.zeros(covered_W, dtype=torch.bool, device=device) # Initialize arrays for the COVERED area combined_data = {} for key in ["predictions", "coarse", "fine"]: combined_data[key] = torch.zeros( time_batchs, C, covered_H, covered_W, device=device, dtype=reconstructions[key].dtype, ) # Track grid coverage (must cover all!) coverage_mask = torch.zeros( covered_H, covered_W, dtype=torch.bool, device=device ) # Combine blocks and reconstruct coordinates blocks_placed = 0 for t in range(time_batchs): for spatial_idx, (lat_start, lat_end, lon_start, lon_end) in enumerate( eval_slices ): # Shift slice indices into the local reconstruction coordinate system. # This is required for regional inference where slices do not start at 0. # For global inference lat_min=lon_min=0 so indices remain unchanged. lat_start -= lat_min lat_end -= lat_min lon_start -= lon_min lon_end -= lon_min if spatial_idx >= sbatch: error_msg = ( f"CRITICAL ERROR: Slice index {spatial_idx} exceeds sbatch {sbatch}. " f"eval_slices has {len(eval_slices)} slices but only {sbatch} spatial blocks reconstructed." ) logger.error(error_msg) raise ValueError(error_msg) # Place block in combined array for key in ["predictions", "coarse", "fine"]: combined_data[key][ t, :, lat_start:lat_end, lon_start:lon_end ] = reconstructions[key][t, spatial_idx] # Reconstruct LATITUDE coordinates from this block block_lat = reconstructions["lat"][t, spatial_idx] # [H_block] lat_reconstructed[lat_start:lat_end] = block_lat lat_filled[lat_start:lat_end] = True # Reconstruct LONGITUDE coordinates from this block block_lon = reconstructions["lon"][t, spatial_idx] # [W_block] lon_reconstructed[lon_start:lon_end] = block_lon lon_filled[lon_start:lon_end] = True # Mark grid coverage coverage_mask[lat_start:lat_end, lon_start:lon_end] = True blocks_placed += 1 logger.info(f"Combined {blocks_placed} spatial blocks") # VERIFY COMPLETE COVERAGE - RAISE ERROR IF INCOMPLETE # Check latitude coordinate coverage lat_missing = (~lat_filled).sum().item() if lat_missing > 0: missing_indices = torch.nonzero(~lat_filled).squeeze().cpu().numpy() error_msg = ( f"CRITICAL ERROR: Latitude coordinate reconstruction incomplete!\n" f"Missing {lat_missing}/{covered_H} latitude coordinates.\n" f"Missing indices: {missing_indices[:10]}{'...' if len(missing_indices) > 10 else ''}\n" f"This indicates blocks don't cover the full latitude range." ) logger.error(error_msg) raise ValueError(error_msg) # Check longitude coordinate coverage lon_missing = (~lon_filled).sum().item() if lon_missing > 0: missing_indices = torch.nonzero(~lon_filled).squeeze().cpu().numpy() error_msg = ( f"CRITICAL ERROR: Longitude coordinate reconstruction incomplete!\n" f"Missing {lon_missing}/{covered_W} longitude coordinates.\n" f"Missing indices: {missing_indices[:10]}{'...' if len(missing_indices) > 10 else ''}\n" f"This indicates blocks don't cover the full longitude range." ) logger.error(error_msg) raise ValueError(error_msg) # Check grid coverage uncovered_cells = (~coverage_mask).sum().item() if uncovered_cells > 0: # Find where coverage is missing missing_mask = ~coverage_mask missing_positions = torch.nonzero(missing_mask) error_msg = ( f"CRITICAL ERROR: Grid coverage incomplete!\n" f"Missing {uncovered_cells}/{covered_H*covered_W} grid cells.\n" f"Coverage: {coverage_mask.sum().item()/(covered_H*covered_W)*100:.1f}%\n" f"First 10 missing positions (lat, lon): {missing_positions[:10].cpu().numpy().tolist()}" ) logger.error(error_msg) raise ValueError(error_msg) # Fix longitude discontinuity when blocks cross the 0°/360° meridian. # np.unwrap keeps the longitude coordinate monotonic # Ex: 358,359,0,1 to 358,359,360,361 lon_reconstructed = torch.from_numpy( np.rad2deg(np.unwrap(np.deg2rad(lon_reconstructed.cpu().numpy()))) ).to(device) # All checks passed - reconstruction is complete logger.info("✅ Coordinate reconstruction complete") logger.info("✅ Grid coverage complete") logger.info( f"Latitude range: {lat_reconstructed.min():.2f} to {lat_reconstructed.max():.2f}" ) logger.info( f"Longitude range: {lon_reconstructed.min():.2f} to {lon_reconstructed.max():.2f}" ) # Add reconstruction info to metadata metadata["coverage_info"] = { "covered_H": covered_H, "covered_W": covered_W, "full_H": dataset.H, "full_W": dataset.W, "coverage_complete": True, "coordinates_complete": True, "lat_range": [ lat_reconstructed.min().item(), lat_reconstructed.max().item(), ], "lon_range": [ lon_reconstructed.min().item(), lon_reconstructed.max().item(), ], "lat_reconstructed": lat_reconstructed.cpu(), "lon_reconstructed": lon_reconstructed.cpu(), } # Store reconstructed coordinates in reconstructions dict reconstructions["lat_reconstructed"] = lat_reconstructed reconstructions["lon_reconstructed"] = lon_reconstructed # Add combined data to reconstructions dict reconstructions["combined"] = combined_data else: logger.error( "Could not find eval_slices in dataset. Cannot combine spatial blocks." ) raise AttributeError( "Dataset missing 'eval_slices' attribute for inference reconstruction." ) logger.info(f"Generating block wise plots for epoch {epoch}...") # Loop through spatial blocks for spatial_idx in range(sbatch): # Extract data for this spatial block # shape: [time_batchs, C, H, W] predictions_block = reconstructions["predictions"][:, spatial_idx] fine_block = reconstructions["fine"][:, spatial_idx] coarse_block = reconstructions["coarse"][:, spatial_idx] # lat_block = reconstructions['lat'][:, spatial_idx] # lon_block = reconstructions['lon'][:, spatial_idx] # 0. QQ Plot save_path = plot_qq_quantiles( predictions_block, # [time_batchs, C, H, W] fine_block, # [time_batchs, C, H, W] coarse_block, # [time_batchs, C, H, W] variable_names=args.varnames_list, units=None, # You might want to add units to args quantiles=[0.90, 0.95, 0.975, 0.99, 0.995], filename=f"{args.run_type}_qq_epoch_{epoch}_spatial_block_{spatial_idx:03d}.png", save_dir=paths.results, ) logger.info(f"Saved QQ plot to {save_path}") # 1. Validation Hexbin Plot save_path = plot_validation_hexbin( predictions=predictions_block, targets=fine_block, variable_names=args.varnames_list, filename=f"{args.run_type}_validation_hexbin_epoch_{epoch}_sblock_{spatial_idx:03d}.png", save_dir=paths.results, ) logger.info(f"Saved validation hexbin plot to: {save_path}") # 2. Comparison Hexbin Plot save_path = plot_comparison_hexbin( predictions=predictions_block, targets=fine_block, coarse_inputs=coarse_block, variable_names=args.varnames_list, filename=f"{args.run_type}_comparison_hexbin_epoch_{epoch}_sblock_{spatial_idx:03d}.png", save_dir=paths.results, ) logger.info(f"Saved comparison hexbin plot to: {save_path}") # 3. Validation PDFs Plot save_path = plot_validation_pdfs( predictions=predictions_block, targets=fine_block, coarse_inputs=coarse_block, variable_names=args.varnames_list, filename=f"{args.run_type}_validation_pdfs_epoch_{epoch}_sblock_{spatial_idx:03d}.png", save_dir=paths.results, ) logger.info(f"Saved validation PDFs plot to: {save_path}") # 4. Power Spectra Plot dlon = getattr(steps, "d_longitude", None) dlat = getattr(steps, "d_latitude", None) assert dlon is not None, "d_longitude not found in steps" assert dlat is not None, "d_latitude not found in steps" save_path = plot_power_spectra( predictions=predictions_block, targets=fine_block, coarse_inputs=coarse_block, dlat=dlat, dlon=dlon, variable_names=args.varnames_list, filename=f"{args.run_type}_power_spectra_epoch_{epoch}_sblock_{spatial_idx:03d}.png", save_dir=paths.results, ) logger.info(f"Saved power spectra plot to: {save_path}") # 5. MAE map plot (time-averaged) # Latitude and longitude coordinates for this spatial block. # Coordinates are time-invariant, so we take them from the first time index (t = 0). first_time_idx = 0 # Get coordinates for this spatial block lat_block = reconstructions["lat"][first_time_idx, spatial_idx] # [H] lon_block = reconstructions["lon"][first_time_idx, spatial_idx] # [W] save_path = plot_MAE_map( predictions=predictions_block, # [T, C, H, W] targets=fine_block, # [T, C, H, W] lat_1d=lat_block, # [H] lon_1d=lon_block, # [W] variable_names=args.varnames_list, filename=f"{args.run_type}_mae_map_epoch_{epoch}_sblock_{spatial_idx:03d}.png", save_dir=paths.results, ) logger.info(f"Saved MAE map to: {save_path}") # 6. Multivariate Correlation Maps # Convert 1D lat/lon to 2D meshgrid lat_2d, lon_2d = torch.meshgrid(lat_block, lon_block, indexing="ij") save_path = plot_validation_mvcorr( predictions=predictions_block, # [T, C, H, W] targets=fine_block, # [T, C, H, W] coarse_inputs=coarse_block, # optional lat=lat_2d.numpy(), lon=lon_2d.numpy(), variable_names=args.varnames_list, filename=f"{args.run_type}_mvcorr_epoch_{epoch}_sblock_{spatial_idx:03d}.png", save_dir=paths.results, ) logger.info(f"Saved multivariate correlation map to: {save_path}") # 7. Surface plot coarse = reconstructions["coarse"][ first_time_idx : first_time_idx + 1, spatial_idx ] fine = reconstructions["fine"][first_time_idx : first_time_idx + 1, spatial_idx] pred = reconstructions["predictions"][ first_time_idx : first_time_idx + 1, spatial_idx ] save_path = plot_surface( predictions=pred, targets=fine, coarse_inputs=coarse, lat_1d=lat_block, lon_1d=lon_block, variable_names=args.varnames_list, filename=f"{args.run_type}_plot_surface_epoch_{epoch}_sblock_{spatial_idx:03d}.png", save_dir=paths.results, ) logger.info(f"Saved surface plot to: {save_path}") # 8. Temporal series save_path = plot_temporal_series_comparison( predictions=predictions_block, targets=fine_block, # coarse_inputs=coarse_block, variable_names=args.varnames_list, filename=f"{args.run_type}_temporal_series_epoch_{epoch}_sblock_{spatial_idx:03d}.png", save_dir=paths.results, ) logger.info(f"Saved temporal series plot to: {save_path}") # 9. Multivariate spatial correlation time series save_path = plot_validation_mvcorr_space( predictions=predictions_block, targets=fine_block, coarse_inputs=coarse_block, variable_names=args.varnames_list, filename=f"{args.run_type}_mvcorr_space_epoch_{epoch}_sblock_{spatial_idx:03d}.png", save_dir=paths.results, ) logger.info( f"Saved multivariate spatial correlation time series to: {save_path}" ) # For inference mode, also generate full domain plots if args.run_type in ["inference", "inference_regional"]: assert ( "combined" in reconstructions ), "Combined data not found in reconstructions for inference mode" logger.info( f"Generating full domain plots for inference mode, epoch {epoch}..." ) # Get combined data for full domain predictions_full = reconstructions["combined"][ "predictions" ] # [time_batchs, C, covered_H, covered_W] fine_full = reconstructions["combined"][ "fine" ] # [time_batchs, C, covered_H, covered_W] coarse_full = reconstructions["combined"][ "coarse" ] # [time_batchs, C, covered_H, covered_W] lat_full = reconstructions["lat_reconstructed"] # [covered_H] lon_full = reconstructions["lon_reconstructed"] # [covered_W] # Generate full domain versions of all plots # 0. QQ Plot for full domain (averaged over space) save_path = plot_qq_quantiles( predictions_full, # [time_batchs, C, H, W] fine_full, # [time_batchs, C, H, W] coarse_full, # [time_batchs, C, H, W] variable_names=args.varnames_list, units=None, quantiles=[0.90, 0.95, 0.975, 0.99, 0.995], filename=f"{args.run_type}_full_domain_qq_epoch_{epoch}.png", save_dir=paths.results, save_npz=True, ) logger.info(f"Saved full domain QQ plot to {save_path}") # 1. Validation Hexbin Plot for full domain save_path = plot_validation_hexbin( predictions=predictions_full, targets=fine_full, variable_names=args.varnames_list, filename=f"{args.run_type}_full_domain_validation_hexbin_epoch_{epoch}.png", save_dir=paths.results, ) logger.info(f"Saved full domain validation hexbin plot to: {save_path}") # 2. Comparison Hexbin Plot for full domain save_path = plot_comparison_hexbin( predictions=predictions_full, targets=fine_full, coarse_inputs=coarse_full, variable_names=args.varnames_list, filename=f"{args.run_type}_full_domain_comparison_hexbin_epoch_{epoch}.png", save_dir=paths.results, ) logger.info(f"Saved full domain comparison hexbin plot to: {save_path}") # 3. Validation PDFs Plot for full domain save_path = plot_validation_pdfs( predictions=predictions_full, targets=fine_full, coarse_inputs=coarse_full, variable_names=args.varnames_list, filename=f"{args.run_type}_full_domain_validation_pdfs_epoch_{epoch}.png", save_dir=paths.results, save_npz=True, ) logger.info(f"Saved full domain validation PDFs plot to: {save_path}") # 4. Power Spectra Plot for full domain dlon = getattr(steps, "d_longitude", None) dlat = getattr(steps, "d_latitude", None) assert dlon is not None, "d_longitude not found in steps" assert dlat is not None, "d_latitude not found in steps" save_path = plot_power_spectra( predictions=predictions_full, targets=fine_full, coarse_inputs=coarse_full, dlat=dlat, dlon=dlon, variable_names=args.varnames_list, filename=f"{args.run_type}_full_domain_power_spectra_epoch_{epoch}.png", save_dir=paths.results, save_npz=True, ) logger.info(f"Saved full domain power spectra plot to: {save_path}") # 5. MAE map Plot for full domain save_path = plot_MAE_map( predictions=predictions_full, targets=fine_full, lat_1d=lat_full, lon_1d=lon_full, variable_names=args.varnames_list, filename=f"{args.run_type}_full_domain_mae_map_epoch_{epoch}.png", save_dir=paths.results, ) logger.info(f"Saved full domain MAE map to: {save_path}") # 6. Error map Plot for full domain save_path = plot_error_map( predictions=predictions_full, targets=fine_full, lat_1d=lat_full, lon_1d=lon_full, variable_names=args.varnames_list, filename=f"{args.run_type}_full_domain_error_map_epoch_{epoch}.png", save_dir=paths.results, ) logger.info(f"Saved full domain error map to: {save_path}") # 7. Surface plots for first few time steps of full domain num_time_steps_to_plot = min(3, time_batchs) for time_idx in range(num_time_steps_to_plot): # Extract single time step pred_single_time = predictions_full[time_idx : time_idx + 1] # [1, C, H, W] fine_single_time = fine_full[time_idx : time_idx + 1] # [1, C, H, W] coarse_single_time = coarse_full[time_idx : time_idx + 1] # [1, C, H, W] tindex = dataset.time_batchs[time_idx] timestamp = pd.to_datetime( dataset.loaded_dfs.time.values[tindex] ).to_pydatetime() save_path = plot_surface( predictions=pred_single_time, targets=fine_single_time, coarse_inputs=coarse_single_time, lat_1d=lat_full, lon_1d=lon_full, timestamp=timestamp, variable_names=args.varnames_list, filename=f"{args.run_type}_full_domain_surface_epoch_{epoch}_time_{time_idx:03d}.png", save_dir=paths.results, ) logger.info( f"Saved full domain surface plot (time {time_idx}) to: {save_path}" ) # Zoom comparison plot for full domain (only for global inference) if args.run_type == "inference": save_path = plot_zoom_comparison( predictions=pred_single_time, targets=fine_single_time, lat_1d=lat_full, lon_1d=lon_full, variable_names=args.varnames_list, filename=f"{args.run_type}_full_domain_zoom_comparison_epoch_{epoch}_time_{time_idx:03d}.png", save_dir=paths.results, ) logger.info( f"Saved full domain zoom comparison (time {time_idx}) to: {save_path}" ) # 8. Multivariate Correlation Maps for full domain # Convert 1D lat/lon to 2D meshgrid lat_2d_full, lon_2d_full = torch.meshgrid(lat_full, lon_full, indexing="ij") save_path = plot_validation_mvcorr( predictions=predictions_full, # [T, C, H, W] targets=fine_full, # [T, C, H, W] coarse_inputs=coarse_full, lat=lat_2d_full.numpy(), lon=lon_2d_full.numpy(), variable_names=args.varnames_list, filename=f"{args.run_type}_full_domain_mvcorr_epoch_{epoch}.png", save_dir=paths.results, ) logger.info(f"Saved full domain multivariate correlation map to: {save_path}") # 9. Temporal series for full domain save_path = plot_temporal_series_comparison( predictions=predictions_full, targets=fine_full, # coarse_inputs=coarse_full, variable_names=args.varnames_list, filename=f"{args.run_type}_full_domain_temporal_series_epoch_{epoch}.png", save_dir=paths.results, ) logger.info(f"Saved full domain temporal series plot to: {save_path}") # 10. Multivariate spatial correlation time series for full domain save_path = plot_validation_mvcorr_space( predictions=predictions_full, targets=fine_full, coarse_inputs=coarse_full, variable_names=args.varnames_list, filename=f"{args.run_type}_full_domain_mvcorr_space_epoch_{epoch}.png", save_dir=paths.results, ) logger.info( f"Saved full domain multivariate spatial correlation time series to: {save_path}" ) return {"data": reconstructions, "metadata": metadata, "device": device}
[docs] def generate_residuals_norm( model, features, labels, targets, loss_fn, args, device, logger, epoch=0, batch_idx=0, inference_type="sampler", ): """ Generate normalized residuals for all variables. Parameters ---------- model : torch.nn.Module Diffusion model features : torch.Tensor Input feature tensor provided to the model labels : torch.Tensor Conditioning labels provided to the model targets : torch.Tensor Ground truth target tensor used for noise injection in direct inference loss_fn : callable Loss function args : argparse.Namespace Command line arguments device : torch.device Training device logger : Logger Logger instance epoch : int Current epoch number inference_type : str, optional Inference mode, either "direct" (deterministic) or "sampler" (stochastic diffusion sampling) Returns ------- torch.Tensor [B, C, H, W] residuals in normalized space """ # Generate samples for metrics calculation # Choose direct for rapid evaluation, sampler for full quality if inference_type == "direct": if args.debug: logger.info("Using direct inference/evaluation mode (deterministic)") if args.precond == "unet": # Direct prediction for unet generated_residuals = model(features, class_labels=labels) else: rnd_normal = torch.randn([targets.shape[0], 1, 1, 1], device=targets.device) sigma = (rnd_normal * loss_fn.P_std + loss_fn.P_mean).exp() noisy_targets = targets + torch.randn_like(targets) * sigma generated_residuals = model(noisy_targets, sigma, features, labels) elif inference_type == "sampler": if args.precond == "unet": raise ValueError("UNet does not support sampler inference") if args.debug and logger is not None: logger.info("Using sampler inference/evaluation mode (stochastic)") logger.info(f"Starting EDM sampler with {args.num_steps} steps") generated_residuals = sampler( epoch, batch_idx, model, features, labels, num_steps=args.num_steps, sigma_min=args.sigma_min, sigma_max=args.sigma_max, rho=args.rho, solver=args.solver, S_churn=args.s_churn, S_min=args.s_min, S_max=args.s_max, S_noise=args.s_noise, logger=logger, ) else: logger.error(f"Unknown inference_type: {inference_type}") raise return generated_residuals
[docs] def run_validation( model, valid_dataset, valid_loader, loss_fn, norm_mapping, normalization_type, index_mapping, args, steps, device, logger, epoch, writer=None, plot_every_n_epochs=None, paths=None, compute_crps=False, crps_batch_size=2, crps_ensemble_size=10, ): """ Run validation on the model. Parameters ---------- model : torch.nn.Module Diffusion model valid_loader : DataLoader Validation data loader loss_fn : callable Loss function norm_mapping : dict Normalization statistics normalization_type : EasyDict Normalization types for each variable args : argparse.Namespace Command line arguments device : torch.device Training device logger : Logger Logger instance epoch : int Current epoch number writer : SummaryWriter, optional TensorBoard writer plot_every_n_epochs : int, optional Frequency (in epochs) at which validation plots are generated paths : dict, optional Paths used for saving reconstructions and plots compute_crps : bool, optional Whether to compute CRPS using stochastic ensemble sampling crps_batch_size : int, optional Number of validation batches used for CRPS computation crps_ensemble_size : int, optional Number of ensemble members used to estimate CRPS Returns ------- tuple (avg_val_loss, val_metrics) - average validation loss and metrics dictionary """ # Define available metrics metric_names = [ "MAE", "NMAE", "RMSE", "R2", "PEARSON", "KL", ] # You can add more metrics here like ["MAE", "MSE", "RMSE"] metric_funcs = { "MAE": mae_all, "NMAE": nmae_all, "RMSE": rmse_all, "R2": r2_all, "PEARSON": pearson_all, "KL": kl_divergence_all, # You can add more metrics here: # "MSE": mse_all, } # Add CRPS only if requested if compute_crps: metric_names.append("CRPS") metric_funcs["CRPS"] = crps_ensemble_all # Separate deterministic metrics from CRPS. # CRPS is handled separately due to its stochastic and expensive nature. deterministic_metrics = [m for m in metric_names if m != "CRPS"] model.eval() val_loss = MetricTracker() # Create metrics for both model predictions and coarse baseline. # This is done in two steps because deterministic metrics (MAE, NMAE) # are computed for both model predictions and the coarse baseline, # whereas CRPS is a probabilistic metric and is only defined for # stochastic model outputs (no coarse vs fine CRPS). val_metrics = {} for k in args.varnames_list: for m in deterministic_metrics: val_metrics[f"{k}_pred_vs_fine_{m}"] = ( MetricTracker() ) # Model prediction vs true fine val_metrics[f"{k}_coarse_vs_fine_{m}"] = ( MetricTracker() ) # Coarse vs true fine (baseline) if compute_crps: val_metrics[f"{k}_pred_vs_fine_CRPS"] = MetricTracker() # Add average metrics across all variables for each metric type for m in deterministic_metrics: val_metrics[f"average_pred_vs_fine_{m}"] = MetricTracker() val_metrics[f"average_coarse_vs_fine_{m}"] = MetricTracker() if compute_crps: val_metrics["average_pred_vs_fine_CRPS"] = MetricTracker() all_data = {"predictions": [], "coarse": [], "fine": [], "lat": [], "lon": []} crps_batches = [] logger.info(f"Running validation for epoch {epoch}...") logger.info(f"EDM Sampler parameters: steps={args.num_steps}") with torch.no_grad(): val_loop = tqdm( enumerate(valid_loader), total=len(valid_loader), desc=f"Validation Epoch {epoch}", ) for batch_idx, batch in val_loop: # Move data to device features = batch["inputs"].to(device) targets = batch["targets"].to(device) coarse = batch["coarse"].to(device) # coarse_norm = batch["coarse_norm"].to(device) # Number of variables (channels) n_vars = len(args.varnames_list) # Extract normalized coarse field from model inputs coarse_norm = features[:, :n_vars] fine = batch["fine"].to(device) lat_batch = batch["corrdinates"]["lat"].to(device) lon_batch = batch["corrdinates"]["lon"].to(device) if epoch == 0 and batch_idx == 0: logger.info( f"Validation batch idx:{batch_idx}\n" f"features shape:{features.shape}, targets shape:{targets.shape}\n" f"coarse shape:{coarse.shape}, fine shape:{fine.shape}\n" f"lat shape:{lat_batch.shape}, lon shape:{lon_batch.shape}" ) # Prepare labels if args.time_normalization == "linear": labels = torch.stack( (batch["doy"].to(device), batch["hour"].to(device)), dim=1 ) elif args.time_normalization == "cos_sin": labels = torch.stack( ( batch["doy_sin"].to(device), batch["doy_cos"].to(device), batch["hour_sin"].to(device), batch["hour_cos"].to(device), ), dim=1, ) # Calculate validation loss with torch.amp.autocast(device_type=device.type, dtype=features.dtype): loss = loss_fn(model, targets, features, labels) # unet loss is a scalar, so no need for mean if args.precond != "unet": loss = loss.mean() val_loss.update(loss.item(), targets.shape[0]) # Store a limited number of batches for CRPS computation. # CRPS is expensive, so we only keep the first crps_batch_size batches # and reuse the existing features and labels. if compute_crps and len(crps_batches) < crps_batch_size: crps_batches.append( { "features": features, "labels": labels, "batch": batch, } ) # Track batch-level averages for overall metrics for each metric type batch_metric_sums = { m: {"pred": MetricTracker(), "coarse": MetricTracker()} for m in deterministic_metrics } generated_residual = generate_residuals_norm( model=model, features=features, labels=labels, targets=targets, loss_fn=loss_fn, args=args, device=device, logger=logger, epoch=epoch, batch_idx=batch_idx, inference_type=args.inference_type, ) batch_predictions = [] # Reconstruct final images for var_name in args.varnames_list: # Get the correct channel index for this variable iv = index_mapping[var_name] # Reconstruct final image: coarse + residual coarse_var_norm = coarse_norm[:, iv : iv + 1] final_prediction_norm = ( coarse_var_norm + generated_residual[:, iv : iv + 1] ) # Calculate metrics against ground truth fine data fine_var = fine[:, iv : iv + 1] coarse_var = coarse[:, iv : iv + 1] norm_type = normalization_type[var_name] if norm_type.startswith("log1p"): stats_fine = norm_mapping[f"{var_name}_fine_log"] else: stats_fine = norm_mapping[f"{var_name}_fine"] final_prediction = denormalize( final_prediction_norm, stats_fine, norm_type, device, var_name=var_name, data_type="fine", debug=args.debug, logger=logger, ) batch_predictions.append(final_prediction) # Calculate all metrics for this variable for metric_name in deterministic_metrics: metric_func = metric_funcs[metric_name] # Model prediction vs fine num_elements_pred, metric_value_pred = metric_func( final_prediction, fine_var ) val_metrics[f"{var_name}_pred_vs_fine_{metric_name}"].update( metric_value_pred.item(), num_elements_pred ) # Coarse vs fine (baseline metric) num_elements_coarse, metric_value_coarse = metric_func( coarse_var, fine_var ) val_metrics[f"{var_name}_coarse_vs_fine_{metric_name}"].update( metric_value_coarse.item(), num_elements_coarse ) # Accumulate for batch averages batch_metric_sums[metric_name]["pred"].update( metric_value_pred.item(), num_elements_pred ) batch_metric_sums[metric_name]["coarse"].update( metric_value_coarse.item(), num_elements_coarse ) final_prediction_batch = torch.cat(batch_predictions, dim=1) # [B, C, H, W] # Store only needed data for reconstruction # Validation outputs are accumulated and immediately moved to CPU # to avoid CUDA out-of-memory errors. all_data["predictions"].append(final_prediction_batch.detach().cpu()) all_data["coarse"].append(coarse.detach().cpu()) all_data["fine"].append(fine.detach().cpu()) all_data["lat"].append(lat_batch.detach().cpu()) # [B, H] all_data["lon"].append(lon_batch.detach().cpu()) # [B, W] # Update overall average metrics for this batch for each metric type for metric_name in deterministic_metrics: batch_avg_pred = batch_metric_sums[metric_name]["pred"].getmean() batch_avg_coarse = batch_metric_sums[metric_name]["coarse"].getmean() val_metrics[f"average_pred_vs_fine_{metric_name}"].update( batch_avg_pred, 1 ) val_metrics[f"average_coarse_vs_fine_{metric_name}"].update( batch_avg_coarse, 1 ) # Update progress bar (show first metric by default) primary_metric = deterministic_metrics[0] batch_avg_pred = batch_metric_sums[primary_metric]["pred"].getmean() batch_avg_coarse = batch_metric_sums[primary_metric]["coarse"].getmean() val_loop.set_postfix( { "Val Loss": f"{loss.item():.4f}", "Avg Val Loss": f"{val_loss.getmean():.4f}", f"Avg Pred {primary_metric}": f"{batch_avg_pred:.4f}", f"Avg Coarse {primary_metric}": f"{batch_avg_coarse:.4f}", } ) torch.cuda.empty_cache() avg_val_loss = val_loss.getmean() # To verify with Kazem # Compute CRPS only if requested and if some batches were collected. # CRPS is evaluated using an ensemble of stochastic sampler runs. if compute_crps and len(crps_batches) > 0: logger.info( "CRPS configuration summary:\n" f" └── Number of CRPS batches: {len(crps_batches)}\n" f" └── Ensemble size: {crps_ensemble_size}" ) for item in tqdm(crps_batches, desc="CRPS batches", total=len(crps_batches)): features = item["features"] labels = item["labels"] batch = item["batch"] # Generate an ensemble of predictions using the sampler ens_preds = [] for _ in tqdm(range(crps_ensemble_size), desc="CRPS ensemble", leave=False): generated_residual = generate_residuals_norm( model=model, features=features, labels=labels, targets=batch["targets"].to(device), loss_fn=loss_fn, args=args, device=device, logger=None, epoch=epoch, batch_idx=-1, # not tied to validation loop inference_type="sampler", ) # Reconstruct final prediction reconstructed_vars = [] # Extract normalized coarse field from inputs n_vars = len(args.varnames_list) coarse_norm = features[:, :n_vars] for var_name in args.varnames_list: iv = index_mapping[var_name] # coarse_var_norm = batch["coarse_norm"][:, iv:iv+1].to(device) coarse_var_norm = coarse_norm[:, iv : iv + 1] final_pred_norm = ( coarse_var_norm + generated_residual[:, iv : iv + 1] ) norm_type = normalization_type[var_name] if norm_type.startswith("log1p"): stats_fine = norm_mapping[f"{var_name}_fine_log"] else: stats_fine = norm_mapping[f"{var_name}_fine"] final_pred = denormalize( final_pred_norm, stats_fine, norm_type, device, var_name=var_name, data_type="fine", debug=args.debug, logger=logger, ) reconstructed_vars.append(final_pred) # Final reconstructed prediction for this ensemble member final_prediction = torch.cat(reconstructed_vars, dim=1) # [B, C, H, W] ens_preds.append(final_prediction) # Compute CRPS per variable pred_ens = torch.stack(ens_preds, dim=0) # [N_ens, B, C, H, W] for var_name in args.varnames_list: iv = index_mapping[var_name] pred_ens_var = pred_ens[:, :, iv : iv + 1, :, :] # [N_ens, B, 1, H, W] fine_var = batch["fine"][:, iv : iv + 1].to(device) pred_ens_flat = pred_ens_var.reshape(crps_ensemble_size, -1) true_flat = fine_var.reshape(-1) # Compute CRPS per variable using ensemble predictions. num_elem, crps_mean = crps_ensemble_all(pred_ens_flat, true_flat) # Update per-variable CRPS tracker val_metrics[f"{var_name}_pred_vs_fine_CRPS"].update( crps_mean.item(), num_elem ) # Global average CRPS tracker val_metrics["average_pred_vs_fine_CRPS"].update( crps_mean.item(), num_elem ) # Log validation results logger.info(f"Validation Epoch {epoch} - Average Loss: {avg_val_loss:.4f}") logger.info("=" * 60) logger.info("VALIDATION METRICS SUMMARY:") logger.info("=" * 60) # Log overall metrics for each metric type for metric_name in metric_names: if metric_name == "CRPS": # Log CRPS only when it has been computed to avoid empty MetricTracker access. if compute_crps: final_avg_pred = val_metrics["average_pred_vs_fine_CRPS"].getmean() std_avg_pred = val_metrics["average_pred_vs_fine_CRPS"].getstd() logger.info("OVERALL CRPS:") logger.info( f" └── Average Prediction vs Fine CRPS: {final_avg_pred:.5f} ± {std_avg_pred:.5f}" ) else: final_avg_pred = val_metrics[ f"average_pred_vs_fine_{metric_name}" ].getmean() final_avg_coarse = val_metrics[ f"average_coarse_vs_fine_{metric_name}" ].getmean() std_avg_pred = val_metrics[f"average_pred_vs_fine_{metric_name}"].getstd() std_avg_coarse = val_metrics[ f"average_coarse_vs_fine_{metric_name}" ].getstd() logger.info(f"OVERALL {metric_name} METRICS:") logger.info( f" └── Average Prediction vs Fine {metric_name}: {final_avg_pred:.4f} ± {std_avg_pred:.4f}" ) logger.info( f" └── Average Coarse vs Fine {metric_name}: {final_avg_coarse:.4f} ± {std_avg_coarse:.4f}" ) logger.info("") # Log per-variable metrics logger.info("PER-VARIABLE METRICS:") for var_name in args.varnames_list: logger.info(f" └── {var_name}:") for metric_name in metric_names: if metric_name == "CRPS": # Log CRPS only when it has been computed to avoid empty MetricTracker access. if compute_crps: crps_var = val_metrics[f"{var_name}_pred_vs_fine_CRPS"].getmean() crps_std = val_metrics[f"{var_name}_pred_vs_fine_CRPS"].getstd() logger.info(" └── CRPS:") logger.info( f" └── Model Pred vs Fine: {crps_var:.5f} ± {crps_std:.5f}" ) else: pred_metric = val_metrics[ f"{var_name}_pred_vs_fine_{metric_name}" ].getmean() pred_std = val_metrics[ f"{var_name}_pred_vs_fine_{metric_name}" ].getstd() coarse_metric = val_metrics[ f"{var_name}_coarse_vs_fine_{metric_name}" ].getmean() coarse_std = val_metrics[ f"{var_name}_coarse_vs_fine_{metric_name}" ].getstd() logger.info(f" └── {metric_name}:") logger.info( f" └── Model Pred vs Fine: {pred_metric:.4f} ± {pred_std:.4f}" ) logger.info( f" └── Coarse vs Fine: {coarse_metric:.4f} ± {coarse_std:.4f}" ) # To verify with Kazem # Global heatmap of validation metrics (per variable × metric) if paths is not None: try: heatmap_path = plot_metrics_heatmap( valid_metrics_history=val_metrics, variable_names=args.varnames_list, metric_names=metric_names, filename=f"{args.run_type}_validation_metrics_epoch_{epoch}", save_dir=paths.results, ) logger.info(f"Saved validation metrics heatmap to: {heatmap_path}") except Exception as e: logger.warning(f"Could not generate metrics heatmap: {e}") # Check if we should create plots for this batch should_plot = ( plot_every_n_epochs is not None and epoch % plot_every_n_epochs == 0 and paths is not None ) if should_plot: logger.info("Reconstructing and plots ...") _ = reconstruct_original_layout( epoch, args, paths, steps, all_data=all_data, dataset=valid_dataset, # device=device, # Keep on the same device --> OOM device=torch.device( "cpu" ), # reconstruction & plotting on CPU to avoid cuda out of memory logger=logger, # Pass the logger ) # Log to TensorBoard if writer is provided if writer is not None: writer.add_scalar("Loss/val_epoch", avg_val_loss, epoch) # Log overall metrics for each metric type for metric_name in metric_names: if metric_name == "CRPS": # Log CRPS only when it has been computed to avoid empty MetricTracker access. if compute_crps: final_avg_pred = val_metrics["average_pred_vs_fine_CRPS"].getmean() std_pred = val_metrics["average_pred_vs_fine_CRPS"].getstd() writer.add_scalar( "Metrics/average_pred_vs_fine_CRPS", final_avg_pred, epoch ) writer.add_scalar( "Metrics/average_pred_vs_fine_CRPS_std", std_pred, epoch ) else: final_avg_pred = val_metrics[ f"average_pred_vs_fine_{metric_name}" ].getmean() std_pred = val_metrics[f"average_pred_vs_fine_{metric_name}"].getstd() final_avg_coarse = val_metrics[ f"average_coarse_vs_fine_{metric_name}" ].getmean() std_coarse = val_metrics[ f"average_coarse_vs_fine_{metric_name}" ].getstd() writer.add_scalar( f"Metrics/average_pred_vs_fine_{metric_name}", final_avg_pred, epoch ) writer.add_scalar( f"Metrics/average_pred_vs_fine_{metric_name}_std", std_pred, epoch ) writer.add_scalar( f"Metrics/average_coarse_vs_fine_{metric_name}", final_avg_coarse, epoch, ) writer.add_scalar( f"Metrics/average_coarse_vs_fine_{metric_name}_std", std_coarse, epoch, ) # Log per-variable metrics for var_name in args.varnames_list: for metric_name in metric_names: if metric_name == "CRPS": # Log CRPS only when it has been computed to avoid empty MetricTracker access. if compute_crps: crps_var = val_metrics[ f"{var_name}_pred_vs_fine_CRPS" ].getmean() crps_var_std = val_metrics[ f"{var_name}_pred_vs_fine_CRPS" ].getstd() writer.add_scalar( f"Metrics/{var_name}_pred_vs_fine_CRPS", crps_var, epoch ) writer.add_scalar( f"Metrics/{var_name}_pred_vs_fine_CRPS_std", crps_var_std, epoch, ) else: pred_metric = val_metrics[ f"{var_name}_pred_vs_fine_{metric_name}" ].getmean() pred_metric_std = val_metrics[ f"{var_name}_pred_vs_fine_{metric_name}" ].getstd() coarse_metric = val_metrics[ f"{var_name}_coarse_vs_fine_{metric_name}" ].getmean() coarse_metric_std = val_metrics[ f"{var_name}_coarse_vs_fine_{metric_name}" ].getstd() writer.add_scalar( f"Metrics/{var_name}_pred_vs_fine_{metric_name}", pred_metric, epoch, ) writer.add_scalar( f"Metrics/{var_name}_pred_vs_fine_{metric_name}_std", pred_metric_std, epoch, ) writer.add_scalar( f"Metrics/{var_name}_coarse_vs_fine_{metric_name}", coarse_metric, epoch, ) writer.add_scalar( f"Metrics/{var_name}_coarse_vs_fine_{metric_name}_std", coarse_metric_std, epoch, ) return avg_val_loss, val_metrics