Source code for IPSL_AID.diagnostics

# Copyright 2026 IPSL / CNRS / Sorbonne University
# Authors: Kazem Ardaneh, Kishanthan Kingston, Pierre Chapel, Rosie Eade
#
# This work is licensed under the Creative Commons
# Attribution-NonCommercial-ShareAlike 4.0 International License.
# To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc-sa/4.0/

import os

os.environ.setdefault(
    "CARTOPY_DATA_DIR",
    "/leonardo_work/EUHPC_D27_095/cartopy_data",
)
import unittest
from datetime import datetime
import numpy as np
import torch
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import matplotlib.colors as mcolors
import matplotlib.patches as patches
from matplotlib.patches import ConnectionPatch
import matplotlib as mpl
from scipy import stats
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import mpltex
from sklearn.metrics import r2_score
import seaborn as sns
import pandas as pd


# ---------------------------------------------
# COMPLETE MATPLOTLIB STYLE CONFIGURATION
# ---------------------------------------------
params = {
    # DPI & figure settings
    # "figure.dpi": 150,
    # "savefig.dpi": 300,
    # Fonts
    "font.family": "DejaVu Sans",
    "mathtext.rm": "arial",
    "font.size": 12,  # General font size (affects ax.text())
    "font.style": "normal",  # 'normal', 'italic', 'oblique'
    "font.weight": "normal",  # 'normal', 'bold', 'heavy', 'light', 'ultrabold', 'ultralight'
    "font.stretch": "normal",  # Font stretch
    # Line properties
    "lines.linewidth": 2,
    "lines.dashed_pattern": [4, 2],
    "lines.dashdot_pattern": [6, 3, 2, 3],
    "lines.dotted_pattern": [2, 3],
    # Axis labels and titles
    "axes.labelsize": 15,
    "axes.titlesize": 15,
    # Tick settings
    "xtick.labelsize": 12,
    "ytick.labelsize": 12,
    "xtick.major.size": 6,
    "ytick.major.size": 6,
    "xtick.direction": "out",
    "ytick.direction": "out",
    # Legend
    "legend.fontsize": 10,
    "legend.loc": "best",
    "legend.frameon": False,
    # Text properties
    "text.color": "black",  # Default text color
    "text.usetex": False,  # LaTeX rendering
    "text.hinting": "auto",  # Text hinting
    "text.antialiased": True,  # Text anti-aliasing
    "text.latex.preamble": "",  # LaTeX preamble
}


mpl.rcParams.update(params)

# ============================================================================
# PLOTTING CONFIGURATION
# ============================================================================


[docs] class PlotConfig: """Central configuration for all plotting functions.""" # General settings DEFAULT_SAVE_DIR = "./results" DEFAULT_FIGSIZE_MULTIPLIER = 4 # Color schemes COLORMAPS = { "temperature": "rainbow", "temp": "rainbow", "2t": "rainbow", "zonal": "BrBG_r", "10u": "BrBG_r", "meridional": "BrBG_r", "10v": "BrBG_r", "tp": "Blues", "TP": "Blues", "precipitation": "Blues", "dewpoint": "rainbow", "d2m": "rainbow", "surface temperature": "rainbow", "st": "rainbow", "pressure": "viridis", "pres": "viridis", "humidity": "Greens", "humid": "Greens", "wind": "coolwarm", "speed": "coolwarm", "mae": "Reds", "error": "Reds", "divergence": "seismic", "curl": "seismic", "ssr": "seismic", "default": "viridis", } # Fixed visualization ranges for error diagnostics FIXED_DIFF_RANGES = { "T2M": (-5.0, 5.0), # K "temperature": (-5.0, 5.0), "2t": (-5.0, 5.0), "VAR_2T": (-5.0, 5.0), "U10": (-5.0, 5.0), # m/s "10u": (-5.0, 5.0), "meridional": (-5.0, 5.0), "VAR_10U": (-5.0, 5.0), "V10": (-5.0, 5.0), # m/s "10v": (-5.0, 5.0), "VAR_10V": (-5.0, 5.0), "TP": (-0.5, 0.5), # mm/h "tp": (-0.5, 0.5), "VAR_TP": (-0.5, 0.5), "VAR_D2M": (-5.0, 5.0), # K "VAR_ST": (-5.0, 5.0), # K } FIXED_DIFF_RANGES_ERRORS = { "VAR_2T": (0, 0.01), # K "VAR_10U": (0, 3.0), # m/s "VAR_10V": (0, 3.0), # m/s "VAR_TP": (0, 0.5), # mm/h "VAR_D2M": (0, 1.0), # K "VAR_ST": (0, 1.0), # K "Temp": (0, 3.0), "Press": (0, 3.0), "Humid": (0, 3.0), "Wind": (0, 3.0), } FIXED_MAE_RANGES = { "T2M": (0.0, 3.0), "temperature": (0.0, 3.0), "2t": (0.0, 3.0), "VAR_2T": (0.0, 3.0), "U10": (0.0, 3.0), "10u": (0.0, 3.0), "meridional": (0.0, 3.0), "VAR_10U": (0.0, 3.0), "V10": (0.0, 3.0), "10v": (0.0, 3.0), "VAR_10V": (0.0, 3.0), "TP": (0.0, 1.0), "tp": (0.0, 1.0), "VAR_TP": (0.0, 1.0), "VAR_D2M": (0.0, 3.0), "VAR_ST": (0.0, 3.0), } FIXED_SSR_RANGES = { "T2M": (0.0, 3.0), "temperature": (0.0, 3.0), "2t": (0.0, 3.0), "VAR_2T": (0.0, 3.0), "U10": (0.0, 3.0), "10u": (0.0, 3.0), "meridional": (0.0, 3.0), "VAR_10U": (0.0, 3.0), "V10": (0.0, 3.0), "10v": (0.0, 3.0), "VAR_10V": (0.0, 3.0), "TP": (0.0, 3.0), "tp": (0.0, 3.0), "VAR_TP": (0.0, 1.0), "VAR_D2M": (0.0, 3.0), "VAR_ST": (0.0, 3.0), } # Geographic features COASTLINE_w = 0.5 BORDER_w = 0.5 LAKE_w = 0.5 BORDER_STYLE = "--" # Colorbar settings COLORBAR_h = 0.02 COLORBAR_PAD = 0.05
[docs] @classmethod def get_colormap(cls, variable_name): """Get appropriate colormap for a variable.""" var_lower = variable_name.lower() for key, cmap in cls.COLORMAPS.items(): if key in var_lower: return cmap return cls.COLORMAPS["default"]
[docs] @classmethod def get_plot_name(cls, variable_name): """Convert variable name to readable plot name.""" # Remove common prefixes name = variable_name.replace("VAR_", "").replace("var_", "") # Special cases if name == "2T": return "Temperature [K]" elif name == "10U": return "Zonal Wind [m/s]" elif name == "10V": return "Meridional Wind [m/s]" elif name == "MSLP": return "Sea Level Pressure" elif name == "T2M": return "2m Temperature [K]" elif name == "U10": return "10m Zonal Wind [m/s]" elif name == "V10": return "10m Meridional Wind [m/s]" elif name == "TP": return "Precipitation [mm/h]" elif name == "tp": return "Precipitation [mm/h]" elif name == "D2M": return "Dewpoint [K]" elif name == "ST": return "Surface Temperature [K]" # General conversion name = name.replace("_", " ") return name.title()
[docs] @classmethod def convert_units(cls, variable_name, data): """ Safe unit conversion when required. - NEVER modifies input - Returns a new array only if conversion is needed """ name = variable_name.lower() if name in ["tp", "var_tp", "precipitation"]: return data * 1000.0 # m to mm return data
[docs] @staticmethod def get_fixed_diff_range(var_name): """Get fixed visualization range for signed differences (Prediction − Truth).""" return PlotConfig.FIXED_DIFF_RANGES.get(var_name, None)
[docs] @staticmethod def get_fixed_diff_range_errors(var_name): """Get fixed visualization range for error map.""" return PlotConfig.FIXED_DIFF_RANGES_ERRORS.get(var_name, None)
[docs] @staticmethod def get_fixed_mae_range(var_name): """Get fixed visualization range for Mean Absolute Error (MAE).""" return PlotConfig.FIXED_MAE_RANGES.get(var_name, None)
[docs] @staticmethod def get_fixed_ssr_range(var_name): """Get fixed visualization range for Spread Skill Ratio (SSR).""" return PlotConfig.FIXED_SSR_RANGES.get(var_name, None)
[docs] def plot_validation_hexbin( predictions, # Model predictions (fine predicted) targets, # Ground truth (fine true) coarse_inputs=None, # Coarse inputs for comparison (optional) variable_names=None, # List of variable names filename="validation_hexbin.png", save_dir="./results", figsize_multiplier=4, # Base size per subplot ): """ Create hexbin plots comparing model predictions vs ground truth for all variables. Parameters ---------- predictions : torch.Tensor or np.array Model predictions of shape [batch_size, num_variables, h, w] targets : torch.Tensor or np.array Ground truth of shape [batch_size, num_variables, h, w] coarse_inputs : torch.Tensor or np.array, optional Coarse inputs of shape [batch_size, num_variables, h, w] variable_names : list of str, optional Names of the variables for subplot titles filename : str, optional Output filename save_dir : str, optional Directory to save the plot figsize_multiplier : int, optional Base size multiplier for subplots """ # Convert to numpy if they're tensors if hasattr(predictions, "detach"): predictions = predictions.detach().cpu().numpy() if hasattr(targets, "detach"): targets = targets.detach().cpu().numpy() if coarse_inputs is not None and hasattr(coarse_inputs, "detach"): coarse_inputs = coarse_inputs.detach().cpu().numpy() batch_size, num_vars, h, w = predictions.shape # Default variable names if not provided if variable_names is None: variable_names = [f"VAR_{i}" for i in range(num_vars)] # Calculate grid dimensions ncols = num_vars nrows = (num_vars + ncols - 1) // ncols # Ceiling division # Create figure fig, axes = plt.subplots( nrows, ncols, figsize=(ncols * figsize_multiplier, figsize_multiplier) ) # axes = np.atleast_1d(axes).ravel() for ax in axes: ax.set_box_aspect(1) # Handle single subplot case if num_vars == 1: axes = np.array([axes]) if axes.ndim == 1: axes = axes.reshape(1, -1) # Flatten axes for easy iteration axes_flat = axes.flatten() # Plot each variable max_count = 0 for i, (var_name, ax) in enumerate(zip(variable_names, axes_flat)): if i >= num_vars: ax.set_visible(False) continue # Flatten the spatial dimensions # pred_flat = predictions[:, i, :, :].reshape(-1) # target_flat = targets[:, i, :, :].reshape(-1) pred_i = PlotConfig.convert_units(var_name, predictions[:, i]) tgt_i = PlotConfig.convert_units(var_name, targets[:, i]) pred_flat = pred_i.reshape(-1) target_flat = tgt_i.reshape(-1) # Create hexbin plot hb = ax.hexbin( target_flat, pred_flat, gridsize=100, cmap="jet", bins="log", mincnt=1 ) # Get counts for colorbar scaling counts = hb.get_array() if counts is not None: max_count = max(max_count, np.max(counts)) # Add identity line min_val = min(target_flat.min(), pred_flat.min()) max_val = max(target_flat.max(), pred_flat.max()) ax.plot([min_val, max_val], [min_val, max_val], "r--", alpha=0.7) # Calculate metrics r2 = r2_score(target_flat, pred_flat) mae = np.mean(np.abs(pred_flat - target_flat)) rmse = np.sqrt(np.mean((pred_flat - target_flat) ** 2)) # Add metrics to plot textstr = f"$R^2$: {r2:.3f}\nMAE: {mae:.3f}\nRMSE: {rmse:.3f}" ax.text( 0.05, 0.95, textstr, transform=ax.transAxes, fontsize=10, verticalalignment="top", ) # Set title # ax.set_title(f'{var_name}') plot_name = PlotConfig.get_plot_name(var_name) ax.set_title(plot_name) # Set equal aspect ratio # ax.set_aspect('equal') # Format ticks ax.xaxis.set_major_locator(ticker.MaxNLocator(5)) ax.yaxis.set_major_locator(ticker.MaxNLocator(5)) # Only show y-label for leftmost subplots if i % ncols == 0: # First column ax.set_ylabel("Predicted Values") else: ax.set_ylabel("") # Remove y-label for non-leftmost plots # Only show x-label for bottom row subplots if i >= (nrows - 1) * ncols: # Last row ax.set_xlabel("True Values") else: ax.set_xlabel("") # Remove x-label for non-bottom plots # Add colorbar # cbar_width_per_subplot = 0.02 # actual_cbar_width = cbar_width_per_subplot / num_vars # cbar_ax = fig.add_axes([0.92, 0.1, actual_cbar_width, 0.8]) # cbar = fig.colorbar(hb, cax=cbar_ax, label=r"$\mathrm{\log_{10}[Count]}$") # Colorbar attached to the LAST axis ax_last = axes_flat[min(num_vars - 1, len(axes_flat) - 1)] cax = ax_last.inset_axes([1.05, 0.0, 0.04, 1.0]) # [x, y, width, height] cbar = fig.colorbar(hb, cax=cax) cbar.set_label(r"$\log_{10}[\mathrm{Count}]$") plt.subplots_adjust( hspace=0.1, wspace=0.3, left=0.1, right=0.9, top=0.9, bottom=0.1 ) # Ensure save directory exists os.makedirs(save_dir, exist_ok=True) save_path = os.path.join(save_dir, filename) plt.savefig(save_path, bbox_inches="tight") plt.close(fig) return save_path
[docs] def plot_comparison_hexbin( predictions, targets, coarse_inputs, variable_names=None, filename="comparison_hexbin.png", save_dir="./results", figsize_multiplier=4, ): """ Create hexbin comparison plots between model predictions, ground truth, and coarse inputs. For each variable, creates two side-by-side hexbin plots: 1. Model predictions vs ground truth 2. Coarse inputs vs ground truth Each plot includes an identity line and R²/MAE metrics. Parameters ---------- predictions : torch.Tensor or np.array Model predictions of shape [batch_size, num_variables, h, w] targets : torch.Tensor or np.array Ground truth of shape [batch_size, num_variables, h, w] coarse_inputs : torch.Tensor or np.array Coarse inputs of shape [batch_size, num_variables, h, w] variable_names : list of str, optional Names of the variables for subplot titles. If None, uses VAR_0, VAR_1, etc. filename : str, optional Output filename save_dir : str, optional Directory to save the plot figsize_multiplier : int, optional Base size multiplier for subplots Returns ------- save_path : str Path to the saved figure """ # Convert tensors → numpy if hasattr(predictions, "detach"): predictions = predictions.detach().cpu().numpy() if hasattr(targets, "detach"): targets = targets.detach().cpu().numpy() if hasattr(coarse_inputs, "detach"): coarse_inputs = coarse_inputs.detach().cpu().numpy() batch_size, num_vars, h, w = predictions.shape if variable_names is None: variable_names = [f"VAR_{i}" for i in range(num_vars)] plot_variable_names = [PlotConfig.get_plot_name(var) for var in variable_names] # For color scaling: collect all hexbin counts all_counts = [] # ---------------------------------------- # 2) Pre-pass: collect hexbin densities # ---------------------------------------- # for i in range(num_vars): # target_flat = targets[:, i].reshape(-1) # pred_flat = predictions[:, i].reshape(-1) # coarse_flat = coarse_inputs[:, i].reshape(-1) for i, var_name in enumerate(variable_names): pred_i = PlotConfig.convert_units(var_name, predictions[:, i]) tgt_i = PlotConfig.convert_units(var_name, targets[:, i]) coarse_i = PlotConfig.convert_units(var_name, coarse_inputs[:, i]) pred_flat = pred_i.reshape(-1) target_flat = tgt_i.reshape(-1) coarse_flat = coarse_i.reshape(-1) # Use a temporary invisible axes to get density arrays fig_tmp, ax_tmp = plt.subplots() hb1 = ax_tmp.hexbin( target_flat, pred_flat, gridsize=100, cmap="jet", bins="log", mincnt=1 ) hb2 = ax_tmp.hexbin( target_flat, coarse_flat, gridsize=100, cmap="jet", bins="log", mincnt=1 ) all_counts.append(hb1.get_array()) all_counts.append(hb2.get_array()) plt.close(fig_tmp) # Global colorbar limits all_counts = np.concatenate(all_counts) global_vmin = np.min(all_counts) global_vmax = np.max(all_counts) # ---------------------------------------- # 3) Actual plot # ---------------------------------------- fig, axes = plt.subplots( num_vars, 2, figsize=(2 * figsize_multiplier, num_vars * figsize_multiplier * 0.8), ) plt.subplots_adjust( hspace=0.3, wspace=0.4, left=0.1, right=0.9, top=0.9, bottom=0.1 ) if num_vars == 1: axes = axes.reshape(1, -1) last_hb = None for i, var_name in enumerate(variable_names): # target_flat = targets[:, i].reshape(-1) # pred_flat = predictions[:, i].reshape(-1) # coarse_flat = coarse_inputs[:, i].reshape(-1) pred_i = PlotConfig.convert_units(var_name, predictions[:, i]) tgt_i = PlotConfig.convert_units(var_name, targets[:, i]) coarse_i = PlotConfig.convert_units(var_name, coarse_inputs[:, i]) pred_flat = pred_i.reshape(-1) target_flat = tgt_i.reshape(-1) coarse_flat = coarse_i.reshape(-1) # Calculate per-variable min/max for this variable var_min = min(target_flat.min(), pred_flat.min(), coarse_flat.min()) var_max = max(target_flat.max(), pred_flat.max(), coarse_flat.max()) # Add a small margin margin = 0.05 * (var_max - var_min) plot_min = var_min - margin plot_max = var_max + margin # -------------------------- # Left: Model vs True # -------------------------- ax = axes[i, 0] hb = ax.hexbin( target_flat, pred_flat, gridsize=100, cmap="jet", bins="log", mincnt=1, vmin=global_vmin, vmax=global_vmax, ) last_hb = hb # store for colorbar # Use per-variable axis limits ax.set_xlim(plot_min, plot_max) ax.set_ylim(plot_min, plot_max) # identity line ax.plot([plot_min, plot_max], [plot_min, plot_max], "r--", alpha=0.7) r2 = r2_score(target_flat, pred_flat) mae = np.mean(np.abs(pred_flat - target_flat)) ax.text( 0.05, 0.95, f"$R^2$: {r2:.3f}\nMAE: {mae:.3f}", transform=ax.transAxes, va="top", ) ax.set_title(f"{plot_variable_names[i]} – Model vs True") ax.set_ylabel("Model Values") if i == num_vars - 1: ax.set_xlabel("True Values") else: ax.set_xlabel("") # -------------------------- # Right: Coarse vs True # -------------------------- ax = axes[i, 1] hb = ax.hexbin( target_flat, coarse_flat, gridsize=100, cmap="jet", bins="log", mincnt=1, vmin=global_vmin, vmax=global_vmax, ) last_hb = hb # Use the same per-variable limits ax.set_xlim(plot_min, plot_max) ax.set_ylim(plot_min, plot_max) ax.plot( [plot_min, plot_max], [plot_min, plot_max], "r--", alpha=0.7, linewidth=1 ) r2 = r2_score(target_flat, coarse_flat) mae = np.mean(np.abs(coarse_flat - target_flat)) ax.text( 0.05, 0.95, f"$R^2$: {r2:.3f}\nMAE: {mae:.3f}", transform=ax.transAxes, va="top", ) ax.set_title(f"{plot_variable_names[i]} – Coarse vs True") ax.set_ylabel("Coarse Values") if i == num_vars - 1: ax.set_xlabel("True Values") else: ax.set_xlabel("") # ---------------------------------------- # 4) Single shared colorbar # ---------------------------------------- cbar_ax = fig.add_axes([0.98, 0.1, 0.02, 0.8]) fig.colorbar(last_hb, cax=cbar_ax, label=r"$\log_{10}[\mathrm{Count}]$") # Save os.makedirs(save_dir, exist_ok=True) save_path = os.path.join(save_dir, filename) plt.savefig(save_path, bbox_inches="tight") plt.close(fig) return save_path
[docs] def plot_metric_histories( valid_metrics_history, variable_names, metric_names, filename="validation_metrics", save_dir="./results", figsize_multiplier=4, ): """ Creates row-based panel plots: one figure per metric, rows = variables, shared x-axis. Parameters ---------- valid_metrics_history : dict Dict from training loop storing metric histories. variable_names : list of str Names of variables. metric_names : list of str List of metric names (e.g. ["MAE"]). filename : str Prefix for saved figures. save_dir : str Directory where images are saved. """ os.makedirs(save_dir, exist_ok=True) num_vars = len(variable_names) for metric in metric_names: # Rows = variables, 1 column, shared x-axis fig, axes = plt.subplots( nrows=num_vars, ncols=1, figsize=(6, figsize_multiplier * num_vars), squeeze=False, sharex=True, ) plt.subplots_adjust(hspace=0.1, left=0.15, right=0.95, top=0.9, bottom=0.1) for i, var in enumerate(variable_names): ax = axes[i, 0] key_pred = f"{var}_pred_vs_fine_{metric}" key_coarse = f"{var}_coarse_vs_fine_{metric}" if ( key_pred not in valid_metrics_history or key_coarse not in valid_metrics_history ): ax.text(0.5, 0.5, "Missing Data", ha="center", va="center") ax.set_yscale("log") continue pred_hist = valid_metrics_history[key_pred] coarse_hist = valid_metrics_history[key_coarse] # Plot linestyles = mpltex.linestyle_generator(markers=[]) ax.plot(pred_hist, label="Pred vs Fine", **next(linestyles)) ax.plot(coarse_hist, label="Coarse vs Fine", **next(linestyles)) ax.set_yscale("log") # ax.set_ylabel(rf"$\mathrm{{{metric}\ ({var})}}$") ax.set_ylabel(f"{metric} ({var})") ax.grid(True, alpha=0.3) ax.legend() # Only bottom row shows x-axis label if i == num_vars - 1: ax.set_xlabel("Epoch") else: ax.tick_params(labelbottom=False) save_path = os.path.join(save_dir, f"{filename}_{metric}.png") plt.savefig(save_path, bbox_inches="tight") plt.close(fig) return save_path
[docs] def plot_metrics_heatmap( valid_metrics_history, variable_names, metric_names, filename="validation_metrics_heatmap", save_dir="./results", figsize_multiplier=4, ): """ Plot a heatmap of validation metrics. Parameters ---------- valid_metrics_history : dict Dict from validation loop storing metric histories. variable_names : list of str Names of variables. metric_names : list of str List of metric names (["MAE", "NMAE", "RMSE", "R²"]). filename : str Prefix for saved figures. save_dir : str Directory where images are saved. figsize_multiplier : float Controls overall figure size """ os.makedirs(save_dir, exist_ok=True) # Build DataFrame data = {} for metric in metric_names: values = [] for var in variable_names: key = f"{var}_pred_vs_fine_{metric}" if key in valid_metrics_history: tracker = valid_metrics_history[key] if tracker.count > 0: value = tracker.getmean() # Convert only dimensional metrics if metric.lower() in ["mae", "rmse", "crps"]: value = PlotConfig.convert_units(var, value) else: value = np.nan else: value = np.nan values.append(value) data[metric] = values df = pd.DataFrame(data, index=variable_names) fig_width = figsize_multiplier + len(metric_names) fig_height = 0.6 * len(variable_names) + figsize_multiplier / 2 fig, ax = plt.subplots(figsize=(fig_width, fig_height)) sns.heatmap( df, ax=ax, cmap="viridis", annot=True, fmt=".3f", linewidths=0.8, cbar=True ) ax.set_title("Validation metrics") ax.set_xlabel("Metric") ax.set_ylabel("Variable") plt.tight_layout() save_path = os.path.join(save_dir, f"{filename}.png") plt.savefig(save_path, bbox_inches="tight") plt.close(fig) return save_path
[docs] def plot_loss_histories( train_loss_history, valid_loss_history, filename="training_validation_loss.png", save_dir="./results", figsize_multiplier=4, ): """ Plots training and validation loss in a single panel. Parameters: ----------- train_loss_history : list or array History of training loss values. valid_loss_history : list or array History of validation loss values. filename : str Output image file name for the plot. save_dir : str Directory to save the plot. """ # Ensure inputs are lists if not isinstance(train_loss_history, list): train_loss_history = list(train_loss_history) if not isinstance(valid_loss_history, list): valid_loss_history = list(valid_loss_history) fig = plt.figure(figsize=(6, figsize_multiplier)) ax = fig.add_subplot(111) epochs = range(len(train_loss_history)) # Plot losses linestyles = mpltex.linestyle_generator(markers=[]) ax.plot(epochs, train_loss_history, label="Training Loss", **next(linestyles)) if valid_loss_history and any(valid_loss_history): ax.plot(epochs, valid_loss_history, label="Validation Loss", **next(linestyles)) ax.set_yscale("log") ax.set_xlabel("Epoch") ax.set_ylabel("Loss Value") ax.legend() ax.grid(True, alpha=0.3) # Ensure save directory exists os.makedirs(save_dir, exist_ok=True) save_path = os.path.join(save_dir, filename) plt.savefig(save_path, bbox_inches="tight") plt.close(fig) return save_path print(f"Loss history plot saved to: '{save_path}'")
[docs] def plot_average_metrics( valid_metrics_history, metric_names, # List of metrics to plot filename="average_metrics.png", save_dir="./results", figsize_multiplier=4, ): """ Plots average metrics across all variables in a row-based layout with shared x-axis. Each row corresponds to one metric, plotting both: - average_pred_vs_fine_<metric> - average_coarse_vs_fine_<metric> Parameters ---------- valid_metrics_history : dict Dictionary containing validation metrics history. metric_names : list of str Names of metrics to plot. filename : str Output image file name for the plot. save_dir : str Directory to save the plot. """ if not metric_names: print("No metric names provided") return num_rows = len(metric_names) # Create figure: rows = num_rows, 1 column, share x-axis fig, axes = plt.subplots( nrows=num_rows, ncols=1, figsize=(6, figsize_multiplier * num_rows), squeeze=False, sharex=True, ) plt.subplots_adjust(hspace=0.1, left=0.15, right=0.95, top=0.95, bottom=0.1) for idx, metric in enumerate(metric_names): ax = axes[idx, 0] linestyles = mpltex.linestyle_generator(markers=[]) # Keys key_pred = f"average_pred_vs_fine_{metric}" key_coarse = f"average_coarse_vs_fine_{metric}" # Plot pred vs fine if key_pred in valid_metrics_history: hist = valid_metrics_history[key_pred] if not isinstance(hist, list): hist = list(hist) ax.plot(hist, label="Pred vs Fine", **next(linestyles)) # Plot coarse vs fine if key_coarse in valid_metrics_history: hist = valid_metrics_history[key_coarse] if not isinstance(hist, list): hist = list(hist) ax.plot(hist, label="Coarse vs Fine", **next(linestyles)) ax.set_yscale("log") ax.set_ylabel(metric.replace("_", " ").title()) ax.grid(True, alpha=0.3) ax.legend() # Only bottom row gets x-label if idx == num_rows - 1: ax.set_xlabel("Epoch") else: ax.set_xlabel("") ax.tick_params(labelbottom=False) # Ensure save directory exists os.makedirs(save_dir, exist_ok=True) save_path = os.path.join(save_dir, filename) plt.savefig(save_path, bbox_inches="tight") plt.close(fig) return save_path
[docs] def plot_spatiotemporal_histograms( steps, tindex_lim, centers, tindices, mode="train", filename="average_metrics.png", save_dir="./results", figsize_multiplier=4, ): """ Plot two 2D hexagonal bin histograms showing spatial-temporal data coverage: latitude center vs temporal index and longitude center vs temporal index. This function visualizes the distribution of data samples across spatial (latitude/longitude) and temporal dimensions using hexagonal binning, which provides smoother density estimation compared to rectangular binning. Parameters ---------- steps : EasyDict Dictionary containing coordinate dimensions and limits. Expected to have attributes 'latitude' (or 'lat') and 'longitude' (or 'lon') specifying the maximum spatial indices. tindex_lim : tuple Tuple of (min_time, max_time) specifying the temporal index limits. centers : list of tuples List of (lat_center, lon_center) coordinates for each data sample. Each center represents the spatial location of a data point. tindices : list or array-like List of temporal indices corresponding to each data sample. Should have the same length as 'centers'. mode : str Dataset mode identifier, typically "train" or "validation". Used for plot title and filename. save_dir : str Directory path where the plot will be saved. Directory will be created if it doesn't exist. filename : str, optional Optional prefix to prepend to the output filename. Default is empty string. Returns ------- None The function saves the plot to disk and does not return any value. Notes ----- - The function creates two side-by-side subplots: 1. Latitude center index vs temporal index 2. Longitude center index vs temporal index - Uses hexagonal binning (hexbin) for density visualization, which reduces visual artifacts compared to rectangular histograms. - A single colorbar is shared between both plots with log10 scaling. - The color scale is normalized to the maximum count across both histograms. - Hexagons with zero count (mincnt=1) are not displayed. Examples -------- >>> steps = EasyDict({'latitude': 180, 'longitude': 360}) >>> tindex_lim = (0, 1000) >>> centers = [(10, 20), (15, 25), (10, 20), ...] # list of (lat, lon) >>> tindices = [0, 5, 10, 15, ...] # corresponding temporal indices >>> plot_spatiotemporal_histograms(steps, tindex_lim, centers, ... tindices, "train", "./plots") The function will save a plot named "spatiotemporal_train_hexbin.png" in the "./plots" directory. """ if not centers or not tindices: print(f"No data to plot for {mode} mode") return # Convert to numpy arrays for efficient processing centers = np.array(centers) lat_centers = centers[:, 0] lon_centers = centers[:, 1] tindices = np.array(tindices) # Extract spatial limits from steps dictionary with fallback options max_lat = getattr(steps, "latitude", getattr(steps, "lat", None)) max_lon = getattr(steps, "longitude", getattr(steps, "lon", None)) min_time, max_time = tindex_lim # Create figure with two side-by-side subplots sharing y-axis fig, (ax1, ax2) = plt.subplots( 1, 2, figsize=(2 * figsize_multiplier, figsize_multiplier), sharey=True ) plt.subplots_adjust( hspace=0.1, wspace=0.1, left=0.1, right=0.9, top=0.9, bottom=0.1 ) # Plot latitude vs time using hexagonal binning hex1 = ax1.hexbin( lat_centers, tindices, gridsize=100, # Number of hexagons in x-direction extent=[0, max_lat, min_time, max_time], # Data limits cmap="jet", # Color map (assumed to be defined) mincnt=1, # Only show hexagons with at least 1 count edgecolors="none", ) # No borders on hexagons ax1.set_xlabel("Latitude Center Index", fontsize=12) ax1.set_ylabel("Temporal Index", fontsize=12) ax1.set_xlim(0, max_lat) ax1.set_ylim(min_time, max_time) ax1.grid(True, alpha=0.3, linestyle="--") # Plot longitude vs time using hexagonal binning hex2 = ax2.hexbin( lon_centers, tindices, gridsize=100, extent=[0, max_lon, min_time, max_time], cmap="jet", mincnt=1, edgecolors="none", ) ax2.set_xlabel("Longitude Center Index", fontsize=12) ax2.set_xlim(0, max_lon) ax2.set_ylim(min_time, max_time) ax2.grid(True, alpha=0.3, linestyle="--") # Normalize color scale to maximum count across both plots max_count = 1 if hex1.get_array() is not None and len(hex1.get_array()) > 0: max_count = max(max_count, hex1.get_array().max()) if hex2.get_array() is not None and len(hex2.get_array()) > 0: max_count = max(max_count, hex2.get_array().max()) hex1.set_clim(0, max_count) hex2.set_clim(0, max_count) # Add single colorbar for both plots cbar_ax = fig.add_axes([0.93, 0.1, 0.02, 0.8]) fig.colorbar(hex1, cax=cbar_ax, label=r"$\log_{10}[\mathrm{Count}]$") # Save figure to disk os.makedirs(save_dir, exist_ok=True) filename = f"{filename}spatiotemporal_{mode}_hexbin.png" save_path = os.path.join(save_dir, filename) plt.savefig(save_path, bbox_inches="tight") plt.close(fig) return save_path
[docs] def plot_surface( predictions, targets, coarse_inputs, lat_1d, lon_1d, timestamp=None, variable_names=None, filename="forecast_plot.png", save_dir=None, figsize_multiplier=None, ): """ Plot side-by-side forecast maps (coarse_inputs input, true target, model prediction, and difference) for one or more meteorological variables over a geographic domain. Parameters ---------- coarse_inputs : torch.Tensor or np.ndarray coarse_inputs-resolution input data with shape [1, n_vars, H, W]. targets : torch.Tensor or np.ndarray Ground-truth high-resolution data with shape [1, n_vars, H, W]. predictions : torch.Tensor or np.ndarray Model predictions at targets resolution with shape [1, n_vars, H, W]. lat_1d : array-like 1D array of latitude coordinates with shape [H]. lon_1d : array-like 1D array of longitude coordinates with shape [W]. timestamp : datetime.datetime Forecast timestamp to include in the plot title. variable_names : list of str, optional Variable names or identifiers. filename : str, optional Output filename for saving the plot. save_dir : str, optional Directory to save the plot. figsize_multiplier : int, optional Base size multiplier for subplots. Returns ------- None """ # Use defaults from config if not provided if save_dir is None: save_dir = PlotConfig.DEFAULT_SAVE_DIR if figsize_multiplier is None: figsize_multiplier = PlotConfig.DEFAULT_FIGSIZE_MULTIPLIER # Convert tensors to numpy if needed if hasattr(coarse_inputs, "detach"): coarse_inputs = coarse_inputs.detach().cpu().numpy() if hasattr(targets, "detach"): targets = targets.detach().cpu().numpy() if hasattr(predictions, "detach"): predictions = predictions.detach().cpu().numpy() if hasattr(lat_1d, "detach"): lat_1d = lat_1d.detach().cpu().numpy() if hasattr(lon_1d, "detach"): lon_1d = lon_1d.detach().cpu().numpy() # Create 2D meshgrid from 1D coordinates lat_min, lat_max = lat_1d.min(), lat_1d.max() lon_min, lon_max = lon_1d.min(), lon_1d.max() # Shape h, w = coarse_inputs[0, 0].shape lat_block = np.linspace(lat_max, lat_min, h) lon_block = np.linspace(lon_min, lon_max, w) lat, lon = np.meshgrid(lat_block, lon_block, indexing="ij") # Projection center lon_center = float((lon_min + lon_max) / 2) # Check data dimensions n_vars = coarse_inputs.shape[1] if targets.shape[1] != n_vars: raise ValueError( f"targets data has {targets.shape[1]} variables but coarse_inputs has {n_vars}" ) if predictions.shape[1] != n_vars: raise ValueError( f"predictions data has {predictions.shape[1]} variables but coarse_inputs has {n_vars}" ) # Default variable names if not provided if variable_names is None: variable_names = [f"VAR_{i}" for i in range(n_vars)] # Derive plot names and colormaps plot_variable_names = [PlotConfig.get_plot_name(var) for var in variable_names] cmaps = [PlotConfig.get_colormap(var) for var in variable_names] # Derive vmin/vmax from data for each variable (for coarse_inputs, truth, prediction) vmin_list = [] vmax_list = [] # Derive vmin/vmax for difference plots (signed difference) diff_vmin_list = [] diff_vmax_list = [] for i in range(n_vars): var_name = variable_names[i] coarse_i = PlotConfig.convert_units(var_name, coarse_inputs[0, i]) target_i = PlotConfig.convert_units(var_name, targets[0, i]) pred_i = PlotConfig.convert_units(var_name, predictions[0, i]) # Get combined data range for this variable (coarse_inputs, truth, prediction) # all_data = np.concatenate([# # coarse_inputs[0, i].flatten(), # targets[0, i].flatten(), # predictions[0, i].flatten() # ]) all_data = np.concatenate( [coarse_i.flatten(), target_i.flatten(), pred_i.flatten()] ) # Calculate vmin/vmax (using quantile approach like original function) all_data_flat = all_data[~np.isnan(all_data)] if len(all_data_flat) > 0: q_low, q_high = np.quantile(all_data_flat, [0.02, 0.98]) vmin, vmax = float(q_low), float(q_high) else: vmin, vmax = -1, 1 # Ensure vmin < vmax if vmin >= vmax: vmin, vmax = float(np.nanmin(all_data)), float(np.nanmax(all_data)) vmin_list.append(vmin) vmax_list.append(vmax) # Calculate signed difference between prediction and truth fixed_range = PlotConfig.get_fixed_diff_range(var_name) diff_data = (predictions[0, i] - targets[0, i]).flatten() diff_data = diff_data[~np.isnan(diff_data)] if fixed_range is not None: diff_vmin, diff_vmax = fixed_range else: if len(diff_data) > 0: # For signed difference, we want symmetric range around 0 max_abs_diff = np.max(np.abs(diff_data)) diff_vmin = -max_abs_diff * 1.1 # Add 10% padding diff_vmax = max_abs_diff * 1.1 # Add 10% padding # If all differences are zero or very small if diff_vmax <= 0.001: diff_vmin, diff_vmax = -0.1, 0.1 else: diff_vmin, diff_vmax = -1, 1 diff_vmin_list.append(diff_vmin) diff_vmax_list.append(diff_vmax) # Use fixed figure size instead of geo_ratio calculation # This ensures rectangular panels regardless of location base_width_per_panel = 4.5 # Same as original scale base_height_per_panel = 3.0 # Keep this as is fig_width = base_width_per_panel * n_vars fig_height = base_height_per_panel * 4 # 4 rows # Set up figure fig, axes = plt.subplots( 4, n_vars, # 4 rows, n_vars columns figsize=(fig_width, fig_height), subplot_kw={ "projection": ccrs.PlateCarree(central_longitude=lon_center) }, # ccrs.Mercator(central_longitude=lon_center) gridspec_kw={"wspace": 0.1, "hspace": 0.1}, # Keep spacing squeeze=False, ) # Main title if timestamp is not None: # fig.suptitle( # f"Forecast for {timestamp.strftime('%Y-%m-%d %H:%M')}", # fontsize=16, y=1.02 # ) print(f"Forecast for {timestamp.strftime('%Y-%m-%d %H:%M')}") # Define geographic features # coastline = cfeature.COASTLINE.with_scale('50m') # borders = cfeature.BORDERS.with_scale('50m') # lakes = cfeature.LAKES.with_scale('50m') # Plot each variable for col_idx in range(n_vars): # Data for this variable # coarse_inputs_data = coarse_inputs[0, col_idx, :, :] # targets_data = targets[0, col_idx, :, :] # pred_data = predictions[0, col_idx, :, :] var_name = variable_names[col_idx] # plot_name = plot_variable_names[col_idx] coarse_inputs_data = PlotConfig.convert_units( var_name, coarse_inputs[0, col_idx] ) targets_data = PlotConfig.convert_units(var_name, targets[0, col_idx]) pred_data = PlotConfig.convert_units(var_name, predictions[0, col_idx]) diff_data = pred_data - targets_data # Signed difference (pred - truth) # Store image objects for rows that need colorbars im_coar = None im_diff = None # Process all rows for row_idx in range(4): ax = axes[row_idx, col_idx] # Select data based on row if row_idx == 0: data = coarse_inputs_data vmin, vmax = vmin_list[col_idx], vmax_list[col_idx] cmap = cmaps[col_idx] elif row_idx == 1: data = targets_data vmin, vmax = vmin_list[col_idx], vmax_list[col_idx] cmap = cmaps[col_idx] elif row_idx == 2: data = pred_data vmin, vmax = vmin_list[col_idx], vmax_list[col_idx] cmap = cmaps[col_idx] else: # row_idx == 3 data = diff_data vmin, vmax = diff_vmin_list[col_idx], diff_vmax_list[col_idx] cmap = "RdBu_r" # Diverging colormap for differences # Create the plot im = ax.pcolormesh( lon, lat, data, vmin=vmin, vmax=vmax, cmap=cmap, transform=ccrs.PlateCarree(), shading="auto", ) # Store image objects for rows that need colorbars if row_idx == 0: im_coar = im elif row_idx == 3: im_diff = im # Set extent and features # ax.set_global() ax.set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree()) # ax.coastlines(linewidth=0.5) # ax.add_feature(borders, linewidth=0.5, linestyle="--", edgecolor="black") # ax.add_feature(lakes, linewidth=0.5, edgecolor="black", facecolor="none") ax.coastlines(linewidth=0.6) ax.add_feature( cfeature.BORDERS.with_scale("50m"), linewidth=0.9, linestyle="--", edgecolor="black", zorder=11, ) ax.add_feature( cfeature.LAKES.with_scale("50m"), edgecolor="black", facecolor="none", linewidth=0.9, zorder=9, ) # ax.set_aspect("auto") # CRITICAL: This makes panels rectangular regardless of projection ax.set_xticks([]) ax.set_yticks([]) # Add colorbar for PREDICTION row (row 2) if im_coar is not None: ax_coar = axes[0, col_idx] # Position at top of panel: [x, y, width, height] where y > 1.0 places it above cax_top = ax_coar.inset_axes([0.1, 1.05, 0.8, 0.05]) cbar = fig.colorbar(im_coar, cax=cax_top, orientation="horizontal") cbar.set_label(f"{plot_variable_names[col_idx]}") cax_top.xaxis.set_ticks_position("top") cax_top.xaxis.set_label_position("top") # Add colorbar for DIFFERENCE row (row 3) if im_diff is not None: ax_diff = axes[3, col_idx] cax_diff = ax_diff.inset_axes([0.1, -0.12, 0.8, 0.05]) fig.colorbar( im_diff, cax=cax_diff, orientation="horizontal", label=f{plot_variable_names[col_idx]} (Pred - Truth)", ) # Add row labels on the left side row_labels = ["Coarse", "Truth", "Prediction", "Pred - Truth"] for row_idx, label in enumerate(row_labels): axes[row_idx, 0].text( -0.12, 0.5, label, transform=axes[row_idx, 0].transAxes, va="center", ha="right", rotation="vertical", fontsize=12, ) # Adjust layout - give more room at bottom for colorbars fig.subplots_adjust( top=0.90, bottom=0.25, left=0.10, right=0.95, wspace=0.1, hspace=0.15 ) # Save figure os.makedirs(save_dir, exist_ok=True) save_path = os.path.join(save_dir, filename) plt.savefig(save_path, bbox_inches="tight") plt.close(fig) return save_path
[docs] def plot_ensemble_surface( predictions_ens, lat_1d, lon_1d, variable_names, timestamp=None, filename="ensemble_surface.png", save_dir="./results", ): """ Plot ensemble members, ensemble mean, and ensemble spread. Parameters ---------- predictions_ens : torch.Tensor or np.ndarray Ensemble predictions of shape [n_ensemble_members, n_vars, H, W] lat_1d : array-like 1D array of latitude coordinates with shape [H]. lon_1d : array-like 1D array of longitude coordinates with shape [W]. variable_names : list of str, optional Variable names or identifiers. timestamp : datetime.datetime Forecast timestamp to include in the plot title. filename : str, optional Output filename for saving the plot. save_dir : str, optional Directory to save the plot. figsize_multiplier : int, optional Base size multiplier for subplots. Returns ------- None """ if torch.is_tensor(predictions_ens): predictions_ens = predictions_ens.detach().cpu().numpy() N_ens, C, H, W = predictions_ens.shape if N_ens < 3: raise ValueError("Need at least 3 ensemble members") # Ensemble statistics ensemble_mean = np.mean(predictions_ens, axis=0) ensemble_std = np.std(predictions_ens, axis=0) lat_1d = np.asarray(lat_1d) lon_1d = np.asarray(lon_1d) lat_min, lat_max = lat_1d.min(), lat_1d.max() lon_min, lon_max = lon_1d.min(), lon_1d.max() lat_block = np.linspace(lat_max, lat_min, H) lon_block = np.linspace(lon_min, lon_max, W) lat, lon = np.meshgrid(lat_block, lon_block, indexing="ij") lon_center = float((lon_min + lon_max) / 2) plot_variable_names = [PlotConfig.get_plot_name(v) for v in variable_names] cmaps = [PlotConfig.get_colormap(v) for v in variable_names] # Compute vmin/vmax vmin_list = [] vmax_list = [] for i in range(C): var = variable_names[i] ens_members = [ PlotConfig.convert_units(var, predictions_ens[k, i]) for k in range(N_ens) ] mean_i = PlotConfig.convert_units(var, ensemble_mean[i]) all_data = np.concatenate( [x.flatten() for x in ens_members] + [mean_i.flatten()] ) all_data = all_data[~np.isnan(all_data)] if len(all_data) > 0: q_low, q_high = np.quantile(all_data, [0.02, 0.98]) vmin, vmax = float(q_low), float(q_high) else: vmin, vmax = -1, 1 if vmin >= vmax: vmin = float(np.nanmin(all_data)) vmax = float(np.nanmax(all_data)) vmin_list.append(vmin) vmax_list.append(vmax) n_rows = 5 n_cols = C base_width_per_panel = 4.5 base_height_per_panel = 3.0 fig_width = base_width_per_panel * n_cols fig_height = base_height_per_panel * n_rows fig, axes = plt.subplots( n_rows, n_cols, figsize=(fig_width, fig_height), subplot_kw={"projection": ccrs.PlateCarree(central_longitude=lon_center)}, gridspec_kw={"wspace": 0.1, "hspace": 0.15}, squeeze=False, ) row_labels = [ "Prediction 1", "Prediction 2", "Prediction 3", "Ensemble Mean", "Ensemble std (σ)", ] for col in range(n_cols): var = variable_names[col] member1 = PlotConfig.convert_units(var, predictions_ens[0, col]) member2 = PlotConfig.convert_units(var, predictions_ens[1, col]) member3 = PlotConfig.convert_units(var, predictions_ens[2, col]) mean_field = PlotConfig.convert_units(var, ensemble_mean[col]) std_field = PlotConfig.convert_units(var, ensemble_std[col]) rows_data = [member1, member2, member3, mean_field, std_field] im_main = None im_spread = None for row in range(n_rows): ax = axes[row, col] if row == 4: cmap = "Reds" vmin = 0 vmax = np.nanmax(std_field) # vmax = np.quantile(std_field, 0.99) else: cmap = cmaps[col] vmin = vmin_list[col] vmax = vmax_list[col] im = ax.pcolormesh( lon, lat, rows_data[row], vmin=vmin, vmax=vmax, cmap=cmap, transform=ccrs.PlateCarree(), shading="auto", ) ax.set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree()) ax.coastlines(linewidth=0.6) ax.add_feature( cfeature.BORDERS.with_scale("50m"), linewidth=0.9, linestyle="--", edgecolor="black", zorder=11, ) ax.add_feature( cfeature.LAKES.with_scale("50m"), edgecolor="black", facecolor="none", linewidth=0.9, zorder=9, ) ax.set_xticks([]) ax.set_yticks([]) if row == 4: im_spread = im else: if im_main is None: im_main = im ax_top = axes[0, col] cax_top = ax_top.inset_axes([0.1, 1.05, 0.8, 0.05]) cbar = fig.colorbar(im_main, cax=cax_top, orientation="horizontal") cbar.set_label(plot_variable_names[col]) cax_top.xaxis.set_ticks_position("top") cax_top.xaxis.set_label_position("top") ax_bottom = axes[4, col] cax_bottom = ax_bottom.inset_axes([0.1, -0.12, 0.8, 0.05]) fig.colorbar( im_spread, cax=cax_bottom, orientation="horizontal", label=f"Std {plot_variable_names[col]}", ) for r, label in enumerate(row_labels): axes[r, 0].text( -0.12, 0.5, label, transform=axes[r, 0].transAxes, va="center", ha="right", rotation="vertical", fontsize=12, ) if timestamp is not None: print(f"Ensemble predictions — {timestamp}") fig.subplots_adjust( top=0.90, bottom=0.25, left=0.10, right=0.95, wspace=0.1, hspace=0.15, ) os.makedirs(save_dir, exist_ok=True) save_path = os.path.join(save_dir, filename) plt.savefig(save_path, bbox_inches="tight") plt.close(fig) return save_path
[docs] def plot_zoom_comparison( predictions, targets, lat_1d, lon_1d, variable_names=None, filename="zoom_plot.png", save_dir=None, zoom_box=None, ): """ Plot a comparison between ground truth and model predictions with a geographic zoom. Parameters ---------- targets : torch.Tensor or np.ndarray Ground-truth high-resolution data with shape [1, n_vars, H, W]. predictions : torch.Tensor or np.ndarray Model predictions at targets resolution with shape [1, n_vars, H, W]. lat_1d : array-like 1D array of latitude coordinates with shape [H]. lon_1d : array-like 1D array of longitude coordinates with shape [W]. variable_names : list of str, optional Variable names or identifiers. filename : str, optional Output filename for saving the plot. save_dir : str, optional Directory to save the plot. zoom_box : dict, optional Dictionary defining the zoom region with keys. Returns ------- None """ if save_dir is None: save_dir = PlotConfig.DEFAULT_SAVE_DIR if zoom_box is None: zoom_box = {"lat_min": -23, "lat_max": 13, "lon_min": 255, "lon_max": 345} # Convert tensors if hasattr(predictions, "detach"): predictions = predictions.detach().cpu().numpy() if hasattr(targets, "detach"): targets = targets.detach().cpu().numpy() if hasattr(lat_1d, "detach"): lat_1d = lat_1d.detach().cpu().numpy() if hasattr(lon_1d, "detach"): lon_1d = lon_1d.detach().cpu().numpy() lat = lat_1d lon = lon_1d lon2d, lat2d = np.meshgrid(lon, lat) lat_min, lat_max = lat.min(), lat.max() lon_min, lon_max = lon.min(), lon.max() lon_center = float((lon_min + lon_max) / 2) lat_mask = (lat >= zoom_box["lat_min"]) & (lat <= zoom_box["lat_max"]) lon_mask = (lon >= zoom_box["lon_min"]) & (lon <= zoom_box["lon_max"]) lat_zoom = lat[lat_mask] lon_zoom = lon[lon_mask] lon_zoom2d, lat_zoom2d = np.meshgrid(lon_zoom, lat_zoom) n_vars = targets.shape[1] if variable_names is None: variable_names = [f"VAR_{i}" for i in range(n_vars)] plot_variable_names = [PlotConfig.get_plot_name(v) for v in variable_names] cmaps = [PlotConfig.get_colormap(v) for v in variable_names] proj_global = ccrs.PlateCarree(central_longitude=lon_center) proj_zoom = ccrs.PlateCarree() base_width_per_panel = 4.5 base_height_per_panel = 2.5 fig = plt.figure(figsize=(base_width_per_panel * n_vars, base_height_per_panel * 4)) left_margin = 0.08 right_margin = 0.02 bottom_margin = 0.08 top_margin = 0.06 hspace = 0.002 wspace = 0.008 total_width = 1 - left_margin - right_margin total_height = 1 - bottom_margin - top_margin col_width = total_width / n_vars row_height = total_height / 4 axes = np.empty((4, n_vars), dtype=object) for row in range(4): for col in range(n_vars): proj = proj_global if row == 0 else proj_zoom x0 = left_margin + col * col_width + wspace / 2 y0 = 1 - top_margin - (row + 1) * row_height + hspace / 2 w = col_width - wspace h = row_height - hspace axes[row, col] = fig.add_axes([x0, y0, w, h], projection=proj) coastline = cfeature.COASTLINE.with_scale("50m") borders = cfeature.BORDERS.with_scale("50m") for col in range(n_vars): var = variable_names[col] truth = PlotConfig.convert_units(var, targets[0, col]) pred = PlotConfig.convert_units(var, predictions[0, col]) mae = np.abs(pred - truth) cmap = cmaps[col] all_data = np.concatenate([truth.flatten(), pred.flatten()]) all_data = all_data[~np.isnan(all_data)] vmin, vmax = np.quantile(all_data, [0.02, 0.98]) mae_vmax = np.quantile(mae[~np.isnan(mae)], 0.98) truth_zoom = truth[np.ix_(lat_mask, lon_mask)] pred_zoom = pred[np.ix_(lat_mask, lon_mask)] mae_zoom = mae[np.ix_(lat_mask, lon_mask)] # ---- Row 0 Truth global ---- ax = axes[0, col] im = ax.pcolormesh( lon2d, lat2d, truth, cmap=cmap, vmin=vmin, vmax=vmax, transform=ccrs.PlateCarree(), shading="auto", ) ax.set_extent([lon_min, lon_max, lat_min, lat_max]) ax.add_feature(coastline, linewidth=0.6) ax.add_feature(borders, linewidth=0.5) rect = patches.Rectangle( (zoom_box["lon_min"], zoom_box["lat_min"]), zoom_box["lon_max"] - zoom_box["lon_min"], zoom_box["lat_max"] - zoom_box["lat_min"], linewidth=2, edgecolor="red", facecolor="none", transform=ccrs.PlateCarree(), ) ax.add_patch(rect) zoom_ax = axes[1, col] fig.add_artist( ConnectionPatch( xyA=(zoom_box["lon_min"], zoom_box["lat_max"]), coordsA=ccrs.PlateCarree()._as_mpl_transform(ax), xyB=(0, 1), coordsB=zoom_ax.transAxes, color="red", linewidth=1.5, ) ) fig.add_artist( ConnectionPatch( xyA=(zoom_box["lon_max"], zoom_box["lat_max"]), coordsA=ccrs.PlateCarree()._as_mpl_transform(ax), xyB=(1, 1), coordsB=zoom_ax.transAxes, color="red", linewidth=1.5, ) ) im_global = im # ---- Row 1 Truth zoom ---- ax = axes[1, col] ax.pcolormesh( lon_zoom2d, lat_zoom2d, truth_zoom, cmap=cmap, vmin=vmin, vmax=vmax, transform=ccrs.PlateCarree(), shading="auto", ) """ ax.set_extent( [ zoom_box["lon_min"], zoom_box["lon_max"], zoom_box["lat_min"], zoom_box["lat_max"], ] ) """ ax.set_extent( [ lon_zoom.min(), lon_zoom.max(), lat_zoom.min(), lat_zoom.max(), ], crs=ccrs.PlateCarree(), ) ax.add_feature(coastline, linewidth=0.6) ax.add_feature(borders, linewidth=0.5) # ---- Row 2 Prediction zoom ---- ax = axes[2, col] ax.pcolormesh( lon_zoom2d, lat_zoom2d, pred_zoom, cmap=cmap, vmin=vmin, vmax=vmax, transform=ccrs.PlateCarree(), shading="auto", ) """ ax.set_extent( [ zoom_box["lon_min"], zoom_box["lon_max"], zoom_box["lat_min"], zoom_box["lat_max"], ] ) """ ax.set_extent( [ lon_zoom.min(), lon_zoom.max(), lat_zoom.min(), lat_zoom.max(), ], crs=ccrs.PlateCarree(), ) ax.add_feature(coastline, linewidth=0.6) ax.add_feature(borders, linewidth=0.5) # ---- Row 3 MAE ---- ax = axes[3, col] im_mae = ax.pcolormesh( lon_zoom2d, lat_zoom2d, mae_zoom, cmap="Reds", vmin=0, vmax=mae_vmax, transform=ccrs.PlateCarree(), shading="auto", ) """ ax.set_extent( [ zoom_box["lon_min"], zoom_box["lon_max"], zoom_box["lat_min"], zoom_box["lat_max"], ] ) """ ax.set_extent( [ lon_zoom.min(), lon_zoom.max(), lat_zoom.min(), lat_zoom.max(), ], crs=ccrs.PlateCarree(), ) ax.add_feature(coastline, linewidth=0.6) ax.add_feature(borders, linewidth=0.5) # Colorbars cax = ax.inset_axes([0.15, -0.25, 0.7, 0.05]) fig.colorbar(im_mae, cax=cax, orientation="horizontal").set_label( f"MAE {plot_variable_names[col]}" ) ax_top = axes[0, col] cax_top = ax_top.inset_axes([0.1, 1.05, 0.8, 0.05]) cbar = fig.colorbar(im_global, cax=cax_top, orientation="horizontal") cbar.set_label(plot_variable_names[col]) cax_top.xaxis.set_ticks_position("top") cax_top.xaxis.set_label_position("top") for r in range(4): axes[r, col].set_xticks([]) axes[r, col].set_yticks([]) # Labels lignes row_labels = ["Truth", "Truth (Zoom)", "Prediction (Zoom)", "MAE"] for i, label in enumerate(row_labels): axes[i, 0].text( -0.15, 0.5, label, transform=axes[i, 0].transAxes, rotation=90, va="center", ha="right", ) os.makedirs(save_dir, exist_ok=True) save_path = os.path.join(save_dir, filename) plt.savefig(save_path, bbox_inches="tight") plt.close() return save_path
[docs] def plot_global_surface_robinson( predictions, targets, coarse_inputs, lat_1d, lon_1d, timestamp=None, variable_names=None, filename="global_robinson.png", save_dir=None, figsize_multiplier=None, ): """ Plot coarse, truth, prediction and difference fields in Robinson projection. Parameters ---------- coarse_inputs : torch.Tensor or np.ndarray coarse_inputs-resolution input data with shape [1, n_vars, H, W]. targets : torch.Tensor or np.ndarray Ground-truth high-resolution data with shape [1, n_vars, H, W]. predictions : torch.Tensor or np.ndarray Model predictions at targets resolution with shape [1, n_vars, H, W]. lat_1d : array-like 1D array of latitude coordinates with shape [H]. lon_1d : array-like 1D array of longitude coordinates with shape [W]. timestamp : datetime.datetime Forecast timestamp to include in the plot title. variable_names : list of str, optional Variable names or identifiers. filename : str, optional Output filename for saving the plot. save_dir : str, optional Directory to save the plot. figsize_multiplier : int, optional Base size multiplier for subplots. Returns ------- None """ # Use defaults from config if not provided if save_dir is None: save_dir = PlotConfig.DEFAULT_SAVE_DIR if figsize_multiplier is None: figsize_multiplier = PlotConfig.DEFAULT_FIGSIZE_MULTIPLIER # Convert tensors to numpy if needed if hasattr(coarse_inputs, "detach"): coarse_inputs = coarse_inputs.detach().cpu().numpy() if hasattr(targets, "detach"): targets = targets.detach().cpu().numpy() if hasattr(predictions, "detach"): predictions = predictions.detach().cpu().numpy() if hasattr(lat_1d, "detach"): lat_1d = lat_1d.detach().cpu().numpy() if hasattr(lon_1d, "detach"): lon_1d = lon_1d.detach().cpu().numpy() # # Create 2D meshgrid from 1D coordinates lat_min, lat_max = lat_1d.min(), lat_1d.max() lon_min, lon_max = lon_1d.min(), lon_1d.max() # Shape h, w = coarse_inputs[0, 0].shape lat_block = np.linspace(lat_max, lat_min, h) lon_block = np.linspace(lon_min, lon_max, w) lat2d, lon2d = np.meshgrid(lat_block, lon_block, indexing="ij") lon2d = ((lon2d + 180) % 360) - 180 # normalize # Check data dimensions n_vars = coarse_inputs.shape[1] if targets.shape[1] != n_vars: raise ValueError( f"targets data has {targets.shape[1]} variables but coarse_inputs has {n_vars}" ) if predictions.shape[1] != n_vars: raise ValueError( f"predictions data has {predictions.shape[1]} variables but coarse_inputs has {n_vars}" ) # Default variable names if not provided if variable_names is None: variable_names = [f"VAR_{i}" for i in range(n_vars)] # Derive plot names and colormaps plot_variable_names = [PlotConfig.get_plot_name(var) for var in variable_names] cmaps = [PlotConfig.get_colormap(var) for var in variable_names] # Derive vmin/vmax from data for each variable (for coarse_inputs, truth, prediction) vmin_list = [] vmax_list = [] # Derive vmin/vmax for difference plots (signed difference) diff_vmin_list = [] diff_vmax_list = [] for i in range(n_vars): # Get combined data range for this variable (coarse_inputs, truth, prediction) all_data = np.concatenate( [ coarse_inputs[0, i].flatten(), targets[0, i].flatten(), predictions[0, i].flatten(), ] ) # Calculate vmin/vmax (using quantile approach like original function) all_data_flat = all_data[~np.isnan(all_data)] if len(all_data_flat) > 0: q_low, q_high = np.quantile(all_data_flat, [0.02, 0.98]) vmin, vmax = float(q_low), float(q_high) else: vmin, vmax = -1, 1 # Ensure vmin < vmax if vmin >= vmax: vmin, vmax = float(np.nanmin(all_data)), float(np.nanmax(all_data)) vmin_list.append(vmin) vmax_list.append(vmax) # Calculate signed difference between prediction and truth diff_data = (predictions[0, i] - targets[0, i]).flatten() diff_data = diff_data[~np.isnan(diff_data)] if len(diff_data) > 0: # For signed difference, we want symmetric range around 0 max_abs_diff = np.max(np.abs(diff_data)) diff_vmin = -max_abs_diff * 1.1 # Add 10% padding diff_vmax = max_abs_diff * 1.1 # Add 10% padding # If all differences are zero or very small if diff_vmax <= 0.001: diff_vmin, diff_vmax = -0.1, 0.1 else: diff_vmin, diff_vmax = -1, 1 diff_vmin_list.append(diff_vmin) diff_vmax_list.append(diff_vmax) # Set up figure fig, axes = plt.subplots( 4, n_vars, # 4 rows, n_vars columns figsize=(4.5 * n_vars, 3.2 * 4), subplot_kw={"projection": ccrs.Robinson()}, gridspec_kw={"hspace": 0.12, "wspace": 0.05}, ) if n_vars == 1: axes = axes.reshape(4, 1) row_labels = ["Coarse", "Truth", "Prediction", "Pred − Truth"] # Plot each variable for col in range(n_vars): coarse = coarse_inputs[0, col] truth = targets[0, col] pred = predictions[0, col] diff = pred - truth data_rows = [coarse, truth, pred, diff] vmins = [vmin_list[col]] * 3 + [diff_vmin_list[col]] vmaxs = [vmax_list[col]] * 3 + [diff_vmax_list[col]] cmaps_row = [cmaps[col]] * 3 + ["RdBu_r"] for row in range(4): ax = axes[row, col] ax.set_global() # Create the plot im = ax.pcolormesh( lon2d, lat2d, data_rows[row], transform=ccrs.PlateCarree(), cmap=cmaps_row[row], vmin=vmins[row], vmax=vmaxs[row], shading="auto", ) ax.coastlines(linewidth=0.9) ax.add_feature( cfeature.BORDERS.with_scale("50m"), linewidth=0.9, linestyle="--", edgecolor="black", zorder=11, ) ax.add_feature( cfeature.LAKES.with_scale("50m"), edgecolor="black", facecolor="none", linewidth=0.9, zorder=9, ) ax.set_xticks([]) ax.set_yticks([]) # if row == 0: # ax.set_title(plot_variable_names[col], fontsize=13) if col == 0: ax.text( -0.08, 0.5, row_labels[row], transform=ax.transAxes, va="center", ha="right", rotation=90, fontsize=12, ) # Colorbars if row == 0: # cax = ax.inset_axes([0.1, 1.02, 0.8, 0.05]) cax = ax.inset_axes([0.1, 1.08, 0.8, 0.05]) cb = fig.colorbar(im, cax=cax, orientation="horizontal") cb.set_label(plot_variable_names[col]) cax.xaxis.set_ticks_position("top") cax.xaxis.set_label_position("top") if row == 3: cax = ax.inset_axes([0.1, -0.12, 0.8, 0.05]) fig.colorbar( im, cax=cax, orientation="horizontal", label=f{plot_variable_names[col]} (Pred - Truth)", ) if timestamp is not None: fig.suptitle( f"Global Robinson diagnostic – {timestamp.strftime('%Y-%m-%d %H:%M')}", fontsize=16, y=0.96, ) os.makedirs(save_dir, exist_ok=True) save_path = os.path.join(save_dir, filename) plt.savefig(save_path, bbox_inches="tight") plt.close(fig) return save_path
[docs] def plot_MAE_map( predictions, # Model predictions (fine predicted) targets, # Ground truth (fine true) lat_1d, lon_1d, timestamp=None, variable_names=None, filename="validation_mae_map.png", save_dir=None, figsize_multiplier=None, # Base size per subplot ): """ Plot spatial MAE maps averaged over all time steps: MAE(x, y) = mean_t(abs(prediction - target)) Parameters ---------- predictions : torch.Tensor or np.array Model predictions of shape [batch_size, num_variables, h, w] targets : torch.Tensor or np.array Ground truth of shape [batch_size, num_variables, h, w] lat_1d : array-like 1D array of latitude coordinates with shape [H]. lon_1d : array-like 1D array of longitude coordinates with shape [W]. timestamp : datetime.datetime Forecast timestamp to include in the plot title. variable_names : list of str, optional Variable names or identifiers. filename : str, optional Output filename for saving the plot. save_dir : str, optional Directory to save the plot. figsize_multiplier : int, optional Base size multiplier for subplots. Returns ------- None """ if save_dir is None: save_dir = PlotConfig.DEFAULT_SAVE_DIR if figsize_multiplier is None: figsize_multiplier = PlotConfig.DEFAULT_FIGSIZE_MULTIPLIER # Convert tensors to numpy if hasattr(predictions, "detach"): predictions = predictions.detach().cpu().numpy() if hasattr(targets, "detach"): targets = targets.detach().cpu().numpy() if hasattr(lat_1d, "detach"): lat_1d = lat_1d.detach().cpu().numpy() if hasattr(lon_1d, "detach"): lon_1d = lon_1d.detach().cpu().numpy() lat_min, lat_max = lat_1d.min(), lat_1d.max() lon_min, lon_max = lon_1d.min(), lon_1d.max() T, n_vars, h, w = predictions.shape lat_block = np.linspace(lat_max, lat_min, h) lon_block = np.linspace(lon_min, lon_max, w) lat, lon = np.meshgrid(lat_block, lon_block, indexing="ij") lon_center = float((lon_min + lon_max) / 2) if targets.shape[1] != n_vars: raise ValueError("targets and predictions must have same number of variables") if variable_names is None: variable_names = [f"VAR_{i}" for i in range(n_vars)] plot_variable_names = [PlotConfig.get_plot_name(var) for var in variable_names] # cmaps = [PlotConfig.get_colormap(var) for var in variable_names] cmaps = PlotConfig.get_colormap("mae") vmin_list, vmax_list = [], [] # MAE averaged over time for color scaling for i in range(n_vars): # mae_data = np.mean(np.abs(predictions[:, i] - targets[:, i]), axis=0) pred_i = PlotConfig.convert_units(variable_names[i], predictions[:, i]) tgt_i = PlotConfig.convert_units(variable_names[i], targets[:, i]) mae_data = np.mean(np.abs(pred_i - tgt_i), axis=0) mae_flat = mae_data.flatten() mae_flat = mae_flat[~np.isnan(mae_flat)] fixed_range = PlotConfig.get_fixed_mae_range(variable_names[i]) if fixed_range is not None: vmin, vmax = fixed_range else: if len(mae_flat) > 0: q_low, q_high = np.quantile(mae_flat, [0.02, 0.98]) vmin, vmax = float(q_low), float(q_high) else: vmin, vmax = 0.0, 1.0 if vmin >= vmax: vmin, vmax = float(np.nanmin(mae_flat)), float(np.nanmax(mae_flat)) vmin_list.append(vmin) vmax_list.append(vmax) base_width_per_panel = 4.5 base_height_per_panel = 3.0 fig_width = base_width_per_panel * n_vars fig_height = base_height_per_panel fig, axes = plt.subplots( 1, n_vars, figsize=(fig_width, fig_height), subplot_kw={ "projection": ccrs.PlateCarree(central_longitude=lon_center) }, # ccrs.Mercator(central_longitude=lon_center) gridspec_kw={"wspace": 0.1}, ) if n_vars == 1: axes = [axes] if timestamp is not None: fig.suptitle( f"MAE Map (time-averaged) - {timestamp.strftime('%Y-%m-%d %H:%M')}", fontsize=16, y=1.05, ) for col_idx in range(n_vars): ax = axes[col_idx] pred_i = PlotConfig.convert_units( variable_names[col_idx], predictions[:, col_idx] ) tgt_i = PlotConfig.convert_units(variable_names[col_idx], targets[:, col_idx]) # MAE averaged over all time steps mae_data = np.mean(np.abs(pred_i - tgt_i), axis=0) im = ax.pcolormesh( lon, lat, mae_data, vmin=vmin_list[col_idx], vmax=vmax_list[col_idx], cmap=cmaps, transform=ccrs.PlateCarree(), shading="auto", ) ax.set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree()) # ax.set_global() ax.coastlines(linewidth=0.6) ax.add_feature( cfeature.BORDERS.with_scale("50m"), linewidth=0.6, linestyle="--", edgecolor="black", zorder=11, ) ax.add_feature( cfeature.LAKES.with_scale("50m"), edgecolor="black", facecolor="none", linewidth=0.6, zorder=9, ) # ax.set_aspect("auto") ax.set_xticks([]) ax.set_yticks([]) cax = ax.inset_axes([0.1, -0.15, 0.8, 0.05]) fig.colorbar( im, cax=cax, orientation="horizontal", label=f"MAE {plot_variable_names[col_idx]}", ) fig.subplots_adjust(top=0.85, bottom=0.25, left=0.08, right=0.95) os.makedirs(save_dir, exist_ok=True) save_path = os.path.join(save_dir, filename) plt.savefig(save_path, bbox_inches="tight") plt.close(fig) return save_path
[docs] def plot_error_map( predictions, # Model predictions (fine predicted) targets, # Ground truth (fine true) lat_1d, lon_1d, timestamp=None, variable_names=None, filename="validation_error_map.png", save_dir=None, figsize_multiplier=None, ): """ Plot spatial ERROR maps averaged over all time steps. Parameters ---------- predictions : torch.Tensor or np.array Model predictions of shape [batch_size, num_variables, h, w] targets : torch.Tensor or np.array Ground truth of shape [batch_size, num_variables, h, w] lat_1d : array-like 1D array of latitude coordinates with shape [H]. lon_1d : array-like 1D array of longitude coordinates with shape [W]. timestamp : datetime.datetime Forecast timestamp to include in the plot title. variable_names : list of str, optional Variable names or identifiers. filename : str, optional Output filename for saving the plot. save_dir : str, optional Directory to save the plot. figsize_multiplier : int, optional Base size multiplier for subplots. Returns ------- None """ if save_dir is None: save_dir = PlotConfig.DEFAULT_SAVE_DIR if figsize_multiplier is None: figsize_multiplier = PlotConfig.DEFAULT_FIGSIZE_MULTIPLIER # Convert tensors to numpy if hasattr(predictions, "detach"): predictions = predictions.detach().cpu().numpy() if hasattr(targets, "detach"): targets = targets.detach().cpu().numpy() if hasattr(lat_1d, "detach"): lat_1d = lat_1d.detach().cpu().numpy() if hasattr(lon_1d, "detach"): lon_1d = lon_1d.detach().cpu().numpy() lat_min, lat_max = lat_1d.min(), lat_1d.max() lon_min, lon_max = lon_1d.min(), lon_1d.max() T, n_vars, h, w = predictions.shape lat_block = np.linspace(lat_max, lat_min, h) lon_block = np.linspace(lon_min, lon_max, w) lat, lon = np.meshgrid(lat_block, lon_block, indexing="ij") lon_center = float((lon_min + lon_max) / 2) if variable_names is None: variable_names = [f"VAR_{i}" for i in range(n_vars)] plot_variable_names = [PlotConfig.get_plot_name(v) for v in variable_names] cmaps = PlotConfig.get_colormap("error") vmin_list, vmax_list = [], [] eps = 1e-6 # Compute time-averaged error for scaling for i in range(n_vars): var = variable_names[i] pred_i = PlotConfig.convert_units(var, predictions[:, i]) tgt_i = PlotConfig.convert_units(var, targets[:, i]) if var.lower() in ["var_tp", "precip", "precipitation"]: err = np.mean(np.abs(pred_i - tgt_i), axis=0) else: err = np.mean(np.abs(pred_i - tgt_i) / (np.abs(tgt_i) + eps), axis=0) err_flat = err.flatten() err_flat = err_flat[~np.isnan(err_flat)] fixed_range = PlotConfig.get_fixed_diff_range_errors(var) if fixed_range is not None: vmin, vmax = fixed_range else: if len(err_flat) > 0: vmax = np.max(err_flat) vmin = 0 vmax = 1.1 * vmax else: vmin, vmax = 0, 1 vmin_list.append(vmin) vmax_list.append(vmax) base_w, base_h = 4.5, 3.0 fig, axes = plt.subplots( 1, n_vars, figsize=(base_w * n_vars, base_h), subplot_kw={"projection": ccrs.PlateCarree(central_longitude=lon_center)}, gridspec_kw={"wspace": 0.1}, ) if n_vars == 1: axes = [axes] for i in range(n_vars): ax = axes[i] var = variable_names[i] pred_i = PlotConfig.convert_units(var, predictions[:, i]) tgt_i = PlotConfig.convert_units(var, targets[:, i]) if var.lower() in ["var_tp", "precip", "precipitation"]: err_map = np.mean(np.abs(pred_i - tgt_i), axis=0) label = "Absolute Error (mm/h)" else: err_map = np.mean(np.abs(pred_i - tgt_i) / (np.abs(tgt_i) + eps), axis=0) label = "Relative Error" im = ax.pcolormesh( lon, lat, err_map, vmin=vmin_list[i], vmax=vmax_list[i], cmap=cmaps, transform=ccrs.PlateCarree(), shading="auto", ) ax.set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree()) ax.coastlines(linewidth=0.6) ax.add_feature( cfeature.BORDERS.with_scale("50m"), linewidth=0.6, linestyle="--", edgecolor="black", zorder=11, ) ax.add_feature( cfeature.LAKES.with_scale("50m"), edgecolor="black", facecolor="none", linewidth=0.6, zorder=9, ) ax.set_xticks([]) ax.set_yticks([]) ax.set_title(plot_variable_names[i]) cax = ax.inset_axes([0.1, -0.15, 0.8, 0.05]) fig.colorbar( im, cax=cax, orientation="horizontal", label=f"{label}", ) fig.subplots_adjust(top=0.85, bottom=0.25, left=0.08, right=0.95) os.makedirs(save_dir, exist_ok=True) save_path = os.path.join(save_dir, filename) plt.savefig(save_path, bbox_inches="tight") plt.close(fig) return save_path
[docs] def spread_skill_ratio( predictions, # Model predictions (fine predicted) targets, # Ground truth (fine true) variable_names, pixel_wise=False, ): """ Compute spread skill ratio of predictions with respect to targets. The formula implemented is equation (15) in "Why Should Ensemble Spread Match the RMSE of the Ensemble Mean?", Fortin et al. Parameters ---------- predictions : torch.Tensor or np.array Model predictions of shape [ensemble_size, batch_size, num_variables, h, w] It is very important not to switch dimensions order. ensemble_size must be greater or equal than 2 for spread skill ratio to be computed. targets : torch.Tensor or np.array Ground truth of shape [batch_size, num_variables, h, w] variable_names : list of str, optional Variable names or identifiers. pixel_wise : bool If True, computes and return the SSR for each pixel independantly. If False, computes and return the SSR averaged over all pixels and all timesteps. Defaults to False Returns ------- np.array of shape [num_variables, h, w] if pixel_wise == True or of shape [num_variables,] if pixel_wise == False (default) """ # Convert tensors to numpy if hasattr(predictions, "detach"): predictions = predictions.detach().cpu().numpy() if hasattr(targets, "detach"): targets = targets.detach().cpu().numpy() if len(predictions.shape) != 5: raise ValueError( "predictions needs to be 5 dimensional tensor / array [ensemble_size, temporal_size, n_vars, h, w]." ) if predictions.shape[0] == 1: raise ValueError( "predictions needs to contain more than 1 member to compute spread skill ratio." ) E, T, n_vars, h, w = predictions.shape if targets.shape[1] != n_vars: raise ValueError("targets and predictions must have same number of variables") if variable_names is None: variable_names = [f"VAR_{i}" for i in range(n_vars)] ssr_list = [] for i in range(n_vars): # mae_data = np.mean(np.abs(predictions[:, i] - targets[:, i]), axis=0) pred_i = PlotConfig.convert_units( variable_names[i], predictions[:, :, i] ) # [E,T,h,w] tgt_i = PlotConfig.convert_units(variable_names[i], targets[:, i]) # [T,h,w] # apply the formula (15) found in "Why Should Ensemble Spread Match the RMSE of the Ensemble Mean?", Fortin et al. mean_pred_i = np.mean(pred_i, axis=0) # ensemble mean [T,h,w] if pixel_wise: rmse_data_i = np.sqrt(np.mean((mean_pred_i - tgt_i) ** 2, axis=0)) # [h,w] spread_i = ( np.sqrt(np.mean(np.var(pred_i, axis=0), axis=0)) * np.sqrt((E + 1) / E) ) # [h,w] # sqrt of temporal mean of variance * corrective factor depending on the number of members in the ensemble. ssr_i = np.divide(spread_i, rmse_data_i) # [h,w] else: # do the same but for average over every pixel and every timestep rmse_data_i_mean = np.sqrt(np.mean((mean_pred_i - tgt_i) ** 2)) # float spread_i_mean = ( np.sqrt(np.mean(np.var(pred_i, axis=0))) * np.sqrt((E + 1) / E) ) # float # sqrt of temporal mean of variance * corrective factor depending on the number of members in the ensemble. ssr_i = np.divide(spread_i_mean, rmse_data_i_mean) # float ssr_list.append(ssr_i) return np.array(ssr_list)
[docs] def plot_spread_skill_ratio_map( predictions, # Model predictions (fine predicted) targets, # Ground truth (fine true) lat_1d, lon_1d, timestamp=None, variable_names=None, filename="validation_spread_skill_ratio_map.png", save_dir=None, figsize_multiplier=None, # Base size per subplot ): """ Plot spatial spread skill ratio maps averaged over all time steps for each individual pixel. The formula implemented is equation (15) in article "Why Should Ensemble Spread Match the RMSE of the Ensemble Mean?", Fortin et al. Parameters ---------- predictions : torch.Tensor or np.array Model predictions of shape [ensemble_size, batch_size, num_variables, h, w] It is very important not to switch dimensions order. ensemble_size must be greater or equal than 2 for spread skill ratio to be computed. targets : torch.Tensor or np.array Ground truth of shape [batch_size, num_variables, h, w] lat_1d : array-like 1D array of latitude coordinates with shape [H]. lon_1d : array-like 1D array of longitude coordinates with shape [W]. timestamp : datetime.datetime Forecast timestamp to include in the plot title. variable_names : list of str, optional Variable names or identifiers. filename : str, optional Output filename for saving the plot. save_dir : str, optional Directory to save the plot. figsize_multiplier : int, optional Base size multiplier for subplots. Returns ------- None """ if save_dir is None: save_dir = PlotConfig.DEFAULT_SAVE_DIR if figsize_multiplier is None: figsize_multiplier = PlotConfig.DEFAULT_FIGSIZE_MULTIPLIER # Convert tensors to numpy if hasattr(predictions, "detach"): predictions = predictions.detach().cpu().numpy() if hasattr(targets, "detach"): targets = targets.detach().cpu().numpy() if hasattr(lat_1d, "detach"): lat_1d = lat_1d.detach().cpu().numpy() if hasattr(lon_1d, "detach"): lon_1d = lon_1d.detach().cpu().numpy() lat_min, lat_max = lat_1d.min(), lat_1d.max() lon_min, lon_max = lon_1d.min(), lon_1d.max() if len(predictions.shape) != 5: raise ValueError( "predictions needs to be 5 dimensional tensor / array [ensemble_size, temporal_size, n_vars, h, w]." ) if predictions.shape[0] == 1: raise ValueError( "predictions needs to contain more than 1 member to compute spread skill ratio." ) E, T, n_vars, h, w = predictions.shape lat_block = np.linspace(lat_max, lat_min, h) lon_block = np.linspace(lon_min, lon_max, w) lat, lon = np.meshgrid(lat_block, lon_block, indexing="ij") lon_center = float((lon_min + lon_max) / 2) if targets.shape[1] != n_vars: raise ValueError("targets and predictions must have same number of variables") if variable_names is None: variable_names = [f"VAR_{i}" for i in range(n_vars)] plot_variable_names = [PlotConfig.get_plot_name(var) for var in variable_names] # cmaps = [PlotConfig.get_colormap(var) for var in variable_names] cmaps = PlotConfig.get_colormap("SSR") vmin_list, vmax_list = [], [] ssr_list = spread_skill_ratio(predictions, targets, variable_names, pixel_wise=True) ssr_mean_list = spread_skill_ratio( predictions, targets, variable_names, pixel_wise=False ) # MAE averaged over time for color scaling for i in range(n_vars): ssr_i = ssr_list[i] ssr_i_flat = (ssr_i).flatten() # flatten [h*w] fixed_range = PlotConfig.get_fixed_ssr_range(variable_names[i]) if fixed_range is not None: vmin, vmax = fixed_range else: if len(ssr_i_flat) > 0: q_low, q_high = np.quantile(ssr_i_flat, [0.02, 0.98]) vmin, vmax = float(q_low), float(q_high) else: vmin, vmax = 0.0, 2.0 if vmin >= vmax: vmin, vmax = float(np.nanmin(ssr_i_flat)), float(np.nanmax(ssr_i_flat)) vmin_list.append(vmin) vmax_list.append(vmax) base_width_per_panel = 4.5 base_height_per_panel = 3.0 fig_width = base_width_per_panel * n_vars fig_height = base_height_per_panel fig, axes = plt.subplots( 1, n_vars, figsize=(fig_width, fig_height), subplot_kw={ "projection": ccrs.PlateCarree(central_longitude=lon_center) }, # ccrs.Mercator(central_longitude=lon_center) gridspec_kw={"wspace": 0.1}, ) if n_vars == 1: axes = [axes] # if timestamp is None: # fig.suptitle( # "Spread skill ratio map (time-averaged)", # y=1.05, # ) # if timestamp is not None: # fig.suptitle( # f"Spread skill ratio map {timestamp.strftime('%Y-%m-%d %H:%M')}", # y=1.05, # ) for col_idx in range(n_vars): ax = axes[col_idx] # MAE averaged over all time steps ssr_data_i = ssr_list[col_idx] ssr_data_i_mean = ssr_mean_list[col_idx] vmin = vmin_list[col_idx] vmax = vmax_list[col_idx] if vmin <= 1 <= vmax: norm = mcolors.TwoSlopeNorm(vmin=vmin, vcenter=1, vmax=vmax) im = ax.pcolormesh( lon, lat, ssr_data_i, cmap=cmaps, norm=norm, transform=ccrs.PlateCarree(), shading="auto", ) elif vmax >= 1: norm = mcolors.TwoSlopeNorm(vmin=0, vcenter=1, vmax=vmax) im = ax.pcolormesh( lon, lat, ssr_data_i, cmap=cmaps, norm=norm, transform=ccrs.PlateCarree(), shading="auto", ) else: norm = mcolors.TwoSlopeNorm(vmin=0, vcenter=1, vmax=2) im = ax.pcolormesh( lon, lat, ssr_data_i, cmap=cmaps, norm=norm, transform=ccrs.PlateCarree(), shading="auto", ) ax.set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree()) # ax.set_global() ax.coastlines(linewidth=0.6) ax.add_feature( cfeature.BORDERS.with_scale("50m"), linewidth=0.6, linestyle="--", edgecolor="black", zorder=11, ) ax.add_feature( cfeature.LAKES.with_scale("50m"), edgecolor="black", facecolor="none", linewidth=0.6, zorder=9, ) # ax.set_aspect("auto") ax.set_xticks([]) ax.set_yticks([]) # props = dict(boxstyle="round", facecolor="wheat", alpha=0.5) # place a text box in upper left in axes coords # ax.text( # 0.05, # 1.15, # f"SSR = {np.mean(ssr_data_i):.2f}", # transform=ax.transAxes, # verticalalignment="top", # bbox=props, # ) ax.set_title( f"{plot_variable_names[col_idx]} (SSR={ssr_data_i_mean:.2f})", pad=10, ) cax = ax.inset_axes([0.1, -0.15, 0.8, 0.05]) fig.colorbar( im, cax=cax, orientation="horizontal", label=f"SSR {plot_variable_names[col_idx]}", ) fig.subplots_adjust(top=0.85, bottom=0.25, left=0.08, right=0.95) os.makedirs(save_dir, exist_ok=True) save_path = os.path.join(save_dir, filename) plt.savefig(save_path, bbox_inches="tight") plt.close(fig) return save_path
[docs] def plot_spread_skill_ratio_hexbin( predictions, # Model predictions (fine predicted) targets, # Ground truth (fine true) variable_names=None, filename="validation_spread_skill_ratio_hexbin.png", save_dir=None, figsize_multiplier=None, # Base size per subplot ): """ Plot spatial spread skill ratio scatterplot, where each point represent a prediction for a single pixel, single timestep: SSR(x, y) = spread(x,y) / skill(x,y) where spread(x,y) = temporal mean of standard deviation of ensemble members predictions and skill = temporal mean of RMSE of the mean of the ensemble members. Parameters ---------- predictions : torch.Tensor or np.array Model predictions of shape [ensemble_size, batch_size, num_variables, h, w] It is very important not to switch dimensions order. ensemble_size must be greater or equal than 2 for spread skill ratio to be computed. targets : torch.Tensor or np.array Ground truth of shape [batch_size, num_variables, h, w] variable_names : list of str, optional Variable names or identifiers. filename : str, optional Output filename for saving the plot. save_dir : str, optional Directory to save the plot. figsize_multiplier : int, optional Base size multiplier for subplots. Returns ------- None """ if save_dir is None: save_dir = PlotConfig.DEFAULT_SAVE_DIR if figsize_multiplier is None: figsize_multiplier = PlotConfig.DEFAULT_FIGSIZE_MULTIPLIER # Convert tensors to numpy if hasattr(predictions, "detach"): predictions = predictions.detach().cpu().numpy() if hasattr(targets, "detach"): targets = targets.detach().cpu().numpy() if len(predictions.shape) != 5: raise ValueError( "predictions needs to be 5 dimensional tensor / array [ensemble_size, temporal_size, n_vars, h, w]." ) if predictions.shape[0] == 1: raise ValueError( "predictions needs to contain more than 1 member to compute spread skill ratio." ) E, T, n_vars, h, w = predictions.shape if targets.shape[1] != n_vars: raise ValueError("targets and predictions must have same number of variables") if variable_names is None: variable_names = [f"VAR_{i}" for i in range(n_vars)] plot_variable_names = [PlotConfig.get_plot_name(var) for var in variable_names] rmse_list = [] spread_list = [] mean_ssr_list = spread_skill_ratio( predictions, targets, variable_names, pixel_wise=False ) # MAE averaged over time for color scaling for i in range(n_vars): # mae_data = np.mean(np.abs(predictions[:, i] - targets[:, i]), axis=0) pred_i = PlotConfig.convert_units(variable_names[i], predictions[:, :, i]) tgt_i = PlotConfig.convert_units(variable_names[i], targets[:, i]) mean_pred_i = np.mean(pred_i, axis=0) rmse_data_i = np.abs((mean_pred_i - tgt_i)).flatten() rmse_list.append(rmse_data_i) spread_i = np.std(pred_i, axis=0).flatten() spread_list.append(spread_i) ncols = n_vars nrows = (n_vars + ncols - 1) // ncols fig, axes = plt.subplots( nrows, ncols, figsize=(ncols * figsize_multiplier, figsize_multiplier), ) axes = np.atleast_1d(axes).ravel() for ax in axes: ax.set_box_aspect(1) last_hb = None for i, ax in enumerate(axes): if i >= n_vars: ax.set_visible(False) continue rmse = rmse_list[i] spread = spread_list[i] hb = ax.hexbin( rmse, spread, gridsize=100, cmap="jet", bins="log", mincnt=1, ) last_hb = hb textstr = f"SSR: {mean_ssr_list[i]:.3f}" ax.text( 0.05, 0.95, textstr, transform=ax.transAxes, fontsize=10, verticalalignment="top", ) var_min = min(rmse.min(), spread.min()) var_max = max(rmse.max(), spread.max()) margin = 0.05 * (var_max - var_min) plot_min = var_min - margin plot_max = var_max + margin # Identity line ax.plot([plot_min, plot_max], [plot_min, plot_max], "r--", alpha=0.7) ax.set_xlim(plot_min, plot_max) ax.set_ylim(plot_min, plot_max) ax.set_title(plot_variable_names[i]) ax.xaxis.set_major_locator(ticker.MaxNLocator(5)) ax.yaxis.set_major_locator(ticker.MaxNLocator(5)) if i % ncols == 0: ax.set_ylabel("Ensemble std") else: ax.set_ylabel("") if i >= (nrows - 1) * ncols: ax.set_xlabel("RMSE") else: ax.set_xlabel("") ax_last = axes[min(n_vars - 1, len(axes) - 1)] cax = ax_last.inset_axes([1.05, 0.0, 0.04, 1.0]) cbar = fig.colorbar(last_hb, cax=cax) cbar.set_label(r"$\log_{10}[\mathrm{Count}]$") plt.subplots_adjust( hspace=0.1, wspace=0.3, left=0.1, right=0.9, top=0.9, bottom=0.1 ) os.makedirs(save_dir, exist_ok=True) save_path = os.path.join(save_dir, filename) plt.savefig(save_path, bbox_inches="tight") plt.close(fig) return save_path
[docs] def plot_validation_pdfs( predictions, # Model predictions (fine predicted) targets, # Ground truth (fine true) coarse_inputs=None, # Coarse inputs for comparison (optional) variable_names=None, # List of variable names filename="validation_pdfs.png", save_dir="./results", figsize_multiplier=4, # Base size per subplot save_npz=False, ): """ Create PDF (Probability Density Function) plots comparing distributions of model predictions vs ground truth for all variables. Parameters ---------- predictions : torch.Tensor or np.array Model predictions of shape [batch_size, num_variables, h, w] targets : torch.Tensor or np.array Ground truth of shape [batch_size, num_variables, h, w] coarse_inputs : torch.Tensor or np.array, optional Coarse inputs of shape [batch_size, num_variables, h, w] variable_names : list of str, optional Names of the variables for subplot titles filename : str, optional Output filename save_dir : str, optional Directory to save the plot figsize_multiplier : int, optional Base size multiplier for subplots save_npz : bool, optional If True, saves the PDF diagnostics to a compressed .npz file. Returns ------- None The function saves the plot to disk and does not return any value. Notes ----- - Creates horizontal subplots (one per variable) showing PDFs - Each subplot shows up to 3 lines: Predictions, Ground Truth, and Coarse Inputs - Uses automatic color and linestyle cycling based on global matplotlib settings - Calculates and displays key statistics for each distribution - Handles both PyTorch tensors and numpy arrays Examples -------- >>> predictions = np.random.randn(10, 3, 64, 64) # 10 samples, 3 variables >>> targets = np.random.randn(10, 3, 64, 64) >>> plot_validation_pdfs(predictions, targets, variable_names=['Temp', 'Pres', 'Humid']) """ # Convert to numpy if they're tensors if hasattr(predictions, "detach"): predictions = predictions.detach().cpu().numpy() if hasattr(targets, "detach"): targets = targets.detach().cpu().numpy() if coarse_inputs is not None and hasattr(coarse_inputs, "detach"): coarse_inputs = coarse_inputs.detach().cpu().numpy() batch_size, num_vars, h, w = predictions.shape # Default variable names if not provided if variable_names is None: variable_names = [f"Variable {i + 1}" for i in range(num_vars)] plot_variable_names = [PlotConfig.get_plot_name(var) for var in variable_names] # Calculate grid dimensions for horizontal layout ncols = num_vars nrows = 1 # Single row for horizontal layout # Create figure with horizontal subplots fig, axes = plt.subplots( nrows, ncols, figsize=(ncols * figsize_multiplier, figsize_multiplier) ) # Handle single subplot case if num_vars == 1: axes = np.array([axes]) if axes.ndim == 0: axes = np.array([axes]) axes = axes.flatten() for ax in axes: ax.set_box_aspect(1) plt.subplots_adjust( hspace=0.1, wspace=0.3, left=0.1, right=0.9, top=0.9, bottom=0.1 ) if save_npz: pdf_npz_data = {} # Plot PDF for each variable for i, (var_name, ax) in enumerate(zip(variable_names, axes)): if i >= num_vars: ax.set_visible(False) continue linestyles = mpltex.linestyle_generator(markers=[]) # Flatten the spatial dimensions pred_i = PlotConfig.convert_units(var_name, predictions[:, i]) tgt_i = PlotConfig.convert_units(var_name, targets[:, i]) plot_name = plot_variable_names[i] pred_flat = pred_i.reshape(-1) target_flat = tgt_i.reshape(-1) # Collect all data for combined range all_data = [pred_flat, target_flat] if coarse_inputs is not None: # coarse_flat = coarse_inputs[:, i, :, :].flatten() #.mean(axis=0).reshape(-1) coarse_i = PlotConfig.convert_units(var_name, coarse_inputs[:, i]) coarse_flat = coarse_i.reshape(-1) all_data.append(coarse_flat) # Calculate global range for consistent x-axis all_values = np.concatenate(all_data) data_min = np.percentile(all_values, 0.25) # 0.5th percentile data_max = np.percentile(all_values, 99.5) # 99.5th percentile data_range = data_max - data_min # Extend range slightly for better visualization x_min = data_min - 0.05 * data_range x_max = data_max + 0.05 * data_range # Create bins for PDF calculation n_bins = 100 bins = np.linspace(x_min, x_max, n_bins + 1) # Small epsilon to avoid log(0) epsilon = 1e-12 # Plot log PDFs # Plot predictions hist_pred, bin_edges = np.histogram(pred_flat, bins=bins, density=True) bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:]) log_hist_pred = np.log10(hist_pred + epsilon) ax.plot(bin_centers, log_hist_pred, label="Pred", **next(linestyles)) # Plot ground truth hist_target, _ = np.histogram(target_flat, bins=bins, density=True) log_hist_target = np.log10(hist_target + epsilon) ax.plot(bin_centers, log_hist_target, label="Truth", **next(linestyles)) # Plot coarse inputs if available if coarse_inputs is not None: hist_coarse, _ = np.histogram(coarse_flat, bins=bins, density=True) log_hist_coarse = np.log10(hist_coarse + epsilon) ax.plot(bin_centers, log_hist_coarse, label="Coarse", **next(linestyles)) # Calculate and display statistics stats_text = [] # Predictions statistics pred_mean = np.mean(pred_flat) pred_std = np.std(pred_flat) stats_text.append(f"Predictions: μ={pred_mean:.3f}, σ={pred_std:.3f}") # Ground truth statistics target_mean = np.mean(target_flat) target_std = np.std(target_flat) stats_text.append(f"Ground Truth: μ={target_mean:.3f}, σ={target_std:.3f}") # Coarse statistics if available if coarse_inputs is not None: coarse_mean = np.mean(coarse_flat) coarse_std = np.std(coarse_flat) stats_text.append(f"Coarse: μ={coarse_mean:.3f}, σ={coarse_std:.3f}") # Calculate KL divergence between predictions and ground truth hist_pred_safe = hist_pred + epsilon hist_target_safe = hist_target + epsilon # Normalize to probability distributions hist_pred_safe = hist_pred_safe / np.sum(hist_pred_safe) hist_target_safe = hist_target_safe / np.sum(hist_target_safe) kl_divergence = np.sum( hist_target_safe * np.log(hist_target_safe / hist_pred_safe) ) # Add KL divergence to statistics stats_text.append(f"KL Divergence: {kl_divergence:.4f}") # Calculate correlation coefficient correlation = np.corrcoef(pred_flat, target_flat)[0, 1] stats_text.append(f"Correlation: {correlation:.4f}") # Add statistics as text box # stats_str = '\n'.join(stats_text) # ax.text(0.5, 1.02, stats_str, transform=ax.transAxes, # verticalalignment='bottom', horizontalalignment='center') # Log statistics instead of plotting them print(f"[PDF stats] {plot_name}") print(f" Predictions: μ={pred_mean:.3f}, σ={pred_std:.3f}") print(f" Ground Truth: μ={target_mean:.3f}, σ={target_std:.3f}") if coarse_inputs is not None: print(f" Coarse: μ={coarse_mean:.3f}, σ={coarse_std:.3f}") print(f" KL Divergence: {kl_divergence:.4f}") print(f" Correlation: {correlation:.4f}") # ax.set_xlabel(f'{var_name}') ax.set_xlabel(plot_name) # Only show y-label for leftmost subplot if i == 0: # ax.set_ylabel('log₁₀(PDF)') ax.set_ylabel(r"$\log_{10}(\mathrm{PDF})$") # Add grid ax.grid(True, alpha=0.3, linestyle="--") # Add legend ax.legend() # Set consistent x-limits # ax.set_xlim(x_min, x_max) # Set y-limits for log plot (handle cases where log values might be very negative) y_min = min(log_hist_pred.min(), log_hist_target.min()) if coarse_inputs is not None: y_min = min(y_min, log_hist_coarse.min()) y_max = max(log_hist_pred.max(), log_hist_target.max()) if coarse_inputs is not None: y_max = max(y_max, log_hist_coarse.max()) # Add small margin to y-limits y_margin = 0.1 * (y_max - y_min) if y_max > y_min else 0.1 ax.set_ylim(y_min - y_margin, y_max + y_margin) # Use scientific notation for large ranges if data_range > 1000: ax.ticklabel_format(style="sci", axis="x", scilimits=(0, 0)) if save_npz: key = f"{var_name}__pdf__" pdf_npz_data[key + "bin_centers"] = bin_centers pdf_npz_data[key + "log_pred"] = log_hist_pred pdf_npz_data[key + "log_truth"] = log_hist_target pdf_npz_data[key + "mean_pred"] = pred_mean pdf_npz_data[key + "std_pred"] = pred_std pdf_npz_data[key + "mean_truth"] = target_mean pdf_npz_data[key + "std_truth"] = target_std pdf_npz_data[key + "kl"] = kl_divergence pdf_npz_data[key + "corr"] = correlation if coarse_inputs is not None: pdf_npz_data[key + "log_coarse"] = log_hist_coarse pdf_npz_data[key + "mean_coarse"] = coarse_mean pdf_npz_data[key + "std_coarse"] = coarse_std # Ensure save directory exists os.makedirs(save_dir, exist_ok=True) save_path = os.path.join(save_dir, filename) if save_npz: npz_path = os.path.splitext(save_path)[0] + ".npz" np.savez_compressed(npz_path, **pdf_npz_data) plt.savefig(save_path, bbox_inches="tight") plt.close(fig) return save_path
[docs] def plot_power_spectra( predictions, # Model predictions targets, # Ground truth dlat, # Grid spacing in latitude (degrees) dlon, # Grid spacing in longitude (degrees) coarse_inputs=None, # Coarse inputs for comparison (optional) variable_names=None, # List of variable names filename="power_spectra_physical.png", save_dir="./results", figsize_multiplier=4, save_npz=False, ): """ Calculate and plot power spectra with proper physical wavenumbers. Parameters ---------- predictions : torch.Tensor or np.array Model predictions of shape [batch_size, num_variables, nh, nw] targets : torch.Tensor or np.array Ground truth of shape [batch_size, num_variables, nh, nw] dlat : float Grid spacing in latitude (degrees) dlon : float Grid spacing in longitude (degrees) coarse_inputs : torch.Tensor or np.array, optional Coarse inputs of shape [batch_size, num_variables, nh, nw] variable_names : list of str, optional Names of the variable names for subplot titles filename : str, optional Output filename save_dir : str, optional Directory to save the plot figsize_multiplier : int, optional Base size multiplier for subplots save_npz : bool, optional If True, saves the PDF diagnostics to a compressed .npz file. Returns ------- None """ # Convert to numpy if they're tensors if hasattr(predictions, "detach"): predictions = predictions.detach().cpu().numpy() if hasattr(targets, "detach"): targets = targets.detach().cpu().numpy() if coarse_inputs is not None and hasattr(coarse_inputs, "detach"): coarse_inputs = coarse_inputs.detach().cpu().numpy() batch_size, num_vars, nh, nw = predictions.shape # Default variable names if not provided if variable_names is None: variable_names = [f"Variable {i + 1}" for i in range(num_vars)] # plot_variable_names = [PlotConfig.get_plot_name(var) for var in variable_names] # Calculate wavenumbers # FFT frequencies are in cycles per grid spacing fft_freq_lat = np.fft.fftfreq(nh, d=dlat) # cycles per degree in lat direction fft_freq_lon = np.fft.fftfreq(nw, d=dlon) # cycles per degree in lon direction # Shift frequencies so zero is at center fft_freq_lat_shifted = np.fft.fftshift(fft_freq_lat) fft_freq_lon_shifted = np.fft.fftshift(fft_freq_lon) # Create 2D wavenumber grid k_lat, k_lon = np.meshgrid(fft_freq_lon_shifted, fft_freq_lat_shifted) # Calculate magnitude of wavenumber vector (in cycles/degree) k_mag = np.sqrt(k_lat**2 + k_lon**2) # Create bins for radial averaging max_k = np.min([np.max(np.abs(fft_freq_lat)), np.max(np.abs(fft_freq_lon))]) k_bins = np.linspace(0, max_k, min(nh, nw) // 2) k_centers = 0.5 * (k_bins[1:] + k_bins[:-1]) # Create figure ncols = num_vars nrows = 1 # 2 Two rows: one for 2D spectrum, one for 1D spectrum fig, axes = plt.subplots( nrows, ncols, figsize=(ncols * figsize_multiplier, nrows * figsize_multiplier), squeeze=False, ) # nrows * figsize_multiplier plt.subplots_adjust( hspace=0.2, wspace=0.3, left=0.1, right=0.9, top=0.9, bottom=0.1 ) axes = axes.ravel() for ax in axes: ax.set_box_aspect(1) """ # Handle single subplot case if num_vars == 1: axes = np.array([[axes[0]], [axes[1]]]) elif axes.ndim == 1: axes = axes.reshape(nrows, ncols) """ if save_npz: spectra_npz_data = {} spectra_npz_data["__meta__dlat"] = dlat spectra_npz_data["__meta__dlon"] = dlon spectra_npz_data["__meta__variables"] = np.array(variable_names) # Process each variable for i, var_name in enumerate(variable_names): if i >= num_vars: continue linestyles = mpltex.linestyle_generator(markers=[]) # plot_name = plot_variable_names[i] # Initialize arrays for averaged PSDs psd2d_pred_sum = np.zeros((nh, nw)) psd2d_target_sum = np.zeros((nh, nw)) if coarse_inputs is not None: psd2d_coarse_sum = np.zeros((nh, nw)) # Process each sample in the batch for b in range(batch_size): # Predictions # field_pred = predictions[b, i] field_pred = PlotConfig.convert_units(var_name, predictions[b, i]) psd2d_pred = calculate_psd2d_simple(field_pred) psd2d_pred_sum += psd2d_pred # Targets # field_target = targets[b, i] field_target = PlotConfig.convert_units(var_name, targets[b, i]) psd2d_target = calculate_psd2d_simple(field_target) psd2d_target_sum += psd2d_target # Coarse inputs if coarse_inputs is not None: # field_coarse = coarse_inputs[b, i] field_coarse = PlotConfig.convert_units(var_name, coarse_inputs[b, i]) psd2d_coarse = calculate_psd2d_simple(field_coarse) psd2d_coarse_sum += psd2d_coarse # Average over batch psd2d_pred_avg = psd2d_pred_sum / batch_size psd2d_target_avg = psd2d_target_sum / batch_size if coarse_inputs is not None: psd2d_coarse_avg = psd2d_coarse_sum / batch_size # Calculate 1D radial spectra psd1d_pred = radial_average_psd(psd2d_pred_avg, k_mag, k_bins) psd1d_target = radial_average_psd(psd2d_target_avg, k_mag, k_bins) if coarse_inputs is not None: psd1d_coarse = radial_average_psd(psd2d_coarse_avg, k_mag, k_bins) if save_npz: key = f"{var_name}__spectra__" spectra_npz_data[key + "k"] = k_centers spectra_npz_data[key + "psd_pred"] = psd1d_pred spectra_npz_data[key + "psd_truth"] = psd1d_target if coarse_inputs is not None: spectra_npz_data[key + "psd_coarse"] = psd1d_coarse """ # --- Plot 2D PSD (top row) --- ax_top = axes[0, i] if num_vars > 1 else axes[0] # Use k_lon and k_lat for the axes instead of lat/lon k_lon_min, k_lon_max = fft_freq_lon_shifted[0], fft_freq_lon_shifted[-1] k_lat_min, k_lat_max = fft_freq_lat_shifted[0], fft_freq_lat_shifted[-1] im = ax_top.imshow(np.log10(psd2d_pred_avg + 1e-12), cmap=cmap_white_jet, aspect='auto', origin='lower', extent=[k_lon_min, k_lon_max, k_lat_min, k_lat_max]) #ax_top.set_title(f'{var_name}') ax_top.set_title(plot_name) # Only add y-axis label for leftmost column if i == 0: ax_top.set_ylabel(r'$\mathrm{k_{lat}}$ (cycles/°)') else: ax_top.set_ylabel('') # Remove y-axis tick labels for non-leftmost columns ax_top.tick_params(axis='y', labelleft=False) # Always show x-axis label ax_top.set_xlabel(r'$\mathrm{k_{lon}}$ (cycles/°)') # Add grid for better readability ax_top.grid(True, alpha=0.3, linestyle='--') # Add colorbar for the last column only if i == num_vars - 1: cax = ax_top.inset_axes([1.05, 0, 0.05, 1]) # [x, y, w, h] relative to axes cbar = plt.colorbar(im, cax=cax, orientation='vertical') cbar.set_label('log₁₀(PSD)') """ # --- Plot 1D Radial Spectrum (bottom row) --- # ax_bottom = axes[1, i] if num_vars > 1 else axes[1] ax_bottom = axes[i] # Plot all spectra ax_bottom.loglog(k_centers, psd1d_pred, label="Pred", **next(linestyles)) ax_bottom.loglog(k_centers, psd1d_target, label="Truth", **next(linestyles)) if coarse_inputs is not None: ax_bottom.loglog( k_centers, psd1d_coarse, label="Coarse", **next(linestyles) ) # Only add y-axis label for leftmost column if i == 0: ax_bottom.set_ylabel("PSD(k)") else: ax_bottom.set_ylabel("") # Always show x-axis label ax_bottom.set_xlabel("Wavenumber k [cycles/°]") ax_bottom.legend() ax_bottom.grid(True, alpha=0.3, which="both") # Set reasonable axis limits valid = (k_centers > 0) & (psd1d_target > 0) if np.any(valid): ax_bottom.set_xlim(k_centers[valid][0] * 0.8, k_centers[valid][-1] * 1.2) # Find y-range y_min = min(psd1d_pred[valid].min(), psd1d_target[valid].min()) y_max = max(psd1d_pred[valid].max(), psd1d_target[valid].max()) if coarse_inputs is not None: y_min = min(y_min, psd1d_coarse[valid].min()) y_max = max(y_max, psd1d_coarse[valid].max()) ax_bottom.set_ylim(y_min * 0.5, y_max * 2.0) # Save figure os.makedirs(save_dir, exist_ok=True) save_path = os.path.join(save_dir, filename) if save_npz: npz_path = os.path.splitext(save_path)[0] + ".npz" np.savez_compressed(npz_path, **spectra_npz_data) plt.savefig(save_path, bbox_inches="tight") plt.close(fig) return save_path
[docs] def calculate_psd2d_simple(field): """ Simple 2D PSD calculation without preprocessing. """ fft = np.fft.fft2(field) psd2d = np.abs(np.fft.fftshift(fft)) ** 2 return psd2d
[docs] def radial_average_psd(psd2d, k_mag, k_bins): """ Radially average 2D PSD using wavenumber magnitude. """ # Flatten arrays k_flat = k_mag.flatten() psd_flat = psd2d.flatten() # Use binned_statistic for radial averaging psd1d, _, _ = stats.binned_statistic( k_flat, psd_flat, statistic="mean", bins=k_bins ) # Multiply by area of annulus (2πkΔk) to get proper spectral density k_centers = 0.5 * (k_bins[1:] + k_bins[:-1]) delta_k = k_bins[1:] - k_bins[:-1] area = 2 * np.pi * k_centers * delta_k # Avoid division by zero valid = area > 0 psd1d[valid] = psd1d[valid] * area[valid] return psd1d
[docs] def plot_qq_quantiles( predictions, # Model predictions targets, # Ground truth coarse_inputs, # Coarse inputs variable_names=None, # List of variable names units=None, # List of units for each variable quantiles=[0.90, 0.95, 0.975, 0.99, 0.995], filename="qq_quantiles.png", save_dir="./results", figsize_multiplier=4, save_npz=False, ): """ Create QQ-plats at different quantiles comparing model predictions and coarse inputs against ground truth. For each variable, plots quantiles of predictions and coarse inputs against quantiles of ground truth with a 1:1 reference line. Parameters ---------- predictions : torch.Tensor or np.array Model predictions of shape [batch_size, num_variables, h, w] targets : torch.Tensor or np.array Ground truth of shape [batch_size, num_variables, h, w] coarse_inputs : torch.Tensor or np.array Coarse inputs of shape [batch_size, num_variables, h, w] variable_names : list of str, optional Names of the variables for subplot titles. If None, uses ["VAR_0", "VAR_1", ...] units : list of str, optional Units for each variable for axis labels. If None, uses empty strings. quantiles : list of float, optional Quantile values to plot (e.g., [0.90, 0.95, 0.975, 0.99, 0.995]) filename : str, optional Output filename save_dir : str, optional Directory to save the plot figsize_multiplier : int, optional Base size multiplier for subplots save_npz : bool, optional If True, saves the PDF diagnostics to a compressed .npz file. Returns ------- save_path : str Path to the saved figure """ # Convert tensors → numpy if hasattr(predictions, "detach"): predictions = predictions.detach().cpu().numpy() if hasattr(targets, "detach"): targets = targets.detach().cpu().numpy() if hasattr(coarse_inputs, "detach"): coarse_inputs = coarse_inputs.detach().cpu().numpy() batch_size, num_vars, h, w = predictions.shape # Default variable names if not provided if variable_names is None: variable_names = [f"VAR_{i}" for i in range(num_vars)] plot_variable_names = [PlotConfig.get_plot_name(var) for var in variable_names] # Default units if not provided if units is None: units = [""] * num_vars # Figure setup fig, axes = plt.subplots( 1, num_vars, figsize=(num_vars * figsize_multiplier, figsize_multiplier), constrained_layout=True, ) if num_vars > 1: axes = axes.ravel() # Handle single subplot case else: axes = np.array([axes]) for ax in axes: ax.set_box_aspect(1) if save_npz: qq_npz_data = {} qq_npz_data["__meta__variables"] = np.array(variable_names) qq_npz_data["__meta__quantiles"] = np.array(quantiles) for i, var_name in enumerate(variable_names): linestyles = mpltex.linestyle_generator(lines=[]) ax = axes[i] plot_name = plot_variable_names[i] # Flatten spatial dims and average over batch # target_vals = targets[:, i] # pred_vals = predictions[:, i] # coarse_vals = coarse_inputs[:, i] pred_vals = PlotConfig.convert_units(var_name, predictions[:, i]) target_vals = PlotConfig.convert_units(var_name, targets[:, i]) coarse_vals = PlotConfig.convert_units(var_name, coarse_inputs[:, i]) # Compute quantiles qs_target = np.quantile(target_vals, quantiles) qs_pred = np.quantile(pred_vals, quantiles) qs_coarse = np.quantile(coarse_vals, quantiles) if save_npz: key = f"{var_name}__qq__" qq_npz_data[key + "quantiles"] = np.array(quantiles) qq_npz_data[key + "truth"] = qs_target qq_npz_data[key + "pred"] = qs_pred qq_npz_data[key + "coarse"] = qs_coarse print(f"[QQ Quantiles] {plot_name}") for q, qt, qp, qc in zip(quantiles, qs_target, qs_pred, qs_coarse): print(f" q={q:.3f} | Truth={qt:.4f} | Pred={qp:.4f} | Coarse={qc:.4f} ") # ---- Plot predicted quantiles ---- for q_idx, q in enumerate(quantiles): ax.plot( qs_target[q_idx], qs_pred[q_idx], label=f"{q * 100:.1f}%", **next(linestyles), ) # ---- Plot coarse quantiles ---- ax.plot( qs_target, qs_coarse, c="black", marker="s", label="Coarse", linestyle="None", ) # ---- 1:1 reference line ---- # Calculate appropriate limits for this variable min_val = min(qs_target.min(), qs_pred.min(), qs_coarse.min()) max_val = max(qs_target.max(), qs_pred.max(), qs_coarse.max()) margin = 0.0 plot_min = min_val - margin plot_max = max_val + margin ax.plot( [plot_min, plot_max], [plot_min, plot_max], "r--", alpha=0.7, label="1:1" ) # Set axis limits # ax.set_xlim(plot_min, plot_max) # ax.set_ylim(plot_min, plot_max) ax.xaxis.set_major_locator(ticker.MaxNLocator(4)) ax.yaxis.set_major_locator(ticker.MaxNLocator(4)) # Labels and formatting # ax.set_title(var_name) ax.set_title(plot_name) # Add unit to labels if provided unit_str = f" ({units[i]})" if units[i] else "" # Only add y-axis label for leftmost plot if i == 0: ax.set_ylabel(f"Predicted/Coarse quantiles{unit_str}") ax.set_xlabel(f"True quantiles{unit_str}") ax.grid(True, linestyle="--", alpha=0.3) # Add legend only for first subplot if i == 0: ax.legend() # Save figure os.makedirs(save_dir, exist_ok=True) save_path = os.path.join(save_dir, filename) if save_npz: npz_path = os.path.splitext(save_path)[0] + ".npz" np.savez_compressed(npz_path, **qq_npz_data) plt.savefig(save_path, bbox_inches="tight") plt.close(fig) return save_path
[docs] def dry_frequency_map(array, threshold): """ Compute spatial dry pixels proportion maps. Value of each pixel corresponds to the frequency of dry weather for this pixel. Parameters ---------- array : torch.Tensor or np.array Model predictions of shape [batch_size, h, w] threshold : float threshold for precipitation (expressed in mm): under it, pixel is considered dry. Returns ------- np.ndarray(np.float64) of shape [h,w] """ # convert to numpy if tensor : if hasattr(array, "detach"): array = array.detach().cpu().numpy() dry_array = (array < threshold).astype(np.float64) dry_array_map = np.mean(dry_array, axis=0) return dry_array_map
[docs] def plot_dry_frequency_map( predictions, # Model predictions precipitation (fine predicted) targets, # Ground truth precipitation (fine true) threshold, # threshold to define dry and wet (in mm) lat_1d, lon_1d, filename="validation_dry_frequency_map.png", save_dir=None, figsize_multiplier=None, # Base size per subplot ): """ Plot spatial dry pixels proportion maps. Value of each pixel corresponds to the frequency of dry weather for this pixel. Parameters ---------- predictions : torch.Tensor or np.array Model predictions of shape [batch_size, h, w] targets : torch.Tensor or np.array Ground truth of shape [batch_size, h, w] threshold : float threshold for precipitation (expressed in mm): under it, pixel is considered dry. lat_1d : array-like 1D array of latitude coordinates with shape [H]. lon_1d : array-like 1D array of longitude coordinates with shape [W]. filename : str, optional Output filename for saving the plot. save_dir : str, optional Directory to save the plot. figsize_multiplier : int, optional Base size multiplier for subplots. Returns ------- None """ if save_dir is None: save_dir = PlotConfig.DEFAULT_SAVE_DIR if figsize_multiplier is None: figsize_multiplier = PlotConfig.DEFAULT_FIGSIZE_MULTIPLIER # Convert tensors to numpy if hasattr(predictions, "detach"): predictions = predictions.detach().cpu().numpy() if hasattr(targets, "detach"): targets = targets.detach().cpu().numpy() if hasattr(lat_1d, "detach"): lat_1d = lat_1d.detach().cpu().numpy() if hasattr(lon_1d, "detach"): lon_1d = lon_1d.detach().cpu().numpy() lat_min, lat_max = lat_1d.min(), lat_1d.max() lon_min, lon_max = lon_1d.min(), lon_1d.max() _, h, w = targets.shape lat_block = np.linspace(lat_max, lat_min, h) lon_block = np.linspace(lon_min, lon_max, w) lat, lon = np.meshgrid(lat_block, lon_block, indexing="ij") lon_center = float((lon_min + lon_max) / 2) cmap = PlotConfig.get_colormap( "dry frequency" ) # need to define the comap in PlotConfig # convert units : predictions = PlotConfig.convert_units("precipitation", predictions) targets = PlotConfig.convert_units("precipitation", targets) dry_freq_pred_map = dry_frequency_map(predictions, threshold) # dry_freq_pred = np.mean(dry_freq_pred_map) dry_freq_targ_map = dry_frequency_map(targets, threshold) # dry_freq_targ = np.mean(dry_freq_targ_map) vmin = 0 vmax = 1 base_width_per_panel = 4.5 base_height_per_panel = 3.0 fig_width = base_width_per_panel fig_height = 3 * base_height_per_panel fig, axes = plt.subplots( 3, figsize=(fig_width, fig_height), subplot_kw={ "projection": ccrs.PlateCarree(central_longitude=lon_center) }, # ccrs.Mercator(central_longitude=lon_center) gridspec_kw={"wspace": 0.1}, ) fig.subplots_adjust( top=0.9, bottom=0.1, left=0.1, right=0.9, wspace=0.1, hspace=0.1 ) im = axes[0].pcolormesh( lon, lat, dry_freq_pred_map, vmin=vmin, vmax=vmax, cmap=cmap, transform=ccrs.PlateCarree(), shading="auto", ) axes[0].set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree()) axes[0].coastlines(linewidth=0.6) axes[0].add_feature( cfeature.BORDERS.with_scale("50m"), linewidth=0.6, linestyle="--", edgecolor="black", zorder=11, ) axes[0].add_feature( cfeature.LAKES.with_scale("50m"), edgecolor="black", facecolor="none", linewidth=0.6, zorder=9, ) # ax.set_aspect("auto") axes[0].set_xticks([]) axes[0].set_yticks([]) axes[0].set_title("Predicted") im = axes[1].pcolormesh( lon, lat, dry_freq_targ_map, vmin=vmin, vmax=vmax, cmap=cmap, transform=ccrs.PlateCarree(), shading="auto", ) axes[1].set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree()) axes[1].coastlines(linewidth=0.6) axes[1].add_feature( cfeature.BORDERS.with_scale("50m"), linewidth=0.6, linestyle="--", edgecolor="black", zorder=11, ) axes[1].add_feature( cfeature.LAKES.with_scale("50m"), edgecolor="black", facecolor="none", linewidth=0.6, zorder=9, ) # ax.set_aspect("auto") axes[1].set_xticks([]) axes[1].set_yticks([]) axes[1].set_title("Target") pos0 = axes[0].get_position() pos1 = axes[1].get_position() bottom = pos1.y0 top = pos0.y1 height = top - bottom cax1 = fig.add_axes([0.92, bottom, 0.03, height]) fig.colorbar(im, cax=cax1, label="frequency") # vmax_diff = max( # np.abs(np.max(dry_freq_pred_map - dry_freq_targ_map)), # np.abs(np.min(dry_freq_pred_map - dry_freq_targ_map)), # ) # norm_diff = mcolors.TwoSlopeNorm(vmin=-1, vcenter=0, vmax = 1) im = axes[2].pcolormesh( lon, lat, dry_freq_pred_map - dry_freq_targ_map, norm=mcolors.TwoSlopeNorm(vmin=-vmax, vcenter=0, vmax=vmax), cmap="seismic", transform=ccrs.PlateCarree(), shading="auto", ) axes[2].set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree()) axes[2].coastlines(linewidth=0.6) axes[2].add_feature( cfeature.BORDERS.with_scale("50m"), linewidth=0.6, linestyle="--", edgecolor="black", zorder=11, ) axes[2].add_feature( cfeature.LAKES.with_scale("50m"), edgecolor="black", facecolor="none", linewidth=0.6, zorder=9, ) # ax.set_aspect("auto") axes[2].set_xticks([]) axes[2].set_yticks([]) axes[2].set_title("Predicted frequency - Target frequency") pos2 = axes[2].get_position() cax2 = fig.add_axes([0.92, pos2.y0, 0.03, pos2.height]) fig.colorbar(im, cax=cax2, label="frequency") os.makedirs(save_dir, exist_ok=True) save_path = os.path.join(save_dir, filename) plt.savefig(save_path, bbox_inches="tight") plt.close(fig) return save_path
[docs] def calculate_pearsoncorr_nparray(arr1, arr2, axis=0): """ Calculate Pearson correlation between 2 N-dimensional numpy arrays. Parameters: ----------- arr1 : numpy.ndarray First N-dimensional array arr2 : numpy.ndarray Second N-dimensional array (must have same shape as arr1) axis : int or type of int, default=0 Axis or tuple of axes over which to compute correlation Returns: -------- numpy.ndarray Pearson correlation coefficients. Output has N - len(axis) dimensions (input shape with the specified axis/axes removed). """ if arr1.shape != arr2.shape: raise ValueError( f"Arrays must have the same shape. Got {arr1.shape} and {arr2.shape}" ) # Center the data over axis/axes specified arr1_centered = arr1 - arr1.mean(axis=axis, keepdims=True) arr2_centered = arr2 - arr2.mean(axis=axis, keepdims=True) # Compute correlation over axis/axes specified numerator = (arr1_centered * arr2_centered).sum(axis=axis) denominator = np.sqrt( (arr1_centered**2).sum(axis=axis) * (arr2_centered**2).sum(axis=axis) ) # Avoid division by zero (set as 0.0 instead of inf or nan) correlations = np.divide( numerator, denominator, out=np.zeros_like(numerator), where=denominator != 0 ) return correlations
[docs] def plot_validation_mvcorr_space( predictions, # Model predictions (fine predicted) targets, # Ground truth (fine true) coarse_inputs=None, # Coarse inputs for comparison (optional) variable_names=None, # List of variable names filename="validation_mvcorr_space.png", save_dir="./results", figsize_multiplier=4, # Base size per subplot ): """ Compute multivariate correlation over the space dimensions and plot as time-series, comparing model predictions vs ground truth, for all combinations of variables. Uses Pearson's correlation coefficient. Parameters ---------- predictions : torch.Tensor or np.array Model predictions of shape [batch_size, num_variables, h, w] targets : torch.Tensor or np.array Ground truth of shape [batch_size, num_variables, h, w] coarse_inputs : torch.Tensor or np.array, optional Coarse inputs of shape [batch_size, num_variables, h, w] variable_names : list of str, optional Names of the variables for subplot titles filename : str, optional Output filename save_dir : str, optional Directory to save the plot figsize_multiplier : int, optional Base size multiplier for subplots Returns ------- save_path : str Path to the saved figure """ # Convert to numpy if they're tensors if hasattr(predictions, "detach"): predictions = predictions.detach().cpu().numpy() if hasattr(targets, "detach"): targets = targets.detach().cpu().numpy() if coarse_inputs is not None and hasattr(coarse_inputs, "detach"): coarse_inputs = coarse_inputs.detach().cpu().numpy() batch_size, num_vars, h, w = predictions.shape if num_vars < 2: print("ERROR: need at least 2 variables but num_vars < 2") return "0" # Default variable names if not provided if variable_names is None: variable_names = [f"VAR_{i}" for i in range(num_vars)] # Make list of tuples defining variable combinations list_var_combos = [] for ii in range(num_vars - 1): for jj in range(num_vars - 1 - ii): list_var_combos.append((ii, ii + jj + 1)) # Calculate grid dimensions ncols = 1 nrows = int(num_vars * (num_vars - 1) / 2) # no. distinct pairs of input variables fwidth = 6 # longitude range fheight = nrows * figsize_multiplier # Set up figure fig, axes = plt.subplots( nrows, ncols, figsize=(fwidth, fheight), squeeze=False, sharex=True ) axes = axes.flatten() linestyles = mpltex.linestyle_generator(markers=[]) style_truth = next(linestyles) style_pred = next(linestyles) style_coarse = next(linestyles) if coarse_inputs is not None else None var_name_combo_list = [] # Plot correlation timeseries for each combination of variables # max_count = 0 for i, varComb in enumerate(list_var_combos): var_name_combo = variable_names[varComb[0]] + " & " + variable_names[varComb[1]] var_name_combo_list.append(var_name_combo) # Compute Correlation pred_corr = calculate_pearsoncorr_nparray( predictions[:, varComb[0], :, :], predictions[:, varComb[1], :, :], axis=(1, 2), ) target_corr = calculate_pearsoncorr_nparray( targets[:, varComb[0], :, :], targets[:, varComb[1], :, :], axis=(1, 2) ) if coarse_inputs is not None: coarse_corr = calculate_pearsoncorr_nparray( coarse_inputs[:, varComb[0], :, :], coarse_inputs[:, varComb[1], :, :], axis=(1, 2), ) ax = axes[i] time_index = range(batch_size) ax.plot(time_index, target_corr, label="Truth", linewidth=1.0, **style_truth) ax.plot(time_index, pred_corr, label="Prediction", linewidth=1.0, **style_pred) if coarse_inputs is not None: ax.plot( time_index, coarse_corr, label="Coarse", linewidth=1.0, **style_coarse ) ax.grid(True, alpha=0.3) ax.set_ylabel(var_name_combo) # ax.set_ylim(-1, 1) ax.set_xlim(0, batch_size - 1) if i == 0: ax.legend() axes[0].set_title("Spatial Pearson Correlation Over Time") axes[-1].set_xlabel("Time Step") # Ensure save directory exists os.makedirs(save_dir, exist_ok=True) save_path = os.path.join(save_dir, filename) plt.savefig(save_path, bbox_inches="tight") plt.close(fig) return save_path
[docs] def plot_validation_mvcorr( predictions, # Model predictions (fine predicted) targets, # Ground truth (fine true) lat, lon, coarse_inputs=None, # Coarse inputs for comparison (optional) variable_names=None, # List of variable names filename="validation_mvcorr_time.png", save_dir="./results", figsize_multiplier=4, # Base size per subplot ): """ Compute multivariate correlation over the time dimension and plot as maps, comparing model predictions vs ground truth, for all combinations of variables. Uses Pearson's correlation coefficient. Parameters ---------- predictions : torch.Tensor or np.array Model predictions of shape [batch_size, num_variables, h, w] targets : torch.Tensor or np.array Ground truth of shape [batch_size, num_variables, h, w] lat : array-like 2D array of latitude coordinates with shape [h, w]. lon : array-like 2D array of longitude coordinates with shape [h, w]. coarse_inputs : torch.Tensor or np.array, optional Coarse inputs of shape [batch_size, num_variables, h, w] variable_names : list of str, optional Names of the variables for subplot titles filename : str, optional Output filename save_dir : str, optional Directory to save the plot figsize_multiplier : int, optional Base size multiplier for subplots Returns ------- save_path : str Path to the saved figure """ if save_dir is None: save_dir = PlotConfig.DEFAULT_SAVE_DIR if figsize_multiplier is None: figsize_multiplier = PlotConfig.DEFAULT_FIGSIZE_MULTIPLIER # Convert to numpy if they're tensors if hasattr(predictions, "detach"): predictions = predictions.detach().cpu().numpy() if hasattr(targets, "detach"): targets = targets.detach().cpu().numpy() if coarse_inputs is not None and hasattr(coarse_inputs, "detach"): coarse_inputs = coarse_inputs.detach().cpu().numpy() if hasattr(lat, "detach"): lat = lat.detach().cpu().numpy() if hasattr(lon, "detach"): lon = lon.detach().cpu().numpy() lat_min, lat_max = lat.min(), lat.max() lon_min, lon_max = lon.min(), lon.max() T, n_vars, h, w = predictions.shape lat_block = np.linspace(lat_max, lat_min, h) lon_block = np.linspace(lon_min, lon_max, w) lat, lon = np.meshgrid(lat_block, lon_block, indexing="ij") lon_center = float((lon_min + lon_max) / 2) batch_size, num_vars, h, w = predictions.shape if num_vars < 2: print("ERROR: need at least 2 variables but num_vars < 2") return "0" # Default variable names if not provided if variable_names is None: variable_names = [f"VAR_{i}" for i in range(num_vars)] # Make list of tuples defining variable combinations list_var_combos = [] for ii in range(num_vars - 1): for jj in range(num_vars - 1 - ii): list_var_combos.append((ii, ii + jj + 1)) # Calculate grid dimensions ncols = 2 if coarse_inputs is not None: ncols = 3 nrows = int(num_vars * (num_vars - 1) / 2) # no. distinct pairs of input variables base_width_per_panel = 4.5 base_height_per_panel = 3.0 fig_width = base_width_per_panel * ncols fig_height = base_height_per_panel * nrows spa_cor_out = np.zeros([nrows, ncols - 1]) spa_rmse_out = np.zeros([nrows, ncols - 1]) # Set up figure fig, axes = plt.subplots( nrows, ncols, figsize=(fig_width, fig_height), subplot_kw={"projection": ccrs.PlateCarree(central_longitude=lon_center)}, squeeze=False, gridspec_kw={"wspace": 0.1}, ) # Define geographic features coastline = cfeature.COASTLINE.with_scale("50m") borders = cfeature.BORDERS.with_scale("50m") # lakes = cfeature.LAKES.with_scale('50m') var_name_combo_list = [] # Plot each combination of variables # max_count = 0 for i, varComb in enumerate(list_var_combos): var_name_combo = variable_names[varComb[0]] + " & " + variable_names[varComb[1]] var_name_combo_list.append(var_name_combo) # Compute Correlation pred_corr = calculate_pearsoncorr_nparray( predictions[:, varComb[0], :, :], predictions[:, varComb[1], :, :], axis=0 ) target_corr = calculate_pearsoncorr_nparray( targets[:, varComb[0], :, :], targets[:, varComb[1], :, :], axis=0 ) if coarse_inputs is not None: coarse_corr = calculate_pearsoncorr_nparray( coarse_inputs[:, varComb[0], :, :], coarse_inputs[:, varComb[1], :, :], axis=0, ) spa_cor_out[i, 0] = np.corrcoef( pred_corr.reshape(pred_corr.size), target_corr.reshape(target_corr.size) )[0, 1] if coarse_inputs is not None: spa_cor_out[i, 1] = np.corrcoef( coarse_corr.reshape(coarse_corr.size), target_corr.reshape(target_corr.size), )[0, 1] spa_rmse_out[i, 0] = np.sqrt((np.square(pred_corr - target_corr)).mean()) if coarse_inputs is not None: spa_rmse_out[i, 1] = np.sqrt((np.square(coarse_corr - target_corr)).mean()) # Col 0: Truth ax_target = axes[i, 0] ax_target.pcolormesh( lon, lat, target_corr, vmin=-1.0, vmax=1.0, cmap="RdBu", transform=ccrs.PlateCarree(), shading="auto", ) ax_target.add_feature(coastline, linewidth=PlotConfig.COASTLINE_w) ax_target.add_feature( borders, linewidth=PlotConfig.BORDER_w, edgecolor="black", linestyle=PlotConfig.BORDER_STYLE, ) # ax_target.set_aspect("auto") ax_target.set_extent( [lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree() ) # Col 1: Prediction ax_pred = axes[i, 1] im_pred = ax_pred.pcolormesh( lon, lat, pred_corr, vmin=-1.0, vmax=1.0, cmap="RdBu", transform=ccrs.PlateCarree(), shading="auto", ) ax_pred.add_feature(coastline, linewidth=PlotConfig.COASTLINE_w) ax_pred.add_feature( borders, linewidth=PlotConfig.BORDER_w, edgecolor="black", linestyle=PlotConfig.BORDER_STYLE, ) # ax_pred.set_aspect("auto") ax_pred.set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree()) if coarse_inputs is not None: # Col 2: Coarse input ax_coar = axes[i, 2] ax_coar.pcolormesh( lon, lat, coarse_corr, vmin=-1.0, vmax=1.0, cmap="RdBu", transform=ccrs.PlateCarree(), shading="auto", ) ax_coar.add_feature(coastline, linewidth=PlotConfig.COASTLINE_w) ax_coar.add_feature( borders, linewidth=PlotConfig.BORDER_w, edgecolor="black", linestyle=PlotConfig.BORDER_STYLE, ) # ax_coar.set_aspect("auto") ax_coar.set_extent( [lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree() ) # Add col labels col_labels = ["Truth", "Prediction"] if coarse_inputs is not None: col_labels = ["Truth", "Prediction", "Coarse"] for col_idx, label in enumerate(col_labels): axes[0, col_idx].set_title(label) # Add row labels for row_idx, label in enumerate(var_name_combo_list): axes[row_idx, 0].text( -0.1, 0.5, label, transform=axes[row_idx, 0].transAxes, va="center", ha="right", rotation="vertical", fontsize=12, ) # Add colorbar fig.subplots_adjust(top=0.9, bottom=0.1, left=0.1, right=0.9, wspace=0.1) pos_top = axes[0, 0].get_position() pos_bottom = axes[-1, 0].get_position() bottom = pos_bottom.y0 top = pos_top.y1 height = top - bottom cbar_ax = fig.add_axes([0.92, bottom, 0.015, height]) fig.colorbar(im_pred, cax=cbar_ax, label=r"Temporal Pearson Correlation") # Ensure save directory exists os.makedirs(save_dir, exist_ok=True) save_path = os.path.join(save_dir, filename) plt.savefig(save_path, bbox_inches="tight") plt.close(fig) """ # _________________________________________ # Output summary map statistics as heatmaps # Spatial Correlation and Spatial RMSE wrt target # Setupt axis labels xLabels=['Prediction'] if coarse_inputs is not None: xLabels=['Prediction','Coarse'] yLabels=var_name_combo_list fig, (ax1,ax2) = plt.subplots(ncols=2, figsize=((ncols+2)*2,4))#, layout='constrained') sns.heatmap(spa_cor_out, ax=ax1, cbar=False, linewidth=0.5, annot=True, fmt='.3f', xticklabels=xLabels, yticklabels=yLabels, vmin=0.0, vmax=1.0, cmap=plt.get_cmap('Reds')) fig.colorbar(ax1.collections[0], ax=ax1, location="left", use_gridspec=False, pad=0.1, label="correlation") ax1.tick_params(axis='y', pad=90, length=0) ax1.tick_params(axis='x', length=0) ax1.yaxis.set_label_position("left") sns.heatmap(spa_rmse_out, ax=ax2, cbar=False, linewidth=0.5, annot=True, fmt='.3f', xticklabels=xLabels, yticklabels=[""]*ncols, vmin=0.0, vmax=0.3, cmap=plt.get_cmap('Reds_r')) fig.colorbar(ax2.collections[0], ax=ax2, location="right", use_gridspec=False, pad=0.1, label="RMSE") ax2.tick_params(rotation=0, length=0) ax2.yaxis.set_label_position("right") # Ensure save directory exists os.makedirs(save_dir, exist_ok=True) filenameCR='SpCorrRmse_'+filename save_path = os.path.join(save_dir, filenameCR) plt.savefig(save_path, bbox_inches='tight') plt.close() """ return save_path
[docs] def plot_temporal_series_comparison( predictions, # Model predictions (fine predicted) targets, # Ground truth (fine true) coarse_inputs=None, # Coarse inputs for comparison (optional) variable_names=None, # List of variable names filename="validation_temp_series.png", save_dir="./results", figsize_multiplier=4, ): """ Plot spatially averaged temporal series for each variable. Parameters ---------- predictions : torch.Tensor or np.array Model predictions of shape [batch_size, num_variables, h, w] targets : torch.Tensor or np.array Ground truth of shape [batch_size, num_variables, h, w] coarse_inputs : torch.Tensor or np.array, optional Coarse inputs of shape [batch_size, num_variables, h, w] variable_names : list of str, optional Names of the variables for subplot titles filename : str, optional Output filename save_dir : str, optional Directory to save the plot figsize_multiplier : int, optional Base size multiplier for subplots Returns ------- save_path : str Path to the saved figure """ if hasattr(predictions, "detach"): predictions = predictions.detach().cpu().numpy() if hasattr(targets, "detach"): targets = targets.detach().cpu().numpy() if coarse_inputs is not None and hasattr(coarse_inputs, "detach"): coarse_inputs = coarse_inputs.detach().cpu().numpy() if predictions.shape != targets.shape: raise ValueError(f"Shape mismatch: {predictions.shape} vs {targets.shape}") if coarse_inputs is not None and coarse_inputs.shape != targets.shape: raise ValueError( f"Coarse shape mismatch: {coarse_inputs.shape} vs {targets.shape}" ) batch_size, num_vars, h, w = predictions.shape # Default variable names if not provided if variable_names is None: variable_names = [f"VAR_{i}" for i in range(num_vars)] if len(variable_names) != num_vars: raise ValueError( f"{len(variable_names)} variable names but num_vars={num_vars}" ) fig = plt.figure(figsize=(6, figsize_multiplier * num_vars)) linestyles = mpltex.linestyle_generator(markers=[]) style_truth = next(linestyles) style_pred = next(linestyles) style_coarse = next(linestyles) if coarse_inputs is not None else None # Loop over variables for i, var in enumerate(variable_names): ax = fig.add_subplot(num_vars, 1, i + 1) pred_vals = PlotConfig.convert_units(var, predictions[:, i]) true_vals = PlotConfig.convert_units(var, targets[:, i]) # Spatial mean over H and W dimensions s_pred = pred_vals.mean(axis=(1, 2)) s_true = true_vals.mean(axis=(1, 2)) # Temporal axis time_index = range(batch_size) ax.plot(time_index, s_true, label="Truth", linewidth=1.0, **style_truth) ax.plot(time_index, s_pred, label="Prediction", linewidth=1.0, **style_pred) if coarse_inputs is not None: coarse_vals = PlotConfig.convert_units(var, coarse_inputs[:, i]) s_coarse = coarse_vals.mean(axis=(1, 2)) ax.plot(time_index, s_coarse, label="Coarse", linewidth=1.0, **style_coarse) ax.set_title(var) ax.grid(True, alpha=0.3) ax.set_ylabel("Spatial mean") if i == num_vars - 1: ax.set_xlabel("Time index") else: ax.tick_params(labelbottom=False) ax.legend() os.makedirs(save_dir, exist_ok=True) save_path = os.path.join(save_dir, filename) plt.savefig(save_path, bbox_inches="tight") plt.close(fig) return save_path
[docs] def ranks( predictions, # Model predictions precipitation (fine predicted) targets, # Ground truth precipitation (fine true) ): """ Compute ranks of predictions compared to targets. Parameters ---------- predictions : torch.Tensor or np.array Model predictions of shape [ensemble_size, batch_size, h, w] targets : torch.Tensor or np.array Targets of shape [batch_size, h, w] Returns ------- np.ndarray(np.float64) of shape [batch_size*h*w,] """ # convert to numpy if tensor : if hasattr(predictions, "detach"): predictions = predictions.detach().cpu().numpy() if hasattr(targets, "detach"): targets = targets.detach().cpu().numpy() nb_ens, T, L, W = predictions.shape predictions_ens = predictions.reshape(nb_ens, T * L * W) targets = targets.reshape(1, T * L * W) diff = predictions_ens - targets mask_leq = (diff <= 0).astype(np.float32) mask_l = (diff < 0).astype(np.float32) mask = (mask_leq + mask_l) / 2 return np.sum(mask, axis=0)
[docs] def plot_ranks( predictions, # model predictions targets, # ground truth variable_names=None, # list of variable names filename="ranks.png", save_dir="./results", figsize_multiplier=4, ): """ Create rank histograms of predictions compared to targets for each variable. Parameters ---------- predictions : torch.Tensor or np.array Model predictions of shape [ensemble_size, batch_size, num_variables, h, w] targets : torch.Tensor or np.array Ground truth of shape [batch_size, num_variables, h, w] variable_names : list of str, optional Names of the variables for subplot titles. If None, uses ["VAR_0", "VAR_1", ...] filename : str, optional Output filename save_dir : str, optional Directory to save the plot figsize_multiplier : int, optional Base size multiplier for subplots Returns ------- save_path : str Path to the saved figure """ # Convert tensors → numpy if hasattr(predictions, "detach"): predictions = predictions.detach().cpu().numpy() if hasattr(targets, "detach"): targets = targets.detach().cpu().numpy() ensemble_size, batch_size, num_vars, h, w = predictions.shape # Default variable names if not provided if variable_names is None: variable_names = [f"VAR_{i}" for i in range(num_vars)] plot_variable_names = [PlotConfig.get_plot_name(var) for var in variable_names] # Figure setup fig, axes = plt.subplots( 1, num_vars, figsize=(num_vars * figsize_multiplier, figsize_multiplier), constrained_layout=True, ) if num_vars > 1: axes = axes.ravel() # Handle single subplot case else: axes = np.array([axes]) for ax in axes: ax.set_box_aspect(1) for i, var_name in enumerate(variable_names): ax = axes[i] plot_name = plot_variable_names[i] ranks_predicted = ranks( predictions=predictions[:, :, i, :, :], targets=targets[:, i, :, :], ) ax.hist(ranks_predicted, bins=np.arange(ensemble_size + 2), density=True) ax.plot( [0, ensemble_size + 1], [1 / (ensemble_size + 1), 1 / (ensemble_size + 1)], linestyle="--", color="red", ) ax.set_title(plot_name) ax.set_xlabel("ranks") ax.set_ylabel("frequency") # Save figure os.makedirs(save_dir, exist_ok=True) save_path = os.path.join(save_dir, filename) plt.savefig(save_path, bbox_inches="tight") plt.close(fig) return save_path
[docs] def get_divergence(u_tensor, v_tensor, spacing): """ Compute the horizontal divergence of a windfield. Parameters ---------- u_tensor : torch.Tensor or np.array, shape [...,h,w] tensor that stores the zonal component of the windfield. Can have arbitrary number of dimensions, but the last two dimensions have to correspond to longitude and latitude. u_tensor and v_tensor need to have the same shape. v_tensor : torch.Tensor or np.array tensor that stores the meridional component of the windfield. Can have arbitrary number of dimensions, but the last two dimensions have to correspond to longitude and latitude. u_tensor and v_tensor need to have the same shape. spacing : float float that describes the resolution of the windfield. Used to compute the gradients. Returns ------- np.ndarray(np.float64) of same shape as u_tensor and v_tensor """ # convert to torch if needed if isinstance(u_tensor, np.ndarray): u_tensor = torch.from_numpy(u_tensor) if isinstance(v_tensor, np.ndarray): v_tensor = torch.from_numpy(v_tensor) u_x = torch.gradient(u_tensor, spacing=spacing, dim=-2)[0] v_y = torch.gradient(v_tensor, spacing=spacing, dim=-1)[0] return (u_x + v_y).detach().cpu().numpy()
[docs] def get_curl(u_tensor, v_tensor, spacing): """ Compute the curl of a windfield. Parameters ---------- u_tensor : torch.Tensor or np.array, shape [...,h,w] tensor that stores the zonal component of the windfield. Can have arbitrary number of dimensions, but the last two dimensions have to correspond to longitude and latitude. u_tensor and v_tensor need to have the same shape. v_tensor : torch.Tensor or np.array tensor that stores the meridional component of the windfield. Can have arbitrary number of dimensions, but the last two dimensions have to correspond to longitude and latitude. u_tensor and v_tensor need to have the same shape. spacing : float spatial resolution of the windfield. Used to compute the gradients. Returns ------- np.ndarray(np.float64) of same shape as u_tensor and v_tensor """ # convert to torch if needed if isinstance(u_tensor, np.ndarray): u_tensor = torch.from_numpy(u_tensor) if isinstance(v_tensor, np.ndarray): v_tensor = torch.from_numpy(v_tensor) u_y = torch.gradient(u_tensor, spacing=spacing, dim=-1)[0] v_x = torch.gradient(v_tensor, spacing=spacing, dim=-2)[0] return (v_x - u_y).detach().cpu().numpy()
[docs] def plot_mean_divergence_map( u_prediction, # Model predictions precipitation (fine predicted) v_prediction, # Model predictions precipitation (fine predicted) u_target, # Ground truth precipitation (fine true) v_target, # Ground truth precipitation (fine true) spacing, lat_1d, lon_1d, filename="mean_divergence.png", save_dir=None, figsize_multiplier=None, # Base size per subplot ): """ Plot spatial dry pixels proportion maps. Value of each pixel corresponds to the frequency of dry weather for this pixel. Parameters ---------- u_prediction : torch.Tensor or np.array Model predictions of shape [batch_size, h, w] for zonal component of wind Last two dims have to correspond to longitude and latitude u_prediction and v_prediction need to have the same shape v_prediction : torch.Tensor or np.array Model predictions of shape [batch_size, h, w] for meridional component of wind Last two dims have to correspond to longitude and latitude u_prediction and v_prediction need to have the same shape u_target : torch.Tensor or np.array Ground truth of shape [batch_size, h, w] Last two dims have to correspond to longitude and latitude u_target and v_target need to have the same shape v_target : torch.Tensor or np.array Ground truth of shape [batch_size, h, w] Last two dims have to correspond to longitude and latitude u_target and v_target need to have the same shape spacing : float spatial resolution of the windfield. Used to compute the gradients. lat_1d : array-like 1D array of latitude coordinates with shape [H]. lon_1d : array-like 1D array of longitude coordinates with shape [W]. filename : str, optional Output filename for saving the plot. save_dir : str, optional Directory to save the plot. figsize_multiplier : int, optional Base size multiplier for subplots. Returns ------- None """ if save_dir is None: save_dir = PlotConfig.DEFAULT_SAVE_DIR if figsize_multiplier is None: figsize_multiplier = PlotConfig.DEFAULT_FIGSIZE_MULTIPLIER lat_min, lat_max = lat_1d.min(), lat_1d.max() lon_min, lon_max = lon_1d.min(), lon_1d.max() _, h, w = u_target.shape lat_block = np.linspace(lat_max, lat_min, h) lon_block = np.linspace(lon_min, lon_max, w) lat, lon = np.meshgrid(lat_block, lon_block, indexing="ij") lon_center = float((lon_min + lon_max) / 2) cmap = PlotConfig.get_colormap( "divergence" ) # need to define the comap in PlotConfig # convert units : u_prediction = PlotConfig.convert_units("wind", u_prediction) v_prediction = PlotConfig.convert_units("wind", v_prediction) u_target = PlotConfig.convert_units("wind", u_target) v_target = PlotConfig.convert_units("wind", v_target) div_prediction = get_divergence(u_prediction, v_prediction, spacing) div_target = get_divergence(u_target, v_target, spacing) mean_div_prediction = np.mean(div_prediction, axis=0) mean_div_target = np.mean(div_target, axis=0) vmin = min(np.min(mean_div_prediction), np.min(mean_div_target)) vmax = max(np.max(mean_div_prediction), np.max(mean_div_target)) vmax = max(np.abs(vmax), np.abs(vmin)) norm = mcolors.TwoSlopeNorm(vmin=-vmax, vcenter=0, vmax=vmax) base_width_per_panel = 4.5 base_height_per_panel = 3.0 fig_width = base_width_per_panel fig_height = 3 * base_height_per_panel fig, axes = plt.subplots( 3, 1, figsize=(fig_width, fig_height), subplot_kw={ "projection": ccrs.PlateCarree(central_longitude=lon_center) }, # ccrs.Mercator(central_longitude=lon_center) gridspec_kw={"wspace": 0.1}, ) fig.subplots_adjust( top=0.9, bottom=0.1, left=0.1, right=0.9, wspace=0.1, hspace=0.1 ) im = axes[0].pcolormesh( lon, lat, mean_div_target, norm=norm, cmap=cmap, transform=ccrs.PlateCarree(), shading="auto", ) axes[0].set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree()) axes[0].coastlines(linewidth=0.6) axes[0].add_feature( cfeature.BORDERS.with_scale("50m"), linewidth=0.6, linestyle="--", edgecolor="black", zorder=11, ) axes[0].add_feature( cfeature.LAKES.with_scale("50m"), edgecolor="black", facecolor="none", linewidth=0.6, zorder=9, ) # ax.set_aspect("auto") axes[0].set_xticks([]) axes[0].set_yticks([]) axes[0].set_title("Target") im = axes[1].pcolormesh( lon, lat, mean_div_prediction, norm=norm, cmap=cmap, transform=ccrs.PlateCarree(), shading="auto", ) axes[1].set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree()) axes[1].coastlines(linewidth=0.6) axes[1].add_feature( cfeature.BORDERS.with_scale("50m"), linewidth=0.6, linestyle="--", edgecolor="black", zorder=11, ) axes[1].add_feature( cfeature.LAKES.with_scale("50m"), edgecolor="black", facecolor="none", linewidth=0.6, zorder=9, ) # ax.set_aspect("auto") axes[1].set_xticks([]) axes[1].set_yticks([]) axes[1].set_title("Prediction") # axes are vertically stacked # Get positions of the top two axes pos0 = axes[0].get_position() pos1 = axes[1].get_position() bottom = pos1.y0 top = pos0.y1 height = top - bottom # Create the colorbar axis cax1 = fig.add_axes([0.92, bottom, 0.02, height]) # [left, bottom, width, height] # Add the colorbar fig.colorbar(im, cax=cax1, orientation="vertical", label="divergence") im = axes[2].pcolormesh( lon, lat, mean_div_prediction - mean_div_target, norm=norm, cmap=cmap, transform=ccrs.PlateCarree(), shading="auto", ) axes[2].set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree()) axes[2].coastlines(linewidth=0.6) axes[2].add_feature( cfeature.BORDERS.with_scale("50m"), linewidth=0.6, linestyle="--", edgecolor="black", zorder=11, ) axes[2].add_feature( cfeature.LAKES.with_scale("50m"), edgecolor="black", facecolor="none", linewidth=0.6, zorder=9, ) # ax.set_aspect("auto") axes[2].set_xticks([]) axes[2].set_yticks([]) axes[2].set_title("Predicted - Target") # Get position of the bottom axis pos2 = axes[2].get_position() # Create colorbar axis cax2 = fig.add_axes([0.92, pos2.y0, 0.02, pos2.height]) # Add the colorbar fig.colorbar(im, cax=cax2, orientation="vertical", label="divergence error") os.makedirs(save_dir, exist_ok=True) save_path = os.path.join(save_dir, filename) plt.savefig(save_path, bbox_inches="tight") plt.close(fig) return save_path
[docs] def plot_mean_curl_map( u_prediction, # Model predictions precipitation (fine predicted) v_prediction, # Model predictions precipitation (fine predicted) u_target, # Ground truth precipitation (fine true) v_target, # Ground truth precipitation (fine true) spacing, lat_1d, lon_1d, filename="mean_curl.png", save_dir=None, figsize_multiplier=None, # Base size per subplot ): """ Plot spatial dry pixels proportion maps. Value of each pixel corresponds to the frequency of dry weather for this pixel. Parameters ---------- u_prediction : torch.Tensor or np.array Model predictions of shape [batch_size, h, w] for zonal component of wind Last two dims have to correspond to longitude and latitude u_prediction and v_prediction need to have the same shape v_prediction : torch.Tensor or np.array Model predictions of shape [batch_size, h, w] for meridional component of wind Last two dims have to correspond to longitude and latitude u_prediction and v_prediction need to have the same shape u_target : torch.Tensor or np.array Ground truth of shape [batch_size, h, w] Last two dims have to correspond to longitude and latitude u_target and v_target need to have the same shape v_target : torch.Tensor or np.array Ground truth of shape [batch_size, h, w] Last two dims have to correspond to longitude and latitude u_target and v_target need to have the same shape spacing : float spatial resolution of the windfield. Used to compute the gradients. lat_1d : array-like 1D array of latitude coordinates with shape [H]. lon_1d : array-like 1D array of longitude coordinates with shape [W]. filename : str, optional Output filename for saving the plot. save_dir : str, optional Directory to save the plot. figsize_multiplier : int, optional Base size multiplier for subplots. Returns ------- None """ if save_dir is None: save_dir = PlotConfig.DEFAULT_SAVE_DIR if figsize_multiplier is None: figsize_multiplier = PlotConfig.DEFAULT_FIGSIZE_MULTIPLIER lat_min, lat_max = lat_1d.min(), lat_1d.max() lon_min, lon_max = lon_1d.min(), lon_1d.max() _, h, w = u_target.shape lat_block = np.linspace(lat_max, lat_min, h) lon_block = np.linspace(lon_min, lon_max, w) lat, lon = np.meshgrid(lat_block, lon_block, indexing="ij") lon_center = float((lon_min + lon_max) / 2) cmap = PlotConfig.get_colormap("curl") # need to define the comap in PlotConfig # convert units : u_prediction = PlotConfig.convert_units("wind", u_prediction) v_prediction = PlotConfig.convert_units("wind", v_prediction) u_target = PlotConfig.convert_units("wind", u_target) v_target = PlotConfig.convert_units("wind", v_target) curl_prediction = get_curl(u_prediction, v_prediction, spacing) curl_target = get_curl(u_target, v_target, spacing) mean_curl_prediction = np.mean(curl_prediction, axis=0) mean_curl_target = np.mean(curl_target, axis=0) vmin = min(np.min(mean_curl_prediction), np.min(mean_curl_target)) vmax = max(np.max(mean_curl_prediction), np.max(mean_curl_target)) vmax = max(np.abs(vmax), np.abs(vmin)) norm = mcolors.TwoSlopeNorm(vmin=-vmax, vcenter=0, vmax=vmax) base_width_per_panel = 4.5 base_height_per_panel = 3.0 fig_width = base_width_per_panel fig_height = 3 * base_height_per_panel fig, axes = plt.subplots( 3, 1, figsize=(fig_width, fig_height), subplot_kw={ "projection": ccrs.PlateCarree(central_longitude=lon_center) }, # ccrs.Mercator(central_longitude=lon_center) gridspec_kw={"wspace": 0.1}, ) fig.subplots_adjust( top=0.9, bottom=0.1, left=0.1, right=0.9, wspace=0.1, hspace=0.1 ) im = axes[0].pcolormesh( lon, lat, mean_curl_target, norm=norm, cmap=cmap, transform=ccrs.PlateCarree(), shading="auto", ) axes[0].set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree()) axes[0].coastlines(linewidth=0.6) axes[0].add_feature( cfeature.BORDERS.with_scale("50m"), linewidth=0.6, linestyle="--", edgecolor="black", zorder=11, ) axes[0].add_feature( cfeature.LAKES.with_scale("50m"), edgecolor="black", facecolor="none", linewidth=0.6, zorder=9, ) # ax.set_aspect("auto") axes[0].set_xticks([]) axes[0].set_yticks([]) axes[0].set_title("Target") im = axes[1].pcolormesh( lon, lat, mean_curl_prediction, norm=norm, cmap=cmap, transform=ccrs.PlateCarree(), shading="auto", ) axes[1].set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree()) axes[1].coastlines(linewidth=0.6) axes[1].add_feature( cfeature.BORDERS.with_scale("50m"), linewidth=0.6, linestyle="--", edgecolor="black", zorder=11, ) axes[1].add_feature( cfeature.LAKES.with_scale("50m"), edgecolor="black", facecolor="none", linewidth=0.6, zorder=9, ) # ax.set_aspect("auto") axes[1].set_xticks([]) axes[1].set_yticks([]) axes[1].set_title("Prediction") # Get positions of the top two axes pos0 = axes[0].get_position() pos1 = axes[1].get_position() bottom = pos1.y0 top = pos0.y1 height = top - bottom # Create the colorbar axis cax1 = fig.add_axes([0.92, bottom, 0.02, height]) # [left, bottom, width, height] # Add the colorbar fig.colorbar(im, cax=cax1, orientation="vertical", label="curl") im = axes[2].pcolormesh( lon, lat, mean_curl_prediction - mean_curl_target, norm=norm, cmap=cmap, transform=ccrs.PlateCarree(), shading="auto", ) axes[2].set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree()) axes[2].coastlines(linewidth=0.6) axes[2].add_feature( cfeature.BORDERS.with_scale("50m"), linewidth=0.6, linestyle="--", edgecolor="black", zorder=11, ) axes[2].add_feature( cfeature.LAKES.with_scale("50m"), edgecolor="black", facecolor="none", linewidth=0.6, zorder=9, ) # ax.set_aspect("auto") axes[2].set_xticks([]) axes[2].set_yticks([]) axes[2].set_title("Predicted - Target") # Get position of the bottom axis pos2 = axes[2].get_position() # Create colorbar axis cax2 = fig.add_axes([0.92, pos2.y0, 0.02, pos2.height]) # Add the colorbar fig.colorbar(im, cax=cax2, orientation="vertical", label="curl error") os.makedirs(save_dir, exist_ok=True) save_path = os.path.join(save_dir, filename) plt.savefig(save_path, bbox_inches="tight") plt.close(fig) return save_path
# ============================================================================ # Plotting Functions Test Suite # ============================================================================
[docs] class TestPlottingFunctions(unittest.TestCase): """Unit tests for plotting functions with visible output for styling adjustment."""
[docs] def __init__(self, methodName="runTest", logger=None): super().__init__(methodName) self.logger = logger
[docs] def setUp(self): """Set up test fixtures.""" self.output_dir = "./test_plots" os.makedirs(self.output_dir, exist_ok=True) # Generate realistic synthetic test data np.random.seed(42) self.batch_size = 50 self.num_vars = 4 self.h = 64 self.w = 64 if self.logger: self.logger.info( f"Test setup complete - output directory: {self.output_dir}" ) self.logger.info( f"Batch size: {self.batch_size}, Variables: {self.num_vars}, Resolution: {self.h}x{self.w}" ) # Create correlated data for realistic plots x = np.linspace(0, 4 * np.pi, self.w) y = np.linspace(0, 4 * np.pi, self.h) X, Y = np.meshgrid(x, y) patterns = [ np.sin(X) * np.cos(Y), np.exp(-0.1 * (X - 10) ** 2 - 0.1 * (Y - 10) ** 2), X * Y / 100, np.sin(0.5 * X) * np.cos(0.5 * Y) + 0.5 * np.sin(2 * X) * np.cos(2 * Y), ] self.predictions = np.zeros((self.batch_size, self.num_vars, self.h, self.w)) self.targets = np.zeros((self.batch_size, self.num_vars, self.h, self.w)) self.coarse_inputs = np.zeros((self.batch_size, self.num_vars, self.h, self.w)) for i in range(self.num_vars): base_pattern = patterns[i % len(patterns)] for b in range(self.batch_size): noise_pred = np.random.normal(0, 0.1, (self.h, self.w)) noise_target = np.random.normal(0, 0.1, (self.h, self.w)) noise_coarse = np.random.normal(0, 0.2, (self.h, self.w)) scale = 1.0 + 0.1 * np.random.random() offset = 0.1 * np.random.random() self.predictions[b, i] = base_pattern * scale + offset + noise_pred self.targets[b, i] = ( base_pattern * (scale + 0.05) + offset + 0.05 + noise_target ) self.coarse_inputs[b, i] = ( base_pattern * (scale - 0.1) + offset - 0.1 + noise_coarse ) self.variable_names = [ "Temp", "Press", "Humid", "Wind", ] # Create lat/lon arrays for spatial tests self.lat = np.linspace(-90, 90, self.h) self.lon = np.linspace(-180, 180, self.w) # Create comprehensive metrics history self.valid_metrics_history = {} metrics = ["rmse", "mae", "r2"] for var in self.variable_names: var_key = var.split(" ")[0] for metric in metrics: base_val_pred = 0.8 if metric == "r2" else 1.0 base_val_coarse = 0.6 if metric == "r2" else 1.5 decay = np.linspace(0, 0.3, 10) if metric == "r2": self.valid_metrics_history[f"{var_key}_pred_vs_fine_{metric}"] = ( base_val_pred + decay ) self.valid_metrics_history[f"{var_key}_coarse_vs_fine_{metric}"] = ( base_val_coarse + decay * 0.5 ) else: self.valid_metrics_history[f"{var_key}_pred_vs_fine_{metric}"] = ( base_val_pred - decay ) self.valid_metrics_history[f"{var_key}_coarse_vs_fine_{metric}"] = ( base_val_coarse - decay * 0.5 ) # Add average metrics for metric in metrics: self.valid_metrics_history[f"average_pred_vs_fine_{metric}"] = ( 0.1 + np.linspace(0, 0.2, 10) ) self.valid_metrics_history[f"average_coarse_vs_fine_{metric}"] = ( 0.7 + np.linspace(0, 0.2, 10) ) # Loss histories self.train_loss_history = np.exp( -np.linspace(0, 2, 20) ) + 0.1 * np.random.random(20) self.valid_loss_history = np.exp( -np.linspace(0, 1.5, 20) ) + 0.15 * np.random.random(20)
# ============================================================================ # SINGLE COMPREHENSIVE TEST FOR EACH DIAGNOSTIC METHOD # ============================================================================
[docs] def test_validation_hexbin_comprehensive(self): """Comprehensive test for validation hexbin plots.""" if self.logger: self.logger.info("Testing validation hexbin plots comprehensively") # Test 1: Standard configuration expected_path = plot_validation_hexbin( predictions=self.predictions, targets=self.targets, variable_names=self.variable_names, save_dir=self.output_dir, filename="validation_hexbin_standard.png", figsize_multiplier=5, ) self.assertTrue( os.path.exists(expected_path), f"File not found: {expected_path}" ) expected_path = plot_comparison_hexbin( predictions=self.predictions, targets=self.targets, coarse_inputs=self.coarse_inputs, variable_names=self.variable_names, filename="comparison_hexbin_standard.png", save_dir=self.output_dir, ) self.assertTrue( os.path.exists(expected_path), f"File not found: {expected_path}" ) # Test 2: PyTorch tensors predictions_tensor = torch.from_numpy(self.predictions) targets_tensor = torch.from_numpy(self.targets) coarse_tensor = torch.from_numpy(self.coarse_inputs) expected_path = plot_validation_hexbin( predictions=predictions_tensor, targets=targets_tensor, variable_names=self.variable_names, save_dir=self.output_dir, filename="validation_hexbin_torch.png", ) self.assertTrue( os.path.exists(expected_path), f"File not found: {expected_path}" ) expected_path = plot_comparison_hexbin( predictions=predictions_tensor, targets=targets_tensor, coarse_inputs=coarse_tensor, variable_names=self.variable_names, filename="comparison_hexbin_torch.png", save_dir=self.output_dir, ) self.assertTrue( os.path.exists(expected_path), f"File not found: {expected_path}" ) # Test 3: Single variable single_pred = self.predictions[:, 0:1, :, :] single_target = self.targets[:, 0:1, :, :] single_coarse = self.coarse_inputs[:, 0:1, :, :] expected_path = plot_validation_hexbin( predictions=single_pred, targets=single_target, variable_names=[self.variable_names[0]], save_dir=self.output_dir, filename="validation_hexbin_single.png", figsize_multiplier=6, ) self.assertTrue( os.path.exists(expected_path), f"File not found: {expected_path}" ) expected_path = plot_comparison_hexbin( predictions=single_pred, targets=single_target, coarse_inputs=single_coarse, variable_names=[self.variable_names[0]], filename="comparison_hexbin_single.png", save_dir=self.output_dir, ) self.assertTrue( os.path.exists(expected_path), f"File not found: {expected_path}" ) if self.logger: self.logger.info("✅ All validation hexbin tests passed")
[docs] def test_validation_pdfs_comprehensive(self): """Comprehensive test for validation PDF plots.""" if self.logger: self.logger.info("Testing validation PDF plots comprehensively") # Test 1: Standard configuration with coarse inputs expected_path = plot_validation_pdfs( predictions=self.predictions, targets=self.targets, coarse_inputs=self.coarse_inputs, variable_names=self.variable_names, save_dir=self.output_dir, filename="validation_pdfs_standard.png", ) self.assertTrue( os.path.exists(expected_path), f"File not found: {expected_path}" ) # Test 2: Without coarse inputs expected_path = plot_validation_pdfs( predictions=self.predictions, targets=self.targets, coarse_inputs=None, variable_names=self.variable_names, save_dir=self.output_dir, filename="validation_pdfs_no_coarse.png", ) self.assertTrue( os.path.exists(expected_path), f"File not found: {expected_path}" ) # Test 3: PyTorch tensors predictions_tensor = torch.from_numpy(self.predictions) targets_tensor = torch.from_numpy(self.targets) coarse_tensor = torch.from_numpy(self.coarse_inputs) expected_path = plot_validation_pdfs( predictions=predictions_tensor, targets=targets_tensor, coarse_inputs=coarse_tensor, variable_names=self.variable_names, save_dir=self.output_dir, filename="validation_pdfs_torch.png", ) self.assertTrue( os.path.exists(expected_path), f"File not found: {expected_path}" ) if self.logger: self.logger.info("✅ All validation PDF tests passed")
[docs] def test_power_spectra_comprehensive(self): """Comprehensive test for power spectra plots.""" if self.logger: self.logger.info("Testing power spectra plots comprehensively") dlat = np.abs(self.lat[1] - self.lat[0]) dlon = np.abs(self.lon[1] - self.lon[0]) # Test 1: Standard configuration expected_path = plot_power_spectra( predictions=self.predictions, targets=self.targets, coarse_inputs=self.coarse_inputs, dlat=dlat, dlon=dlon, variable_names=self.variable_names, save_dir=self.output_dir, filename="power_spectra_standard.png", ) self.assertTrue( os.path.exists(expected_path), f"File not found: {expected_path}" ) # Test 2: Without coarse inputs expected_path = plot_power_spectra( predictions=self.predictions, targets=self.targets, coarse_inputs=None, dlat=dlat, dlon=dlon, variable_names=self.variable_names, save_dir=self.output_dir, filename="power_spectra_no_coarse.png", ) self.assertTrue( os.path.exists(expected_path), f"File not found: {expected_path}" ) # Test 3: PyTorch tensors predictions_tensor = torch.from_numpy(self.predictions) targets_tensor = torch.from_numpy(self.targets) coarse_tensor = torch.from_numpy(self.coarse_inputs) expected_path = plot_power_spectra( predictions=predictions_tensor, targets=targets_tensor, coarse_inputs=coarse_tensor, dlat=dlat, dlon=dlon, variable_names=self.variable_names, save_dir=self.output_dir, filename="power_spectra_torch.png", ) self.assertTrue( os.path.exists(expected_path), f"File not found: {expected_path}" ) if self.logger: self.logger.info("✅ All power spectra tests passed")
[docs] def test_spatiotemporal_histograms_comprehensive(self): """Comprehensive test for spatiotemporal histograms.""" if self.logger: self.logger.info("Testing spatiotemporal histograms comprehensively") class MockSteps: latitude = 180 longitude = 360 steps = MockSteps() # Test 1: Dense data tindex_lim = (0, 365) # n_samples = 2000 centers = [] tindices = [] clusters = [ { "lat_range": (30, 60), "lon_range": (200, 250), "time_range": (0, 100), "n": 500, }, { "lat_range": (10, 40), "lon_range": (100, 150), "time_range": (100, 200), "n": 400, }, { "lat_range": (50, 80), "lon_range": (300, 350), "time_range": (200, 300), "n": 600, }, { "lat_range": (0, 30), "lon_range": (50, 100), "time_range": (300, 365), "n": 500, }, ] for cluster in clusters: for _ in range(cluster["n"]): lat = np.random.randint( cluster["lat_range"][0], cluster["lat_range"][1] ) lon = np.random.randint( cluster["lon_range"][0], cluster["lon_range"][1] ) tindex = np.random.randint( cluster["time_range"][0], cluster["time_range"][1] ) centers.append((lat, lon)) tindices.append(tindex) expected_path = plot_spatiotemporal_histograms( steps=steps, tindex_lim=tindex_lim, centers=centers, tindices=tindices, mode="train", filename="spatiotemporal_dense_", save_dir=self.output_dir, ) self.assertTrue( os.path.exists(expected_path), f"File not found: {expected_path}" ) if self.logger: self.logger.info("✅ All spatiotemporal histogram tests passed")
[docs] def test_plot_surface_comprehensive(self): """Comprehensive test for surface plots.""" if self.logger: self.logger.info("Testing surface plots comprehensively") # Test case 1: Standard configuration lat_1d = np.linspace(30, 50, 48) lon_1d = np.linspace(-120, -80, 68) # Create synthetic data batch_size = 1 n_vars = 3 h, w = 48, 68 # Create spatial patterns x = np.linspace(0, 3 * np.pi, w) y = np.linspace(0, 3 * np.pi, h) X, Y = np.meshgrid(x, y) # Initialize arrays coarse_inputs = np.zeros((batch_size, n_vars, h, w)) targets = np.zeros((batch_size, n_vars, h, w)) pred = np.zeros((batch_size, n_vars, h, w)) base_patterns = [ np.sin(X / 2) * np.cos(Y / 2), np.exp(-0.01 * (X - 24) ** 2 - 0.01 * (Y - 24) ** 2), X * Y / 200, ] for i in range(n_vars): base_pattern = base_patterns[i % len(base_patterns)] pattern = base_pattern * 20 + 280 # Temperature-like coarse_inputs[0, i] = pattern + np.random.randn(h, w) * 2 targets[0, i] = pattern + np.random.randn(h, w) * 1 pred[0, i] = targets[0, i] + np.random.randn(h, w) * 0.3 variable_names = ["Temp", "Press", "Humid"] timestamp = datetime(2024, 1, 1, 12, 0) # Test with numpy arrays expected_path = plot_surface( coarse_inputs=coarse_inputs, targets=targets, predictions=pred, lat_1d=lat_1d, lon_1d=lon_1d, timestamp=timestamp, variable_names=variable_names, filename="plot_surface_standard.png", save_dir=self.output_dir, ) self.assertTrue( os.path.exists(expected_path), f"File not found: {expected_path}" ) # Test with PyTorch tensors coarse_inputs_tensor = torch.from_numpy(coarse_inputs.copy()) targets_tensor = torch.from_numpy(targets.copy()) pred_tensor = torch.from_numpy(pred.copy()) expected_path = plot_surface( coarse_inputs=coarse_inputs_tensor, targets=targets_tensor, predictions=pred_tensor, lat_1d=lat_1d, lon_1d=lon_1d, timestamp=timestamp, variable_names=variable_names, filename="plot_surface_torch.png", save_dir=self.output_dir, ) self.assertTrue( os.path.exists(expected_path), f"File not found: {expected_path}" ) if self.logger: self.logger.info("✅ All surface plot tests passed")
[docs] def test_plot_ensemble_surface_comprehensive(self): """Comprehensive test for ensemble surface plots.""" if self.logger: self.logger.info("Testing ensemble surface plots comprehensively") lat_1d = np.linspace(30, 50, 48) lon_1d = np.linspace(-120, -80, 68) # Ensemble configuration N_ens = 5 n_vars = 3 H, W = 48, 68 x = np.linspace(0, 3 * np.pi, W) y = np.linspace(0, 3 * np.pi, H) X, Y = np.meshgrid(x, y) # different types of spatial patterns base_patterns = [ np.sin(X / 2) * np.cos(Y / 2), # sinusoidal pattern np.exp(-0.01 * (X - 24) ** 2 - 0.01 * (Y - 24) ** 2), # Gaussian blob X * Y / 200, # linear gradient ] # ensemble array: [N_ens, n_vars, H, W] predictions_ens = np.zeros((N_ens, n_vars, H, W)) # Generate ensemble members for k in range(N_ens): for i in range(n_vars): # Select a base spatial pattern for each variable base = base_patterns[i % len(base_patterns)] # Scale to realistic physical values signal = base * 20 + 280 # Add Gaussian noise to simulate ensemble spread predictions_ens[k, i] = signal + np.random.randn(H, W) * (0.5 + k * 0.2) variable_names = ["Temp", "Press", "Humid"] timestamp = datetime(2024, 1, 1, 12, 0) # Test 1: numpy expected_path = plot_ensemble_surface( predictions_ens=predictions_ens, lat_1d=lat_1d, lon_1d=lon_1d, variable_names=variable_names, timestamp=timestamp, filename="plot_ensemble_surface_numpy.png", save_dir=self.output_dir, ) self.assertTrue(os.path.exists(expected_path)) self.assertGreater(os.path.getsize(expected_path), 0) # Test 2: torch expected_path = plot_ensemble_surface( predictions_ens=torch.from_numpy(predictions_ens), lat_1d=lat_1d, lon_1d=lon_1d, variable_names=variable_names, timestamp=timestamp, filename="plot_ensemble_surface_torch.png", save_dir=self.output_dir, ) self.assertTrue(os.path.exists(expected_path)) self.assertGreater(os.path.getsize(expected_path), 0) if self.logger: self.logger.info("✅ All ensemble surface plot tests passed")
[docs] def test_plot_zoom_comparison_comprehensive(self): """Comprehensive test for zoom comparison plots.""" if self.logger: self.logger.info("Testing zoom comparison plots comprehensively") # Grid lat_1d = np.linspace(-90, 90, 144) lon_1d = np.linspace(0, 360, 360, endpoint=False) batch_size = 1 n_vars = 3 H, W = 144, 360 targets = np.zeros((batch_size, n_vars, H, W)) predictions = np.zeros((batch_size, n_vars, H, W)) for i in range(n_vars): base = np.ones((H, W)) * (280 + i * 5) targets[0, i] = base + np.random.randn(H, W) * 1.0 predictions[0, i] = targets[0, i] + np.random.randn(H, W) * 0.5 variable_names = ["Temp", "Press", "Humid"] zoom_box = { "lat_min": -23, "lat_max": 13, "lon_min": 255, "lon_max": 345, } # Test 1: numpy expected_path = plot_zoom_comparison( predictions=predictions, targets=targets, lat_1d=lat_1d, lon_1d=lon_1d, variable_names=variable_names, filename="plot_zoom_numpy.png", save_dir=self.output_dir, zoom_box=zoom_box, ) self.assertTrue(os.path.exists(expected_path)) self.assertGreater(os.path.getsize(expected_path), 0) # Test 2: torch expected_path = plot_zoom_comparison( predictions=torch.from_numpy(predictions), targets=torch.from_numpy(targets), lat_1d=lat_1d, lon_1d=lon_1d, variable_names=variable_names, filename="plot_zoom_torch.png", save_dir=self.output_dir, zoom_box=zoom_box, ) self.assertTrue(os.path.exists(expected_path)) self.assertGreater(os.path.getsize(expected_path), 0) if self.logger: self.logger.info("✅ All zoom comparison plot tests passed")
[docs] def test_plot_global_surface_robinson_comprehensive(self): """Comprehensive test for global Robinson surface plots.""" if self.logger: self.logger.info("Testing global Robinson surface plots comprehensively") # GLOBAL domain H, W = 90, 180 lat_1d = np.linspace(-90, 90, H) lon_1d = np.linspace(-180, 180, W) batch_size = 1 n_vars = 3 # Create synthetic global patterns x = np.linspace(-np.pi, np.pi, W) y = np.linspace(-np.pi / 2, np.pi / 2, H) X, Y = np.meshgrid(x, y) coarse_inputs = np.zeros((batch_size, n_vars, H, W)) targets = np.zeros((batch_size, n_vars, H, W)) pred = np.zeros((batch_size, n_vars, H, W)) base_patterns = [np.sin(X) * np.cos(Y), np.cos(2 * X) * np.sin(Y), X * Y] for i in range(n_vars): pattern = base_patterns[i] * 10 + 280 coarse_inputs[0, i] = pattern + np.random.randn(H, W) * 2 targets[0, i] = pattern + np.random.randn(H, W) pred[0, i] = targets[0, i] + np.random.randn(H, W) * 0.3 variable_names = ["Temp", "Press", "Humid"] timestamp = datetime(2024, 1, 1, 12, 0) # Test with numpy arrays expected_path = plot_global_surface_robinson( predictions=pred, targets=targets, coarse_inputs=coarse_inputs, lat_1d=lat_1d, lon_1d=lon_1d, timestamp=timestamp, variable_names=variable_names, filename="plot_global_robinson_standard.png", save_dir=self.output_dir, ) self.assertTrue(os.path.exists(expected_path)) # Test with PyTorch tensors expected_path = plot_global_surface_robinson( predictions=torch.from_numpy(pred), targets=torch.from_numpy(targets), coarse_inputs=torch.from_numpy(coarse_inputs), lat_1d=lat_1d, lon_1d=lon_1d, timestamp=timestamp, variable_names=variable_names, filename="plot_global_robinson_torch.png", save_dir=self.output_dir, ) self.assertTrue(os.path.exists(expected_path)) if self.logger: self.logger.info("✅ All global Robinson surface plot tests passed")
[docs] def test_plot_mae_map_comprehensive(self): """Comprehensive test for time-averaged MAE spatial map plots.""" if self.logger: self.logger.info("Testing MAE map plots comprehensively") # Regional lat/lon lat_1d = np.linspace(30, 50, 48) lon_1d = np.linspace(-120, -80, 68) # Matching spatial resolution predictions = self.predictions[:, :, :48, :68] targets = self.targets[:, :, :48, :68] # Test 1: Standard numpy inputs expected_path = plot_MAE_map( predictions=predictions, targets=targets, lat_1d=lat_1d, lon_1d=lon_1d, variable_names=self.variable_names, save_dir=self.output_dir, filename="validation_mae_map_standard.png", ) self.assertTrue( os.path.exists(expected_path), f"File not found: {expected_path}" ) # Test 2: PyTorch tensors expected_path = plot_MAE_map( predictions=torch.from_numpy(predictions), targets=torch.from_numpy(targets), lat_1d=lat_1d, lon_1d=lon_1d, variable_names=self.variable_names, save_dir=self.output_dir, filename="validation_mae_map_torch.png", ) self.assertTrue( os.path.exists(expected_path), f"File not found: {expected_path}" ) # Test 3: Single variable expected_path = plot_MAE_map( predictions=predictions[:, 0:1], targets=targets[:, 0:1], lat_1d=lat_1d, lon_1d=lon_1d, variable_names=[self.variable_names[0]], save_dir=self.output_dir, filename="validation_mae_map_single_var.png", ) self.assertTrue( os.path.exists(expected_path), f"File not found: {expected_path}" ) if self.logger: self.logger.info("✅ All MAE map plot tests passed")
[docs] def test_plot_error_map_comprehensive(self): """Comprehensive test for time-averaged ERROR spatial map plots.""" if self.logger: self.logger.info("Testing ERROR map plots comprehensively") # Regional lat/lon lat_1d = np.linspace(30, 50, 48) lon_1d = np.linspace(-120, -80, 68) # Matching spatial resolution predictions = self.predictions[:, :, :48, :68] targets = self.targets[:, :, :48, :68] # Test 1: Standard numpy inputs expected_path = plot_error_map( predictions=predictions, targets=targets, lat_1d=lat_1d, lon_1d=lon_1d, variable_names=self.variable_names, save_dir=self.output_dir, filename="validation_error_map_standard.png", ) self.assertTrue( os.path.exists(expected_path), f"File not found: {expected_path}" ) # Test 2: PyTorch tensors expected_path = plot_error_map( predictions=torch.from_numpy(predictions), targets=torch.from_numpy(targets), lat_1d=lat_1d, lon_1d=lon_1d, variable_names=self.variable_names, save_dir=self.output_dir, filename="validation_error_map_torch.png", ) self.assertTrue( os.path.exists(expected_path), f"File not found: {expected_path}" ) # Test 3: Single variable expected_path = plot_error_map( predictions=predictions[:, 0:1], targets=targets[:, 0:1], lat_1d=lat_1d, lon_1d=lon_1d, variable_names=[self.variable_names[0]], save_dir=self.output_dir, filename="validation_error_map_single_var.png", ) self.assertTrue( os.path.exists(expected_path), f"File not found: {expected_path}" ) if self.logger: self.logger.info("✅ All ERROR map plot tests passed")
[docs] def test_plot_spread_skill_ratio_map_comprehensive(self): """Comprehensive test for time-averaged MAE spatial map plots.""" if self.logger: self.logger.info("Testing SSR map plots comprehensively") # Regional lat/lon lat_1d = np.linspace(30, 50, 48) lon_1d = np.linspace(-120, -80, 68) # Matching spatial resolution predictions = self.predictions[:, :, :48, :68] # Transform predictions into ensemble by adding noise: T, C, h, w = predictions.shape noise = np.random.normal(size=(10, T, C, h, w)) predictions_ensemble = predictions + noise targets = self.targets[:, :, :48, :68] # Test 1: Standard numpy inputs expected_path = plot_spread_skill_ratio_map( predictions=predictions_ensemble, targets=targets, lat_1d=lat_1d, lon_1d=lon_1d, variable_names=self.variable_names, save_dir=self.output_dir, filename="validation_ssr_map_standard.png", ) self.assertTrue( os.path.exists(expected_path), f"File not found: {expected_path}" ) # Test 2: PyTorch tensors expected_path = plot_spread_skill_ratio_map( predictions=torch.from_numpy(predictions_ensemble), targets=torch.from_numpy(targets), lat_1d=lat_1d, lon_1d=lon_1d, variable_names=self.variable_names, save_dir=self.output_dir, filename="validation_ssr_map_torch.png", ) self.assertTrue( os.path.exists(expected_path), f"File not found: {expected_path}" ) # Test 3: Single variable expected_path = plot_spread_skill_ratio_map( predictions=predictions_ensemble[:, :, 0:1], targets=targets[:, 0:1], lat_1d=lat_1d, lon_1d=lon_1d, variable_names=[self.variable_names[0]], save_dir=self.output_dir, filename="validation_ssr_map_single_var.png", ) self.assertTrue( os.path.exists(expected_path), f"File not found: {expected_path}" ) if self.logger: self.logger.info("✅ All spread skill ratio map plot tests passed")
[docs] def test_plot_spread_skill_ratio_hexbin_comprehensive(self): """Comprehensive test for spread skill ratio hexbin plots""" if self.logger: self.logger.info("Testing SSR hexbin plots comprehensively") # Matching spatial resolution predictions = self.predictions[:, :, :48, :68] # Transform predictions into ensemble by adding noise: T, C, h, w = predictions.shape noise = np.random.normal(size=(10, T, C, h, w)) predictions_ensemble = predictions + noise targets = self.targets[:, :, :48, :68] # Test 1: Standard numpy inputs expected_path = plot_spread_skill_ratio_hexbin( predictions=predictions_ensemble, targets=targets, variable_names=self.variable_names, save_dir=self.output_dir, filename="validation_ssr_hexbin_standard.png", ) self.assertTrue( os.path.exists(expected_path), f"File not found: {expected_path}" ) # Test 2: PyTorch tensors expected_path = plot_spread_skill_ratio_hexbin( predictions=torch.from_numpy(predictions_ensemble), targets=torch.from_numpy(targets), variable_names=self.variable_names, save_dir=self.output_dir, filename="validation_ssr_hexbin_torch.png", ) self.assertTrue( os.path.exists(expected_path), f"File not found: {expected_path}" ) # Test 3: Single variable expected_path = plot_spread_skill_ratio_hexbin( predictions=predictions_ensemble[:, :, 0:1], targets=targets[:, 0:1], variable_names=[self.variable_names[0]], save_dir=self.output_dir, filename="validation_ssr_hexbin_single_var.png", ) self.assertTrue( os.path.exists(expected_path), f"File not found: {expected_path}" ) if self.logger: self.logger.info("✅ All spread skill ratio hexbin plot tests passed")
[docs] def test_plot_mean_divergence_map_comprehensive(self): """Comprehensive test for mean divergence map plots.""" if self.logger: self.logger.info("Testing divergence map plots comprehensively") # Regional lat/lon lat_1d = np.linspace(30, 50, 48) lon_1d = np.linspace(-120, -80, 68) # Matching spatial resolution u_pred = self.predictions[:, 0, :48, :68] v_pred = self.predictions[:, 1, :48, :68] u_target = self.targets[:, 0, :48, :68] v_target = self.targets[:, 1, :48, :68] # Test 1: Standard numpy inputs expected_path = plot_mean_divergence_map( u_pred, v_pred, u_target, v_target, spacing=1, lat_1d=lat_1d, lon_1d=lon_1d, save_dir=self.output_dir, filename="validation_mean_divergence_map_standard.png", ) self.assertTrue( os.path.exists(expected_path), f"File not found: {expected_path}" ) # Test 2: PyTorch tensors expected_path = plot_mean_divergence_map( torch.from_numpy(u_pred), torch.from_numpy(v_pred), torch.from_numpy(u_target), torch.from_numpy(v_target), spacing=1, lat_1d=lat_1d, lon_1d=lon_1d, save_dir=self.output_dir, filename="validation_mean_divergence_map_torch.png", ) self.assertTrue( os.path.exists(expected_path), f"File not found: {expected_path}" ) if self.logger: self.logger.info("✅ All mean divergence map plot tests passed")
[docs] def test_plot_mean_curl_map_comprehensive(self): """Comprehensive test for mean curl map plots.""" if self.logger: self.logger.info("Testing curl map plots comprehensively") # Regional lat/lon lat_1d = np.linspace(30, 50, 48) lon_1d = np.linspace(-120, -80, 68) # Matching spatial resolution u_pred = self.predictions[:, 0, :48, :68] v_pred = self.predictions[:, 1, :48, :68] u_target = self.targets[:, 0, :48, :68] v_target = self.targets[:, 1, :48, :68] # Test 1: Standard numpy inputs expected_path = plot_mean_curl_map( u_pred, v_pred, u_target, v_target, spacing=1, lat_1d=lat_1d, lon_1d=lon_1d, save_dir=self.output_dir, filename="validation_mean_curl_map_standard.png", ) self.assertTrue( os.path.exists(expected_path), f"File not found: {expected_path}" ) # Test 2: PyTorch tensors expected_path = plot_mean_curl_map( torch.from_numpy(u_pred), torch.from_numpy(v_pred), torch.from_numpy(u_target), torch.from_numpy(v_target), spacing=1, lat_1d=lat_1d, lon_1d=lon_1d, save_dir=self.output_dir, filename="validation_mean_curl_map_torch.png", ) self.assertTrue( os.path.exists(expected_path), f"File not found: {expected_path}" ) if self.logger: self.logger.info("✅ All mean curl map plot tests passed")
[docs] def test_plot_dry_frequency_map_comprehensive(self): """Comprehensive test for dry frequency map plots.""" if self.logger: self.logger.info("Testing dry frequency map plots comprehensively") # Regional lat/lon lat_1d = np.linspace(30, 50, 48) lon_1d = np.linspace(-120, -80, 68) # Matching spatial resolution predictions = self.predictions[:, :, :48, :68] targets = self.targets[:, :, :48, :68] # Test 1: Standard numpy inputs expected_path = plot_dry_frequency_map( predictions=predictions[:, 0, :, :], targets=targets[:, 0, :, :], threshold=1, lat_1d=lat_1d, lon_1d=lon_1d, save_dir=self.output_dir, filename="validation_dry_frequency_map_standard.png", ) self.assertTrue( os.path.exists(expected_path), f"File not found: {expected_path}" ) # Test 2: PyTorch tensors expected_path = plot_dry_frequency_map( predictions=torch.from_numpy(predictions[:, 0, :, :]), targets=torch.from_numpy(targets[:, 0, :, :]), threshold=1, lat_1d=lat_1d, lon_1d=lon_1d, save_dir=self.output_dir, filename="validation_dry_frequency_map_torch.png", ) self.assertTrue( os.path.exists(expected_path), f"File not found: {expected_path}" ) if self.logger: self.logger.info("✅ All dry frequency map plot tests passed")
[docs] def test_dry_frequency_map(self): """Comprehensive test for the dry frequency map compute function.""" if self.logger: self.logger.info( "Testing dry frequency map compute function comprehensively" ) predictions = self.predictions[:, :, :48, :68] # Test 1 : standard numpy inputs arr = dry_frequency_map(predictions[:, 0, :, :], 1) self.assertTrue(arr.shape == predictions.shape[-2:]) # Test 1 : torch tensors arr = dry_frequency_map(torch.from_numpy(predictions[:, 0, :, :]), 1) self.assertTrue(arr.shape == predictions.shape[-2:]) if self.logger: self.logger.info("✅ All dry frequency tests passed")
[docs] def test_divergence(self): """Comprehensive test for the divergence compute function.""" if self.logger: self.logger.info("Testing divergence compute function comprehensively") u = self.predictions[:, 0, :48, :68] v = self.predictions[:, 1, :48, :68] # test 1 : standard numpy inputs : div = get_divergence(u, v, spacing=1) self.assertTrue(div.shape == u.shape) # test 2 : torch inputs : div_torch = get_divergence(torch.from_numpy(u), torch.from_numpy(v), spacing=1) self.assertTrue(div_torch.shape == u.shape) if self.logger: self.logger.info("✅ All divergence tests passed")
[docs] def test_curl(self): """Comprehensive test for the curl compute function.""" if self.logger: self.logger.info("Testing curl compute function comprehensively") u = self.predictions[:, 0, :48, :68] v = self.predictions[:, 1, :48, :68] # test 1 : standard numpy inputs : curl = get_curl(u, v, spacing=1) self.assertTrue(curl.shape == u.shape) # test 2 : torch inputs : curl_torch = get_curl(torch.from_numpy(u), torch.from_numpy(v), spacing=1) self.assertTrue(curl_torch.shape == u.shape) if self.logger: self.logger.info("✅ All curl tests passed")
[docs] def test_metric_plots_comprehensive(self): """Comprehensive test for metric plots.""" if self.logger: self.logger.info("Testing metric plots comprehensively") # Test 1: Metric histories expected_path = plot_metric_histories( valid_metrics_history=self.valid_metrics_history, variable_names=[name.split(" ")[0] for name in self.variable_names], metric_names=["rmse", "mae", "r2"], save_dir=self.output_dir, filename="metric_histories_comprehensive", ) self.assertTrue( os.path.exists(expected_path), f"File not found: {expected_path}" ) # Test 2: Loss histories expected_path = plot_loss_histories( train_loss_history=self.train_loss_history, valid_loss_history=self.valid_loss_history, save_dir=self.output_dir, filename="loss_histories_standard.png", ) self.assertTrue( os.path.exists(expected_path), f"File not found: {expected_path}" ) # Test 3: Average metrics expected_path = plot_average_metrics( valid_metrics_history=self.valid_metrics_history, metric_names=["rmse", "mae", "r2"], save_dir=self.output_dir, filename="average_metrics_standard.png", ) self.assertTrue( os.path.exists(expected_path), f"File not found: {expected_path}" ) if self.logger: self.logger.info("✅ All metric plot tests passed")
[docs] def test_plot_metrics_heatmap_comprehensive(self): """Comprehensive test for validation metrics heatmap.""" if self.logger: self.logger.info("Testing metrics heatmap") # Local dummy MetricTracker class DummyMetricTracker: def __init__(self, values): self.values = np.asarray(values) self.count = len(self.values) def getmean(self): return float(np.mean(self.values)) if self.count > 0 else np.nan # Fake MetricTracker-based metrics dict valid_metrics_trackers = {} for var in self.variable_names: var_key = var.split(" ")[0] for metric in ["rmse", "mae", "r2"]: # Reuse the existing synthetic histories history = self.valid_metrics_history[f"{var_key}_pred_vs_fine_{metric}"] valid_metrics_trackers[f"{var_key}_pred_vs_fine_{metric}"] = ( DummyMetricTracker(history) ) expected_path = plot_metrics_heatmap( valid_metrics_history=valid_metrics_trackers, variable_names=[name.split(" ")[0] for name in self.variable_names], metric_names=["rmse", "mae", "r2"], save_dir=self.output_dir, filename="metrics_heatmap_comprehensive", ) self.assertTrue( os.path.exists(expected_path), f"File not found: {expected_path}" ) if self.logger: self.logger.info("✅ Metrics heatmap test passed")
[docs] def test_qq_quantiles_comprehensive(self): """Comprehensive test for QQ-quantiles plots.""" if self.logger: self.logger.info("Testing QQ-quantiles plots comprehensively") # Test 1: Standard configuration with all parameters expected_path = plot_qq_quantiles( predictions=self.predictions, targets=self.targets, coarse_inputs=self.coarse_inputs, variable_names=self.variable_names, quantiles=[0.90, 0.95, 0.975, 0.99, 0.995], save_dir=self.output_dir, filename="qq_quantiles_standard.png", ) self.assertTrue( os.path.exists(expected_path), f"File not found: {expected_path}" ) # Test 4: Single variable (edge case) expected_path = plot_qq_quantiles( predictions=self.predictions[:, 0:1], # Keep only first variable targets=self.targets[:, 0:1], coarse_inputs=self.coarse_inputs[:, 0:1], variable_names=["Temperature (K)"], quantiles=[0.90, 0.95, 0.99], save_dir=self.output_dir, filename="qq_quantiles_single_var.png", ) self.assertTrue( os.path.exists(expected_path), f"File not found: {expected_path}" ) # Test 5: PyTorch tensors predictions_tensor = torch.from_numpy(self.predictions) targets_tensor = torch.from_numpy(self.targets) coarse_tensor = torch.from_numpy(self.coarse_inputs) expected_path = plot_qq_quantiles( predictions=predictions_tensor, targets=targets_tensor, coarse_inputs=coarse_tensor, variable_names=self.variable_names, quantiles=[0.90, 0.95, 0.975, 0.99, 0.995], save_dir=self.output_dir, filename="qq_quantiles_torch.png", ) self.assertTrue( os.path.exists(expected_path), f"File not found: {expected_path}" ) if self.logger: self.logger.info("✅ All QQ-quantiles tests passed")
[docs] def test_mv_correlation(self): """Test for correlation over the time dimension for pairs of variables. Test for correlation over the spatial dimensions. """ # Define lat lon grid w = self.predictions.shape[2] h = self.predictions.shape[3] dlat = 20 dlon = 40 lat1 = 30 lon1 = -120 lon, lat = np.meshgrid( np.linspace(lon1, lon1 + dlon, w), np.linspace(lat1, lat1 + dlat, h) ) # Test 1: Standard configuration Numpy arrays for correlation over time dimension expected_path = plot_validation_mvcorr( predictions=self.predictions, targets=self.targets, lat=lat, lon=lon, variable_names=self.variable_names, save_dir=self.output_dir, filename="validation_mvcorr_numpy.png", figsize_multiplier=3, ) self.assertTrue( os.path.exists(expected_path), f"File not found: {expected_path}" ) expected_path = plot_validation_mvcorr( predictions=self.predictions, targets=self.targets, lat=lat, lon=lon, coarse_inputs=self.coarse_inputs, variable_names=self.variable_names, save_dir=self.output_dir, filename="comparison_mvcorr_numpy.png", figsize_multiplier=3, ) self.assertTrue( os.path.exists(expected_path), f"File not found: {expected_path}" ) # Test 2: Standard configuration PyTorch tensors for correlation over time dimension coarse_tensor = torch.from_numpy(self.coarse_inputs.copy()) fine_tensor = torch.from_numpy(self.targets.copy()) pred_tensor = torch.from_numpy(self.predictions.copy()) expected_path = plot_validation_mvcorr( predictions=pred_tensor, targets=fine_tensor, lat=lat, lon=lon, coarse_inputs=coarse_tensor, variable_names=self.variable_names, save_dir=self.output_dir, filename="comparison_mvcorr_torch.png", figsize_multiplier=3, ) self.assertTrue( os.path.exists(expected_path), f"File not found: {expected_path}" ) # Test 3: Standard configuration Numpy arrays for correlation over space dimensions expected_path = plot_validation_mvcorr_space( predictions=self.predictions, targets=self.targets, variable_names=self.variable_names, save_dir=self.output_dir, filename="comparison_mv_corr_space_numpy.png", figsize_multiplier=3, ) self.assertTrue( os.path.exists(expected_path), f"File not found: {expected_path}" ) # Test 4: Standard configuration Numpy arrays for correlation over space dimensions expected_path = plot_validation_mvcorr_space( predictions=self.predictions, targets=self.targets, coarse_inputs=coarse_tensor, variable_names=self.variable_names, save_dir=self.output_dir, filename="comparison_mvcorr_space_torch.png", figsize_multiplier=3, ) self.assertTrue( os.path.exists(expected_path), f"File not found: {expected_path}" ) if self.logger: self.logger.info("✅ All correlation plots tests passed")
[docs] def test_temporal_series_comparison_comprehensive(self): """Comprehensive test for spatially averaged temporal series comparison.""" if self.logger: self.logger.info("Testing temporal series comparison comprehensively") # Here we reinterpret batch_size as time dimension T T = 100 C = self.num_vars H = self.h W = self.w # Create synthetic temporal signal time = np.linspace(0, 4 * np.pi, T) predictions = np.zeros((T, C, H, W)) targets = np.zeros((T, C, H, W)) for c in range(C): for t in range(T): seasonal_signal = np.sin(time[t]) * (c + 1) spatial_pattern = ( np.sin(np.linspace(0, 2 * np.pi, W))[None, :] * np.cos(np.linspace(0, 2 * np.pi, H))[:, None] ) targets[t, c] = seasonal_signal + spatial_pattern predictions[t, c] = ( seasonal_signal + spatial_pattern + np.random.normal(0, 0.1, (H, W)) ) # Test 1: numpy inputs expected_path = plot_temporal_series_comparison( predictions=predictions, targets=targets, variable_names=self.variable_names, save_dir=self.output_dir, filename="temporal_series_numpy.png", ) self.assertTrue( os.path.exists(expected_path), f"File not found: {expected_path}", ) # Test 2: torch tensors expected_path_torch = plot_temporal_series_comparison( predictions=torch.from_numpy(predictions), targets=torch.from_numpy(targets), variable_names=self.variable_names, save_dir=self.output_dir, filename="temporal_series_torch.png", ) self.assertTrue( os.path.exists(expected_path_torch), f"File not found: {expected_path_torch}", ) # Test 3: shape mismatch error with self.assertRaises(ValueError): plot_temporal_series_comparison( predictions=predictions, targets=targets[:, :, :-1, :], # wrong shape variable_names=self.variable_names, save_dir=self.output_dir, filename="temporal_series_error.png", ) if self.logger: self.logger.info("✅ All temporal series comparison tests passed")
[docs] def test_ranks(self): """Comprehensive test for the ranks compute function.""" if self.logger: self.logger.info("Testing ranks compute function comprehensively") predictions = self.predictions[:, :, :48, :68] targets = self.targets[:, 0, :48, :68] ensemble_size = 10 # Test 1 : standard numpy inputs predictions_repeated = np.repeat( predictions[None, :, 0, :, :], axis=0, repeats=ensemble_size ) arr = ranks(predictions=predictions_repeated, targets=targets) self.assertTrue(arr.shape == targets.flatten().shape) # Test 2 : torch tensors predictions_repeated_torch = torch.from_numpy(predictions_repeated) targets_torch = torch.from_numpy(targets) arr_torch = ranks(predictions=predictions_repeated_torch, targets=targets_torch) self.assertTrue(arr_torch.shape == targets.flatten().shape) if self.logger: self.logger.info("✅ All ranks tests passed")
[docs] def test_plot_ranks(self): """Comprehensive test for the ranks plot function.""" if self.logger: self.logger.info("Testing plot_ranks function comprehensively") predictions = self.predictions[:, :, :48, :68] targets = self.targets[:, :, :48, :68] ensemble_size = 10 # Test 1 : standard numpy inputs predictions_repeated = np.repeat( predictions[None, :, :, :, :], axis=0, repeats=ensemble_size ) expected_path = plot_ranks( predictions=predictions_repeated, targets=targets, variable_names=self.variable_names, save_dir=self.output_dir, filename="ranks.png", ) self.assertTrue( os.path.exists(expected_path), f"File not found: {expected_path}" ) # Test 2 : torch inputs expected_path_torch = plot_ranks( predictions=torch.from_numpy(predictions_repeated), targets=torch.from_numpy(targets), variable_names=self.variable_names, save_dir=self.output_dir, filename="ranks_torch.png", ) self.assertTrue( os.path.exists(expected_path_torch), f"File not found: {expected_path}" ) if self.logger: self.logger.info("✅ All ranks plot tests passed")
[docs] def tearDown(self): """Clean up after tests.""" # Note: We don't remove the output directory so you can inspect the plots if self.logger: self.logger.info( f"Plotting tests completed - plots available in: {self.output_dir}" )
[docs] class TestSSRFunction(unittest.TestCase): """Unit tests for crps_ensemble_all function."""
[docs] def __init__(self, methodName="runTest", logger=None): super().__init__(methodName) self.logger = logger
[docs] def setUp(self): """Set up test fixtures.""" if self.logger: self.logger.info("Setting up spread_skill_ratio function test fixtures")
[docs] def test_ssr_basic(self): """Test SSR with simple known values.""" if self.logger: self.logger.info("Testing SSR basic functionality") np.random.seed(0) # set random seed for reproducibility. true = np.zeros((100, 1, 10, 10)) # shape [T,C,h,w] pred_ens = np.random.normal( loc=0, scale=0.1, size=(10, 100, 1, 10, 10) ) # N_ens = 10 ssr = spread_skill_ratio( predictions=pred_ens, targets=true, variable_names=None, pixel_wise=False )[0] # get element of single element array. # expected SSR should be ~3.15 : # let $X_{i,r}$ be the prediction for ensemble member r. # as targets = mean of predictions, RMSE should be very close to variance of predictions (= 0.01) # deriving the expected value of the spread is more complicated : # $ spread = \frac{11}{10} \sqrt{ \frac{1}{N} \sum_{i,r} (X_{i,r} - \bar X_i)^2 } $ # We can inject the definition of $\bar X_i = \frac{1}{R} \sum_{r'} X_{i,r'} $ # and develop the squared term : # $ X_{i,r} \frac{R-1}{R} - \frac{1}{R} \sum_{r' != r} X_{i,r'} $ # we can replace the mean by the expectance operator, and develop the squared term : # $ R \times \mathbb{E} [ ( X_{i,r} \frac{R-1}{R} - \frac{1}{R} \sum_{r' != r} X_{i,r'} )^2 ] $ # Since all X_{i,r'} are iid, we can develop the square inside the expected value, the covariance terms will be 0 # and we are left with : # $ R \times [ \mathbb{V}(X_{i,r}) (\frac{R-1}{R})^2 + \frac{R-1}{R^2} \mathbb{V}(X_{i,r'}) ] $ # plugging in R = 10, we get : # $ 10 * 0.9 * \mathbb{X}$ # this term is under a square root and multiplied by the corrective factor to give the spread : # spread = \sqrt{1.1} \times \sqrt{\mathbb{V}(X)} \times 3 # So, SSR = spread / RMSE = \sqrt{1.1} * 3 ~ 3.15 # SSR must be finite and non-negative self.assertGreaterEqual(ssr, 0.0) self.assertAlmostEqual(ssr, 3.15, places=1) if self.logger: self.logger.info(f"SSR computed : {ssr:.2f} vs SSR expected : ~ 3.15") self.logger.info("✅ SSR basic test passed")
[docs] def test_ssr_one_when_perfect_prediction(self): """Test SSR is supposed to be 1 when the predictions follow the same distribution as the truth.""" if self.logger: self.logger.info("Testing SSR perfect prediction") np.random.seed(0) # set random seed for reproducibility. true = np.random.normal( loc=0, scale=0.1, size=(100, 1, 10, 10) ) # shape [T,C,h,w] pred_ens = np.random.normal( loc=0, scale=0.1, size=(10, 100, 1, 10, 10) ) # N_ens = 10 ssr = spread_skill_ratio( predictions=pred_ens, targets=true, variable_names=None, pixel_wise=False )[0] # get element of single element array. self.assertAlmostEqual(ssr, 1.0, places=1) if self.logger: self.logger.info(f"SSR computed : {ssr:.2f} vs SSR expected : ~ 1.0") self.logger.info("✅ SSR perfect prediction test passed")
# ----------------------------------------------------------------------------