Skip to content
Snippets Groups Projects
Commit 4dd58fc4 authored by chinyun_yu_joey's avatar chinyun_yu_joey
Browse files

add unet

parent 708de4a6
No related branches found
No related tags found
No related merge requests found
#!/usr/bin/env python
import norbert
import numpy as np
import soundfile as sf
import torch
from torch_specinv import griffin_lim
from evaluator.music_demixing import MusicDemixingPredictor
from remote import download_large_file_from_google_drive
def stft(x, n_fft=4096, n_hopsize=1024):
window = torch.hann_window(n_fft, dtype=x.dtype, device=x.device)
X = torch.stft(
x,
n_fft,
n_hopsize,
n_fft,
window,
return_complex=True
)
return X
def istft(X, n_fft=4096, n_hopsize=1024):
dtype = X.dtype
if dtype == torch.complex32:
dtype = torch.float16
elif dtype == torch.complex64:
dtype = torch.float32
elif dtype == torch.complex128:
dtype = torch.float64
window = torch.hann_window(n_fft, dtype=dtype, device=X.device)
x = torch.istft(
X,
n_fft,
n_hopsize,
n_fft,
window,
)
return x
# Separation function - taken from
# https://github.com/asteroid-team/asteroid/blob/master/egs/musdb18/X-UMX/eval.py
def separate(
audio,
x_umx_target,
niter=1,
softmask=False,
alpha=1.0,
residual_model=False,
device="cpu",
):
"""
Performing the separation on audio input
Parameters
----------
audio: np.ndarray [shape=(nb_samples, nb_channels, nb_timesteps)]
mixture audio
x_umx_target: asteroid.models
X-UMX model used for separating
instruments: list
The list of instruments, e.g., ["bass", "drums", "vocals"]
niter: int
Number of EM steps for refining initial estimates in a
post-processing stage, defaults to 1.
softmask: boolean
if activated, then the initial estimates for the sources will
be obtained through a ratio mask of the mixture STFT, and not
by using the default behavior of reconstructing waveforms
by using the mixture phase, defaults to False
alpha: float
changes the exponent to use for building ratio masks, defaults to 1.0
residual_model: boolean
computes a residual target, for custom separation scenarios
when not all targets are available, defaults to False
device: str
set torch device. Defaults to `cpu`.
Returns
-------
estimates: `dict` [`str`, `np.ndarray`]
dictionary with all estimates obtained by the separation model.
"""
# convert numpy audio to torch
audio_torch = torch.tensor(audio.T).float().to(device)
X = stft(audio_torch)
with torch.no_grad():
masked_tf_rep = x_umx_target(X.abs().unsqueeze(0)).squeeze()
Y = X * masked_tf_rep
estimate = istft(Y)
estimates = estimate.unsqueeze(0).expand(4, -1, -1).numpy()
return estimates
class XUMXPredictor(MusicDemixingPredictor):
def prediction_setup(self):
# Load your model here and put it into `evaluation` mode
model_path, _ = download_large_file_from_google_drive(
"1GemMK3ETH5jJV-hnFfAyRCGrN2gDe5bo",
save_path="./models", save_name="unet.pth"
)
self.separator = torch.jit.load(model_path)
self.separator.eval()
def prediction(
self,
mixture_file_path,
bass_file_path,
drums_file_path,
other_file_path,
vocals_file_path,
):
# Step 1: Load mixture
# mixture is stereo with sample rate of 44.1kHz
x, rate = sf.read(mixture_file_path)
# Step 2: Pad mixture to compensate STFT truncation
x_padded = np.pad(x, ((0, 1024), (0, 0)))
# Step 3: Perform separation
estimates = separate(
x_padded,
self.separator,
)
# Step 4: Truncate to orignal length
estimates = estimates[..., :x.shape[0]]
# Step 5: Store results
target_file_map = {
"vocals": vocals_file_path,
"drums": drums_file_path,
"bass": bass_file_path,
"other": other_file_path,
}
for i, target in enumerate(['drums', 'bass', 'other', 'vocals']):
path = target_file_map[target]
sf.write(
path,
estimates[i].T,
rate
)
if __name__ == "__main__":
submission = XUMXPredictor()
submission.run()
print("Successfully generated predictions!")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment