Skip to content
Snippets Groups Projects
Commit 20b31367 authored by Chin-Yun Yu's avatar Chin-Yun Yu
Browse files

try wiener filtering

parent 1b41576a
No related branches found
No related tags found
No related merge requests found
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F
from importlib import import_module from importlib import import_module
from .identity_separation_model import IdentitySeparationModel from .identity_separation_model import IdentitySeparationModel
import os import os
import yaml import yaml
from aimless.utils import MWF
from torchaudio.transforms import Spectrogram, InverseSpectrogram
class HDemucs(IdentitySeparationModel): class HDemucs(IdentitySeparationModel):
...@@ -12,7 +15,7 @@ class HDemucs(IdentitySeparationModel): ...@@ -12,7 +15,7 @@ class HDemucs(IdentitySeparationModel):
""" """
version_path = "./my_submission/lightning_logs/hdemucs-64/" version_path = "./my_submission/lightning_logs/hdemucs-64/"
ckpt_name = "epoch=3-step=52000.ckpt" ckpt_name = "epoch=2-step=40000.ckpt"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
instruments_idx = { instruments_idx = {
"dialog": 2, "dialog": 2,
...@@ -72,3 +75,51 @@ class HDemucs(IdentitySeparationModel): ...@@ -72,3 +75,51 @@ class HDemucs(IdentitySeparationModel):
output_sample_rates[instrument] = sample_rate output_sample_rates[instrument] = sample_rate
return separated_music_arrays, output_sample_rates return separated_music_arrays, output_sample_rates
class HDemucsMWF(HDemucs):
n_fft = 4096
hop_length = 1024
def __init__(self):
super().__init__()
self.spec = Spectrogram(
n_fft=self.n_fft, hop_length=self.hop_length, power=None
).to(self.device)
self.inv_spec = InverseSpectrogram(
n_fft=self.n_fft, hop_length=self.hop_length
).to(self.device)
self.mwf = MWF(softmask=True).to(self.device)
@torch.no_grad()
def separate_music_file(self, mixed_sound_array, sample_rate):
mixed_sound_array = (
torch.from_numpy(mixed_sound_array.T).float().to(self.device).unsqueeze(1)
)
# B, C, S, T
seperated = self.hdemucs(mixed_sound_array).transpose(0, 2)
# B, S, C, T
mask_hat = self.spec(seperated).abs()
# B, S, C, F, T
mix_spec = self.spec(mixed_sound_array.transpose(0, 1))
# B, C, F, T
seperated = self.inv_spec(self.mwf(mask_hat, mix_spec)).squeeze().cpu()
if seperated.shape[-1] < mixed_sound_array.shape[-1]:
seperated = F.pad(
seperated, (0, mixed_sound_array.shape[-1] - seperated.shape[-1])
)
seperated = seperated.transpose(1, 2).numpy()
# input_length = len(left_mixed_arr)
separated_music_arrays = {}
output_sample_rates = {}
for instrument in self.instruments:
separated_music_arrays[instrument] = seperated[
self.instruments_idx[instrument]
]
output_sample_rates[instrument] = sample_rate
return separated_music_arrays, output_sample_rates
...@@ -5,6 +5,6 @@ ...@@ -5,6 +5,6 @@
# MySeparationModel = ScaledIdentitySeparationModel # MySeparationModel = ScaledIdentitySeparationModel
# from my_submission.cocktail_fork_separation_model import CocktailForkSeparationModel # from my_submission.cocktail_fork_separation_model import CocktailForkSeparationModel
from my_submission.hdemucs import HDemucs from my_submission.hdemucs import HDemucs, HDemucsMWF
MySeparationModel = HDemucs MySeparationModel = HDemucsMWF
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment