Source code for rtnn.model_loader

from rtnn.models.rnn import RNN_LSTM, RNN_GRU
from rtnn.models.Transformer import Encoder
from rtnn.models.fcn import FCN


[docs] def load_model(args): """ Load and initialize a model based on the provided configuration. This function acts as a factory that instantiates the appropriate model architecture based on the `type` argument. Supported models include: - LSTM: Bidirectional LSTM with Conv1d output projection - GRU: Bidirectional GRU with Conv1d output projection - Transformer: Transformer encoder with positional embeddings - FCN/fullyconnected: Fully connected network with configurable depth Parameters ---------- args : argparse.Namespace Namespace containing model configuration parameters. Required attributes depend on the model type: For LSTM/GRU: - type : str ('lstm' or 'gru') - feature_channel : int - output_channel : int - hidden_size : int - num_layers : int For Transformer: - type : str ('transformer') - feature_channel : int - output_channel : int - embed_size : int - num_layers : int - nhead : int - forward_expansion : int - seq_length : int - dropout : float For FCN/fullyconnected: - type : str ('fcn' or 'fullyconnected') - feature_channel : int - output_channel : int - num_layers : int - hidden_size : int - seq_length : int - dim_expand : int (optional, default 0) Returns ------- torch.nn.Module Initialized PyTorch model of the specified architecture. Raises ------ ValueError If the specified model type is not implemented. Examples -------- >>> args = argparse.Namespace( ... type='lstm', ... feature_channel=6, ... output_channel=4, ... hidden_size=128, ... num_layers=3 ... ) >>> model = load_model(args) >>> print(type(model)) <class 'rtnn.models.rnn.RNN_LSTM'> >>> args = argparse.Namespace( ... type='transformer', ... feature_channel=6, ... output_channel=4, ... embed_size=64, ... num_layers=2, ... nhead=4, ... forward_expansion=4, ... seq_length=10, ... dropout=0.1 ... ) >>> model = load_model(args) >>> print(type(model)) <class 'rtnn.models.Transformer.Encoder'> >>> args = argparse.Namespace( ... type='fcn', ... feature_channel=6, ... output_channel=4, ... num_layers=3, ... hidden_size=196, ... seq_length=10 ... ) >>> model = load_model(args) >>> print(type(model)) <class 'rtnn.models.fcn.FCN'> """ model_type = args.type.lower() if model_type in ["lstm", "gru"]: model_class = RNN_LSTM if model_type == "lstm" else RNN_GRU model = model_class( feature_channel=args.feature_channel, output_channel=args.output_channel, hidden_size=args.hidden_size, num_layers=args.num_layers, ) elif model_type == "transformer": model = Encoder( feature_channel=args.feature_channel, output_channel=args.output_channel, embed_size=args.embed_size, num_layers=args.num_layers, heads=args.nhead, forward_expansion=args.forward_expansion if args.forward_expansion is not None else 1, seq_length=args.seq_length, dropout=args.dropout, ) elif model_type in ["fcn", "fullyconnected"]: model = FCN( feature_channel=args.feature_channel, output_channel=args.output_channel, num_layers=args.num_layers, hidden_size=args.hidden_size, seq_length=args.seq_length, dim_expand=0, ) else: raise ValueError(f"Model type '{args.type}' is not implemented.") return model