diff --git a/rollout.py b/rollout.py index 2176e5328e5508476b9c0df4ab8e4d10b5f61ec3..a9fc5c1bb7b8eae7795724a6920f05663718d1e3 100644 --- a/rollout.py +++ b/rollout.py @@ -1,17 +1,22 @@ #!/usr/bin/env python -############################################################ -## Ideally you shouldn't need to change this file at all ## -############################################################ - +################################################################ +## Ideally you shouldn't need to change this file at all ## +## ## +## This file generates the rollouts, with the specific agent, ## +## batch_size and wrappers specified in subminssion_config.py ## +################################################################ +from tqdm import tqdm import numpy as np from envs.batched_env import BactchedEnv from submission_config import SubmissionConfig +NUM_ASSESSMENTS = 512 + def run_batched_rollout(batched_env, agent): """ - This function will be called the rollout + This function will generate a series of rollouts in a batched manner. """ num_envs = batched_env.num_envs @@ -22,26 +27,36 @@ def run_batched_rollout(batched_env, agent): dones = [False for _ in range(num_envs)] infos = [{} for _ in range(num_envs)] + # We assign each environment a fixed number of episodes at the start + envs_each = NUM_ASSESSMENTS // num_envs + remainders = NUM_ASSESSMENTS % num_envs + episodes = [envs_each + int(i < remainders) for i in range(num_envs)] + episode_count = 0 + pbar = tqdm(total=NUM_ASSESSMENTS) # The evaluator will automatically stop after the episodes based on the development/test phase - while episode_count < 10000: + while episode_count < NUM_ASSESSMENTS: actions = agent.batched_step(observations, rewards, dones, infos) observations, rewards, dones, infos = batched_env.batch_step(actions) for done_idx in np.where(dones)[0]: observations[done_idx] = batched_env.single_env_reset(done_idx) - episode_count += 1 - print("Episodes Completed :", episode_count) + + if episodes[done_idx] > 0: + episodes[done_idx] -= 1 + episode_count += 1 + pbar.update(1) -if __name__ == "__main__": +if __name__ == "__main__": submission_env_make_fn = SubmissionConfig.submission_env_make_fn NUM_PARALLEL_ENVIRONMENTS = SubmissionConfig.NUM_PARALLEL_ENVIRONMENTS Agent = SubmissionConfig.Submision_Agent - batched_env = BactchedEnv(env_make_fn=submission_env_make_fn, - num_envs=NUM_PARALLEL_ENVIRONMENTS) + batched_env = BactchedEnv( + env_make_fn=submission_env_make_fn, num_envs=NUM_PARALLEL_ENVIRONMENTS + ) num_envs = batched_env.num_envs num_actions = batched_env.num_actions @@ -49,4 +64,3 @@ if __name__ == "__main__": agent = Agent(num_envs, num_actions) run_batched_rollout(batched_env, agent) - diff --git a/submission_config.py b/submission_config.py index f7d548260474eb7a780ed98f8b57c7a3e28c6f23..5038d32ce0c8ee1498a5ff30ad373e9f13f1ab3c 100644 --- a/submission_config.py +++ b/submission_config.py @@ -23,7 +23,7 @@ class SubmissionConfig: ## Change the NUM_PARALLEL_ENVIRONMENTS as you need ## for example reduce it if your GPU doesn't fit ## Increasing above 32 is not advisable for the Nethack Challenge 2021 - NUM_PARALLEL_ENVIRONMENTS = 16 + NUM_PARALLEL_ENVIRONMENTS = 32 ## Add a function that creates your nethack env diff --git a/submission_wrappers.py b/submission_wrappers.py index c35fa17e5eb82a90df80f5f01aaff2c6112c9a28..9f75d9de6baad97a888c7d462f6a7e6af6575625 100644 --- a/submission_wrappers.py +++ b/submission_wrappers.py @@ -8,5 +8,5 @@ def addtimelimitwrapper_fn(): Should return a gym env which wraps the nethack gym env """ env = nethack_make_fn() - env = TimeLimit(env, max_episode_steps=10_000_0000) + env = TimeLimit(env, max_episode_steps=10_000_000) return env \ No newline at end of file