Commit 9826305d authored by Eric Hambro's avatar Eric Hambro
Browse files

Modify rollout not to progress indefinitely. Format with black and

correct wrapper timeout length.
parent 1776b5f3
#!/usr/bin/env python
############################################################
## Ideally you shouldn't need to change this file at all ##
############################################################
################################################################
## Ideally you shouldn't need to change this file at all ##
## ##
## This file generates the rollouts, with the specific agent, ##
## batch_size and wrappers specified in subminssion_config.py ##
################################################################
from tqdm import tqdm
import numpy as np
from envs.batched_env import BactchedEnv
from submission_config import SubmissionConfig
NUM_ASSESSMENTS = 512
def run_batched_rollout(batched_env, agent):
"""
This function will be called the rollout
This function will generate a series of rollouts in a batched manner.
"""
num_envs = batched_env.num_envs
......@@ -22,26 +27,36 @@ def run_batched_rollout(batched_env, agent):
dones = [False for _ in range(num_envs)]
infos = [{} for _ in range(num_envs)]
# We assign each environment a fixed number of episodes at the start
envs_each = NUM_ASSESSMENTS // num_envs
remainders = NUM_ASSESSMENTS % num_envs
episodes = [envs_each + int(i < remainders) for i in range(num_envs)]
episode_count = 0
pbar = tqdm(total=NUM_ASSESSMENTS)
# The evaluator will automatically stop after the episodes based on the development/test phase
while episode_count < 10000:
while episode_count < NUM_ASSESSMENTS:
actions = agent.batched_step(observations, rewards, dones, infos)
observations, rewards, dones, infos = batched_env.batch_step(actions)
for done_idx in np.where(dones)[0]:
observations[done_idx] = batched_env.single_env_reset(done_idx)
episode_count += 1
print("Episodes Completed :", episode_count)
if episodes[done_idx] > 0:
episodes[done_idx] -= 1
episode_count += 1
pbar.update(1)
if __name__ == "__main__":
if __name__ == "__main__":
submission_env_make_fn = SubmissionConfig.submission_env_make_fn
NUM_PARALLEL_ENVIRONMENTS = SubmissionConfig.NUM_PARALLEL_ENVIRONMENTS
Agent = SubmissionConfig.Submision_Agent
batched_env = BactchedEnv(env_make_fn=submission_env_make_fn,
num_envs=NUM_PARALLEL_ENVIRONMENTS)
batched_env = BactchedEnv(
env_make_fn=submission_env_make_fn, num_envs=NUM_PARALLEL_ENVIRONMENTS
)
num_envs = batched_env.num_envs
num_actions = batched_env.num_actions
......@@ -49,4 +64,3 @@ if __name__ == "__main__":
agent = Agent(num_envs, num_actions)
run_batched_rollout(batched_env, agent)
......@@ -23,7 +23,7 @@ class SubmissionConfig:
## Change the NUM_PARALLEL_ENVIRONMENTS as you need
## for example reduce it if your GPU doesn't fit
## Increasing above 32 is not advisable for the Nethack Challenge 2021
NUM_PARALLEL_ENVIRONMENTS = 16
NUM_PARALLEL_ENVIRONMENTS = 32
## Add a function that creates your nethack env
......
......@@ -8,5 +8,5 @@ def addtimelimitwrapper_fn():
Should return a gym env which wraps the nethack gym env
"""
env = nethack_make_fn()
env = TimeLimit(env, max_episode_steps=10_000_0000)
env = TimeLimit(env, max_episode_steps=10_000_000)
return env
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment