rtnn package
Subpackages
Submodules
rtnn.dataset module
- class rtnn.dataset.DataPreprocessor(*args: Any, **kwargs: Any)[source]
Bases:
DatasetDataset class for preprocessing LSM (Land Surface Model) data.
This class handles loading and preprocessing of NetCDF files containing climate data, with support for multiple years, spatial and temporal batching, and various normalization techniques.
- Parameters:
logger (object) – Logger instance for logging messages.
dfs (List[str]) – List of file paths to NetCDF files.
stime (int) – Start time index.
tstep (int) – Number of time steps per file.
tbatch (int) – Temporal batch size.
norm_mapping (Dict, optional) – Dictionary containing normalization statistics for each variable. Default is empty dict.
normalization_type (Dict, optional) – Dictionary specifying normalization type for each variable. Default is empty dict.
- norm_mapping
Normalization statistics.
- Type:
Dict
- normalization_type
Normalization types per variable.
- Type:
Dict
- time_blocks
Shuffled time blocks.
- Type:
np.ndarray
Examples
>>> from rtnn.logger import Logger >>> logger = Logger() >>> files = ["data_1995.nc", "data_1996.nc"] >>> dataset = DataPreprocessor( ... logger=logger, ... dfs=files, ... stime=0, ... tstep=100, ... tbatch=24, ... norm_mapping={}, ... normalization_type={} ... ) >>> len(dataset) 100 >>> features, targets = dataset[0] >>> features.shape torch.Size([schunk, feature_channels, seq_length]) >>> targets.shape torch.Size([schunk, output_channels, seq_length])
- __init__(logger: Any, dfs: List[str], stime: int, tbatch: int, training: bool = True, sblock_perc: float = 0.6, norm_mapping: Dict = {}, normalization_type: Dict = {}, debug: bool = False) None[source]
Initialize the DataPreprocessor.
- Parameters:
logger (Any) – Logger instance for logging messages.
dfs (List[str]) – List of file paths to NetCDF files.
stime (int) – Start time index.
tbatch (int) – Temporal batch size.
training (bool, optional) – If True, use 60% of spatial batches (data augmentation). If False, use 100% of spatial batches (full evaluation).
norm_mapping (Dict, optional) – Dictionary containing normalization statistics for each variable.
normalization_type (Dict, optional) – Dictionary specifying normalization type for each variable.
debug (bool, optional) – If True, print debug information.
- normalize(data: numpy.ndarray, var_name: str) numpy.ndarray[source]
Normalize data using the specified normalization method.
- Parameters:
data (np.ndarray) – Input data array to normalize.
var_name (str) – Name of the variable for which to retrieve normalization statistics.
- Returns:
Normalized data array.
- Return type:
np.ndarray
- Raises:
ValueError – If the normalization type is not supported.
Notes
Supported normalization types: - minmax: (x - min) / (max - min) - standard: (x - mean) / std - robust: (x - median) / IQR - log1p_minmax: log1p(x) normalized - log1p_standard: log1p(x) standardized - log1p_robust: log1p(x) robust normalized - sqrt_minmax: sqrt(x) normalized - sqrt_standard: sqrt(x) standardized - sqrt_robust: sqrt(x) robust normalized
rtnn.diagnostics module
Plotting utilities for RTnn model visualization.
This module provides functions for visualizing model predictions, training metrics, and data statistics. It includes tools for creating line plots, hexbin plots, histograms, and metric history plots using matplotlib.
The module supports: - Visualization of radiative transfer model predictions vs targets - Absorption rate plotting for different channels - Training and validation metric histories - Statistical distributions of input variables - Various normalization scheme visualizations
Dependencies
matplotlib : For plotting mpltex : For line styles scikit-learn : For R² score calculation xarray : For NetCDF data handling
- rtnn.diagnostics.stats(file_list, logger, output_dir, norm_mapping=None, plots=False)[source]
Compute statistics and generate histograms for variables in NetCDF files.
Reads a collection of NetCDF files, computes descriptive statistics for each variable, and generates histogram plots saved to disk. In addition to raw statistics, transformed statistics using logarithmic (log1p) and square-root transformations are also computed.
- Parameters:
file_list (list of str) – Paths to the NetCDF files to process.
logger (logging.Logger) – Logger used to report progress and informational messages.
output_dir (str) – Directory where histogram plots will be saved.
norm_mapping (dict, optional) – Dictionary to update with computed statistics. If None, a new dictionary is created. Default is None.
- Returns:
Dictionary mapping variable names to their computed statistics. Each variable contains the following entries:
- Raw statistics:
vmin : float
vmax : float
vmean : float
vstd : float
- Robust statistics:
q1 : float
q3 : float
iqr : float
median : float
- Log-transformed statistics (log1p):
log_min : float
log_max : float
log_mean : float
log_std : float
log_q1 : float
log_q3 : float
log_iqr : float
log_median : float
- Square-root-transformed statistics:
sqrt_min : float
sqrt_max : float
sqrt_mean : float
sqrt_std : float
sqrt_q1 : float
sqrt_q3 : float
sqrt_iqr : float
sqrt_median : float
- Return type:
Examples
>>> norm_mapping = stats( ... file_list=["data_1995.nc", "data_1996.nc"], ... logger=logger, ... output_dir="./stats" ... ) >>> norm_mapping["coszang"]["vmean"] 0.5
- rtnn.diagnostics.plot_RTM(predicts, targets, filename)[source]
Plot radiative transfer model predictions against targets.
Creates a 2x2 grid of line plots showing predicted vs true values for four different flux channels. Plots random samples from the batch.
- Parameters:
predicts (torch.Tensor) – Model predictions of shape (batch_size, 4, seq_length).
targets (torch.Tensor) – Ground truth targets of shape (batch_size, 4, seq_length).
filename (str) – Path where the plot will be saved.
Notes
Plots up to 7 random samples from the batch
Uses mpltex for line styles
Y-axis range: 0-1
X-axis: level indices
- rtnn.diagnostics.plot_HeatRate(abs12_predict, abs12_target, abs34_predict, abs34_target, filename)[source]
Plot absorption rates for two channel groups.
Creates a 2-panel figure showing absorption rates for channels 1-2 and 3-4.
- Parameters:
abs12_predict (torch.Tensor) – Predicted absorption for channels 1-2 of shape (batch_size, 1, seq_length).
abs12_target (torch.Tensor) – True absorption for channels 1-2.
abs34_predict (torch.Tensor) – Predicted absorption for channels 3-4.
abs34_target (torch.Tensor) – True absorption for channels 3-4.
filename (str) – Path where the plot will be saved.
Notes
Upper panel: channels 1-2
Lower panel: channels 3-4
Plots random samples from the batch
- rtnn.diagnostics.plot_flux_and_abs_lines(predicts, targets, abs12_predict=None, abs12_target=None, abs34_predict=None, abs34_target=None, filename='output_lines.png')[source]
Create line plots for fluxes and absorption rates.
Generates a multi-panel figure with line plots for four flux channels and optionally two absorption panels. Each panel shows predictions vs targets.
- Parameters:
predicts (torch.Tensor) – Model predictions for fluxes of shape (batch_size, 4, seq_length).
targets (torch.Tensor) – Ground truth fluxes.
abs12_predict (torch.Tensor, optional) – Predicted absorption for channels 1-2.
abs12_target (torch.Tensor, optional) – True absorption for channels 1-2.
abs34_predict (torch.Tensor, optional) – Predicted absorption for channels 3-4.
abs34_target (torch.Tensor, optional) – True absorption for channels 3-4.
filename (str, optional) – Output filename. Default is “output_lines.png”.
Notes
- Figure layout:
2x2 grid for fluxes (upwelling/downwelling for two channels)
Optional 1x2 grid for absorption rates (if provided)
- rtnn.diagnostics.plot_flux_and_abs(predicts, targets, abs12_predict=None, abs12_target=None, abs34_predict=None, abs34_target=None, filename='output.png')[source]
Create hexbin plots for fluxes and absorption rates.
Generates a multi-panel figure with hexbin density plots showing the relationship between predicted and true values. Useful for assessing prediction accuracy across the entire dataset.
- Parameters:
predicts (torch.Tensor) – Model predictions for fluxes of shape (batch_size, 4, seq_length).
targets (torch.Tensor) – Ground truth fluxes.
abs12_predict (torch.Tensor, optional) – Predicted absorption for channels 1-2.
abs12_target (torch.Tensor, optional) – True absorption for channels 1-2.
abs34_predict (torch.Tensor, optional) – Predicted absorption for channels 3-4.
abs34_target (torch.Tensor, optional) – True absorption for channels 3-4.
filename (str, optional) – Output filename. Default is “output.png”.
Notes
Hexbin plots use logarithmic color scale
Includes diagonal reference line (y=x)
Displays R² score in the top-left corner of each panel
Shared colorbar on the right
- rtnn.diagnostics.plot_metric_histories(train_history, valid_history, filename='training_validation_metrics.png')[source]
Plot training and validation metrics over epochs.
Creates a multi-panel figure showing the evolution of various metrics (e.g., NMAE, NMSE, R2) over training epochs.
- Parameters:
Notes
Metrics are plotted on a logarithmic scale
Each metric gets its own panel
Panels are arranged in a grid (3 columns)
Blue lines: training, Orange lines: validation
- rtnn.diagnostics.plot_loss_histories(train_loss, valid_loss, filename='training_validation_loss.png')[source]
Plot training and validation loss over epochs.
Creates a single-panel figure showing the loss evolution during training.
- Parameters:
Notes
Uses logarithmic scale for y-axis
Blue line: training loss
Orange line: validation loss
Includes grid for better readability
- rtnn.diagnostics.plot_spatial_temporal_density(sindex_tracker, tindex_tracker, mode='train', save_dir='./tests_plots', filename='density_scatter', figsize=(10, 10))[source]
Plot a density scatter plot of spatial index vs temporal index with marginal histograms.
This function creates a 2D density scatter plot (hexbin) showing the distribution of spatial indices (processor ranks) across temporal indices, with: - Right plot: Histogram of temporal index distribution - Top plot: Histogram of spatial index distribution
- Parameters:
sindex_tracker (list or array-like) – List of spatial indices (processor ranks) for each data sample.
tindex_tracker (list or array-like) – List of temporal indices for each data sample.
mode (str, optional) – Dataset mode identifier (“train”, “validation”, “test”).
save_dir (str, optional) – Directory path where the plot will be saved.
filename (str, optional) – Base name for the output file.
figsize (tuple, optional) – Figure size as (width, height) in inches.
- Returns:
Path to the saved plot file.
- Return type:
rtnn.evaluater module
Evaluation utilities for RTnn model assessment.
This module provides comprehensive evaluation tools for radiative transfer neural network models, including custom loss functions, metric computation, and visualization helpers.
The module includes: - Custom loss functions (NMSE, NMAE, combined MSE-MAE, LogCosh, Weighted MSE) - Metric calculators for evaluation (MSE, MAE, MBE, R², NMSE, NMAE, MARE, GMRAE) - Data normalization/de-normalization utilities - Absorption rate calculations - Main evaluation loop for LSM models
Dependencies
torch : For tensor operations and loss functions numpy : For numerical operations plot_helper : For visualization utilities
- class rtnn.evaluater.NMSELoss(*args: Any, **kwargs: Any)[source]
Bases:
ModuleNormalized Mean Squared Error Loss.
Computes MSE normalized by the mean square of the target values. Useful when the scale of the target variable varies.
- Parameters:
eps (float, optional) – Small constant for numerical stability. Default is 1e-8.
Examples
>>> criterion = NMSELoss() >>> loss = criterion(predictions, targets)
- class rtnn.evaluater.NMAELoss(*args: Any, **kwargs: Any)[source]
Bases:
ModuleNormalized Mean Absolute Error Loss.
Computes MAE normalized by the mean absolute value of the target. Provides a scale-invariant error metric.
- Parameters:
eps (float, optional) – Small constant for numerical stability. Default is 1e-8.
- class rtnn.evaluater.MetricTracker[source]
Bases:
objectA utility class for tracking and computing statistics of metric values.
This class maintains a running average of metric values and provides methods to compute mean and root mean squared values.
Examples
>>> tracker = MetricTracker() >>> tracker.update(10.0, 5) # value=10.0, count=5 samples >>> tracker.update(20.0, 3) # value=20.0, count=3 samples >>> print(tracker.getmean()) # (10*5 + 20*3) / (5+3) = 110/8 = 13.75 13.75 >>> print(tracker.getsqrtmean()) # sqrt(13.75) 3.7080992435478315
- getmean()[source]
Calculate the mean of all tracked values.
- Returns:
Weighted mean of all values: total_value / total_count
- Return type:
- Raises:
ZeroDivisionError – If no values have been added (count == 0)
- getstd()[source]
Calculate the standard deviation of all tracked values.
- Returns:
Weighted standard deviation of all values: sqrt(E(x^2) - (E(x))^2)
- Return type:
- Raises:
ZeroDivisionError – If no values have been added (count == 0)
- getsqrtmean()[source]
Calculate the square root of the mean of all tracked values.
- Returns:
Square root of the weighted mean: sqrt(total_value / total_count)
- Return type:
- Raises:
ZeroDivisionError – If no values have been added (count == 0)
- rtnn.evaluater.get_loss_function(loss_type, args, logger=None)[source]
Factory function to instantiate the requested loss function.
- Parameters:
loss_type (str) – Type of loss function. Options: - ‘mse’: Mean Squared Error - ‘mae’: Mean Absolute Error - ‘nmae’: Normalized Mean Absolute Error - ‘nmse’: Normalized Mean Squared Error - ‘wmse’: Weighted Mean Squared Error - ‘logcosh’: Log-Cosh loss - ‘smoothl1’: Smooth L1 Loss (Huber-like) - ‘huber’: Huber Loss
args (argparse.Namespace) – Arguments containing loss-specific parameters (e.g., beta_delta for Huber).
- Returns:
Initialized loss function.
- Return type:
- Raises:
ValueError – If loss_type is not supported or required parameters are missing.
Examples
>>> args = argparse.Namespace(beta_delta=1.0) >>> criterion = get_loss_function('huber', args)
- rtnn.evaluater.mse_all(pred, true)[source]
Compute Mean Squared Error.
- Parameters:
pred (torch.Tensor) – Predictions.
true (torch.Tensor) – Ground truth.
- Returns:
(num_elements, mse_value)
- Return type:
- rtnn.evaluater.mbe_all(pred, true)[source]
Compute Mean Bias Error.
- Parameters:
pred (torch.Tensor) – Predictions.
true (torch.Tensor) – Ground truth.
- Returns:
(num_elements, mbe_value)
- Return type:
- rtnn.evaluater.mae_all(pred, true)[source]
Compute Mean Absolute Error.
- Parameters:
pred (torch.Tensor) – Predictions.
true (torch.Tensor) – Ground truth.
- Returns:
(num_elements, mae_value)
- Return type:
- rtnn.evaluater.r2_all(pred, true)[source]
Calculate R2 (coefficient of determination) between predicted and true values.
Computes the R2 metric and returns both the number of elements and the R2 value.
- Parameters:
pred (torch.Tensor) – Predicted values from the model
true (torch.Tensor) – Ground truth values
- Returns:
(num_elements, r2_value) where: - num_elements (int): Total number of elements in the tensors - r2_value (torch.Tensor): R2 score
- Return type:
Notes
R2 is calculated as:
R2 = 1 - sum((true - pred)^2) / sum((true - mean(true))^2)
This implementation is fully torch-based and works on CPU and GPU.
- rtnn.evaluater.nmae_all(pred, true)[source]
Compute Normalized Mean Absolute Error.
- Parameters:
pred (torch.Tensor) – Predictions.
true (torch.Tensor) – Ground truth.
- Returns:
(num_elements, nmae_value)
- Return type:
- rtnn.evaluater.nmse_all(pred, true)[source]
Compute Normalized Mean Squared Error.
- Parameters:
pred (torch.Tensor) – Predictions.
true (torch.Tensor) – Ground truth.
- Returns:
(num_elements, nmse_value)
- Return type:
- rtnn.evaluater.mare_all(pred, true)[source]
Compute Mean Absolute Relative Error.
- Parameters:
pred (torch.Tensor) – Predictions.
true (torch.Tensor) – Ground truth.
- Returns:
(num_elements, mare_value)
- Return type:
- rtnn.evaluater.gmrae_all(pred, true)[source]
Compute Geometric Mean Relative Absolute Error.
- Parameters:
pred (torch.Tensor) – Predictions.
true (torch.Tensor) – Ground truth.
- Returns:
(num_elements, gmrae_value)
- Return type:
- rtnn.evaluater.unnorm_mpas(pred, targ, norm_mapping, normalization_type, idxmap)[source]
Reverse normalization to recover original physical values.
Applies inverse transformation based on the normalization method used during preprocessing.
- Parameters:
pred (torch.Tensor) – Normalized predictions of shape (batch, channels, seq_length).
targ (torch.Tensor) – Normalized targets.
norm_mapping (dict) – Dictionary containing normalization statistics for each variable.
normalization_type (dict) – Dictionary mapping variable names to normalization types.
idxmap (dict) – Dictionary mapping channel indices to variable names.
- Returns:
tuple – (unnormalized_predictions, unnormalized_targets)
Supported normalization types –
minmax: x * (max - min) + min
standard: x * std + mean
robust: x * iqr + median
log1p_*: expm1(x * scale + offset)
sqrt_*: (x * scale + offset) ** 2
Examples
>>> idxmap = {0: 'collim_alb', 1: 'collim_tran'} >>> upred, utarg = unnorm_mpas(pred, targ, norm_mapping, norm_type, idxmap)
- rtnn.evaluater.calc_abs(pred, targ, p=None)[source]
Calculate absorption rates from flux predictions.
Computes net absorption rates for two channel groups (1-2 and 3-4) using the difference between upwelling and downwelling fluxes.
- Parameters:
pred (torch.Tensor) – Predictions of shape (batch, 4, seq_length) where channels 0-1 are for first group and 2-3 for second group.
targ (torch.Tensor) – Targets of same shape as pred.
p (torch.Tensor, optional) – Pressure levels for atmospheric heating rate calculation. If provided, computes heating rate using pressure gradients.
- Returns:
(abs12_pred, abs12_targ, abs34_pred, abs34_targ) where each is a tensor of shape (batch, 1, seq_length-1).
- Return type:
Notes
If p is None: returns d(net) where net = up - down
If p is provided: returns heating rate using d(net)/dp
- rtnn.evaluater.calc_hr(up, down, p=None)[source]
Calculate heating rate from upwelling and downwelling fluxes.
Computes the net radiative heating rate by taking the vertical derivative of net flux (upwelling - downwelling). If pressure levels are provided, calculates the actual atmospheric heating rate using pressure gradients.
- Parameters:
up (torch.Tensor) – Upwelling flux tensor of shape (batch, channels, seq_length).
down (torch.Tensor) – Downwelling flux tensor of shape (batch, channels, seq_length).
p (torch.Tensor, optional) – Pressure levels of shape (seq_length,) or (batch, seq_length). If provided, computes physical heating rate. If None, returns the derivative of net flux.
- Returns:
- If p is None:
Returns the negative derivative of net flux with respect to level index: -d(net)/dz (or d(net)/d(level)) of shape (batch, channels, seq_length - 1)
- If p is provided:
Returns the atmospheric heating rate in K/day using the formula: heating_rate = (g * 8.64e4) / (cp * 100) * d(net)/dp where g = 9.8066 m/s², cp = 1004 J/(kg·K) (calculated as 7*R/2 with R=287)
- Return type:
Notes
The derivative is computed using finite differences: net[i+1] - net[i]
For pressure-based calculation, uses dp = p[i+1] - p[i]
The factor 8.64e4 converts from W/m² to K/day
The factor 100 converts pressure from hPa to Pa
Examples
>>> # Calculate net flux derivative >>> hr = calc_hr(up, down) >>> hr.shape torch.Size([32, 4, 9]) # for seq_length=10
>>> # Calculate actual heating rate with pressure levels >>> pressure = torch.linspace(1000, 100, 10) # hPa >>> heating_rate = calc_hr(up, down, p=pressure) >>> heating_rate.shape torch.Size([32, 4, 9])
- rtnn.evaluater.run_validation(loader, model, norm_mapping, normalization_type, index_mapping, device, args, epoch)[source]
Evaluate model accuracy on LSM dataset.
Performs comprehensive evaluation including: - Loss computation for main fluxes and absorption rates - Metric calculation (NMAE, NMSE, R²) - Optional plotting of predictions vs targets
- Parameters:
loader (torch.utils.data.DataLoader) – Data loader for evaluation dataset.
model (torch.nn.Module) – Trained model to evaluate.
norm_mapping (dict) – Normalization statistics for variables.
normalization_type (dict) – Normalization types per variable.
index_mapping (dict) – Mapping from channel indices to variable names.
device (torch.device) – Device to run evaluation on.
args (argparse.Namespace) – Arguments containing loss type, beta, etc.
epoch (int) – Current epoch number (for plotting).
- Returns:
(valid_loss, valid_metrics) where valid_metrics is a dictionary containing computed metrics for fluxes, abs12, and abs34.
- Return type:
Examples
>>> valid_loss, metrics = run_validation( ... test_loader, model, norm_mapping, norm_type, idxmap, device, args, epoch ... ) >>> print(f"Validation loss: {valid_loss:.4f}") >>> print(f"NMAE: {metrics['fluxes_NMAE']:.4f}")
rtnn.logger module
- class rtnn.logger.Logger(console_output=True, file_output=False, log_file='module_log_file.log', pretty_print=True, record=False)[source]
Bases:
object- __init__(console_output=True, file_output=False, log_file='module_log_file.log', pretty_print=True, record=False)[source]
- start_task(task_name: str, description: str = '', **meta)[source]
Display a clearly formatted ‘task start’ message with good spacing.
rtnn.main module
RTnn (Radiative Transfer Neural Network) Training Pipeline
This module provides the main entry point for training neural network models for radiative transfer calculations in climate modeling. It supports various model architectures including LSTM, GRU, Transformer, and FCN.
The training pipeline includes: - Data loading and preprocessing from NetCDF files - Model initialization and configuration - Training loop with progress tracking - Validation and metric computation - Checkpoint saving and model persistence - Visualization and logging
- rtnn.main.parse_years(year_str)[source]
Parse a year string into a list of integers.
Supports hyphen-separated ranges (e.g., “1995-1999”) and comma-separated lists (e.g., “1995,1997,1999”). Returns a list of integers.
- Parameters:
year_str (str) – String containing years in range or comma-separated format.
- Returns:
List of parsed years.
- Return type:
Examples
>>> parse_years("1995-1999") [1995, 1996, 1997, 1998, 1999]
>>> parse_years("1995,1997,1999") [1995, 1997, 1999]
- rtnn.main.parse_args()[source]
Parse command-line arguments for RTnn model training.
Defines and parses command-line arguments required to configure and run the Radiative Transfer Neural Network (RTnn) training pipeline. This includes model architecture parameters, training hyperparameters, data configuration, and output settings.
- Returns:
Object containing parsed command-line arguments, grouped as follows:
- Model architecture
- typestr
Model type (e.g., “lstm”, “gru”, “fcn”, “fullyconnected”, “transformer”, “cnn”, “mlp”).
- hidden_sizeint
Size of hidden layers.
- num_layersint
Number of model layers.
- seq_lengthint
Length of input sequence.
- feature_channelint
Number of input feature channels.
- output_channelint
Number of output channels.
- embed_sizeint
Embedding dimension for transformer models.
- nheadint
Number of attention heads (transformer).
- forward_expansionint
Expansion factor for feed-forward layers.
- dropoutfloat
Dropout rate.
- Training hyperparameters
- batch_sizeint
Number of samples per batch.
- tbatchint
Temporal batch length.
- num_epochsint
Number of training epochs.
- learning_ratefloat
Initial learning rate.
- loss_typestr
Loss function (e.g., “mse”, “mae”, “nmae”, “nmse”, “wmse”, “logcosh”, “smoothl1”, “huber”).
- betafloat
Weighting factor for loss components.
- beta_deltafloat
Delta parameter for Huber or SmoothL1 loss.
- num_workersint
Number of data loader workers.
- Data configuration
- train_data_filesstr
Path or pattern for training data files.
- test_data_filesstr
Path or pattern for testing data files.
- train_yearsstr
Training years (comma-separated or range, e.g., “1995-1999”).
- test_yearstr
Test year or range.
- normstr
Normalization scheme (e.g., “log1p_standard”, “standard”, “minmax”, “none”).
- dataset_typestr
Dataset type (e.g., “LSM”, “RTM”).
- Output configuration
- root_dirstr
Root directory for all operations.
- main_folderstr
Main output folder name.
- sub_folderstr
Sub-folder name for the current run.
- prefixstr
Prefix for saved files.
- model_namestr
Custom model name (auto-generated if empty).
- save_modelbool
Whether to save model checkpoints.
- save_checkpoint_namestr
Base name for saved checkpoints.
- save_per_samplesint
Save checkpoint every N samples.
- load_modelbool
Whether to load an existing model.
- load_checkpoint_namestr
Checkpoint file to load.
- inferencebool
Run in inference-only mode.
- Return type:
Examples
>>> args = parse_args() >>> args.type 'lstm' >>> args.batch_size 16
Command line usage
$ rtnn –type lstm –hidden_size 128 –num_layers 3 –batch_size 32
- rtnn.main.setup_directories_and_logging(args)[source]
Set up directory structure and logging infrastructure for experiments.
- Parameters:
args (argparse.Namespace) – Parsed command-line arguments.
- Returns:
paths (EasyDict) – Dictionary containing paths to created directories.
logger (Logger) – Configured logger instance.
- rtnn.main.log_configuration(args, paths, logger)[source]
Log all configuration parameters to the provided logger.
- Parameters:
args (argparse.Namespace) – Configuration object containing all experiment parameters.
paths (EasyDict) – Dictionary containing paths to various experiment directories.
logger (Logger) – Logger instance for outputting configuration information.
- rtnn.main.setup_device_and_seed(args, logger)[source]
Set up device (GPU/CPU) and random seeds for reproducibility.
- Parameters:
args (argparse.Namespace) – Parsed command-line arguments.
logger (Logger) – Logger instance.
- Returns:
Device to use for computations.
- Return type:
- rtnn.main.get_data_files(args, logger)[source]
Get training and testing data files based on year specifications.
- Parameters:
args (argparse.Namespace) – Parsed command-line arguments.
logger (Logger) – Logger instance.
- Returns:
(train_files, test_files) lists of file paths.
- Return type:
- rtnn.main.create_normalization_mapping(train_files, paths, logger)[source]
Create normalization mapping from training data.
- rtnn.main.create_datasets_and_loaders(args, train_files, test_files, norm_mapping, logger)[source]
Create datasets and data loaders for training and validation.
- Parameters:
args (argparse.Namespace) – Parsed command-line arguments.
train_files (list) – Training file paths.
test_files (list) – Test file paths.
norm_mapping (dict) – Normalization statistics.
logger (Logger) – Logger instance.
- Returns:
(train_loader, test_loader, train_dataset, test_dataset)
- Return type:
- rtnn.main.initialize_model(args, device, logger)[source]
Initialize the model architecture.
- Parameters:
args (argparse.Namespace) – Parsed command-line arguments.
device (torch.device) – Device to place model on.
logger (Logger) – Logger instance.
- Returns:
Initialized model.
- Return type:
- rtnn.main.load_checkpoint_if_requested(args, model, optimizer, paths, device, logger)[source]
Load model checkpoint if requested using ModelUtils.load_training_checkpoint().
This function leverages ModelUtils.load_training_checkpoint() which handles: - DataParallel compatibility - Loading model and optimizer states - Extracting training state (epoch, samples, metrics, etc.)
- Parameters:
args (argparse.Namespace) – Parsed command-line arguments.
model (torch.nn.Module) – Model to load weights into.
optimizer (torch.optim.Optimizer) – Optimizer to restore state.
paths (EasyDict) – Directory paths.
device (torch.device) – Device for loading.
logger (Logger) – Logger instance.
- Returns:
- (start_epoch, samples_processed, batches_processed, best_val_loss,
best_epoch, checkpoint, train_loss_history, valid_loss_history, valid_metrics_history)
- Return type:
- rtnn.main.save_checkpoint(model, optimizer, epoch, samples_processed, batches_processed, train_loss, valid_loss, args, paths, logger, checkpoint_type='epoch')[source]
Save model checkpoint using ModelUtils.
- Parameters:
model (torch.nn.Module) – Model to save.
optimizer (torch.optim.Optimizer) – Optimizer state to save.
epoch (int) – Current epoch.
samples_processed (int) – Total samples processed.
batches_processed (int) – Total batches processed.
train_loss (float) – Current training loss.
valid_loss (float) – Current validation loss.
args (argparse.Namespace) – Command-line arguments.
paths (EasyDict) – Directory paths.
logger (Logger) – Logger instance.
checkpoint_type (str) – Type of checkpoint (“epoch”, “best”, “final”, “samples”).
- rtnn.main.train_epoch(model, train_loader, optimizer, loss_func, metric_funcs, metric_names, output_keys, train_metrics, train_loss_tracker, norm_mapping, normalization_type, index_mapping, device, args, epoch, writer, global_step, logger)[source]
Train for one epoch.
- Returns:
(average_train_loss, updated_global_step)
- Return type:
rtnn.model_loader module
- rtnn.model_loader.load_model(args)[source]
Load and initialize a model based on the provided configuration.
This function acts as a factory that instantiates the appropriate model architecture based on the type argument. Supported models include: - LSTM: Bidirectional LSTM with Conv1d output projection - GRU: Bidirectional GRU with Conv1d output projection - Transformer: Transformer encoder with positional embeddings - FCN/fullyconnected: Fully connected network with configurable depth
- Parameters:
args (argparse.Namespace) –
Namespace containing model configuration parameters. Required attributes depend on the model type:
- For LSTM/GRU:
type : str (‘lstm’ or ‘gru’)
feature_channel : int
output_channel : int
hidden_size : int
num_layers : int
- For Transformer:
type : str (‘transformer’)
feature_channel : int
output_channel : int
embed_size : int
num_layers : int
nhead : int
forward_expansion : int
seq_length : int
dropout : float
- For FCN/fullyconnected:
type : str (‘fcn’ or ‘fullyconnected’)
feature_channel : int
output_channel : int
num_layers : int
hidden_size : int
seq_length : int
dim_expand : int (optional, default 0)
- Returns:
Initialized PyTorch model of the specified architecture.
- Return type:
- Raises:
ValueError – If the specified model type is not implemented.
Examples
>>> args = argparse.Namespace( ... type='lstm', ... feature_channel=6, ... output_channel=4, ... hidden_size=128, ... num_layers=3 ... ) >>> model = load_model(args) >>> print(type(model)) <class 'rtnn.models.rnn.RNN_LSTM'>
>>> args = argparse.Namespace( ... type='transformer', ... feature_channel=6, ... output_channel=4, ... embed_size=64, ... num_layers=2, ... nhead=4, ... forward_expansion=4, ... seq_length=10, ... dropout=0.1 ... ) >>> model = load_model(args) >>> print(type(model)) <class 'rtnn.models.Transformer.Encoder'>
>>> args = argparse.Namespace( ... type='fcn', ... feature_channel=6, ... output_channel=4, ... num_layers=3, ... hidden_size=196, ... seq_length=10 ... ) >>> model = load_model(args) >>> print(type(model)) <class 'rtnn.models.fcn.FCN'>
rtnn.model_utils module
- class rtnn.model_utils.ModelUtils[source]
Bases:
objectUtility class for model inspection, checkpointing, and memory profiling.
This class provides static methods for common model operations including parameter counting, memory usage analysis, checkpoint management, and model inspection.
Examples
>>> utils = ModelUtils() >>> param_counts = ModelUtils.get_parameter_number(model) >>> ModelUtils.save_checkpoint(state, "checkpoint.pth.tar", logger)
- static get_parameter_number(model, logger=None)[source]
Calculate the total and trainable number of parameters in a model.
- Parameters:
model (torch.nn.Module) – PyTorch model to inspect
logger (Logger, optional) – Logger instance for output, by default None
- Returns:
Dictionary containing: - ‘Total’: Total number of parameters - ‘Trainable’: Number of trainable parameters
- Return type:
Examples
>>> model = torch.nn.Linear(10, 5) >>> counts = ModelUtils.get_parameter_number(model, logger)
- static print_model_layers(model, logger=None)[source]
Print model parameter names along with their gradient requirements.
- Parameters:
model (torch.nn.Module) – PyTorch model to inspect
logger (Logger, optional) – Logger instance for output, by default None
Examples
>>> model = torch.nn.Sequential( ... torch.nn.Linear(10, 5), ... torch.nn.ReLU(), ... torch.nn.Linear(5, 1) ... ) >>> ModelUtils.print_model_layers(model, logger)
- static save_checkpoint(state, filename='checkpoint.pth.tar', logger=None)[source]
Save model and optimizer state to a file.
- Parameters:
state (dict) – Dictionary containing model state_dict and other training information. Typically includes: - ‘state_dict’: Model parameters - ‘optimizer’: Optimizer state - ‘epoch’: Current epoch - ‘loss’: Current loss value
filename (str, optional) – File path to save the checkpoint, by default “checkpoint.pth.tar”
logger (Logger, optional) – Logger instance for output, by default None
Examples
>>> state = { ... 'state_dict': model.state_dict(), ... 'optimizer': optimizer.state_dict(), ... 'epoch': epoch, ... 'loss': loss ... } >>> ModelUtils.save_checkpoint(state, 'model_checkpoint.pth.tar', logger)
- static load_checkpoint(checkpoint, model, optimizer=None, logger=None)[source]
Load model and optimizer state from a checkpoint file.
- Parameters:
checkpoint (dict) – Loaded checkpoint dictionary
model (torch.nn.Module) – Model to load weights into
optimizer (torch.optim.Optimizer, optional) – Optimizer to restore state, by default None
logger (Logger, optional) – Logger instance for output, by default None
Examples
>>> checkpoint = torch.load('model_checkpoint.pth.tar') >>> ModelUtils.load_checkpoint(checkpoint, model, optimizer, logger)
- static load_training_checkpoint(checkpoint_path, model, optimizer, device, logger=None)[source]
Load comprehensive training checkpoint.
- Parameters:
checkpoint_path (str) – Path to checkpoint file
model (torch.nn.Module) – Model to load weights into
optimizer (torch.optim.Optimizer) – Optimizer to restore state
device (torch.device) – Device to load checkpoint to
logger (Logger, optional) – Logger instance for output
- Returns:
(epoch, samples_processed, batches_processed, best_val_loss, best_epoch, checkpoint)
- Return type:
- static count_parameters_by_layer(model, logger=None)[source]
Count parameters for each layer in the model.
- Parameters:
model (torch.nn.Module) – PyTorch model to analyze
logger (Logger, optional) – Logger instance for output, by default None
- Returns:
Dictionary with layer names as keys and parameter counts as values
- Return type:
Examples
>>> layer_params = ModelUtils.count_parameters_by_layer(model, logger)
- static log_model_summary(model, input_shape=None, logger=None)[source]
Log comprehensive model summary including parameters and architecture.
- Parameters:
model (torch.nn.Module) – PyTorch model to summarize
input_shape (tuple, optional) – Input shape for memory analysis, by default None
logger (Logger, optional) – Logger instance for output, by default None
- static save_training_checkpoint(model, optimizer, epoch, samples_processed, batches_processed, train_loss_history, valid_loss_history, valid_metrics_history, best_val_loss, best_epoch, avg_val_loss, avg_epoch_loss, args, paths, logger, checkpoint_type='epoch', save_full_model=True)[source]
Save comprehensive training checkpoint with consistent formatting.
- Parameters:
model (torch.nn.Module) – Model to save
optimizer (torch.optim.Optimizer) – Optimizer to save
epoch (int) – Current epoch
samples_processed (int) – Number of samples processed so far
batches_processed (int) – Number of batches processed so far
train_loss_history (list) – History of training losses
valid_loss_history (list) – History of validation losses
valid_metrics_history (dict) – History of validation metrics
best_val_loss (float) – Best validation loss so far
best_epoch (int) – Epoch with best validation loss
avg_val_loss (float) – Current epoch validation loss
avg_epoch_loss (float) – Current epoch training loss
args (argparse.Namespace) – Command line arguments
paths (EasyDict) – Directory paths
logger (Logger) – Logger instance
checkpoint_type (str) – Type of checkpoint: “samples”, “epoch”, “best”, “final”
save_full_model (bool) – Whether to also save the full model separately
- Returns:
(checkpoint_filename, full_model_filename)
- Return type:
Examples
>>> checkpoint_file, full_model_file = ModelUtils.save_training_checkpoint( ... model, optimizer, epoch, samples_processed, batches_processed, ... train_loss_history, valid_loss_history, valid_metrics_history, ... best_val_loss, best_epoch, avg_val_loss, avg_epoch_loss, ... args, paths, logger, checkpoint_type="best" ... )
rtnn.stats module
rtnn.utils module
- class rtnn.utils.EasyDict[source]
Bases:
dictA dictionary subclass that allows for attribute-style access to its items. This class extends the built-in dict and overrides the __getattr__, __setattr__, and __delattr__ methods to enable accessing dictionary keys as attributes. Original work: Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. Original source: https://github.com/NVlabs/edm
- class rtnn.utils.FileUtils[source]
Bases:
objectUtility class for file and directory operations.
- __init__()[source]
Initialize the FileUtils class. This class does not maintain any state, so the constructor is empty.
rtnn.version module
Version information for rtnn.