# Copyright 2026 IPSL / CNRS / Sorbonne University
# Authors: Kazem Ardaneh
#
# This work is licensed under the Creative Commons
# Attribution-NonCommercial-ShareAlike 4.0 International License.
# To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc-sa/4.0/
import torch
import torch.nn as nn
from typing import Optional
[docs]
class SelfAttention(nn.Module):
"""
Multi-head self-attention mechanism.
This module implements scaled dot-product attention with multiple heads.
Parameters
----------
embed_size : int
Size of the embedding dimension.
heads : int
Number of attention heads.
Attributes
----------
embed_size : int
Embedding dimension.
heads : int
Number of attention heads.
head_dim : int
Dimension of each attention head (embed_size // heads).
values : nn.Linear
Linear layer for value projections.
keys : nn.Linear
Linear layer for key projections.
queries : nn.Linear
Linear layer for query projections.
fc_out : nn.Linear
Final output linear layer.
Examples
--------
>>> attention = SelfAttention(embed_size=128, heads=4)
>>> values = torch.randn(32, 10, 128)
>>> keys = torch.randn(32, 10, 128)
>>> query = torch.randn(32, 10, 128)
>>> out = attention(values, keys, query, mask=None)
>>> out.shape
torch.Size([32, 10, 128])
"""
[docs]
def __init__(self, embed_size: int, heads: int) -> None:
"""
Initialize the SelfAttention module.
Parameters
----------
embed_size : int
Size of the embedding dimension.
heads : int
Number of attention heads.
Raises
------
AssertionError
If embed_size is not divisible by heads.
"""
super(SelfAttention, self).__init__()
self.embed_size = embed_size
self.heads = heads
self.head_dim = embed_size // heads
assert (
self.head_dim * heads == embed_size
), "Embedding size needs to be divisible by heads"
self.values = nn.Linear(self.head_dim, self.head_dim, bias=True)
self.keys = nn.Linear(self.head_dim, self.head_dim, bias=True)
self.queries = nn.Linear(self.head_dim, self.head_dim, bias=True)
self.fc_out = nn.Linear(heads * self.head_dim, embed_size)
[docs]
def forward(
self,
values: torch.Tensor,
keys: torch.Tensor,
query: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Forward pass through the self-attention mechanism.
Parameters
----------
values : torch.Tensor
Value tensor of shape (batch_size, value_len, embed_size).
keys : torch.Tensor
Key tensor of shape (batch_size, key_len, embed_size).
query : torch.Tensor
Query tensor of shape (batch_size, query_len, embed_size).
mask : torch.Tensor, optional
Attention mask of shape (batch_size, 1, 1, key_len) or
(batch_size, query_len, key_len). Default is None.
Returns
-------
torch.Tensor
Output tensor of shape (batch_size, query_len, embed_size).
Notes
-----
The attention mechanism follows the formula:
Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) V
"""
N = query.shape[0]
value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
# Split embedding into heads
values = values.reshape(N, value_len, self.heads, self.head_dim)
keys = keys.reshape(N, key_len, self.heads, self.head_dim)
query = query.reshape(N, query_len, self.heads, self.head_dim)
# Apply linear projections
values = self.values(values)
keys = self.keys(keys)
queries = self.queries(query)
# Compute attention scores
energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
# Apply mask if provided
if mask is not None:
energy = energy.masked_fill(mask == 0, float("-1e20"))
# Scale and apply softmax
attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)
# Apply attention to values
out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
N, query_len, self.heads * self.head_dim
)
# Final projection
out = self.fc_out(out)
return out
[docs]
class Encoder(nn.Module):
"""
Transformer encoder for sequence-to-sequence processing.
This module applies positional encoding and a stack of transformer blocks
to transform input sequences.
Parameters
----------
feature_channel : int
Number of input features.
output_channel : int
Number of output channels.
embed_size : int
Size of the embedding dimension.
num_layers : int
Number of transformer blocks.
heads : int
Number of attention heads.
forward_expansion : int
Expansion factor for feed-forward networks.
seq_length : int
Length of the input sequence.
dropout : float
Dropout rate.
Attributes
----------
embed_size : int
Embedding dimension.
seq_length : int
Input sequence length.
first : nn.Linear
Initial linear projection.
first_act : nn.ReLU
Activation function.
position_embedding : nn.Embedding
Positional embeddings.
layers : nn.ModuleList
Stack of transformer blocks.
dropout : nn.Dropout
Dropout layer.
final : nn.Conv1d
Final convolution to map to output channels.
Examples
--------
>>> encoder = Encoder(
... feature_channel=6,
... output_channel=4,
... embed_size=64,
... num_layers=2,
... heads=4,
... forward_expansion=4,
... seq_length=10,
... dropout=0.1
... )
>>> x = torch.randn(32, 6, 10)
>>> y = encoder(x)
>>> y.shape
torch.Size([32, 4, 10])
"""
[docs]
def __init__(
self,
feature_channel: int,
output_channel: int,
embed_size: int,
num_layers: int,
heads: int,
forward_expansion: int,
seq_length: int,
dropout: float,
) -> None:
"""
Initialize the Transformer encoder.
Parameters
----------
feature_channel : int
Number of input features.
output_channel : int
Number of output channels.
embed_size : int
Size of the embedding dimension.
num_layers : int
Number of transformer blocks.
heads : int
Number of attention heads.
forward_expansion : int
Expansion factor for feed-forward networks.
seq_length : int
Length of the input sequence.
dropout : float
Dropout rate.
Raises
------
ValueError
If num_layers is less than 1.
"""
super(Encoder, self).__init__()
if num_layers < 1:
raise ValueError(f"num_layers must be at least 1, got {num_layers}")
self.embed_size = embed_size
self.seq_length = seq_length
# Initial projection from features to embeddings
self.first = nn.Linear(feature_channel, embed_size)
self.first_act = nn.ReLU()
# Positional embeddings
self.position_embedding = nn.Embedding(seq_length, embed_size)
# Stack of transformer blocks
self.layers = nn.ModuleList(
[
TransformerBlock(
embed_size,
heads,
dropout=dropout,
forward_expansion=forward_expansion,
)
for _ in range(num_layers)
]
)
self.dropout = nn.Dropout(dropout)
# Final projection to output channels
self.final = nn.Conv1d(
embed_size, output_channel, kernel_size=1, padding=0, bias=True
)
[docs]
def forward(
self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Forward pass through the transformer encoder.
Parameters
----------
x : torch.Tensor
Input tensor of shape (batch_size, feature_channel, seq_length).
mask : torch.Tensor, optional
Attention mask. Default is None.
Returns
-------
torch.Tensor
Output tensor of shape (batch_size, output_channel, seq_length).
Notes
-----
The forward pass:
1. Permutes input to (batch, seq, features)
2. Applies linear projection to embeddings
3. Adds positional embeddings
4. Passes through transformer blocks
5. Permutes back and applies final convolution
"""
# Permute to (batch, seq, features) for transformer
x = torch.permute(x, (0, 2, 1))
N = x.shape[0]
# Positional embeddings
positions = (
torch.arange(0, self.seq_length).expand(N, self.seq_length).to(x.device)
)
positions = self.position_embedding(positions)
# Initial projection and add positional embeddings
out = self.first_act(self.first(x))
out = out + positions
# Apply transformer blocks
for layer in self.layers:
out = layer(out, out, out, mask)
# Permute back and apply final convolution
out = torch.permute(out, (0, 2, 1))
out = self.final(out)
return out
[docs]
class EncoderTorch(nn.Module):
[docs]
def __init__(
self,
feature_channel: int,
output_channel: int,
embed_size: int,
num_layers: int,
heads: int,
forward_expansion: int,
seq_length: int,
dropout: float,
) -> None:
super().__init__()
if num_layers < 1:
raise ValueError(f"num_layers must be at least 1, got {num_layers}")
self.embed_size = embed_size
self.seq_length = seq_length
# Input projection
self.input_proj = nn.Linear(feature_channel, embed_size)
# Positional embedding
self.position_embedding = nn.Embedding(seq_length, embed_size)
# PyTorch TransformerEncoderLayer
encoder_layer = nn.TransformerEncoderLayer(
d_model=embed_size,
nhead=heads,
dim_feedforward=forward_expansion * embed_size,
dropout=dropout,
activation="relu", # matches ReLU
batch_first=True, # IMPORTANT (you used batch-first)
norm_first=False, # matches post-norm design
)
# Stack layers
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.dropout = nn.Dropout(dropout)
# Final projection (same as Conv1d)
self.final = nn.Conv1d(embed_size, output_channel, kernel_size=1)
[docs]
def forward(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
src_key_padding_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
x: (batch, feature_channel, seq_length)
"""
# (batch, seq, feature)
x = x.permute(0, 2, 1)
N, seq_len, _ = x.shape
# Positional encoding
positions = (
torch.arange(0, seq_len, device=x.device).unsqueeze(0).expand(N, seq_len)
)
pos_embed = self.position_embedding(positions)
# Input projection + position
x = self.input_proj(x)
x = x + pos_embed
x = self.dropout(x)
# Transformer encoder
x = self.encoder(x, mask=mask, src_key_padding_mask=src_key_padding_mask)
# Back to (batch, channels, seq)
x = x.permute(0, 2, 1)
x = self.final(x)
return x