Skip to content
Snippets Groups Projects
Commit a4749f1a authored by Shivam Khandelwal's avatar Shivam Khandelwal
Browse files

Initial commit

parents
No related branches found
No related tags found
No related merge requests found
*.wav filter=lfs diff=lfs merge=lfs -text
*.pth filter=lfs diff=lfs merge=lfs -text
data/
shared/
logs/
.gradle/
*.pyc
.idea
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
LICENSE 0 → 100644
MIT License
Copyright (c) 2021 AIcrowd
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
README.md 0 → 100644
![Nethack Banner](https://raw.githubusercontent.com/facebookresearch/nle/master/dat/nle/logo.png)
# Nethack Challenge - Starter Kit
👉 [Challenge page](https://www.aicrowd.com/challenges/neurips-2021-nethack-challenge)
[![Discord](https://img.shields.io/discord/565639094860775436.svg)](https://discord.gg/fNRrSvZkry)
This repository is the Nethack Challenge **Submission template and Starter kit**!
Clone the repository to compete now!
**This repository contains**:
* **Documentation** on how to submit your models to the leaderboard
* **The procedure** for best practices and information on how we evaluate your agent, etc.
* **Starter code** for you to get started!
# Table of Contents
1. [Competition Procedure](#competition-procedure)
2. [How to access and use dataset](#how-to-access-and-use-dataset)
3. [How to start participating](#how-to-start-participating)
4. [How do I specify my software runtime / dependencies?](#how-do-i-specify-my-software-runtime-dependencies-)
5. [What should my code structure be like ?](#what-should-my-code-structure-be-like-)
6. [How to make submission](#how-to-make-submission)
7. [Other concepts](#other-concepts)
8. [Important links](#-important-links)
<p style="text-align:center"><img style="text-align:center" src="https://raw.githubusercontent.com/facebookresearch/nle/master/dat/nle/example_run.gif"></p>
# Competition Procedure
The NetHack Learning Environment (NLE) is a Reinforcement Learning environment presented at NeurIPS 2020. NLE is based on NetHack 3.6.6 and designed to provide a standard RL interface to the game, and comes with tasks that function as a first step to evaluate agents on this new environment. You can read more about NLE in the NeurIPS 2020 paper.
We are excited that this competition offers machine learning students, researchers and NetHack-bot builders the opportunity to participate in a grand challenge in AI without prohibitive computational costs—and we are eagerly looking forward to the wide variety of submissions.
**The following is a high level description of how this process works**
![](https://i.imgur.com/xzQkwKV.jpg)
1. **Sign up** to join the competition [on the AIcrowd website](https://www.aicrowd.com/challenges/neurips-2021-nethack-challenge).
2. **Clone** this repo and start developing your solution.
3. **Train** your models for audio seperation and write prediction code in `test.py`.
4. [**Submit**](#how-to-submit-a-model) your trained models to [AIcrowd Gitlab](https://gitlab.aicrowd.com) for evaluation [(full instructions below)](#how-to-submit-a-model). The automated evaluation setup will evaluate the submissions against the test dataset to compute and report the metrics on the leaderboard of the competition.
# How to run the environment
To be added
# How to start participating
## Setup
1. **Add your SSH key** to AIcrowd GitLab
You can add your SSH Keys to your GitLab account by going to your profile settings [here](https://gitlab.aicrowd.com/profile/keys). If you do not have SSH Keys, you will first need to [generate one](https://docs.gitlab.com/ee/ssh/README.html#generating-a-new-ssh-key-pair).
2. **Clone the repository**
```
git clone git@github.com:AIcrowd/neurips-2021-nethack-starter-kit.git
```
3. **Install** competition specific dependencies!
```
cd neurips-2021-nethack-starter-kit
pip3 install -r requirements.txt
```
4. Try out random prediction codebase present in `test.py`.
## How do I specify my software runtime / dependencies ?
We accept submissions with custom runtime, so you don't need to worry about which libraries or framework to pick from.
The configuration files typically include `requirements.txt` (pypi packages), `environment.yml` (conda environment), `apt.txt` (apt packages) or even your own `Dockerfile`.
You can check detailed information about the same in the 👉 [RUNTIME.md](/docs/RUNTIME.md) file.
## What should my code structure be like ?
Please follow the example structure as it is in the starter kit for the code structure.
The different files and directories have following meaning:
```
.
├── aicrowd.json # Submission meta information - like your username
├── apt.txt # Packages to be installed inside docker image
├── data # Your local dataset copy - you don't need to upload it (read DATASET.md)
├── requirements.txt # Python packages to be installed
├── test.py # IMPORTANT: Your testing/prediction code, must be derived from NethackSubmission (example in test.py)
└── utility # The utility scripts to provide smoother experience to you.
├── docker_build.sh
├── docker_run.sh
├── environ.sh
```
Finally, **you must specify an AIcrowd submission JSON in `aicrowd.json` to be scored!**
The `aicrowd.json` of each submission should contain the following content:
```json
{
"challenge_id": "evaluations-api-neurips-nethack",
"authors": ["your-aicrowd-username"],
"description": "(optional) description about your awesome agent",
"external_dataset_used": false
}
```
This JSON is used to map your submission to the challenge - so please remember to use the correct `challenge_id` as specified above.
## How to make submission
👉 [SUBMISSION.md](/docs/SUBMISSION.md)
**Best of Luck** :tada: :tada:
# Other Concepts
## Hardware and Time constraints
To be added.
## Local Run
To be added.
## Contributing
🙏 You can share your solutions or any other baselines by contributing directly to this repository by opening merge request.
- Add your implemntation as `test_<approach-name>.py`
- Test it out using `python test_<approach-name>.py`
- Add any documentation for your approach at top of your file.
- Import it in `predict.py`
- Create merge request! 🎉🎉🎉
## Contributors
- [Shivam Khandelwal](https://www.aicrowd.com/participants/shivam)
# 📎 Important links
💪 &nbsp;Challenge Page: https://www.aicrowd.com/challenges/neurips-2021-nethack-challenge
🗣️ &nbsp;Discussion Forum: https://www.aicrowd.com/challenges/neurips-2021-nethack-challenge/discussion
🏆 &nbsp;Leaderboard: https://www.aicrowd.com/challenges/neurips-2021-nethack-challenge/leaderboards
{
"challenge_id": "evaluations-api-neurips-nethack",,
"authors": [
"aicrowd-bot"
],
"external_dataset_used": false
}
build-essential
git
## Adding your runtime
This repository is a valid submission (and submission structure).
You can simply add your dependencies on top of this repository.
Few of the most common ways are as follows:
* `environment.yml` -- The _optional_ Anaconda environment specification.
As you add new requirements you can export your `conda` environment to this file!
```
conda env export --no-build > environment.yml
```
* **Create your new conda environment**
```sh
conda create --name music_demixing_challenge
conda activate music_demixing_challenge
```
* **Your code specific dependencies**
```sh
conda install <your-package>
```
* `requirements.txt` -- The `pip3` packages used by your inference code. **Note that dependencies specified by `environment.yml` take precedence over `requirements.txt`.** As you add new pip3 packages to your inference procedure either manually add them to `requirements.txt` or if your software runtime is simple, perform:
```
# Put ALL of the current pip3 packages on your system in the submission
>> pip3 freeze >> requirements.txt
>> cat requirementst.txt
aicrowd_api
coloredlogs
matplotlib
pandas
[...]
```
* `apt.txt` -- The Debian packages (via aptitude) used by your inference code!
These files are used to construct your **AIcrowd submission docker containers** in which your code will run.
In case you are advanced user, you can check other methods to specify the runtime [here](https://discourse.aicrowd.com/t/how-to-specify-runtime-environment-for-your-submission/2274), which includes adding your own `Dockerfile` directly.
# Making submission
This file will help you in making your first submission.
## Submission Entrypoint (where you write your code!)
The evaluator will execute `run.sh` for generating predictions, so please remember to include it in your submission!
The inline documentation of `test.py` will guide you with interfacing with the codebase properly. You can check TODOs inside it to learn about the functions you need to implement.
You can modify the existing `test.py` OR copy it (to say `your_code.py`) and change it.
## IMPORTANT: Saving Models before submission!
Before you submit make sure that you have saved your models, which are needed by your inference code.
In case your files are larger in size you can use `git-lfs` to upload them. More details [here](https://discourse.aicrowd.com/t/how-to-upload-large-files-size-to-your-submission/2304).
## How to submit a trained model!
To make a submission, you will have to create a **private** repository on [https://gitlab.aicrowd.com/](https://gitlab.aicrowd.com/).
You will have to add your SSH Keys to your GitLab account by going to your profile settings [here](https://gitlab.aicrowd.com/profile/keys). If you do not have SSH Keys, you will first need to [generate one](https://docs.gitlab.com/ee/ssh/README.html#generating-a-new-ssh-key-pair).
Then you can create a submission by making a _tag push_ to your repository on [https://gitlab.aicrowd.com/](https://gitlab.aicrowd.com/).
**Any tag push (where the tag name begins with "submission-") to your private repository is considered as a submission**
Then you can add the correct git remote, and finally submit by doing :
```
cd neurips-2021-nethack-starter-kit
# Add AIcrowd git remote endpoint
git remote add aicrowd git@gitlab.aicrowd.com:<YOUR_AICROWD_USER_NAME>/neurips-2021-nethack-starter-kit.git
git push aicrowd master
```
```
# Create a tag for your submission and push
git tag -am "submission-v0.1" submission-v0.1
git push aicrowd master
git push aicrowd submission-v0.1
# Note : If the contents of your repository (latest commit hash) does not change,
# then pushing a new tag will **not** trigger a new evaluation.
```
You now should be able to see the details of your submission at :
[gitlab.aicrowd.com/<YOUR_AICROWD_USER_NAME>/neurips-2021-nethack-starter-kit/issues](https://gitlab.aicrowd.com//<YOUR_AICROWD_USER_NAME>/neurips-2021-nethack-starter-kit/issues)
**NOTE**: Remember to update your username instead of `<YOUR_AICROWD_USER_NAME>` above :wink:
### Other helpful files
👉 [RUNTIME.md](/docs/RUNTIME.md)
#!/usr/bin/env python
import aicrowd_api
import os
########################################################################
# Instatiate Event Notifier
########################################################################
aicrowd_events = aicrowd_api.events.AIcrowdEvents()
def execution_start():
########################################################################
# Register Evaluation Start event
########################################################################
aicrowd_events.register_event(
event_type=aicrowd_events.AICROWD_EVENT_INFO,
message="execution_started",
payload={
"event_type": "airborne_detection:execution_started"
}
)
def execution_running():
########################################################################
# Register Evaluation Start event
########################################################################
aicrowd_events.register_event(
event_type=aicrowd_events.AICROWD_EVENT_INFO,
message="execution_progress",
payload={
"event_type": "airborne_detection:execution_progress",
"progress": 0.0
}
)
def execution_progress(progress):
########################################################################
# Register Evaluation Progress event
########################################################################
aicrowd_events.register_event(
event_type=aicrowd_events.AICROWD_EVENT_INFO,
message="execution_progress",
payload={
"event_type": "airborne_detection:execution_progress",
"progress" : progress
}
)
def execution_success():
########################################################################
# Register Evaluation Complete event
########################################################################
predictions_output_path = os.getenv("PREDICTIONS_OUTPUT_PATH", False)
aicrowd_events.register_event(
event_type=aicrowd_events.AICROWD_EVENT_SUCCESS,
message="execution_success",
payload={
"event_type": "airborne_detection:execution_success",
"predictions_output_path" : predictions_output_path
},
blocking=True
)
def execution_error(error):
########################################################################
# Register Evaluation Complete event
########################################################################
aicrowd_events.register_event(
event_type=aicrowd_events.AICROWD_EVENT_ERROR,
message="execution_error",
payload={ #Arbitrary Payload
"event_type": "airborne_detection:execution_error",
"error" : error
},
blocking=True
)
def is_grading():
return os.getenv("AICROWD_IS_GRADING", False)
######################################################################################
### This is a read-only file to allow participants to run their code locally. ###
### It will be over-writter during the evaluation, Please do not make any changes ###
### to this file. ###
######################################################################################
import traceback
import os
import signal
from contextlib import contextmanager
from os import listdir
from os.path import isfile, join
import soundfile as sf
import numpy as np
from evaluator import aicrowd_helpers
class TimeoutException(Exception): pass
@contextmanager
def time_limit(seconds):
def signal_handler(signum, frame):
raise TimeoutException("Prediction timed out!")
signal.signal(signal.SIGALRM, signal_handler)
signal.alarm(seconds)
try:
yield
finally:
signal.alarm(0)
class MusicDemixingPredictor:
def __init__(self):
self.test_data_path = os.getenv("TEST_DATASET_PATH", os.getcwd() + "/data/test/")
self.results_data_path = os.getenv("RESULTS_DATASET_PATH", os.getcwd() + "/data/results/")
self.inference_setup_timeout = int(os.getenv("INFERENCE_SETUP_TIMEOUT_SECONDS", "900"))
self.inference_per_music_timeout = int(os.getenv("INFERENCE_PER_MUSIC_TIMEOUT_SECONDS", "240"))
self.partial_run = os.getenv("PARTIAL_RUN_MUSIC_NAMES", None)
self.results = []
self.current_music_name = None
def get_all_music_names(self):
valid_music_names = None
if self.partial_run:
valid_music_names = self.partial_run.split(',')
music_names = []
for folder in listdir(self.test_data_path):
if not isfile(join(self.test_data_path, folder)):
if valid_music_names is None or folder in valid_music_names:
music_names.append(folder)
return music_names
def get_music_folder_location(self, music_name):
return join(self.test_data_path, music_name)
def get_music_file_location(self, music_name, instrument=None):
if instrument is None:
instrument = "mixture"
return join(self.test_data_path, music_name, instrument + ".wav")
if not os.path.exists(self.results_data_path):
os.makedirs(self.results_data_path)
if not os.path.exists(join(self.results_data_path, music_name)):
os.makedirs(join(self.results_data_path, music_name))
return join(self.results_data_path, music_name, instrument + ".wav")
def scoring(self):
"""
Add scoring function in the starter kit for participant's reference
"""
def sdr(references, estimates):
# compute SDR for one song
delta = 1e-7 # avoid numerical errors
num = np.sum(np.square(references), axis=(1, 2))
den = np.sum(np.square(references - estimates), axis=(1, 2))
num += delta
den += delta
return 10 * np.log10(num / den)
music_names = self.get_all_music_names()
instruments = ["bass", "drums", "other", "vocals"]
scores = {}
for music_name in music_names:
print("Evaluating for: %s" % music_name)
scores[music_name] = {}
references = []
estimates = []
for instrument in instruments:
reference_file = join(self.test_data_path, music_name, instrument + ".wav")
estimate_file = self.get_music_file_location(music_name, instrument)
reference, _ = sf.read(reference_file)
estimate, _ = sf.read(estimate_file)
references.append(reference)
estimates.append(estimate)
references = np.stack(references)
estimates = np.stack(estimates)
references = references.astype(np.float32)
estimates = estimates.astype(np.float32)
song_score = sdr(references, estimates).tolist()
scores[music_name]["sdr_bass"] = song_score[0]
scores[music_name]["sdr_drums"] = song_score[1]
scores[music_name]["sdr_other"] = song_score[2]
scores[music_name]["sdr_vocals"] = song_score[3]
scores[music_name]["sdr"] = np.mean(song_score)
return scores
def evaluation(self):
"""
Admin function: Runs the whole evaluation
"""
aicrowd_helpers.execution_start()
try:
with time_limit(self.inference_setup_timeout):
self.prediction_setup()
except NotImplementedError:
print("prediction_setup doesn't exist for this run, skipping...")
aicrowd_helpers.execution_running()
music_names = self.get_all_music_names()
for music_name in music_names:
with time_limit(self.inference_per_music_timeout):
self.prediction(mixture_file_path=self.get_music_file_location(music_name),
bass_file_path=self.get_music_file_location(music_name, "bass"),
drums_file_path=self.get_music_file_location(music_name, "drums"),
other_file_path=self.get_music_file_location(music_name, "other"),
vocals_file_path=self.get_music_file_location(music_name, "vocals"),
)
if not self.verify_results(music_name):
raise Exception("verification failed, demixed files not found.")
aicrowd_helpers.execution_success()
def run(self):
try:
self.evaluation()
except Exception as e:
error = traceback.format_exc()
print(error)
aicrowd_helpers.execution_error(error)
if not aicrowd_helpers.is_grading():
raise e
def prediction_setup(self):
"""
You can do any preprocessing required for your codebase here :
like loading your models into memory, etc.
"""
raise NotImplementedError
def prediction(self, music_name, mixture_file_path, bass_file_path, drums_file_path, other_file_path,
vocals_file_path):
"""
This function will be called for all the flight during the evaluation.
NOTE: In case you want to load your model, please do so in `inference_setup` function.
"""
raise NotImplementedError
def verify_results(self, music_name):
"""
This function will be called to check all the files exist and other verification needed.
(like length of the wav files)
"""
valid = True
valid = valid and os.path.isfile(self.get_music_file_location(music_name, "vocals"))
valid = valid and os.path.isfile(self.get_music_file_location(music_name, "bass"))
valid = valid and os.path.isfile(self.get_music_file_location(music_name, "drums"))
valid = valid and os.path.isfile(self.get_music_file_location(music_name, "other"))
return valid
run.sh 0 → 100755
#!/bin/bash
python predict.py
test.py 0 → 100644
#!/usr/bin/env python
# This file is the entrypoint for your submission.
# You can modify this file to include your code or directly call your functions/modules from here.
def main():
"""
This function will be called for training phase.
"""
# Sample code for illustration, add your training code below
env = gym.make("NetHackScore-v0")
# actions = [env.action_space.sample() for _ in range(10)] # Just doing 10 samples in this example
# xposes = []
# for _ in range(1):
# obs = env.reset()
# done = False
# netr = 0
# # Limiting our code to 1024 steps in this example, you can do "while not done" to run till end
# while not done:
# To get better view in your training phase, it is suggested
# to register progress continuously, example when 54% completed
# aicrowd_helper.register_progress(0.54)
# Save trained model to train/ directory
# Training 100% Completed
aicrowd_helper.register_progress(1)
#env.close()
if __name__ == "__main__":
main()
train.py 0 → 100644
#!/usr/bin/env python
# This file is the entrypoint for your submission.
# You can modify this file to include your code or directly call your functions/modules from here.
def main():
"""
This function will be called for training phase.
"""
# Sample code for illustration, add your training code below
env = gym.make("NetHackScore-v0")
# actions = [env.action_space.sample() for _ in range(10)] # Just doing 10 samples in this example
# xposes = []
# for _ in range(1):
# obs = env.reset()
# done = False
# netr = 0
# # Limiting our code to 1024 steps in this example, you can do "while not done" to run till end
# while not done:
# To get better view in your training phase, it is suggested
# to register progress continuously, example when 54% completed
# aicrowd_helper.register_progress(0.54)
# Save trained model to train/ directory
# Training 100% Completed
aicrowd_helper.register_progress(1)
#env.close()
if __name__ == "__main__":
main()
#!/bin/bash
if [ -e environ_secret.sh ]
then
source utility/environ_secret.sh
else
source utility/environ.sh
fi
if ! [ -x "$(command -v aicrowd-repo2docker)" ]; then
echo 'Error: aicrowd-repo2docker is not installed.' >&2
echo 'Please install it using requirements.txt or pip install -U aicrowd-repo2docker' >&2
exit 1
fi
# Expected Env variables : in environ.sh
REPO2DOCKER="$(which aicrowd-repo2docker)"
sudo ${REPO2DOCKER} --no-run \
--user-id 1001 \
--user-name aicrowd \
--image-name ${IMAGE_NAME}:${IMAGE_TAG} \
--debug .
#!/bin/bash
# This script run your submission inside a docker image, this is identical in termrs of
# how your code will be executed on AIcrowd platform
if [ -e environ_secret.sh ]
then
echo "Note: Gathering environment variables from environ_secret.sh"
source utility/environ_secret.sh
else
echo "Note: Gathering environment variables from environ.sh"
source utility/environ.sh
fi
# Skip building docker image on run, by default each run means new docker image build
if [[ " $@ " =~ " --no-build " ]]; then
echo "Skipping docker image build"
else
echo "Building docker image, for skipping docker image build use \"--no-build\""
./utility/docker_build.sh
fi
# Expected Env variables : in environ.sh
sudo docker run \
--net=host \
--user 0 \
-e AICROWD_IS_GRADING=True \
-e AICROWD_DEBUG_MODE=True \
-it ${IMAGE_NAME}:${IMAGE_TAG} \
/bin/bash
\ No newline at end of file
#!/bin/bash
export IMAGE_NAME="aicrowd/music-demixing-challenge"
export IMAGE_TAG="local"
#!/usr/bin/env python3
print("Evaluator script to test predictions locally to be added here.")
import os, sys
import requests
import zipfile
sys.path.append(os.path.dirname(os.path.realpath(os.getcwd())))
sys.path.append(os.path.realpath(os.getcwd()))
DATASET_FILE_NAME = 'download.zip'
DATASET_DOWNLOAD_URL = 'https://zenodo.org/record/3270814/files/MUSDB18-7-WAV.zip?download=1'
DATASET_FULL_DOWNLOAD_URL = 'https://zenodo.org/record/3338373/files/musdb18hq.zip?download=1'
def download_dataset(full=False):
dn_url = DATASET_DOWNLOAD_URL
if full:
dn_url = DATASET_FULL_DOWNLOAD_URL
r = requests.get(dn_url, stream=True)
with open(DATASET_FILE_NAME, 'wb') as fd:
for chunk in r.iter_content(chunk_size=256):
fd.write(chunk)
def unzip_dataset():
with zipfile.ZipFile(DATASET_FILE_NAME, 'r') as zip_ref:
zip_ref.extractall('data/')
def cleanup():
os.remove(DATASET_FILE_NAME)
def verify_dataset():
assert os.path.isdir("data/train") and os.path.isdir("data/test"), "Dataset folder not found"
assert os.path.isdir("data/train/Hollow Ground - Left Blind"), "Random song check in training folder failed"
assert os.path.isdir("data/test/Louis Cressy Band - Good Time"), "Random song check in testing folder failed"
def move_to_git_root():
if not os.path.exists(os.path.join(os.getcwd(), ".git")):
os.chdir("..")
assert os.path.exists(os.path.join(os.getcwd(), ".git")), "Unable to reach to repository root"
if __name__ == "__main__":
move_to_git_root()
try:
verify_dataset()
except AssertionError:
print("Dataset not found...")
option = input("Download full dataset (y/Y) or 7s dataset (n/N)? ")
if option.lower() == 'y':
print("Downloading full dataset...")
download_dataset(full=True)
else:
print("Downloading 7s dataset...")
download_dataset()
print("Unzipping the dataset...")
unzip_dataset()
print("Cleaning up...")
#cleanup()
print("Verifying the dataset...")
verify_dataset()
print("Done.")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment