Skip to content
Snippets Groups Projects
Commit 09f2e495 authored by Chin-Yun Yu's avatar Chin-Yun Yu
Browse files

feat: speech rnn

parent c840afe9
No related branches found
No related tags found
No related merge requests found
......@@ -36,7 +36,8 @@ class EnsembleNet(IdentitySeparationModel):
"""
hdemucs_path = "./my_submission/lightning_logs/hdemucs-64-sdr/"
bandsplitRNN_path = "./my_submission/lightning_logs/bandsplitRNN-music/"
music_bandsplitRNN_path = "./my_submission/lightning_logs/bandsplitRNN-music/"
speech_bandsplitRNN_path = "./my_submission/lightning_logs/bandsplitRNN-speech/"
hdemucs_ckpt_name = "last.ckpt"
bandsplitRNN_ckpt_name = "last.ckpt"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
......@@ -52,9 +53,14 @@ class EnsembleNet(IdentitySeparationModel):
self.hdemucs = model.to(self.device)
model, config = load_checkpoint(
self.bandsplitRNN_path, self.bandsplitRNN_ckpt_name
self.music_bandsplitRNN_path, self.bandsplitRNN_ckpt_name
)
self.bandsplitRNN = model.to(self.device)
self.rnn_music = model.to(self.device)
model, config = load_checkpoint(
self.speech_bandsplitRNN_path, self.bandsplitRNN_ckpt_name
)
self.rnn_speech = model.to(self.device)
n_fft = config["model"]["init_args"]["n_fft"]
hop_length = config["model"]["init_args"]["hop_length"]
......@@ -132,9 +138,11 @@ class EnsembleNet(IdentitySeparationModel):
# bandsplitRNN
X = self.spec(mixed_sound_array[:2].transpose(0, 1))
mask = self.bandsplitRNN(X.abs())
music_mask = self.rnn_music(X.abs())
speech_mask = self.rnn_speech(X.abs())
mask = torch.stack([music_mask, speech_mask], dim=1)
# speech_spec = X * mask
speech_spec = self.mwf(mask.unsqueeze(1), X)
speech_spec = self.mwf(mask, X)
separated = self.inv_spec(speech_spec).squeeze().cpu()
if separated.shape[-1] < mixed_sound_array.shape[-1]:
......@@ -142,12 +150,20 @@ class EnsembleNet(IdentitySeparationModel):
separated, (0, mixed_sound_array.shape[-1] - separated.shape[-1])
)
separated_music = separated.permute(0, 2, 1).numpy()[0]
separated_music, separated_speech, separated_fx = separated.permute(
0, 2, 1
).numpy()
# input_length = len(left_mixed_arr)
separated_music_arrays["music"] += separated_music
separated_music_arrays["music"] /= 2
separated_music_arrays["effect"] += separated_fx
separated_music_arrays["effect"] /= 2
separated_music_arrays["dialog"] = separated_speech
separated_music_arrays["dialog"] /= 2
return separated_music_arrays, output_sample_rates
......
File added
# pytorch_lightning==1.8.5.post0
seed_everything: 2434
trainer:
logger: true
enable_checkpointing: true
callbacks:
- class_path: pytorch_lightning.callbacks.ModelCheckpoint
init_args:
dirpath: null
filename: null
monitor: null
verbose: false
save_last: true
save_top_k: 1
save_weights_only: false
mode: min
auto_insert_metric_name: true
every_n_train_steps: 2000
train_time_interval: null
every_n_epochs: null
save_on_train_epoch_end: null
- class_path: pytorch_lightning.callbacks.ModelCheckpoint
init_args:
dirpath: null
filename: null
monitor: null
verbose: false
save_last: null
save_top_k: -1
save_weights_only: true
mode: min
auto_insert_metric_name: true
every_n_train_steps: 1000
train_time_interval: null
every_n_epochs: null
save_on_train_epoch_end: null
- class_path: pytorch_lightning.callbacks.ModelSummary
init_args:
max_depth: 2
default_root_dir: /import/c4dm-datasets-ext/ycy_sdx23_checkpoints/
gradient_clip_val: null
gradient_clip_algorithm: null
num_nodes: 1
num_processes: null
devices: null
gpus: null
auto_select_gpus: false
tpu_cores: null
ipus: null
enable_progress_bar: true
overfit_batches: 0.0
track_grad_norm: -1
check_val_every_n_epoch: 1
fast_dev_run: false
accumulate_grad_batches: null
max_epochs: null
min_epochs: null
max_steps: -1
min_steps: null
max_time: null
limit_train_batches: null
limit_val_batches: 0
limit_test_batches: null
limit_predict_batches: null
val_check_interval: null
log_every_n_steps: 1
accelerator: gpu
strategy: <pytorch_lightning.strategies.ddp.DDPStrategy object at 0x7f2d054b3ca0>
sync_batchnorm: false
precision: 32
enable_model_summary: true
num_sanity_val_steps: 0
resume_from_checkpoint: null
profiler: null
benchmark: null
deterministic: null
reload_dataloaders_every_n_epochs: 0
auto_lr_find: false
replace_sampler_ddp: true
detect_anomaly: false
auto_scale_batch_size: false
plugins: null
amp_backend: native
amp_level: null
move_metrics_to_cpu: false
multiple_trainloader_mode: max_size_cycle
inference_mode: true
ckpt_path: null
model:
class_path: aimless.lightning.freq_mask.MaskPredictor
init_args:
model:
class_path: aimless.models.band_split_rnn.BandSplitRNN
init_args:
n_fft: 4096
split_freqs:
- 100
- 200
- 300
- 400
- 500
- 600
- 700
- 800
- 900
- 1000
- 1250
- 1500
- 1750
- 2000
- 2250
- 2500
- 2750
- 3000
- 3250
- 3500
- 3750
- 4000
- 4500
- 5000
- 5500
- 6000
- 6500
- 7000
- 7500
- 8000
- 9000
- 10000
- 11000
- 12000
- 13000
- 14000
- 15000
- 16000
- 18000
- 20000
- 22000
hidden_size: 128
num_layers: 12
norm_groups: 4
criterion:
class_path: aimless.loss.freq.MDLoss
init_args:
mcoeff: 10
n_fft: 4096
hop_length: 1024
transforms:
- class_path: aimless.augment.SpeedPerturb
init_args:
orig_freq: 44100
speeds:
- 90
- 100
- 110
p: 0.2
- class_path: aimless.augment.RandomPitch
init_args:
semitones:
- -1
- 1
- 0
- 1
- 2
n_fft: 2048
hop_length: 512
p: 0.2
target_track: sdx
targets:
speech: null
n_fft: 4096
hop_length: 1024
residual_model: true
softmask: false
alpha: 1.0
n_iter: 1
data:
class_path: data.lightning.DnR
init_args:
root: /import/c4dm-datasets-ext/sdx-2023/dnr_v2/dnr_v2/
seq_duration: 3.0
samples_per_track: 144
random: true
include_val: true
random_track_mix: true
transforms:
- class_path: data.augment.RandomGain
init_args:
low: 0.25
high: 1.25
p: 1.0
- class_path: data.augment.RandomFlipPhase
init_args:
p: 0.5
- class_path: data.augment.RandomSwapLR
init_args:
p: 0.5
batch_size: 16
optimizer:
class_path: torch.optim.Adam
init_args:
lr: 0.0003
betas:
- 0.9
- 0.999
eps: 1.0e-08
weight_decay: 0
amsgrad: false
foreach: null
maximize: false
capturable: false
differentiable: false
fused: false
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment