Commit 40051264 authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

agent class

parent 21ffa030
......@@ -130,3 +130,5 @@ dmypy.json
# Pyre type checker
.pyre/
nle_data/
import aicrowd_gym
import nle
import numpy as np
from tqdm import trange
from custom_wrappers import EarlyTerminationNethack
from batched_env import BactchedEnv
class BatchedAgent:
"""
Simple Batched agent interface
Main motivation is to speedup runs by increasing gpu utilization
"""
def __init__(self, num_envs):
"""
Setup your model
Load your weights etc
"""
self.num_envs = num_envs
def preprocess_observations(self, observations, rewards, dones, infos):
"""
Add any preprocessing steps, for example rerodering/stacking for torch/tf in your model
"""
pass
def batched_step(self):
"""
Return a list of actions
"""
pass
class RandomBatchedAgent(BatchedAgent):
def __init__(self, num_envs, num_actions):
super().__init__(num_envs)
self.num_actions = num_actions
self.seeded_state = np.random.RandomState(42)
def preprocess_observations(self, observations, rewards, dones, infos):
return observations, rewards, dones, infos
def batched_step(self, observations, rewards, dones, infos):
rets = self.preprocess_observations(observations, rewards, dones, infos)
observations, rewards, dones, infos = rets
actions = self.seeded_state.randint(self.num_actions, size=self.num_envs)
return actions
if __name__ == '__main__':
def nethack_make_fn():
# These settings will be fixed by the aicrowd evaluator
env = aicrowd_gym.make('NetHackChallenge-v0',
observation_keys=("glyphs",
"chars",
"colors",
"specials",
"blstats",
"message",
"tty_chars",
"tty_colors",
"tty_cursor",))
# This wrapper will always be added on the aicrowd evaluator
env = EarlyTerminationNethack(env=env,
minimum_score=1000,
cutoff_timesteps=50000)
# Add any wrappers you need
return env
# Change the num_envs as you need, for example reduce if your GPU doesn't fit
# but increasing above 32 is not advisable for the Nethack Challenge 2021
num_envs = 16
batched_env = BactchedEnv(env_make_fn=nethack_make_fn, num_envs=num_envs)
# This part can be left as is
observations = batched_env.batch_reset()
rewards = [0.0 for _ in range(num_envs)]
dones = [False for _ in range(num_envs)]
infos = [{} for _ in range(num_envs)]
# Change this to your agent interface
num_actions = batched_env.envs[0].action_space.n
agent = RandomBatchedAgent(num_envs, num_actions)
# The evaluation setup will automatically stop after the requisite number of rollouts
# But you can change this if you want
for _ in trange(1000000000000):
# Ideally this part can be left unchanged
actions = agent.batched_step(observations, rewards, dones, infos)
observations, rewards, dones, infos = batched_env.batch_step(actions)
for done_idx in np.where(dones)[0]:
observations[done_idx] = batched_env.single_env_reset(done_idx)
import gym
import nle
import numpy as np
from tqdm import trange
from collections.abc import Iterable
class BactchedEnv:
def __init__(self, env_make_fn, num_envs=32):
"""
Creates multiple copies of the environment with the same env_make_fn function
"""
self.num_envs = num_envs
self.envs = [env_make_fn() for _ in range(self.num_envs)]
# TODO: Can have different settings for each env? Probably not needed for Nethack
def batch_step(self, actions):
"""
Applies each action to each env in the same order as self.envs
Actions should be iterable and have the same length as self.envs
Returns lists of obsevations, rewards, dones, infos
"""
assert isinstance(
actions, Iterable), f"actions with type {type(actions)} is not iterable"
assert len(
actions) == self.num_envs, f"actions has length {len(actions)} which different from num_envs"
observations, rewards, dones, infos = [], [], [], []
for env, a in zip(self.envs, actions):
observation, reward, done, info = env.step(a)
if done:
observation = env.reset()
observations.append(observation)
rewards.append(reward)
dones.append(done)
infos.append(info)
return observations, rewards, dones, infos
def batch_reset(self):
"""
Resets all the environments in self.envs
"""
observation = [env.reset() for env in self.envs]
return observation
def single_env_reset(self, index):
"""
Resets the env at the index location
"""
observation = self.envs[index].reset()
return observation
# TODO: Add helper functions to format to tf or torch batching
if __name__ == '__main__':
def nethack_make_fn():
return gym.make('NetHackChallenge-v0',
observation_keys=("glyphs",
"chars",
"colors",
"specials",
"blstats",
"message",
"tty_chars",
"tty_colors",
"tty_cursor",))
num_envs = 16
batched_env = BactchedEnv(env_make_fn=nethack_make_fn, num_envs=num_envs)
observations = batched_env.batch_reset()
num_actions = batched_env.envs[0].action_space.n
for _ in trange(10000000000000):
actions = np.random.randint(num_actions, size=num_envs)
observations, rewards, dones, infos = batched_env.batch_step(actions)
for done_idx in np.where(dones)[0]:
observations[done_idx] = batched_env.single_env_reset(done_idx)
import gym
class EarlyTerminationNethack(gym.Wrapper):
"""
To limit the timesteps for "Beginner" agents
We terminate the episode early if
The minimum_score is not achieved without the cuttoff_timesteps
Participants should not edit this file
"""
def __init__(self, env, minimum_score=1000, cutoff_timesteps=50000):
super().__init__(env)
self._minimum_score = minimum_score
self._cuttoff_timesteps = cutoff_timesteps
self._elapsed_steps = None
self._score = None
def step(self, action):
assert self._elapsed_steps is not None, "Cannot call env.step() before calling reset()"
observation, reward, done, info = self.env.step(action)
self._elapsed_steps += 1
self._score += reward
if self._elapsed_steps > self._cuttoff_timesteps and \
self._score < self._minimum_score:
info['Early Termination'] = not done
done = True
return observation, reward, done, info
def reset(self, **kwargs):
self._elapsed_steps = 0
self._score = 0
return self.env.reset(**kwargs)
\ No newline at end of file
#!/usr/bin/env python
import aicrowd_api
import os
########################################################################
# Instatiate Event Notifier
########################################################################
aicrowd_events = aicrowd_api.events.AIcrowdEvents()
def execution_start():
########################################################################
# Register Evaluation Start event
########################################################################
aicrowd_events.register_event(
event_type=aicrowd_events.AICROWD_EVENT_INFO,
message="execution_started",
payload={
"event_type": "airborne_detection:execution_started"
}
)
def execution_running():
########################################################################
# Register Evaluation Start event
########################################################################
aicrowd_events.register_event(
event_type=aicrowd_events.AICROWD_EVENT_INFO,
message="execution_progress",
payload={
"event_type": "airborne_detection:execution_progress",
"progress": 0.0
}
)
def execution_progress(progress):
########################################################################
# Register Evaluation Progress event
########################################################################
aicrowd_events.register_event(
event_type=aicrowd_events.AICROWD_EVENT_INFO,
message="execution_progress",
payload={
"event_type": "airborne_detection:execution_progress",
"progress" : progress
}
)
def execution_success():
########################################################################
# Register Evaluation Complete event
########################################################################
predictions_output_path = os.getenv("PREDICTIONS_OUTPUT_PATH", False)
aicrowd_events.register_event(
event_type=aicrowd_events.AICROWD_EVENT_SUCCESS,
message="execution_success",
payload={
"event_type": "airborne_detection:execution_success",
"predictions_output_path" : predictions_output_path
},
blocking=True
)
def execution_error(error):
########################################################################
# Register Evaluation Complete event
########################################################################
aicrowd_events.register_event(
event_type=aicrowd_events.AICROWD_EVENT_ERROR,
message="execution_error",
payload={ #Arbitrary Payload
"event_type": "airborne_detection:execution_error",
"error" : error
},
blocking=True
)
def is_grading():
return os.getenv("AICROWD_IS_GRADING", False)
######################################################################################
### This is a read-only file to allow participants to run their code locally. ###
### It will be over-writter during the evaluation, Please do not make any changes ###
### to this file. ###
######################################################################################
import traceback
import os
import signal
from contextlib import contextmanager
from os import listdir
from os.path import isfile, join
import soundfile as sf
import numpy as np
from evaluator import aicrowd_helpers
class TimeoutException(Exception): pass
@contextmanager
def time_limit(seconds):
def signal_handler(signum, frame):
raise TimeoutException("Prediction timed out!")
signal.signal(signal.SIGALRM, signal_handler)
signal.alarm(seconds)
try:
yield
finally:
signal.alarm(0)
class MusicDemixingPredictor:
def __init__(self):
self.test_data_path = os.getenv("TEST_DATASET_PATH", os.getcwd() + "/data/test/")
self.results_data_path = os.getenv("RESULTS_DATASET_PATH", os.getcwd() + "/data/results/")
self.inference_setup_timeout = int(os.getenv("INFERENCE_SETUP_TIMEOUT_SECONDS", "900"))
self.inference_per_music_timeout = int(os.getenv("INFERENCE_PER_MUSIC_TIMEOUT_SECONDS", "240"))
self.partial_run = os.getenv("PARTIAL_RUN_MUSIC_NAMES", None)
self.results = []
self.current_music_name = None
def get_all_music_names(self):
valid_music_names = None
if self.partial_run:
valid_music_names = self.partial_run.split(',')
music_names = []
for folder in listdir(self.test_data_path):
if not isfile(join(self.test_data_path, folder)):
if valid_music_names is None or folder in valid_music_names:
music_names.append(folder)
return music_names
def get_music_folder_location(self, music_name):
return join(self.test_data_path, music_name)
def get_music_file_location(self, music_name, instrument=None):
if instrument is None:
instrument = "mixture"
return join(self.test_data_path, music_name, instrument + ".wav")
if not os.path.exists(self.results_data_path):
os.makedirs(self.results_data_path)
if not os.path.exists(join(self.results_data_path, music_name)):
os.makedirs(join(self.results_data_path, music_name))
return join(self.results_data_path, music_name, instrument + ".wav")
def scoring(self):
"""
Add scoring function in the starter kit for participant's reference
"""
def sdr(references, estimates):
# compute SDR for one song
delta = 1e-7 # avoid numerical errors
num = np.sum(np.square(references), axis=(1, 2))
den = np.sum(np.square(references - estimates), axis=(1, 2))
num += delta
den += delta
return 10 * np.log10(num / den)
music_names = self.get_all_music_names()
instruments = ["bass", "drums", "other", "vocals"]
scores = {}
for music_name in music_names:
print("Evaluating for: %s" % music_name)
scores[music_name] = {}
references = []
estimates = []
for instrument in instruments:
reference_file = join(self.test_data_path, music_name, instrument + ".wav")
estimate_file = self.get_music_file_location(music_name, instrument)
reference, _ = sf.read(reference_file)
estimate, _ = sf.read(estimate_file)
references.append(reference)
estimates.append(estimate)
references = np.stack(references)
estimates = np.stack(estimates)
references = references.astype(np.float32)
estimates = estimates.astype(np.float32)
song_score = sdr(references, estimates).tolist()
scores[music_name]["sdr_bass"] = song_score[0]
scores[music_name]["sdr_drums"] = song_score[1]
scores[music_name]["sdr_other"] = song_score[2]
scores[music_name]["sdr_vocals"] = song_score[3]
scores[music_name]["sdr"] = np.mean(song_score)
return scores
def evaluation(self):
"""
Admin function: Runs the whole evaluation
"""
aicrowd_helpers.execution_start()
try:
with time_limit(self.inference_setup_timeout):
self.prediction_setup()
except NotImplementedError:
print("prediction_setup doesn't exist for this run, skipping...")
aicrowd_helpers.execution_running()
music_names = self.get_all_music_names()
for music_name in music_names:
with time_limit(self.inference_per_music_timeout):
self.prediction(mixture_file_path=self.get_music_file_location(music_name),
bass_file_path=self.get_music_file_location(music_name, "bass"),
drums_file_path=self.get_music_file_location(music_name, "drums"),
other_file_path=self.get_music_file_location(music_name, "other"),
vocals_file_path=self.get_music_file_location(music_name, "vocals"),
)
if not self.verify_results(music_name):
raise Exception("verification failed, demixed files not found.")
aicrowd_helpers.execution_success()
def run(self):
try:
self.evaluation()
except Exception as e:
error = traceback.format_exc()
print(error)
aicrowd_helpers.execution_error(error)
if not aicrowd_helpers.is_grading():
raise e
def prediction_setup(self):
"""
You can do any preprocessing required for your codebase here :
like loading your models into memory, etc.
"""
raise NotImplementedError
def prediction(self, music_name, mixture_file_path, bass_file_path, drums_file_path, other_file_path,
vocals_file_path):
"""
This function will be called for all the flight during the evaluation.
NOTE: In case you want to load your model, please do so in `inference_setup` function.
"""
raise NotImplementedError
def verify_results(self, music_name):
"""
This function will be called to check all the files exist and other verification needed.
(like length of the wav files)
"""
valid = True
valid = valid and os.path.isfile(self.get_music_file_location(music_name, "vocals"))
valid = valid and os.path.isfile(self.get_music_file_location(music_name, "bass"))
valid = valid and os.path.isfile(self.get_music_file_location(music_name, "drums"))
valid = valid and os.path.isfile(self.get_music_file_location(music_name, "other"))
return valid
......@@ -3,3 +3,4 @@ aicrowd-gym
numpy
scipy
git+https://github.com/facebookresearch/nle.git@eric/competition --no-binary=nle
tqdm
#!/bin/bash
python rollout.py
python agent.py
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment