Forked from
nethack / neurips-2021-the-nethack-challenge
120 commits behind the upstream repository.
-
Shivam Khandelwal authoredShivam Khandelwal authored
music_demixing.py 7.02 KiB
######################################################################################
### This is a read-only file to allow participants to run their code locally. ###
### It will be over-writter during the evaluation, Please do not make any changes ###
### to this file. ###
######################################################################################
import traceback
import os
import signal
from contextlib import contextmanager
from os import listdir
from os.path import isfile, join
import soundfile as sf
import numpy as np
from evaluator import aicrowd_helpers
class TimeoutException(Exception): pass
@contextmanager
def time_limit(seconds):
def signal_handler(signum, frame):
raise TimeoutException("Prediction timed out!")
signal.signal(signal.SIGALRM, signal_handler)
signal.alarm(seconds)
try:
yield
finally:
signal.alarm(0)
class MusicDemixingPredictor:
def __init__(self):
self.test_data_path = os.getenv("TEST_DATASET_PATH", os.getcwd() + "/data/test/")
self.results_data_path = os.getenv("RESULTS_DATASET_PATH", os.getcwd() + "/data/results/")
self.inference_setup_timeout = int(os.getenv("INFERENCE_SETUP_TIMEOUT_SECONDS", "900"))
self.inference_per_music_timeout = int(os.getenv("INFERENCE_PER_MUSIC_TIMEOUT_SECONDS", "240"))
self.partial_run = os.getenv("PARTIAL_RUN_MUSIC_NAMES", None)
self.results = []
self.current_music_name = None
def get_all_music_names(self):
valid_music_names = None
if self.partial_run:
valid_music_names = self.partial_run.split(',')
music_names = []
for folder in listdir(self.test_data_path):
if not isfile(join(self.test_data_path, folder)):
if valid_music_names is None or folder in valid_music_names:
music_names.append(folder)
return music_names
def get_music_folder_location(self, music_name):
return join(self.test_data_path, music_name)
def get_music_file_location(self, music_name, instrument=None):
if instrument is None:
instrument = "mixture"
return join(self.test_data_path, music_name, instrument + ".wav")
if not os.path.exists(self.results_data_path):
os.makedirs(self.results_data_path)
if not os.path.exists(join(self.results_data_path, music_name)):
os.makedirs(join(self.results_data_path, music_name))
return join(self.results_data_path, music_name, instrument + ".wav")
def scoring(self):
"""
Add scoring function in the starter kit for participant's reference
"""
def sdr(references, estimates):
# compute SDR for one song
delta = 1e-7 # avoid numerical errors
num = np.sum(np.square(references), axis=(1, 2))
den = np.sum(np.square(references - estimates), axis=(1, 2))
num += delta
den += delta
return 10 * np.log10(num / den)
music_names = self.get_all_music_names()
instruments = ["bass", "drums", "other", "vocals"]
scores = {}
for music_name in music_names:
print("Evaluating for: %s" % music_name)
scores[music_name] = {}
references = []
estimates = []
for instrument in instruments:
reference_file = join(self.test_data_path, music_name, instrument + ".wav")
estimate_file = self.get_music_file_location(music_name, instrument)
reference, _ = sf.read(reference_file)
estimate, _ = sf.read(estimate_file)
references.append(reference)
estimates.append(estimate)
references = np.stack(references)
estimates = np.stack(estimates)
references = references.astype(np.float32)
estimates = estimates.astype(np.float32)
song_score = sdr(references, estimates).tolist()
scores[music_name]["sdr_bass"] = song_score[0]
scores[music_name]["sdr_drums"] = song_score[1]
scores[music_name]["sdr_other"] = song_score[2]
scores[music_name]["sdr_vocals"] = song_score[3]
scores[music_name]["sdr"] = np.mean(song_score)
return scores
def evaluation(self):
"""
Admin function: Runs the whole evaluation
"""
aicrowd_helpers.execution_start()
try:
with time_limit(self.inference_setup_timeout):
self.prediction_setup()
except NotImplementedError:
print("prediction_setup doesn't exist for this run, skipping...")
aicrowd_helpers.execution_running()
music_names = self.get_all_music_names()
for music_name in music_names:
with time_limit(self.inference_per_music_timeout):
self.prediction(mixture_file_path=self.get_music_file_location(music_name),
bass_file_path=self.get_music_file_location(music_name, "bass"),
drums_file_path=self.get_music_file_location(music_name, "drums"),
other_file_path=self.get_music_file_location(music_name, "other"),
vocals_file_path=self.get_music_file_location(music_name, "vocals"),
)
if not self.verify_results(music_name):
raise Exception("verification failed, demixed files not found.")
aicrowd_helpers.execution_success()
def run(self):
try:
self.evaluation()
except Exception as e:
error = traceback.format_exc()
print(error)
aicrowd_helpers.execution_error(error)
if not aicrowd_helpers.is_grading():
raise e
def prediction_setup(self):
"""
You can do any preprocessing required for your codebase here :
like loading your models into memory, etc.
"""
raise NotImplementedError
def prediction(self, music_name, mixture_file_path, bass_file_path, drums_file_path, other_file_path,
vocals_file_path):
"""
This function will be called for all the flight during the evaluation.
NOTE: In case you want to load your model, please do so in `inference_setup` function.
"""
raise NotImplementedError
def verify_results(self, music_name):
"""
This function will be called to check all the files exist and other verification needed.
(like length of the wav files)
"""
valid = True
valid = valid and os.path.isfile(self.get_music_file_location(music_name, "vocals"))
valid = valid and os.path.isfile(self.get_music_file_location(music_name, "bass"))
valid = valid and os.path.isfile(self.get_music_file_location(music_name, "drums"))
valid = valid and os.path.isfile(self.get_music_file_location(music_name, "other"))
return valid