Skip to content
Snippets Groups Projects
music_demixing.py 7.02 KiB
######################################################################################
### This is a read-only file to allow participants to run their code locally.      ###
### It will be over-writter during the evaluation, Please do not make any changes  ###
### to this file.                                                                  ###
######################################################################################

import traceback
import os
import signal
from contextlib import contextmanager
from os import listdir
from os.path import isfile, join

import soundfile as sf
import numpy as np
from evaluator import aicrowd_helpers


class TimeoutException(Exception): pass


@contextmanager
def time_limit(seconds):
    def signal_handler(signum, frame):
        raise TimeoutException("Prediction timed out!")

    signal.signal(signal.SIGALRM, signal_handler)
    signal.alarm(seconds)
    try:
        yield
    finally:
        signal.alarm(0)


class MusicDemixingPredictor:
    def __init__(self):
        self.test_data_path = os.getenv("TEST_DATASET_PATH", os.getcwd() + "/data/test/")
        self.results_data_path = os.getenv("RESULTS_DATASET_PATH", os.getcwd() + "/data/results/")
        self.inference_setup_timeout = int(os.getenv("INFERENCE_SETUP_TIMEOUT_SECONDS", "900"))
        self.inference_per_music_timeout = int(os.getenv("INFERENCE_PER_MUSIC_TIMEOUT_SECONDS", "240"))
        self.partial_run = os.getenv("PARTIAL_RUN_MUSIC_NAMES", None)
        self.results = []
        self.current_music_name = None

    def get_all_music_names(self):
        valid_music_names = None
        if self.partial_run:
            valid_music_names = self.partial_run.split(',')
        music_names = []
        for folder in listdir(self.test_data_path):
            if not isfile(join(self.test_data_path, folder)):
                if valid_music_names is None or folder in valid_music_names:
                    music_names.append(folder)
        return music_names

    def get_music_folder_location(self, music_name):
        return join(self.test_data_path, music_name)

    def get_music_file_location(self, music_name, instrument=None):
        if instrument is None:
            instrument = "mixture"
            return join(self.test_data_path, music_name, instrument + ".wav")

        if not os.path.exists(self.results_data_path):
            os.makedirs(self.results_data_path)
        if not os.path.exists(join(self.results_data_path, music_name)):
            os.makedirs(join(self.results_data_path, music_name))

        return join(self.results_data_path, music_name, instrument + ".wav")

    def scoring(self):
        """
        Add scoring function in the starter kit for participant's reference
        """
        def sdr(references, estimates):
            # compute SDR for one song
            delta = 1e-7  # avoid numerical errors
            num = np.sum(np.square(references), axis=(1, 2))
            den = np.sum(np.square(references - estimates), axis=(1, 2))
            num += delta
            den += delta
            return 10 * np.log10(num  / den)

        music_names = self.get_all_music_names()
        instruments = ["bass", "drums", "other", "vocals"]
        scores = {}
        for music_name in music_names:
            print("Evaluating for: %s" % music_name)
            scores[music_name] = {}
            references = []
            estimates = []
            for instrument in instruments:
                reference_file = join(self.test_data_path, music_name, instrument + ".wav")
                estimate_file = self.get_music_file_location(music_name, instrument)
                reference, _ = sf.read(reference_file)
                estimate, _ = sf.read(estimate_file)
                references.append(reference)
                estimates.append(estimate)
            references = np.stack(references)
            estimates = np.stack(estimates)
            references = references.astype(np.float32)
            estimates = estimates.astype(np.float32)
            song_score = sdr(references, estimates).tolist()
            scores[music_name]["sdr_bass"] = song_score[0]
            scores[music_name]["sdr_drums"] = song_score[1]
            scores[music_name]["sdr_other"] = song_score[2]
            scores[music_name]["sdr_vocals"] = song_score[3]
            scores[music_name]["sdr"] = np.mean(song_score)
        return scores


    def evaluation(self):
        """
        Admin function: Runs the whole evaluation
        """
        aicrowd_helpers.execution_start()
        try:
            with time_limit(self.inference_setup_timeout):
                self.prediction_setup()
        except NotImplementedError:
            print("prediction_setup doesn't exist for this run, skipping...")

        aicrowd_helpers.execution_running()

        music_names = self.get_all_music_names()

        for music_name in music_names:
            with time_limit(self.inference_per_music_timeout):
                self.prediction(mixture_file_path=self.get_music_file_location(music_name),
                                bass_file_path=self.get_music_file_location(music_name, "bass"),
                                drums_file_path=self.get_music_file_location(music_name, "drums"),
                                other_file_path=self.get_music_file_location(music_name, "other"),
                                vocals_file_path=self.get_music_file_location(music_name, "vocals"),
                )
                
            if not self.verify_results(music_name):
                raise Exception("verification failed, demixed files not found.")
        aicrowd_helpers.execution_success()

    def run(self):
        try:
            self.evaluation()
        except Exception as e:
            error = traceback.format_exc()
            print(error)
            aicrowd_helpers.execution_error(error)
            if not aicrowd_helpers.is_grading():
                raise e

    def prediction_setup(self):
        """
        You can do any preprocessing required for your codebase here : 
            like loading your models into memory, etc.
        """
        raise NotImplementedError

    def prediction(self, music_name, mixture_file_path, bass_file_path, drums_file_path, other_file_path,
                   vocals_file_path):
        """
        This function will be called for all the flight during the evaluation.
        NOTE: In case you want to load your model, please do so in `inference_setup` function.
        """
        raise NotImplementedError

    def verify_results(self, music_name):
        """
        This function will be called to check all the files exist and other verification needed.
        (like length of the wav files)
        """
        valid = True
        valid = valid and os.path.isfile(self.get_music_file_location(music_name, "vocals"))
        valid = valid and os.path.isfile(self.get_music_file_location(music_name, "bass"))
        valid = valid and os.path.isfile(self.get_music_file_location(music_name, "drums"))
        valid = valid and os.path.isfile(self.get_music_file_location(music_name, "other"))
        return valid