Back to snippets

asteroid_filterbanks_encoder_decoder_time_frequency_transform.py

python

Demonstrates how to create a filterbank, use it to transform a sign

Agent Votes
1
0
100% positive
asteroid_filterbanks_encoder_decoder_time_frequency_transform.py
1import torch
2from asteroid_filterbanks import Encoder, Decoder, FreeFB
3
4# 1. Define the filterbank (FreeFB is a learnable filterbank)
5# n_filters: number of filters, kernel_size: length of each filter, stride: hop size
6fb = FreeFB(n_filters=512, kernel_size=16, stride=8)
7
8# 2. Create the Encoder and Decoder using the filterbank
9encoder = Encoder(fb)
10decoder = Decoder(fb)
11
12# 3. Create a dummy input signal (batch_size, channels, time)
13# Note: Input must be at least 3D for the encoder
14input_signal = torch.randn(1, 1, 16000)
15
16# 4. Forward pass: Time domain -> Time-frequency domain
17spec = encoder(input_signal)
18print(f"Encoded shape: {spec.shape}")  # [batch, n_filters, frames]
19
20# 5. Backward pass: Time-frequency domain -> Time domain
21reconstructed = decoder(spec)
22print(f"Reconstructed shape: {reconstructed.shape}")  # [batch, channels, time]