rtnn package

Subpackages

Submodules

rtnn.dataset module

class rtnn.dataset.DataPreprocessor(*args: Any, **kwargs: Any)[source]

Bases: Dataset

Dataset class for preprocessing LSM (Land Surface Model) data.

This class handles loading and preprocessing of NetCDF files containing climate data, with support for multiple years, spatial and temporal batching, and various normalization techniques.

Parameters:
  • logger (object) – Logger instance for logging messages.

  • dfs (List[str]) – List of file paths to NetCDF files.

  • stime (int) – Start time index.

  • tstep (int) – Number of time steps per file.

  • tbatch (int) – Temporal batch size.

  • norm_mapping (Dict, optional) – Dictionary containing normalization statistics for each variable. Default is empty dict.

  • normalization_type (Dict, optional) – Dictionary specifying normalization type for each variable. Default is empty dict.

logger

Logger instance.

Type:

object

stime

Start time index.

Type:

int

tstep

Time steps per file.

Type:

int

tbatch

Temporal batch size.

Type:

int

norm_mapping

Normalization statistics.

Type:

Dict

normalization_type

Normalization types per variable.

Type:

Dict

sbatch

Number of spatial batches.

Type:

int

years

Sorted list of years in the dataset.

Type:

List[int]

etime

End time index.

Type:

int

dfs

List of (year, spatial_index, file_path) tuples.

Type:

List[Tuple[int, int, str]]

time_blocks

Shuffled time blocks.

Type:

np.ndarray

min_dims

Minimum dimensions across files.

Type:

Dict[str, int]

cosz

Cosine of solar zenith angle variable names.

Type:

List[str]

lai

Leaf area index variable names.

Type:

List[str]

ssa

Single scattering albedo variable names.

Type:

List[str]

rs

Surface reflectance variable names.

Type:

List[str]

ov

Output variable names.

Type:

List[str]

Examples

>>> from rtnn.logger import Logger
>>> logger = Logger()
>>> files = ["data_1995.nc", "data_1996.nc"]
>>> dataset = DataPreprocessor(
...     logger=logger,
...     dfs=files,
...     stime=0,
...     tstep=100,
...     tbatch=24,
...     norm_mapping={},
...     normalization_type={}
... )
>>> len(dataset)
100
>>> features, targets = dataset[0]
>>> features.shape
torch.Size([schunk, feature_channels, seq_length])
>>> targets.shape
torch.Size([schunk, output_channels, seq_length])
__init__(logger: Any, dfs: List[str], stime: int, tbatch: int, training: bool = True, sblock_perc: float = 0.6, norm_mapping: Dict = {}, normalization_type: Dict = {}, debug: bool = False) None[source]

Initialize the DataPreprocessor.

Parameters:
  • logger (Any) – Logger instance for logging messages.

  • dfs (List[str]) – List of file paths to NetCDF files.

  • stime (int) – Start time index.

  • tbatch (int) – Temporal batch size.

  • training (bool, optional) – If True, use 60% of spatial batches (data augmentation). If False, use 100% of spatial batches (full evaluation).

  • norm_mapping (Dict, optional) – Dictionary containing normalization statistics for each variable.

  • normalization_type (Dict, optional) – Dictionary specifying normalization type for each variable.

  • debug (bool, optional) – If True, print debug information.

normalize(data: numpy.ndarray, var_name: str) numpy.ndarray[source]

Normalize data using the specified normalization method.

Parameters:
  • data (np.ndarray) – Input data array to normalize.

  • var_name (str) – Name of the variable for which to retrieve normalization statistics.

Returns:

Normalized data array.

Return type:

np.ndarray

Raises:

ValueError – If the normalization type is not supported.

Notes

Supported normalization types: - minmax: (x - min) / (max - min) - standard: (x - mean) / std - robust: (x - median) / IQR - log1p_minmax: log1p(x) normalized - log1p_standard: log1p(x) standardized - log1p_robust: log1p(x) robust normalized - sqrt_minmax: sqrt(x) normalized - sqrt_standard: sqrt(x) standardized - sqrt_robust: sqrt(x) robust normalized

rtnn.diagnostics module

Plotting utilities for RTnn model visualization.

This module provides functions for visualizing model predictions, training metrics, and data statistics. It includes tools for creating line plots, hexbin plots, histograms, and metric history plots using matplotlib.

The module supports: - Visualization of radiative transfer model predictions vs targets - Absorption rate plotting for different channels - Training and validation metric histories - Statistical distributions of input variables - Various normalization scheme visualizations

Dependencies

matplotlib : For plotting mpltex : For line styles scikit-learn : For R² score calculation xarray : For NetCDF data handling

rtnn.diagnostics.stats(file_list, logger, output_dir, norm_mapping=None, plots=False)[source]

Compute statistics and generate histograms for variables in NetCDF files.

Reads a collection of NetCDF files, computes descriptive statistics for each variable, and generates histogram plots saved to disk. In addition to raw statistics, transformed statistics using logarithmic (log1p) and square-root transformations are also computed.

Parameters:
  • file_list (list of str) – Paths to the NetCDF files to process.

  • logger (logging.Logger) – Logger used to report progress and informational messages.

  • output_dir (str) – Directory where histogram plots will be saved.

  • norm_mapping (dict, optional) – Dictionary to update with computed statistics. If None, a new dictionary is created. Default is None.

Returns:

Dictionary mapping variable names to their computed statistics. Each variable contains the following entries:

Raw statistics:
  • vmin : float

  • vmax : float

  • vmean : float

  • vstd : float

Robust statistics:
  • q1 : float

  • q3 : float

  • iqr : float

  • median : float

Log-transformed statistics (log1p):
  • log_min : float

  • log_max : float

  • log_mean : float

  • log_std : float

  • log_q1 : float

  • log_q3 : float

  • log_iqr : float

  • log_median : float

Square-root-transformed statistics:
  • sqrt_min : float

  • sqrt_max : float

  • sqrt_mean : float

  • sqrt_std : float

  • sqrt_q1 : float

  • sqrt_q3 : float

  • sqrt_iqr : float

  • sqrt_median : float

Return type:

dict

Examples

>>> norm_mapping = stats(
...     file_list=["data_1995.nc", "data_1996.nc"],
...     logger=logger,
...     output_dir="./stats"
... )
>>> norm_mapping["coszang"]["vmean"]
0.5
rtnn.diagnostics.plot_RTM(predicts, targets, filename)[source]

Plot radiative transfer model predictions against targets.

Creates a 2x2 grid of line plots showing predicted vs true values for four different flux channels. Plots random samples from the batch.

Parameters:
  • predicts (torch.Tensor) – Model predictions of shape (batch_size, 4, seq_length).

  • targets (torch.Tensor) – Ground truth targets of shape (batch_size, 4, seq_length).

  • filename (str) – Path where the plot will be saved.

Notes

  • Plots up to 7 random samples from the batch

  • Uses mpltex for line styles

  • Y-axis range: 0-1

  • X-axis: level indices

rtnn.diagnostics.plot_HeatRate(abs12_predict, abs12_target, abs34_predict, abs34_target, filename)[source]

Plot absorption rates for two channel groups.

Creates a 2-panel figure showing absorption rates for channels 1-2 and 3-4.

Parameters:
  • abs12_predict (torch.Tensor) – Predicted absorption for channels 1-2 of shape (batch_size, 1, seq_length).

  • abs12_target (torch.Tensor) – True absorption for channels 1-2.

  • abs34_predict (torch.Tensor) – Predicted absorption for channels 3-4.

  • abs34_target (torch.Tensor) – True absorption for channels 3-4.

  • filename (str) – Path where the plot will be saved.

Notes

  • Upper panel: channels 1-2

  • Lower panel: channels 3-4

  • Plots random samples from the batch

rtnn.diagnostics.plot_flux_and_abs_lines(predicts, targets, abs12_predict=None, abs12_target=None, abs34_predict=None, abs34_target=None, filename='output_lines.png')[source]

Create line plots for fluxes and absorption rates.

Generates a multi-panel figure with line plots for four flux channels and optionally two absorption panels. Each panel shows predictions vs targets.

Parameters:
  • predicts (torch.Tensor) – Model predictions for fluxes of shape (batch_size, 4, seq_length).

  • targets (torch.Tensor) – Ground truth fluxes.

  • abs12_predict (torch.Tensor, optional) – Predicted absorption for channels 1-2.

  • abs12_target (torch.Tensor, optional) – True absorption for channels 1-2.

  • abs34_predict (torch.Tensor, optional) – Predicted absorption for channels 3-4.

  • abs34_target (torch.Tensor, optional) – True absorption for channels 3-4.

  • filename (str, optional) – Output filename. Default is “output_lines.png”.

Notes

Figure layout:
  • 2x2 grid for fluxes (upwelling/downwelling for two channels)

  • Optional 1x2 grid for absorption rates (if provided)

rtnn.diagnostics.plot_flux_and_abs(predicts, targets, abs12_predict=None, abs12_target=None, abs34_predict=None, abs34_target=None, filename='output.png')[source]

Create hexbin plots for fluxes and absorption rates.

Generates a multi-panel figure with hexbin density plots showing the relationship between predicted and true values. Useful for assessing prediction accuracy across the entire dataset.

Parameters:
  • predicts (torch.Tensor) – Model predictions for fluxes of shape (batch_size, 4, seq_length).

  • targets (torch.Tensor) – Ground truth fluxes.

  • abs12_predict (torch.Tensor, optional) – Predicted absorption for channels 1-2.

  • abs12_target (torch.Tensor, optional) – True absorption for channels 1-2.

  • abs34_predict (torch.Tensor, optional) – Predicted absorption for channels 3-4.

  • abs34_target (torch.Tensor, optional) – True absorption for channels 3-4.

  • filename (str, optional) – Output filename. Default is “output.png”.

Notes

  • Hexbin plots use logarithmic color scale

  • Includes diagonal reference line (y=x)

  • Displays R² score in the top-left corner of each panel

  • Shared colorbar on the right

rtnn.diagnostics.plot_metric_histories(train_history, valid_history, filename='training_validation_metrics.png')[source]

Plot training and validation metrics over epochs.

Creates a multi-panel figure showing the evolution of various metrics (e.g., NMAE, NMSE, R2) over training epochs.

Parameters:
  • train_history (dict) – Dictionary with metric names as keys and lists of training values.

  • valid_history (dict) – Dictionary with metric names as keys and lists of validation values.

  • filename (str, optional) – Output filename. Default is “training_validation_metrics.png”.

Notes

  • Metrics are plotted on a logarithmic scale

  • Each metric gets its own panel

  • Panels are arranged in a grid (3 columns)

  • Blue lines: training, Orange lines: validation

rtnn.diagnostics.plot_loss_histories(train_loss, valid_loss, filename='training_validation_loss.png')[source]

Plot training and validation loss over epochs.

Creates a single-panel figure showing the loss evolution during training.

Parameters:
  • train_loss (list or array) – Training loss values over epochs.

  • valid_loss (list or array) – Validation loss values over epochs.

  • filename (str, optional) – Output filename. Default is “training_validation_loss.png”.

Notes

  • Uses logarithmic scale for y-axis

  • Blue line: training loss

  • Orange line: validation loss

  • Includes grid for better readability

rtnn.diagnostics.plot_spatial_temporal_density(sindex_tracker, tindex_tracker, mode='train', save_dir='./tests_plots', filename='density_scatter', figsize=(10, 10))[source]

Plot a density scatter plot of spatial index vs temporal index with marginal histograms.

This function creates a 2D density scatter plot (hexbin) showing the distribution of spatial indices (processor ranks) across temporal indices, with: - Right plot: Histogram of temporal index distribution - Top plot: Histogram of spatial index distribution

Parameters:
  • sindex_tracker (list or array-like) – List of spatial indices (processor ranks) for each data sample.

  • tindex_tracker (list or array-like) – List of temporal indices for each data sample.

  • mode (str, optional) – Dataset mode identifier (“train”, “validation”, “test”).

  • save_dir (str, optional) – Directory path where the plot will be saved.

  • filename (str, optional) – Base name for the output file.

  • figsize (tuple, optional) – Figure size as (width, height) in inches.

Returns:

Path to the saved plot file.

Return type:

str

rtnn.evaluater module

Evaluation utilities for RTnn model assessment.

This module provides comprehensive evaluation tools for radiative transfer neural network models, including custom loss functions, metric computation, and visualization helpers.

The module includes: - Custom loss functions (NMSE, NMAE, combined MSE-MAE, LogCosh, Weighted MSE) - Metric calculators for evaluation (MSE, MAE, MBE, R², NMSE, NMAE, MARE, GMRAE) - Data normalization/de-normalization utilities - Absorption rate calculations - Main evaluation loop for LSM models

Dependencies

torch : For tensor operations and loss functions numpy : For numerical operations plot_helper : For visualization utilities

class rtnn.evaluater.NMSELoss(*args: Any, **kwargs: Any)[source]

Bases: Module

Normalized Mean Squared Error Loss.

Computes MSE normalized by the mean square of the target values. Useful when the scale of the target variable varies.

Parameters:

eps (float, optional) – Small constant for numerical stability. Default is 1e-8.

Examples

>>> criterion = NMSELoss()
>>> loss = criterion(predictions, targets)
__init__(eps=1e-08)[source]
forward(pred, target)[source]
class rtnn.evaluater.NMAELoss(*args: Any, **kwargs: Any)[source]

Bases: Module

Normalized Mean Absolute Error Loss.

Computes MAE normalized by the mean absolute value of the target. Provides a scale-invariant error metric.

Parameters:

eps (float, optional) – Small constant for numerical stability. Default is 1e-8.

__init__(eps=1e-08)[source]
forward(pred, target)[source]
class rtnn.evaluater.MetricTracker[source]

Bases: object

A utility class for tracking and computing statistics of metric values.

This class maintains a running average of metric values and provides methods to compute mean and root mean squared values.

value

Cumulative weighted sum of metric values

Type:

float

count

Total number of samples processed

Type:

int

Examples

>>> tracker = MetricTracker()
>>> tracker.update(10.0, 5)  # value=10.0, count=5 samples
>>> tracker.update(20.0, 3)  # value=20.0, count=3 samples
>>> print(tracker.getmean())  # (10*5 + 20*3) / (5+3) = 110/8 = 13.75
13.75
>>> print(tracker.getsqrtmean())  # sqrt(13.75)
3.7080992435478315
__init__()[source]

Initialize MetricTracker with zero values.

reset()[source]

Reset all tracked values to zero.

Return type:

None

update(value, count)[source]

Update the tracker with new metric values.

Parameters:
  • value (float) – The metric value to add

  • count (int) – Number of samples this value represents (weight)

Return type:

None

getmean()[source]

Calculate the mean of all tracked values.

Returns:

Weighted mean of all values: total_value / total_count

Return type:

float

Raises:

ZeroDivisionError – If no values have been added (count == 0)

getstd()[source]

Calculate the standard deviation of all tracked values.

Returns:

Weighted standard deviation of all values: sqrt(E(x^2) - (E(x))^2)

Return type:

float

Raises:

ZeroDivisionError – If no values have been added (count == 0)

getsqrtmean()[source]

Calculate the square root of the mean of all tracked values.

Returns:

Square root of the weighted mean: sqrt(total_value / total_count)

Return type:

float

Raises:

ZeroDivisionError – If no values have been added (count == 0)

rtnn.evaluater.get_loss_function(loss_type, args, logger=None)[source]

Factory function to instantiate the requested loss function.

Parameters:
  • loss_type (str) – Type of loss function. Options: - ‘mse’: Mean Squared Error - ‘mae’: Mean Absolute Error - ‘nmae’: Normalized Mean Absolute Error - ‘nmse’: Normalized Mean Squared Error - ‘wmse’: Weighted Mean Squared Error - ‘logcosh’: Log-Cosh loss - ‘smoothl1’: Smooth L1 Loss (Huber-like) - ‘huber’: Huber Loss

  • args (argparse.Namespace) – Arguments containing loss-specific parameters (e.g., beta_delta for Huber).

Returns:

Initialized loss function.

Return type:

torch.nn.Module

Raises:

ValueError – If loss_type is not supported or required parameters are missing.

Examples

>>> args = argparse.Namespace(beta_delta=1.0)
>>> criterion = get_loss_function('huber', args)
rtnn.evaluater.mse_all(pred, true)[source]

Compute Mean Squared Error.

Parameters:
Returns:

(num_elements, mse_value)

Return type:

tuple

rtnn.evaluater.mbe_all(pred, true)[source]

Compute Mean Bias Error.

Parameters:
Returns:

(num_elements, mbe_value)

Return type:

tuple

rtnn.evaluater.mae_all(pred, true)[source]

Compute Mean Absolute Error.

Parameters:
Returns:

(num_elements, mae_value)

Return type:

tuple

rtnn.evaluater.r2_all(pred, true)[source]

Calculate R2 (coefficient of determination) between predicted and true values.

Computes the R2 metric and returns both the number of elements and the R2 value.

Parameters:
Returns:

(num_elements, r2_value) where: - num_elements (int): Total number of elements in the tensors - r2_value (torch.Tensor): R2 score

Return type:

tuple

Notes

R2 is calculated as:

R2 = 1 - sum((true - pred)^2) / sum((true - mean(true))^2)

This implementation is fully torch-based and works on CPU and GPU.

rtnn.evaluater.nmae_all(pred, true)[source]

Compute Normalized Mean Absolute Error.

Parameters:
Returns:

(num_elements, nmae_value)

Return type:

tuple

rtnn.evaluater.nmse_all(pred, true)[source]

Compute Normalized Mean Squared Error.

Parameters:
Returns:

(num_elements, nmse_value)

Return type:

tuple

rtnn.evaluater.mare_all(pred, true)[source]

Compute Mean Absolute Relative Error.

Parameters:
Returns:

(num_elements, mare_value)

Return type:

tuple

rtnn.evaluater.gmrae_all(pred, true)[source]

Compute Geometric Mean Relative Absolute Error.

Parameters:
Returns:

(num_elements, gmrae_value)

Return type:

tuple

rtnn.evaluater.unnorm_mpas(pred, targ, norm_mapping, normalization_type, idxmap)[source]

Reverse normalization to recover original physical values.

Applies inverse transformation based on the normalization method used during preprocessing.

Parameters:
  • pred (torch.Tensor) – Normalized predictions of shape (batch, channels, seq_length).

  • targ (torch.Tensor) – Normalized targets.

  • norm_mapping (dict) – Dictionary containing normalization statistics for each variable.

  • normalization_type (dict) – Dictionary mapping variable names to normalization types.

  • idxmap (dict) – Dictionary mapping channel indices to variable names.

Returns:

  • tuple – (unnormalized_predictions, unnormalized_targets)

  • Supported normalization types

    • minmax: x * (max - min) + min

    • standard: x * std + mean

    • robust: x * iqr + median

    • log1p_*: expm1(x * scale + offset)

    • sqrt_*: (x * scale + offset) ** 2

Examples

>>> idxmap = {0: 'collim_alb', 1: 'collim_tran'}
>>> upred, utarg = unnorm_mpas(pred, targ, norm_mapping, norm_type, idxmap)
rtnn.evaluater.calc_abs(pred, targ, p=None)[source]

Calculate absorption rates from flux predictions.

Computes net absorption rates for two channel groups (1-2 and 3-4) using the difference between upwelling and downwelling fluxes.

Parameters:
  • pred (torch.Tensor) – Predictions of shape (batch, 4, seq_length) where channels 0-1 are for first group and 2-3 for second group.

  • targ (torch.Tensor) – Targets of same shape as pred.

  • p (torch.Tensor, optional) – Pressure levels for atmospheric heating rate calculation. If provided, computes heating rate using pressure gradients.

Returns:

(abs12_pred, abs12_targ, abs34_pred, abs34_targ) where each is a tensor of shape (batch, 1, seq_length-1).

Return type:

tuple

Notes

  • If p is None: returns d(net) where net = up - down

  • If p is provided: returns heating rate using d(net)/dp

rtnn.evaluater.calc_hr(up, down, p=None)[source]

Calculate heating rate from upwelling and downwelling fluxes.

Computes the net radiative heating rate by taking the vertical derivative of net flux (upwelling - downwelling). If pressure levels are provided, calculates the actual atmospheric heating rate using pressure gradients.

Parameters:
  • up (torch.Tensor) – Upwelling flux tensor of shape (batch, channels, seq_length).

  • down (torch.Tensor) – Downwelling flux tensor of shape (batch, channels, seq_length).

  • p (torch.Tensor, optional) – Pressure levels of shape (seq_length,) or (batch, seq_length). If provided, computes physical heating rate. If None, returns the derivative of net flux.

Returns:

If p is None:

Returns the negative derivative of net flux with respect to level index: -d(net)/dz (or d(net)/d(level)) of shape (batch, channels, seq_length - 1)

If p is provided:

Returns the atmospheric heating rate in K/day using the formula: heating_rate = (g * 8.64e4) / (cp * 100) * d(net)/dp where g = 9.8066 m/s², cp = 1004 J/(kg·K) (calculated as 7*R/2 with R=287)

Return type:

torch.Tensor

Notes

  • The derivative is computed using finite differences: net[i+1] - net[i]

  • For pressure-based calculation, uses dp = p[i+1] - p[i]

  • The factor 8.64e4 converts from W/m² to K/day

  • The factor 100 converts pressure from hPa to Pa

Examples

>>> # Calculate net flux derivative
>>> hr = calc_hr(up, down)
>>> hr.shape
torch.Size([32, 4, 9])  # for seq_length=10
>>> # Calculate actual heating rate with pressure levels
>>> pressure = torch.linspace(1000, 100, 10)  # hPa
>>> heating_rate = calc_hr(up, down, p=pressure)
>>> heating_rate.shape
torch.Size([32, 4, 9])
rtnn.evaluater.run_validation(loader, model, norm_mapping, normalization_type, index_mapping, device, args, epoch)[source]

Evaluate model accuracy on LSM dataset.

Performs comprehensive evaluation including: - Loss computation for main fluxes and absorption rates - Metric calculation (NMAE, NMSE, R²) - Optional plotting of predictions vs targets

Parameters:
  • loader (torch.utils.data.DataLoader) – Data loader for evaluation dataset.

  • model (torch.nn.Module) – Trained model to evaluate.

  • norm_mapping (dict) – Normalization statistics for variables.

  • normalization_type (dict) – Normalization types per variable.

  • index_mapping (dict) – Mapping from channel indices to variable names.

  • device (torch.device) – Device to run evaluation on.

  • args (argparse.Namespace) – Arguments containing loss type, beta, etc.

  • epoch (int) – Current epoch number (for plotting).

Returns:

(valid_loss, valid_metrics) where valid_metrics is a dictionary containing computed metrics for fluxes, abs12, and abs34.

Return type:

tuple

Examples

>>> valid_loss, metrics = run_validation(
...     test_loader, model, norm_mapping, norm_type, idxmap, device, args, epoch
... )
>>> print(f"Validation loss: {valid_loss:.4f}")
>>> print(f"NMAE: {metrics['fluxes_NMAE']:.4f}")

rtnn.logger module

class rtnn.logger.Logger(console_output=True, file_output=False, log_file='module_log_file.log', pretty_print=True, record=False)[source]

Bases: object

__init__(console_output=True, file_output=False, log_file='module_log_file.log', pretty_print=True, record=False)[source]
clear_logs()[source]

Clear the stored Rich logs if record=True.

show_header(module_name)[source]

Display startup banner.

start_task(task_name: str, description: str = '', **meta)[source]

Display a clearly formatted ‘task start’ message with good spacing.

log_metrics()[source]

Log pipeline metrics

info(message)[source]

Formatted info message

warning(message)[source]

Formatted warning message

success(message)[source]

Custom success level (not default logging level)

step(step_name, message)[source]

Highlight pipeline step events

exception(message, exception=None)[source]

Display a formatted exception message with visual stack trace.

error(message, exception=None)[source]

Display a formatted error log, optionally including exception trace.

rtnn.main module

RTnn (Radiative Transfer Neural Network) Training Pipeline

This module provides the main entry point for training neural network models for radiative transfer calculations in climate modeling. It supports various model architectures including LSTM, GRU, Transformer, and FCN.

The training pipeline includes: - Data loading and preprocessing from NetCDF files - Model initialization and configuration - Training loop with progress tracking - Validation and metric computation - Checkpoint saving and model persistence - Visualization and logging

rtnn.main.print_version()[source]

Print detailed version information.

rtnn.main.parse_years(year_str)[source]

Parse a year string into a list of integers.

Supports hyphen-separated ranges (e.g., “1995-1999”) and comma-separated lists (e.g., “1995,1997,1999”). Returns a list of integers.

Parameters:

year_str (str) – String containing years in range or comma-separated format.

Returns:

List of parsed years.

Return type:

list of int

Examples

>>> parse_years("1995-1999")
[1995, 1996, 1997, 1998, 1999]
>>> parse_years("1995,1997,1999")
[1995, 1997, 1999]
rtnn.main.parse_args()[source]

Parse command-line arguments for RTnn model training.

Defines and parses command-line arguments required to configure and run the Radiative Transfer Neural Network (RTnn) training pipeline. This includes model architecture parameters, training hyperparameters, data configuration, and output settings.

Returns:

Object containing parsed command-line arguments, grouped as follows:

Model architecture
typestr

Model type (e.g., “lstm”, “gru”, “fcn”, “fullyconnected”, “transformer”, “cnn”, “mlp”).

hidden_sizeint

Size of hidden layers.

num_layersint

Number of model layers.

seq_lengthint

Length of input sequence.

feature_channelint

Number of input feature channels.

output_channelint

Number of output channels.

embed_sizeint

Embedding dimension for transformer models.

nheadint

Number of attention heads (transformer).

forward_expansionint

Expansion factor for feed-forward layers.

dropoutfloat

Dropout rate.

Training hyperparameters
batch_sizeint

Number of samples per batch.

tbatchint

Temporal batch length.

num_epochsint

Number of training epochs.

learning_ratefloat

Initial learning rate.

loss_typestr

Loss function (e.g., “mse”, “mae”, “nmae”, “nmse”, “wmse”, “logcosh”, “smoothl1”, “huber”).

betafloat

Weighting factor for loss components.

beta_deltafloat

Delta parameter for Huber or SmoothL1 loss.

num_workersint

Number of data loader workers.

Data configuration
train_data_filesstr

Path or pattern for training data files.

test_data_filesstr

Path or pattern for testing data files.

train_yearsstr

Training years (comma-separated or range, e.g., “1995-1999”).

test_yearstr

Test year or range.

normstr

Normalization scheme (e.g., “log1p_standard”, “standard”, “minmax”, “none”).

dataset_typestr

Dataset type (e.g., “LSM”, “RTM”).

Output configuration
root_dirstr

Root directory for all operations.

main_folderstr

Main output folder name.

sub_folderstr

Sub-folder name for the current run.

prefixstr

Prefix for saved files.

model_namestr

Custom model name (auto-generated if empty).

save_modelbool

Whether to save model checkpoints.

save_checkpoint_namestr

Base name for saved checkpoints.

save_per_samplesint

Save checkpoint every N samples.

load_modelbool

Whether to load an existing model.

load_checkpoint_namestr

Checkpoint file to load.

inferencebool

Run in inference-only mode.

Return type:

argparse.Namespace

Examples

>>> args = parse_args()
>>> args.type
'lstm'
>>> args.batch_size
16

Command line usage

$ rtnn –type lstm –hidden_size 128 –num_layers 3 –batch_size 32

rtnn.main.setup_directories_and_logging(args)[source]

Set up directory structure and logging infrastructure for experiments.

Parameters:

args (argparse.Namespace) – Parsed command-line arguments.

Returns:

  • paths (EasyDict) – Dictionary containing paths to created directories.

  • logger (Logger) – Configured logger instance.

rtnn.main.log_configuration(args, paths, logger)[source]

Log all configuration parameters to the provided logger.

Parameters:
  • args (argparse.Namespace) – Configuration object containing all experiment parameters.

  • paths (EasyDict) – Dictionary containing paths to various experiment directories.

  • logger (Logger) – Logger instance for outputting configuration information.

rtnn.main.setup_device_and_seed(args, logger)[source]

Set up device (GPU/CPU) and random seeds for reproducibility.

Parameters:
Returns:

Device to use for computations.

Return type:

torch.device

rtnn.main.get_data_files(args, logger)[source]

Get training and testing data files based on year specifications.

Parameters:
Returns:

(train_files, test_files) lists of file paths.

Return type:

tuple

rtnn.main.create_normalization_mapping(train_files, paths, logger)[source]

Create normalization mapping from training data.

Parameters:
  • train_files (list) – List of training file paths.

  • paths (EasyDict) – Directory paths.

  • logger (Logger) – Logger instance.

Returns:

Normalization statistics for each variable.

Return type:

dict

rtnn.main.create_datasets_and_loaders(args, train_files, test_files, norm_mapping, logger)[source]

Create datasets and data loaders for training and validation.

Parameters:
  • args (argparse.Namespace) – Parsed command-line arguments.

  • train_files (list) – Training file paths.

  • test_files (list) – Test file paths.

  • norm_mapping (dict) – Normalization statistics.

  • logger (Logger) – Logger instance.

Returns:

(train_loader, test_loader, train_dataset, test_dataset)

Return type:

tuple

rtnn.main.initialize_model(args, device, logger)[source]

Initialize the model architecture.

Parameters:
Returns:

Initialized model.

Return type:

torch.nn.Module

rtnn.main.load_checkpoint_if_requested(args, model, optimizer, paths, device, logger)[source]

Load model checkpoint if requested using ModelUtils.load_training_checkpoint().

This function leverages ModelUtils.load_training_checkpoint() which handles: - DataParallel compatibility - Loading model and optimizer states - Extracting training state (epoch, samples, metrics, etc.)

Parameters:
Returns:

(start_epoch, samples_processed, batches_processed, best_val_loss,

best_epoch, checkpoint, train_loss_history, valid_loss_history, valid_metrics_history)

Return type:

tuple

rtnn.main.save_checkpoint(model, optimizer, epoch, samples_processed, batches_processed, train_loss, valid_loss, args, paths, logger, checkpoint_type='epoch')[source]

Save model checkpoint using ModelUtils.

Parameters:
  • model (torch.nn.Module) – Model to save.

  • optimizer (torch.optim.Optimizer) – Optimizer state to save.

  • epoch (int) – Current epoch.

  • samples_processed (int) – Total samples processed.

  • batches_processed (int) – Total batches processed.

  • train_loss (float) – Current training loss.

  • valid_loss (float) – Current validation loss.

  • args (argparse.Namespace) – Command-line arguments.

  • paths (EasyDict) – Directory paths.

  • logger (Logger) – Logger instance.

  • checkpoint_type (str) – Type of checkpoint (“epoch”, “best”, “final”, “samples”).

rtnn.main.train_epoch(model, train_loader, optimizer, loss_func, metric_funcs, metric_names, output_keys, train_metrics, train_loss_tracker, norm_mapping, normalization_type, index_mapping, device, args, epoch, writer, global_step, logger)[source]

Train for one epoch.

Returns:

(average_train_loss, updated_global_step)

Return type:

tuple

rtnn.main.main()[source]

Main entry point for training the RTnn model.

rtnn.model_loader module

rtnn.model_loader.load_model(args)[source]

Load and initialize a model based on the provided configuration.

This function acts as a factory that instantiates the appropriate model architecture based on the type argument. Supported models include: - LSTM: Bidirectional LSTM with Conv1d output projection - GRU: Bidirectional GRU with Conv1d output projection - Transformer: Transformer encoder with positional embeddings - FCN/fullyconnected: Fully connected network with configurable depth

Parameters:

args (argparse.Namespace) –

Namespace containing model configuration parameters. Required attributes depend on the model type:

For LSTM/GRU:
  • type : str (‘lstm’ or ‘gru’)

  • feature_channel : int

  • output_channel : int

  • hidden_size : int

  • num_layers : int

For Transformer:
  • type : str (‘transformer’)

  • feature_channel : int

  • output_channel : int

  • embed_size : int

  • num_layers : int

  • nhead : int

  • forward_expansion : int

  • seq_length : int

  • dropout : float

For FCN/fullyconnected:
  • type : str (‘fcn’ or ‘fullyconnected’)

  • feature_channel : int

  • output_channel : int

  • num_layers : int

  • hidden_size : int

  • seq_length : int

  • dim_expand : int (optional, default 0)

Returns:

Initialized PyTorch model of the specified architecture.

Return type:

torch.nn.Module

Raises:

ValueError – If the specified model type is not implemented.

Examples

>>> args = argparse.Namespace(
...     type='lstm',
...     feature_channel=6,
...     output_channel=4,
...     hidden_size=128,
...     num_layers=3
... )
>>> model = load_model(args)
>>> print(type(model))
<class 'rtnn.models.rnn.RNN_LSTM'>
>>> args = argparse.Namespace(
...     type='transformer',
...     feature_channel=6,
...     output_channel=4,
...     embed_size=64,
...     num_layers=2,
...     nhead=4,
...     forward_expansion=4,
...     seq_length=10,
...     dropout=0.1
... )
>>> model = load_model(args)
>>> print(type(model))
<class 'rtnn.models.Transformer.Encoder'>
>>> args = argparse.Namespace(
...     type='fcn',
...     feature_channel=6,
...     output_channel=4,
...     num_layers=3,
...     hidden_size=196,
...     seq_length=10
... )
>>> model = load_model(args)
>>> print(type(model))
<class 'rtnn.models.fcn.FCN'>

rtnn.model_utils module

class rtnn.model_utils.ModelUtils[source]

Bases: object

Utility class for model inspection, checkpointing, and memory profiling.

This class provides static methods for common model operations including parameter counting, memory usage analysis, checkpoint management, and model inspection.

Examples

>>> utils = ModelUtils()
>>> param_counts = ModelUtils.get_parameter_number(model)
>>> ModelUtils.save_checkpoint(state, "checkpoint.pth.tar", logger)
__init__()[source]

Initialize ModelUtils instance.

static get_parameter_number(model, logger=None)[source]

Calculate the total and trainable number of parameters in a model.

Parameters:
  • model (torch.nn.Module) – PyTorch model to inspect

  • logger (Logger, optional) – Logger instance for output, by default None

Returns:

Dictionary containing: - ‘Total’: Total number of parameters - ‘Trainable’: Number of trainable parameters

Return type:

dict

Examples

>>> model = torch.nn.Linear(10, 5)
>>> counts = ModelUtils.get_parameter_number(model, logger)
static print_model_layers(model, logger=None)[source]

Print model parameter names along with their gradient requirements.

Parameters:
  • model (torch.nn.Module) – PyTorch model to inspect

  • logger (Logger, optional) – Logger instance for output, by default None

Examples

>>> model = torch.nn.Sequential(
...     torch.nn.Linear(10, 5),
...     torch.nn.ReLU(),
...     torch.nn.Linear(5, 1)
... )
>>> ModelUtils.print_model_layers(model, logger)
static save_checkpoint(state, filename='checkpoint.pth.tar', logger=None)[source]

Save model and optimizer state to a file.

Parameters:
  • state (dict) – Dictionary containing model state_dict and other training information. Typically includes: - ‘state_dict’: Model parameters - ‘optimizer’: Optimizer state - ‘epoch’: Current epoch - ‘loss’: Current loss value

  • filename (str, optional) – File path to save the checkpoint, by default “checkpoint.pth.tar”

  • logger (Logger, optional) – Logger instance for output, by default None

Examples

>>> state = {
...     'state_dict': model.state_dict(),
...     'optimizer': optimizer.state_dict(),
...     'epoch': epoch,
...     'loss': loss
... }
>>> ModelUtils.save_checkpoint(state, 'model_checkpoint.pth.tar', logger)
static load_checkpoint(checkpoint, model, optimizer=None, logger=None)[source]

Load model and optimizer state from a checkpoint file.

Parameters:
  • checkpoint (dict) – Loaded checkpoint dictionary

  • model (torch.nn.Module) – Model to load weights into

  • optimizer (torch.optim.Optimizer, optional) – Optimizer to restore state, by default None

  • logger (Logger, optional) – Logger instance for output, by default None

Examples

>>> checkpoint = torch.load('model_checkpoint.pth.tar')
>>> ModelUtils.load_checkpoint(checkpoint, model, optimizer, logger)
static load_training_checkpoint(checkpoint_path, model, optimizer, device, logger=None)[source]

Load comprehensive training checkpoint.

Parameters:
Returns:

(epoch, samples_processed, batches_processed, best_val_loss, best_epoch, checkpoint)

Return type:

tuple

static count_parameters_by_layer(model, logger=None)[source]

Count parameters for each layer in the model.

Parameters:
  • model (torch.nn.Module) – PyTorch model to analyze

  • logger (Logger, optional) – Logger instance for output, by default None

Returns:

Dictionary with layer names as keys and parameter counts as values

Return type:

dict

Examples

>>> layer_params = ModelUtils.count_parameters_by_layer(model, logger)
static log_model_summary(model, input_shape=None, logger=None)[source]

Log comprehensive model summary including parameters and architecture.

Parameters:
  • model (torch.nn.Module) – PyTorch model to summarize

  • input_shape (tuple, optional) – Input shape for memory analysis, by default None

  • logger (Logger, optional) – Logger instance for output, by default None

static save_training_checkpoint(model, optimizer, epoch, samples_processed, batches_processed, train_loss_history, valid_loss_history, valid_metrics_history, best_val_loss, best_epoch, avg_val_loss, avg_epoch_loss, args, paths, logger, checkpoint_type='epoch', save_full_model=True)[source]

Save comprehensive training checkpoint with consistent formatting.

Parameters:
  • model (torch.nn.Module) – Model to save

  • optimizer (torch.optim.Optimizer) – Optimizer to save

  • epoch (int) – Current epoch

  • samples_processed (int) – Number of samples processed so far

  • batches_processed (int) – Number of batches processed so far

  • train_loss_history (list) – History of training losses

  • valid_loss_history (list) – History of validation losses

  • valid_metrics_history (dict) – History of validation metrics

  • best_val_loss (float) – Best validation loss so far

  • best_epoch (int) – Epoch with best validation loss

  • avg_val_loss (float) – Current epoch validation loss

  • avg_epoch_loss (float) – Current epoch training loss

  • args (argparse.Namespace) – Command line arguments

  • paths (EasyDict) – Directory paths

  • logger (Logger) – Logger instance

  • checkpoint_type (str) – Type of checkpoint: “samples”, “epoch”, “best”, “final”

  • save_full_model (bool) – Whether to also save the full model separately

Returns:

(checkpoint_filename, full_model_filename)

Return type:

tuple

Examples

>>> checkpoint_file, full_model_file = ModelUtils.save_training_checkpoint(
...     model, optimizer, epoch, samples_processed, batches_processed,
...     train_loss_history, valid_loss_history, valid_metrics_history,
...     best_val_loss, best_epoch, avg_val_loss, avg_epoch_loss,
...     args, paths, logger, checkpoint_type="best"
... )
static save_emergency_checkpoint(model, optimizer, epoch, samples_processed, batches_processed, train_loss_history, valid_loss_history, valid_metrics_history, args, paths, logger, reason='emergency')[source]

Save emergency checkpoint for recovery.

Parameters:

reason (str) – Reason for emergency save (e.g., “crash”, “interrupt”, “error”)

Returns:

(checkpoint_filename, full_model_filename)

Return type:

tuple

rtnn.stats module

rtnn.utils module

class rtnn.utils.EasyDict[source]

Bases: dict

A dictionary subclass that allows for attribute-style access to its items. This class extends the built-in dict and overrides the __getattr__, __setattr__, and __delattr__ methods to enable accessing dictionary keys as attributes. Original work: Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. Original source: https://github.com/NVlabs/edm

class rtnn.utils.FileUtils[source]

Bases: object

Utility class for file and directory operations.

__init__()[source]

Initialize the FileUtils class. This class does not maintain any state, so the constructor is empty.

static makedir(dirs)[source]

Create a directory if it does not exist.

Parameters:

dirs (str) – The path of the directory to be created.

static makefile(dirs, filename)[source]

Create an empty file in the specified directory. :param dirs: The path of the directory where the file will be created. :type dirs: str :param filename: The name of the file to be created. :type filename: str

rtnn.version module

Version information for rtnn.

rtnn.version.get_version()[source]

Return the version string.