Source code for rtnn.models.DimChangeModule
import torch
import torch.nn as nn
[docs]
class DimChange(nn.Module):
"""
B*V*H -> B*V*(H + H_add)
"""
[docs]
def __init__(self, channel_number, output_number):
super(DimChange, self).__init__()
self.channel_number = channel_number
self.output_number = output_number
self.conv_final = nn.Conv1d(
channel_number, output_number, kernel_size=1, stride=1, padding=0, bias=True
)
[docs]
def forward(self, x):
x = torch.permute(x, (0, 2, 1))
x = self.conv_final(x)
x = torch.permute(x, (0, 2, 1))
return x