# Copyright 2026 IPSL / CNRS / Sorbonne University
# Authors: Kazem Ardaneh, Kishanthan Kingston, Pierre Chapel, Rosie Eade
#
# This work is licensed under the Creative Commons
# Attribution-NonCommercial-ShareAlike 4.0 International License.
# To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc-sa/4.0/
import os
os.environ.setdefault(
"CARTOPY_DATA_DIR",
"/leonardo_work/EUHPC_D27_095/cartopy_data",
)
import unittest
from datetime import datetime
import numpy as np
import torch
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import matplotlib.colors as mcolors
import matplotlib.patches as patches
from matplotlib.patches import ConnectionPatch
import matplotlib as mpl
from scipy import stats
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import mpltex
from sklearn.metrics import r2_score
import seaborn as sns
import pandas as pd
# ---------------------------------------------
# COMPLETE MATPLOTLIB STYLE CONFIGURATION
# ---------------------------------------------
params = {
# DPI & figure settings
# "figure.dpi": 150,
# "savefig.dpi": 300,
# Fonts
"font.family": "DejaVu Sans",
"mathtext.rm": "arial",
"font.size": 12, # General font size (affects ax.text())
"font.style": "normal", # 'normal', 'italic', 'oblique'
"font.weight": "normal", # 'normal', 'bold', 'heavy', 'light', 'ultrabold', 'ultralight'
"font.stretch": "normal", # Font stretch
# Line properties
"lines.linewidth": 2,
"lines.dashed_pattern": [4, 2],
"lines.dashdot_pattern": [6, 3, 2, 3],
"lines.dotted_pattern": [2, 3],
# Axis labels and titles
"axes.labelsize": 15,
"axes.titlesize": 15,
# Tick settings
"xtick.labelsize": 12,
"ytick.labelsize": 12,
"xtick.major.size": 6,
"ytick.major.size": 6,
"xtick.direction": "out",
"ytick.direction": "out",
# Legend
"legend.fontsize": 10,
"legend.loc": "best",
"legend.frameon": False,
# Text properties
"text.color": "black", # Default text color
"text.usetex": False, # LaTeX rendering
"text.hinting": "auto", # Text hinting
"text.antialiased": True, # Text anti-aliasing
"text.latex.preamble": "", # LaTeX preamble
}
mpl.rcParams.update(params)
# ============================================================================
# PLOTTING CONFIGURATION
# ============================================================================
[docs]
class PlotConfig:
"""Central configuration for all plotting functions."""
# General settings
DEFAULT_SAVE_DIR = "./results"
DEFAULT_FIGSIZE_MULTIPLIER = 4
# Color schemes
COLORMAPS = {
"temperature": "rainbow",
"temp": "rainbow",
"2t": "rainbow",
"zonal": "BrBG_r",
"10u": "BrBG_r",
"meridional": "BrBG_r",
"10v": "BrBG_r",
"tp": "Blues",
"TP": "Blues",
"precipitation": "Blues",
"dewpoint": "rainbow",
"d2m": "rainbow",
"surface temperature": "rainbow",
"st": "rainbow",
"pressure": "viridis",
"pres": "viridis",
"humidity": "Greens",
"humid": "Greens",
"wind": "coolwarm",
"speed": "coolwarm",
"mae": "Reds",
"error": "Reds",
"divergence": "seismic",
"curl": "seismic",
"ssr": "seismic",
"default": "viridis",
}
# Fixed visualization ranges for error diagnostics
FIXED_DIFF_RANGES = {
"T2M": (-5.0, 5.0), # K
"temperature": (-5.0, 5.0),
"2t": (-5.0, 5.0),
"VAR_2T": (-5.0, 5.0),
"U10": (-5.0, 5.0), # m/s
"10u": (-5.0, 5.0),
"meridional": (-5.0, 5.0),
"VAR_10U": (-5.0, 5.0),
"V10": (-5.0, 5.0), # m/s
"10v": (-5.0, 5.0),
"VAR_10V": (-5.0, 5.0),
"TP": (-0.5, 0.5), # mm/h
"tp": (-0.5, 0.5),
"VAR_TP": (-0.5, 0.5),
"VAR_D2M": (-5.0, 5.0), # K
"VAR_ST": (-5.0, 5.0), # K
}
FIXED_DIFF_RANGES_ERRORS = {
"VAR_2T": (0, 0.01), # K
"VAR_10U": (0, 3.0), # m/s
"VAR_10V": (0, 3.0), # m/s
"VAR_TP": (0, 0.5), # mm/h
"VAR_D2M": (0, 1.0), # K
"VAR_ST": (0, 1.0), # K
"Temp": (0, 3.0),
"Press": (0, 3.0),
"Humid": (0, 3.0),
"Wind": (0, 3.0),
}
FIXED_MAE_RANGES = {
"T2M": (0.0, 3.0),
"temperature": (0.0, 3.0),
"2t": (0.0, 3.0),
"VAR_2T": (0.0, 3.0),
"U10": (0.0, 3.0),
"10u": (0.0, 3.0),
"meridional": (0.0, 3.0),
"VAR_10U": (0.0, 3.0),
"V10": (0.0, 3.0),
"10v": (0.0, 3.0),
"VAR_10V": (0.0, 3.0),
"TP": (0.0, 1.0),
"tp": (0.0, 1.0),
"VAR_TP": (0.0, 1.0),
"VAR_D2M": (0.0, 3.0),
"VAR_ST": (0.0, 3.0),
}
FIXED_SSR_RANGES = {
"T2M": (0.0, 3.0),
"temperature": (0.0, 3.0),
"2t": (0.0, 3.0),
"VAR_2T": (0.0, 3.0),
"U10": (0.0, 3.0),
"10u": (0.0, 3.0),
"meridional": (0.0, 3.0),
"VAR_10U": (0.0, 3.0),
"V10": (0.0, 3.0),
"10v": (0.0, 3.0),
"VAR_10V": (0.0, 3.0),
"TP": (0.0, 3.0),
"tp": (0.0, 3.0),
"VAR_TP": (0.0, 1.0),
"VAR_D2M": (0.0, 3.0),
"VAR_ST": (0.0, 3.0),
}
# Geographic features
COASTLINE_w = 0.5
BORDER_w = 0.5
LAKE_w = 0.5
BORDER_STYLE = "--"
# Colorbar settings
COLORBAR_h = 0.02
COLORBAR_PAD = 0.05
[docs]
@classmethod
def get_colormap(cls, variable_name):
"""Get appropriate colormap for a variable."""
var_lower = variable_name.lower()
for key, cmap in cls.COLORMAPS.items():
if key in var_lower:
return cmap
return cls.COLORMAPS["default"]
[docs]
@classmethod
def get_plot_name(cls, variable_name):
"""Convert variable name to readable plot name."""
# Remove common prefixes
name = variable_name.replace("VAR_", "").replace("var_", "")
# Special cases
if name == "2T":
return "Temperature [K]"
elif name == "10U":
return "Zonal Wind [m/s]"
elif name == "10V":
return "Meridional Wind [m/s]"
elif name == "MSLP":
return "Sea Level Pressure"
elif name == "T2M":
return "2m Temperature [K]"
elif name == "U10":
return "10m Zonal Wind [m/s]"
elif name == "V10":
return "10m Meridional Wind [m/s]"
elif name == "TP":
return "Precipitation [mm/h]"
elif name == "tp":
return "Precipitation [mm/h]"
elif name == "D2M":
return "Dewpoint [K]"
elif name == "ST":
return "Surface Temperature [K]"
# General conversion
name = name.replace("_", " ")
return name.title()
[docs]
@classmethod
def convert_units(cls, variable_name, data):
"""
Safe unit conversion when required.
- NEVER modifies input
- Returns a new array only if conversion is needed
"""
name = variable_name.lower()
if name in ["tp", "var_tp", "precipitation"]:
return data * 1000.0 # m to mm
return data
[docs]
@staticmethod
def get_fixed_diff_range(var_name):
"""Get fixed visualization range for signed differences (Prediction − Truth)."""
return PlotConfig.FIXED_DIFF_RANGES.get(var_name, None)
[docs]
@staticmethod
def get_fixed_diff_range_errors(var_name):
"""Get fixed visualization range for error map."""
return PlotConfig.FIXED_DIFF_RANGES_ERRORS.get(var_name, None)
[docs]
@staticmethod
def get_fixed_mae_range(var_name):
"""Get fixed visualization range for Mean Absolute Error (MAE)."""
return PlotConfig.FIXED_MAE_RANGES.get(var_name, None)
[docs]
@staticmethod
def get_fixed_ssr_range(var_name):
"""Get fixed visualization range for Spread Skill Ratio (SSR)."""
return PlotConfig.FIXED_SSR_RANGES.get(var_name, None)
[docs]
def plot_validation_hexbin(
predictions, # Model predictions (fine predicted)
targets, # Ground truth (fine true)
coarse_inputs=None, # Coarse inputs for comparison (optional)
variable_names=None, # List of variable names
filename="validation_hexbin.png",
save_dir="./results",
figsize_multiplier=4, # Base size per subplot
):
"""
Create hexbin plots comparing model predictions vs ground truth for all variables.
Parameters
----------
predictions : torch.Tensor or np.array
Model predictions of shape [batch_size, num_variables, h, w]
targets : torch.Tensor or np.array
Ground truth of shape [batch_size, num_variables, h, w]
coarse_inputs : torch.Tensor or np.array, optional
Coarse inputs of shape [batch_size, num_variables, h, w]
variable_names : list of str, optional
Names of the variables for subplot titles
filename : str, optional
Output filename
save_dir : str, optional
Directory to save the plot
figsize_multiplier : int, optional
Base size multiplier for subplots
"""
# Convert to numpy if they're tensors
if hasattr(predictions, "detach"):
predictions = predictions.detach().cpu().numpy()
if hasattr(targets, "detach"):
targets = targets.detach().cpu().numpy()
if coarse_inputs is not None and hasattr(coarse_inputs, "detach"):
coarse_inputs = coarse_inputs.detach().cpu().numpy()
batch_size, num_vars, h, w = predictions.shape
# Default variable names if not provided
if variable_names is None:
variable_names = [f"VAR_{i}" for i in range(num_vars)]
# Calculate grid dimensions
ncols = num_vars
nrows = (num_vars + ncols - 1) // ncols # Ceiling division
# Create figure
fig, axes = plt.subplots(
nrows, ncols, figsize=(ncols * figsize_multiplier, figsize_multiplier)
) #
axes = np.atleast_1d(axes).ravel()
for ax in axes:
ax.set_box_aspect(1)
# Handle single subplot case
if num_vars == 1:
axes = np.array([axes])
if axes.ndim == 1:
axes = axes.reshape(1, -1)
# Flatten axes for easy iteration
axes_flat = axes.flatten()
# Plot each variable
max_count = 0
for i, (var_name, ax) in enumerate(zip(variable_names, axes_flat)):
if i >= num_vars:
ax.set_visible(False)
continue
# Flatten the spatial dimensions
# pred_flat = predictions[:, i, :, :].reshape(-1)
# target_flat = targets[:, i, :, :].reshape(-1)
pred_i = PlotConfig.convert_units(var_name, predictions[:, i])
tgt_i = PlotConfig.convert_units(var_name, targets[:, i])
pred_flat = pred_i.reshape(-1)
target_flat = tgt_i.reshape(-1)
# Create hexbin plot
hb = ax.hexbin(
target_flat, pred_flat, gridsize=100, cmap="jet", bins="log", mincnt=1
)
# Get counts for colorbar scaling
counts = hb.get_array()
if counts is not None:
max_count = max(max_count, np.max(counts))
# Add identity line
min_val = min(target_flat.min(), pred_flat.min())
max_val = max(target_flat.max(), pred_flat.max())
ax.plot([min_val, max_val], [min_val, max_val], "r--", alpha=0.7)
# Calculate metrics
r2 = r2_score(target_flat, pred_flat)
mae = np.mean(np.abs(pred_flat - target_flat))
rmse = np.sqrt(np.mean((pred_flat - target_flat) ** 2))
# Add metrics to plot
textstr = f"$R^2$: {r2:.3f}\nMAE: {mae:.3f}\nRMSE: {rmse:.3f}"
ax.text(
0.05,
0.95,
textstr,
transform=ax.transAxes,
fontsize=10,
verticalalignment="top",
)
# Set title
# ax.set_title(f'{var_name}')
plot_name = PlotConfig.get_plot_name(var_name)
ax.set_title(plot_name)
# Set equal aspect ratio
# ax.set_aspect('equal')
# Format ticks
ax.xaxis.set_major_locator(ticker.MaxNLocator(5))
ax.yaxis.set_major_locator(ticker.MaxNLocator(5))
# Only show y-label for leftmost subplots
if i % ncols == 0: # First column
ax.set_ylabel("Predicted Values")
else:
ax.set_ylabel("") # Remove y-label for non-leftmost plots
# Only show x-label for bottom row subplots
if i >= (nrows - 1) * ncols: # Last row
ax.set_xlabel("True Values")
else:
ax.set_xlabel("") # Remove x-label for non-bottom plots
# Add colorbar
# cbar_width_per_subplot = 0.02
# actual_cbar_width = cbar_width_per_subplot / num_vars
# cbar_ax = fig.add_axes([0.92, 0.1, actual_cbar_width, 0.8])
# cbar = fig.colorbar(hb, cax=cbar_ax, label=r"$\mathrm{\log_{10}[Count]}$")
# Colorbar attached to the LAST axis
ax_last = axes_flat[min(num_vars - 1, len(axes_flat) - 1)]
cax = ax_last.inset_axes([1.05, 0.0, 0.04, 1.0]) # [x, y, width, height]
cbar = fig.colorbar(hb, cax=cax)
cbar.set_label(r"$\log_{10}[\mathrm{Count}]$")
plt.subplots_adjust(
hspace=0.1, wspace=0.3, left=0.1, right=0.9, top=0.9, bottom=0.1
)
# Ensure save directory exists
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, filename)
plt.savefig(save_path, bbox_inches="tight")
plt.close(fig)
return save_path
[docs]
def plot_comparison_hexbin(
predictions,
targets,
coarse_inputs,
variable_names=None,
filename="comparison_hexbin.png",
save_dir="./results",
figsize_multiplier=4,
):
"""
Create hexbin comparison plots between model predictions, ground truth, and coarse inputs.
For each variable, creates two side-by-side hexbin plots:
1. Model predictions vs ground truth
2. Coarse inputs vs ground truth
Each plot includes an identity line and R²/MAE metrics.
Parameters
----------
predictions : torch.Tensor or np.array
Model predictions of shape [batch_size, num_variables, h, w]
targets : torch.Tensor or np.array
Ground truth of shape [batch_size, num_variables, h, w]
coarse_inputs : torch.Tensor or np.array
Coarse inputs of shape [batch_size, num_variables, h, w]
variable_names : list of str, optional
Names of the variables for subplot titles. If None, uses VAR_0, VAR_1, etc.
filename : str, optional
Output filename
save_dir : str, optional
Directory to save the plot
figsize_multiplier : int, optional
Base size multiplier for subplots
Returns
-------
save_path : str
Path to the saved figure
"""
# Convert tensors → numpy
if hasattr(predictions, "detach"):
predictions = predictions.detach().cpu().numpy()
if hasattr(targets, "detach"):
targets = targets.detach().cpu().numpy()
if hasattr(coarse_inputs, "detach"):
coarse_inputs = coarse_inputs.detach().cpu().numpy()
batch_size, num_vars, h, w = predictions.shape
if variable_names is None:
variable_names = [f"VAR_{i}" for i in range(num_vars)]
plot_variable_names = [PlotConfig.get_plot_name(var) for var in variable_names]
# For color scaling: collect all hexbin counts
all_counts = []
# ----------------------------------------
# 2) Pre-pass: collect hexbin densities
# ----------------------------------------
# for i in range(num_vars):
# target_flat = targets[:, i].reshape(-1)
# pred_flat = predictions[:, i].reshape(-1)
# coarse_flat = coarse_inputs[:, i].reshape(-1)
for i, var_name in enumerate(variable_names):
pred_i = PlotConfig.convert_units(var_name, predictions[:, i])
tgt_i = PlotConfig.convert_units(var_name, targets[:, i])
coarse_i = PlotConfig.convert_units(var_name, coarse_inputs[:, i])
pred_flat = pred_i.reshape(-1)
target_flat = tgt_i.reshape(-1)
coarse_flat = coarse_i.reshape(-1)
# Use a temporary invisible axes to get density arrays
fig_tmp, ax_tmp = plt.subplots()
hb1 = ax_tmp.hexbin(
target_flat, pred_flat, gridsize=100, cmap="jet", bins="log", mincnt=1
)
hb2 = ax_tmp.hexbin(
target_flat, coarse_flat, gridsize=100, cmap="jet", bins="log", mincnt=1
)
all_counts.append(hb1.get_array())
all_counts.append(hb2.get_array())
plt.close(fig_tmp)
# Global colorbar limits
all_counts = np.concatenate(all_counts)
global_vmin = np.min(all_counts)
global_vmax = np.max(all_counts)
# ----------------------------------------
# 3) Actual plot
# ----------------------------------------
fig, axes = plt.subplots(
num_vars,
2,
figsize=(2 * figsize_multiplier, num_vars * figsize_multiplier * 0.8),
)
plt.subplots_adjust(
hspace=0.3, wspace=0.4, left=0.1, right=0.9, top=0.9, bottom=0.1
)
if num_vars == 1:
axes = axes.reshape(1, -1)
last_hb = None
for i, var_name in enumerate(variable_names):
# target_flat = targets[:, i].reshape(-1)
# pred_flat = predictions[:, i].reshape(-1)
# coarse_flat = coarse_inputs[:, i].reshape(-1)
pred_i = PlotConfig.convert_units(var_name, predictions[:, i])
tgt_i = PlotConfig.convert_units(var_name, targets[:, i])
coarse_i = PlotConfig.convert_units(var_name, coarse_inputs[:, i])
pred_flat = pred_i.reshape(-1)
target_flat = tgt_i.reshape(-1)
coarse_flat = coarse_i.reshape(-1)
# Calculate per-variable min/max for this variable
var_min = min(target_flat.min(), pred_flat.min(), coarse_flat.min())
var_max = max(target_flat.max(), pred_flat.max(), coarse_flat.max())
# Add a small margin
margin = 0.05 * (var_max - var_min)
plot_min = var_min - margin
plot_max = var_max + margin
# --------------------------
# Left: Model vs True
# --------------------------
ax = axes[i, 0]
hb = ax.hexbin(
target_flat,
pred_flat,
gridsize=100,
cmap="jet",
bins="log",
mincnt=1,
vmin=global_vmin,
vmax=global_vmax,
)
last_hb = hb # store for colorbar
# Use per-variable axis limits
ax.set_xlim(plot_min, plot_max)
ax.set_ylim(plot_min, plot_max)
# identity line
ax.plot([plot_min, plot_max], [plot_min, plot_max], "r--", alpha=0.7)
r2 = r2_score(target_flat, pred_flat)
mae = np.mean(np.abs(pred_flat - target_flat))
ax.text(
0.05,
0.95,
f"$R^2$: {r2:.3f}\nMAE: {mae:.3f}",
transform=ax.transAxes,
va="top",
)
ax.set_title(f"{plot_variable_names[i]} – Model vs True")
ax.set_ylabel("Model Values")
if i == num_vars - 1:
ax.set_xlabel("True Values")
else:
ax.set_xlabel("")
# --------------------------
# Right: Coarse vs True
# --------------------------
ax = axes[i, 1]
hb = ax.hexbin(
target_flat,
coarse_flat,
gridsize=100,
cmap="jet",
bins="log",
mincnt=1,
vmin=global_vmin,
vmax=global_vmax,
)
last_hb = hb
# Use the same per-variable limits
ax.set_xlim(plot_min, plot_max)
ax.set_ylim(plot_min, plot_max)
ax.plot(
[plot_min, plot_max], [plot_min, plot_max], "r--", alpha=0.7, linewidth=1
)
r2 = r2_score(target_flat, coarse_flat)
mae = np.mean(np.abs(coarse_flat - target_flat))
ax.text(
0.05,
0.95,
f"$R^2$: {r2:.3f}\nMAE: {mae:.3f}",
transform=ax.transAxes,
va="top",
)
ax.set_title(f"{plot_variable_names[i]} – Coarse vs True")
ax.set_ylabel("Coarse Values")
if i == num_vars - 1:
ax.set_xlabel("True Values")
else:
ax.set_xlabel("")
# ----------------------------------------
# 4) Single shared colorbar
# ----------------------------------------
cbar_ax = fig.add_axes([0.98, 0.1, 0.02, 0.8])
fig.colorbar(last_hb, cax=cbar_ax, label=r"$\log_{10}[\mathrm{Count}]$")
# Save
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, filename)
plt.savefig(save_path, bbox_inches="tight")
plt.close(fig)
return save_path
[docs]
def plot_metric_histories(
valid_metrics_history,
variable_names,
metric_names,
filename="validation_metrics",
save_dir="./results",
figsize_multiplier=4,
):
"""
Creates row-based panel plots: one figure per metric, rows = variables, shared x-axis.
Parameters
----------
valid_metrics_history : dict
Dict from training loop storing metric histories.
variable_names : list of str
Names of variables.
metric_names : list of str
List of metric names (e.g. ["MAE"]).
filename : str
Prefix for saved figures.
save_dir : str
Directory where images are saved.
"""
os.makedirs(save_dir, exist_ok=True)
num_vars = len(variable_names)
for metric in metric_names:
# Rows = variables, 1 column, shared x-axis
fig, axes = plt.subplots(
nrows=num_vars,
ncols=1,
figsize=(6, figsize_multiplier * num_vars),
squeeze=False,
sharex=True,
)
plt.subplots_adjust(hspace=0.1, left=0.15, right=0.95, top=0.9, bottom=0.1)
for i, var in enumerate(variable_names):
ax = axes[i, 0]
key_pred = f"{var}_pred_vs_fine_{metric}"
key_coarse = f"{var}_coarse_vs_fine_{metric}"
if (
key_pred not in valid_metrics_history
or key_coarse not in valid_metrics_history
):
ax.text(0.5, 0.5, "Missing Data", ha="center", va="center")
ax.set_yscale("log")
continue
pred_hist = valid_metrics_history[key_pred]
coarse_hist = valid_metrics_history[key_coarse]
# Plot
linestyles = mpltex.linestyle_generator(markers=[])
ax.plot(pred_hist, label="Pred vs Fine", **next(linestyles))
ax.plot(coarse_hist, label="Coarse vs Fine", **next(linestyles))
ax.set_yscale("log")
# ax.set_ylabel(rf"$\mathrm{{{metric}\ ({var})}}$")
ax.set_ylabel(f"{metric} ({var})")
ax.grid(True, alpha=0.3)
ax.legend()
# Only bottom row shows x-axis label
if i == num_vars - 1:
ax.set_xlabel("Epoch")
else:
ax.tick_params(labelbottom=False)
save_path = os.path.join(save_dir, f"{filename}_{metric}.png")
plt.savefig(save_path, bbox_inches="tight")
plt.close(fig)
return save_path
[docs]
def plot_metrics_heatmap(
valid_metrics_history,
variable_names,
metric_names,
filename="validation_metrics_heatmap",
save_dir="./results",
figsize_multiplier=4,
):
"""
Plot a heatmap of validation metrics.
Parameters
----------
valid_metrics_history : dict
Dict from validation loop storing metric histories.
variable_names : list of str
Names of variables.
metric_names : list of str
List of metric names (["MAE", "NMAE", "RMSE", "R²"]).
filename : str
Prefix for saved figures.
save_dir : str
Directory where images are saved.
figsize_multiplier : float
Controls overall figure size
"""
os.makedirs(save_dir, exist_ok=True)
# Build DataFrame
data = {}
for metric in metric_names:
values = []
for var in variable_names:
key = f"{var}_pred_vs_fine_{metric}"
if key in valid_metrics_history:
tracker = valid_metrics_history[key]
if tracker.count > 0:
value = tracker.getmean()
# Convert only dimensional metrics
if metric.lower() in ["mae", "rmse", "crps"]:
value = PlotConfig.convert_units(var, value)
else:
value = np.nan
else:
value = np.nan
values.append(value)
data[metric] = values
df = pd.DataFrame(data, index=variable_names)
fig_width = figsize_multiplier + len(metric_names)
fig_height = 0.6 * len(variable_names) + figsize_multiplier / 2
fig, ax = plt.subplots(figsize=(fig_width, fig_height))
sns.heatmap(
df, ax=ax, cmap="viridis", annot=True, fmt=".3f", linewidths=0.8, cbar=True
)
ax.set_title("Validation metrics")
ax.set_xlabel("Metric")
ax.set_ylabel("Variable")
plt.tight_layout()
save_path = os.path.join(save_dir, f"{filename}.png")
plt.savefig(save_path, bbox_inches="tight")
plt.close(fig)
return save_path
[docs]
def plot_loss_histories(
train_loss_history,
valid_loss_history,
filename="training_validation_loss.png",
save_dir="./results",
figsize_multiplier=4,
):
"""
Plots training and validation loss in a single panel.
Parameters:
-----------
train_loss_history : list or array
History of training loss values.
valid_loss_history : list or array
History of validation loss values.
filename : str
Output image file name for the plot.
save_dir : str
Directory to save the plot.
"""
# Ensure inputs are lists
if not isinstance(train_loss_history, list):
train_loss_history = list(train_loss_history)
if not isinstance(valid_loss_history, list):
valid_loss_history = list(valid_loss_history)
fig = plt.figure(figsize=(6, figsize_multiplier))
ax = fig.add_subplot(111)
epochs = range(len(train_loss_history))
# Plot losses
linestyles = mpltex.linestyle_generator(markers=[])
ax.plot(epochs, train_loss_history, label="Training Loss", **next(linestyles))
if valid_loss_history and any(valid_loss_history):
ax.plot(epochs, valid_loss_history, label="Validation Loss", **next(linestyles))
ax.set_yscale("log")
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss Value")
ax.legend()
ax.grid(True, alpha=0.3)
# Ensure save directory exists
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, filename)
plt.savefig(save_path, bbox_inches="tight")
plt.close(fig)
return save_path
print(f"Loss history plot saved to: '{save_path}'")
[docs]
def plot_average_metrics(
valid_metrics_history,
metric_names, # List of metrics to plot
filename="average_metrics.png",
save_dir="./results",
figsize_multiplier=4,
):
"""
Plots average metrics across all variables in a row-based layout with shared x-axis.
Each row corresponds to one metric, plotting both:
- average_pred_vs_fine_<metric>
- average_coarse_vs_fine_<metric>
Parameters
----------
valid_metrics_history : dict
Dictionary containing validation metrics history.
metric_names : list of str
Names of metrics to plot.
filename : str
Output image file name for the plot.
save_dir : str
Directory to save the plot.
"""
if not metric_names:
print("No metric names provided")
return
num_rows = len(metric_names)
# Create figure: rows = num_rows, 1 column, share x-axis
fig, axes = plt.subplots(
nrows=num_rows,
ncols=1,
figsize=(6, figsize_multiplier * num_rows),
squeeze=False,
sharex=True,
)
plt.subplots_adjust(hspace=0.1, left=0.15, right=0.95, top=0.95, bottom=0.1)
for idx, metric in enumerate(metric_names):
ax = axes[idx, 0]
linestyles = mpltex.linestyle_generator(markers=[])
# Keys
key_pred = f"average_pred_vs_fine_{metric}"
key_coarse = f"average_coarse_vs_fine_{metric}"
# Plot pred vs fine
if key_pred in valid_metrics_history:
hist = valid_metrics_history[key_pred]
if not isinstance(hist, list):
hist = list(hist)
ax.plot(hist, label="Pred vs Fine", **next(linestyles))
# Plot coarse vs fine
if key_coarse in valid_metrics_history:
hist = valid_metrics_history[key_coarse]
if not isinstance(hist, list):
hist = list(hist)
ax.plot(hist, label="Coarse vs Fine", **next(linestyles))
ax.set_yscale("log")
ax.set_ylabel(metric.replace("_", " ").title())
ax.grid(True, alpha=0.3)
ax.legend()
# Only bottom row gets x-label
if idx == num_rows - 1:
ax.set_xlabel("Epoch")
else:
ax.set_xlabel("")
ax.tick_params(labelbottom=False)
# Ensure save directory exists
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, filename)
plt.savefig(save_path, bbox_inches="tight")
plt.close(fig)
return save_path
[docs]
def plot_spatiotemporal_histograms(
steps,
tindex_lim,
centers,
tindices,
mode="train",
filename="average_metrics.png",
save_dir="./results",
figsize_multiplier=4,
):
"""
Plot two 2D hexagonal bin histograms showing spatial-temporal data coverage:
latitude center vs temporal index and longitude center vs temporal index.
This function visualizes the distribution of data samples across spatial
(latitude/longitude) and temporal dimensions using hexagonal binning,
which provides smoother density estimation compared to rectangular binning.
Parameters
----------
steps : EasyDict
Dictionary containing coordinate dimensions and limits. Expected to have
attributes 'latitude' (or 'lat') and 'longitude' (or 'lon') specifying
the maximum spatial indices.
tindex_lim : tuple
Tuple of (min_time, max_time) specifying the temporal index limits.
centers : list of tuples
List of (lat_center, lon_center) coordinates for each data sample.
Each center represents the spatial location of a data point.
tindices : list or array-like
List of temporal indices corresponding to each data sample.
Should have the same length as 'centers'.
mode : str
Dataset mode identifier, typically "train" or "validation".
Used for plot title and filename.
save_dir : str
Directory path where the plot will be saved.
Directory will be created if it doesn't exist.
filename : str, optional
Optional prefix to prepend to the output filename.
Default is empty string.
Returns
-------
None
The function saves the plot to disk and does not return any value.
Notes
-----
- The function creates two side-by-side subplots:
1. Latitude center index vs temporal index
2. Longitude center index vs temporal index
- Uses hexagonal binning (hexbin) for density visualization, which reduces
visual artifacts compared to rectangular histograms.
- A single colorbar is shared between both plots with log10 scaling.
- The color scale is normalized to the maximum count across both histograms.
- Hexagons with zero count (mincnt=1) are not displayed.
Examples
--------
>>> steps = EasyDict({'latitude': 180, 'longitude': 360})
>>> tindex_lim = (0, 1000)
>>> centers = [(10, 20), (15, 25), (10, 20), ...] # list of (lat, lon)
>>> tindices = [0, 5, 10, 15, ...] # corresponding temporal indices
>>> plot_spatiotemporal_histograms(steps, tindex_lim, centers,
... tindices, "train", "./plots")
The function will save a plot named "spatiotemporal_train_hexbin.png"
in the "./plots" directory.
"""
if not centers or not tindices:
print(f"No data to plot for {mode} mode")
return
# Convert to numpy arrays for efficient processing
centers = np.array(centers)
lat_centers = centers[:, 0]
lon_centers = centers[:, 1]
tindices = np.array(tindices)
# Extract spatial limits from steps dictionary with fallback options
max_lat = getattr(steps, "latitude", getattr(steps, "lat", None))
max_lon = getattr(steps, "longitude", getattr(steps, "lon", None))
min_time, max_time = tindex_lim
# Create figure with two side-by-side subplots sharing y-axis
fig, (ax1, ax2) = plt.subplots(
1, 2, figsize=(2 * figsize_multiplier, figsize_multiplier), sharey=True
)
plt.subplots_adjust(
hspace=0.1, wspace=0.1, left=0.1, right=0.9, top=0.9, bottom=0.1
)
# Plot latitude vs time using hexagonal binning
hex1 = ax1.hexbin(
lat_centers,
tindices,
gridsize=100, # Number of hexagons in x-direction
extent=[0, max_lat, min_time, max_time], # Data limits
cmap="jet", # Color map (assumed to be defined)
mincnt=1, # Only show hexagons with at least 1 count
edgecolors="none",
) # No borders on hexagons
ax1.set_xlabel("Latitude Center Index", fontsize=12)
ax1.set_ylabel("Temporal Index", fontsize=12)
ax1.set_xlim(0, max_lat)
ax1.set_ylim(min_time, max_time)
ax1.grid(True, alpha=0.3, linestyle="--")
# Plot longitude vs time using hexagonal binning
hex2 = ax2.hexbin(
lon_centers,
tindices,
gridsize=100,
extent=[0, max_lon, min_time, max_time],
cmap="jet",
mincnt=1,
edgecolors="none",
)
ax2.set_xlabel("Longitude Center Index", fontsize=12)
ax2.set_xlim(0, max_lon)
ax2.set_ylim(min_time, max_time)
ax2.grid(True, alpha=0.3, linestyle="--")
# Normalize color scale to maximum count across both plots
max_count = 1
if hex1.get_array() is not None and len(hex1.get_array()) > 0:
max_count = max(max_count, hex1.get_array().max())
if hex2.get_array() is not None and len(hex2.get_array()) > 0:
max_count = max(max_count, hex2.get_array().max())
hex1.set_clim(0, max_count)
hex2.set_clim(0, max_count)
# Add single colorbar for both plots
cbar_ax = fig.add_axes([0.93, 0.1, 0.02, 0.8])
fig.colorbar(hex1, cax=cbar_ax, label=r"$\log_{10}[\mathrm{Count}]$")
# Save figure to disk
os.makedirs(save_dir, exist_ok=True)
filename = f"{filename}spatiotemporal_{mode}_hexbin.png"
save_path = os.path.join(save_dir, filename)
plt.savefig(save_path, bbox_inches="tight")
plt.close(fig)
return save_path
[docs]
def plot_surface(
predictions,
targets,
coarse_inputs,
lat_1d,
lon_1d,
timestamp=None,
variable_names=None,
filename="forecast_plot.png",
save_dir=None,
figsize_multiplier=None,
):
"""
Plot side-by-side forecast maps (coarse_inputs input, true target, model prediction, and difference)
for one or more meteorological variables over a geographic domain.
Parameters
----------
coarse_inputs : torch.Tensor or np.ndarray
coarse_inputs-resolution input data with shape [1, n_vars, H, W].
targets : torch.Tensor or np.ndarray
Ground-truth high-resolution data with shape [1, n_vars, H, W].
predictions : torch.Tensor or np.ndarray
Model predictions at targets resolution with shape [1, n_vars, H, W].
lat_1d : array-like
1D array of latitude coordinates with shape [H].
lon_1d : array-like
1D array of longitude coordinates with shape [W].
timestamp : datetime.datetime
Forecast timestamp to include in the plot title.
variable_names : list of str, optional
Variable names or identifiers.
filename : str, optional
Output filename for saving the plot.
save_dir : str, optional
Directory to save the plot.
figsize_multiplier : int, optional
Base size multiplier for subplots.
Returns
-------
None
"""
# Use defaults from config if not provided
if save_dir is None:
save_dir = PlotConfig.DEFAULT_SAVE_DIR
if figsize_multiplier is None:
figsize_multiplier = PlotConfig.DEFAULT_FIGSIZE_MULTIPLIER
# Convert tensors to numpy if needed
if hasattr(coarse_inputs, "detach"):
coarse_inputs = coarse_inputs.detach().cpu().numpy()
if hasattr(targets, "detach"):
targets = targets.detach().cpu().numpy()
if hasattr(predictions, "detach"):
predictions = predictions.detach().cpu().numpy()
if hasattr(lat_1d, "detach"):
lat_1d = lat_1d.detach().cpu().numpy()
if hasattr(lon_1d, "detach"):
lon_1d = lon_1d.detach().cpu().numpy()
# Create 2D meshgrid from 1D coordinates
lat_min, lat_max = lat_1d.min(), lat_1d.max()
lon_min, lon_max = lon_1d.min(), lon_1d.max()
# Shape
h, w = coarse_inputs[0, 0].shape
lat_block = np.linspace(lat_max, lat_min, h)
lon_block = np.linspace(lon_min, lon_max, w)
lat, lon = np.meshgrid(lat_block, lon_block, indexing="ij")
# Projection center
lon_center = float((lon_min + lon_max) / 2)
# Check data dimensions
n_vars = coarse_inputs.shape[1]
if targets.shape[1] != n_vars:
raise ValueError(
f"targets data has {targets.shape[1]} variables but coarse_inputs has {n_vars}"
)
if predictions.shape[1] != n_vars:
raise ValueError(
f"predictions data has {predictions.shape[1]} variables but coarse_inputs has {n_vars}"
)
# Default variable names if not provided
if variable_names is None:
variable_names = [f"VAR_{i}" for i in range(n_vars)]
# Derive plot names and colormaps
plot_variable_names = [PlotConfig.get_plot_name(var) for var in variable_names]
cmaps = [PlotConfig.get_colormap(var) for var in variable_names]
# Derive vmin/vmax from data for each variable (for coarse_inputs, truth, prediction)
vmin_list = []
vmax_list = []
# Derive vmin/vmax for difference plots (signed difference)
diff_vmin_list = []
diff_vmax_list = []
for i in range(n_vars):
var_name = variable_names[i]
coarse_i = PlotConfig.convert_units(var_name, coarse_inputs[0, i])
target_i = PlotConfig.convert_units(var_name, targets[0, i])
pred_i = PlotConfig.convert_units(var_name, predictions[0, i])
# Get combined data range for this variable (coarse_inputs, truth, prediction)
# all_data = np.concatenate([#
# coarse_inputs[0, i].flatten(),
# targets[0, i].flatten(),
# predictions[0, i].flatten()
# ])
all_data = np.concatenate(
[coarse_i.flatten(), target_i.flatten(), pred_i.flatten()]
)
# Calculate vmin/vmax (using quantile approach like original function)
all_data_flat = all_data[~np.isnan(all_data)]
if len(all_data_flat) > 0:
q_low, q_high = np.quantile(all_data_flat, [0.02, 0.98])
vmin, vmax = float(q_low), float(q_high)
else:
vmin, vmax = -1, 1
# Ensure vmin < vmax
if vmin >= vmax:
vmin, vmax = float(np.nanmin(all_data)), float(np.nanmax(all_data))
vmin_list.append(vmin)
vmax_list.append(vmax)
# Calculate signed difference between prediction and truth
fixed_range = PlotConfig.get_fixed_diff_range(var_name)
diff_data = (predictions[0, i] - targets[0, i]).flatten()
diff_data = diff_data[~np.isnan(diff_data)]
if fixed_range is not None:
diff_vmin, diff_vmax = fixed_range
else:
if len(diff_data) > 0:
# For signed difference, we want symmetric range around 0
max_abs_diff = np.max(np.abs(diff_data))
diff_vmin = -max_abs_diff * 1.1 # Add 10% padding
diff_vmax = max_abs_diff * 1.1 # Add 10% padding
# If all differences are zero or very small
if diff_vmax <= 0.001:
diff_vmin, diff_vmax = -0.1, 0.1
else:
diff_vmin, diff_vmax = -1, 1
diff_vmin_list.append(diff_vmin)
diff_vmax_list.append(diff_vmax)
# Use fixed figure size instead of geo_ratio calculation
# This ensures rectangular panels regardless of location
base_width_per_panel = 4.5 # Same as original scale
base_height_per_panel = 3.0 # Keep this as is
fig_width = base_width_per_panel * n_vars
fig_height = base_height_per_panel * 4 # 4 rows
# Set up figure
fig, axes = plt.subplots(
4,
n_vars, # 4 rows, n_vars columns
figsize=(fig_width, fig_height),
subplot_kw={
"projection": ccrs.PlateCarree(central_longitude=lon_center)
}, # ccrs.Mercator(central_longitude=lon_center)
gridspec_kw={"wspace": 0.1, "hspace": 0.1}, # Keep spacing
squeeze=False,
)
# Main title
if timestamp is not None:
# fig.suptitle(
# f"Forecast for {timestamp.strftime('%Y-%m-%d %H:%M')}",
# fontsize=16, y=1.02
# )
print(f"Forecast for {timestamp.strftime('%Y-%m-%d %H:%M')}")
# Define geographic features
# coastline = cfeature.COASTLINE.with_scale('50m')
# borders = cfeature.BORDERS.with_scale('50m')
# lakes = cfeature.LAKES.with_scale('50m')
# Plot each variable
for col_idx in range(n_vars):
# Data for this variable
# coarse_inputs_data = coarse_inputs[0, col_idx, :, :]
# targets_data = targets[0, col_idx, :, :]
# pred_data = predictions[0, col_idx, :, :]
var_name = variable_names[col_idx]
# plot_name = plot_variable_names[col_idx]
coarse_inputs_data = PlotConfig.convert_units(
var_name, coarse_inputs[0, col_idx]
)
targets_data = PlotConfig.convert_units(var_name, targets[0, col_idx])
pred_data = PlotConfig.convert_units(var_name, predictions[0, col_idx])
diff_data = pred_data - targets_data # Signed difference (pred - truth)
# Store image objects for rows that need colorbars
im_coar = None
im_diff = None
# Process all rows
for row_idx in range(4):
ax = axes[row_idx, col_idx]
# Select data based on row
if row_idx == 0:
data = coarse_inputs_data
vmin, vmax = vmin_list[col_idx], vmax_list[col_idx]
cmap = cmaps[col_idx]
elif row_idx == 1:
data = targets_data
vmin, vmax = vmin_list[col_idx], vmax_list[col_idx]
cmap = cmaps[col_idx]
elif row_idx == 2:
data = pred_data
vmin, vmax = vmin_list[col_idx], vmax_list[col_idx]
cmap = cmaps[col_idx]
else: # row_idx == 3
data = diff_data
vmin, vmax = diff_vmin_list[col_idx], diff_vmax_list[col_idx]
cmap = "RdBu_r" # Diverging colormap for differences
# Create the plot
im = ax.pcolormesh(
lon,
lat,
data,
vmin=vmin,
vmax=vmax,
cmap=cmap,
transform=ccrs.PlateCarree(),
shading="auto",
)
# Store image objects for rows that need colorbars
if row_idx == 0:
im_coar = im
elif row_idx == 3:
im_diff = im
# Set extent and features
# ax.set_global()
ax.set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree())
# ax.coastlines(linewidth=0.5)
# ax.add_feature(borders, linewidth=0.5, linestyle="--", edgecolor="black")
# ax.add_feature(lakes, linewidth=0.5, edgecolor="black", facecolor="none")
ax.coastlines(linewidth=0.6)
ax.add_feature(
cfeature.BORDERS.with_scale("50m"),
linewidth=0.9,
linestyle="--",
edgecolor="black",
zorder=11,
)
ax.add_feature(
cfeature.LAKES.with_scale("50m"),
edgecolor="black",
facecolor="none",
linewidth=0.9,
zorder=9,
)
# ax.set_aspect("auto") # CRITICAL: This makes panels rectangular regardless of projection
ax.set_xticks([])
ax.set_yticks([])
# Add colorbar for PREDICTION row (row 2)
if im_coar is not None:
ax_coar = axes[0, col_idx]
# Position at top of panel: [x, y, width, height] where y > 1.0 places it above
cax_top = ax_coar.inset_axes([0.1, 1.05, 0.8, 0.05])
cbar = fig.colorbar(im_coar, cax=cax_top, orientation="horizontal")
cbar.set_label(f"{plot_variable_names[col_idx]}")
cax_top.xaxis.set_ticks_position("top")
cax_top.xaxis.set_label_position("top")
# Add colorbar for DIFFERENCE row (row 3)
if im_diff is not None:
ax_diff = axes[3, col_idx]
cax_diff = ax_diff.inset_axes([0.1, -0.12, 0.8, 0.05])
fig.colorbar(
im_diff,
cax=cax_diff,
orientation="horizontal",
label=f"Δ {plot_variable_names[col_idx]} (Pred - Truth)",
)
# Add row labels on the left side
row_labels = ["Coarse", "Truth", "Prediction", "Pred - Truth"]
for row_idx, label in enumerate(row_labels):
axes[row_idx, 0].text(
-0.12,
0.5,
label,
transform=axes[row_idx, 0].transAxes,
va="center",
ha="right",
rotation="vertical",
fontsize=12,
)
# Adjust layout - give more room at bottom for colorbars
fig.subplots_adjust(
top=0.90, bottom=0.25, left=0.10, right=0.95, wspace=0.1, hspace=0.15
)
# Save figure
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, filename)
plt.savefig(save_path, bbox_inches="tight")
plt.close(fig)
return save_path
[docs]
def plot_ensemble_surface(
predictions_ens,
lat_1d,
lon_1d,
variable_names,
timestamp=None,
filename="ensemble_surface.png",
save_dir="./results",
):
"""
Plot ensemble members, ensemble mean, and ensemble spread.
Parameters
----------
predictions_ens : torch.Tensor or np.ndarray
Ensemble predictions of shape [n_ensemble_members, n_vars, H, W]
lat_1d : array-like
1D array of latitude coordinates with shape [H].
lon_1d : array-like
1D array of longitude coordinates with shape [W].
variable_names : list of str, optional
Variable names or identifiers.
timestamp : datetime.datetime
Forecast timestamp to include in the plot title.
filename : str, optional
Output filename for saving the plot.
save_dir : str, optional
Directory to save the plot.
figsize_multiplier : int, optional
Base size multiplier for subplots.
Returns
-------
None
"""
if torch.is_tensor(predictions_ens):
predictions_ens = predictions_ens.detach().cpu().numpy()
N_ens, C, H, W = predictions_ens.shape
if N_ens < 3:
raise ValueError("Need at least 3 ensemble members")
# Ensemble statistics
ensemble_mean = np.mean(predictions_ens, axis=0)
ensemble_std = np.std(predictions_ens, axis=0)
lat_1d = np.asarray(lat_1d)
lon_1d = np.asarray(lon_1d)
lat_min, lat_max = lat_1d.min(), lat_1d.max()
lon_min, lon_max = lon_1d.min(), lon_1d.max()
lat_block = np.linspace(lat_max, lat_min, H)
lon_block = np.linspace(lon_min, lon_max, W)
lat, lon = np.meshgrid(lat_block, lon_block, indexing="ij")
lon_center = float((lon_min + lon_max) / 2)
plot_variable_names = [PlotConfig.get_plot_name(v) for v in variable_names]
cmaps = [PlotConfig.get_colormap(v) for v in variable_names]
# Compute vmin/vmax
vmin_list = []
vmax_list = []
for i in range(C):
var = variable_names[i]
ens_members = [
PlotConfig.convert_units(var, predictions_ens[k, i]) for k in range(N_ens)
]
mean_i = PlotConfig.convert_units(var, ensemble_mean[i])
all_data = np.concatenate(
[x.flatten() for x in ens_members] + [mean_i.flatten()]
)
all_data = all_data[~np.isnan(all_data)]
if len(all_data) > 0:
q_low, q_high = np.quantile(all_data, [0.02, 0.98])
vmin, vmax = float(q_low), float(q_high)
else:
vmin, vmax = -1, 1
if vmin >= vmax:
vmin = float(np.nanmin(all_data))
vmax = float(np.nanmax(all_data))
vmin_list.append(vmin)
vmax_list.append(vmax)
n_rows = 5
n_cols = C
base_width_per_panel = 4.5
base_height_per_panel = 3.0
fig_width = base_width_per_panel * n_cols
fig_height = base_height_per_panel * n_rows
fig, axes = plt.subplots(
n_rows,
n_cols,
figsize=(fig_width, fig_height),
subplot_kw={"projection": ccrs.PlateCarree(central_longitude=lon_center)},
gridspec_kw={"wspace": 0.1, "hspace": 0.15},
squeeze=False,
)
row_labels = [
"Prediction 1",
"Prediction 2",
"Prediction 3",
"Ensemble Mean",
"Ensemble std (σ)",
]
for col in range(n_cols):
var = variable_names[col]
member1 = PlotConfig.convert_units(var, predictions_ens[0, col])
member2 = PlotConfig.convert_units(var, predictions_ens[1, col])
member3 = PlotConfig.convert_units(var, predictions_ens[2, col])
mean_field = PlotConfig.convert_units(var, ensemble_mean[col])
std_field = PlotConfig.convert_units(var, ensemble_std[col])
rows_data = [member1, member2, member3, mean_field, std_field]
im_main = None
im_spread = None
for row in range(n_rows):
ax = axes[row, col]
if row == 4:
cmap = "Reds"
vmin = 0
vmax = np.nanmax(std_field)
# vmax = np.quantile(std_field, 0.99)
else:
cmap = cmaps[col]
vmin = vmin_list[col]
vmax = vmax_list[col]
im = ax.pcolormesh(
lon,
lat,
rows_data[row],
vmin=vmin,
vmax=vmax,
cmap=cmap,
transform=ccrs.PlateCarree(),
shading="auto",
)
ax.set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree())
ax.coastlines(linewidth=0.6)
ax.add_feature(
cfeature.BORDERS.with_scale("50m"),
linewidth=0.9,
linestyle="--",
edgecolor="black",
zorder=11,
)
ax.add_feature(
cfeature.LAKES.with_scale("50m"),
edgecolor="black",
facecolor="none",
linewidth=0.9,
zorder=9,
)
ax.set_xticks([])
ax.set_yticks([])
if row == 4:
im_spread = im
else:
if im_main is None:
im_main = im
ax_top = axes[0, col]
cax_top = ax_top.inset_axes([0.1, 1.05, 0.8, 0.05])
cbar = fig.colorbar(im_main, cax=cax_top, orientation="horizontal")
cbar.set_label(plot_variable_names[col])
cax_top.xaxis.set_ticks_position("top")
cax_top.xaxis.set_label_position("top")
ax_bottom = axes[4, col]
cax_bottom = ax_bottom.inset_axes([0.1, -0.12, 0.8, 0.05])
fig.colorbar(
im_spread,
cax=cax_bottom,
orientation="horizontal",
label=f"Std {plot_variable_names[col]}",
)
for r, label in enumerate(row_labels):
axes[r, 0].text(
-0.12,
0.5,
label,
transform=axes[r, 0].transAxes,
va="center",
ha="right",
rotation="vertical",
fontsize=12,
)
if timestamp is not None:
print(f"Ensemble predictions — {timestamp}")
fig.subplots_adjust(
top=0.90,
bottom=0.25,
left=0.10,
right=0.95,
wspace=0.1,
hspace=0.15,
)
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, filename)
plt.savefig(save_path, bbox_inches="tight")
plt.close(fig)
return save_path
[docs]
def plot_zoom_comparison(
predictions,
targets,
lat_1d,
lon_1d,
variable_names=None,
filename="zoom_plot.png",
save_dir=None,
zoom_box=None,
):
"""
Plot a comparison between ground truth and model predictions with a geographic zoom.
Parameters
----------
targets : torch.Tensor or np.ndarray
Ground-truth high-resolution data with shape [1, n_vars, H, W].
predictions : torch.Tensor or np.ndarray
Model predictions at targets resolution with shape [1, n_vars, H, W].
lat_1d : array-like
1D array of latitude coordinates with shape [H].
lon_1d : array-like
1D array of longitude coordinates with shape [W].
variable_names : list of str, optional
Variable names or identifiers.
filename : str, optional
Output filename for saving the plot.
save_dir : str, optional
Directory to save the plot.
zoom_box : dict, optional
Dictionary defining the zoom region with keys.
Returns
-------
None
"""
if save_dir is None:
save_dir = PlotConfig.DEFAULT_SAVE_DIR
if zoom_box is None:
zoom_box = {"lat_min": -23, "lat_max": 13, "lon_min": 255, "lon_max": 345}
# Convert tensors
if hasattr(predictions, "detach"):
predictions = predictions.detach().cpu().numpy()
if hasattr(targets, "detach"):
targets = targets.detach().cpu().numpy()
if hasattr(lat_1d, "detach"):
lat_1d = lat_1d.detach().cpu().numpy()
if hasattr(lon_1d, "detach"):
lon_1d = lon_1d.detach().cpu().numpy()
lat = lat_1d
lon = lon_1d
lon2d, lat2d = np.meshgrid(lon, lat)
lat_min, lat_max = lat.min(), lat.max()
lon_min, lon_max = lon.min(), lon.max()
lon_center = float((lon_min + lon_max) / 2)
lat_mask = (lat >= zoom_box["lat_min"]) & (lat <= zoom_box["lat_max"])
lon_mask = (lon >= zoom_box["lon_min"]) & (lon <= zoom_box["lon_max"])
lat_zoom = lat[lat_mask]
lon_zoom = lon[lon_mask]
lon_zoom2d, lat_zoom2d = np.meshgrid(lon_zoom, lat_zoom)
n_vars = targets.shape[1]
if variable_names is None:
variable_names = [f"VAR_{i}" for i in range(n_vars)]
plot_variable_names = [PlotConfig.get_plot_name(v) for v in variable_names]
cmaps = [PlotConfig.get_colormap(v) for v in variable_names]
proj_global = ccrs.PlateCarree(central_longitude=lon_center)
proj_zoom = ccrs.PlateCarree()
base_width_per_panel = 4.5
base_height_per_panel = 2.5
fig = plt.figure(figsize=(base_width_per_panel * n_vars, base_height_per_panel * 4))
left_margin = 0.08
right_margin = 0.02
bottom_margin = 0.08
top_margin = 0.06
hspace = 0.002
wspace = 0.008
total_width = 1 - left_margin - right_margin
total_height = 1 - bottom_margin - top_margin
col_width = total_width / n_vars
row_height = total_height / 4
axes = np.empty((4, n_vars), dtype=object)
for row in range(4):
for col in range(n_vars):
proj = proj_global if row == 0 else proj_zoom
x0 = left_margin + col * col_width + wspace / 2
y0 = 1 - top_margin - (row + 1) * row_height + hspace / 2
w = col_width - wspace
h = row_height - hspace
axes[row, col] = fig.add_axes([x0, y0, w, h], projection=proj)
coastline = cfeature.COASTLINE.with_scale("50m")
borders = cfeature.BORDERS.with_scale("50m")
for col in range(n_vars):
var = variable_names[col]
truth = PlotConfig.convert_units(var, targets[0, col])
pred = PlotConfig.convert_units(var, predictions[0, col])
mae = np.abs(pred - truth)
cmap = cmaps[col]
all_data = np.concatenate([truth.flatten(), pred.flatten()])
all_data = all_data[~np.isnan(all_data)]
vmin, vmax = np.quantile(all_data, [0.02, 0.98])
mae_vmax = np.quantile(mae[~np.isnan(mae)], 0.98)
truth_zoom = truth[np.ix_(lat_mask, lon_mask)]
pred_zoom = pred[np.ix_(lat_mask, lon_mask)]
mae_zoom = mae[np.ix_(lat_mask, lon_mask)]
# ---- Row 0 Truth global ----
ax = axes[0, col]
im = ax.pcolormesh(
lon2d,
lat2d,
truth,
cmap=cmap,
vmin=vmin,
vmax=vmax,
transform=ccrs.PlateCarree(),
shading="auto",
)
ax.set_extent([lon_min, lon_max, lat_min, lat_max])
ax.add_feature(coastline, linewidth=0.6)
ax.add_feature(borders, linewidth=0.5)
rect = patches.Rectangle(
(zoom_box["lon_min"], zoom_box["lat_min"]),
zoom_box["lon_max"] - zoom_box["lon_min"],
zoom_box["lat_max"] - zoom_box["lat_min"],
linewidth=2,
edgecolor="red",
facecolor="none",
transform=ccrs.PlateCarree(),
)
ax.add_patch(rect)
zoom_ax = axes[1, col]
fig.add_artist(
ConnectionPatch(
xyA=(zoom_box["lon_min"], zoom_box["lat_max"]),
coordsA=ccrs.PlateCarree()._as_mpl_transform(ax),
xyB=(0, 1),
coordsB=zoom_ax.transAxes,
color="red",
linewidth=1.5,
)
)
fig.add_artist(
ConnectionPatch(
xyA=(zoom_box["lon_max"], zoom_box["lat_max"]),
coordsA=ccrs.PlateCarree()._as_mpl_transform(ax),
xyB=(1, 1),
coordsB=zoom_ax.transAxes,
color="red",
linewidth=1.5,
)
)
im_global = im
# ---- Row 1 Truth zoom ----
ax = axes[1, col]
ax.pcolormesh(
lon_zoom2d,
lat_zoom2d,
truth_zoom,
cmap=cmap,
vmin=vmin,
vmax=vmax,
transform=ccrs.PlateCarree(),
shading="auto",
)
"""
ax.set_extent(
[
zoom_box["lon_min"],
zoom_box["lon_max"],
zoom_box["lat_min"],
zoom_box["lat_max"],
]
)
"""
ax.set_extent(
[
lon_zoom.min(),
lon_zoom.max(),
lat_zoom.min(),
lat_zoom.max(),
],
crs=ccrs.PlateCarree(),
)
ax.add_feature(coastline, linewidth=0.6)
ax.add_feature(borders, linewidth=0.5)
# ---- Row 2 Prediction zoom ----
ax = axes[2, col]
ax.pcolormesh(
lon_zoom2d,
lat_zoom2d,
pred_zoom,
cmap=cmap,
vmin=vmin,
vmax=vmax,
transform=ccrs.PlateCarree(),
shading="auto",
)
"""
ax.set_extent(
[
zoom_box["lon_min"],
zoom_box["lon_max"],
zoom_box["lat_min"],
zoom_box["lat_max"],
]
)
"""
ax.set_extent(
[
lon_zoom.min(),
lon_zoom.max(),
lat_zoom.min(),
lat_zoom.max(),
],
crs=ccrs.PlateCarree(),
)
ax.add_feature(coastline, linewidth=0.6)
ax.add_feature(borders, linewidth=0.5)
# ---- Row 3 MAE ----
ax = axes[3, col]
im_mae = ax.pcolormesh(
lon_zoom2d,
lat_zoom2d,
mae_zoom,
cmap="Reds",
vmin=0,
vmax=mae_vmax,
transform=ccrs.PlateCarree(),
shading="auto",
)
"""
ax.set_extent(
[
zoom_box["lon_min"],
zoom_box["lon_max"],
zoom_box["lat_min"],
zoom_box["lat_max"],
]
)
"""
ax.set_extent(
[
lon_zoom.min(),
lon_zoom.max(),
lat_zoom.min(),
lat_zoom.max(),
],
crs=ccrs.PlateCarree(),
)
ax.add_feature(coastline, linewidth=0.6)
ax.add_feature(borders, linewidth=0.5)
# Colorbars
cax = ax.inset_axes([0.15, -0.25, 0.7, 0.05])
fig.colorbar(im_mae, cax=cax, orientation="horizontal").set_label(
f"MAE {plot_variable_names[col]}"
)
ax_top = axes[0, col]
cax_top = ax_top.inset_axes([0.1, 1.05, 0.8, 0.05])
cbar = fig.colorbar(im_global, cax=cax_top, orientation="horizontal")
cbar.set_label(plot_variable_names[col])
cax_top.xaxis.set_ticks_position("top")
cax_top.xaxis.set_label_position("top")
for r in range(4):
axes[r, col].set_xticks([])
axes[r, col].set_yticks([])
# Labels lignes
row_labels = ["Truth", "Truth (Zoom)", "Prediction (Zoom)", "MAE"]
for i, label in enumerate(row_labels):
axes[i, 0].text(
-0.15,
0.5,
label,
transform=axes[i, 0].transAxes,
rotation=90,
va="center",
ha="right",
)
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, filename)
plt.savefig(save_path, bbox_inches="tight")
plt.close()
return save_path
[docs]
def plot_global_surface_robinson(
predictions,
targets,
coarse_inputs,
lat_1d,
lon_1d,
timestamp=None,
variable_names=None,
filename="global_robinson.png",
save_dir=None,
figsize_multiplier=None,
):
"""
Plot coarse, truth, prediction and difference fields in Robinson projection.
Parameters
----------
coarse_inputs : torch.Tensor or np.ndarray
coarse_inputs-resolution input data with shape [1, n_vars, H, W].
targets : torch.Tensor or np.ndarray
Ground-truth high-resolution data with shape [1, n_vars, H, W].
predictions : torch.Tensor or np.ndarray
Model predictions at targets resolution with shape [1, n_vars, H, W].
lat_1d : array-like
1D array of latitude coordinates with shape [H].
lon_1d : array-like
1D array of longitude coordinates with shape [W].
timestamp : datetime.datetime
Forecast timestamp to include in the plot title.
variable_names : list of str, optional
Variable names or identifiers.
filename : str, optional
Output filename for saving the plot.
save_dir : str, optional
Directory to save the plot.
figsize_multiplier : int, optional
Base size multiplier for subplots.
Returns
-------
None
"""
# Use defaults from config if not provided
if save_dir is None:
save_dir = PlotConfig.DEFAULT_SAVE_DIR
if figsize_multiplier is None:
figsize_multiplier = PlotConfig.DEFAULT_FIGSIZE_MULTIPLIER
# Convert tensors to numpy if needed
if hasattr(coarse_inputs, "detach"):
coarse_inputs = coarse_inputs.detach().cpu().numpy()
if hasattr(targets, "detach"):
targets = targets.detach().cpu().numpy()
if hasattr(predictions, "detach"):
predictions = predictions.detach().cpu().numpy()
if hasattr(lat_1d, "detach"):
lat_1d = lat_1d.detach().cpu().numpy()
if hasattr(lon_1d, "detach"):
lon_1d = lon_1d.detach().cpu().numpy() #
# Create 2D meshgrid from 1D coordinates
lat_min, lat_max = lat_1d.min(), lat_1d.max()
lon_min, lon_max = lon_1d.min(), lon_1d.max()
# Shape
h, w = coarse_inputs[0, 0].shape
lat_block = np.linspace(lat_max, lat_min, h)
lon_block = np.linspace(lon_min, lon_max, w)
lat2d, lon2d = np.meshgrid(lat_block, lon_block, indexing="ij")
lon2d = ((lon2d + 180) % 360) - 180 # normalize
# Check data dimensions
n_vars = coarse_inputs.shape[1]
if targets.shape[1] != n_vars:
raise ValueError(
f"targets data has {targets.shape[1]} variables but coarse_inputs has {n_vars}"
)
if predictions.shape[1] != n_vars:
raise ValueError(
f"predictions data has {predictions.shape[1]} variables but coarse_inputs has {n_vars}"
)
# Default variable names if not provided
if variable_names is None:
variable_names = [f"VAR_{i}" for i in range(n_vars)]
# Derive plot names and colormaps
plot_variable_names = [PlotConfig.get_plot_name(var) for var in variable_names]
cmaps = [PlotConfig.get_colormap(var) for var in variable_names]
# Derive vmin/vmax from data for each variable (for coarse_inputs, truth, prediction)
vmin_list = []
vmax_list = []
# Derive vmin/vmax for difference plots (signed difference)
diff_vmin_list = []
diff_vmax_list = []
for i in range(n_vars):
# Get combined data range for this variable (coarse_inputs, truth, prediction)
all_data = np.concatenate(
[
coarse_inputs[0, i].flatten(),
targets[0, i].flatten(),
predictions[0, i].flatten(),
]
)
# Calculate vmin/vmax (using quantile approach like original function)
all_data_flat = all_data[~np.isnan(all_data)]
if len(all_data_flat) > 0:
q_low, q_high = np.quantile(all_data_flat, [0.02, 0.98])
vmin, vmax = float(q_low), float(q_high)
else:
vmin, vmax = -1, 1
# Ensure vmin < vmax
if vmin >= vmax:
vmin, vmax = float(np.nanmin(all_data)), float(np.nanmax(all_data))
vmin_list.append(vmin)
vmax_list.append(vmax)
# Calculate signed difference between prediction and truth
diff_data = (predictions[0, i] - targets[0, i]).flatten()
diff_data = diff_data[~np.isnan(diff_data)]
if len(diff_data) > 0:
# For signed difference, we want symmetric range around 0
max_abs_diff = np.max(np.abs(diff_data))
diff_vmin = -max_abs_diff * 1.1 # Add 10% padding
diff_vmax = max_abs_diff * 1.1 # Add 10% padding
# If all differences are zero or very small
if diff_vmax <= 0.001:
diff_vmin, diff_vmax = -0.1, 0.1
else:
diff_vmin, diff_vmax = -1, 1
diff_vmin_list.append(diff_vmin)
diff_vmax_list.append(diff_vmax)
# Set up figure
fig, axes = plt.subplots(
4,
n_vars, # 4 rows, n_vars columns
figsize=(4.5 * n_vars, 3.2 * 4),
subplot_kw={"projection": ccrs.Robinson()},
gridspec_kw={"hspace": 0.12, "wspace": 0.05},
)
if n_vars == 1:
axes = axes.reshape(4, 1)
row_labels = ["Coarse", "Truth", "Prediction", "Pred − Truth"]
# Plot each variable
for col in range(n_vars):
coarse = coarse_inputs[0, col]
truth = targets[0, col]
pred = predictions[0, col]
diff = pred - truth
data_rows = [coarse, truth, pred, diff]
vmins = [vmin_list[col]] * 3 + [diff_vmin_list[col]]
vmaxs = [vmax_list[col]] * 3 + [diff_vmax_list[col]]
cmaps_row = [cmaps[col]] * 3 + ["RdBu_r"]
for row in range(4):
ax = axes[row, col]
ax.set_global()
# Create the plot
im = ax.pcolormesh(
lon2d,
lat2d,
data_rows[row],
transform=ccrs.PlateCarree(),
cmap=cmaps_row[row],
vmin=vmins[row],
vmax=vmaxs[row],
shading="auto",
)
ax.coastlines(linewidth=0.9)
ax.add_feature(
cfeature.BORDERS.with_scale("50m"),
linewidth=0.9,
linestyle="--",
edgecolor="black",
zorder=11,
)
ax.add_feature(
cfeature.LAKES.with_scale("50m"),
edgecolor="black",
facecolor="none",
linewidth=0.9,
zorder=9,
)
ax.set_xticks([])
ax.set_yticks([])
# if row == 0:
# ax.set_title(plot_variable_names[col], fontsize=13)
if col == 0:
ax.text(
-0.08,
0.5,
row_labels[row],
transform=ax.transAxes,
va="center",
ha="right",
rotation=90,
fontsize=12,
)
# Colorbars
if row == 0:
# cax = ax.inset_axes([0.1, 1.02, 0.8, 0.05])
cax = ax.inset_axes([0.1, 1.08, 0.8, 0.05])
cb = fig.colorbar(im, cax=cax, orientation="horizontal")
cb.set_label(plot_variable_names[col])
cax.xaxis.set_ticks_position("top")
cax.xaxis.set_label_position("top")
if row == 3:
cax = ax.inset_axes([0.1, -0.12, 0.8, 0.05])
fig.colorbar(
im,
cax=cax,
orientation="horizontal",
label=f"Δ {plot_variable_names[col]} (Pred - Truth)",
)
if timestamp is not None:
fig.suptitle(
f"Global Robinson diagnostic – {timestamp.strftime('%Y-%m-%d %H:%M')}",
fontsize=16,
y=0.96,
)
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, filename)
plt.savefig(save_path, bbox_inches="tight")
plt.close(fig)
return save_path
[docs]
def plot_MAE_map(
predictions, # Model predictions (fine predicted)
targets, # Ground truth (fine true)
lat_1d,
lon_1d,
timestamp=None,
variable_names=None,
filename="validation_mae_map.png",
save_dir=None,
figsize_multiplier=None, # Base size per subplot
):
"""
Plot spatial MAE maps averaged over all time steps:
MAE(x, y) = mean_t(abs(prediction - target))
Parameters
----------
predictions : torch.Tensor or np.array
Model predictions of shape [batch_size, num_variables, h, w]
targets : torch.Tensor or np.array
Ground truth of shape [batch_size, num_variables, h, w]
lat_1d : array-like
1D array of latitude coordinates with shape [H].
lon_1d : array-like
1D array of longitude coordinates with shape [W].
timestamp : datetime.datetime
Forecast timestamp to include in the plot title.
variable_names : list of str, optional
Variable names or identifiers.
filename : str, optional
Output filename for saving the plot.
save_dir : str, optional
Directory to save the plot.
figsize_multiplier : int, optional
Base size multiplier for subplots.
Returns
-------
None
"""
if save_dir is None:
save_dir = PlotConfig.DEFAULT_SAVE_DIR
if figsize_multiplier is None:
figsize_multiplier = PlotConfig.DEFAULT_FIGSIZE_MULTIPLIER
# Convert tensors to numpy
if hasattr(predictions, "detach"):
predictions = predictions.detach().cpu().numpy()
if hasattr(targets, "detach"):
targets = targets.detach().cpu().numpy()
if hasattr(lat_1d, "detach"):
lat_1d = lat_1d.detach().cpu().numpy()
if hasattr(lon_1d, "detach"):
lon_1d = lon_1d.detach().cpu().numpy()
lat_min, lat_max = lat_1d.min(), lat_1d.max()
lon_min, lon_max = lon_1d.min(), lon_1d.max()
T, n_vars, h, w = predictions.shape
lat_block = np.linspace(lat_max, lat_min, h)
lon_block = np.linspace(lon_min, lon_max, w)
lat, lon = np.meshgrid(lat_block, lon_block, indexing="ij")
lon_center = float((lon_min + lon_max) / 2)
if targets.shape[1] != n_vars:
raise ValueError("targets and predictions must have same number of variables")
if variable_names is None:
variable_names = [f"VAR_{i}" for i in range(n_vars)]
plot_variable_names = [PlotConfig.get_plot_name(var) for var in variable_names]
# cmaps = [PlotConfig.get_colormap(var) for var in variable_names]
cmaps = PlotConfig.get_colormap("mae")
vmin_list, vmax_list = [], []
# MAE averaged over time for color scaling
for i in range(n_vars):
# mae_data = np.mean(np.abs(predictions[:, i] - targets[:, i]), axis=0)
pred_i = PlotConfig.convert_units(variable_names[i], predictions[:, i])
tgt_i = PlotConfig.convert_units(variable_names[i], targets[:, i])
mae_data = np.mean(np.abs(pred_i - tgt_i), axis=0)
mae_flat = mae_data.flatten()
mae_flat = mae_flat[~np.isnan(mae_flat)]
fixed_range = PlotConfig.get_fixed_mae_range(variable_names[i])
if fixed_range is not None:
vmin, vmax = fixed_range
else:
if len(mae_flat) > 0:
q_low, q_high = np.quantile(mae_flat, [0.02, 0.98])
vmin, vmax = float(q_low), float(q_high)
else:
vmin, vmax = 0.0, 1.0
if vmin >= vmax:
vmin, vmax = float(np.nanmin(mae_flat)), float(np.nanmax(mae_flat))
vmin_list.append(vmin)
vmax_list.append(vmax)
base_width_per_panel = 4.5
base_height_per_panel = 3.0
fig_width = base_width_per_panel * n_vars
fig_height = base_height_per_panel
fig, axes = plt.subplots(
1,
n_vars,
figsize=(fig_width, fig_height),
subplot_kw={
"projection": ccrs.PlateCarree(central_longitude=lon_center)
}, # ccrs.Mercator(central_longitude=lon_center)
gridspec_kw={"wspace": 0.1},
)
if n_vars == 1:
axes = [axes]
if timestamp is not None:
fig.suptitle(
f"MAE Map (time-averaged) - {timestamp.strftime('%Y-%m-%d %H:%M')}",
fontsize=16,
y=1.05,
)
for col_idx in range(n_vars):
ax = axes[col_idx]
pred_i = PlotConfig.convert_units(
variable_names[col_idx], predictions[:, col_idx]
)
tgt_i = PlotConfig.convert_units(variable_names[col_idx], targets[:, col_idx])
# MAE averaged over all time steps
mae_data = np.mean(np.abs(pred_i - tgt_i), axis=0)
im = ax.pcolormesh(
lon,
lat,
mae_data,
vmin=vmin_list[col_idx],
vmax=vmax_list[col_idx],
cmap=cmaps,
transform=ccrs.PlateCarree(),
shading="auto",
)
ax.set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree())
# ax.set_global()
ax.coastlines(linewidth=0.6)
ax.add_feature(
cfeature.BORDERS.with_scale("50m"),
linewidth=0.6,
linestyle="--",
edgecolor="black",
zorder=11,
)
ax.add_feature(
cfeature.LAKES.with_scale("50m"),
edgecolor="black",
facecolor="none",
linewidth=0.6,
zorder=9,
)
# ax.set_aspect("auto")
ax.set_xticks([])
ax.set_yticks([])
cax = ax.inset_axes([0.1, -0.15, 0.8, 0.05])
fig.colorbar(
im,
cax=cax,
orientation="horizontal",
label=f"MAE {plot_variable_names[col_idx]}",
)
fig.subplots_adjust(top=0.85, bottom=0.25, left=0.08, right=0.95)
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, filename)
plt.savefig(save_path, bbox_inches="tight")
plt.close(fig)
return save_path
[docs]
def plot_error_map(
predictions, # Model predictions (fine predicted)
targets, # Ground truth (fine true)
lat_1d,
lon_1d,
timestamp=None,
variable_names=None,
filename="validation_error_map.png",
save_dir=None,
figsize_multiplier=None,
):
"""
Plot spatial ERROR maps averaged over all time steps.
Parameters
----------
predictions : torch.Tensor or np.array
Model predictions of shape [batch_size, num_variables, h, w]
targets : torch.Tensor or np.array
Ground truth of shape [batch_size, num_variables, h, w]
lat_1d : array-like
1D array of latitude coordinates with shape [H].
lon_1d : array-like
1D array of longitude coordinates with shape [W].
timestamp : datetime.datetime
Forecast timestamp to include in the plot title.
variable_names : list of str, optional
Variable names or identifiers.
filename : str, optional
Output filename for saving the plot.
save_dir : str, optional
Directory to save the plot.
figsize_multiplier : int, optional
Base size multiplier for subplots.
Returns
-------
None
"""
if save_dir is None:
save_dir = PlotConfig.DEFAULT_SAVE_DIR
if figsize_multiplier is None:
figsize_multiplier = PlotConfig.DEFAULT_FIGSIZE_MULTIPLIER
# Convert tensors to numpy
if hasattr(predictions, "detach"):
predictions = predictions.detach().cpu().numpy()
if hasattr(targets, "detach"):
targets = targets.detach().cpu().numpy()
if hasattr(lat_1d, "detach"):
lat_1d = lat_1d.detach().cpu().numpy()
if hasattr(lon_1d, "detach"):
lon_1d = lon_1d.detach().cpu().numpy()
lat_min, lat_max = lat_1d.min(), lat_1d.max()
lon_min, lon_max = lon_1d.min(), lon_1d.max()
T, n_vars, h, w = predictions.shape
lat_block = np.linspace(lat_max, lat_min, h)
lon_block = np.linspace(lon_min, lon_max, w)
lat, lon = np.meshgrid(lat_block, lon_block, indexing="ij")
lon_center = float((lon_min + lon_max) / 2)
if variable_names is None:
variable_names = [f"VAR_{i}" for i in range(n_vars)]
plot_variable_names = [PlotConfig.get_plot_name(v) for v in variable_names]
cmaps = PlotConfig.get_colormap("error")
vmin_list, vmax_list = [], []
eps = 1e-6
# Compute time-averaged error for scaling
for i in range(n_vars):
var = variable_names[i]
pred_i = PlotConfig.convert_units(var, predictions[:, i])
tgt_i = PlotConfig.convert_units(var, targets[:, i])
if var.lower() in ["var_tp", "precip", "precipitation"]:
err = np.mean(np.abs(pred_i - tgt_i), axis=0)
else:
err = np.mean(np.abs(pred_i - tgt_i) / (np.abs(tgt_i) + eps), axis=0)
err_flat = err.flatten()
err_flat = err_flat[~np.isnan(err_flat)]
fixed_range = PlotConfig.get_fixed_diff_range_errors(var)
if fixed_range is not None:
vmin, vmax = fixed_range
else:
if len(err_flat) > 0:
vmax = np.max(err_flat)
vmin = 0
vmax = 1.1 * vmax
else:
vmin, vmax = 0, 1
vmin_list.append(vmin)
vmax_list.append(vmax)
base_w, base_h = 4.5, 3.0
fig, axes = plt.subplots(
1,
n_vars,
figsize=(base_w * n_vars, base_h),
subplot_kw={"projection": ccrs.PlateCarree(central_longitude=lon_center)},
gridspec_kw={"wspace": 0.1},
)
if n_vars == 1:
axes = [axes]
for i in range(n_vars):
ax = axes[i]
var = variable_names[i]
pred_i = PlotConfig.convert_units(var, predictions[:, i])
tgt_i = PlotConfig.convert_units(var, targets[:, i])
if var.lower() in ["var_tp", "precip", "precipitation"]:
err_map = np.mean(np.abs(pred_i - tgt_i), axis=0)
label = "Absolute Error (mm/h)"
else:
err_map = np.mean(np.abs(pred_i - tgt_i) / (np.abs(tgt_i) + eps), axis=0)
label = "Relative Error"
im = ax.pcolormesh(
lon,
lat,
err_map,
vmin=vmin_list[i],
vmax=vmax_list[i],
cmap=cmaps,
transform=ccrs.PlateCarree(),
shading="auto",
)
ax.set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree())
ax.coastlines(linewidth=0.6)
ax.add_feature(
cfeature.BORDERS.with_scale("50m"),
linewidth=0.6,
linestyle="--",
edgecolor="black",
zorder=11,
)
ax.add_feature(
cfeature.LAKES.with_scale("50m"),
edgecolor="black",
facecolor="none",
linewidth=0.6,
zorder=9,
)
ax.set_xticks([])
ax.set_yticks([])
ax.set_title(plot_variable_names[i])
cax = ax.inset_axes([0.1, -0.15, 0.8, 0.05])
fig.colorbar(
im,
cax=cax,
orientation="horizontal",
label=f"{label}",
)
fig.subplots_adjust(top=0.85, bottom=0.25, left=0.08, right=0.95)
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, filename)
plt.savefig(save_path, bbox_inches="tight")
plt.close(fig)
return save_path
[docs]
def spread_skill_ratio(
predictions, # Model predictions (fine predicted)
targets, # Ground truth (fine true)
variable_names,
pixel_wise=False,
):
"""
Compute spread skill ratio of predictions with respect to targets.
The formula implemented is equation (15) in "Why Should Ensemble Spread Match the RMSE of the Ensemble Mean?", Fortin et al.
Parameters
----------
predictions : torch.Tensor or np.array
Model predictions of shape [ensemble_size, batch_size, num_variables, h, w]
It is very important not to switch dimensions order.
ensemble_size must be greater or equal than 2 for spread skill ratio to be computed.
targets : torch.Tensor or np.array
Ground truth of shape [batch_size, num_variables, h, w]
variable_names : list of str, optional
Variable names or identifiers.
pixel_wise : bool
If True, computes and return the SSR for each pixel independantly.
If False, computes and return the SSR averaged over all pixels and all timesteps.
Defaults to False
Returns
-------
np.array of shape [num_variables, h, w] if pixel_wise == True
or of shape [num_variables,] if pixel_wise == False (default)
"""
# Convert tensors to numpy
if hasattr(predictions, "detach"):
predictions = predictions.detach().cpu().numpy()
if hasattr(targets, "detach"):
targets = targets.detach().cpu().numpy()
if len(predictions.shape) != 5:
raise ValueError(
"predictions needs to be 5 dimensional tensor / array [ensemble_size, temporal_size, n_vars, h, w]."
)
if predictions.shape[0] == 1:
raise ValueError(
"predictions needs to contain more than 1 member to compute spread skill ratio."
)
E, T, n_vars, h, w = predictions.shape
if targets.shape[1] != n_vars:
raise ValueError("targets and predictions must have same number of variables")
if variable_names is None:
variable_names = [f"VAR_{i}" for i in range(n_vars)]
ssr_list = []
for i in range(n_vars):
# mae_data = np.mean(np.abs(predictions[:, i] - targets[:, i]), axis=0)
pred_i = PlotConfig.convert_units(
variable_names[i], predictions[:, :, i]
) # [E,T,h,w]
tgt_i = PlotConfig.convert_units(variable_names[i], targets[:, i]) # [T,h,w]
# apply the formula (15) found in "Why Should Ensemble Spread Match the RMSE of the Ensemble Mean?", Fortin et al.
mean_pred_i = np.mean(pred_i, axis=0) # ensemble mean [T,h,w]
if pixel_wise:
rmse_data_i = np.sqrt(np.mean((mean_pred_i - tgt_i) ** 2, axis=0)) # [h,w]
spread_i = (
np.sqrt(np.mean(np.var(pred_i, axis=0), axis=0)) * np.sqrt((E + 1) / E)
) # [h,w] # sqrt of temporal mean of variance * corrective factor depending on the number of members in the ensemble.
ssr_i = np.divide(spread_i, rmse_data_i) # [h,w]
else:
# do the same but for average over every pixel and every timestep
rmse_data_i_mean = np.sqrt(np.mean((mean_pred_i - tgt_i) ** 2)) # float
spread_i_mean = (
np.sqrt(np.mean(np.var(pred_i, axis=0))) * np.sqrt((E + 1) / E)
) # float # sqrt of temporal mean of variance * corrective factor depending on the number of members in the ensemble.
ssr_i = np.divide(spread_i_mean, rmse_data_i_mean) # float
ssr_list.append(ssr_i)
return np.array(ssr_list)
[docs]
def plot_spread_skill_ratio_map(
predictions, # Model predictions (fine predicted)
targets, # Ground truth (fine true)
lat_1d,
lon_1d,
timestamp=None,
variable_names=None,
filename="validation_spread_skill_ratio_map.png",
save_dir=None,
figsize_multiplier=None, # Base size per subplot
):
"""
Plot spatial spread skill ratio maps averaged over all time steps for each individual pixel.
The formula implemented is equation (15) in article "Why Should Ensemble Spread Match the RMSE of the Ensemble Mean?", Fortin et al.
Parameters
----------
predictions : torch.Tensor or np.array
Model predictions of shape [ensemble_size, batch_size, num_variables, h, w]
It is very important not to switch dimensions order.
ensemble_size must be greater or equal than 2 for spread skill ratio to be computed.
targets : torch.Tensor or np.array
Ground truth of shape [batch_size, num_variables, h, w]
lat_1d : array-like
1D array of latitude coordinates with shape [H].
lon_1d : array-like
1D array of longitude coordinates with shape [W].
timestamp : datetime.datetime
Forecast timestamp to include in the plot title.
variable_names : list of str, optional
Variable names or identifiers.
filename : str, optional
Output filename for saving the plot.
save_dir : str, optional
Directory to save the plot.
figsize_multiplier : int, optional
Base size multiplier for subplots.
Returns
-------
None
"""
if save_dir is None:
save_dir = PlotConfig.DEFAULT_SAVE_DIR
if figsize_multiplier is None:
figsize_multiplier = PlotConfig.DEFAULT_FIGSIZE_MULTIPLIER
# Convert tensors to numpy
if hasattr(predictions, "detach"):
predictions = predictions.detach().cpu().numpy()
if hasattr(targets, "detach"):
targets = targets.detach().cpu().numpy()
if hasattr(lat_1d, "detach"):
lat_1d = lat_1d.detach().cpu().numpy()
if hasattr(lon_1d, "detach"):
lon_1d = lon_1d.detach().cpu().numpy()
lat_min, lat_max = lat_1d.min(), lat_1d.max()
lon_min, lon_max = lon_1d.min(), lon_1d.max()
if len(predictions.shape) != 5:
raise ValueError(
"predictions needs to be 5 dimensional tensor / array [ensemble_size, temporal_size, n_vars, h, w]."
)
if predictions.shape[0] == 1:
raise ValueError(
"predictions needs to contain more than 1 member to compute spread skill ratio."
)
E, T, n_vars, h, w = predictions.shape
lat_block = np.linspace(lat_max, lat_min, h)
lon_block = np.linspace(lon_min, lon_max, w)
lat, lon = np.meshgrid(lat_block, lon_block, indexing="ij")
lon_center = float((lon_min + lon_max) / 2)
if targets.shape[1] != n_vars:
raise ValueError("targets and predictions must have same number of variables")
if variable_names is None:
variable_names = [f"VAR_{i}" for i in range(n_vars)]
plot_variable_names = [PlotConfig.get_plot_name(var) for var in variable_names]
# cmaps = [PlotConfig.get_colormap(var) for var in variable_names]
cmaps = PlotConfig.get_colormap("SSR")
vmin_list, vmax_list = [], []
ssr_list = spread_skill_ratio(predictions, targets, variable_names, pixel_wise=True)
ssr_mean_list = spread_skill_ratio(
predictions, targets, variable_names, pixel_wise=False
)
# MAE averaged over time for color scaling
for i in range(n_vars):
ssr_i = ssr_list[i]
ssr_i_flat = (ssr_i).flatten() # flatten [h*w]
fixed_range = PlotConfig.get_fixed_ssr_range(variable_names[i])
if fixed_range is not None:
vmin, vmax = fixed_range
else:
if len(ssr_i_flat) > 0:
q_low, q_high = np.quantile(ssr_i_flat, [0.02, 0.98])
vmin, vmax = float(q_low), float(q_high)
else:
vmin, vmax = 0.0, 2.0
if vmin >= vmax:
vmin, vmax = float(np.nanmin(ssr_i_flat)), float(np.nanmax(ssr_i_flat))
vmin_list.append(vmin)
vmax_list.append(vmax)
base_width_per_panel = 4.5
base_height_per_panel = 3.0
fig_width = base_width_per_panel * n_vars
fig_height = base_height_per_panel
fig, axes = plt.subplots(
1,
n_vars,
figsize=(fig_width, fig_height),
subplot_kw={
"projection": ccrs.PlateCarree(central_longitude=lon_center)
}, # ccrs.Mercator(central_longitude=lon_center)
gridspec_kw={"wspace": 0.1},
)
if n_vars == 1:
axes = [axes]
# if timestamp is None:
# fig.suptitle(
# "Spread skill ratio map (time-averaged)",
# y=1.05,
# )
# if timestamp is not None:
# fig.suptitle(
# f"Spread skill ratio map {timestamp.strftime('%Y-%m-%d %H:%M')}",
# y=1.05,
# )
for col_idx in range(n_vars):
ax = axes[col_idx]
# MAE averaged over all time steps
ssr_data_i = ssr_list[col_idx]
ssr_data_i_mean = ssr_mean_list[col_idx]
vmin = vmin_list[col_idx]
vmax = vmax_list[col_idx]
if vmin <= 1 <= vmax:
norm = mcolors.TwoSlopeNorm(vmin=vmin, vcenter=1, vmax=vmax)
im = ax.pcolormesh(
lon,
lat,
ssr_data_i,
cmap=cmaps,
norm=norm,
transform=ccrs.PlateCarree(),
shading="auto",
)
elif vmax >= 1:
norm = mcolors.TwoSlopeNorm(vmin=0, vcenter=1, vmax=vmax)
im = ax.pcolormesh(
lon,
lat,
ssr_data_i,
cmap=cmaps,
norm=norm,
transform=ccrs.PlateCarree(),
shading="auto",
)
else:
norm = mcolors.TwoSlopeNorm(vmin=0, vcenter=1, vmax=2)
im = ax.pcolormesh(
lon,
lat,
ssr_data_i,
cmap=cmaps,
norm=norm,
transform=ccrs.PlateCarree(),
shading="auto",
)
ax.set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree())
# ax.set_global()
ax.coastlines(linewidth=0.6)
ax.add_feature(
cfeature.BORDERS.with_scale("50m"),
linewidth=0.6,
linestyle="--",
edgecolor="black",
zorder=11,
)
ax.add_feature(
cfeature.LAKES.with_scale("50m"),
edgecolor="black",
facecolor="none",
linewidth=0.6,
zorder=9,
)
# ax.set_aspect("auto")
ax.set_xticks([])
ax.set_yticks([])
# props = dict(boxstyle="round", facecolor="wheat", alpha=0.5)
# place a text box in upper left in axes coords
# ax.text(
# 0.05,
# 1.15,
# f"SSR = {np.mean(ssr_data_i):.2f}",
# transform=ax.transAxes,
# verticalalignment="top",
# bbox=props,
# )
ax.set_title(
f"{plot_variable_names[col_idx]} (SSR={ssr_data_i_mean:.2f})",
pad=10,
)
cax = ax.inset_axes([0.1, -0.15, 0.8, 0.05])
fig.colorbar(
im,
cax=cax,
orientation="horizontal",
label=f"SSR {plot_variable_names[col_idx]}",
)
fig.subplots_adjust(top=0.85, bottom=0.25, left=0.08, right=0.95)
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, filename)
plt.savefig(save_path, bbox_inches="tight")
plt.close(fig)
return save_path
[docs]
def plot_spread_skill_ratio_hexbin(
predictions, # Model predictions (fine predicted)
targets, # Ground truth (fine true)
variable_names=None,
filename="validation_spread_skill_ratio_hexbin.png",
save_dir=None,
figsize_multiplier=None, # Base size per subplot
):
"""
Plot spatial spread skill ratio scatterplot, where each point represent a prediction for a single pixel, single timestep:
SSR(x, y) = spread(x,y) / skill(x,y)
where spread(x,y) = temporal mean of standard deviation of ensemble members predictions
and skill = temporal mean of RMSE of the mean of the ensemble members.
Parameters
----------
predictions : torch.Tensor or np.array
Model predictions of shape [ensemble_size, batch_size, num_variables, h, w]
It is very important not to switch dimensions order.
ensemble_size must be greater or equal than 2 for spread skill ratio to be computed.
targets : torch.Tensor or np.array
Ground truth of shape [batch_size, num_variables, h, w]
variable_names : list of str, optional
Variable names or identifiers.
filename : str, optional
Output filename for saving the plot.
save_dir : str, optional
Directory to save the plot.
figsize_multiplier : int, optional
Base size multiplier for subplots.
Returns
-------
None
"""
if save_dir is None:
save_dir = PlotConfig.DEFAULT_SAVE_DIR
if figsize_multiplier is None:
figsize_multiplier = PlotConfig.DEFAULT_FIGSIZE_MULTIPLIER
# Convert tensors to numpy
if hasattr(predictions, "detach"):
predictions = predictions.detach().cpu().numpy()
if hasattr(targets, "detach"):
targets = targets.detach().cpu().numpy()
if len(predictions.shape) != 5:
raise ValueError(
"predictions needs to be 5 dimensional tensor / array [ensemble_size, temporal_size, n_vars, h, w]."
)
if predictions.shape[0] == 1:
raise ValueError(
"predictions needs to contain more than 1 member to compute spread skill ratio."
)
E, T, n_vars, h, w = predictions.shape
if targets.shape[1] != n_vars:
raise ValueError("targets and predictions must have same number of variables")
if variable_names is None:
variable_names = [f"VAR_{i}" for i in range(n_vars)]
plot_variable_names = [PlotConfig.get_plot_name(var) for var in variable_names]
rmse_list = []
spread_list = []
mean_ssr_list = spread_skill_ratio(
predictions, targets, variable_names, pixel_wise=False
)
# MAE averaged over time for color scaling
for i in range(n_vars):
# mae_data = np.mean(np.abs(predictions[:, i] - targets[:, i]), axis=0)
pred_i = PlotConfig.convert_units(variable_names[i], predictions[:, :, i])
tgt_i = PlotConfig.convert_units(variable_names[i], targets[:, i])
mean_pred_i = np.mean(pred_i, axis=0)
rmse_data_i = np.abs((mean_pred_i - tgt_i)).flatten()
rmse_list.append(rmse_data_i)
spread_i = np.std(pred_i, axis=0).flatten()
spread_list.append(spread_i)
ncols = n_vars
nrows = (n_vars + ncols - 1) // ncols
fig, axes = plt.subplots(
nrows,
ncols,
figsize=(ncols * figsize_multiplier, figsize_multiplier),
)
axes = np.atleast_1d(axes).ravel()
for ax in axes:
ax.set_box_aspect(1)
last_hb = None
for i, ax in enumerate(axes):
if i >= n_vars:
ax.set_visible(False)
continue
rmse = rmse_list[i]
spread = spread_list[i]
hb = ax.hexbin(
rmse,
spread,
gridsize=100,
cmap="jet",
bins="log",
mincnt=1,
)
last_hb = hb
textstr = f"SSR: {mean_ssr_list[i]:.3f}"
ax.text(
0.05,
0.95,
textstr,
transform=ax.transAxes,
fontsize=10,
verticalalignment="top",
)
var_min = min(rmse.min(), spread.min())
var_max = max(rmse.max(), spread.max())
margin = 0.05 * (var_max - var_min)
plot_min = var_min - margin
plot_max = var_max + margin
# Identity line
ax.plot([plot_min, plot_max], [plot_min, plot_max], "r--", alpha=0.7)
ax.set_xlim(plot_min, plot_max)
ax.set_ylim(plot_min, plot_max)
ax.set_title(plot_variable_names[i])
ax.xaxis.set_major_locator(ticker.MaxNLocator(5))
ax.yaxis.set_major_locator(ticker.MaxNLocator(5))
if i % ncols == 0:
ax.set_ylabel("Ensemble std")
else:
ax.set_ylabel("")
if i >= (nrows - 1) * ncols:
ax.set_xlabel("RMSE")
else:
ax.set_xlabel("")
ax_last = axes[min(n_vars - 1, len(axes) - 1)]
cax = ax_last.inset_axes([1.05, 0.0, 0.04, 1.0])
cbar = fig.colorbar(last_hb, cax=cax)
cbar.set_label(r"$\log_{10}[\mathrm{Count}]$")
plt.subplots_adjust(
hspace=0.1, wspace=0.3, left=0.1, right=0.9, top=0.9, bottom=0.1
)
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, filename)
plt.savefig(save_path, bbox_inches="tight")
plt.close(fig)
return save_path
[docs]
def plot_validation_pdfs(
predictions, # Model predictions (fine predicted)
targets, # Ground truth (fine true)
coarse_inputs=None, # Coarse inputs for comparison (optional)
variable_names=None, # List of variable names
filename="validation_pdfs.png",
save_dir="./results",
figsize_multiplier=4, # Base size per subplot
save_npz=False,
):
"""
Create PDF (Probability Density Function) plots comparing distributions of
model predictions vs ground truth for all variables.
Parameters
----------
predictions : torch.Tensor or np.array
Model predictions of shape [batch_size, num_variables, h, w]
targets : torch.Tensor or np.array
Ground truth of shape [batch_size, num_variables, h, w]
coarse_inputs : torch.Tensor or np.array, optional
Coarse inputs of shape [batch_size, num_variables, h, w]
variable_names : list of str, optional
Names of the variables for subplot titles
filename : str, optional
Output filename
save_dir : str, optional
Directory to save the plot
figsize_multiplier : int, optional
Base size multiplier for subplots
save_npz : bool, optional
If True, saves the PDF diagnostics to a compressed .npz file.
Returns
-------
None
The function saves the plot to disk and does not return any value.
Notes
-----
- Creates horizontal subplots (one per variable) showing PDFs
- Each subplot shows up to 3 lines: Predictions, Ground Truth, and Coarse Inputs
- Uses automatic color and linestyle cycling based on global matplotlib settings
- Calculates and displays key statistics for each distribution
- Handles both PyTorch tensors and numpy arrays
Examples
--------
>>> predictions = np.random.randn(10, 3, 64, 64) # 10 samples, 3 variables
>>> targets = np.random.randn(10, 3, 64, 64)
>>> plot_validation_pdfs(predictions, targets, variable_names=['Temp', 'Pres', 'Humid'])
"""
# Convert to numpy if they're tensors
if hasattr(predictions, "detach"):
predictions = predictions.detach().cpu().numpy()
if hasattr(targets, "detach"):
targets = targets.detach().cpu().numpy()
if coarse_inputs is not None and hasattr(coarse_inputs, "detach"):
coarse_inputs = coarse_inputs.detach().cpu().numpy()
batch_size, num_vars, h, w = predictions.shape
# Default variable names if not provided
if variable_names is None:
variable_names = [f"Variable {i + 1}" for i in range(num_vars)]
plot_variable_names = [PlotConfig.get_plot_name(var) for var in variable_names]
# Calculate grid dimensions for horizontal layout
ncols = num_vars
nrows = 1 # Single row for horizontal layout
# Create figure with horizontal subplots
fig, axes = plt.subplots(
nrows, ncols, figsize=(ncols * figsize_multiplier, figsize_multiplier)
)
# Handle single subplot case
if num_vars == 1:
axes = np.array([axes])
if axes.ndim == 0:
axes = np.array([axes])
axes = axes.flatten()
for ax in axes:
ax.set_box_aspect(1)
plt.subplots_adjust(
hspace=0.1, wspace=0.3, left=0.1, right=0.9, top=0.9, bottom=0.1
)
if save_npz:
pdf_npz_data = {}
# Plot PDF for each variable
for i, (var_name, ax) in enumerate(zip(variable_names, axes)):
if i >= num_vars:
ax.set_visible(False)
continue
linestyles = mpltex.linestyle_generator(markers=[])
# Flatten the spatial dimensions
pred_i = PlotConfig.convert_units(var_name, predictions[:, i])
tgt_i = PlotConfig.convert_units(var_name, targets[:, i])
plot_name = plot_variable_names[i]
pred_flat = pred_i.reshape(-1)
target_flat = tgt_i.reshape(-1)
# Collect all data for combined range
all_data = [pred_flat, target_flat]
if coarse_inputs is not None:
# coarse_flat = coarse_inputs[:, i, :, :].flatten() #.mean(axis=0).reshape(-1)
coarse_i = PlotConfig.convert_units(var_name, coarse_inputs[:, i])
coarse_flat = coarse_i.reshape(-1)
all_data.append(coarse_flat)
# Calculate global range for consistent x-axis
all_values = np.concatenate(all_data)
data_min = np.percentile(all_values, 0.25) # 0.5th percentile
data_max = np.percentile(all_values, 99.5) # 99.5th percentile
data_range = data_max - data_min
# Extend range slightly for better visualization
x_min = data_min - 0.05 * data_range
x_max = data_max + 0.05 * data_range
# Create bins for PDF calculation
n_bins = 100
bins = np.linspace(x_min, x_max, n_bins + 1)
# Small epsilon to avoid log(0)
epsilon = 1e-12
# Plot log PDFs
# Plot predictions
hist_pred, bin_edges = np.histogram(pred_flat, bins=bins, density=True)
bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])
log_hist_pred = np.log10(hist_pred + epsilon)
ax.plot(bin_centers, log_hist_pred, label="Pred", **next(linestyles))
# Plot ground truth
hist_target, _ = np.histogram(target_flat, bins=bins, density=True)
log_hist_target = np.log10(hist_target + epsilon)
ax.plot(bin_centers, log_hist_target, label="Truth", **next(linestyles))
# Plot coarse inputs if available
if coarse_inputs is not None:
hist_coarse, _ = np.histogram(coarse_flat, bins=bins, density=True)
log_hist_coarse = np.log10(hist_coarse + epsilon)
ax.plot(bin_centers, log_hist_coarse, label="Coarse", **next(linestyles))
# Calculate and display statistics
stats_text = []
# Predictions statistics
pred_mean = np.mean(pred_flat)
pred_std = np.std(pred_flat)
stats_text.append(f"Predictions: μ={pred_mean:.3f}, σ={pred_std:.3f}")
# Ground truth statistics
target_mean = np.mean(target_flat)
target_std = np.std(target_flat)
stats_text.append(f"Ground Truth: μ={target_mean:.3f}, σ={target_std:.3f}")
# Coarse statistics if available
if coarse_inputs is not None:
coarse_mean = np.mean(coarse_flat)
coarse_std = np.std(coarse_flat)
stats_text.append(f"Coarse: μ={coarse_mean:.3f}, σ={coarse_std:.3f}")
# Calculate KL divergence between predictions and ground truth
hist_pred_safe = hist_pred + epsilon
hist_target_safe = hist_target + epsilon
# Normalize to probability distributions
hist_pred_safe = hist_pred_safe / np.sum(hist_pred_safe)
hist_target_safe = hist_target_safe / np.sum(hist_target_safe)
kl_divergence = np.sum(
hist_target_safe * np.log(hist_target_safe / hist_pred_safe)
)
# Add KL divergence to statistics
stats_text.append(f"KL Divergence: {kl_divergence:.4f}")
# Calculate correlation coefficient
correlation = np.corrcoef(pred_flat, target_flat)[0, 1]
stats_text.append(f"Correlation: {correlation:.4f}")
# Add statistics as text box
# stats_str = '\n'.join(stats_text)
# ax.text(0.5, 1.02, stats_str, transform=ax.transAxes,
# verticalalignment='bottom', horizontalalignment='center')
# Log statistics instead of plotting them
print(f"[PDF stats] {plot_name}")
print(f" Predictions: μ={pred_mean:.3f}, σ={pred_std:.3f}")
print(f" Ground Truth: μ={target_mean:.3f}, σ={target_std:.3f}")
if coarse_inputs is not None:
print(f" Coarse: μ={coarse_mean:.3f}, σ={coarse_std:.3f}")
print(f" KL Divergence: {kl_divergence:.4f}")
print(f" Correlation: {correlation:.4f}")
# ax.set_xlabel(f'{var_name}')
ax.set_xlabel(plot_name)
# Only show y-label for leftmost subplot
if i == 0:
# ax.set_ylabel('log₁₀(PDF)')
ax.set_ylabel(r"$\log_{10}(\mathrm{PDF})$")
# Add grid
ax.grid(True, alpha=0.3, linestyle="--")
# Add legend
ax.legend()
# Set consistent x-limits
# ax.set_xlim(x_min, x_max)
# Set y-limits for log plot (handle cases where log values might be very negative)
y_min = min(log_hist_pred.min(), log_hist_target.min())
if coarse_inputs is not None:
y_min = min(y_min, log_hist_coarse.min())
y_max = max(log_hist_pred.max(), log_hist_target.max())
if coarse_inputs is not None:
y_max = max(y_max, log_hist_coarse.max())
# Add small margin to y-limits
y_margin = 0.1 * (y_max - y_min) if y_max > y_min else 0.1
ax.set_ylim(y_min - y_margin, y_max + y_margin)
# Use scientific notation for large ranges
if data_range > 1000:
ax.ticklabel_format(style="sci", axis="x", scilimits=(0, 0))
if save_npz:
key = f"{var_name}__pdf__"
pdf_npz_data[key + "bin_centers"] = bin_centers
pdf_npz_data[key + "log_pred"] = log_hist_pred
pdf_npz_data[key + "log_truth"] = log_hist_target
pdf_npz_data[key + "mean_pred"] = pred_mean
pdf_npz_data[key + "std_pred"] = pred_std
pdf_npz_data[key + "mean_truth"] = target_mean
pdf_npz_data[key + "std_truth"] = target_std
pdf_npz_data[key + "kl"] = kl_divergence
pdf_npz_data[key + "corr"] = correlation
if coarse_inputs is not None:
pdf_npz_data[key + "log_coarse"] = log_hist_coarse
pdf_npz_data[key + "mean_coarse"] = coarse_mean
pdf_npz_data[key + "std_coarse"] = coarse_std
# Ensure save directory exists
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, filename)
if save_npz:
npz_path = os.path.splitext(save_path)[0] + ".npz"
np.savez_compressed(npz_path, **pdf_npz_data)
plt.savefig(save_path, bbox_inches="tight")
plt.close(fig)
return save_path
[docs]
def plot_power_spectra(
predictions, # Model predictions
targets, # Ground truth
dlat, # Grid spacing in latitude (degrees)
dlon, # Grid spacing in longitude (degrees)
coarse_inputs=None, # Coarse inputs for comparison (optional)
variable_names=None, # List of variable names
filename="power_spectra_physical.png",
save_dir="./results",
figsize_multiplier=4,
save_npz=False,
):
"""
Calculate and plot power spectra with proper physical wavenumbers.
Parameters
----------
predictions : torch.Tensor or np.array
Model predictions of shape [batch_size, num_variables, nh, nw]
targets : torch.Tensor or np.array
Ground truth of shape [batch_size, num_variables, nh, nw]
dlat : float
Grid spacing in latitude (degrees)
dlon : float
Grid spacing in longitude (degrees)
coarse_inputs : torch.Tensor or np.array, optional
Coarse inputs of shape [batch_size, num_variables, nh, nw]
variable_names : list of str, optional
Names of the variable names for subplot titles
filename : str, optional
Output filename
save_dir : str, optional
Directory to save the plot
figsize_multiplier : int, optional
Base size multiplier for subplots
save_npz : bool, optional
If True, saves the PDF diagnostics to a compressed .npz file.
Returns
-------
None
"""
# Convert to numpy if they're tensors
if hasattr(predictions, "detach"):
predictions = predictions.detach().cpu().numpy()
if hasattr(targets, "detach"):
targets = targets.detach().cpu().numpy()
if coarse_inputs is not None and hasattr(coarse_inputs, "detach"):
coarse_inputs = coarse_inputs.detach().cpu().numpy()
batch_size, num_vars, nh, nw = predictions.shape
# Default variable names if not provided
if variable_names is None:
variable_names = [f"Variable {i + 1}" for i in range(num_vars)]
# plot_variable_names = [PlotConfig.get_plot_name(var) for var in variable_names]
# Calculate wavenumbers
# FFT frequencies are in cycles per grid spacing
fft_freq_lat = np.fft.fftfreq(nh, d=dlat) # cycles per degree in lat direction
fft_freq_lon = np.fft.fftfreq(nw, d=dlon) # cycles per degree in lon direction
# Shift frequencies so zero is at center
fft_freq_lat_shifted = np.fft.fftshift(fft_freq_lat)
fft_freq_lon_shifted = np.fft.fftshift(fft_freq_lon)
# Create 2D wavenumber grid
k_lat, k_lon = np.meshgrid(fft_freq_lon_shifted, fft_freq_lat_shifted)
# Calculate magnitude of wavenumber vector (in cycles/degree)
k_mag = np.sqrt(k_lat**2 + k_lon**2)
# Create bins for radial averaging
max_k = np.min([np.max(np.abs(fft_freq_lat)), np.max(np.abs(fft_freq_lon))])
k_bins = np.linspace(0, max_k, min(nh, nw) // 2)
k_centers = 0.5 * (k_bins[1:] + k_bins[:-1])
# Create figure
ncols = num_vars
nrows = 1 # 2 Two rows: one for 2D spectrum, one for 1D spectrum
fig, axes = plt.subplots(
nrows,
ncols,
figsize=(ncols * figsize_multiplier, nrows * figsize_multiplier),
squeeze=False,
) # nrows * figsize_multiplier
plt.subplots_adjust(
hspace=0.2, wspace=0.3, left=0.1, right=0.9, top=0.9, bottom=0.1
)
axes = axes.ravel()
for ax in axes:
ax.set_box_aspect(1)
"""
# Handle single subplot case
if num_vars == 1:
axes = np.array([[axes[0]], [axes[1]]])
elif axes.ndim == 1:
axes = axes.reshape(nrows, ncols)
"""
if save_npz:
spectra_npz_data = {}
spectra_npz_data["__meta__dlat"] = dlat
spectra_npz_data["__meta__dlon"] = dlon
spectra_npz_data["__meta__variables"] = np.array(variable_names)
# Process each variable
for i, var_name in enumerate(variable_names):
if i >= num_vars:
continue
linestyles = mpltex.linestyle_generator(markers=[])
# plot_name = plot_variable_names[i]
# Initialize arrays for averaged PSDs
psd2d_pred_sum = np.zeros((nh, nw))
psd2d_target_sum = np.zeros((nh, nw))
if coarse_inputs is not None:
psd2d_coarse_sum = np.zeros((nh, nw))
# Process each sample in the batch
for b in range(batch_size):
# Predictions
# field_pred = predictions[b, i]
field_pred = PlotConfig.convert_units(var_name, predictions[b, i])
psd2d_pred = calculate_psd2d_simple(field_pred)
psd2d_pred_sum += psd2d_pred
# Targets
# field_target = targets[b, i]
field_target = PlotConfig.convert_units(var_name, targets[b, i])
psd2d_target = calculate_psd2d_simple(field_target)
psd2d_target_sum += psd2d_target
# Coarse inputs
if coarse_inputs is not None:
# field_coarse = coarse_inputs[b, i]
field_coarse = PlotConfig.convert_units(var_name, coarse_inputs[b, i])
psd2d_coarse = calculate_psd2d_simple(field_coarse)
psd2d_coarse_sum += psd2d_coarse
# Average over batch
psd2d_pred_avg = psd2d_pred_sum / batch_size
psd2d_target_avg = psd2d_target_sum / batch_size
if coarse_inputs is not None:
psd2d_coarse_avg = psd2d_coarse_sum / batch_size
# Calculate 1D radial spectra
psd1d_pred = radial_average_psd(psd2d_pred_avg, k_mag, k_bins)
psd1d_target = radial_average_psd(psd2d_target_avg, k_mag, k_bins)
if coarse_inputs is not None:
psd1d_coarse = radial_average_psd(psd2d_coarse_avg, k_mag, k_bins)
if save_npz:
key = f"{var_name}__spectra__"
spectra_npz_data[key + "k"] = k_centers
spectra_npz_data[key + "psd_pred"] = psd1d_pred
spectra_npz_data[key + "psd_truth"] = psd1d_target
if coarse_inputs is not None:
spectra_npz_data[key + "psd_coarse"] = psd1d_coarse
"""
# --- Plot 2D PSD (top row) ---
ax_top = axes[0, i] if num_vars > 1 else axes[0]
# Use k_lon and k_lat for the axes instead of lat/lon
k_lon_min, k_lon_max = fft_freq_lon_shifted[0], fft_freq_lon_shifted[-1]
k_lat_min, k_lat_max = fft_freq_lat_shifted[0], fft_freq_lat_shifted[-1]
im = ax_top.imshow(np.log10(psd2d_pred_avg + 1e-12),
cmap=cmap_white_jet,
aspect='auto',
origin='lower',
extent=[k_lon_min, k_lon_max, k_lat_min, k_lat_max])
#ax_top.set_title(f'{var_name}')
ax_top.set_title(plot_name)
# Only add y-axis label for leftmost column
if i == 0:
ax_top.set_ylabel(r'$\mathrm{k_{lat}}$ (cycles/°)')
else:
ax_top.set_ylabel('')
# Remove y-axis tick labels for non-leftmost columns
ax_top.tick_params(axis='y', labelleft=False)
# Always show x-axis label
ax_top.set_xlabel(r'$\mathrm{k_{lon}}$ (cycles/°)')
# Add grid for better readability
ax_top.grid(True, alpha=0.3, linestyle='--')
# Add colorbar for the last column only
if i == num_vars - 1:
cax = ax_top.inset_axes([1.05, 0, 0.05, 1]) # [x, y, w, h] relative to axes
cbar = plt.colorbar(im, cax=cax, orientation='vertical')
cbar.set_label('log₁₀(PSD)')
"""
# --- Plot 1D Radial Spectrum (bottom row) ---
# ax_bottom = axes[1, i] if num_vars > 1 else axes[1]
ax_bottom = axes[i]
# Plot all spectra
ax_bottom.loglog(k_centers, psd1d_pred, label="Pred", **next(linestyles))
ax_bottom.loglog(k_centers, psd1d_target, label="Truth", **next(linestyles))
if coarse_inputs is not None:
ax_bottom.loglog(
k_centers, psd1d_coarse, label="Coarse", **next(linestyles)
)
# Only add y-axis label for leftmost column
if i == 0:
ax_bottom.set_ylabel("PSD(k)")
else:
ax_bottom.set_ylabel("")
# Always show x-axis label
ax_bottom.set_xlabel("Wavenumber k [cycles/°]")
ax_bottom.legend()
ax_bottom.grid(True, alpha=0.3, which="both")
# Set reasonable axis limits
valid = (k_centers > 0) & (psd1d_target > 0)
if np.any(valid):
ax_bottom.set_xlim(k_centers[valid][0] * 0.8, k_centers[valid][-1] * 1.2)
# Find y-range
y_min = min(psd1d_pred[valid].min(), psd1d_target[valid].min())
y_max = max(psd1d_pred[valid].max(), psd1d_target[valid].max())
if coarse_inputs is not None:
y_min = min(y_min, psd1d_coarse[valid].min())
y_max = max(y_max, psd1d_coarse[valid].max())
ax_bottom.set_ylim(y_min * 0.5, y_max * 2.0)
# Save figure
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, filename)
if save_npz:
npz_path = os.path.splitext(save_path)[0] + ".npz"
np.savez_compressed(npz_path, **spectra_npz_data)
plt.savefig(save_path, bbox_inches="tight")
plt.close(fig)
return save_path
[docs]
def calculate_psd2d_simple(field):
"""
Simple 2D PSD calculation without preprocessing.
"""
fft = np.fft.fft2(field)
psd2d = np.abs(np.fft.fftshift(fft)) ** 2
return psd2d
[docs]
def radial_average_psd(psd2d, k_mag, k_bins):
"""
Radially average 2D PSD using wavenumber magnitude.
"""
# Flatten arrays
k_flat = k_mag.flatten()
psd_flat = psd2d.flatten()
# Use binned_statistic for radial averaging
psd1d, _, _ = stats.binned_statistic(
k_flat, psd_flat, statistic="mean", bins=k_bins
)
# Multiply by area of annulus (2πkΔk) to get proper spectral density
k_centers = 0.5 * (k_bins[1:] + k_bins[:-1])
delta_k = k_bins[1:] - k_bins[:-1]
area = 2 * np.pi * k_centers * delta_k
# Avoid division by zero
valid = area > 0
psd1d[valid] = psd1d[valid] * area[valid]
return psd1d
[docs]
def plot_qq_quantiles(
predictions, # Model predictions
targets, # Ground truth
coarse_inputs, # Coarse inputs
variable_names=None, # List of variable names
units=None, # List of units for each variable
quantiles=[0.90, 0.95, 0.975, 0.99, 0.995],
filename="qq_quantiles.png",
save_dir="./results",
figsize_multiplier=4,
save_npz=False,
):
"""
Create QQ-plats at different quantiles comparing model predictions and
coarse inputs against ground truth.
For each variable, plots quantiles of predictions and coarse inputs
against quantiles of ground truth with a 1:1 reference line.
Parameters
----------
predictions : torch.Tensor or np.array
Model predictions of shape [batch_size, num_variables, h, w]
targets : torch.Tensor or np.array
Ground truth of shape [batch_size, num_variables, h, w]
coarse_inputs : torch.Tensor or np.array
Coarse inputs of shape [batch_size, num_variables, h, w]
variable_names : list of str, optional
Names of the variables for subplot titles.
If None, uses ["VAR_0", "VAR_1", ...]
units : list of str, optional
Units for each variable for axis labels.
If None, uses empty strings.
quantiles : list of float, optional
Quantile values to plot (e.g., [0.90, 0.95, 0.975, 0.99, 0.995])
filename : str, optional
Output filename
save_dir : str, optional
Directory to save the plot
figsize_multiplier : int, optional
Base size multiplier for subplots
save_npz : bool, optional
If True, saves the PDF diagnostics to a compressed .npz file.
Returns
-------
save_path : str
Path to the saved figure
"""
# Convert tensors → numpy
if hasattr(predictions, "detach"):
predictions = predictions.detach().cpu().numpy()
if hasattr(targets, "detach"):
targets = targets.detach().cpu().numpy()
if hasattr(coarse_inputs, "detach"):
coarse_inputs = coarse_inputs.detach().cpu().numpy()
batch_size, num_vars, h, w = predictions.shape
# Default variable names if not provided
if variable_names is None:
variable_names = [f"VAR_{i}" for i in range(num_vars)]
plot_variable_names = [PlotConfig.get_plot_name(var) for var in variable_names]
# Default units if not provided
if units is None:
units = [""] * num_vars
# Figure setup
fig, axes = plt.subplots(
1,
num_vars,
figsize=(num_vars * figsize_multiplier, figsize_multiplier),
constrained_layout=True,
)
if num_vars > 1:
axes = axes.ravel()
# Handle single subplot case
else:
axes = np.array([axes])
for ax in axes:
ax.set_box_aspect(1)
if save_npz:
qq_npz_data = {}
qq_npz_data["__meta__variables"] = np.array(variable_names)
qq_npz_data["__meta__quantiles"] = np.array(quantiles)
for i, var_name in enumerate(variable_names):
linestyles = mpltex.linestyle_generator(lines=[])
ax = axes[i]
plot_name = plot_variable_names[i]
# Flatten spatial dims and average over batch
# target_vals = targets[:, i]
# pred_vals = predictions[:, i]
# coarse_vals = coarse_inputs[:, i]
pred_vals = PlotConfig.convert_units(var_name, predictions[:, i])
target_vals = PlotConfig.convert_units(var_name, targets[:, i])
coarse_vals = PlotConfig.convert_units(var_name, coarse_inputs[:, i])
# Compute quantiles
qs_target = np.quantile(target_vals, quantiles)
qs_pred = np.quantile(pred_vals, quantiles)
qs_coarse = np.quantile(coarse_vals, quantiles)
if save_npz:
key = f"{var_name}__qq__"
qq_npz_data[key + "quantiles"] = np.array(quantiles)
qq_npz_data[key + "truth"] = qs_target
qq_npz_data[key + "pred"] = qs_pred
qq_npz_data[key + "coarse"] = qs_coarse
print(f"[QQ Quantiles] {plot_name}")
for q, qt, qp, qc in zip(quantiles, qs_target, qs_pred, qs_coarse):
print(f" q={q:.3f} | Truth={qt:.4f} | Pred={qp:.4f} | Coarse={qc:.4f} ")
# ---- Plot predicted quantiles ----
for q_idx, q in enumerate(quantiles):
ax.plot(
qs_target[q_idx],
qs_pred[q_idx],
label=f"{q * 100:.1f}%",
**next(linestyles),
)
# ---- Plot coarse quantiles ----
ax.plot(
qs_target,
qs_coarse,
c="black",
marker="s",
label="Coarse",
linestyle="None",
)
# ---- 1:1 reference line ----
# Calculate appropriate limits for this variable
min_val = min(qs_target.min(), qs_pred.min(), qs_coarse.min())
max_val = max(qs_target.max(), qs_pred.max(), qs_coarse.max())
margin = 0.0
plot_min = min_val - margin
plot_max = max_val + margin
ax.plot(
[plot_min, plot_max], [plot_min, plot_max], "r--", alpha=0.7, label="1:1"
)
# Set axis limits
# ax.set_xlim(plot_min, plot_max)
# ax.set_ylim(plot_min, plot_max)
ax.xaxis.set_major_locator(ticker.MaxNLocator(4))
ax.yaxis.set_major_locator(ticker.MaxNLocator(4))
# Labels and formatting
# ax.set_title(var_name)
ax.set_title(plot_name)
# Add unit to labels if provided
unit_str = f" ({units[i]})" if units[i] else ""
# Only add y-axis label for leftmost plot
if i == 0:
ax.set_ylabel(f"Predicted/Coarse quantiles{unit_str}")
ax.set_xlabel(f"True quantiles{unit_str}")
ax.grid(True, linestyle="--", alpha=0.3)
# Add legend only for first subplot
if i == 0:
ax.legend()
# Save figure
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, filename)
if save_npz:
npz_path = os.path.splitext(save_path)[0] + ".npz"
np.savez_compressed(npz_path, **qq_npz_data)
plt.savefig(save_path, bbox_inches="tight")
plt.close(fig)
return save_path
[docs]
def dry_frequency_map(array, threshold):
"""
Compute spatial dry pixels proportion maps. Value of each pixel corresponds to the frequency of dry weather for this pixel.
Parameters
----------
array : torch.Tensor or np.array
Model predictions of shape [batch_size, h, w]
threshold : float
threshold for precipitation (expressed in mm): under it, pixel is considered dry.
Returns
-------
np.ndarray(np.float64) of shape [h,w]
"""
# convert to numpy if tensor :
if hasattr(array, "detach"):
array = array.detach().cpu().numpy()
dry_array = (array < threshold).astype(np.float64)
dry_array_map = np.mean(dry_array, axis=0)
return dry_array_map
[docs]
def plot_dry_frequency_map(
predictions, # Model predictions precipitation (fine predicted)
targets, # Ground truth precipitation (fine true)
threshold, # threshold to define dry and wet (in mm)
lat_1d,
lon_1d,
filename="validation_dry_frequency_map.png",
save_dir=None,
figsize_multiplier=None, # Base size per subplot
):
"""
Plot spatial dry pixels proportion maps. Value of each pixel corresponds to the frequency of dry weather for this pixel.
Parameters
----------
predictions : torch.Tensor or np.array
Model predictions of shape [batch_size, h, w]
targets : torch.Tensor or np.array
Ground truth of shape [batch_size, h, w]
threshold : float
threshold for precipitation (expressed in mm): under it, pixel is considered dry.
lat_1d : array-like
1D array of latitude coordinates with shape [H].
lon_1d : array-like
1D array of longitude coordinates with shape [W].
filename : str, optional
Output filename for saving the plot.
save_dir : str, optional
Directory to save the plot.
figsize_multiplier : int, optional
Base size multiplier for subplots.
Returns
-------
None
"""
if save_dir is None:
save_dir = PlotConfig.DEFAULT_SAVE_DIR
if figsize_multiplier is None:
figsize_multiplier = PlotConfig.DEFAULT_FIGSIZE_MULTIPLIER
# Convert tensors to numpy
if hasattr(predictions, "detach"):
predictions = predictions.detach().cpu().numpy()
if hasattr(targets, "detach"):
targets = targets.detach().cpu().numpy()
if hasattr(lat_1d, "detach"):
lat_1d = lat_1d.detach().cpu().numpy()
if hasattr(lon_1d, "detach"):
lon_1d = lon_1d.detach().cpu().numpy()
lat_min, lat_max = lat_1d.min(), lat_1d.max()
lon_min, lon_max = lon_1d.min(), lon_1d.max()
_, h, w = targets.shape
lat_block = np.linspace(lat_max, lat_min, h)
lon_block = np.linspace(lon_min, lon_max, w)
lat, lon = np.meshgrid(lat_block, lon_block, indexing="ij")
lon_center = float((lon_min + lon_max) / 2)
cmap = PlotConfig.get_colormap(
"dry frequency"
) # need to define the comap in PlotConfig
# convert units :
predictions = PlotConfig.convert_units("precipitation", predictions)
targets = PlotConfig.convert_units("precipitation", targets)
dry_freq_pred_map = dry_frequency_map(predictions, threshold)
# dry_freq_pred = np.mean(dry_freq_pred_map)
dry_freq_targ_map = dry_frequency_map(targets, threshold)
# dry_freq_targ = np.mean(dry_freq_targ_map)
vmin = 0
vmax = 1
base_width_per_panel = 4.5
base_height_per_panel = 3.0
fig_width = base_width_per_panel
fig_height = 3 * base_height_per_panel
fig, axes = plt.subplots(
3,
figsize=(fig_width, fig_height),
subplot_kw={
"projection": ccrs.PlateCarree(central_longitude=lon_center)
}, # ccrs.Mercator(central_longitude=lon_center)
gridspec_kw={"wspace": 0.1},
)
fig.subplots_adjust(
top=0.9, bottom=0.1, left=0.1, right=0.9, wspace=0.1, hspace=0.1
)
im = axes[0].pcolormesh(
lon,
lat,
dry_freq_pred_map,
vmin=vmin,
vmax=vmax,
cmap=cmap,
transform=ccrs.PlateCarree(),
shading="auto",
)
axes[0].set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree())
axes[0].coastlines(linewidth=0.6)
axes[0].add_feature(
cfeature.BORDERS.with_scale("50m"),
linewidth=0.6,
linestyle="--",
edgecolor="black",
zorder=11,
)
axes[0].add_feature(
cfeature.LAKES.with_scale("50m"),
edgecolor="black",
facecolor="none",
linewidth=0.6,
zorder=9,
)
# ax.set_aspect("auto")
axes[0].set_xticks([])
axes[0].set_yticks([])
axes[0].set_title("Predicted")
im = axes[1].pcolormesh(
lon,
lat,
dry_freq_targ_map,
vmin=vmin,
vmax=vmax,
cmap=cmap,
transform=ccrs.PlateCarree(),
shading="auto",
)
axes[1].set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree())
axes[1].coastlines(linewidth=0.6)
axes[1].add_feature(
cfeature.BORDERS.with_scale("50m"),
linewidth=0.6,
linestyle="--",
edgecolor="black",
zorder=11,
)
axes[1].add_feature(
cfeature.LAKES.with_scale("50m"),
edgecolor="black",
facecolor="none",
linewidth=0.6,
zorder=9,
)
# ax.set_aspect("auto")
axes[1].set_xticks([])
axes[1].set_yticks([])
axes[1].set_title("Target")
pos0 = axes[0].get_position()
pos1 = axes[1].get_position()
bottom = pos1.y0
top = pos0.y1
height = top - bottom
cax1 = fig.add_axes([0.92, bottom, 0.03, height])
fig.colorbar(im, cax=cax1, label="frequency")
# vmax_diff = max(
# np.abs(np.max(dry_freq_pred_map - dry_freq_targ_map)),
# np.abs(np.min(dry_freq_pred_map - dry_freq_targ_map)),
# )
# norm_diff = mcolors.TwoSlopeNorm(vmin=-1, vcenter=0, vmax = 1)
im = axes[2].pcolormesh(
lon,
lat,
dry_freq_pred_map - dry_freq_targ_map,
norm=mcolors.TwoSlopeNorm(vmin=-vmax, vcenter=0, vmax=vmax),
cmap="seismic",
transform=ccrs.PlateCarree(),
shading="auto",
)
axes[2].set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree())
axes[2].coastlines(linewidth=0.6)
axes[2].add_feature(
cfeature.BORDERS.with_scale("50m"),
linewidth=0.6,
linestyle="--",
edgecolor="black",
zorder=11,
)
axes[2].add_feature(
cfeature.LAKES.with_scale("50m"),
edgecolor="black",
facecolor="none",
linewidth=0.6,
zorder=9,
)
# ax.set_aspect("auto")
axes[2].set_xticks([])
axes[2].set_yticks([])
axes[2].set_title("Predicted frequency - Target frequency")
pos2 = axes[2].get_position()
cax2 = fig.add_axes([0.92, pos2.y0, 0.03, pos2.height])
fig.colorbar(im, cax=cax2, label="frequency")
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, filename)
plt.savefig(save_path, bbox_inches="tight")
plt.close(fig)
return save_path
[docs]
def calculate_pearsoncorr_nparray(arr1, arr2, axis=0):
"""
Calculate Pearson correlation between 2 N-dimensional numpy arrays.
Parameters:
-----------
arr1 : numpy.ndarray
First N-dimensional array
arr2 : numpy.ndarray
Second N-dimensional array (must have same shape as arr1)
axis : int or type of int, default=0
Axis or tuple of axes over which to compute correlation
Returns:
--------
numpy.ndarray
Pearson correlation coefficients. Output has N - len(axis) dimensions
(input shape with the specified axis/axes removed).
"""
if arr1.shape != arr2.shape:
raise ValueError(
f"Arrays must have the same shape. Got {arr1.shape} and {arr2.shape}"
)
# Center the data over axis/axes specified
arr1_centered = arr1 - arr1.mean(axis=axis, keepdims=True)
arr2_centered = arr2 - arr2.mean(axis=axis, keepdims=True)
# Compute correlation over axis/axes specified
numerator = (arr1_centered * arr2_centered).sum(axis=axis)
denominator = np.sqrt(
(arr1_centered**2).sum(axis=axis) * (arr2_centered**2).sum(axis=axis)
)
# Avoid division by zero (set as 0.0 instead of inf or nan)
correlations = np.divide(
numerator, denominator, out=np.zeros_like(numerator), where=denominator != 0
)
return correlations
[docs]
def plot_validation_mvcorr_space(
predictions, # Model predictions (fine predicted)
targets, # Ground truth (fine true)
coarse_inputs=None, # Coarse inputs for comparison (optional)
variable_names=None, # List of variable names
filename="validation_mvcorr_space.png",
save_dir="./results",
figsize_multiplier=4, # Base size per subplot
):
"""
Compute multivariate correlation over the space dimensions and plot as time-series,
comparing model predictions vs ground truth, for all combinations of variables.
Uses Pearson's correlation coefficient.
Parameters
----------
predictions : torch.Tensor or np.array
Model predictions of shape [batch_size, num_variables, h, w]
targets : torch.Tensor or np.array
Ground truth of shape [batch_size, num_variables, h, w]
coarse_inputs : torch.Tensor or np.array, optional
Coarse inputs of shape [batch_size, num_variables, h, w]
variable_names : list of str, optional
Names of the variables for subplot titles
filename : str, optional
Output filename
save_dir : str, optional
Directory to save the plot
figsize_multiplier : int, optional
Base size multiplier for subplots
Returns
-------
save_path : str
Path to the saved figure
"""
# Convert to numpy if they're tensors
if hasattr(predictions, "detach"):
predictions = predictions.detach().cpu().numpy()
if hasattr(targets, "detach"):
targets = targets.detach().cpu().numpy()
if coarse_inputs is not None and hasattr(coarse_inputs, "detach"):
coarse_inputs = coarse_inputs.detach().cpu().numpy()
batch_size, num_vars, h, w = predictions.shape
if num_vars < 2:
print("ERROR: need at least 2 variables but num_vars < 2")
return "0"
# Default variable names if not provided
if variable_names is None:
variable_names = [f"VAR_{i}" for i in range(num_vars)]
# Make list of tuples defining variable combinations
list_var_combos = []
for ii in range(num_vars - 1):
for jj in range(num_vars - 1 - ii):
list_var_combos.append((ii, ii + jj + 1))
# Calculate grid dimensions
ncols = 1
nrows = int(num_vars * (num_vars - 1) / 2) # no. distinct pairs of input variables
fwidth = 6 # longitude range
fheight = nrows * figsize_multiplier
# Set up figure
fig, axes = plt.subplots(
nrows, ncols, figsize=(fwidth, fheight), squeeze=False, sharex=True
)
axes = axes.flatten()
linestyles = mpltex.linestyle_generator(markers=[])
style_truth = next(linestyles)
style_pred = next(linestyles)
style_coarse = next(linestyles) if coarse_inputs is not None else None
var_name_combo_list = []
# Plot correlation timeseries for each combination of variables
# max_count = 0
for i, varComb in enumerate(list_var_combos):
var_name_combo = variable_names[varComb[0]] + " & " + variable_names[varComb[1]]
var_name_combo_list.append(var_name_combo)
# Compute Correlation
pred_corr = calculate_pearsoncorr_nparray(
predictions[:, varComb[0], :, :],
predictions[:, varComb[1], :, :],
axis=(1, 2),
)
target_corr = calculate_pearsoncorr_nparray(
targets[:, varComb[0], :, :], targets[:, varComb[1], :, :], axis=(1, 2)
)
if coarse_inputs is not None:
coarse_corr = calculate_pearsoncorr_nparray(
coarse_inputs[:, varComb[0], :, :],
coarse_inputs[:, varComb[1], :, :],
axis=(1, 2),
)
ax = axes[i]
time_index = range(batch_size)
ax.plot(time_index, target_corr, label="Truth", linewidth=1.0, **style_truth)
ax.plot(time_index, pred_corr, label="Prediction", linewidth=1.0, **style_pred)
if coarse_inputs is not None:
ax.plot(
time_index, coarse_corr, label="Coarse", linewidth=1.0, **style_coarse
)
ax.grid(True, alpha=0.3)
ax.set_ylabel(var_name_combo)
# ax.set_ylim(-1, 1)
ax.set_xlim(0, batch_size - 1)
if i == 0:
ax.legend()
axes[0].set_title("Spatial Pearson Correlation Over Time")
axes[-1].set_xlabel("Time Step")
# Ensure save directory exists
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, filename)
plt.savefig(save_path, bbox_inches="tight")
plt.close(fig)
return save_path
[docs]
def plot_validation_mvcorr(
predictions, # Model predictions (fine predicted)
targets, # Ground truth (fine true)
lat,
lon,
coarse_inputs=None, # Coarse inputs for comparison (optional)
variable_names=None, # List of variable names
filename="validation_mvcorr_time.png",
save_dir="./results",
figsize_multiplier=4, # Base size per subplot
):
"""
Compute multivariate correlation over the time dimension and plot as maps,
comparing model predictions vs ground truth, for all combinations of variables.
Uses Pearson's correlation coefficient.
Parameters
----------
predictions : torch.Tensor or np.array
Model predictions of shape [batch_size, num_variables, h, w]
targets : torch.Tensor or np.array
Ground truth of shape [batch_size, num_variables, h, w]
lat : array-like
2D array of latitude coordinates with shape [h, w].
lon : array-like
2D array of longitude coordinates with shape [h, w].
coarse_inputs : torch.Tensor or np.array, optional
Coarse inputs of shape [batch_size, num_variables, h, w]
variable_names : list of str, optional
Names of the variables for subplot titles
filename : str, optional
Output filename
save_dir : str, optional
Directory to save the plot
figsize_multiplier : int, optional
Base size multiplier for subplots
Returns
-------
save_path : str
Path to the saved figure
"""
if save_dir is None:
save_dir = PlotConfig.DEFAULT_SAVE_DIR
if figsize_multiplier is None:
figsize_multiplier = PlotConfig.DEFAULT_FIGSIZE_MULTIPLIER
# Convert to numpy if they're tensors
if hasattr(predictions, "detach"):
predictions = predictions.detach().cpu().numpy()
if hasattr(targets, "detach"):
targets = targets.detach().cpu().numpy()
if coarse_inputs is not None and hasattr(coarse_inputs, "detach"):
coarse_inputs = coarse_inputs.detach().cpu().numpy()
if hasattr(lat, "detach"):
lat = lat.detach().cpu().numpy()
if hasattr(lon, "detach"):
lon = lon.detach().cpu().numpy()
lat_min, lat_max = lat.min(), lat.max()
lon_min, lon_max = lon.min(), lon.max()
T, n_vars, h, w = predictions.shape
lat_block = np.linspace(lat_max, lat_min, h)
lon_block = np.linspace(lon_min, lon_max, w)
lat, lon = np.meshgrid(lat_block, lon_block, indexing="ij")
lon_center = float((lon_min + lon_max) / 2)
batch_size, num_vars, h, w = predictions.shape
if num_vars < 2:
print("ERROR: need at least 2 variables but num_vars < 2")
return "0"
# Default variable names if not provided
if variable_names is None:
variable_names = [f"VAR_{i}" for i in range(num_vars)]
# Make list of tuples defining variable combinations
list_var_combos = []
for ii in range(num_vars - 1):
for jj in range(num_vars - 1 - ii):
list_var_combos.append((ii, ii + jj + 1))
# Calculate grid dimensions
ncols = 2
if coarse_inputs is not None:
ncols = 3
nrows = int(num_vars * (num_vars - 1) / 2) # no. distinct pairs of input variables
base_width_per_panel = 4.5
base_height_per_panel = 3.0
fig_width = base_width_per_panel * ncols
fig_height = base_height_per_panel * nrows
spa_cor_out = np.zeros([nrows, ncols - 1])
spa_rmse_out = np.zeros([nrows, ncols - 1])
# Set up figure
fig, axes = plt.subplots(
nrows,
ncols,
figsize=(fig_width, fig_height),
subplot_kw={"projection": ccrs.PlateCarree(central_longitude=lon_center)},
squeeze=False,
gridspec_kw={"wspace": 0.1},
)
# Define geographic features
coastline = cfeature.COASTLINE.with_scale("50m")
borders = cfeature.BORDERS.with_scale("50m")
# lakes = cfeature.LAKES.with_scale('50m')
var_name_combo_list = []
# Plot each combination of variables
# max_count = 0
for i, varComb in enumerate(list_var_combos):
var_name_combo = variable_names[varComb[0]] + " & " + variable_names[varComb[1]]
var_name_combo_list.append(var_name_combo)
# Compute Correlation
pred_corr = calculate_pearsoncorr_nparray(
predictions[:, varComb[0], :, :], predictions[:, varComb[1], :, :], axis=0
)
target_corr = calculate_pearsoncorr_nparray(
targets[:, varComb[0], :, :], targets[:, varComb[1], :, :], axis=0
)
if coarse_inputs is not None:
coarse_corr = calculate_pearsoncorr_nparray(
coarse_inputs[:, varComb[0], :, :],
coarse_inputs[:, varComb[1], :, :],
axis=0,
)
spa_cor_out[i, 0] = np.corrcoef(
pred_corr.reshape(pred_corr.size), target_corr.reshape(target_corr.size)
)[0, 1]
if coarse_inputs is not None:
spa_cor_out[i, 1] = np.corrcoef(
coarse_corr.reshape(coarse_corr.size),
target_corr.reshape(target_corr.size),
)[0, 1]
spa_rmse_out[i, 0] = np.sqrt((np.square(pred_corr - target_corr)).mean())
if coarse_inputs is not None:
spa_rmse_out[i, 1] = np.sqrt((np.square(coarse_corr - target_corr)).mean())
# Col 0: Truth
ax_target = axes[i, 0]
ax_target.pcolormesh(
lon,
lat,
target_corr,
vmin=-1.0,
vmax=1.0,
cmap="RdBu",
transform=ccrs.PlateCarree(),
shading="auto",
)
ax_target.add_feature(coastline, linewidth=PlotConfig.COASTLINE_w)
ax_target.add_feature(
borders,
linewidth=PlotConfig.BORDER_w,
edgecolor="black",
linestyle=PlotConfig.BORDER_STYLE,
)
# ax_target.set_aspect("auto")
ax_target.set_extent(
[lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree()
)
# Col 1: Prediction
ax_pred = axes[i, 1]
im_pred = ax_pred.pcolormesh(
lon,
lat,
pred_corr,
vmin=-1.0,
vmax=1.0,
cmap="RdBu",
transform=ccrs.PlateCarree(),
shading="auto",
)
ax_pred.add_feature(coastline, linewidth=PlotConfig.COASTLINE_w)
ax_pred.add_feature(
borders,
linewidth=PlotConfig.BORDER_w,
edgecolor="black",
linestyle=PlotConfig.BORDER_STYLE,
)
# ax_pred.set_aspect("auto")
ax_pred.set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree())
if coarse_inputs is not None:
# Col 2: Coarse input
ax_coar = axes[i, 2]
ax_coar.pcolormesh(
lon,
lat,
coarse_corr,
vmin=-1.0,
vmax=1.0,
cmap="RdBu",
transform=ccrs.PlateCarree(),
shading="auto",
)
ax_coar.add_feature(coastline, linewidth=PlotConfig.COASTLINE_w)
ax_coar.add_feature(
borders,
linewidth=PlotConfig.BORDER_w,
edgecolor="black",
linestyle=PlotConfig.BORDER_STYLE,
)
# ax_coar.set_aspect("auto")
ax_coar.set_extent(
[lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree()
)
# Add col labels
col_labels = ["Truth", "Prediction"]
if coarse_inputs is not None:
col_labels = ["Truth", "Prediction", "Coarse"]
for col_idx, label in enumerate(col_labels):
axes[0, col_idx].set_title(label)
# Add row labels
for row_idx, label in enumerate(var_name_combo_list):
axes[row_idx, 0].text(
-0.1,
0.5,
label,
transform=axes[row_idx, 0].transAxes,
va="center",
ha="right",
rotation="vertical",
fontsize=12,
)
# Add colorbar
fig.subplots_adjust(top=0.9, bottom=0.1, left=0.1, right=0.9, wspace=0.1)
pos_top = axes[0, 0].get_position()
pos_bottom = axes[-1, 0].get_position()
bottom = pos_bottom.y0
top = pos_top.y1
height = top - bottom
cbar_ax = fig.add_axes([0.92, bottom, 0.015, height])
fig.colorbar(im_pred, cax=cbar_ax, label=r"Temporal Pearson Correlation")
# Ensure save directory exists
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, filename)
plt.savefig(save_path, bbox_inches="tight")
plt.close(fig)
"""
# _________________________________________
# Output summary map statistics as heatmaps
# Spatial Correlation and Spatial RMSE wrt target
# Setupt axis labels
xLabels=['Prediction']
if coarse_inputs is not None: xLabels=['Prediction','Coarse']
yLabels=var_name_combo_list
fig, (ax1,ax2) = plt.subplots(ncols=2, figsize=((ncols+2)*2,4))#, layout='constrained')
sns.heatmap(spa_cor_out, ax=ax1, cbar=False, linewidth=0.5, annot=True, fmt='.3f', xticklabels=xLabels, yticklabels=yLabels, vmin=0.0, vmax=1.0, cmap=plt.get_cmap('Reds'))
fig.colorbar(ax1.collections[0], ax=ax1, location="left", use_gridspec=False, pad=0.1, label="correlation")
ax1.tick_params(axis='y', pad=90, length=0)
ax1.tick_params(axis='x', length=0)
ax1.yaxis.set_label_position("left")
sns.heatmap(spa_rmse_out, ax=ax2, cbar=False, linewidth=0.5, annot=True, fmt='.3f', xticklabels=xLabels, yticklabels=[""]*ncols, vmin=0.0, vmax=0.3, cmap=plt.get_cmap('Reds_r'))
fig.colorbar(ax2.collections[0], ax=ax2, location="right", use_gridspec=False, pad=0.1, label="RMSE")
ax2.tick_params(rotation=0, length=0)
ax2.yaxis.set_label_position("right")
# Ensure save directory exists
os.makedirs(save_dir, exist_ok=True)
filenameCR='SpCorrRmse_'+filename
save_path = os.path.join(save_dir, filenameCR)
plt.savefig(save_path, bbox_inches='tight')
plt.close()
"""
return save_path
[docs]
def plot_temporal_series_comparison(
predictions, # Model predictions (fine predicted)
targets, # Ground truth (fine true)
coarse_inputs=None, # Coarse inputs for comparison (optional)
variable_names=None, # List of variable names
filename="validation_temp_series.png",
save_dir="./results",
figsize_multiplier=4,
):
"""
Plot spatially averaged temporal series for each variable.
Parameters
----------
predictions : torch.Tensor or np.array
Model predictions of shape [batch_size, num_variables, h, w]
targets : torch.Tensor or np.array
Ground truth of shape [batch_size, num_variables, h, w]
coarse_inputs : torch.Tensor or np.array, optional
Coarse inputs of shape [batch_size, num_variables, h, w]
variable_names : list of str, optional
Names of the variables for subplot titles
filename : str, optional
Output filename
save_dir : str, optional
Directory to save the plot
figsize_multiplier : int, optional
Base size multiplier for subplots
Returns
-------
save_path : str
Path to the saved figure
"""
if hasattr(predictions, "detach"):
predictions = predictions.detach().cpu().numpy()
if hasattr(targets, "detach"):
targets = targets.detach().cpu().numpy()
if coarse_inputs is not None and hasattr(coarse_inputs, "detach"):
coarse_inputs = coarse_inputs.detach().cpu().numpy()
if predictions.shape != targets.shape:
raise ValueError(f"Shape mismatch: {predictions.shape} vs {targets.shape}")
if coarse_inputs is not None and coarse_inputs.shape != targets.shape:
raise ValueError(
f"Coarse shape mismatch: {coarse_inputs.shape} vs {targets.shape}"
)
batch_size, num_vars, h, w = predictions.shape
# Default variable names if not provided
if variable_names is None:
variable_names = [f"VAR_{i}" for i in range(num_vars)]
if len(variable_names) != num_vars:
raise ValueError(
f"{len(variable_names)} variable names but num_vars={num_vars}"
)
fig = plt.figure(figsize=(6, figsize_multiplier * num_vars))
linestyles = mpltex.linestyle_generator(markers=[])
style_truth = next(linestyles)
style_pred = next(linestyles)
style_coarse = next(linestyles) if coarse_inputs is not None else None
# Loop over variables
for i, var in enumerate(variable_names):
ax = fig.add_subplot(num_vars, 1, i + 1)
pred_vals = PlotConfig.convert_units(var, predictions[:, i])
true_vals = PlotConfig.convert_units(var, targets[:, i])
# Spatial mean over H and W dimensions
s_pred = pred_vals.mean(axis=(1, 2))
s_true = true_vals.mean(axis=(1, 2))
# Temporal axis
time_index = range(batch_size)
ax.plot(time_index, s_true, label="Truth", linewidth=1.0, **style_truth)
ax.plot(time_index, s_pred, label="Prediction", linewidth=1.0, **style_pred)
if coarse_inputs is not None:
coarse_vals = PlotConfig.convert_units(var, coarse_inputs[:, i])
s_coarse = coarse_vals.mean(axis=(1, 2))
ax.plot(time_index, s_coarse, label="Coarse", linewidth=1.0, **style_coarse)
ax.set_title(var)
ax.grid(True, alpha=0.3)
ax.set_ylabel("Spatial mean")
if i == num_vars - 1:
ax.set_xlabel("Time index")
else:
ax.tick_params(labelbottom=False)
ax.legend()
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, filename)
plt.savefig(save_path, bbox_inches="tight")
plt.close(fig)
return save_path
[docs]
def ranks(
predictions, # Model predictions precipitation (fine predicted)
targets, # Ground truth precipitation (fine true)
):
"""
Compute ranks of predictions compared to targets.
Parameters
----------
predictions : torch.Tensor or np.array
Model predictions of shape [ensemble_size, batch_size, h, w]
targets : torch.Tensor or np.array
Targets of shape [batch_size, h, w]
Returns
-------
np.ndarray(np.float64) of shape [batch_size*h*w,]
"""
# convert to numpy if tensor :
if hasattr(predictions, "detach"):
predictions = predictions.detach().cpu().numpy()
if hasattr(targets, "detach"):
targets = targets.detach().cpu().numpy()
nb_ens, T, L, W = predictions.shape
predictions_ens = predictions.reshape(nb_ens, T * L * W)
targets = targets.reshape(1, T * L * W)
diff = predictions_ens - targets
mask_leq = (diff <= 0).astype(np.float32)
mask_l = (diff < 0).astype(np.float32)
mask = (mask_leq + mask_l) / 2
return np.sum(mask, axis=0)
[docs]
def plot_ranks(
predictions, # model predictions
targets, # ground truth
variable_names=None, # list of variable names
filename="ranks.png",
save_dir="./results",
figsize_multiplier=4,
):
"""
Create rank histograms of predictions compared to targets for each variable.
Parameters
----------
predictions : torch.Tensor or np.array
Model predictions of shape [ensemble_size, batch_size, num_variables, h, w]
targets : torch.Tensor or np.array
Ground truth of shape [batch_size, num_variables, h, w]
variable_names : list of str, optional
Names of the variables for subplot titles.
If None, uses ["VAR_0", "VAR_1", ...]
filename : str, optional
Output filename
save_dir : str, optional
Directory to save the plot
figsize_multiplier : int, optional
Base size multiplier for subplots
Returns
-------
save_path : str
Path to the saved figure
"""
# Convert tensors → numpy
if hasattr(predictions, "detach"):
predictions = predictions.detach().cpu().numpy()
if hasattr(targets, "detach"):
targets = targets.detach().cpu().numpy()
ensemble_size, batch_size, num_vars, h, w = predictions.shape
# Default variable names if not provided
if variable_names is None:
variable_names = [f"VAR_{i}" for i in range(num_vars)]
plot_variable_names = [PlotConfig.get_plot_name(var) for var in variable_names]
# Figure setup
fig, axes = plt.subplots(
1,
num_vars,
figsize=(num_vars * figsize_multiplier, figsize_multiplier),
constrained_layout=True,
)
if num_vars > 1:
axes = axes.ravel()
# Handle single subplot case
else:
axes = np.array([axes])
for ax in axes:
ax.set_box_aspect(1)
for i, var_name in enumerate(variable_names):
ax = axes[i]
plot_name = plot_variable_names[i]
ranks_predicted = ranks(
predictions=predictions[:, :, i, :, :],
targets=targets[:, i, :, :],
)
ax.hist(ranks_predicted, bins=np.arange(ensemble_size + 2), density=True)
ax.plot(
[0, ensemble_size + 1],
[1 / (ensemble_size + 1), 1 / (ensemble_size + 1)],
linestyle="--",
color="red",
)
ax.set_title(plot_name)
ax.set_xlabel("ranks")
ax.set_ylabel("frequency")
# Save figure
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, filename)
plt.savefig(save_path, bbox_inches="tight")
plt.close(fig)
return save_path
[docs]
def get_divergence(u_tensor, v_tensor, spacing):
"""
Compute the horizontal divergence of a windfield.
Parameters
----------
u_tensor : torch.Tensor or np.array, shape [...,h,w]
tensor that stores the zonal component of the windfield. Can have arbitrary number of dimensions, but the last two dimensions have to correspond to longitude and latitude.
u_tensor and v_tensor need to have the same shape.
v_tensor : torch.Tensor or np.array
tensor that stores the meridional component of the windfield. Can have arbitrary number of dimensions, but the last two dimensions have to correspond to longitude and latitude.
u_tensor and v_tensor need to have the same shape.
spacing : float
float that describes the resolution of the windfield. Used to compute the gradients.
Returns
-------
np.ndarray(np.float64) of same shape as u_tensor and v_tensor
"""
# convert to torch if needed
if isinstance(u_tensor, np.ndarray):
u_tensor = torch.from_numpy(u_tensor)
if isinstance(v_tensor, np.ndarray):
v_tensor = torch.from_numpy(v_tensor)
u_x = torch.gradient(u_tensor, spacing=spacing, dim=-2)[0]
v_y = torch.gradient(v_tensor, spacing=spacing, dim=-1)[0]
return (u_x + v_y).detach().cpu().numpy()
[docs]
def get_curl(u_tensor, v_tensor, spacing):
"""
Compute the curl of a windfield.
Parameters
----------
u_tensor : torch.Tensor or np.array, shape [...,h,w]
tensor that stores the zonal component of the windfield. Can have arbitrary number of dimensions, but the last two dimensions have to correspond to longitude and latitude.
u_tensor and v_tensor need to have the same shape.
v_tensor : torch.Tensor or np.array
tensor that stores the meridional component of the windfield. Can have arbitrary number of dimensions, but the last two dimensions have to correspond to longitude and latitude.
u_tensor and v_tensor need to have the same shape.
spacing : float
spatial resolution of the windfield. Used to compute the gradients.
Returns
-------
np.ndarray(np.float64) of same shape as u_tensor and v_tensor
"""
# convert to torch if needed
if isinstance(u_tensor, np.ndarray):
u_tensor = torch.from_numpy(u_tensor)
if isinstance(v_tensor, np.ndarray):
v_tensor = torch.from_numpy(v_tensor)
u_y = torch.gradient(u_tensor, spacing=spacing, dim=-1)[0]
v_x = torch.gradient(v_tensor, spacing=spacing, dim=-2)[0]
return (v_x - u_y).detach().cpu().numpy()
[docs]
def plot_mean_divergence_map(
u_prediction, # Model predictions precipitation (fine predicted)
v_prediction, # Model predictions precipitation (fine predicted)
u_target, # Ground truth precipitation (fine true)
v_target, # Ground truth precipitation (fine true)
spacing,
lat_1d,
lon_1d,
filename="mean_divergence.png",
save_dir=None,
figsize_multiplier=None, # Base size per subplot
):
"""
Plot spatial dry pixels proportion maps. Value of each pixel corresponds to the frequency of dry weather for this pixel.
Parameters
----------
u_prediction : torch.Tensor or np.array
Model predictions of shape [batch_size, h, w] for zonal component of wind
Last two dims have to correspond to longitude and latitude
u_prediction and v_prediction need to have the same shape
v_prediction : torch.Tensor or np.array
Model predictions of shape [batch_size, h, w] for meridional component of wind
Last two dims have to correspond to longitude and latitude
u_prediction and v_prediction need to have the same shape
u_target : torch.Tensor or np.array
Ground truth of shape [batch_size, h, w]
Last two dims have to correspond to longitude and latitude
u_target and v_target need to have the same shape
v_target : torch.Tensor or np.array
Ground truth of shape [batch_size, h, w]
Last two dims have to correspond to longitude and latitude
u_target and v_target need to have the same shape
spacing : float
spatial resolution of the windfield. Used to compute the gradients.
lat_1d : array-like
1D array of latitude coordinates with shape [H].
lon_1d : array-like
1D array of longitude coordinates with shape [W].
filename : str, optional
Output filename for saving the plot.
save_dir : str, optional
Directory to save the plot.
figsize_multiplier : int, optional
Base size multiplier for subplots.
Returns
-------
None
"""
if save_dir is None:
save_dir = PlotConfig.DEFAULT_SAVE_DIR
if figsize_multiplier is None:
figsize_multiplier = PlotConfig.DEFAULT_FIGSIZE_MULTIPLIER
lat_min, lat_max = lat_1d.min(), lat_1d.max()
lon_min, lon_max = lon_1d.min(), lon_1d.max()
_, h, w = u_target.shape
lat_block = np.linspace(lat_max, lat_min, h)
lon_block = np.linspace(lon_min, lon_max, w)
lat, lon = np.meshgrid(lat_block, lon_block, indexing="ij")
lon_center = float((lon_min + lon_max) / 2)
cmap = PlotConfig.get_colormap(
"divergence"
) # need to define the comap in PlotConfig
# convert units :
u_prediction = PlotConfig.convert_units("wind", u_prediction)
v_prediction = PlotConfig.convert_units("wind", v_prediction)
u_target = PlotConfig.convert_units("wind", u_target)
v_target = PlotConfig.convert_units("wind", v_target)
div_prediction = get_divergence(u_prediction, v_prediction, spacing)
div_target = get_divergence(u_target, v_target, spacing)
mean_div_prediction = np.mean(div_prediction, axis=0)
mean_div_target = np.mean(div_target, axis=0)
vmin = min(np.min(mean_div_prediction), np.min(mean_div_target))
vmax = max(np.max(mean_div_prediction), np.max(mean_div_target))
vmax = max(np.abs(vmax), np.abs(vmin))
norm = mcolors.TwoSlopeNorm(vmin=-vmax, vcenter=0, vmax=vmax)
base_width_per_panel = 4.5
base_height_per_panel = 3.0
fig_width = base_width_per_panel
fig_height = 3 * base_height_per_panel
fig, axes = plt.subplots(
3,
1,
figsize=(fig_width, fig_height),
subplot_kw={
"projection": ccrs.PlateCarree(central_longitude=lon_center)
}, # ccrs.Mercator(central_longitude=lon_center)
gridspec_kw={"wspace": 0.1},
)
fig.subplots_adjust(
top=0.9, bottom=0.1, left=0.1, right=0.9, wspace=0.1, hspace=0.1
)
im = axes[0].pcolormesh(
lon,
lat,
mean_div_target,
norm=norm,
cmap=cmap,
transform=ccrs.PlateCarree(),
shading="auto",
)
axes[0].set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree())
axes[0].coastlines(linewidth=0.6)
axes[0].add_feature(
cfeature.BORDERS.with_scale("50m"),
linewidth=0.6,
linestyle="--",
edgecolor="black",
zorder=11,
)
axes[0].add_feature(
cfeature.LAKES.with_scale("50m"),
edgecolor="black",
facecolor="none",
linewidth=0.6,
zorder=9,
)
# ax.set_aspect("auto")
axes[0].set_xticks([])
axes[0].set_yticks([])
axes[0].set_title("Target")
im = axes[1].pcolormesh(
lon,
lat,
mean_div_prediction,
norm=norm,
cmap=cmap,
transform=ccrs.PlateCarree(),
shading="auto",
)
axes[1].set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree())
axes[1].coastlines(linewidth=0.6)
axes[1].add_feature(
cfeature.BORDERS.with_scale("50m"),
linewidth=0.6,
linestyle="--",
edgecolor="black",
zorder=11,
)
axes[1].add_feature(
cfeature.LAKES.with_scale("50m"),
edgecolor="black",
facecolor="none",
linewidth=0.6,
zorder=9,
)
# ax.set_aspect("auto")
axes[1].set_xticks([])
axes[1].set_yticks([])
axes[1].set_title("Prediction")
# axes are vertically stacked
# Get positions of the top two axes
pos0 = axes[0].get_position()
pos1 = axes[1].get_position()
bottom = pos1.y0
top = pos0.y1
height = top - bottom
# Create the colorbar axis
cax1 = fig.add_axes([0.92, bottom, 0.02, height]) # [left, bottom, width, height]
# Add the colorbar
fig.colorbar(im, cax=cax1, orientation="vertical", label="divergence")
im = axes[2].pcolormesh(
lon,
lat,
mean_div_prediction - mean_div_target,
norm=norm,
cmap=cmap,
transform=ccrs.PlateCarree(),
shading="auto",
)
axes[2].set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree())
axes[2].coastlines(linewidth=0.6)
axes[2].add_feature(
cfeature.BORDERS.with_scale("50m"),
linewidth=0.6,
linestyle="--",
edgecolor="black",
zorder=11,
)
axes[2].add_feature(
cfeature.LAKES.with_scale("50m"),
edgecolor="black",
facecolor="none",
linewidth=0.6,
zorder=9,
)
# ax.set_aspect("auto")
axes[2].set_xticks([])
axes[2].set_yticks([])
axes[2].set_title("Predicted - Target")
# Get position of the bottom axis
pos2 = axes[2].get_position()
# Create colorbar axis
cax2 = fig.add_axes([0.92, pos2.y0, 0.02, pos2.height])
# Add the colorbar
fig.colorbar(im, cax=cax2, orientation="vertical", label="divergence error")
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, filename)
plt.savefig(save_path, bbox_inches="tight")
plt.close(fig)
return save_path
[docs]
def plot_mean_curl_map(
u_prediction, # Model predictions precipitation (fine predicted)
v_prediction, # Model predictions precipitation (fine predicted)
u_target, # Ground truth precipitation (fine true)
v_target, # Ground truth precipitation (fine true)
spacing,
lat_1d,
lon_1d,
filename="mean_curl.png",
save_dir=None,
figsize_multiplier=None, # Base size per subplot
):
"""
Plot spatial dry pixels proportion maps. Value of each pixel corresponds to the frequency of dry weather for this pixel.
Parameters
----------
u_prediction : torch.Tensor or np.array
Model predictions of shape [batch_size, h, w] for zonal component of wind
Last two dims have to correspond to longitude and latitude
u_prediction and v_prediction need to have the same shape
v_prediction : torch.Tensor or np.array
Model predictions of shape [batch_size, h, w] for meridional component of wind
Last two dims have to correspond to longitude and latitude
u_prediction and v_prediction need to have the same shape
u_target : torch.Tensor or np.array
Ground truth of shape [batch_size, h, w]
Last two dims have to correspond to longitude and latitude
u_target and v_target need to have the same shape
v_target : torch.Tensor or np.array
Ground truth of shape [batch_size, h, w]
Last two dims have to correspond to longitude and latitude
u_target and v_target need to have the same shape
spacing : float
spatial resolution of the windfield. Used to compute the gradients.
lat_1d : array-like
1D array of latitude coordinates with shape [H].
lon_1d : array-like
1D array of longitude coordinates with shape [W].
filename : str, optional
Output filename for saving the plot.
save_dir : str, optional
Directory to save the plot.
figsize_multiplier : int, optional
Base size multiplier for subplots.
Returns
-------
None
"""
if save_dir is None:
save_dir = PlotConfig.DEFAULT_SAVE_DIR
if figsize_multiplier is None:
figsize_multiplier = PlotConfig.DEFAULT_FIGSIZE_MULTIPLIER
lat_min, lat_max = lat_1d.min(), lat_1d.max()
lon_min, lon_max = lon_1d.min(), lon_1d.max()
_, h, w = u_target.shape
lat_block = np.linspace(lat_max, lat_min, h)
lon_block = np.linspace(lon_min, lon_max, w)
lat, lon = np.meshgrid(lat_block, lon_block, indexing="ij")
lon_center = float((lon_min + lon_max) / 2)
cmap = PlotConfig.get_colormap("curl") # need to define the comap in PlotConfig
# convert units :
u_prediction = PlotConfig.convert_units("wind", u_prediction)
v_prediction = PlotConfig.convert_units("wind", v_prediction)
u_target = PlotConfig.convert_units("wind", u_target)
v_target = PlotConfig.convert_units("wind", v_target)
curl_prediction = get_curl(u_prediction, v_prediction, spacing)
curl_target = get_curl(u_target, v_target, spacing)
mean_curl_prediction = np.mean(curl_prediction, axis=0)
mean_curl_target = np.mean(curl_target, axis=0)
vmin = min(np.min(mean_curl_prediction), np.min(mean_curl_target))
vmax = max(np.max(mean_curl_prediction), np.max(mean_curl_target))
vmax = max(np.abs(vmax), np.abs(vmin))
norm = mcolors.TwoSlopeNorm(vmin=-vmax, vcenter=0, vmax=vmax)
base_width_per_panel = 4.5
base_height_per_panel = 3.0
fig_width = base_width_per_panel
fig_height = 3 * base_height_per_panel
fig, axes = plt.subplots(
3,
1,
figsize=(fig_width, fig_height),
subplot_kw={
"projection": ccrs.PlateCarree(central_longitude=lon_center)
}, # ccrs.Mercator(central_longitude=lon_center)
gridspec_kw={"wspace": 0.1},
)
fig.subplots_adjust(
top=0.9, bottom=0.1, left=0.1, right=0.9, wspace=0.1, hspace=0.1
)
im = axes[0].pcolormesh(
lon,
lat,
mean_curl_target,
norm=norm,
cmap=cmap,
transform=ccrs.PlateCarree(),
shading="auto",
)
axes[0].set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree())
axes[0].coastlines(linewidth=0.6)
axes[0].add_feature(
cfeature.BORDERS.with_scale("50m"),
linewidth=0.6,
linestyle="--",
edgecolor="black",
zorder=11,
)
axes[0].add_feature(
cfeature.LAKES.with_scale("50m"),
edgecolor="black",
facecolor="none",
linewidth=0.6,
zorder=9,
)
# ax.set_aspect("auto")
axes[0].set_xticks([])
axes[0].set_yticks([])
axes[0].set_title("Target")
im = axes[1].pcolormesh(
lon,
lat,
mean_curl_prediction,
norm=norm,
cmap=cmap,
transform=ccrs.PlateCarree(),
shading="auto",
)
axes[1].set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree())
axes[1].coastlines(linewidth=0.6)
axes[1].add_feature(
cfeature.BORDERS.with_scale("50m"),
linewidth=0.6,
linestyle="--",
edgecolor="black",
zorder=11,
)
axes[1].add_feature(
cfeature.LAKES.with_scale("50m"),
edgecolor="black",
facecolor="none",
linewidth=0.6,
zorder=9,
)
# ax.set_aspect("auto")
axes[1].set_xticks([])
axes[1].set_yticks([])
axes[1].set_title("Prediction")
# Get positions of the top two axes
pos0 = axes[0].get_position()
pos1 = axes[1].get_position()
bottom = pos1.y0
top = pos0.y1
height = top - bottom
# Create the colorbar axis
cax1 = fig.add_axes([0.92, bottom, 0.02, height]) # [left, bottom, width, height]
# Add the colorbar
fig.colorbar(im, cax=cax1, orientation="vertical", label="curl")
im = axes[2].pcolormesh(
lon,
lat,
mean_curl_prediction - mean_curl_target,
norm=norm,
cmap=cmap,
transform=ccrs.PlateCarree(),
shading="auto",
)
axes[2].set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree())
axes[2].coastlines(linewidth=0.6)
axes[2].add_feature(
cfeature.BORDERS.with_scale("50m"),
linewidth=0.6,
linestyle="--",
edgecolor="black",
zorder=11,
)
axes[2].add_feature(
cfeature.LAKES.with_scale("50m"),
edgecolor="black",
facecolor="none",
linewidth=0.6,
zorder=9,
)
# ax.set_aspect("auto")
axes[2].set_xticks([])
axes[2].set_yticks([])
axes[2].set_title("Predicted - Target")
# Get position of the bottom axis
pos2 = axes[2].get_position()
# Create colorbar axis
cax2 = fig.add_axes([0.92, pos2.y0, 0.02, pos2.height])
# Add the colorbar
fig.colorbar(im, cax=cax2, orientation="vertical", label="curl error")
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, filename)
plt.savefig(save_path, bbox_inches="tight")
plt.close(fig)
return save_path
# ============================================================================
# Plotting Functions Test Suite
# ============================================================================
[docs]
class TestPlottingFunctions(unittest.TestCase):
"""Unit tests for plotting functions with visible output for styling adjustment."""
[docs]
def __init__(self, methodName="runTest", logger=None):
super().__init__(methodName)
self.logger = logger
[docs]
def setUp(self):
"""Set up test fixtures."""
self.output_dir = "./test_plots"
os.makedirs(self.output_dir, exist_ok=True)
# Generate realistic synthetic test data
np.random.seed(42)
self.batch_size = 50
self.num_vars = 4
self.h = 64
self.w = 64
if self.logger:
self.logger.info(
f"Test setup complete - output directory: {self.output_dir}"
)
self.logger.info(
f"Batch size: {self.batch_size}, Variables: {self.num_vars}, Resolution: {self.h}x{self.w}"
)
# Create correlated data for realistic plots
x = np.linspace(0, 4 * np.pi, self.w)
y = np.linspace(0, 4 * np.pi, self.h)
X, Y = np.meshgrid(x, y)
patterns = [
np.sin(X) * np.cos(Y),
np.exp(-0.1 * (X - 10) ** 2 - 0.1 * (Y - 10) ** 2),
X * Y / 100,
np.sin(0.5 * X) * np.cos(0.5 * Y) + 0.5 * np.sin(2 * X) * np.cos(2 * Y),
]
self.predictions = np.zeros((self.batch_size, self.num_vars, self.h, self.w))
self.targets = np.zeros((self.batch_size, self.num_vars, self.h, self.w))
self.coarse_inputs = np.zeros((self.batch_size, self.num_vars, self.h, self.w))
for i in range(self.num_vars):
base_pattern = patterns[i % len(patterns)]
for b in range(self.batch_size):
noise_pred = np.random.normal(0, 0.1, (self.h, self.w))
noise_target = np.random.normal(0, 0.1, (self.h, self.w))
noise_coarse = np.random.normal(0, 0.2, (self.h, self.w))
scale = 1.0 + 0.1 * np.random.random()
offset = 0.1 * np.random.random()
self.predictions[b, i] = base_pattern * scale + offset + noise_pred
self.targets[b, i] = (
base_pattern * (scale + 0.05) + offset + 0.05 + noise_target
)
self.coarse_inputs[b, i] = (
base_pattern * (scale - 0.1) + offset - 0.1 + noise_coarse
)
self.variable_names = [
"Temp",
"Press",
"Humid",
"Wind",
]
# Create lat/lon arrays for spatial tests
self.lat = np.linspace(-90, 90, self.h)
self.lon = np.linspace(-180, 180, self.w)
# Create comprehensive metrics history
self.valid_metrics_history = {}
metrics = ["rmse", "mae", "r2"]
for var in self.variable_names:
var_key = var.split(" ")[0]
for metric in metrics:
base_val_pred = 0.8 if metric == "r2" else 1.0
base_val_coarse = 0.6 if metric == "r2" else 1.5
decay = np.linspace(0, 0.3, 10)
if metric == "r2":
self.valid_metrics_history[f"{var_key}_pred_vs_fine_{metric}"] = (
base_val_pred + decay
)
self.valid_metrics_history[f"{var_key}_coarse_vs_fine_{metric}"] = (
base_val_coarse + decay * 0.5
)
else:
self.valid_metrics_history[f"{var_key}_pred_vs_fine_{metric}"] = (
base_val_pred - decay
)
self.valid_metrics_history[f"{var_key}_coarse_vs_fine_{metric}"] = (
base_val_coarse - decay * 0.5
)
# Add average metrics
for metric in metrics:
self.valid_metrics_history[f"average_pred_vs_fine_{metric}"] = (
0.1 + np.linspace(0, 0.2, 10)
)
self.valid_metrics_history[f"average_coarse_vs_fine_{metric}"] = (
0.7 + np.linspace(0, 0.2, 10)
)
# Loss histories
self.train_loss_history = np.exp(
-np.linspace(0, 2, 20)
) + 0.1 * np.random.random(20)
self.valid_loss_history = np.exp(
-np.linspace(0, 1.5, 20)
) + 0.15 * np.random.random(20)
# ============================================================================
# SINGLE COMPREHENSIVE TEST FOR EACH DIAGNOSTIC METHOD
# ============================================================================
[docs]
def test_validation_hexbin_comprehensive(self):
"""Comprehensive test for validation hexbin plots."""
if self.logger:
self.logger.info("Testing validation hexbin plots comprehensively")
# Test 1: Standard configuration
expected_path = plot_validation_hexbin(
predictions=self.predictions,
targets=self.targets,
variable_names=self.variable_names,
save_dir=self.output_dir,
filename="validation_hexbin_standard.png",
figsize_multiplier=5,
)
self.assertTrue(
os.path.exists(expected_path), f"File not found: {expected_path}"
)
expected_path = plot_comparison_hexbin(
predictions=self.predictions,
targets=self.targets,
coarse_inputs=self.coarse_inputs,
variable_names=self.variable_names,
filename="comparison_hexbin_standard.png",
save_dir=self.output_dir,
)
self.assertTrue(
os.path.exists(expected_path), f"File not found: {expected_path}"
)
# Test 2: PyTorch tensors
predictions_tensor = torch.from_numpy(self.predictions)
targets_tensor = torch.from_numpy(self.targets)
coarse_tensor = torch.from_numpy(self.coarse_inputs)
expected_path = plot_validation_hexbin(
predictions=predictions_tensor,
targets=targets_tensor,
variable_names=self.variable_names,
save_dir=self.output_dir,
filename="validation_hexbin_torch.png",
)
self.assertTrue(
os.path.exists(expected_path), f"File not found: {expected_path}"
)
expected_path = plot_comparison_hexbin(
predictions=predictions_tensor,
targets=targets_tensor,
coarse_inputs=coarse_tensor,
variable_names=self.variable_names,
filename="comparison_hexbin_torch.png",
save_dir=self.output_dir,
)
self.assertTrue(
os.path.exists(expected_path), f"File not found: {expected_path}"
)
# Test 3: Single variable
single_pred = self.predictions[:, 0:1, :, :]
single_target = self.targets[:, 0:1, :, :]
single_coarse = self.coarse_inputs[:, 0:1, :, :]
expected_path = plot_validation_hexbin(
predictions=single_pred,
targets=single_target,
variable_names=[self.variable_names[0]],
save_dir=self.output_dir,
filename="validation_hexbin_single.png",
figsize_multiplier=6,
)
self.assertTrue(
os.path.exists(expected_path), f"File not found: {expected_path}"
)
expected_path = plot_comparison_hexbin(
predictions=single_pred,
targets=single_target,
coarse_inputs=single_coarse,
variable_names=[self.variable_names[0]],
filename="comparison_hexbin_single.png",
save_dir=self.output_dir,
)
self.assertTrue(
os.path.exists(expected_path), f"File not found: {expected_path}"
)
if self.logger:
self.logger.info("✅ All validation hexbin tests passed")
[docs]
def test_validation_pdfs_comprehensive(self):
"""Comprehensive test for validation PDF plots."""
if self.logger:
self.logger.info("Testing validation PDF plots comprehensively")
# Test 1: Standard configuration with coarse inputs
expected_path = plot_validation_pdfs(
predictions=self.predictions,
targets=self.targets,
coarse_inputs=self.coarse_inputs,
variable_names=self.variable_names,
save_dir=self.output_dir,
filename="validation_pdfs_standard.png",
)
self.assertTrue(
os.path.exists(expected_path), f"File not found: {expected_path}"
)
# Test 2: Without coarse inputs
expected_path = plot_validation_pdfs(
predictions=self.predictions,
targets=self.targets,
coarse_inputs=None,
variable_names=self.variable_names,
save_dir=self.output_dir,
filename="validation_pdfs_no_coarse.png",
)
self.assertTrue(
os.path.exists(expected_path), f"File not found: {expected_path}"
)
# Test 3: PyTorch tensors
predictions_tensor = torch.from_numpy(self.predictions)
targets_tensor = torch.from_numpy(self.targets)
coarse_tensor = torch.from_numpy(self.coarse_inputs)
expected_path = plot_validation_pdfs(
predictions=predictions_tensor,
targets=targets_tensor,
coarse_inputs=coarse_tensor,
variable_names=self.variable_names,
save_dir=self.output_dir,
filename="validation_pdfs_torch.png",
)
self.assertTrue(
os.path.exists(expected_path), f"File not found: {expected_path}"
)
if self.logger:
self.logger.info("✅ All validation PDF tests passed")
[docs]
def test_power_spectra_comprehensive(self):
"""Comprehensive test for power spectra plots."""
if self.logger:
self.logger.info("Testing power spectra plots comprehensively")
dlat = np.abs(self.lat[1] - self.lat[0])
dlon = np.abs(self.lon[1] - self.lon[0])
# Test 1: Standard configuration
expected_path = plot_power_spectra(
predictions=self.predictions,
targets=self.targets,
coarse_inputs=self.coarse_inputs,
dlat=dlat,
dlon=dlon,
variable_names=self.variable_names,
save_dir=self.output_dir,
filename="power_spectra_standard.png",
)
self.assertTrue(
os.path.exists(expected_path), f"File not found: {expected_path}"
)
# Test 2: Without coarse inputs
expected_path = plot_power_spectra(
predictions=self.predictions,
targets=self.targets,
coarse_inputs=None,
dlat=dlat,
dlon=dlon,
variable_names=self.variable_names,
save_dir=self.output_dir,
filename="power_spectra_no_coarse.png",
)
self.assertTrue(
os.path.exists(expected_path), f"File not found: {expected_path}"
)
# Test 3: PyTorch tensors
predictions_tensor = torch.from_numpy(self.predictions)
targets_tensor = torch.from_numpy(self.targets)
coarse_tensor = torch.from_numpy(self.coarse_inputs)
expected_path = plot_power_spectra(
predictions=predictions_tensor,
targets=targets_tensor,
coarse_inputs=coarse_tensor,
dlat=dlat,
dlon=dlon,
variable_names=self.variable_names,
save_dir=self.output_dir,
filename="power_spectra_torch.png",
)
self.assertTrue(
os.path.exists(expected_path), f"File not found: {expected_path}"
)
if self.logger:
self.logger.info("✅ All power spectra tests passed")
[docs]
def test_spatiotemporal_histograms_comprehensive(self):
"""Comprehensive test for spatiotemporal histograms."""
if self.logger:
self.logger.info("Testing spatiotemporal histograms comprehensively")
class MockSteps:
latitude = 180
longitude = 360
steps = MockSteps()
# Test 1: Dense data
tindex_lim = (0, 365)
# n_samples = 2000
centers = []
tindices = []
clusters = [
{
"lat_range": (30, 60),
"lon_range": (200, 250),
"time_range": (0, 100),
"n": 500,
},
{
"lat_range": (10, 40),
"lon_range": (100, 150),
"time_range": (100, 200),
"n": 400,
},
{
"lat_range": (50, 80),
"lon_range": (300, 350),
"time_range": (200, 300),
"n": 600,
},
{
"lat_range": (0, 30),
"lon_range": (50, 100),
"time_range": (300, 365),
"n": 500,
},
]
for cluster in clusters:
for _ in range(cluster["n"]):
lat = np.random.randint(
cluster["lat_range"][0], cluster["lat_range"][1]
)
lon = np.random.randint(
cluster["lon_range"][0], cluster["lon_range"][1]
)
tindex = np.random.randint(
cluster["time_range"][0], cluster["time_range"][1]
)
centers.append((lat, lon))
tindices.append(tindex)
expected_path = plot_spatiotemporal_histograms(
steps=steps,
tindex_lim=tindex_lim,
centers=centers,
tindices=tindices,
mode="train",
filename="spatiotemporal_dense_",
save_dir=self.output_dir,
)
self.assertTrue(
os.path.exists(expected_path), f"File not found: {expected_path}"
)
if self.logger:
self.logger.info("✅ All spatiotemporal histogram tests passed")
[docs]
def test_plot_surface_comprehensive(self):
"""Comprehensive test for surface plots."""
if self.logger:
self.logger.info("Testing surface plots comprehensively")
# Test case 1: Standard configuration
lat_1d = np.linspace(30, 50, 48)
lon_1d = np.linspace(-120, -80, 68)
# Create synthetic data
batch_size = 1
n_vars = 3
h, w = 48, 68
# Create spatial patterns
x = np.linspace(0, 3 * np.pi, w)
y = np.linspace(0, 3 * np.pi, h)
X, Y = np.meshgrid(x, y)
# Initialize arrays
coarse_inputs = np.zeros((batch_size, n_vars, h, w))
targets = np.zeros((batch_size, n_vars, h, w))
pred = np.zeros((batch_size, n_vars, h, w))
base_patterns = [
np.sin(X / 2) * np.cos(Y / 2),
np.exp(-0.01 * (X - 24) ** 2 - 0.01 * (Y - 24) ** 2),
X * Y / 200,
]
for i in range(n_vars):
base_pattern = base_patterns[i % len(base_patterns)]
pattern = base_pattern * 20 + 280 # Temperature-like
coarse_inputs[0, i] = pattern + np.random.randn(h, w) * 2
targets[0, i] = pattern + np.random.randn(h, w) * 1
pred[0, i] = targets[0, i] + np.random.randn(h, w) * 0.3
variable_names = ["Temp", "Press", "Humid"]
timestamp = datetime(2024, 1, 1, 12, 0)
# Test with numpy arrays
expected_path = plot_surface(
coarse_inputs=coarse_inputs,
targets=targets,
predictions=pred,
lat_1d=lat_1d,
lon_1d=lon_1d,
timestamp=timestamp,
variable_names=variable_names,
filename="plot_surface_standard.png",
save_dir=self.output_dir,
)
self.assertTrue(
os.path.exists(expected_path), f"File not found: {expected_path}"
)
# Test with PyTorch tensors
coarse_inputs_tensor = torch.from_numpy(coarse_inputs.copy())
targets_tensor = torch.from_numpy(targets.copy())
pred_tensor = torch.from_numpy(pred.copy())
expected_path = plot_surface(
coarse_inputs=coarse_inputs_tensor,
targets=targets_tensor,
predictions=pred_tensor,
lat_1d=lat_1d,
lon_1d=lon_1d,
timestamp=timestamp,
variable_names=variable_names,
filename="plot_surface_torch.png",
save_dir=self.output_dir,
)
self.assertTrue(
os.path.exists(expected_path), f"File not found: {expected_path}"
)
if self.logger:
self.logger.info("✅ All surface plot tests passed")
[docs]
def test_plot_ensemble_surface_comprehensive(self):
"""Comprehensive test for ensemble surface plots."""
if self.logger:
self.logger.info("Testing ensemble surface plots comprehensively")
lat_1d = np.linspace(30, 50, 48)
lon_1d = np.linspace(-120, -80, 68)
# Ensemble configuration
N_ens = 5
n_vars = 3
H, W = 48, 68
x = np.linspace(0, 3 * np.pi, W)
y = np.linspace(0, 3 * np.pi, H)
X, Y = np.meshgrid(x, y)
# different types of spatial patterns
base_patterns = [
np.sin(X / 2) * np.cos(Y / 2), # sinusoidal pattern
np.exp(-0.01 * (X - 24) ** 2 - 0.01 * (Y - 24) ** 2), # Gaussian blob
X * Y / 200, # linear gradient
]
# ensemble array: [N_ens, n_vars, H, W]
predictions_ens = np.zeros((N_ens, n_vars, H, W))
# Generate ensemble members
for k in range(N_ens):
for i in range(n_vars):
# Select a base spatial pattern for each variable
base = base_patterns[i % len(base_patterns)]
# Scale to realistic physical values
signal = base * 20 + 280
# Add Gaussian noise to simulate ensemble spread
predictions_ens[k, i] = signal + np.random.randn(H, W) * (0.5 + k * 0.2)
variable_names = ["Temp", "Press", "Humid"]
timestamp = datetime(2024, 1, 1, 12, 0)
# Test 1: numpy
expected_path = plot_ensemble_surface(
predictions_ens=predictions_ens,
lat_1d=lat_1d,
lon_1d=lon_1d,
variable_names=variable_names,
timestamp=timestamp,
filename="plot_ensemble_surface_numpy.png",
save_dir=self.output_dir,
)
self.assertTrue(os.path.exists(expected_path))
self.assertGreater(os.path.getsize(expected_path), 0)
# Test 2: torch
expected_path = plot_ensemble_surface(
predictions_ens=torch.from_numpy(predictions_ens),
lat_1d=lat_1d,
lon_1d=lon_1d,
variable_names=variable_names,
timestamp=timestamp,
filename="plot_ensemble_surface_torch.png",
save_dir=self.output_dir,
)
self.assertTrue(os.path.exists(expected_path))
self.assertGreater(os.path.getsize(expected_path), 0)
if self.logger:
self.logger.info("✅ All ensemble surface plot tests passed")
[docs]
def test_plot_zoom_comparison_comprehensive(self):
"""Comprehensive test for zoom comparison plots."""
if self.logger:
self.logger.info("Testing zoom comparison plots comprehensively")
# Grid
lat_1d = np.linspace(-90, 90, 144)
lon_1d = np.linspace(0, 360, 360, endpoint=False)
batch_size = 1
n_vars = 3
H, W = 144, 360
targets = np.zeros((batch_size, n_vars, H, W))
predictions = np.zeros((batch_size, n_vars, H, W))
for i in range(n_vars):
base = np.ones((H, W)) * (280 + i * 5)
targets[0, i] = base + np.random.randn(H, W) * 1.0
predictions[0, i] = targets[0, i] + np.random.randn(H, W) * 0.5
variable_names = ["Temp", "Press", "Humid"]
zoom_box = {
"lat_min": -23,
"lat_max": 13,
"lon_min": 255,
"lon_max": 345,
}
# Test 1: numpy
expected_path = plot_zoom_comparison(
predictions=predictions,
targets=targets,
lat_1d=lat_1d,
lon_1d=lon_1d,
variable_names=variable_names,
filename="plot_zoom_numpy.png",
save_dir=self.output_dir,
zoom_box=zoom_box,
)
self.assertTrue(os.path.exists(expected_path))
self.assertGreater(os.path.getsize(expected_path), 0)
# Test 2: torch
expected_path = plot_zoom_comparison(
predictions=torch.from_numpy(predictions),
targets=torch.from_numpy(targets),
lat_1d=lat_1d,
lon_1d=lon_1d,
variable_names=variable_names,
filename="plot_zoom_torch.png",
save_dir=self.output_dir,
zoom_box=zoom_box,
)
self.assertTrue(os.path.exists(expected_path))
self.assertGreater(os.path.getsize(expected_path), 0)
if self.logger:
self.logger.info("✅ All zoom comparison plot tests passed")
[docs]
def test_plot_global_surface_robinson_comprehensive(self):
"""Comprehensive test for global Robinson surface plots."""
if self.logger:
self.logger.info("Testing global Robinson surface plots comprehensively")
# GLOBAL domain
H, W = 90, 180
lat_1d = np.linspace(-90, 90, H)
lon_1d = np.linspace(-180, 180, W)
batch_size = 1
n_vars = 3
# Create synthetic global patterns
x = np.linspace(-np.pi, np.pi, W)
y = np.linspace(-np.pi / 2, np.pi / 2, H)
X, Y = np.meshgrid(x, y)
coarse_inputs = np.zeros((batch_size, n_vars, H, W))
targets = np.zeros((batch_size, n_vars, H, W))
pred = np.zeros((batch_size, n_vars, H, W))
base_patterns = [np.sin(X) * np.cos(Y), np.cos(2 * X) * np.sin(Y), X * Y]
for i in range(n_vars):
pattern = base_patterns[i] * 10 + 280
coarse_inputs[0, i] = pattern + np.random.randn(H, W) * 2
targets[0, i] = pattern + np.random.randn(H, W)
pred[0, i] = targets[0, i] + np.random.randn(H, W) * 0.3
variable_names = ["Temp", "Press", "Humid"]
timestamp = datetime(2024, 1, 1, 12, 0)
# Test with numpy arrays
expected_path = plot_global_surface_robinson(
predictions=pred,
targets=targets,
coarse_inputs=coarse_inputs,
lat_1d=lat_1d,
lon_1d=lon_1d,
timestamp=timestamp,
variable_names=variable_names,
filename="plot_global_robinson_standard.png",
save_dir=self.output_dir,
)
self.assertTrue(os.path.exists(expected_path))
# Test with PyTorch tensors
expected_path = plot_global_surface_robinson(
predictions=torch.from_numpy(pred),
targets=torch.from_numpy(targets),
coarse_inputs=torch.from_numpy(coarse_inputs),
lat_1d=lat_1d,
lon_1d=lon_1d,
timestamp=timestamp,
variable_names=variable_names,
filename="plot_global_robinson_torch.png",
save_dir=self.output_dir,
)
self.assertTrue(os.path.exists(expected_path))
if self.logger:
self.logger.info("✅ All global Robinson surface plot tests passed")
[docs]
def test_plot_mae_map_comprehensive(self):
"""Comprehensive test for time-averaged MAE spatial map plots."""
if self.logger:
self.logger.info("Testing MAE map plots comprehensively")
# Regional lat/lon
lat_1d = np.linspace(30, 50, 48)
lon_1d = np.linspace(-120, -80, 68)
# Matching spatial resolution
predictions = self.predictions[:, :, :48, :68]
targets = self.targets[:, :, :48, :68]
# Test 1: Standard numpy inputs
expected_path = plot_MAE_map(
predictions=predictions,
targets=targets,
lat_1d=lat_1d,
lon_1d=lon_1d,
variable_names=self.variable_names,
save_dir=self.output_dir,
filename="validation_mae_map_standard.png",
)
self.assertTrue(
os.path.exists(expected_path), f"File not found: {expected_path}"
)
# Test 2: PyTorch tensors
expected_path = plot_MAE_map(
predictions=torch.from_numpy(predictions),
targets=torch.from_numpy(targets),
lat_1d=lat_1d,
lon_1d=lon_1d,
variable_names=self.variable_names,
save_dir=self.output_dir,
filename="validation_mae_map_torch.png",
)
self.assertTrue(
os.path.exists(expected_path), f"File not found: {expected_path}"
)
# Test 3: Single variable
expected_path = plot_MAE_map(
predictions=predictions[:, 0:1],
targets=targets[:, 0:1],
lat_1d=lat_1d,
lon_1d=lon_1d,
variable_names=[self.variable_names[0]],
save_dir=self.output_dir,
filename="validation_mae_map_single_var.png",
)
self.assertTrue(
os.path.exists(expected_path), f"File not found: {expected_path}"
)
if self.logger:
self.logger.info("✅ All MAE map plot tests passed")
[docs]
def test_plot_error_map_comprehensive(self):
"""Comprehensive test for time-averaged ERROR spatial map plots."""
if self.logger:
self.logger.info("Testing ERROR map plots comprehensively")
# Regional lat/lon
lat_1d = np.linspace(30, 50, 48)
lon_1d = np.linspace(-120, -80, 68)
# Matching spatial resolution
predictions = self.predictions[:, :, :48, :68]
targets = self.targets[:, :, :48, :68]
# Test 1: Standard numpy inputs
expected_path = plot_error_map(
predictions=predictions,
targets=targets,
lat_1d=lat_1d,
lon_1d=lon_1d,
variable_names=self.variable_names,
save_dir=self.output_dir,
filename="validation_error_map_standard.png",
)
self.assertTrue(
os.path.exists(expected_path), f"File not found: {expected_path}"
)
# Test 2: PyTorch tensors
expected_path = plot_error_map(
predictions=torch.from_numpy(predictions),
targets=torch.from_numpy(targets),
lat_1d=lat_1d,
lon_1d=lon_1d,
variable_names=self.variable_names,
save_dir=self.output_dir,
filename="validation_error_map_torch.png",
)
self.assertTrue(
os.path.exists(expected_path), f"File not found: {expected_path}"
)
# Test 3: Single variable
expected_path = plot_error_map(
predictions=predictions[:, 0:1],
targets=targets[:, 0:1],
lat_1d=lat_1d,
lon_1d=lon_1d,
variable_names=[self.variable_names[0]],
save_dir=self.output_dir,
filename="validation_error_map_single_var.png",
)
self.assertTrue(
os.path.exists(expected_path), f"File not found: {expected_path}"
)
if self.logger:
self.logger.info("✅ All ERROR map plot tests passed")
[docs]
def test_plot_spread_skill_ratio_map_comprehensive(self):
"""Comprehensive test for time-averaged MAE spatial map plots."""
if self.logger:
self.logger.info("Testing SSR map plots comprehensively")
# Regional lat/lon
lat_1d = np.linspace(30, 50, 48)
lon_1d = np.linspace(-120, -80, 68)
# Matching spatial resolution
predictions = self.predictions[:, :, :48, :68]
# Transform predictions into ensemble by adding noise:
T, C, h, w = predictions.shape
noise = np.random.normal(size=(10, T, C, h, w))
predictions_ensemble = predictions + noise
targets = self.targets[:, :, :48, :68]
# Test 1: Standard numpy inputs
expected_path = plot_spread_skill_ratio_map(
predictions=predictions_ensemble,
targets=targets,
lat_1d=lat_1d,
lon_1d=lon_1d,
variable_names=self.variable_names,
save_dir=self.output_dir,
filename="validation_ssr_map_standard.png",
)
self.assertTrue(
os.path.exists(expected_path), f"File not found: {expected_path}"
)
# Test 2: PyTorch tensors
expected_path = plot_spread_skill_ratio_map(
predictions=torch.from_numpy(predictions_ensemble),
targets=torch.from_numpy(targets),
lat_1d=lat_1d,
lon_1d=lon_1d,
variable_names=self.variable_names,
save_dir=self.output_dir,
filename="validation_ssr_map_torch.png",
)
self.assertTrue(
os.path.exists(expected_path), f"File not found: {expected_path}"
)
# Test 3: Single variable
expected_path = plot_spread_skill_ratio_map(
predictions=predictions_ensemble[:, :, 0:1],
targets=targets[:, 0:1],
lat_1d=lat_1d,
lon_1d=lon_1d,
variable_names=[self.variable_names[0]],
save_dir=self.output_dir,
filename="validation_ssr_map_single_var.png",
)
self.assertTrue(
os.path.exists(expected_path), f"File not found: {expected_path}"
)
if self.logger:
self.logger.info("✅ All spread skill ratio map plot tests passed")
[docs]
def test_plot_spread_skill_ratio_hexbin_comprehensive(self):
"""Comprehensive test for spread skill ratio hexbin plots"""
if self.logger:
self.logger.info("Testing SSR hexbin plots comprehensively")
# Matching spatial resolution
predictions = self.predictions[:, :, :48, :68]
# Transform predictions into ensemble by adding noise:
T, C, h, w = predictions.shape
noise = np.random.normal(size=(10, T, C, h, w))
predictions_ensemble = predictions + noise
targets = self.targets[:, :, :48, :68]
# Test 1: Standard numpy inputs
expected_path = plot_spread_skill_ratio_hexbin(
predictions=predictions_ensemble,
targets=targets,
variable_names=self.variable_names,
save_dir=self.output_dir,
filename="validation_ssr_hexbin_standard.png",
)
self.assertTrue(
os.path.exists(expected_path), f"File not found: {expected_path}"
)
# Test 2: PyTorch tensors
expected_path = plot_spread_skill_ratio_hexbin(
predictions=torch.from_numpy(predictions_ensemble),
targets=torch.from_numpy(targets),
variable_names=self.variable_names,
save_dir=self.output_dir,
filename="validation_ssr_hexbin_torch.png",
)
self.assertTrue(
os.path.exists(expected_path), f"File not found: {expected_path}"
)
# Test 3: Single variable
expected_path = plot_spread_skill_ratio_hexbin(
predictions=predictions_ensemble[:, :, 0:1],
targets=targets[:, 0:1],
variable_names=[self.variable_names[0]],
save_dir=self.output_dir,
filename="validation_ssr_hexbin_single_var.png",
)
self.assertTrue(
os.path.exists(expected_path), f"File not found: {expected_path}"
)
if self.logger:
self.logger.info("✅ All spread skill ratio hexbin plot tests passed")
[docs]
def test_plot_mean_divergence_map_comprehensive(self):
"""Comprehensive test for mean divergence map plots."""
if self.logger:
self.logger.info("Testing divergence map plots comprehensively")
# Regional lat/lon
lat_1d = np.linspace(30, 50, 48)
lon_1d = np.linspace(-120, -80, 68)
# Matching spatial resolution
u_pred = self.predictions[:, 0, :48, :68]
v_pred = self.predictions[:, 1, :48, :68]
u_target = self.targets[:, 0, :48, :68]
v_target = self.targets[:, 1, :48, :68]
# Test 1: Standard numpy inputs
expected_path = plot_mean_divergence_map(
u_pred,
v_pred,
u_target,
v_target,
spacing=1,
lat_1d=lat_1d,
lon_1d=lon_1d,
save_dir=self.output_dir,
filename="validation_mean_divergence_map_standard.png",
)
self.assertTrue(
os.path.exists(expected_path), f"File not found: {expected_path}"
)
# Test 2: PyTorch tensors
expected_path = plot_mean_divergence_map(
torch.from_numpy(u_pred),
torch.from_numpy(v_pred),
torch.from_numpy(u_target),
torch.from_numpy(v_target),
spacing=1,
lat_1d=lat_1d,
lon_1d=lon_1d,
save_dir=self.output_dir,
filename="validation_mean_divergence_map_torch.png",
)
self.assertTrue(
os.path.exists(expected_path), f"File not found: {expected_path}"
)
if self.logger:
self.logger.info("✅ All mean divergence map plot tests passed")
[docs]
def test_plot_mean_curl_map_comprehensive(self):
"""Comprehensive test for mean curl map plots."""
if self.logger:
self.logger.info("Testing curl map plots comprehensively")
# Regional lat/lon
lat_1d = np.linspace(30, 50, 48)
lon_1d = np.linspace(-120, -80, 68)
# Matching spatial resolution
u_pred = self.predictions[:, 0, :48, :68]
v_pred = self.predictions[:, 1, :48, :68]
u_target = self.targets[:, 0, :48, :68]
v_target = self.targets[:, 1, :48, :68]
# Test 1: Standard numpy inputs
expected_path = plot_mean_curl_map(
u_pred,
v_pred,
u_target,
v_target,
spacing=1,
lat_1d=lat_1d,
lon_1d=lon_1d,
save_dir=self.output_dir,
filename="validation_mean_curl_map_standard.png",
)
self.assertTrue(
os.path.exists(expected_path), f"File not found: {expected_path}"
)
# Test 2: PyTorch tensors
expected_path = plot_mean_curl_map(
torch.from_numpy(u_pred),
torch.from_numpy(v_pred),
torch.from_numpy(u_target),
torch.from_numpy(v_target),
spacing=1,
lat_1d=lat_1d,
lon_1d=lon_1d,
save_dir=self.output_dir,
filename="validation_mean_curl_map_torch.png",
)
self.assertTrue(
os.path.exists(expected_path), f"File not found: {expected_path}"
)
if self.logger:
self.logger.info("✅ All mean curl map plot tests passed")
[docs]
def test_plot_dry_frequency_map_comprehensive(self):
"""Comprehensive test for dry frequency map plots."""
if self.logger:
self.logger.info("Testing dry frequency map plots comprehensively")
# Regional lat/lon
lat_1d = np.linspace(30, 50, 48)
lon_1d = np.linspace(-120, -80, 68)
# Matching spatial resolution
predictions = self.predictions[:, :, :48, :68]
targets = self.targets[:, :, :48, :68]
# Test 1: Standard numpy inputs
expected_path = plot_dry_frequency_map(
predictions=predictions[:, 0, :, :],
targets=targets[:, 0, :, :],
threshold=1,
lat_1d=lat_1d,
lon_1d=lon_1d,
save_dir=self.output_dir,
filename="validation_dry_frequency_map_standard.png",
)
self.assertTrue(
os.path.exists(expected_path), f"File not found: {expected_path}"
)
# Test 2: PyTorch tensors
expected_path = plot_dry_frequency_map(
predictions=torch.from_numpy(predictions[:, 0, :, :]),
targets=torch.from_numpy(targets[:, 0, :, :]),
threshold=1,
lat_1d=lat_1d,
lon_1d=lon_1d,
save_dir=self.output_dir,
filename="validation_dry_frequency_map_torch.png",
)
self.assertTrue(
os.path.exists(expected_path), f"File not found: {expected_path}"
)
if self.logger:
self.logger.info("✅ All dry frequency map plot tests passed")
[docs]
def test_dry_frequency_map(self):
"""Comprehensive test for the dry frequency map compute function."""
if self.logger:
self.logger.info(
"Testing dry frequency map compute function comprehensively"
)
predictions = self.predictions[:, :, :48, :68]
# Test 1 : standard numpy inputs
arr = dry_frequency_map(predictions[:, 0, :, :], 1)
self.assertTrue(arr.shape == predictions.shape[-2:])
# Test 1 : torch tensors
arr = dry_frequency_map(torch.from_numpy(predictions[:, 0, :, :]), 1)
self.assertTrue(arr.shape == predictions.shape[-2:])
if self.logger:
self.logger.info("✅ All dry frequency tests passed")
[docs]
def test_divergence(self):
"""Comprehensive test for the divergence compute function."""
if self.logger:
self.logger.info("Testing divergence compute function comprehensively")
u = self.predictions[:, 0, :48, :68]
v = self.predictions[:, 1, :48, :68]
# test 1 : standard numpy inputs :
div = get_divergence(u, v, spacing=1)
self.assertTrue(div.shape == u.shape)
# test 2 : torch inputs :
div_torch = get_divergence(torch.from_numpy(u), torch.from_numpy(v), spacing=1)
self.assertTrue(div_torch.shape == u.shape)
if self.logger:
self.logger.info("✅ All divergence tests passed")
[docs]
def test_curl(self):
"""Comprehensive test for the curl compute function."""
if self.logger:
self.logger.info("Testing curl compute function comprehensively")
u = self.predictions[:, 0, :48, :68]
v = self.predictions[:, 1, :48, :68]
# test 1 : standard numpy inputs :
curl = get_curl(u, v, spacing=1)
self.assertTrue(curl.shape == u.shape)
# test 2 : torch inputs :
curl_torch = get_curl(torch.from_numpy(u), torch.from_numpy(v), spacing=1)
self.assertTrue(curl_torch.shape == u.shape)
if self.logger:
self.logger.info("✅ All curl tests passed")
[docs]
def test_metric_plots_comprehensive(self):
"""Comprehensive test for metric plots."""
if self.logger:
self.logger.info("Testing metric plots comprehensively")
# Test 1: Metric histories
expected_path = plot_metric_histories(
valid_metrics_history=self.valid_metrics_history,
variable_names=[name.split(" ")[0] for name in self.variable_names],
metric_names=["rmse", "mae", "r2"],
save_dir=self.output_dir,
filename="metric_histories_comprehensive",
)
self.assertTrue(
os.path.exists(expected_path), f"File not found: {expected_path}"
)
# Test 2: Loss histories
expected_path = plot_loss_histories(
train_loss_history=self.train_loss_history,
valid_loss_history=self.valid_loss_history,
save_dir=self.output_dir,
filename="loss_histories_standard.png",
)
self.assertTrue(
os.path.exists(expected_path), f"File not found: {expected_path}"
)
# Test 3: Average metrics
expected_path = plot_average_metrics(
valid_metrics_history=self.valid_metrics_history,
metric_names=["rmse", "mae", "r2"],
save_dir=self.output_dir,
filename="average_metrics_standard.png",
)
self.assertTrue(
os.path.exists(expected_path), f"File not found: {expected_path}"
)
if self.logger:
self.logger.info("✅ All metric plot tests passed")
[docs]
def test_plot_metrics_heatmap_comprehensive(self):
"""Comprehensive test for validation metrics heatmap."""
if self.logger:
self.logger.info("Testing metrics heatmap")
# Local dummy MetricTracker
class DummyMetricTracker:
def __init__(self, values):
self.values = np.asarray(values)
self.count = len(self.values)
def getmean(self):
return float(np.mean(self.values)) if self.count > 0 else np.nan
# Fake MetricTracker-based metrics dict
valid_metrics_trackers = {}
for var in self.variable_names:
var_key = var.split(" ")[0]
for metric in ["rmse", "mae", "r2"]:
# Reuse the existing synthetic histories
history = self.valid_metrics_history[f"{var_key}_pred_vs_fine_{metric}"]
valid_metrics_trackers[f"{var_key}_pred_vs_fine_{metric}"] = (
DummyMetricTracker(history)
)
expected_path = plot_metrics_heatmap(
valid_metrics_history=valid_metrics_trackers,
variable_names=[name.split(" ")[0] for name in self.variable_names],
metric_names=["rmse", "mae", "r2"],
save_dir=self.output_dir,
filename="metrics_heatmap_comprehensive",
)
self.assertTrue(
os.path.exists(expected_path), f"File not found: {expected_path}"
)
if self.logger:
self.logger.info("✅ Metrics heatmap test passed")
[docs]
def test_qq_quantiles_comprehensive(self):
"""Comprehensive test for QQ-quantiles plots."""
if self.logger:
self.logger.info("Testing QQ-quantiles plots comprehensively")
# Test 1: Standard configuration with all parameters
expected_path = plot_qq_quantiles(
predictions=self.predictions,
targets=self.targets,
coarse_inputs=self.coarse_inputs,
variable_names=self.variable_names,
quantiles=[0.90, 0.95, 0.975, 0.99, 0.995],
save_dir=self.output_dir,
filename="qq_quantiles_standard.png",
)
self.assertTrue(
os.path.exists(expected_path), f"File not found: {expected_path}"
)
# Test 4: Single variable (edge case)
expected_path = plot_qq_quantiles(
predictions=self.predictions[:, 0:1], # Keep only first variable
targets=self.targets[:, 0:1],
coarse_inputs=self.coarse_inputs[:, 0:1],
variable_names=["Temperature (K)"],
quantiles=[0.90, 0.95, 0.99],
save_dir=self.output_dir,
filename="qq_quantiles_single_var.png",
)
self.assertTrue(
os.path.exists(expected_path), f"File not found: {expected_path}"
)
# Test 5: PyTorch tensors
predictions_tensor = torch.from_numpy(self.predictions)
targets_tensor = torch.from_numpy(self.targets)
coarse_tensor = torch.from_numpy(self.coarse_inputs)
expected_path = plot_qq_quantiles(
predictions=predictions_tensor,
targets=targets_tensor,
coarse_inputs=coarse_tensor,
variable_names=self.variable_names,
quantiles=[0.90, 0.95, 0.975, 0.99, 0.995],
save_dir=self.output_dir,
filename="qq_quantiles_torch.png",
)
self.assertTrue(
os.path.exists(expected_path), f"File not found: {expected_path}"
)
if self.logger:
self.logger.info("✅ All QQ-quantiles tests passed")
[docs]
def test_mv_correlation(self):
"""Test for correlation over the time dimension for pairs of variables.
Test for correlation over the spatial dimensions.
"""
# Define lat lon grid
w = self.predictions.shape[2]
h = self.predictions.shape[3]
dlat = 20
dlon = 40
lat1 = 30
lon1 = -120
lon, lat = np.meshgrid(
np.linspace(lon1, lon1 + dlon, w), np.linspace(lat1, lat1 + dlat, h)
)
# Test 1: Standard configuration Numpy arrays for correlation over time dimension
expected_path = plot_validation_mvcorr(
predictions=self.predictions,
targets=self.targets,
lat=lat,
lon=lon,
variable_names=self.variable_names,
save_dir=self.output_dir,
filename="validation_mvcorr_numpy.png",
figsize_multiplier=3,
)
self.assertTrue(
os.path.exists(expected_path), f"File not found: {expected_path}"
)
expected_path = plot_validation_mvcorr(
predictions=self.predictions,
targets=self.targets,
lat=lat,
lon=lon,
coarse_inputs=self.coarse_inputs,
variable_names=self.variable_names,
save_dir=self.output_dir,
filename="comparison_mvcorr_numpy.png",
figsize_multiplier=3,
)
self.assertTrue(
os.path.exists(expected_path), f"File not found: {expected_path}"
)
# Test 2: Standard configuration PyTorch tensors for correlation over time dimension
coarse_tensor = torch.from_numpy(self.coarse_inputs.copy())
fine_tensor = torch.from_numpy(self.targets.copy())
pred_tensor = torch.from_numpy(self.predictions.copy())
expected_path = plot_validation_mvcorr(
predictions=pred_tensor,
targets=fine_tensor,
lat=lat,
lon=lon,
coarse_inputs=coarse_tensor,
variable_names=self.variable_names,
save_dir=self.output_dir,
filename="comparison_mvcorr_torch.png",
figsize_multiplier=3,
)
self.assertTrue(
os.path.exists(expected_path), f"File not found: {expected_path}"
)
# Test 3: Standard configuration Numpy arrays for correlation over space dimensions
expected_path = plot_validation_mvcorr_space(
predictions=self.predictions,
targets=self.targets,
variable_names=self.variable_names,
save_dir=self.output_dir,
filename="comparison_mv_corr_space_numpy.png",
figsize_multiplier=3,
)
self.assertTrue(
os.path.exists(expected_path), f"File not found: {expected_path}"
)
# Test 4: Standard configuration Numpy arrays for correlation over space dimensions
expected_path = plot_validation_mvcorr_space(
predictions=self.predictions,
targets=self.targets,
coarse_inputs=coarse_tensor,
variable_names=self.variable_names,
save_dir=self.output_dir,
filename="comparison_mvcorr_space_torch.png",
figsize_multiplier=3,
)
self.assertTrue(
os.path.exists(expected_path), f"File not found: {expected_path}"
)
if self.logger:
self.logger.info("✅ All correlation plots tests passed")
[docs]
def test_temporal_series_comparison_comprehensive(self):
"""Comprehensive test for spatially averaged temporal series comparison."""
if self.logger:
self.logger.info("Testing temporal series comparison comprehensively")
# Here we reinterpret batch_size as time dimension T
T = 100
C = self.num_vars
H = self.h
W = self.w
# Create synthetic temporal signal
time = np.linspace(0, 4 * np.pi, T)
predictions = np.zeros((T, C, H, W))
targets = np.zeros((T, C, H, W))
for c in range(C):
for t in range(T):
seasonal_signal = np.sin(time[t]) * (c + 1)
spatial_pattern = (
np.sin(np.linspace(0, 2 * np.pi, W))[None, :]
* np.cos(np.linspace(0, 2 * np.pi, H))[:, None]
)
targets[t, c] = seasonal_signal + spatial_pattern
predictions[t, c] = (
seasonal_signal + spatial_pattern + np.random.normal(0, 0.1, (H, W))
)
# Test 1: numpy inputs
expected_path = plot_temporal_series_comparison(
predictions=predictions,
targets=targets,
variable_names=self.variable_names,
save_dir=self.output_dir,
filename="temporal_series_numpy.png",
)
self.assertTrue(
os.path.exists(expected_path),
f"File not found: {expected_path}",
)
# Test 2: torch tensors
expected_path_torch = plot_temporal_series_comparison(
predictions=torch.from_numpy(predictions),
targets=torch.from_numpy(targets),
variable_names=self.variable_names,
save_dir=self.output_dir,
filename="temporal_series_torch.png",
)
self.assertTrue(
os.path.exists(expected_path_torch),
f"File not found: {expected_path_torch}",
)
# Test 3: shape mismatch error
with self.assertRaises(ValueError):
plot_temporal_series_comparison(
predictions=predictions,
targets=targets[:, :, :-1, :], # wrong shape
variable_names=self.variable_names,
save_dir=self.output_dir,
filename="temporal_series_error.png",
)
if self.logger:
self.logger.info("✅ All temporal series comparison tests passed")
[docs]
def test_ranks(self):
"""Comprehensive test for the ranks compute function."""
if self.logger:
self.logger.info("Testing ranks compute function comprehensively")
predictions = self.predictions[:, :, :48, :68]
targets = self.targets[:, 0, :48, :68]
ensemble_size = 10
# Test 1 : standard numpy inputs
predictions_repeated = np.repeat(
predictions[None, :, 0, :, :], axis=0, repeats=ensemble_size
)
arr = ranks(predictions=predictions_repeated, targets=targets)
self.assertTrue(arr.shape == targets.flatten().shape)
# Test 2 : torch tensors
predictions_repeated_torch = torch.from_numpy(predictions_repeated)
targets_torch = torch.from_numpy(targets)
arr_torch = ranks(predictions=predictions_repeated_torch, targets=targets_torch)
self.assertTrue(arr_torch.shape == targets.flatten().shape)
if self.logger:
self.logger.info("✅ All ranks tests passed")
[docs]
def test_plot_ranks(self):
"""Comprehensive test for the ranks plot function."""
if self.logger:
self.logger.info("Testing plot_ranks function comprehensively")
predictions = self.predictions[:, :, :48, :68]
targets = self.targets[:, :, :48, :68]
ensemble_size = 10
# Test 1 : standard numpy inputs
predictions_repeated = np.repeat(
predictions[None, :, :, :, :], axis=0, repeats=ensemble_size
)
expected_path = plot_ranks(
predictions=predictions_repeated,
targets=targets,
variable_names=self.variable_names,
save_dir=self.output_dir,
filename="ranks.png",
)
self.assertTrue(
os.path.exists(expected_path), f"File not found: {expected_path}"
)
# Test 2 : torch inputs
expected_path_torch = plot_ranks(
predictions=torch.from_numpy(predictions_repeated),
targets=torch.from_numpy(targets),
variable_names=self.variable_names,
save_dir=self.output_dir,
filename="ranks_torch.png",
)
self.assertTrue(
os.path.exists(expected_path_torch), f"File not found: {expected_path}"
)
if self.logger:
self.logger.info("✅ All ranks plot tests passed")
[docs]
def tearDown(self):
"""Clean up after tests."""
# Note: We don't remove the output directory so you can inspect the plots
if self.logger:
self.logger.info(
f"Plotting tests completed - plots available in: {self.output_dir}"
)
[docs]
class TestSSRFunction(unittest.TestCase):
"""Unit tests for crps_ensemble_all function."""
[docs]
def __init__(self, methodName="runTest", logger=None):
super().__init__(methodName)
self.logger = logger
[docs]
def setUp(self):
"""Set up test fixtures."""
if self.logger:
self.logger.info("Setting up spread_skill_ratio function test fixtures")
[docs]
def test_ssr_basic(self):
"""Test SSR with simple known values."""
if self.logger:
self.logger.info("Testing SSR basic functionality")
np.random.seed(0) # set random seed for reproducibility.
true = np.zeros((100, 1, 10, 10)) # shape [T,C,h,w]
pred_ens = np.random.normal(
loc=0, scale=0.1, size=(10, 100, 1, 10, 10)
) # N_ens = 10
ssr = spread_skill_ratio(
predictions=pred_ens, targets=true, variable_names=None, pixel_wise=False
)[0] # get element of single element array.
# expected SSR should be ~3.15 :
# let $X_{i,r}$ be the prediction for ensemble member r.
# as targets = mean of predictions, RMSE should be very close to variance of predictions (= 0.01)
# deriving the expected value of the spread is more complicated :
# $ spread = \frac{11}{10} \sqrt{ \frac{1}{N} \sum_{i,r} (X_{i,r} - \bar X_i)^2 } $
# We can inject the definition of $\bar X_i = \frac{1}{R} \sum_{r'} X_{i,r'} $
# and develop the squared term :
# $ X_{i,r} \frac{R-1}{R} - \frac{1}{R} \sum_{r' != r} X_{i,r'} $
# we can replace the mean by the expectance operator, and develop the squared term :
# $ R \times \mathbb{E} [ ( X_{i,r} \frac{R-1}{R} - \frac{1}{R} \sum_{r' != r} X_{i,r'} )^2 ] $
# Since all X_{i,r'} are iid, we can develop the square inside the expected value, the covariance terms will be 0
# and we are left with :
# $ R \times [ \mathbb{V}(X_{i,r}) (\frac{R-1}{R})^2 + \frac{R-1}{R^2} \mathbb{V}(X_{i,r'}) ] $
# plugging in R = 10, we get :
# $ 10 * 0.9 * \mathbb{X}$
# this term is under a square root and multiplied by the corrective factor to give the spread :
# spread = \sqrt{1.1} \times \sqrt{\mathbb{V}(X)} \times 3
# So, SSR = spread / RMSE = \sqrt{1.1} * 3 ~ 3.15
# SSR must be finite and non-negative
self.assertGreaterEqual(ssr, 0.0)
self.assertAlmostEqual(ssr, 3.15, places=1)
if self.logger:
self.logger.info(f"SSR computed : {ssr:.2f} vs SSR expected : ~ 3.15")
self.logger.info("✅ SSR basic test passed")
[docs]
def test_ssr_one_when_perfect_prediction(self):
"""Test SSR is supposed to be 1 when the predictions follow the same distribution as the truth."""
if self.logger:
self.logger.info("Testing SSR perfect prediction")
np.random.seed(0) # set random seed for reproducibility.
true = np.random.normal(
loc=0, scale=0.1, size=(100, 1, 10, 10)
) # shape [T,C,h,w]
pred_ens = np.random.normal(
loc=0, scale=0.1, size=(10, 100, 1, 10, 10)
) # N_ens = 10
ssr = spread_skill_ratio(
predictions=pred_ens, targets=true, variable_names=None, pixel_wise=False
)[0] # get element of single element array.
self.assertAlmostEqual(ssr, 1.0, places=1)
if self.logger:
self.logger.info(f"SSR computed : {ssr:.2f} vs SSR expected : ~ 1.0")
self.logger.info("✅ SSR perfect prediction test passed")
# ----------------------------------------------------------------------------