Source code for rtnn.models.DimChangeModule

import torch
import torch.nn as nn


[docs] class DimChange(nn.Module): """ B*V*H -> B*V*(H + H_add) """
[docs] def __init__(self, channel_number, output_number): super(DimChange, self).__init__() self.channel_number = channel_number self.output_number = output_number self.conv_final = nn.Conv1d( channel_number, output_number, kernel_size=1, stride=1, padding=0, bias=True )
[docs] def forward(self, x): x = torch.permute(x, (0, 2, 1)) x = self.conv_final(x) x = torch.permute(x, (0, 2, 1)) return x