Source code for rtnn.models.rnn

# Copyright 2026 IPSL / CNRS / Sorbonne University
# Authors: Kazem Ardaneh
#
# This work is licensed under the Creative Commons
# Attribution-NonCommercial-ShareAlike 4.0 International License.
# To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc-sa/4.0/

import torch
import torch.nn as nn
from typing import Tuple, Union


[docs] class BaseRNN(nn.Module): """ Base class for bidirectional RNN modules (LSTM/GRU). This class provides a common interface for both LSTM and GRU models with bidirectional processing and a final 1D convolutional layer to map the hidden states to the desired output channels. Parameters ---------- feature_channel : int Number of input features per time step. output_channel : int Number of output channels (target variables). hidden_size : int Number of hidden units in the RNN layers. num_layers : int Number of stacked RNN layers. rnn_type : str Type of RNN cell, either 'lstm' or 'gru'. Attributes ---------- rnn : nn.LSTM or nn.GRU The bidirectional RNN layer. final : nn.Conv1d Final 1D convolution to project hidden states to output channels. hidden_size : int Number of hidden units. num_layers : int Number of stacked layers. output_channel : int Number of output channels. Examples -------- >>> model = BaseRNN( ... feature_channel=6, ... output_channel=4, ... hidden_size=64, ... num_layers=2, ... rnn_type='lstm' ... ) >>> x = torch.randn(32, 6, 10) # (batch, features, sequence) >>> y = model(x) >>> y.shape torch.Size([32, 4, 10]) """
[docs] def __init__( self, feature_channel: int, output_channel: int, hidden_size: int, num_layers: int, rnn_type: str, ) -> None: """ Initialize the BaseRNN module. Parameters ---------- feature_channel : int Number of input features. output_channel : int Number of output channels. hidden_size : int Size of hidden state. num_layers : int Number of RNN layers. rnn_type : str Type of RNN ('lstm' or 'gru'). Raises ------ ValueError If rnn_type is not 'lstm' or 'gru'. """ super().__init__() # Validate input if rnn_type not in ["lstm", "gru"]: raise ValueError(f"rnn_type must be 'lstm' or 'gru', got {rnn_type}") self.hidden_size = hidden_size self.num_layers = num_layers self.output_channel = output_channel self.rnn_type = rnn_type # Select RNN class rnn_class = nn.LSTM if rnn_type == "lstm" else nn.GRU # Create bidirectional RNN self.rnn = rnn_class( input_size=feature_channel, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, bidirectional=True, ) # Final projection layer self.final = nn.Conv1d( in_channels=2 * hidden_size, # bidirectional doubles hidden size out_channels=output_channel, kernel_size=1, padding=0, bias=True, )
[docs] def init_hidden( self, batch_size: int, device: torch.device ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Initialize the hidden state for the RNN. Parameters ---------- batch_size : int Batch size for the input. device : torch.device Device to create the hidden state on. Returns ------- torch.Tensor or tuple of torch.Tensor For GRU: returns hidden state tensor of shape (2 * num_layers, batch_size, hidden_size) For LSTM: returns tuple (hidden, cell) both of same shape. """ hidden = torch.zeros( 2 * self.num_layers, batch_size, self.hidden_size, device=device, requires_grad=False, ) if self.rnn_type == "lstm": return (hidden, hidden.clone()) return hidden
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass through the bidirectional RNN. Parameters ---------- x : torch.Tensor Input tensor of shape (batch_size, feature_channel, seq_length). Returns ------- torch.Tensor Output tensor of shape (batch_size, output_channel, seq_length). Notes ----- The input is permuted to (batch_size, seq_length, feature_channel) for the RNN, then the output is permuted back for the convolution. """ # Permute to (batch, seq, features) for RNN x = x.permute(0, 2, 1) # Initialize hidden state hidden = self.init_hidden(x.size(0), x.device) # Forward through RNN out, _ = self.rnn(x, hidden) # Permute back to (batch, features, seq) for convolution out = out.permute(0, 2, 1) # Final projection return self.final(out)
[docs] class RNN_LSTM(BaseRNN): """ LSTM-based bidirectional RNN model. This class inherits from BaseRNN and configures it to use LSTM cells. Parameters ---------- feature_channel : int Number of input features. output_channel : int Number of output channels. hidden_size : int Size of hidden state. num_layers : int Number of LSTM layers. Examples -------- >>> model = RNN_LSTM( ... feature_channel=6, ... output_channel=4, ... hidden_size=128, ... num_layers=3 ... ) >>> x = torch.randn(16, 6, 10) >>> y = model(x) >>> print(y.shape) torch.Size([16, 4, 10]) """
[docs] def __init__( self, feature_channel: int, output_channel: int, hidden_size: int, num_layers: int, ) -> None: """ Initialize the LSTM model. Parameters ---------- feature_channel : int Number of input features. output_channel : int Number of output channels. hidden_size : int Size of hidden state. num_layers : int Number of LSTM layers. """ super().__init__( feature_channel, output_channel, hidden_size, num_layers, "lstm" )
[docs] class RNN_GRU(BaseRNN): """ GRU-based bidirectional RNN model. This class inherits from BaseRNN and configures it to use GRU cells. Parameters ---------- feature_channel : int Number of input features. output_channel : int Number of output channels. hidden_size : int Size of hidden state. num_layers : int Number of GRU layers. Examples -------- >>> model = RNN_GRU( ... feature_channel=6, ... output_channel=4, ... hidden_size=128, ... num_layers=3 ... ) >>> x = torch.randn(16, 6, 10) >>> y = model(x) >>> print(y.shape) torch.Size([16, 4, 10]) """
[docs] def __init__( self, feature_channel: int, output_channel: int, hidden_size: int, num_layers: int, ) -> None: """ Initialize the GRU model. Parameters ---------- feature_channel : int Number of input features. output_channel : int Number of output channels. hidden_size : int Size of hidden state. num_layers : int Number of GRU layers. """ super().__init__( feature_channel, output_channel, hidden_size, num_layers, "gru" )