Source code for IPSL_AID.evaluater

# Copyright 2026 IPSL / CNRS / Sorbonne University
# Authors: Kazem Ardaneh, Kishanthan Kingston
#
# This work is licensed under the Creative Commons
# Attribution-NonCommercial-ShareAlike 4.0 International License.
# To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc-sa/4.0/

# ruff: noqa: E731
import numpy as np
import torch
import pandas as pd
from tqdm import tqdm
from IPSL_AID.diagnostics import (
    plot_validation_hexbin,
    plot_comparison_hexbin,
    plot_validation_pdfs,
    plot_power_spectra,
    plot_qq_quantiles,
    plot_surface,
    plot_zoom_comparison,
    plot_MAE_map,
    plot_error_map,
    plot_metrics_heatmap,
    plot_validation_mvcorr,
    plot_validation_mvcorr_space,
    plot_temporal_series_comparison,
)
import unittest
from unittest.mock import Mock, patch


[docs] class MetricTracker: """ A utility class for tracking and computing statistics of metric values. This class maintains a running average of metric values and provides methods to compute mean and root mean squared values. Attributes ---------- value : float Cumulative weighted sum of metric values count : int Total number of samples processed Examples -------- >>> tracker = MetricTracker() >>> tracker.update(10.0, 5) # value=10.0, count=5 samples >>> tracker.update(20.0, 3) # value=20.0, count=3 samples >>> print(tracker.getmean()) # (10*5 + 20*3) / (5+3) = 110/8 = 13.75 13.75 >>> print(tracker.getsqrtmean()) # sqrt(13.75) 3.7080992435478315 """
[docs] def __init__(self): """ Initialize MetricTracker with zero values. """ self.reset()
[docs] def reset(self): """ Reset all tracked values to zero. Returns ------- None """ self.value = 0.0 self.count = 0 self.value_sq = 0.0
[docs] def update(self, value, count): """ Update the tracker with new metric values. Parameters ---------- value : float The metric value to add count : int Number of samples this value represents (weight) Returns ------- None """ self.count += count self.value += value * count self.value_sq += (value**2) * count
[docs] def getmean(self): """ Calculate the mean of all tracked values. Returns ------- float Weighted mean of all values: total_value / total_count Raises ------ ZeroDivisionError If no values have been added (count == 0) """ if self.count == 0: raise ZeroDivisionError("Cannot compute mean with zero samples") return self.value / self.count
[docs] def getstd(self): """ Calculate the standard deviation of all tracked values. Returns ------- float Weighted standard deviation of all values: sqrt(E(x^2) - (E(x))^2) Raises ------ ZeroDivisionError If no values have been added (count == 0) """ if self.count == 0: raise ZeroDivisionError("Cannot compute std with zero samples") mean = self.getmean() variance = self.value_sq / self.count - mean**2 return np.sqrt(max(variance, 0.0)) # numerical safety
[docs] def getsqrtmean(self): """ Calculate the square root of the mean of all tracked values. Returns ------- float Square root of the weighted mean: sqrt(total_value / total_count) Raises ------ ZeroDivisionError If no values have been added (count == 0) """ return np.sqrt(self.getmean())
[docs] def mae_all(pred, true): """ Calculate Mean Absolute Error (MAE) between predicted and true values. Computes the MAE metric and returns both the number of elements and the mean absolute error value. Parameters ---------- pred : torch.Tensor Predicted values from the model true : torch.Tensor Ground truth values Returns ------- tuple (num_elements, mae_value) where: - num_elements (int): Total number of elements in the tensors - mae_value (torch.Tensor): Mean absolute error value Examples -------- >>> pred = torch.tensor([1.0, 2.0, 3.0]) >>> true = torch.tensor([1.1, 1.9, 3.2]) >>> num_elements, mae = mae_all(pred, true) >>> print(f"MAE: {mae.item():.4f}, Elements: {num_elements}") MAE: 0.1333, Elements: 3 Notes ----- The MAE is calculated as: mean(abs(pred - true)) This function is useful for tracking metrics with MetricTracker """ num_elements = pred.numel() mae_value = torch.mean(torch.abs(pred - true)) return num_elements, mae_value
[docs] def nmae_all(pred, true, eps=1e-8): """ Normalized Mean Absolute Error (NMAE). NMAE = MAE(pred, true) / mean(abs(true)) Computes the NMAE metric and returns both the number of elements and the normalized mean absolute error value. Parameters ---------- pred : torch.Tensor Predicted values from the model true : torch.Tensor Ground truth values eps : float Small value to avoid division by zero Returns ------- tuple (num_elements, mae_value) where: - num_elements (int): Total number of elements in the tensors - mae_value (torch.Tensor): Mean absolute error value Examples -------- >>> pred = torch.tensor([1.0, 2.0, 3.0]) >>> true = torch.tensor([1.1, 1.9, 3.2]) >>> num_elements, nmae = nmae_all(pred, true) >>> print(f"NMAE: {nmae.item():.4f}, Elements: {num_elements}") NMAE: 0.047059, Elements: 3 Notes ----- The NMAE is calculated as: MAE(pred, true) / mean(abs(true)) This function is useful for tracking metrics with MetricTracker """ num_elements = pred.numel() mae = torch.mean(torch.abs(pred - true)) norm = torch.mean(torch.abs(true)) + eps nmae = mae / norm return num_elements, nmae
# To verify with Kazem
[docs] def crps_ensemble_all(pred_ens, true): """ Continuous Ranked Probability Score (CRPS) for an ensemble. Computes the CRPS metric for ensemble predictions and returns both the number of elements and the mean CRPS value. Parameters ---------- pred_ens : torch.Tensor Ensemble predictions, shape [N_ens, N_pixels] true : torch.Tensor Ground truth values, shape [N_pixels] Returns ------- tuple (num_elements, crps_mean) where: - num_elements (int): Total number of elements in the tensors - crps_mean (torch.Tensor): Mean CRPS Notes ----- The CRPS for an ensemble is computed as: CRPS = E|X - y| - 0.5 * E|X - X'| where X and X' are independent ensemble members and y is the observation. """ # Number of ensemble members n = pred_ens.shape[0] # Sort ensemble pred_ens_sorted, _ = torch.sort(pred_ens, dim=0) # Term 1: E|X - y| term1 = torch.mean(torch.abs(pred_ens - true.unsqueeze(0)), dim=0) # Term 2: ensemble spread term diff = pred_ens_sorted[1:] - pred_ens_sorted[:-1] weight = torch.arange(1, n, device=pred_ens.device) * torch.arange( n - 1, 0, -1, device=pred_ens.device ) term2 = torch.sum(diff * weight.unsqueeze(1), dim=0) / (n**2) crps_pixel = term1 - term2 # [N_pixels] # Final aggregation num_elements = crps_pixel.numel() crps_mean = crps_pixel.mean() return num_elements, crps_mean
[docs] def rmse_all(pred, true): """ Calculate Root Mean Square Error (RMSE) between predicted and true values. Computes the RMSE metric and returns both the number of elements and the root mean square error value. Parameters ---------- pred : torch.Tensor Predicted values from the model true : torch.Tensor Ground truth values Returns ------- tuple (num_elements, rmse_value) where: - num_elements (int): Total number of elements in the tensors - rmse_value (torch.Tensor): Root mean square error value Examples -------- >>> pred = torch.tensor([1.0, 2.0, 3.0]) >>> true = torch.tensor([1.1, 1.9, 3.2]) >>> num_elements, rmse = rmse_all(pred, true) >>> print(f"RMSE: {rmse.item():.4f}, Elements: {num_elements}") RMSE: 0.1414, Elements: 3 Notes ----- The RMSE is calculated as: sqrt(mean((pred - true)^2)) This function is useful for tracking metrics with MetricTracker """ num_elements = pred.numel() mse = torch.mean((pred - true) ** 2) rmse_value = torch.sqrt(mse) return num_elements, rmse_value
[docs] def r2_all(pred, true): """ Calculate R2 (coefficient of determination) between predicted and true values. Computes the R2 metric and returns both the number of elements and the R2 value. Parameters ---------- pred : torch.Tensor Predicted values from the model true : torch.Tensor Ground truth values Returns ------- tuple (num_elements, r2_value) where: - num_elements (int): Total number of elements in the tensors - r2_value (torch.Tensor): R2 score Notes ----- R2 is calculated as: R2 = 1 - sum((true - pred)^2) / sum((true - mean(true))^2) This implementation is fully torch-based and works on CPU and GPU. """ if pred.shape != true.shape: raise RuntimeError(f"Shape mismatch: pred {pred.shape} vs true {true.shape}") eps = 1e-12 # Small value to avoid division by zero when variance is zero num_elements = pred.numel() # Flatten pred_flat = pred.reshape(-1) true_flat = true.reshape(-1) # Residual sum of squares ss_res = torch.sum((true_flat - pred_flat) ** 2) # Total sum of squares true_mean = torch.mean(true_flat) ss_tot = torch.sum((true_flat - true_mean) ** 2) # R2 score r2_value = 1.0 - ss_res / (ss_tot + eps) return num_elements, r2_value
[docs] def pearson_all(pred, true): """ Compute the Pearson correlation coefficient between predicted and ground truth values using torch.corrcoef. Parameters ---------- pred : torch.Tensor Predicted values from the model. true : torch.Tensor Ground truth values. Returns ------- tuple (num_elements, pearson_value) where: - num_elements (int): Total number of elements in the tensors. - pearson_value (torch.Tensor): Pearson correlation coefficient. Notes ----- The Pearson correlation coefficient is defined as: rho = Cov(pred, true) / (std(pred) * std(true)) """ if pred.shape != true.shape: raise RuntimeError(f"Shape mismatch: {pred.shape} vs {true.shape}") num_elements = pred.numel() # Flatten tensors to 1D vectors pred_flat = pred.reshape(-1) true_flat = true.reshape(-1) # Stack into a 2 x N matrix required by torch.corrcoef stacked = torch.stack([pred_flat, true_flat], dim=0) # Compute correlation matrix corr_matrix = torch.corrcoef(stacked) # Extract Pearson correlation coefficient between # predictions (row 0) and truth (row 1) pearson_value = corr_matrix[0, 1] return num_elements, pearson_value
[docs] def kl_divergence_all(pred, true): """ Compute the Kullback–Leibler (KL) divergence between predicted and ground truth distributions using histogram-based estimation. Parameters ---------- pred : torch.Tensor Predicted values from the model. true : torch.Tensor Ground truth values. Returns ------- tuple (num_elements, kl_value) where: - num_elements (int): Total number of elements in the tensors. - kl_value (torch.Tensor): KL divergence value. Notes ----- The KL divergence is defined as: KL(P|Q) = sum_i P_i * log(P_i / Q_i) where: - P represents the true distribution - Q represents the predicted distribution """ if pred.shape != true.shape: raise RuntimeError(f"Shape mismatch: {pred.shape} vs {true.shape}") num_elements = pred.numel() n_bins = 100 eps = 1e-12 # Flatten tensors to 1D vectors pred_flat = pred.reshape(-1) true_flat = true.reshape(-1) # Combine for percentile computation all_values = torch.cat([pred_flat, true_flat]) # Percentile clipping data_min = torch.quantile(all_values, 0.0025) data_max = torch.quantile(all_values, 0.995) data_range = data_max - data_min x_min = data_min - 0.05 * data_range x_max = data_max + 0.05 * data_range hist_pred = torch.histc(pred_flat, bins=n_bins, min=x_min.item(), max=x_max.item()) hist_true = torch.histc(true_flat, bins=n_bins, min=x_min.item(), max=x_max.item()) # Add epsilon hist_pred = hist_pred + eps hist_true = hist_true + eps # Normalize to probability mass hist_pred = hist_pred / hist_pred.sum() hist_true = hist_true / hist_true.sum() # KL divergence kl_value = torch.sum(hist_true * torch.log(hist_true / hist_pred)) return num_elements, kl_value
[docs] def denormalize( data, stats, norm_type, device, var_name=None, data_type=None, debug=False, logger=None, ): """ Denormalize a data tensor using the inverse of the normalization operation. Parameters ---------- data : torch.Tensor Normalized tensor to denormalize. stats : object Object containing the required statistics. norm_type : str Normalization type used originally. device : torch.device Device for tensor operations. var_name : str, optional Variable name for debugging. data_type : str, optional Data type for debugging (e.g., "residual", "coarse"). debug : bool, optional Enable debug logging. logger : Logger, optional Logger instance for debug output. """ # Add debug logging at the start if debug and logger: # Create context string context = "" if var_name: context = f" for {var_name}" if data_type: context += f" ({data_type})" logger.info( f"Denormalizing{context} with type '{norm_type}'\n" f" └── Denormalization stats;\n" f" └── vmin: {getattr(stats, 'vmin', None)}\n" f" └── vmax: {getattr(stats, 'vmax', None)}\n" f" └── vmean: {getattr(stats, 'vmean', None)}\n" f" └── vstd: {getattr(stats, 'vstd', None)}\n" f" └── median: {getattr(stats, 'median', None)}\n" f" └── iqr: {getattr(stats, 'iqr', None)}\n" f" └── q1: {getattr(stats, 'q1', None)}\n" f" └── q3: {getattr(stats, 'q3', None)}" ) # ------------------ MIN-MAX ------------------ if norm_type == "minmax": vmin = torch.tensor(stats.vmin, dtype=data.dtype, device=device) vmax = torch.tensor(stats.vmax, dtype=data.dtype, device=device) denom = vmax - vmin if denom == 0: return torch.zeros_like(data) return data * denom + vmin # ------------------ MIN-MAX [-1, 1] ----------------- elif norm_type == "minmax_11": vmin = torch.tensor(stats.vmin, dtype=data.dtype, device=device) vmax = torch.tensor(stats.vmax, dtype=data.dtype, device=device) denom = vmax - vmin if denom == 0: return torch.zeros_like(data) return ((data + 1) / 2) * denom + vmin # ------------------ STANDARD ----------------- elif norm_type == "standard": mean = torch.tensor(stats.vmean, dtype=data.dtype, device=device) std = torch.tensor(stats.vstd, dtype=data.dtype, device=device) if std == 0: return torch.zeros_like(data) return data * std + mean # ------------------ ROBUST ------------------- elif norm_type == "robust": median = torch.tensor(stats.median, dtype=data.dtype, device=device) iqr = torch.tensor(stats.iqr, dtype=data.dtype, device=device) if iqr == 0: return torch.zeros_like(data) return data * iqr + median # ------------------ LOG1P + MIN-MAX ------------------ elif norm_type == "log1p_minmax": log_min = torch.tensor(stats.vmin, dtype=data.dtype, device=device) log_max = torch.tensor(stats.vmax, dtype=data.dtype, device=device) denom = log_max - log_min if denom == 0: return torch.zeros_like(data) log_data = data * denom + log_min return torch.expm1(log_data) # ------------------ LOG1P + STANDARD ------------------ elif norm_type == "log1p_standard": mean = torch.tensor(stats.vmean, dtype=data.dtype, device=device) std = torch.tensor(stats.vstd, dtype=data.dtype, device=device) if std == 0: return torch.zeros_like(data) log_data = data * std + mean return torch.expm1(log_data) else: raise ValueError(f"Unsupported norm_type '{norm_type}'")
@torch.no_grad() def edm_sampler( model, image_input, class_labels=None, num_steps=40, sigma_min=0.02, sigma_max=80.0, rho=7, S_churn=40, S_min=0, S_max=float("inf"), S_noise=1, ): """ EDM sampler for diffusion model inference. Original work: Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. Original source: https://github.com/NVlabs/edm Parameters ---------- model : torch.nn.Module Diffusion model image_input : torch.Tensor Conditioning input (coarse + constants) class_labels : torch.Tensor, optional Time conditioning labels num_steps : int, optional Number of sampling steps sigma_min : float, optional Minimum noise level sigma_max : float, optional Maximum noise level rho : float, optional Time step exponent S_churn : int, optional Stochasticity parameter S_min : float, optional Minimum stochasticity threshold S_max : float, optional Maximum stochasticity threshold S_noise : float, optional Noise scale for stochasticity Returns ------- torch.Tensor Generated residual predictions """ batch_size, _, H, W = image_input.shape # Get the actual model (unwrap DataParallel if needed) if isinstance(model, torch.nn.DataParallel): model = model.module # init noise init_noise = torch.randn( (batch_size, model.out_channels, H, W), dtype=image_input.dtype, device=image_input.device, ) # Adjust noise levels based on what's supported by the model. sigma_min = max(sigma_min, model.sigma_min) sigma_max = min(sigma_max, model.sigma_max) # Time step discretization. step_indices = torch.arange( num_steps, dtype=image_input.dtype, device=image_input.device ) t_steps = ( sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho)) ) ** rho t_steps = torch.cat( [model.round_sigma(t_steps), torch.zeros_like(t_steps[:1])] ) # t_N = 0 # Main sampling loop. x_next = init_noise * t_steps[0] for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 x_cur = x_next # Increase noise temporarily. gamma = ( min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0 ) t_hat = model.round_sigma(t_cur + gamma * t_cur) x_hat = x_cur + (t_hat**2 - t_cur**2).sqrt() * S_noise * torch.randn_like(x_cur) # Euler step. denoised = model(x_hat, t_hat, image_input, class_labels).to(torch.float64) d_cur = (x_hat - denoised) / t_hat x_next = x_hat + (t_next - t_hat) * d_cur # Apply 2nd order correction. if i < num_steps - 1: denoised = model(x_next, t_next, image_input, class_labels).to( torch.float64 ) d_prime = (x_next - denoised) / t_next x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) return x_next.detach() @torch.no_grad() def sampler( epoch, batch_idx, model, image_input, class_labels=None, num_steps=18, sigma_min=None, sigma_max=None, rho=7, solver="heun", discretization="edm", schedule="linear", scaling="none", epsilon_s=1e-3, C_1=0.001, C_2=0.008, M=1000, alpha=1, S_churn=40, S_min=0, S_max=float("inf"), S_noise=1, logger=None, ): """ General sampler for diffusion model inference with multiple configurations. Original work: Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. Original source: https://github.com/NVlabs/edm Parameters ---------- model : torch.nn.Module Diffusion model image_input : torch.Tensor Conditioning input (coarse + constants) class_labels : torch.Tensor, optional Time conditioning labels num_steps : int, optional Number of sampling steps sigma_min : float, optional Minimum noise level sigma_max : float, optional Maximum noise level rho : float, optional Time step exponent for EDM discretization solver : str, optional Solver type: 'euler' or 'heun' discretization : str, optional Discretization type: 'vp', 've', 'iddpm', or 'edm' schedule : str, optional Noise schedule: 'vp', 've', or 'linear' scaling : str, optional Scaling type: 'vp' or 'none' epsilon_s : float, optional Small epsilon for VP schedule C_1 : float, optional Constant for IDDPM discretization C_2 : float, optional Constant for IDDPM discretization M : int, optional Number of steps for IDDPM discretization alpha : float, optional Parameter for Heun's method S_churn : int, optional Stochasticity parameter S_min : float, optional Minimum stochasticity threshold S_max : float, optional Maximum stochasticity threshold S_noise : float, optional Noise scale for stochasticity logger : logging.Logger, optional Logger instance for logging sampler parameters Returns ------- torch.Tensor Generated residual predictions """ # Only the original asserts with messages assert solver in [ "euler", "heun", ], f"Solver must be 'euler' or 'heun', but got '{solver}'" assert ( discretization in ["vp", "ve", "iddpm", "edm"] ), f"Discretization must be 'vp', 've', 'iddpm' or 'edm', but got '{discretization}'" assert schedule in [ "vp", "ve", "linear", ], f"Schedule must be 'vp', 've' or 'linear', but got '{schedule}'" assert scaling in [ "vp", "none", ], f"Scaling must be 'vp' or 'none', but got '{scaling}'" batch_size, _, H, W = image_input.shape # Get the actual model (unwrap DataParallel if needed) if isinstance(model, torch.nn.DataParallel): model = model.module # Initialize noise latents = torch.randn( (batch_size, model.out_channels, H, W), dtype=image_input.dtype, device=image_input.device, ) # Helper functions for VP & VE noise level schedules. vp_sigma = lambda beta_d, beta_min: ( lambda t: (np.e ** (0.5 * beta_d * (t**2) + beta_min * t) - 1) ** 0.5 ) vp_sigma_deriv = lambda beta_d, beta_min: ( lambda t: 0.5 * (beta_min + beta_d * t) * (sigma(t) + 1 / sigma(t)) ) vp_sigma_inv = lambda beta_d, beta_min: ( lambda sigma: ( ((beta_min**2 + 2 * beta_d * (sigma**2 + 1).log()).sqrt() - beta_min) / beta_d ) ) ve_sigma = lambda t: t.sqrt() ve_sigma_deriv = lambda t: 0.5 / t.sqrt() ve_sigma_inv = lambda sigma: sigma**2 # Select default noise level range based on the specified time step discretization. if sigma_min is None: vp_def = vp_sigma(beta_d=19.9, beta_min=0.1)(t=epsilon_s) sigma_min = {"vp": vp_def, "ve": 0.02, "iddpm": 0.002, "edm": 0.002}[ discretization ] if sigma_max is None: vp_def = vp_sigma(beta_d=19.9, beta_min=0.1)(t=1) sigma_max = {"vp": vp_def, "ve": 100, "iddpm": 81, "edm": 80}[discretization] # Log sampler parameters if logger is provided if logger is not None and epoch == 0 and batch_idx == 0: logger.info("=== Sampler Parameters ===") logger.info(f" └── num_steps: {num_steps}") logger.info(f" └── solver: {solver}") logger.info(f" └── discretization: {discretization}") logger.info(f" └── schedule: {schedule}") logger.info(f" └── scaling: {scaling}") logger.info(f" └── sigma_min: {sigma_min}") logger.info(f" └── sigma_max: {sigma_max}") logger.info(f" └── rho: {rho}") logger.info(f" └── S_churn: {S_churn}") logger.info(f" └── S_min: {S_min}") logger.info(f" └── S_max: {S_max}") logger.info(f" └── S_noise: {S_noise}") logger.info(f" └── epsilon_s: {epsilon_s}") logger.info(f" └── C_1: {C_1}") logger.info(f" └── C_2: {C_2}") logger.info(f" └── M: {M}") logger.info(f" └── alpha: {alpha}") logger.info("==========================") # Adjust noise levels based on what's supported by the network. sigma_min = max(sigma_min, model.sigma_min) sigma_max = min(sigma_max, model.sigma_max) # Compute corresponding betas for VP. vp_beta_d = ( 2 * (np.log(sigma_min**2 + 1) / epsilon_s - np.log(sigma_max**2 + 1)) / (epsilon_s - 1) ) vp_beta_min = np.log(sigma_max**2 + 1) - 0.5 * vp_beta_d # Define time steps in terms of noise level. step_indices = torch.arange( num_steps, dtype=image_input.dtype, device=image_input.device ) if discretization == "vp": orig_t_steps = 1 + step_indices / (num_steps - 1) * (epsilon_s - 1) sigma_steps = vp_sigma(vp_beta_d, vp_beta_min)(orig_t_steps) elif discretization == "ve": orig_t_steps = (sigma_max**2) * ( (sigma_min**2 / sigma_max**2) ** (step_indices / (num_steps - 1)) ) sigma_steps = ve_sigma(orig_t_steps) elif discretization == "iddpm": u = torch.zeros(M + 1, dtype=image_input.dtype, device=image_input.device) alpha_bar = lambda j: (0.5 * np.pi * j / M / (C_2 + 1)).sin() ** 2 for j in torch.arange(M, 0, -1, device=image_input.device): # M, ..., 1 u[j - 1] = ( (u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=C_1) - 1 ).sqrt() u_filtered = u[torch.logical_and(u >= sigma_min, u <= sigma_max)] sigma_steps = u_filtered[ ((len(u_filtered) - 1) / (num_steps - 1) * step_indices) .round() .to(torch.int64) ] else: assert discretization == "edm" sigma_steps = ( sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho)) ) ** rho # Define noise level schedule. if schedule == "vp": sigma = vp_sigma(vp_beta_d, vp_beta_min) sigma_deriv = vp_sigma_deriv(vp_beta_d, vp_beta_min) sigma_inv = vp_sigma_inv(vp_beta_d, vp_beta_min) elif schedule == "ve": sigma = ve_sigma sigma_deriv = ve_sigma_deriv sigma_inv = ve_sigma_inv else: assert schedule == "linear" sigma = lambda t: t sigma_deriv = lambda t: 1 sigma_inv = lambda sigma: sigma # Define scaling schedule. if scaling == "vp": s = lambda t: 1 / (1 + sigma(t) ** 2).sqrt() s_deriv = lambda t: -sigma(t) * sigma_deriv(t) * (s(t) ** 3) else: assert scaling == "none" s = lambda t: 1 s_deriv = lambda t: 0 # Compute final time steps based on the corresponding noise levels. t_steps = sigma_inv(model.round_sigma(sigma_steps)) t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0 # Main sampling loop. t_next = t_steps[0] x_next = latents.to(image_input.dtype) * (sigma(t_next) * s(t_next)) for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 x_cur = x_next # Increase noise temporarily. gamma = ( min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= sigma(t_cur) <= S_max else 0 ) t_hat = sigma_inv(model.round_sigma(sigma(t_cur) + gamma * sigma(t_cur))) x_hat = s(t_hat) / s(t_cur) * x_cur + ( sigma(t_hat) ** 2 - sigma(t_cur) ** 2 ).clip(min=0).sqrt() * s(t_hat) * S_noise * torch.randn_like(x_cur) # Euler step. h = t_next - t_hat denoised = model(x_hat / s(t_hat), sigma(t_hat), image_input, class_labels).to( image_input.dtype ) d_cur = ( sigma_deriv(t_hat) / sigma(t_hat) + s_deriv(t_hat) / s(t_hat) ) * x_hat - sigma_deriv(t_hat) * s(t_hat) / sigma(t_hat) * denoised x_prime = x_hat + alpha * h * d_cur t_prime = t_hat + alpha * h # Apply 2nd order correction. if solver == "euler" or i == num_steps - 1: x_next = x_hat + h * d_cur else: assert solver == "heun" denoised = model( x_prime / s(t_prime), sigma(t_prime), image_input, class_labels ).to(image_input.dtype) d_prime = ( sigma_deriv(t_prime) / sigma(t_prime) + s_deriv(t_prime) / s(t_prime) ) * x_prime - sigma_deriv(t_prime) * s(t_prime) / sigma(t_prime) * denoised x_next = x_hat + h * ( (1 - 1 / (2 * alpha)) * d_cur + 1 / (2 * alpha) * d_prime ) return x_next.detach()
[docs] def reconstruct_original_layout( epoch, args, paths, steps, all_data, dataset, device, logger ): """ Robust reconstruction using dataset information directly. Parameters: ----------- all_data : dict Dictionary containing lists of batches for: - 'predictions': model predictions [B, C, H, W] - 'coarse': coarse resolution data [B, C, H, W] - 'fine': fine resolution ground truth [B, C, H, W] - 'lat': latitude coordinates [B, H] - 'lon': longitude coordinates [B, W] dataset : torch.utils.data.Dataset The validation dataset instance device : torch.device Device to store tensors on logger : Logger Logger instance for logging Returns: -------- dict: Reconstructed data with metadata """ # Get dataset parameters time_batchs = len(dataset.time_batchs) sbatch = dataset.sbatch total_dataset_samples = len(dataset) # time_batchs * sbatch # dataset_times = dataset.loaded_dfs.time.values # Get total samples from all batches total_batch_samples = sum(batch.shape[0] for batch in all_data["predictions"]) logger.info("Dataset reconstruction info:") logger.info(f" └── time_batchs: {time_batchs}") logger.info(f" └── sbatch: {sbatch}") logger.info(f" └── total dataset samples: {total_dataset_samples}") logger.info(f" └── total batch samples: {total_batch_samples}") # Handle different scenarios if total_batch_samples > total_dataset_samples: error_msg = ( f"More batch samples ({total_batch_samples}) than dataset samples ({total_dataset_samples})! " f"Something is wrong with the DataLoader." ) logger.error(error_msg) raise elif total_batch_samples < total_dataset_samples: logger.info( f"Note: Batch samples ({total_batch_samples}) < dataset samples ({total_dataset_samples})" ) logger.info("This is normal if DataLoader has drop_last=True") # Get sample shape pred_shape = all_data["predictions"][0].shape[1:] # [C, H, W] C, H, W = pred_shape logger.info(f"Sample shape: C={C}, H={H}, W={W}") # Initialize reconstruction arrays reconstructions = {} for key in ["predictions", "coarse", "fine"]: reconstructions[key] = torch.zeros( time_batchs, sbatch, C, H, W, device=device, dtype=all_data[key][0].dtype ) logger.info(f"Initialized {key} with shape: {reconstructions[key].shape}") reconstructions["lat"] = torch.zeros( time_batchs, sbatch, H, device=device, dtype=all_data["lat"][0].dtype ) reconstructions["lon"] = torch.zeros( time_batchs, sbatch, W, device=device, dtype=all_data["lon"][0].dtype ) logger.info(f"Initialized lat with shape: {reconstructions['lat'].shape}") logger.info(f"Initialized lon with shape: {reconstructions['lon'].shape}") # Create position tracking position_filled = torch.zeros(time_batchs, sbatch, dtype=torch.bool, device=device) # Map each dataset index to position index_to_position = {} for idx in range(total_dataset_samples): sindex = idx % sbatch tindex = idx // sbatch index_to_position[idx] = (tindex, sindex) logger.info(f"Created index mapping for {total_dataset_samples} samples") # Reconstruct using dataset indices dataset_idx = 0 total_reconstructed = 0 logger.info("Starting reconstruction process...") for batch_idx in range(len(all_data["predictions"])): batch = all_data["predictions"][batch_idx] batch_size = batch.shape[0] logger.info( f"Processing batch {batch_idx+1}/{len(all_data['predictions'])} with size {batch_size}" ) for i_in_batch in range(batch_size): # We can only reconstruct up to dataset samples if dataset_idx >= total_dataset_samples: logger.warning( f"Stopping at dataset_idx {dataset_idx} (dataset has {total_dataset_samples} samples)" ) break tindex, sindex = index_to_position[dataset_idx] # Store all data for key in ["predictions", "coarse", "fine"]: reconstructions[key][tindex, sindex] = all_data[key][batch_idx][ i_in_batch ] reconstructions["lat"][tindex, sindex] = all_data["lat"][batch_idx][ i_in_batch ] reconstructions["lon"][tindex, sindex] = all_data["lon"][batch_idx][ i_in_batch ] position_filled[tindex, sindex] = True total_reconstructed += 1 dataset_idx += 1 # Free memory for this batch for key in ("predictions", "coarse", "fine", "lat", "lon"): all_data[key][batch_idx] = None # Break if we've reached dataset limit if dataset_idx >= total_dataset_samples: break logger.info(f"Successfully reconstructed {total_reconstructed} samples") # Check results filled_count = position_filled.sum().item() if filled_count != total_reconstructed: logger.warning( f"filled_count ({filled_count}) != total_reconstructed ({total_reconstructed})" ) if filled_count < total_dataset_samples: missing = total_dataset_samples - filled_count logger.info( f"Note: {missing}/{total_dataset_samples} samples not reconstructed" ) logger.info("This is expected with drop_last=True in DataLoader") # Metadata metadata = { "time_batchs": time_batchs, "sbatch": sbatch, "total_dataset_samples": total_dataset_samples, "total_batch_samples": total_batch_samples, "total_reconstructed": total_reconstructed, "position_filled": position_filled, "index_to_position": index_to_position, "filled_ratio": filled_count / total_dataset_samples if total_dataset_samples > 0 else 0, "reconstruction_device": str(device), } logger.info("Reconstruction completed successfully") # Check if we need to combine spatial blocks for inference if args.run_type in ["inference", "inference_regional"]: logger.info( "Inference mode is active - combining spatial blocks to reconstruct full domain..." ) # Get evaluation slices directly from the DataPreprocessor if hasattr(dataset, "eval_slices"): eval_slices = dataset.eval_slices logger.info(f"Found {len(eval_slices)} evaluation slices") # Determine the spatial extent covered by evaluation slices. # In regional inference, slices may not start at index 0, so the domain size # is computed from the min/max slice indices. lat_min = min(s[0] for s in eval_slices) lat_max = max(s[1] for s in eval_slices) lon_min = min(s[2] for s in eval_slices) lon_max = max(s[3] for s in eval_slices) covered_H = lat_max - lat_min covered_W = lon_max - lon_min logger.info(f"Dataset dimensions: H={dataset.H}, W={dataset.W}") logger.info(f"Blocks cover: H={covered_H}, W={covered_W}") # Initialize coordinate arrays lat_reconstructed = torch.zeros(covered_H, device=device) lon_reconstructed = torch.zeros(covered_W, device=device) # Track which coordinates we've filled (must fill all!) lat_filled = torch.zeros(covered_H, dtype=torch.bool, device=device) lon_filled = torch.zeros(covered_W, dtype=torch.bool, device=device) # Initialize arrays for the COVERED area combined_data = {} for key in ["predictions", "coarse", "fine"]: combined_data[key] = torch.zeros( time_batchs, C, covered_H, covered_W, device=device, dtype=reconstructions[key].dtype, ) # Track grid coverage (must cover all!) coverage_mask = torch.zeros( covered_H, covered_W, dtype=torch.bool, device=device ) # Combine blocks and reconstruct coordinates blocks_placed = 0 for t in range(time_batchs): for spatial_idx, (lat_start, lat_end, lon_start, lon_end) in enumerate( eval_slices ): # Shift slice indices into the local reconstruction coordinate system. # This is required for regional inference where slices do not start at 0. # For global inference lat_min=lon_min=0 so indices remain unchanged. lat_start -= lat_min lat_end -= lat_min lon_start -= lon_min lon_end -= lon_min if spatial_idx >= sbatch: error_msg = ( f"CRITICAL ERROR: Slice index {spatial_idx} exceeds sbatch {sbatch}. " f"eval_slices has {len(eval_slices)} slices but only {sbatch} spatial blocks reconstructed." ) logger.error(error_msg) raise ValueError(error_msg) # Place block in combined array for key in ["predictions", "coarse", "fine"]: combined_data[key][ t, :, lat_start:lat_end, lon_start:lon_end ] = reconstructions[key][t, spatial_idx] # Reconstruct LATITUDE coordinates from this block block_lat = reconstructions["lat"][t, spatial_idx] # [H_block] lat_reconstructed[lat_start:lat_end] = block_lat lat_filled[lat_start:lat_end] = True # Reconstruct LONGITUDE coordinates from this block block_lon = reconstructions["lon"][t, spatial_idx] # [W_block] lon_reconstructed[lon_start:lon_end] = block_lon lon_filled[lon_start:lon_end] = True # Mark grid coverage coverage_mask[lat_start:lat_end, lon_start:lon_end] = True blocks_placed += 1 logger.info(f"Combined {blocks_placed} spatial blocks") # VERIFY COMPLETE COVERAGE - RAISE ERROR IF INCOMPLETE # Check latitude coordinate coverage lat_missing = (~lat_filled).sum().item() if lat_missing > 0: missing_indices = torch.nonzero(~lat_filled).squeeze().cpu().numpy() error_msg = ( f"CRITICAL ERROR: Latitude coordinate reconstruction incomplete!\n" f"Missing {lat_missing}/{covered_H} latitude coordinates.\n" f"Missing indices: {missing_indices[:10]}{'...' if len(missing_indices) > 10 else ''}\n" f"This indicates blocks don't cover the full latitude range." ) logger.error(error_msg) raise ValueError(error_msg) # Check longitude coordinate coverage lon_missing = (~lon_filled).sum().item() if lon_missing > 0: missing_indices = torch.nonzero(~lon_filled).squeeze().cpu().numpy() error_msg = ( f"CRITICAL ERROR: Longitude coordinate reconstruction incomplete!\n" f"Missing {lon_missing}/{covered_W} longitude coordinates.\n" f"Missing indices: {missing_indices[:10]}{'...' if len(missing_indices) > 10 else ''}\n" f"This indicates blocks don't cover the full longitude range." ) logger.error(error_msg) raise ValueError(error_msg) # Check grid coverage uncovered_cells = (~coverage_mask).sum().item() if uncovered_cells > 0: # Find where coverage is missing missing_mask = ~coverage_mask missing_positions = torch.nonzero(missing_mask) error_msg = ( f"CRITICAL ERROR: Grid coverage incomplete!\n" f"Missing {uncovered_cells}/{covered_H*covered_W} grid cells.\n" f"Coverage: {coverage_mask.sum().item()/(covered_H*covered_W)*100:.1f}%\n" f"First 10 missing positions (lat, lon): {missing_positions[:10].cpu().numpy().tolist()}" ) logger.error(error_msg) raise ValueError(error_msg) # Fix longitude discontinuity when blocks cross the 0°/360° meridian. # np.unwrap keeps the longitude coordinate monotonic # Ex: 358,359,0,1 to 358,359,360,361 lon_reconstructed = torch.from_numpy( np.rad2deg(np.unwrap(np.deg2rad(lon_reconstructed.cpu().numpy()))) ).to(device) # All checks passed - reconstruction is complete logger.info("✅ Coordinate reconstruction complete") logger.info("✅ Grid coverage complete") logger.info( f"Latitude range: {lat_reconstructed.min():.2f} to {lat_reconstructed.max():.2f}" ) logger.info( f"Longitude range: {lon_reconstructed.min():.2f} to {lon_reconstructed.max():.2f}" ) # Add reconstruction info to metadata metadata["coverage_info"] = { "covered_H": covered_H, "covered_W": covered_W, "full_H": dataset.H, "full_W": dataset.W, "coverage_complete": True, "coordinates_complete": True, "lat_range": [ lat_reconstructed.min().item(), lat_reconstructed.max().item(), ], "lon_range": [ lon_reconstructed.min().item(), lon_reconstructed.max().item(), ], "lat_reconstructed": lat_reconstructed.cpu(), "lon_reconstructed": lon_reconstructed.cpu(), } # Store reconstructed coordinates in reconstructions dict reconstructions["lat_reconstructed"] = lat_reconstructed reconstructions["lon_reconstructed"] = lon_reconstructed # Add combined data to reconstructions dict reconstructions["combined"] = combined_data else: logger.error( "Could not find eval_slices in dataset. Cannot combine spatial blocks." ) raise AttributeError( "Dataset missing 'eval_slices' attribute for inference reconstruction." ) logger.info(f"Generating block wise plots for epoch {epoch}...") # Loop through spatial blocks for spatial_idx in range(sbatch): # Extract data for this spatial block # shape: [time_batchs, C, H, W] predictions_block = reconstructions["predictions"][:, spatial_idx] fine_block = reconstructions["fine"][:, spatial_idx] coarse_block = reconstructions["coarse"][:, spatial_idx] # lat_block = reconstructions['lat'][:, spatial_idx] # lon_block = reconstructions['lon'][:, spatial_idx] # 0. QQ Plot save_path = plot_qq_quantiles( predictions_block, # [time_batchs, C, H, W] fine_block, # [time_batchs, C, H, W] coarse_block, # [time_batchs, C, H, W] variable_names=args.varnames_list, units=None, # You might want to add units to args quantiles=[0.90, 0.95, 0.975, 0.99, 0.995], filename=f"{args.run_type}_qq_epoch_{epoch}_spatial_block_{spatial_idx:03d}.png", save_dir=paths.results, ) logger.info(f"Saved QQ plot to {save_path}") # 1. Validation Hexbin Plot save_path = plot_validation_hexbin( predictions=predictions_block, targets=fine_block, variable_names=args.varnames_list, filename=f"{args.run_type}_validation_hexbin_epoch_{epoch}_sblock_{spatial_idx:03d}.png", save_dir=paths.results, ) logger.info(f"Saved validation hexbin plot to: {save_path}") # 2. Comparison Hexbin Plot save_path = plot_comparison_hexbin( predictions=predictions_block, targets=fine_block, coarse_inputs=coarse_block, variable_names=args.varnames_list, filename=f"{args.run_type}_comparison_hexbin_epoch_{epoch}_sblock_{spatial_idx:03d}.png", save_dir=paths.results, ) logger.info(f"Saved comparison hexbin plot to: {save_path}") # 3. Validation PDFs Plot save_path = plot_validation_pdfs( predictions=predictions_block, targets=fine_block, coarse_inputs=coarse_block, variable_names=args.varnames_list, filename=f"{args.run_type}_validation_pdfs_epoch_{epoch}_sblock_{spatial_idx:03d}.png", save_dir=paths.results, ) logger.info(f"Saved validation PDFs plot to: {save_path}") # 4. Power Spectra Plot dlon = getattr(steps, "d_longitude", None) dlat = getattr(steps, "d_latitude", None) assert dlon is not None, "d_longitude not found in steps" assert dlat is not None, "d_latitude not found in steps" save_path = plot_power_spectra( predictions=predictions_block, targets=fine_block, coarse_inputs=coarse_block, dlat=dlat, dlon=dlon, variable_names=args.varnames_list, filename=f"{args.run_type}_power_spectra_epoch_{epoch}_sblock_{spatial_idx:03d}.png", save_dir=paths.results, ) logger.info(f"Saved power spectra plot to: {save_path}") # 5. MAE map plot (time-averaged) # Latitude and longitude coordinates for this spatial block. # Coordinates are time-invariant, so we take them from the first time index (t = 0). first_time_idx = 0 # Get coordinates for this spatial block lat_block = reconstructions["lat"][first_time_idx, spatial_idx] # [H] lon_block = reconstructions["lon"][first_time_idx, spatial_idx] # [W] save_path = plot_MAE_map( predictions=predictions_block, # [T, C, H, W] targets=fine_block, # [T, C, H, W] lat_1d=lat_block, # [H] lon_1d=lon_block, # [W] variable_names=args.varnames_list, filename=f"{args.run_type}_mae_map_epoch_{epoch}_sblock_{spatial_idx:03d}.png", save_dir=paths.results, ) logger.info(f"Saved MAE map to: {save_path}") # 6. Multivariate Correlation Maps # Convert 1D lat/lon to 2D meshgrid lat_2d, lon_2d = torch.meshgrid(lat_block, lon_block, indexing="ij") save_path = plot_validation_mvcorr( predictions=predictions_block, # [T, C, H, W] targets=fine_block, # [T, C, H, W] coarse_inputs=coarse_block, # optional lat=lat_2d.numpy(), lon=lon_2d.numpy(), variable_names=args.varnames_list, filename=f"{args.run_type}_mvcorr_epoch_{epoch}_sblock_{spatial_idx:03d}.png", save_dir=paths.results, ) logger.info(f"Saved multivariate correlation map to: {save_path}") # 7. Surface plot coarse = reconstructions["coarse"][ first_time_idx : first_time_idx + 1, spatial_idx ] fine = reconstructions["fine"][first_time_idx : first_time_idx + 1, spatial_idx] pred = reconstructions["predictions"][ first_time_idx : first_time_idx + 1, spatial_idx ] save_path = plot_surface( predictions=pred, targets=fine, coarse_inputs=coarse, lat_1d=lat_block, lon_1d=lon_block, variable_names=args.varnames_list, filename=f"{args.run_type}_plot_surface_epoch_{epoch}_sblock_{spatial_idx:03d}.png", save_dir=paths.results, ) logger.info(f"Saved surface plot to: {save_path}") # 8. Temporal series save_path = plot_temporal_series_comparison( predictions=predictions_block, targets=fine_block, # coarse_inputs=coarse_block, variable_names=args.varnames_list, filename=f"{args.run_type}_temporal_series_epoch_{epoch}_sblock_{spatial_idx:03d}.png", save_dir=paths.results, ) logger.info(f"Saved temporal series plot to: {save_path}") # 9. Multivariate spatial correlation time series save_path = plot_validation_mvcorr_space( predictions=predictions_block, targets=fine_block, coarse_inputs=coarse_block, variable_names=args.varnames_list, filename=f"{args.run_type}_mvcorr_space_epoch_{epoch}_sblock_{spatial_idx:03d}.png", save_dir=paths.results, ) logger.info( f"Saved multivariate spatial correlation time series to: {save_path}" ) # For inference mode, also generate full domain plots if args.run_type in ["inference", "inference_regional"]: assert ( "combined" in reconstructions ), "Combined data not found in reconstructions for inference mode" logger.info( f"Generating full domain plots for inference mode, epoch {epoch}..." ) # Get combined data for full domain predictions_full = reconstructions["combined"][ "predictions" ] # [time_batchs, C, covered_H, covered_W] fine_full = reconstructions["combined"][ "fine" ] # [time_batchs, C, covered_H, covered_W] coarse_full = reconstructions["combined"][ "coarse" ] # [time_batchs, C, covered_H, covered_W] lat_full = reconstructions["lat_reconstructed"] # [covered_H] lon_full = reconstructions["lon_reconstructed"] # [covered_W] # Generate full domain versions of all plots # 0. QQ Plot for full domain (averaged over space) save_path = plot_qq_quantiles( predictions_full, # [time_batchs, C, H, W] fine_full, # [time_batchs, C, H, W] coarse_full, # [time_batchs, C, H, W] variable_names=args.varnames_list, units=None, quantiles=[0.90, 0.95, 0.975, 0.99, 0.995], filename=f"{args.run_type}_full_domain_qq_epoch_{epoch}.png", save_dir=paths.results, save_npz=True, ) logger.info(f"Saved full domain QQ plot to {save_path}") # 1. Validation Hexbin Plot for full domain save_path = plot_validation_hexbin( predictions=predictions_full, targets=fine_full, variable_names=args.varnames_list, filename=f"{args.run_type}_full_domain_validation_hexbin_epoch_{epoch}.png", save_dir=paths.results, ) logger.info(f"Saved full domain validation hexbin plot to: {save_path}") # 2. Comparison Hexbin Plot for full domain save_path = plot_comparison_hexbin( predictions=predictions_full, targets=fine_full, coarse_inputs=coarse_full, variable_names=args.varnames_list, filename=f"{args.run_type}_full_domain_comparison_hexbin_epoch_{epoch}.png", save_dir=paths.results, ) logger.info(f"Saved full domain comparison hexbin plot to: {save_path}") # 3. Validation PDFs Plot for full domain save_path = plot_validation_pdfs( predictions=predictions_full, targets=fine_full, coarse_inputs=coarse_full, variable_names=args.varnames_list, filename=f"{args.run_type}_full_domain_validation_pdfs_epoch_{epoch}.png", save_dir=paths.results, save_npz=True, ) logger.info(f"Saved full domain validation PDFs plot to: {save_path}") # 4. Power Spectra Plot for full domain dlon = getattr(steps, "d_longitude", None) dlat = getattr(steps, "d_latitude", None) assert dlon is not None, "d_longitude not found in steps" assert dlat is not None, "d_latitude not found in steps" save_path = plot_power_spectra( predictions=predictions_full, targets=fine_full, coarse_inputs=coarse_full, dlat=dlat, dlon=dlon, variable_names=args.varnames_list, filename=f"{args.run_type}_full_domain_power_spectra_epoch_{epoch}.png", save_dir=paths.results, save_npz=True, ) logger.info(f"Saved full domain power spectra plot to: {save_path}") # 5. MAE map Plot for full domain save_path = plot_MAE_map( predictions=predictions_full, targets=fine_full, lat_1d=lat_full, lon_1d=lon_full, variable_names=args.varnames_list, filename=f"{args.run_type}_full_domain_mae_map_epoch_{epoch}.png", save_dir=paths.results, ) logger.info(f"Saved full domain MAE map to: {save_path}") # 6. Error map Plot for full domain save_path = plot_error_map( predictions=predictions_full, targets=fine_full, lat_1d=lat_full, lon_1d=lon_full, variable_names=args.varnames_list, filename=f"{args.run_type}_full_domain_error_map_epoch_{epoch}.png", save_dir=paths.results, ) logger.info(f"Saved full domain error map to: {save_path}") # 7. Surface plots for first few time steps of full domain num_time_steps_to_plot = min(3, time_batchs) for time_idx in range(num_time_steps_to_plot): # Extract single time step pred_single_time = predictions_full[time_idx : time_idx + 1] # [1, C, H, W] fine_single_time = fine_full[time_idx : time_idx + 1] # [1, C, H, W] coarse_single_time = coarse_full[time_idx : time_idx + 1] # [1, C, H, W] tindex = dataset.time_batchs[time_idx] timestamp = pd.to_datetime( dataset.loaded_dfs.time.values[tindex] ).to_pydatetime() save_path = plot_surface( predictions=pred_single_time, targets=fine_single_time, coarse_inputs=coarse_single_time, lat_1d=lat_full, lon_1d=lon_full, timestamp=timestamp, variable_names=args.varnames_list, filename=f"{args.run_type}_full_domain_surface_epoch_{epoch}_time_{time_idx:03d}.png", save_dir=paths.results, ) logger.info( f"Saved full domain surface plot (time {time_idx}) to: {save_path}" ) # Zoom comparison plot for full domain (only for global inference) if args.run_type == "inference": save_path = plot_zoom_comparison( predictions=pred_single_time, targets=fine_single_time, lat_1d=lat_full, lon_1d=lon_full, variable_names=args.varnames_list, filename=f"{args.run_type}_full_domain_zoom_comparison_epoch_{epoch}_time_{time_idx:03d}.png", save_dir=paths.results, ) logger.info( f"Saved full domain zoom comparison (time {time_idx}) to: {save_path}" ) # 8. Multivariate Correlation Maps for full domain # Convert 1D lat/lon to 2D meshgrid lat_2d_full, lon_2d_full = torch.meshgrid(lat_full, lon_full, indexing="ij") save_path = plot_validation_mvcorr( predictions=predictions_full, # [T, C, H, W] targets=fine_full, # [T, C, H, W] coarse_inputs=coarse_full, lat=lat_2d_full.numpy(), lon=lon_2d_full.numpy(), variable_names=args.varnames_list, filename=f"{args.run_type}_full_domain_mvcorr_epoch_{epoch}.png", save_dir=paths.results, ) logger.info(f"Saved full domain multivariate correlation map to: {save_path}") # 9. Temporal series for full domain save_path = plot_temporal_series_comparison( predictions=predictions_full, targets=fine_full, # coarse_inputs=coarse_full, variable_names=args.varnames_list, filename=f"{args.run_type}_full_domain_temporal_series_epoch_{epoch}.png", save_dir=paths.results, ) logger.info(f"Saved full domain temporal series plot to: {save_path}") # 10. Multivariate spatial correlation time series for full domain save_path = plot_validation_mvcorr_space( predictions=predictions_full, targets=fine_full, coarse_inputs=coarse_full, variable_names=args.varnames_list, filename=f"{args.run_type}_full_domain_mvcorr_space_epoch_{epoch}.png", save_dir=paths.results, ) logger.info( f"Saved full domain multivariate spatial correlation time series to: {save_path}" ) return {"data": reconstructions, "metadata": metadata, "device": device}
[docs] def generate_residuals_norm( model, features, labels, targets, loss_fn, args, device, logger, epoch=0, batch_idx=0, inference_type="sampler", ): """ Generate normalized residuals for all variables. Parameters ---------- model : torch.nn.Module Diffusion model features : torch.Tensor Input feature tensor provided to the model labels : torch.Tensor Conditioning labels provided to the model targets : torch.Tensor Ground truth target tensor used for noise injection in direct inference loss_fn : callable Loss function args : argparse.Namespace Command line arguments device : torch.device Training device logger : Logger Logger instance epoch : int Current epoch number inference_type : str, optional Inference mode, either "direct" (deterministic) or "sampler" (stochastic diffusion sampling) Returns ------- torch.Tensor [B, C, H, W] residuals in normalized space """ # Generate samples for metrics calculation # Choose direct for rapid evaluation, sampler for full quality if inference_type == "direct": if args.debug: logger.info("Using direct inference/evaluation mode (deterministic)") if args.precond == "unet": # Direct prediction for unet generated_residuals = model(features, class_labels=labels) else: rnd_normal = torch.randn([targets.shape[0], 1, 1, 1], device=targets.device) sigma = (rnd_normal * loss_fn.P_std + loss_fn.P_mean).exp() noisy_targets = targets + torch.randn_like(targets) * sigma generated_residuals = model(noisy_targets, sigma, features, labels) elif inference_type == "sampler": if args.precond == "unet": raise ValueError("UNet does not support sampler inference") if args.debug and logger is not None: logger.info("Using sampler inference/evaluation mode (stochastic)") logger.info(f"Starting EDM sampler with {args.num_steps} steps") generated_residuals = sampler( epoch, batch_idx, model, features, labels, num_steps=args.num_steps, sigma_min=args.sigma_min, sigma_max=args.sigma_max, rho=args.rho, solver=args.solver, S_churn=args.s_churn, S_min=args.s_min, S_max=args.s_max, S_noise=args.s_noise, logger=logger, ) else: logger.error(f"Unknown inference_type: {inference_type}") raise return generated_residuals
[docs] def run_validation( model, valid_dataset, valid_loader, loss_fn, norm_mapping, normalization_type, index_mapping, args, steps, device, logger, epoch, writer=None, plot_every_n_epochs=None, paths=None, compute_crps=False, crps_batch_size=2, crps_ensemble_size=10, ): """ Run validation on the model. Parameters ---------- model : torch.nn.Module Diffusion model valid_loader : DataLoader Validation data loader loss_fn : callable Loss function norm_mapping : dict Normalization statistics normalization_type : EasyDict Normalization types for each variable args : argparse.Namespace Command line arguments device : torch.device Training device logger : Logger Logger instance epoch : int Current epoch number writer : SummaryWriter, optional TensorBoard writer plot_every_n_epochs : int, optional Frequency (in epochs) at which validation plots are generated paths : dict, optional Paths used for saving reconstructions and plots compute_crps : bool, optional Whether to compute CRPS using stochastic ensemble sampling crps_batch_size : int, optional Number of validation batches used for CRPS computation crps_ensemble_size : int, optional Number of ensemble members used to estimate CRPS Returns ------- tuple (avg_val_loss, val_metrics) - average validation loss and metrics dictionary """ # Define available metrics metric_names = [ "MAE", "NMAE", "RMSE", "R2", "PEARSON", "KL", ] # You can add more metrics here like ["MAE", "MSE", "RMSE"] metric_funcs = { "MAE": mae_all, "NMAE": nmae_all, "RMSE": rmse_all, "R2": r2_all, "PEARSON": pearson_all, "KL": kl_divergence_all, # You can add more metrics here: # "MSE": mse_all, } # Add CRPS only if requested if compute_crps: metric_names.append("CRPS") metric_funcs["CRPS"] = crps_ensemble_all # Separate deterministic metrics from CRPS. # CRPS is handled separately due to its stochastic and expensive nature. deterministic_metrics = [m for m in metric_names if m != "CRPS"] model.eval() val_loss = MetricTracker() # Create metrics for both model predictions and coarse baseline. # This is done in two steps because deterministic metrics (MAE, NMAE) # are computed for both model predictions and the coarse baseline, # whereas CRPS is a probabilistic metric and is only defined for # stochastic model outputs (no coarse vs fine CRPS). val_metrics = {} for k in args.varnames_list: for m in deterministic_metrics: val_metrics[f"{k}_pred_vs_fine_{m}"] = ( MetricTracker() ) # Model prediction vs true fine val_metrics[f"{k}_coarse_vs_fine_{m}"] = ( MetricTracker() ) # Coarse vs true fine (baseline) if compute_crps: val_metrics[f"{k}_pred_vs_fine_CRPS"] = MetricTracker() # Add average metrics across all variables for each metric type for m in deterministic_metrics: val_metrics[f"average_pred_vs_fine_{m}"] = MetricTracker() val_metrics[f"average_coarse_vs_fine_{m}"] = MetricTracker() if compute_crps: val_metrics["average_pred_vs_fine_CRPS"] = MetricTracker() all_data = {"predictions": [], "coarse": [], "fine": [], "lat": [], "lon": []} crps_batches = [] logger.info(f"Running validation for epoch {epoch}...") logger.info(f"EDM Sampler parameters: steps={args.num_steps}") with torch.no_grad(): val_loop = tqdm( enumerate(valid_loader), total=len(valid_loader), desc=f"Validation Epoch {epoch}", ) for batch_idx, batch in val_loop: # Move data to device features = batch["inputs"].to(device) targets = batch["targets"].to(device) coarse = batch["coarse"].to(device) # coarse_norm = batch["coarse_norm"].to(device) # Number of variables (channels) n_vars = len(args.varnames_list) # Extract normalized coarse field from model inputs coarse_norm = features[:, :n_vars] fine = batch["fine"].to(device) lat_batch = batch["corrdinates"]["lat"].to(device) lon_batch = batch["corrdinates"]["lon"].to(device) if epoch == 0 and batch_idx == 0: logger.info( f"Validation batch idx:{batch_idx}\n" f"features shape:{features.shape}, targets shape:{targets.shape}\n" f"coarse shape:{coarse.shape}, fine shape:{fine.shape}\n" f"lat shape:{lat_batch.shape}, lon shape:{lon_batch.shape}" ) # Prepare labels if args.time_normalization == "linear": labels = torch.stack( (batch["doy"].to(device), batch["hour"].to(device)), dim=1 ) elif args.time_normalization == "cos_sin": labels = torch.stack( ( batch["doy_sin"].to(device), batch["doy_cos"].to(device), batch["hour_sin"].to(device), batch["hour_cos"].to(device), ), dim=1, ) # Calculate validation loss with torch.amp.autocast(device_type=device.type, dtype=features.dtype): loss = loss_fn(model, targets, features, labels) # unet loss is a scalar, so no need for mean if args.precond != "unet": loss = loss.mean() val_loss.update(loss.item(), targets.shape[0]) # Store a limited number of batches for CRPS computation. # CRPS is expensive, so we only keep the first crps_batch_size batches # and reuse the existing features and labels. if compute_crps and len(crps_batches) < crps_batch_size: crps_batches.append( { "features": features, "labels": labels, "batch": batch, } ) # Track batch-level averages for overall metrics for each metric type batch_metric_sums = { m: {"pred": MetricTracker(), "coarse": MetricTracker()} for m in deterministic_metrics } generated_residual = generate_residuals_norm( model=model, features=features, labels=labels, targets=targets, loss_fn=loss_fn, args=args, device=device, logger=logger, epoch=epoch, batch_idx=batch_idx, inference_type=args.inference_type, ) batch_predictions = [] # Reconstruct final images for var_name in args.varnames_list: # Get the correct channel index for this variable iv = index_mapping[var_name] # Reconstruct final image: coarse + residual coarse_var_norm = coarse_norm[:, iv : iv + 1] final_prediction_norm = ( coarse_var_norm + generated_residual[:, iv : iv + 1] ) # Calculate metrics against ground truth fine data fine_var = fine[:, iv : iv + 1] coarse_var = coarse[:, iv : iv + 1] norm_type = normalization_type[var_name] if norm_type.startswith("log1p"): stats_fine = norm_mapping[f"{var_name}_fine_log"] else: stats_fine = norm_mapping[f"{var_name}_fine"] final_prediction = denormalize( final_prediction_norm, stats_fine, norm_type, device, var_name=var_name, data_type="fine", debug=args.debug, logger=logger, ) batch_predictions.append(final_prediction) # Calculate all metrics for this variable for metric_name in deterministic_metrics: metric_func = metric_funcs[metric_name] # Model prediction vs fine num_elements_pred, metric_value_pred = metric_func( final_prediction, fine_var ) val_metrics[f"{var_name}_pred_vs_fine_{metric_name}"].update( metric_value_pred.item(), num_elements_pred ) # Coarse vs fine (baseline metric) num_elements_coarse, metric_value_coarse = metric_func( coarse_var, fine_var ) val_metrics[f"{var_name}_coarse_vs_fine_{metric_name}"].update( metric_value_coarse.item(), num_elements_coarse ) # Accumulate for batch averages batch_metric_sums[metric_name]["pred"].update( metric_value_pred.item(), num_elements_pred ) batch_metric_sums[metric_name]["coarse"].update( metric_value_coarse.item(), num_elements_coarse ) final_prediction_batch = torch.cat(batch_predictions, dim=1) # [B, C, H, W] # Store only needed data for reconstruction # Validation outputs are accumulated and immediately moved to CPU # to avoid CUDA out-of-memory errors. all_data["predictions"].append(final_prediction_batch.detach().cpu()) all_data["coarse"].append(coarse.detach().cpu()) all_data["fine"].append(fine.detach().cpu()) all_data["lat"].append(lat_batch.detach().cpu()) # [B, H] all_data["lon"].append(lon_batch.detach().cpu()) # [B, W] # Update overall average metrics for this batch for each metric type for metric_name in deterministic_metrics: batch_avg_pred = batch_metric_sums[metric_name]["pred"].getmean() batch_avg_coarse = batch_metric_sums[metric_name]["coarse"].getmean() val_metrics[f"average_pred_vs_fine_{metric_name}"].update( batch_avg_pred, 1 ) val_metrics[f"average_coarse_vs_fine_{metric_name}"].update( batch_avg_coarse, 1 ) # Update progress bar (show first metric by default) primary_metric = deterministic_metrics[0] batch_avg_pred = batch_metric_sums[primary_metric]["pred"].getmean() batch_avg_coarse = batch_metric_sums[primary_metric]["coarse"].getmean() val_loop.set_postfix( { "Val Loss": f"{loss.item():.4f}", "Avg Val Loss": f"{val_loss.getmean():.4f}", f"Avg Pred {primary_metric}": f"{batch_avg_pred:.4f}", f"Avg Coarse {primary_metric}": f"{batch_avg_coarse:.4f}", } ) torch.cuda.empty_cache() avg_val_loss = val_loss.getmean() # To verify with Kazem # Compute CRPS only if requested and if some batches were collected. # CRPS is evaluated using an ensemble of stochastic sampler runs. if compute_crps and len(crps_batches) > 0: logger.info( "CRPS configuration summary:\n" f" └── Number of CRPS batches: {len(crps_batches)}\n" f" └── Ensemble size: {crps_ensemble_size}" ) for item in tqdm(crps_batches, desc="CRPS batches", total=len(crps_batches)): features = item["features"] labels = item["labels"] batch = item["batch"] # Generate an ensemble of predictions using the sampler ens_preds = [] for _ in tqdm(range(crps_ensemble_size), desc="CRPS ensemble", leave=False): generated_residual = generate_residuals_norm( model=model, features=features, labels=labels, targets=batch["targets"].to(device), loss_fn=loss_fn, args=args, device=device, logger=None, epoch=epoch, batch_idx=-1, # not tied to validation loop inference_type="sampler", ) # Reconstruct final prediction reconstructed_vars = [] # Extract normalized coarse field from inputs n_vars = len(args.varnames_list) coarse_norm = features[:, :n_vars] for var_name in args.varnames_list: iv = index_mapping[var_name] # coarse_var_norm = batch["coarse_norm"][:, iv:iv+1].to(device) coarse_var_norm = coarse_norm[:, iv : iv + 1] final_pred_norm = ( coarse_var_norm + generated_residual[:, iv : iv + 1] ) norm_type = normalization_type[var_name] if norm_type.startswith("log1p"): stats_fine = norm_mapping[f"{var_name}_fine_log"] else: stats_fine = norm_mapping[f"{var_name}_fine"] final_pred = denormalize( final_pred_norm, stats_fine, norm_type, device, var_name=var_name, data_type="fine", debug=args.debug, logger=logger, ) reconstructed_vars.append(final_pred) # Final reconstructed prediction for this ensemble member final_prediction = torch.cat(reconstructed_vars, dim=1) # [B, C, H, W] ens_preds.append(final_prediction) # Compute CRPS per variable pred_ens = torch.stack(ens_preds, dim=0) # [N_ens, B, C, H, W] for var_name in args.varnames_list: iv = index_mapping[var_name] pred_ens_var = pred_ens[:, :, iv : iv + 1, :, :] # [N_ens, B, 1, H, W] fine_var = batch["fine"][:, iv : iv + 1].to(device) pred_ens_flat = pred_ens_var.reshape(crps_ensemble_size, -1) true_flat = fine_var.reshape(-1) # Compute CRPS per variable using ensemble predictions. num_elem, crps_mean = crps_ensemble_all(pred_ens_flat, true_flat) # Update per-variable CRPS tracker val_metrics[f"{var_name}_pred_vs_fine_CRPS"].update( crps_mean.item(), num_elem ) # Global average CRPS tracker val_metrics["average_pred_vs_fine_CRPS"].update( crps_mean.item(), num_elem ) # Log validation results logger.info(f"Validation Epoch {epoch} - Average Loss: {avg_val_loss:.4f}") logger.info("=" * 60) logger.info("VALIDATION METRICS SUMMARY:") logger.info("=" * 60) # Log overall metrics for each metric type for metric_name in metric_names: if metric_name == "CRPS": # Log CRPS only when it has been computed to avoid empty MetricTracker access. if compute_crps: final_avg_pred = val_metrics["average_pred_vs_fine_CRPS"].getmean() std_avg_pred = val_metrics["average_pred_vs_fine_CRPS"].getstd() logger.info("OVERALL CRPS:") logger.info( f" └── Average Prediction vs Fine CRPS: {final_avg_pred:.5f} ± {std_avg_pred:.5f}" ) else: final_avg_pred = val_metrics[ f"average_pred_vs_fine_{metric_name}" ].getmean() final_avg_coarse = val_metrics[ f"average_coarse_vs_fine_{metric_name}" ].getmean() std_avg_pred = val_metrics[f"average_pred_vs_fine_{metric_name}"].getstd() std_avg_coarse = val_metrics[ f"average_coarse_vs_fine_{metric_name}" ].getstd() logger.info(f"OVERALL {metric_name} METRICS:") logger.info( f" └── Average Prediction vs Fine {metric_name}: {final_avg_pred:.4f} ± {std_avg_pred:.4f}" ) logger.info( f" └── Average Coarse vs Fine {metric_name}: {final_avg_coarse:.4f} ± {std_avg_coarse:.4f}" ) logger.info("") # Log per-variable metrics logger.info("PER-VARIABLE METRICS:") for var_name in args.varnames_list: logger.info(f" └── {var_name}:") for metric_name in metric_names: if metric_name == "CRPS": # Log CRPS only when it has been computed to avoid empty MetricTracker access. if compute_crps: crps_var = val_metrics[f"{var_name}_pred_vs_fine_CRPS"].getmean() crps_std = val_metrics[f"{var_name}_pred_vs_fine_CRPS"].getstd() logger.info(" └── CRPS:") logger.info( f" └── Model Pred vs Fine: {crps_var:.5f} ± {crps_std:.5f}" ) else: pred_metric = val_metrics[ f"{var_name}_pred_vs_fine_{metric_name}" ].getmean() pred_std = val_metrics[ f"{var_name}_pred_vs_fine_{metric_name}" ].getstd() coarse_metric = val_metrics[ f"{var_name}_coarse_vs_fine_{metric_name}" ].getmean() coarse_std = val_metrics[ f"{var_name}_coarse_vs_fine_{metric_name}" ].getstd() logger.info(f" └── {metric_name}:") logger.info( f" └── Model Pred vs Fine: {pred_metric:.4f} ± {pred_std:.4f}" ) logger.info( f" └── Coarse vs Fine: {coarse_metric:.4f} ± {coarse_std:.4f}" ) # To verify with Kazem # Global heatmap of validation metrics (per variable × metric) if paths is not None: try: heatmap_path = plot_metrics_heatmap( valid_metrics_history=val_metrics, variable_names=args.varnames_list, metric_names=metric_names, filename=f"{args.run_type}_validation_metrics_epoch_{epoch}", save_dir=paths.results, ) logger.info(f"Saved validation metrics heatmap to: {heatmap_path}") except Exception as e: logger.warning(f"Could not generate metrics heatmap: {e}") # Check if we should create plots for this batch should_plot = ( plot_every_n_epochs is not None and epoch % plot_every_n_epochs == 0 and paths is not None ) if should_plot: logger.info("Reconstructing and plots ...") _ = reconstruct_original_layout( epoch, args, paths, steps, all_data=all_data, dataset=valid_dataset, # device=device, # Keep on the same device --> OOM device=torch.device( "cpu" ), # reconstruction & plotting on CPU to avoid cuda out of memory logger=logger, # Pass the logger ) # Log to TensorBoard if writer is provided if writer is not None: writer.add_scalar("Loss/val_epoch", avg_val_loss, epoch) # Log overall metrics for each metric type for metric_name in metric_names: if metric_name == "CRPS": # Log CRPS only when it has been computed to avoid empty MetricTracker access. if compute_crps: final_avg_pred = val_metrics["average_pred_vs_fine_CRPS"].getmean() std_pred = val_metrics["average_pred_vs_fine_CRPS"].getstd() writer.add_scalar( "Metrics/average_pred_vs_fine_CRPS", final_avg_pred, epoch ) writer.add_scalar( "Metrics/average_pred_vs_fine_CRPS_std", std_pred, epoch ) else: final_avg_pred = val_metrics[ f"average_pred_vs_fine_{metric_name}" ].getmean() std_pred = val_metrics[f"average_pred_vs_fine_{metric_name}"].getstd() final_avg_coarse = val_metrics[ f"average_coarse_vs_fine_{metric_name}" ].getmean() std_coarse = val_metrics[ f"average_coarse_vs_fine_{metric_name}" ].getstd() writer.add_scalar( f"Metrics/average_pred_vs_fine_{metric_name}", final_avg_pred, epoch ) writer.add_scalar( f"Metrics/average_pred_vs_fine_{metric_name}_std", std_pred, epoch ) writer.add_scalar( f"Metrics/average_coarse_vs_fine_{metric_name}", final_avg_coarse, epoch, ) writer.add_scalar( f"Metrics/average_coarse_vs_fine_{metric_name}_std", std_coarse, epoch, ) # Log per-variable metrics for var_name in args.varnames_list: for metric_name in metric_names: if metric_name == "CRPS": # Log CRPS only when it has been computed to avoid empty MetricTracker access. if compute_crps: crps_var = val_metrics[ f"{var_name}_pred_vs_fine_CRPS" ].getmean() crps_var_std = val_metrics[ f"{var_name}_pred_vs_fine_CRPS" ].getstd() writer.add_scalar( f"Metrics/{var_name}_pred_vs_fine_CRPS", crps_var, epoch ) writer.add_scalar( f"Metrics/{var_name}_pred_vs_fine_CRPS_std", crps_var_std, epoch, ) else: pred_metric = val_metrics[ f"{var_name}_pred_vs_fine_{metric_name}" ].getmean() pred_metric_std = val_metrics[ f"{var_name}_pred_vs_fine_{metric_name}" ].getstd() coarse_metric = val_metrics[ f"{var_name}_coarse_vs_fine_{metric_name}" ].getmean() coarse_metric_std = val_metrics[ f"{var_name}_coarse_vs_fine_{metric_name}" ].getstd() writer.add_scalar( f"Metrics/{var_name}_pred_vs_fine_{metric_name}", pred_metric, epoch, ) writer.add_scalar( f"Metrics/{var_name}_pred_vs_fine_{metric_name}_std", pred_metric_std, epoch, ) writer.add_scalar( f"Metrics/{var_name}_coarse_vs_fine_{metric_name}", coarse_metric, epoch, ) writer.add_scalar( f"Metrics/{var_name}_coarse_vs_fine_{metric_name}_std", coarse_metric_std, epoch, ) return avg_val_loss, val_metrics
[docs] class TestMetricTracker(unittest.TestCase): """Unit tests for MetricTracker class."""
[docs] def __init__(self, methodName="runTest", logger=None): super().__init__(methodName) self.logger = logger
[docs] def setUp(self): """Set up test fixtures.""" if self.logger: self.logger.info("Setting up MetricTracker test fixtures")
[docs] def test_metric_tracker_init(self): """Test MetricTracker initialization.""" if self.logger: self.logger.info("Testing MetricTracker initialization") tracker = MetricTracker() self.assertEqual(tracker.value, 0.0) self.assertEqual(tracker.count, 0) if self.logger: self.logger.info("✅ MetricTracker initialization test passed")
[docs] def test_metric_tracker_reset(self): """Test MetricTracker reset method.""" if self.logger: self.logger.info("Testing MetricTracker reset") tracker = MetricTracker() tracker.value = 10.5 tracker.count = 5 tracker.reset() self.assertEqual(tracker.value, 0.0) self.assertEqual(tracker.count, 0) if self.logger: self.logger.info("✅ MetricTracker reset test passed")
[docs] def test_metric_tracker_update(self): """Test MetricTracker update method.""" if self.logger: self.logger.info("Testing MetricTracker update") tracker = MetricTracker() # First update tracker.update(10.0, 5) self.assertEqual(tracker.value, 50.0) # 10 * 5 self.assertEqual(tracker.count, 5) # Second update tracker.update(20.0, 3) self.assertEqual(tracker.value, 110.0) # 50 + 20*3 self.assertEqual(tracker.count, 8) # 5 + 3 # Third update with zero count tracker.update(30.0, 0) self.assertEqual(tracker.value, 110.0) # Unchanged self.assertEqual(tracker.count, 8) # Unchanged if self.logger: self.logger.info("✅ MetricTracker update test passed")
[docs] def test_metric_tracker_getmean(self): """Test MetricTracker getmean method.""" if self.logger: self.logger.info("Testing MetricTracker getmean") tracker = MetricTracker() # Test with valid updates tracker.update(10.0, 5) tracker.update(20.0, 3) mean = tracker.getmean() expected_mean = 110.0 / 8 # (10*5 + 20*3) / (5+3) = 110/8 = 13.75 self.assertAlmostEqual(mean, expected_mean, places=6) # Test with zero count (should raise ZeroDivisionError) tracker.reset() with self.assertRaises(ZeroDivisionError): tracker.getmean() if self.logger: self.logger.info("✅ MetricTracker getmean test passed")
[docs] def test_metric_tracker_getstd(self): """Test MetricTracker getstd method.""" if self.logger: self.logger.info("Testing MetricTracker getstd") tracker = MetricTracker() # Known values # Values: [10 (×5), 20 (×3)] # mean = 13.75 # E[x^2] = (10^2 * 5 + 20^2 * 3) / 8 = (500 + 1200) / 8 = 212.5 # variance = 212.5 - 13.75^2 = 23.4375 # std = sqrt(23.4375) ≈ 4.841229 tracker.update(10.0, 5) tracker.update(20.0, 3) std = tracker.getstd() expected_std = np.sqrt(212.5 - 13.75**2) self.assertAlmostEqual(std, expected_std, places=6) # Test with zero count (should raise ZeroDivisionError) tracker.reset() with self.assertRaises(ZeroDivisionError): tracker.getstd() if self.logger: self.logger.info("✅ MetricTracker getstd test passed")
[docs] def test_metric_tracker_getsqrtmean(self): """Test MetricTracker getsqrtmean method.""" if self.logger: self.logger.info("Testing MetricTracker getsqrtmean") tracker = MetricTracker() tracker.update(16.0, 2) # mean = 16, sqrt = 4 tracker.update(4.0, 2) # mean = (16*2 + 4*2)/4 = 10, sqrt = sqrt(10) sqrtmean = tracker.getsqrtmean() expected_sqrtmean = np.sqrt(10.0) # sqrt(10) ≈ 3.16227766 self.assertAlmostEqual(sqrtmean, expected_sqrtmean, places=6) # Test with zero count (should raise ZeroDivisionError) tracker.reset() with self.assertRaises(ZeroDivisionError): tracker.getsqrtmean() if self.logger: self.logger.info("✅ MetricTracker getsqrtmean test passed")
[docs] def test_metric_tracker_example_from_docstring(self): """Test the example provided in the docstring.""" if self.logger: self.logger.info("Testing MetricTracker docstring example") tracker = MetricTracker() tracker.update(10.0, 5) tracker.update(20.0, 3) mean = tracker.getmean() sqrtmean = tracker.getsqrtmean() # Expected values from docstring example expected_mean = 110.0 / 8 # 13.75 expected_sqrtmean = np.sqrt(13.75) # 3.7080992435478315 self.assertAlmostEqual(mean, expected_mean, places=6) self.assertAlmostEqual(sqrtmean, expected_sqrtmean, places=6) if self.logger: self.logger.info("✅ MetricTracker docstring example test passed")
[docs] class TestErrorMetrics(unittest.TestCase): """Unit tests for error metrics."""
[docs] def __init__(self, methodName="runTest", logger=None): super().__init__(methodName) self.logger = logger
[docs] def setUp(self): """Set up test fixtures.""" if self.logger: self.logger.info("Setting up error metrics test fixtures") self.metrics = { "MAE": mae_all, "NMAE": nmae_all, "RMSE": rmse_all, "R2": r2_all, "PEARSON": pearson_all, }
def _compute_expected(self, metric_name, pred, true): mae = torch.mean(torch.abs(pred - true)) if metric_name == "MAE": return mae elif metric_name == "NMAE": denom = torch.mean(torch.abs(true)) return mae / denom if denom != 0 else torch.zeros_like(mae) elif metric_name == "RMSE": diff = pred - true return torch.sqrt(torch.mean(diff**2)) elif metric_name == "R2": true_flat = true.reshape(-1) pred_flat = pred.reshape(-1) ss_res = torch.sum((true_flat - pred_flat) ** 2) ss_tot = torch.sum((true_flat - torch.mean(true_flat)) ** 2) return 1.0 - ss_res / (ss_tot + 1e-12) elif metric_name == "PEARSON": pred_flat = pred.reshape(-1) true_flat = true.reshape(-1) stacked = torch.stack([pred_flat, true_flat], dim=0) corr = torch.corrcoef(stacked)[0, 1] return corr else: raise ValueError(metric_name)
[docs] def test_basic(self): """Test error metrics with simple tensors.""" if self.logger: self.logger.info("Testing error metrics basic functionality") pred = torch.tensor([1.0, 2.0, 3.0]) true = torch.tensor([1.1, 1.9, 3.2]) for name, func in self.metrics.items(): with self.subTest(metric=name): num_elements, value = func(pred, true) expected = self._compute_expected(name, pred, true) self.assertEqual(num_elements, 3) self.assertAlmostEqual(value.item(), expected.item(), places=4) if self.logger: self.logger.info("✅ Error metrics basic test passed")
[docs] def test_exact_match(self): """Test error metrics with identical tensors.""" if self.logger: self.logger.info("Testing error metrics with identical tensors") pred = torch.tensor([1.0, 2.0, 3.0, 4.0]) true = torch.tensor([1.0, 2.0, 3.0, 4.0]) for name, func in self.metrics.items(): with self.subTest(metric=name): num_elements, value = func(pred, true) self.assertEqual(num_elements, 4) # R2 behaves differently from error-based metrics: # for a perfect prediction, R2 = PEARSON = 1.0 whereas error metrics equal 0.0. if name == "R2" or name == "PEARSON": self.assertAlmostEqual(value.item(), 1.0, places=6) else: self.assertEqual(value.item(), 0.0) if self.logger: self.logger.info("✅ Error metrics exact match test passed")
[docs] def test_multi_dimensional(self): """Test error metrics with multi-dimensional tensors.""" if self.logger: self.logger.info("Testing error metrics with multi-dimensional tensors") pred = torch.randn(2, 3, 4, 5) # Batch size 2, channels 3, height 4, width 5 true = torch.randn(2, 3, 4, 5) for name, func in self.metrics.items(): with self.subTest(metric=name): num_elements, value = func(pred, true) expected = self._compute_expected(name, pred, true) self.assertEqual(num_elements, pred.numel()) self.assertAlmostEqual(value.item(), expected.item(), places=6) if self.logger: self.logger.info("✅ Error metrics multi-dimensional test passed")
[docs] def test_different_shapes(self): """Test error metrics with tensors of different shapes.""" if self.logger: self.logger.info("Testing error metrics with different shapes") pred = torch.randn(2, 3, 4) true = torch.randn(2, 3, 4) # This should work fine since shapes match for name, func in self.metrics.items(): with self.subTest(metric=name): num_elements, value = func(pred, true) expected = self._compute_expected(name, pred, true) self.assertEqual(num_elements, 2 * 3 * 4) # 24 self.assertIsInstance(value, torch.Tensor) self.assertAlmostEqual(value.item(), expected.item(), places=6) # Test with mismatched shapes (should fail) true_wrong = torch.randn(2, 4, 3) # Different shape for name, func in self.metrics.items(): with self.subTest(metric=name): with self.assertRaises(RuntimeError): func(pred, true_wrong) if self.logger: self.logger.info("✅ Error metrics shape tests passed")
[docs] def test_dtype_preservation(self): """Test that error metrics preserve data types.""" if self.logger: self.logger.info("Testing error metrics data type preservation") for dtype in (torch.float32, torch.float64): pred = torch.tensor([1.0, 2.0, 3.0], dtype=dtype) true = torch.tensor([1.1, 1.9, 3.2], dtype=dtype) for name, func in self.metrics.items(): with self.subTest(metric=name, dtype=dtype): _, value = func(pred, true) self.assertEqual(value.dtype, dtype) if self.logger: self.logger.info("✅ Error metrics dtype preservation test passed")
[docs] def test_example_from_docstring(self): """Test error metrics examples from their docstrings.""" if self.logger: self.logger.info("Testing error metrics docstring examples") pred = torch.tensor([1.0, 2.0, 3.0]) true = torch.tensor([1.1, 1.9, 3.2]) for name, func in self.metrics.items(): with self.subTest(metric=name): num_elements, value = func(pred, true) expected = self._compute_expected(name, pred, true) self.assertEqual(num_elements, 3) self.assertAlmostEqual(value.item(), expected.item(), places=6) if self.logger: self.logger.info("✅ Error metrics docstring example tests passed")
# KL is estimated via histograms (numerical approximation), # so no exact analytical expected value can be computed.
[docs] def test_kl_divergence_basic(self): """Test KL divergence properties.""" if self.logger: self.logger.info("Testing KL divergence basic properties") torch.manual_seed(0) # Identical distributions → KL ≈ 0 true = torch.randn(1000) pred_same = true.clone() num_elements, kl_same = kl_divergence_all(pred_same, true) self.assertEqual(num_elements, true.numel()) self.assertTrue(torch.isfinite(kl_same)) self.assertAlmostEqual(kl_same.item(), 0.0, places=4) # Different distributions → KL > 0 pred_shifted = true + 2.0 # shift distribution _, kl_diff = kl_divergence_all(pred_shifted, true) self.assertTrue(torch.isfinite(kl_diff)) self.assertGreaterEqual(kl_diff.item(), 0.0) self.assertGreater(kl_diff.item(), kl_same.item()) if self.logger: self.logger.info("✅ KL divergence basic test passed")
[docs] def test_kl_different_shapes(self): """ KL divergence should raise RuntimeError if tensor shapes differ. """ if self.logger: self.logger.info("Testing KL divergence shape mismatch") pred = torch.randn(10) true = torch.randn(5) with self.assertRaises(RuntimeError): kl_divergence_all(pred, true) if self.logger: self.logger.info("✅ KL shape mismatch test passed")
[docs] def test_kl_dtype_preservation(self): """ Ensure KL divergence preserves the input tensor dtype. """ if self.logger: self.logger.info("Testing KL divergence dtype preservation") true_f32 = torch.randn(500, dtype=torch.float32) pred_f32 = true_f32 + 0.1 _, kl_f32 = kl_divergence_all(pred_f32, true_f32) self.assertEqual(kl_f32.dtype, torch.float32) true_f64 = true_f32.double() pred_f64 = pred_f32.double() _, kl_f64 = kl_divergence_all(pred_f64, true_f64) self.assertEqual(kl_f64.dtype, torch.float64) if self.logger: self.logger.info("✅ KL dtype preservation test passed")
[docs] def test_kl_multi_dimensional(self): """ KL divergence should correctly handle multi-dimensional tensors by flattening them internally. """ if self.logger: self.logger.info("Testing KL divergence with multi-dimensional tensors") torch.manual_seed(0) pred = torch.randn(2, 3, 4, 5) true = torch.randn(2, 3, 4, 5) num_elements, kl_value = kl_divergence_all(pred, true) self.assertEqual(num_elements, pred.numel()) self.assertTrue(torch.isfinite(kl_value)) self.assertGreaterEqual(kl_value.item(), 0.0) if self.logger: self.logger.info("✅ KL multi-dimensional test passed")
[docs] class TestCRPSFunction(unittest.TestCase): """Unit tests for crps_ensemble_all function."""
[docs] def __init__(self, methodName="runTest", logger=None): super().__init__(methodName) self.logger = logger
[docs] def setUp(self): """Set up test fixtures.""" if self.logger: self.logger.info("Setting up crps_ensemble_all test fixtures")
[docs] def test_crps_basic(self): """Test CRPS with simple known values.""" if self.logger: self.logger.info("Testing CRPS basic functionality") true = torch.tensor([2.0, 2.0]) pred_ens = torch.tensor( [ [1.0, 3.0], [2.0, 2.0], [3.0, 1.0], ] ) # N_ens = 3 num_elements, crps = crps_ensemble_all(pred_ens, true) # CRPS must be finite and non-negative self.assertEqual(num_elements, 2) self.assertTrue(torch.isfinite(crps)) self.assertGreaterEqual(crps.item(), 0.0) if self.logger: self.logger.info("✅ CRPS basic test passed")
[docs] def test_crps_zero_when_perfect_prediction(self): """Test CRPS is zero when all ensemble members equal truth.""" if self.logger: self.logger.info("Testing CRPS perfect prediction") true = torch.tensor([1.0, 2.0, 3.0]) pred_ens = torch.stack([true, true, true]) # N_ens = 3 num_elements, crps = crps_ensemble_all(pred_ens, true) self.assertEqual(num_elements, 3) self.assertAlmostEqual(crps.item(), 0.0, places=6) if self.logger: self.logger.info("✅ CRPS perfect prediction test passed")
[docs] def test_crps_equals_mae_for_single_member(self): """Test CRPS reduces to MAE when N_ens = 1.""" if self.logger: self.logger.info("Testing CRPS equals MAE for single ensemble member") true = torch.tensor([1.0, 2.0, 3.0]) pred = torch.tensor([1.5, 1.5, 2.5]) pred_ens = pred.unsqueeze(0) # [1, N_pixels] num_elements, crps = crps_ensemble_all(pred_ens, true) expected_mae = torch.mean(torch.abs(pred - true)) self.assertEqual(num_elements, 3) self.assertAlmostEqual(crps.item(), expected_mae.item(), places=6) if self.logger: self.logger.info("✅ CRPS single-member equals MAE test passed")
[docs] def test_crps_multi_dimensional_flatten(self): """Test CRPS with flattened multi-dimensional data.""" if self.logger: self.logger.info("Testing CRPS multi-dimensional flattened input") torch.manual_seed(0) true = torch.randn(4 * 5) pred_ens = torch.randn(6, 4 * 5) num_elements, crps = crps_ensemble_all(pred_ens, true) self.assertEqual(num_elements, 20) self.assertTrue(torch.isfinite(crps)) self.assertGreaterEqual(crps.item(), 0.0) if self.logger: self.logger.info("✅ CRPS multi-dimensional test passed")
[docs] def test_crps_dtype_preservation(self): """Test CRPS preserves floating point dtype.""" if self.logger: self.logger.info("Testing CRPS dtype preservation") true_f32 = torch.tensor([1.0, 2.0], dtype=torch.float32) pred_ens_f32 = torch.tensor([[1.5, 2.5]], dtype=torch.float32) _, crps_f32 = crps_ensemble_all(pred_ens_f32, true_f32) self.assertEqual(crps_f32.dtype, torch.float32) true_f64 = true_f32.double() pred_ens_f64 = pred_ens_f32.double() _, crps_f64 = crps_ensemble_all(pred_ens_f64, true_f64) self.assertEqual(crps_f64.dtype, torch.float64) if self.logger: self.logger.info("✅ CRPS dtype preservation test passed")
[docs] class TestDenormalizeFunction(unittest.TestCase): """Unit tests for denormalize function."""
[docs] def __init__(self, methodName="runTest", logger=None): super().__init__(methodName) self.logger = logger
[docs] def setUp(self): """Set up test fixtures.""" if self.logger: self.logger.info("Setting up denormalize test fixtures") self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
[docs] def test_denormalize_minmax(self): """Test denormalize with minmax normalization.""" if self.logger: self.logger.info("Testing denormalize - minmax") # Create test data data = torch.tensor([0.0, 0.5, 1.0], dtype=torch.float32).to(self.device) # Create stats object class Stats: vmin = 10.0 vmax = 20.0 stats = Stats() # Denormalize result = denormalize(data, stats, "minmax", self.device) # Expected: data * (vmax - vmin) + vmin # For values [0, 0.5, 1.0] and range [10, 20] expected = torch.tensor([10.0, 15.0, 20.0], dtype=torch.float32).to(self.device) torch.testing.assert_close(result, expected) if self.logger: self.logger.info("✅ denormalize minmax test passed")
[docs] def test_denormalize_minmax_11(self): """Test denormalize with minmax_11 normalization.""" if self.logger: self.logger.info("Testing denormalize - minmax_11") # Create test data in range [-1, 1] data = torch.tensor([-1.0, 0.0, 1.0], dtype=torch.float32).to(self.device) # Create stats object class Stats: vmin = 0.0 vmax = 100.0 stats = Stats() # Denormalize result = denormalize(data, stats, "minmax_11", self.device) # Expected: ((data + 1) / 2) * (vmax - vmin) + vmin # For values [-1, 0, 1] and range [0, 100] # (-1+1)/2 * 100 + 0 = 0 # (0+1)/2 * 100 + 0 = 50 # (1+1)/2 * 100 + 0 = 100 expected = torch.tensor([0.0, 50.0, 100.0], dtype=torch.float32).to(self.device) torch.testing.assert_close(result, expected) if self.logger: self.logger.info("✅ denormalize minmax_11 test passed")
[docs] def test_denormalize_standard(self): """Test denormalize with standard normalization.""" if self.logger: self.logger.info("Testing denormalize - standard") # Create test data (normalized, mean=0, std=1) data = torch.tensor([-2.0, 0.0, 2.0], dtype=torch.float32).to(self.device) # Create stats object class Stats: vmean = 50.0 vstd = 10.0 stats = Stats() # Denormalize result = denormalize(data, stats, "standard", self.device) # Expected: data * std + mean # For values [-2, 0, 2] with mean=50, std=10 expected = torch.tensor([30.0, 50.0, 70.0], dtype=torch.float32).to(self.device) torch.testing.assert_close(result, expected) if self.logger: self.logger.info("✅ denormalize standard test passed")
[docs] def test_denormalize_robust(self): """Test denormalize with robust normalization.""" if self.logger: self.logger.info("Testing denormalize - robust") # Create test data data = torch.tensor([-1.0, 0.0, 1.0], dtype=torch.float32).to(self.device) # Create stats object class Stats: median = 50.0 iqr = 20.0 stats = Stats() # Denormalize result = denormalize(data, stats, "robust", self.device) # Expected: data * iqr + median # For values [-1, 0, 1] with median=50, iqr=20 expected = torch.tensor([30.0, 50.0, 70.0], dtype=torch.float32).to(self.device) torch.testing.assert_close(result, expected) if self.logger: self.logger.info("✅ denormalize robust test passed")
[docs] def test_denormalize_log1p_minmax(self): """Test denormalize with log1p_minmax normalization.""" if self.logger: self.logger.info("Testing denormalize - log1p_minmax") # Normalized data data = torch.tensor([0.0, 0.5, 1.0], dtype=torch.float32).to(self.device) # Create stats object class Stats: # log1p(0) = 0 vmin = 0.0 vmax = torch.log1p(torch.tensor(9.0)).item() # log1p(9) ≈ 2.3026 stats = Stats() # Denormalize result = denormalize(data, stats, "log1p_minmax", self.device) # Expected values: # z=0 → log1p(x)=0, x=0 # z=0.5 → log1p(x) = 0.5 * 2.302585 = 1.1513, x = exp(1.151293) - 1 ≈ 2.1623 # z=1 → log1p(x)=~2.3026, x=9 expected = torch.expm1(data * (stats.vmax - stats.vmin) + stats.vmin) torch.testing.assert_close(result, expected, rtol=1e-4, atol=1e-4) if self.logger: self.logger.info("✅ denormalize log1p_minmax test passed")
[docs] def test_denormalize_log1p_standard(self): """Test denormalize with log1p_standard normalization.""" if self.logger: self.logger.info("Testing denormalize - log1p_standard") # Normalized data data = torch.tensor([-1.0, 0.0, 1.0], dtype=torch.float32).to(self.device) # Create stats object class Stats: vmean = torch.log1p(torch.tensor(4.0)).item() # log1p(4) ≈ 1.6094 vstd = 0.5 stats = Stats() # Denormalize result = denormalize(data, stats, "log1p_standard", self.device) # Expected: # z=-1 → log1p(x)=1.1094, x≈2.03 # z=0 → log1p(x)=1.6094, x=4 # z=1 → log1p(x)=2.1094, x≈7.24 expected = torch.tensor( [ torch.expm1(torch.tensor(1.6094 - 0.5)), 4.0, torch.expm1(torch.tensor(1.6094 + 0.5)), ], dtype=torch.float32, ).to(self.device) torch.testing.assert_close(result, expected, rtol=1e-4, atol=1e-4) if self.logger: self.logger.info("✅ denormalize log1p_standard test passed")
[docs] def test_denormalize_zero_denominator(self): """Test denormalize with zero denominator.""" if self.logger: self.logger.info("Testing denormalize - zero denominator") # Test minmax with zero range data = torch.tensor([0.5, 1.0, 1.5], dtype=torch.float32).to(self.device) class StatsZeroRange: vmin = 10.0 vmax = 10.0 # Zero range stats_zero = StatsZeroRange() result = denormalize(data, stats_zero, "minmax", self.device) expected_zero = torch.zeros_like(data).to(self.device) torch.testing.assert_close(result, expected_zero) # Test standard with zero std class StatsZeroStd: vmean = 50.0 vstd = 0.0 stats_zero_std = StatsZeroStd() result_std = denormalize(data, stats_zero_std, "standard", self.device) expected_zero_std = torch.zeros_like(data).to(self.device) torch.testing.assert_close(result_std, expected_zero_std) if self.logger: self.logger.info("✅ denormalize zero denominator test passed")
[docs] def test_denormalize_unsupported_type(self): """Test denormalize with unsupported normalization type.""" if self.logger: self.logger.info("Testing denormalize - unsupported type") data = torch.tensor([1.0], dtype=torch.float32).to(self.device) class Stats: pass stats = Stats() # Should raise ValueError for unsupported type with self.assertRaises(ValueError): denormalize(data, stats, "unsupported_type", self.device) if self.logger: self.logger.info("✅ denormalize unsupported type test passed")
[docs] class TestRunValidation(unittest.TestCase): """Unit tests for run_validation function focusing on return values verification."""
[docs] def __init__(self, methodName="runTest", logger=None): super().__init__(methodName) self.logger = logger
[docs] def setUp(self): """Set up test fixtures.""" if self.logger: self.logger.info("Setting up run_validation test fixtures") self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
[docs] def test_val_loss_and_metrics_across_3_batches_consistent_shape(self): """Verify avg_val_loss and val_metrics with 3 batches of consistent shape.""" if self.logger: self.logger.info( "Testing avg_val_loss and val_metrics with 3 consistent batches" ) # ======================================================================== # SETUP: Debug version - track everything # ======================================================================== # Create mock model mock_model = Mock() mock_loss_fn = Mock() mock_logger = Mock() mock_valid_dataset = Mock() mock_steps = Mock() # Consistent batch configuration batch_size = 2 num_channels = 5 num_vars = 2 H, W = 8, 8 # Create 3 batches batch1 = { "inputs": torch.randn(batch_size, num_channels, H, W).to(self.device), "targets": torch.randn(batch_size, num_vars, H, W).to(self.device), "coarse": torch.stack( [ torch.ones(batch_size, 1, H, W) * 10.0, torch.ones(batch_size, 1, H, W) * 20.0, ], dim=1, ) .squeeze(2) .to(self.device), "fine": torch.stack( [ torch.ones(batch_size, 1, H, W) * 12.0, torch.ones(batch_size, 1, H, W) * 22.0, ], dim=1, ) .squeeze(2) .to(self.device), "doy": torch.tensor([100, 150]).to(self.device), "hour": torch.tensor([12, 18]).to(self.device), "corrdinates": { "lat": torch.randn(batch_size, H, W).to(self.device), "lon": torch.randn(batch_size, H, W).to(self.device), }, } batch2 = { "inputs": torch.randn(batch_size, num_channels, H, W).to(self.device), "targets": torch.randn(batch_size, num_vars, H, W).to(self.device), "coarse": torch.stack( [ torch.ones(batch_size, 1, H, W) * 15.0, torch.ones(batch_size, 1, H, W) * 25.0, ], dim=1, ) .squeeze(2) .to(self.device), "fine": torch.stack( [ torch.ones(batch_size, 1, H, W) * 17.0, torch.ones(batch_size, 1, H, W) * 27.0, ], dim=1, ) .squeeze(2) .to(self.device), "doy": torch.tensor([200, 250]).to(self.device), "hour": torch.tensor([6, 12]).to(self.device), "corrdinates": { "lat": torch.randn(batch_size, H, W).to(self.device), "lon": torch.randn(batch_size, H, W).to(self.device), }, } batch3 = { "inputs": torch.randn(batch_size, num_channels, H, W).to(self.device), "targets": torch.randn(batch_size, num_vars, H, W).to(self.device), "coarse": torch.stack( [ torch.ones(batch_size, 1, H, W) * 20.0, torch.ones(batch_size, 1, H, W) * 30.0, ], dim=1, ) .squeeze(2) .to(self.device), "fine": torch.stack( [ torch.ones(batch_size, 1, H, W) * 21.0, torch.ones(batch_size, 1, H, W) * 31.0, ], dim=1, ) .squeeze(2) .to(self.device), "doy": torch.tensor([50, 100]).to(self.device), "hour": torch.tensor([20, 8]).to(self.device), "corrdinates": { "lat": torch.randn(batch_size, H, W).to(self.device), "lon": torch.randn(batch_size, H, W).to(self.device), }, } batch1["inputs"][:, :num_vars] = batch1["coarse"] batch2["inputs"][:, :num_vars] = batch2["coarse"] batch3["inputs"][:, :num_vars] = batch3["coarse"] mock_valid_loader = [batch1, batch2, batch3] # Mock args mock_args = Mock() mock_args.varnames_list = ["temp", "pressure"] mock_args.time_normalization = "linear" mock_args.debug = False mock_args.inference_type = "direct" # ======================================================================== # SETUP: Better mock tracking # ======================================================================== # Track loss calls loss_calls = [] def loss_fn_side_effect(model, targets, features, labels): # Determine which batch based on the hour values (unique per batch) hour_sum = ( labels[:, 1].sum().item() ) # labels is [batch_size, 2] where second column is hour # Map hour sum to batch if hour_sum == (12 + 18): # Batch 1 loss_value = 0.35 batch_num = 1 elif hour_sum == (6 + 12): # Batch 2 loss_value = 0.85 batch_num = 2 else: # Batch 3 loss_value = 1.35 batch_num = 3 # Create loss tensor loss_tensor = torch.full_like(targets, loss_value) loss_calls.append((batch_num, loss_value, loss_tensor.mean().item())) return loss_tensor mock_loss_fn.side_effect = loss_fn_side_effect mock_loss_fn.P_mean = 0.0 mock_loss_fn.P_std = 1.0 # Track model calls model_calls = [] def model_side_effect(x, sigma, condition_img=None, class_labels=None): batch_size_local = x.shape[0] # Simple: return based on call count call_num = len(model_calls) model_calls.append(call_num) if call_num == 0: # Batch 1 return ( torch.stack( [ torch.ones(batch_size_local, 1, H, W) * 1.0, torch.ones(batch_size_local, 1, H, W) * 2.0, ], dim=1, ) .squeeze(2) .to(self.device) ) elif call_num == 1: # Batch 2 return ( torch.stack( [ torch.ones(batch_size_local, 1, H, W) * 1.5, torch.ones(batch_size_local, 1, H, W) * 2.5, ], dim=1, ) .squeeze(2) .to(self.device) ) else: # Batch 3 return ( torch.stack( [ torch.ones(batch_size_local, 1, H, W) * 0.5, torch.ones(batch_size_local, 1, H, W) * 1.5, ], dim=1, ) .squeeze(2) .to(self.device) ) mock_model.side_effect = model_side_effect # Mock normalization mock_norm_mapping = {} mock_normalization_type = {} mock_index_mapping = {} class MockStats: vmin = 0.0 vmax = 1.0 vmean = 0.0 vstd = 1.0 mock_norm_mapping["temp_fine"] = MockStats() mock_norm_mapping["pressure_fine"] = MockStats() mock_normalization_type["temp"] = "minmax" mock_normalization_type["pressure"] = "minmax" mock_index_mapping["temp"] = 0 mock_index_mapping["pressure"] = 1 # ======================================================================== # EXECUTE: Run validation # ======================================================================== with patch("tqdm.tqdm", side_effect=lambda x, **kwargs: x): with patch("torch.amp.autocast"): avg_val_loss, val_metrics = run_validation( model=mock_model, valid_dataset=mock_valid_dataset, valid_loader=mock_valid_loader, loss_fn=mock_loss_fn, norm_mapping=mock_norm_mapping, normalization_type=mock_normalization_type, index_mapping=mock_index_mapping, args=mock_args, steps=mock_steps, device=self.device, logger=mock_logger, epoch=1, writer=None, plot_every_n_epochs=None, paths=None, ) # ======================================================================== # DEBUG: Print what happened # ======================================================================== if self.logger: self.logger.info(f"Loss calls: {loss_calls}") self.logger.info(f"Model calls: {model_calls}") self.logger.info(f"avg_val_loss = {avg_val_loss}") # Calculate expected based on actual loss values if loss_calls: actual_losses = [loss_entry[1] for loss_entry in loss_calls] expected = sum( loss_value * batch_size for loss_value in actual_losses ) / (len(actual_losses) * batch_size) self.logger.info(f"Actual losses: {actual_losses}") self.logger.info(f"Expected avg: {expected}") # ======================================================================== # VERIFICATION 1: avg_val_loss calculation # ======================================================================== # Get the actual loss values that were used actual_loss_values = ( [loss_entry[1] for loss_entry in loss_calls] if loss_calls else [0.35, 0.85, 1.35] ) # Calculate expected based on ACTUAL values expected_avg_loss = sum( loss_value * batch_size for loss_value in actual_loss_values ) / (len(actual_loss_values) * batch_size) self.assertAlmostEqual( avg_val_loss, expected_avg_loss, places=5, msg=f"Expected avg_val_loss={expected_avg_loss:.5f} (based on losses {actual_loss_values}), got {avg_val_loss:.5f}", ) if self.logger: self.logger.info( f"✅ avg_val_loss verified: {avg_val_loss:.5f} (expected: {expected_avg_loss:.5f})" ) # ======================================================================== # VERIFICATION 2: Compute expected MAE values manually # ======================================================================== # Each batch has 2 samples × 8×8 elements = 128 elements per batch per variable # Total elements per variable: 3 batches × 128 = 384 # TEMP variable calculations: # Batch 1: coarse=10, residual=1.0, pred=11.0, fine=12.0 → MAE=|11-12|=1.0 # Batch 2: coarse=15, residual=1.5, pred=16.5, fine=17.0 → MAE=|16.5-17|=0.5 # Batch 3: coarse=20, residual=0.5, pred=20.5, fine=21.0 → MAE=|20.5-21|=0.5 # Element-weighted temp pred MAE: # (1.0×128 + 0.5×128 + 0.5×128) / 384 = (128 + 64 + 64)/384 = 256/384 = 0.6666667 expected_temp_pred_mae = 256 / 384 # Temp coarse MAE: # Batch 1: |10-12|=2.0, Batch 2: |15-17|=2.0, Batch 3: |20-21|=1.0 # (2.0×128 + 2.0×128 + 1.0×128)/384 = (256 + 256 + 128)/384 = 640/384 = 1.6666667 expected_temp_coarse_mae = 640 / 384 # PRESSURE variable calculations: # Batch 1: coarse=20, residual=2.0, pred=22.0, fine=22.0 → MAE=|22-22|=0.0 # Batch 2: coarse=25, residual=2.5, pred=27.5, fine=27.0 → MAE=|27.5-27|=0.5 # Batch 3: coarse=30, residual=1.5, pred=31.5, fine=31.0 → MAE=|31.5-31|=0.5 # Element-weighted pressure pred MAE: # (0.0×128 + 0.5×128 + 0.5×128)/384 = (0 + 64 + 64)/384 = 128/384 = 0.3333333 expected_pressure_pred_mae = 128 / 384 # Pressure coarse MAE (same as temp): # (2.0×128 + 2.0×128 + 1.0×128)/384 = 640/384 = 1.6666667 expected_pressure_coarse_mae = 640 / 384 # ======================================================================== # VERIFICATION 3: Verify per-variable MAE values # ======================================================================== # Get actual values actual_temp_pred = val_metrics["temp_pred_vs_fine_MAE"].getmean() actual_temp_coarse = val_metrics["temp_coarse_vs_fine_MAE"].getmean() actual_pressure_pred = val_metrics["pressure_pred_vs_fine_MAE"].getmean() actual_pressure_coarse = val_metrics["pressure_coarse_vs_fine_MAE"].getmean() # Verify temp MAE self.assertAlmostEqual( actual_temp_pred, expected_temp_pred_mae, places=5, msg=f"Temp pred MAE: expected {expected_temp_pred_mae:.5f}, got {actual_temp_pred:.5f}", ) self.assertAlmostEqual( actual_temp_coarse, expected_temp_coarse_mae, places=5, msg=f"Temp coarse MAE: expected {expected_temp_coarse_mae:.5f}, got {actual_temp_coarse:.5f}", ) # Verify pressure MAE self.assertAlmostEqual( actual_pressure_pred, expected_pressure_pred_mae, places=5, msg=f"Pressure pred MAE: expected {expected_pressure_pred_mae:.5f}, got {actual_pressure_pred:.5f}", ) self.assertAlmostEqual( actual_pressure_coarse, expected_pressure_coarse_mae, places=5, msg=f"Pressure coarse MAE: expected {expected_pressure_coarse_mae:.5f}, got {actual_pressure_coarse:.5f}", ) if self.logger: self.logger.info("✅ Per-variable MAE verified:") self.logger.info( f" └── temp_pred: {actual_temp_pred:.5f} (expected: {expected_temp_pred_mae:.5f})" ) self.logger.info( f" └── temp_coarse: {actual_temp_coarse:.5f} (expected: {expected_temp_coarse_mae:.5f})" ) self.logger.info( f" └── pressure_pred: {actual_pressure_pred:.5f} (expected: {expected_pressure_pred_mae:.5f})" ) self.logger.info( f" └── pressure_coarse: {actual_pressure_coarse:.5f} (expected: {expected_pressure_coarse_mae:.5f})" ) # ======================================================================== # VERIFICATION 4: Compute and verify average metrics # ======================================================================== # Average metrics are computed per-batch (not weighted by elements) # Let's compute expected batch-level averages: # Batch 1 averages: # - pred: (temp_pred=1.0, pressure_pred=0.0) → avg = (1.0+0.0)/2 = 0.5 # - coarse: (temp_coarse=2.0, pressure_coarse=2.0) → avg = (2.0+2.0)/2 = 2.0 # Batch 2 averages: # - pred: (0.5 + 0.5)/2 = 0.5 # - coarse: (2.0 + 2.0)/2 = 2.0 # Batch 3 averages: # - pred: (0.5 + 0.5)/2 = 0.5 # - coarse: (1.0 + 1.0)/2 = 1.0 # Overall averages (simple mean across batches): expected_avg_pred = (0.5 + 0.5 + 0.5) / 3 # = 0.5 expected_avg_coarse = (2.0 + 2.0 + 1.0) / 3 # = 1.6666667 actual_avg_pred = val_metrics["average_pred_vs_fine_MAE"].getmean() actual_avg_coarse = val_metrics["average_coarse_vs_fine_MAE"].getmean() self.assertAlmostEqual( actual_avg_pred, expected_avg_pred, places=5, msg=f"Avg pred MAE: expected {expected_avg_pred:.5f}, got {actual_avg_pred:.5f}", ) self.assertAlmostEqual( actual_avg_coarse, expected_avg_coarse, places=5, msg=f"Avg coarse MAE: expected {expected_avg_coarse:.5f}, got {actual_avg_coarse:.5f}", ) if self.logger: self.logger.info("✅ Average metrics verified:") self.logger.info( f" └── avg_pred: {actual_avg_pred:.5f} (expected: {expected_avg_pred:.5f})" ) self.logger.info( f" └── avg_coarse: {actual_avg_coarse:.5f} (expected: {expected_avg_coarse:.5f})" ) # ======================================================================== # VERIFICATION 5: Verify MetricTracker counts # ======================================================================== # Per-variable trackers should count total elements # 3 batches × 2 samples × 8×8 elements = 384 elements per variable total_elements_per_var = 3 * batch_size * H * W # 384 self.assertEqual( val_metrics["temp_pred_vs_fine_MAE"].count, total_elements_per_var, f"Temp pred tracker count should be {total_elements_per_var}, got {val_metrics['temp_pred_vs_fine_MAE'].count}", ) self.assertEqual( val_metrics["temp_coarse_vs_fine_MAE"].count, total_elements_per_var, f"Temp coarse tracker count should be {total_elements_per_var}, got {val_metrics['temp_coarse_vs_fine_MAE'].count}", ) self.assertEqual( val_metrics["pressure_pred_vs_fine_MAE"].count, total_elements_per_var, f"Pressure pred tracker count should be {total_elements_per_var}, got {val_metrics['pressure_pred_vs_fine_MAE'].count}", ) self.assertEqual( val_metrics["pressure_coarse_vs_fine_MAE"].count, total_elements_per_var, f"Pressure coarse tracker count should be {total_elements_per_var}, got {val_metrics['pressure_coarse_vs_fine_MAE'].count}", ) # Average trackers should count number of batches self.assertEqual( val_metrics["average_pred_vs_fine_MAE"].count, 3, f"Avg pred tracker count should be 3, got {val_metrics['average_pred_vs_fine_MAE'].count}", ) self.assertEqual( val_metrics["average_coarse_vs_fine_MAE"].count, 3, f"Avg coarse tracker count should be 3, got {val_metrics['average_coarse_vs_fine_MAE'].count}", ) if self.logger: self.logger.info( f"✅ Tracker counts verified: per-var={total_elements_per_var}, avg=3" ) # ======================================================================== # VERIFICATION 6: Verify all MetricTracker values are positive # ======================================================================== for key, tracker in val_metrics.items(): if tracker.count > 0: value = tracker.getmean() # R2 can be negative when the model performs worse than predicting the mean. # # PEARSON can be negative or undefined (NaN when variance is zero). # Therefore, we only enforce non-negativity for error-based metrics. if "R2" not in key and "PEARSON" not in key: self.assertGreaterEqual( value, 0.0, f"{key} should be non-negative, got {value}" ) if self.logger: self.logger.info( "✅ All error-based metric values (except R2 and PEARSON) are non-negative" ) # ======================================================================== # VERIFICATION 7: Verify function call counts # ======================================================================== self.assertEqual( mock_loss_fn.call_count, 3, f"loss_fn should be called 3 times, got {mock_loss_fn.call_count}", ) self.assertEqual( mock_model.call_count, 3, f"model should be called 3 times, got {mock_model.call_count}", ) if self.logger: self.logger.info( f"✅ Function calls verified: loss_fn={mock_loss_fn.call_count}, model={mock_model.call_count}" ) # ======================================================================== # VERIFICATION 8: Final summary # ======================================================================== if self.logger: self.logger.info("=" * 60) self.logger.info("COMPREHENSIVE VERIFICATION SUMMARY:") self.logger.info("=" * 60) self.logger.info(f"avg_val_loss: {avg_val_loss:.5f}") self.logger.info(f"temp_pred_mae: {actual_temp_pred:.5f}") self.logger.info(f"temp_coarse_mae: {actual_temp_coarse:.5f}") self.logger.info(f"pressure_pred_mae: {actual_pressure_pred:.5f}") self.logger.info(f"pressure_coarse_mae: {actual_pressure_coarse:.5f}") self.logger.info(f"avg_pred_mae: {actual_avg_pred:.5f}") self.logger.info(f"avg_coarse_mae: {actual_avg_coarse:.5f}") self.logger.info(f"Tracker counts: per-var={total_elements_per_var}, avg=3") self.logger.info( f"Function calls: loss_fn={mock_loss_fn.call_count}, model={mock_model.call_count}" ) self.logger.info("=" * 60) self.logger.info("✅ ALL VERIFICATIONS PASSED!")
[docs] def test_crps_zero_when_predictions_equal_fine(self): """ Verify that CRPS is zero when all ensemble predictions exactly match the fine target. """ if self.logger: self.logger.info("Testing CRPS = 0 when predictions equal fine") mock_model = Mock() mock_loss_fn = Mock() mock_logger = Mock() mock_steps = Mock() batch_size = 2 num_channels = 5 num_vars = 1 H, W = 8, 8 fine = torch.ones(batch_size, num_vars, H, W, device=self.device) coarse = torch.zeros_like(fine) batch = { "inputs": torch.zeros(batch_size, num_channels, H, W, device=self.device), "targets": torch.zeros(batch_size, num_vars, H, W, device=self.device), "coarse": coarse, "fine": fine, "doy": torch.tensor([100, 200], device=self.device), "hour": torch.tensor([12, 18], device=self.device), "corrdinates": { "lat": torch.zeros(batch_size, H, device=self.device), "lon": torch.zeros(batch_size, W, device=self.device), }, } mock_valid_loader = [batch] # Sampler always returns perfect residuals def mock_sampler(*args, **kwargs): return fine - coarse # exact residuals mock_loss_fn.return_value = torch.zeros_like(batch["targets"]) mock_loss_fn.P_mean = 0.0 mock_loss_fn.P_std = 1.0 class MockStats: vmin = 0.0 vmax = 1.0 # mock_norm_mapping = {"var_residual": MockStats()} mock_norm_mapping = {"var_fine": MockStats()} mock_normalization_type = {"var": "minmax"} mock_index_mapping = {"var": 0} mock_args = Mock() mock_args.varnames_list = ["var"] mock_args.time_normalization = "linear" mock_args.inference_type = "sampler" mock_args.debug = False with patch("tqdm.tqdm", side_effect=lambda x, **kwargs: x): with patch(__name__ + ".sampler", side_effect=mock_sampler): with patch("torch.amp.autocast"): _, val_metrics = run_validation( model=mock_model, valid_dataset=Mock(), valid_loader=mock_valid_loader, loss_fn=mock_loss_fn, norm_mapping=mock_norm_mapping, normalization_type=mock_normalization_type, index_mapping=mock_index_mapping, args=mock_args, steps=mock_steps, device=self.device, logger=mock_logger, epoch=0, compute_crps=True, crps_batch_size=1, crps_ensemble_size=5, writer=None, plot_every_n_epochs=None, paths=None, ) for var_name in mock_args.varnames_list: key = f"{var_name}_pred_vs_fine_CRPS" self.assertIn(key, val_metrics) crps_value = val_metrics[key].getmean() self.assertAlmostEqual(crps_value, 0.0, places=6) if self.logger: self.logger.info("✅ CRPS test passed (predictions == fine, CRPS = 0)")
[docs] def test_generate_residuals_matches_fine(self): """ Final prediction (coarse + residuals) should exactly match fine when residuals = fine - coarse. """ batch_size = 2 H, W = 4, 4 num_vars = 2 device = self.device fine = torch.ones(batch_size, num_vars, H, W, device=device) coarse = torch.zeros_like(fine) targets = torch.zeros_like(fine) features = torch.zeros(batch_size, 1, H, W, device=device) labels = torch.zeros(batch_size, 2, device=device) # Sampler always returns perfect residuals def mock_sampler(*args, **kwargs): return fine - coarse mock_model = Mock() mock_loss_fn = Mock() mock_loss_fn.P_mean = 0.0 mock_loss_fn.P_std = 1.0 args = Mock() args.varnames_list = ["var1", "var2"] args.debug = False with patch(__name__ + ".sampler", side_effect=mock_sampler): generated_residuals = generate_residuals_norm( model=mock_model, features=features, labels=labels, targets=targets, loss_fn=mock_loss_fn, args=args, device=device, logger=Mock(), inference_type="sampler", ) # Shape check (residuals) self.assertEqual(generated_residuals.shape, fine.shape) # Reconstruct final prediction final_pred = coarse + generated_residuals # Exact reconstruction self.assertTrue(torch.allclose(final_pred, fine, atol=1e-6)) if self.logger: self.logger.info("✅ The reconstructed prediction matches the fine data")
[docs] def tearDown(self): """Clean up after tests.""" if self.logger: self.logger.info("Evaluater tests completed successfully")
# ----------------------------------------------------------------------------