"""
Evaluation utilities for RTnn model assessment.
This module provides comprehensive evaluation tools for radiative transfer
neural network models, including custom loss functions, metric computation,
and visualization helpers.
The module includes:
- Custom loss functions (NMSE, NMAE, combined MSE-MAE, LogCosh, Weighted MSE)
- Metric calculators for evaluation (MSE, MAE, MBE, R², NMSE, NMAE, MARE, GMRAE)
- Data normalization/de-normalization utilities
- Absorption rate calculations
- Main evaluation loop for LSM models
Dependencies
------------
torch : For tensor operations and loss functions
numpy : For numerical operations
plot_helper : For visualization utilities
"""
import torch
import torch.nn as nn
import numpy as np
import sys
import os
from tqdm import tqdm
sys.path.append("..")
from rtnn.diagnostics import plot_flux_and_abs, plot_flux_and_abs_lines
[docs]
class NMSELoss(nn.Module):
"""
Normalized Mean Squared Error Loss.
Computes MSE normalized by the mean square of the target values.
Useful when the scale of the target variable varies.
Parameters
----------
eps : float, optional
Small constant for numerical stability. Default is 1e-8.
Examples
--------
>>> criterion = NMSELoss()
>>> loss = criterion(predictions, targets)
"""
[docs]
def __init__(self, eps=1e-8):
super(NMSELoss, self).__init__()
self.eps = eps
self.mse = nn.MSELoss()
[docs]
def forward(self, pred, target):
mse = self.mse(pred, target)
norm = torch.mean(target**2) + self.eps
return mse / norm
[docs]
class NMAELoss(nn.Module):
"""
Normalized Mean Absolute Error Loss.
Computes MAE normalized by the mean absolute value of the target.
Provides a scale-invariant error metric.
Parameters
----------
eps : float, optional
Small constant for numerical stability. Default is 1e-8.
"""
[docs]
def __init__(self, eps=1e-8):
super(NMAELoss, self).__init__()
self.eps = eps
self.l1 = nn.L1Loss()
[docs]
def forward(self, pred, target):
mae = self.l1(pred, target)
norm = torch.mean(torch.abs(target)) + self.eps
return mae / norm
[docs]
class MetricTracker:
"""
A utility class for tracking and computing statistics of metric values.
This class maintains a running average of metric values and provides
methods to compute mean and root mean squared values.
Attributes
----------
value : float
Cumulative weighted sum of metric values
count : int
Total number of samples processed
Examples
--------
>>> tracker = MetricTracker()
>>> tracker.update(10.0, 5) # value=10.0, count=5 samples
>>> tracker.update(20.0, 3) # value=20.0, count=3 samples
>>> print(tracker.getmean()) # (10*5 + 20*3) / (5+3) = 110/8 = 13.75
13.75
>>> print(tracker.getsqrtmean()) # sqrt(13.75)
3.7080992435478315
"""
[docs]
def __init__(self):
"""
Initialize MetricTracker with zero values.
"""
self.reset()
[docs]
def reset(self):
"""
Reset all tracked values to zero.
Returns
-------
None
"""
self.value = 0.0
self.count = 0
self.value_sq = 0.0
[docs]
def update(self, value, count):
"""
Update the tracker with new metric values.
Parameters
----------
value : float
The metric value to add
count : int
Number of samples this value represents (weight)
Returns
-------
None
"""
self.count += count
self.value += value * count
self.value_sq += (value**2) * count
[docs]
def getmean(self):
"""
Calculate the mean of all tracked values.
Returns
-------
float
Weighted mean of all values: total_value / total_count
Raises
------
ZeroDivisionError
If no values have been added (count == 0)
"""
if self.count == 0:
raise ZeroDivisionError("Cannot compute mean with zero samples")
return self.value / self.count
[docs]
def getstd(self):
"""
Calculate the standard deviation of all tracked values.
Returns
-------
float
Weighted standard deviation of all values:
sqrt(E(x^2) - (E(x))^2)
Raises
------
ZeroDivisionError
If no values have been added (count == 0)
"""
if self.count == 0:
raise ZeroDivisionError("Cannot compute std with zero samples")
mean = self.getmean()
variance = self.value_sq / self.count - mean**2
return np.sqrt(max(variance, 0.0)) # numerical safety
[docs]
def getsqrtmean(self):
"""
Calculate the square root of the mean of all tracked values.
Returns
-------
float
Square root of the weighted mean: sqrt(total_value / total_count)
Raises
------
ZeroDivisionError
If no values have been added (count == 0)
"""
return np.sqrt(self.getmean())
[docs]
def get_loss_function(loss_type, args, logger=None):
"""
Factory function to instantiate the requested loss function.
Parameters
----------
loss_type : str
Type of loss function. Options:
- 'mse': Mean Squared Error
- 'mae': Mean Absolute Error
- 'nmae': Normalized Mean Absolute Error
- 'nmse': Normalized Mean Squared Error
- 'wmse': Weighted Mean Squared Error
- 'logcosh': Log-Cosh loss
- 'smoothl1': Smooth L1 Loss (Huber-like)
- 'huber': Huber Loss
args : argparse.Namespace
Arguments containing loss-specific parameters (e.g., beta_delta for Huber).
Returns
-------
torch.nn.Module
Initialized loss function.
Raises
------
ValueError
If loss_type is not supported or required parameters are missing.
Examples
--------
>>> args = argparse.Namespace(beta_delta=1.0)
>>> criterion = get_loss_function('huber', args)
"""
if loss_type == "mse":
if logger:
logger.info("Using Mean Squared Error (MSE) loss")
return nn.MSELoss()
elif loss_type == "mae":
if logger:
logger.info("Using Mean Absolute Error (MAE) loss")
return nn.L1Loss()
elif loss_type == "nmae":
if logger:
logger.info("Using Normalized Mean Absolute Error (NMAE) loss")
return NMAELoss()
elif loss_type == "nmse":
if logger:
logger.info("Using Normalized Mean Squared Error (NMSE) loss")
return NMSELoss()
elif loss_type in ["smoothl1", "huber"]:
if not hasattr(args, "beta_delta"):
raise ValueError(f"{loss_type.capitalize()}Loss requires --beta_delta")
if logger:
logger.info(
f"Using {loss_type.capitalize()} loss with delta={args.beta_delta}"
)
return (
nn.SmoothL1Loss(beta=args.beta_delta)
if loss_type == "smoothl1"
else nn.HuberLoss(delta=args.beta_delta)
)
else:
raise ValueError(f"Unsupported loss type: {loss_type}")
[docs]
def mse_all(pred, true):
"""
Compute Mean Squared Error.
Parameters
----------
pred : torch.Tensor
Predictions.
true : torch.Tensor
Ground truth.
Returns
-------
tuple
(num_elements, mse_value)
"""
return pred.numel(), torch.mean((pred - true) ** 2)
[docs]
def mbe_all(pred, true):
"""
Compute Mean Bias Error.
Parameters
----------
pred : torch.Tensor
Predictions.
true : torch.Tensor
Ground truth.
Returns
-------
tuple
(num_elements, mbe_value)
"""
return pred.numel(), torch.mean(pred - true)
[docs]
def mae_all(pred, true):
"""
Compute Mean Absolute Error.
Parameters
----------
pred : torch.Tensor
Predictions.
true : torch.Tensor
Ground truth.
Returns
-------
tuple
(num_elements, mae_value)
"""
return pred.numel(), torch.mean(torch.abs(pred - true))
[docs]
def r2_all(pred, true):
"""
Calculate R2 (coefficient of determination) between predicted and true values.
Computes the R2 metric and returns both the number of elements and
the R2 value.
Parameters
----------
pred : torch.Tensor
Predicted values from the model
true : torch.Tensor
Ground truth values
Returns
-------
tuple
(num_elements, r2_value) where:
- num_elements (int): Total number of elements in the tensors
- r2_value (torch.Tensor): R2 score
Notes
-----
R2 is calculated as:
R2 = 1 - sum((true - pred)^2) / sum((true - mean(true))^2)
This implementation is fully torch-based and works on CPU and GPU.
"""
if pred.shape != true.shape:
raise RuntimeError(f"Shape mismatch: pred {pred.shape} vs true {true.shape}")
eps = 1e-12 # Small value to avoid division by zero when variance is zero
num_elements = pred.numel()
# Flatten
pred_flat = pred.reshape(-1)
true_flat = true.reshape(-1)
# Residual sum of squares
ss_res = torch.sum((true_flat - pred_flat) ** 2)
# Total sum of squares
true_mean = torch.mean(true_flat)
ss_tot = torch.sum((true_flat - true_mean) ** 2)
# R2 score
r2_value = 1.0 - ss_res / (ss_tot + eps)
return num_elements, r2_value
[docs]
def nmae_all(pred, true):
"""
Compute Normalized Mean Absolute Error.
Parameters
----------
pred : torch.Tensor
Predictions.
true : torch.Tensor
Ground truth.
Returns
-------
tuple
(num_elements, nmae_value)
"""
mae = torch.mean(torch.abs(pred - true))
norm = torch.mean(torch.abs(true)) + 1e-8
nmae = mae / norm
return pred.numel(), nmae
[docs]
def nmse_all(pred, true):
"""
Compute Normalized Mean Squared Error.
Parameters
----------
pred : torch.Tensor
Predictions.
true : torch.Tensor
Ground truth.
Returns
-------
tuple
(num_elements, nmse_value)
"""
mse = torch.mean((pred - true) ** 2)
norm = torch.mean(true**2) + 1e-8
nmse = mse / norm
return pred.numel(), nmse
[docs]
def mare_all(pred, true):
"""
Compute Mean Absolute Relative Error.
Parameters
----------
pred : torch.Tensor
Predictions.
true : torch.Tensor
Ground truth.
Returns
-------
tuple
(num_elements, mare_value)
"""
relative_error = torch.abs(pred - true) / (torch.abs(true) + 1e-8)
mare = torch.mean(relative_error)
return pred.numel(), mare
[docs]
def gmrae_all(pred, true):
"""
Compute Geometric Mean Relative Absolute Error.
Parameters
----------
pred : torch.Tensor
Predictions.
true : torch.Tensor
Ground truth.
Returns
-------
tuple
(num_elements, gmrae_value)
"""
eps = 1e-8
relative_errors = torch.abs(pred - true) / (torch.abs(true) + eps)
log_rel_errors = torch.log(relative_errors + eps)
gmrae = torch.exp(torch.mean(log_rel_errors))
return pred.numel(), gmrae
[docs]
def unnorm_mpas(pred, targ, norm_mapping, normalization_type, idxmap):
"""
Reverse normalization to recover original physical values.
Applies inverse transformation based on the normalization method used
during preprocessing.
Parameters
----------
pred : torch.Tensor
Normalized predictions of shape (batch, channels, seq_length).
targ : torch.Tensor
Normalized targets.
norm_mapping : dict
Dictionary containing normalization statistics for each variable.
normalization_type : dict
Dictionary mapping variable names to normalization types.
idxmap : dict
Dictionary mapping channel indices to variable names.
Returns
-------
tuple
(unnormalized_predictions, unnormalized_targets)
Supported normalization types:
- minmax: x * (max - min) + min
- standard: x * std + mean
- robust: x * iqr + median
- log1p_*: expm1(x * scale + offset)
- sqrt_*: (x * scale + offset) ** 2
Examples
--------
>>> idxmap = {0: 'collim_alb', 1: 'collim_tran'}
>>> upred, utarg = unnorm_mpas(pred, targ, norm_mapping, norm_type, idxmap)
"""
device = pred.device
upred = torch.zeros_like(pred, device=device)
utarg = torch.zeros_like(targ, device=device)
for i, var_name in idxmap.items():
norm_type = normalization_type.get(var_name, "log1p_minmax")
norm = norm_mapping[var_name]
if norm_type == "standard":
mean = norm["vmean"]
std = norm["vstd"]
upred[:, i, :] = pred[:, i, :] * std + mean
utarg[:, i, :] = targ[:, i, :] * std + mean
elif norm_type == "minmax":
vmin = norm["vmin"]
vmax = norm["vmax"]
upred[:, i, :] = pred[:, i, :] * (vmax - vmin) + vmin
utarg[:, i, :] = targ[:, i, :] * (vmax - vmin) + vmin
elif norm_type == "robust":
median = norm["median"]
iqr = norm["iqr"]
upred[:, i, :] = pred[:, i, :] * iqr + median
utarg[:, i, :] = targ[:, i, :] * iqr + median
elif norm_type == "log1p_minmax":
log_min = norm["log_min"]
log_max = norm["log_max"]
unnorm_pred = pred[:, i, :] * (log_max - log_min) + log_min
unnorm_targ = targ[:, i, :] * (log_max - log_min) + log_min
upred[:, i, :] = torch.expm1(unnorm_pred)
utarg[:, i, :] = torch.expm1(unnorm_targ)
elif norm_type == "log1p_standard":
mean = norm["log_mean"]
std = norm["log_std"]
unnorm_pred = pred[:, i, :] * std + mean
unnorm_targ = targ[:, i, :] * std + mean
upred[:, i, :] = torch.expm1(unnorm_pred)
utarg[:, i, :] = torch.expm1(unnorm_targ)
elif norm_type == "log1p_robust":
median = norm["log_median"]
iqr = norm["log_iqr"]
unnorm_pred = pred[:, i, :] * iqr + median
unnorm_targ = targ[:, i, :] * iqr + median
upred[:, i, :] = torch.expm1(unnorm_pred)
utarg[:, i, :] = torch.expm1(unnorm_targ)
elif norm_type == "sqrt_minmax":
sqrt_min = norm["sqrt_min"]
sqrt_max = norm["sqrt_max"]
unnorm_pred = pred[:, i, :] * (sqrt_max - sqrt_min) + sqrt_min
unnorm_targ = targ[:, i, :] * (sqrt_max - sqrt_min) + sqrt_min
upred[:, i, :] = unnorm_pred**2
utarg[:, i, :] = unnorm_targ**2
elif norm_type == "sqrt_standard":
mean = norm["sqrt_mean"]
std = norm["sqrt_std"]
unnorm_pred = pred[:, i, :] * std + mean
unnorm_targ = targ[:, i, :] * std + mean
upred[:, i, :] = unnorm_pred**2
utarg[:, i, :] = unnorm_targ**2
elif norm_type == "sqrt_robust":
median = norm["sqrt_median"]
iqr = norm["sqrt_iqr"]
unnorm_pred = pred[:, i, :] * iqr + median
unnorm_targ = targ[:, i, :] * iqr + median
upred[:, i, :] = unnorm_pred**2
utarg[:, i, :] = unnorm_targ**2
else:
raise ValueError(
f"Unsupported normalization type '{norm_type}' for variable '{var_name}'"
)
return upred, utarg
[docs]
def calc_abs(pred, targ, p=None):
"""
Calculate absorption rates from flux predictions.
Computes net absorption rates for two channel groups (1-2 and 3-4) using
the difference between upwelling and downwelling fluxes.
Parameters
----------
pred : torch.Tensor
Predictions of shape (batch, 4, seq_length) where channels 0-1 are
for first group and 2-3 for second group.
targ : torch.Tensor
Targets of same shape as pred.
p : torch.Tensor, optional
Pressure levels for atmospheric heating rate calculation.
If provided, computes heating rate using pressure gradients.
Returns
-------
tuple
(abs12_pred, abs12_targ, abs34_pred, abs34_targ) where each is a
tensor of shape (batch, 1, seq_length-1).
Notes
-----
- If p is None: returns d(net) where net = up - down
- If p is provided: returns heating rate using d(net)/dp
"""
abs12_pred = calc_hr(pred[:, 0:1, :], pred[:, 1:2, :], p)
abs12_targ = calc_hr(targ[:, 0:1, :], targ[:, 1:2, :], p)
abs34_pred = calc_hr(pred[:, 2:3, :], pred[:, 3:4, :], p)
abs34_targ = calc_hr(targ[:, 2:3, :], targ[:, 3:4, :], p)
return abs12_pred, abs12_targ, abs34_pred, abs34_targ
[docs]
def calc_hr(up, down, p=None):
"""
Calculate heating rate from upwelling and downwelling fluxes.
Computes the net radiative heating rate by taking the vertical derivative
of net flux (upwelling - downwelling). If pressure levels are provided,
calculates the actual atmospheric heating rate using pressure gradients.
Parameters
----------
up : torch.Tensor
Upwelling flux tensor of shape (batch, channels, seq_length).
down : torch.Tensor
Downwelling flux tensor of shape (batch, channels, seq_length).
p : torch.Tensor, optional
Pressure levels of shape (seq_length,) or (batch, seq_length).
If provided, computes physical heating rate. If None, returns
the derivative of net flux.
Returns
-------
torch.Tensor
If p is None:
Returns the negative derivative of net flux with respect to
level index: -d(net)/dz (or d(net)/d(level)) of shape
(batch, channels, seq_length - 1)
If p is provided:
Returns the atmospheric heating rate in K/day using the formula:
heating_rate = (g * 8.64e4) / (cp * 100) * d(net)/dp
where g = 9.8066 m/s², cp = 1004 J/(kg·K) (calculated as 7*R/2 with R=287)
Notes
-----
- The derivative is computed using finite differences: net[i+1] - net[i]
- For pressure-based calculation, uses dp = p[i+1] - p[i]
- The factor 8.64e4 converts from W/m² to K/day
- The factor 100 converts pressure from hPa to Pa
Examples
--------
>>> # Calculate net flux derivative
>>> hr = calc_hr(up, down)
>>> hr.shape
torch.Size([32, 4, 9]) # for seq_length=10
>>> # Calculate actual heating rate with pressure levels
>>> pressure = torch.linspace(1000, 100, 10) # hPa
>>> heating_rate = calc_hr(up, down, p=pressure)
>>> heating_rate.shape
torch.Size([32, 4, 9])
"""
net = up - down
dnet = net - torch.roll(net, 1, 2)
if p is not None:
g = 9.8066
r = 287.0
cp = 7.0 * r / 2.0
fac = g * 8.64e4 / (cp * 100)
dp = p - torch.roll(p, 1, 2)
return dnet[:, :, 1:] / dp[:, :, 1:] * fac
else:
return -dnet[:, :, 1:]
[docs]
def run_validation(
loader, model, norm_mapping, normalization_type, index_mapping, device, args, epoch
):
"""
Evaluate model accuracy on LSM dataset.
Performs comprehensive evaluation including:
- Loss computation for main fluxes and absorption rates
- Metric calculation (NMAE, NMSE, R²)
- Optional plotting of predictions vs targets
Parameters
----------
loader : torch.utils.data.DataLoader
Data loader for evaluation dataset.
model : torch.nn.Module
Trained model to evaluate.
norm_mapping : dict
Normalization statistics for variables.
normalization_type : dict
Normalization types per variable.
index_mapping : dict
Mapping from channel indices to variable names.
device : torch.device
Device to run evaluation on.
args : argparse.Namespace
Arguments containing loss type, beta, etc.
epoch : int
Current epoch number (for plotting).
Returns
-------
tuple
(valid_loss, valid_metrics) where valid_metrics is a dictionary
containing computed metrics for fluxes, abs12, and abs34.
Examples
--------
>>> valid_loss, metrics = run_validation(
... test_loader, model, norm_mapping, norm_type, idxmap, device, args, epoch
... )
>>> print(f"Validation loss: {valid_loss:.4f}")
>>> print(f"NMAE: {metrics['fluxes_NMAE']:.4f}")
"""
model.eval()
valid_loss_types = [
"mse",
"mae",
"nmae",
"nmse",
"wmse",
"logcosh",
"smoothl1",
"huber",
]
loss_type = args.loss_type.lower()
assert (
loss_type in valid_loss_types
), f"Invalid loss_type (should be one of {valid_loss_types})"
func = get_loss_function(loss_type, args)
metric_names = ["NMAE", "NMSE", "R2"]
metric_funcs = {"NMAE": nmae_all, "NMSE": nmse_all, "R2": r2_all}
output_keys = ["fluxes", "abs12", "abs34"]
valid_metrics = {
f"{k}_{m}": MetricTracker() for k in output_keys for m in metric_names
}
valid_loss = MetricTracker()
# Progress bar for validation
loop = tqdm(
enumerate(loader),
total=len(loader),
desc=f"Validation Epoch {epoch}",
leave=False,
)
with torch.no_grad():
for batch_idx, (feature, targets) in loop:
feature_shape = feature.shape
target_shape = targets.shape
inner_batch_size = feature_shape[0] * feature_shape[1]
feature = feature.reshape(
inner_batch_size, feature_shape[2], feature_shape[3]
).to(device=device)
targets = targets.reshape(
inner_batch_size, target_shape[2], target_shape[3]
).to(device=device)
predicts = model(feature)
predicts_unnorm, targets_unnorm = unnorm_mpas(
predicts, targets, norm_mapping, normalization_type, index_mapping
)
abs12_predict, abs12_target, abs34_predict, abs34_target = calc_abs(
predicts_unnorm, targets_unnorm
)
output_dict = {
"fluxes": (predicts, targets),
"abs12": (abs12_predict, abs12_target),
"abs34": (abs34_predict, abs34_target),
}
for key in output_keys:
pred, tgt = output_dict[key]
for metric in metric_names:
metric_key = f"{key}_{metric}"
if metric_key not in valid_metrics:
raise KeyError(
f"Metric key '{metric_key}' not found in valid_metrics"
)
count, value = metric_funcs[metric](pred, tgt)
valid_metrics[metric_key].update(value.item(), count)
main_count, main_val = predicts.numel(), func(predicts, targets)
abs12_count, abs12_val = (
abs12_predict.numel(),
func(abs12_predict, abs12_target),
)
abs34_count, abs34_val = (
abs34_predict.numel(),
func(abs34_predict, abs34_target),
)
weighted_loss = (1.0 - args.beta) * main_val * main_count + args.beta * (
abs12_val * abs12_count + abs34_val * abs34_count
)
total_count = (1.0 - args.beta) * main_count + args.beta * (
abs12_count + abs34_count
)
total_loss = weighted_loss / total_count
valid_loss.update(total_loss.item(), 1)
loop.set_postfix(loss=total_loss.item())
if epoch == args.num_epochs - 1:
print("making plot", batch_idx)
base_dir = os.path.join("results", args.main_folder, args.sub_folder)
# plot_RTM(predicts_unnorm, targets_unnorm, os.path.join(base_dir, f"Flux{batch_idx}_{args.test_year}.png"))
# plot_HeatRate(abs12_predict, abs12_target, abs34_predict, abs34_target, os.path.join(base_dir, f"Abs{batch_idx}_{args.test_year}.png"))
plot_flux_and_abs_lines(
predicts_unnorm,
targets_unnorm,
abs12_predict=abs12_predict,
abs12_target=abs12_target,
abs34_predict=abs34_predict,
abs34_target=abs34_target,
filename=os.path.join(
base_dir, f"Lineplot_Flux_Abs{batch_idx}_{args.test_year}.png"
),
)
plot_flux_and_abs(
predicts_unnorm,
targets_unnorm,
abs12_predict=abs12_predict,
abs12_target=abs12_target,
abs34_predict=abs34_predict,
abs34_target=abs34_target,
filename=os.path.join(
base_dir, f"flux_abs_hexbin_{batch_idx}_{args.test_year}.png"
),
)
return valid_loss.getmean(), {
k: (tracker.getsqrtmean() if k.lower().endswith("mse") else tracker.getmean())
for k, tracker in valid_metrics.items()
}