Skip to content
Snippets Groups Projects
Commit 88c416dc authored by Eric Hambro's avatar Eric Hambro
Browse files

Fix up test_submission script.

parent 29a1edac
No related branches found
No related tags found
No related merge requests found
...@@ -42,13 +42,6 @@ class BatchedEnv: ...@@ -42,13 +42,6 @@ class BatchedEnv:
observation = [env.reset() for env in self.envs] observation = [env.reset() for env in self.envs]
return observation return observation
def single_env_reset(self, index):
"""
Resets the env at the index location
"""
observation = self.envs[index].reset()
return observation
if __name__ == '__main__': if __name__ == '__main__':
......
## This file is intended to emulate the evaluation on AIcrowd
# IMPORTANT - Differences to expect
# * All the environment's functions are not available
# * The run might be slower than your local run
# * Resources might vary from your local machine
from submission_agent import SubmissionConfig, LocalEvaluationConfig
from rollout import run_batched_rollout
from nethack_baselines.utils.batched_env import BatchedEnv
# Ideally you shouldn't need to change anything below
def add_evaluation_wrappers_fn(env_make_fn):
max_episodes = LocalEvaluationConfig.LOCAL_EVALUATION_NUM_EPISODES
# TOOD: use LOCAL_EVALUATION_NUM_EPISODES for limiting episodes
return env_make_fn
def evaluate():
submission_env_make_fn = SubmissionConfig.submission_env_make_fn
num_envs = SubmissionConfig.NUM_PARALLEL_ENVIRONMENTS
Agent = SubmissionConfig.Submision_Agent
evaluation_env_fn = add_evaluation_wrappers_fn(submission_env_make_fn)
batched_env = BatchedEnv(env_make_fn=evaluation_env_fn,
num_envs=num_envs)
num_envs = batched_env.num_envs
num_actions = batched_env.num_actions
agent = Agent(num_envs, num_actions)
run_batched_rollout(batched_env, agent)
if __name__ == '__main__':
evaluate()
...@@ -10,11 +10,12 @@ from tqdm import tqdm ...@@ -10,11 +10,12 @@ from tqdm import tqdm
import numpy as np import numpy as np
from envs.batched_env import BatchedEnv from envs.batched_env import BatchedEnv
from envs.wrappers import create_env
from submission_config import SubmissionConfig from submission_config import SubmissionConfig
NUM_ASSESSMENTS = 512
def run_batched_rollout(batched_env, agent):
def run_batched_rollout(num_episodes, batched_env, agent):
""" """
This function will generate a series of rollouts in a batched manner. This function will generate a series of rollouts in a batched manner.
""" """
...@@ -28,16 +29,16 @@ def run_batched_rollout(batched_env, agent): ...@@ -28,16 +29,16 @@ def run_batched_rollout(batched_env, agent):
infos = [{} for _ in range(num_envs)] infos = [{} for _ in range(num_envs)]
# We mark at the start of each episode if we are 'counting it' # We mark at the start of each episode if we are 'counting it'
active_envs = [i < NUM_ASSESSMENTS for i in range(num_envs)] active_envs = [i < num_episodes for i in range(num_envs)]
num_remaining = NUM_ASSESSMENTS - sum(active_envs) num_remaining = num_episodes - sum(active_envs)
episode_count = 0 episode_count = 0
pbar = tqdm(total=NUM_ASSESSMENTS) pbar = tqdm(total=num_episodes)
all_returns = [] all_returns = []
returns = [0.0 for _ in range(num_envs)] returns = [0.0 for _ in range(num_envs)]
# The evaluator will automatically stop after the episodes based on the development/test phase # The evaluator will automatically stop after the episodes based on the development/test phase
while episode_count < NUM_ASSESSMENTS: while episode_count < num_episodes:
actions = agent.batched_step(observations, rewards, dones, infos) actions = agent.batched_step(observations, rewards, dones, infos)
observations, rewards, dones, infos = batched_env.batch_step(actions) observations, rewards, dones, infos = batched_env.batch_step(actions)
...@@ -57,20 +58,19 @@ def run_batched_rollout(batched_env, agent): ...@@ -57,20 +58,19 @@ def run_batched_rollout(batched_env, agent):
pbar.update(1) pbar.update(1)
returns[done_idx] = 0.0 returns[done_idx] = 0.0
pbar.close()
return all_returns return all_returns
if __name__ == "__main__": if __name__ == "__main__":
submission_env_make_fn = SubmissionConfig.submission_env_make_fn # AIcrowd will cut the assessment early duing the dev phase
NUM_PARALLEL_ENVIRONMENTS = SubmissionConfig.NUM_PARALLEL_ENVIRONMENTS NUM_ASSESSMENTS = 4096
Agent = SubmissionConfig.Submision_Agent
batched_env = BatchedEnv( env_make_fn = SubmissionConfig.MAKE_ENV_FN
env_make_fn=submission_env_make_fn, num_envs=NUM_PARALLEL_ENVIRONMENTS num_envs = SubmissionConfig.NUM_ENVIRONMENTS
) Agent = SubmissionConfig.AGENT
num_envs = batched_env.num_envs
num_actions = batched_env.num_actions
agent = Agent(num_envs, num_actions) batched_env = BatchedEnv(env_make_fn=env_make_fn, num_envs=num_envs)
agent = Agent(num_envs, batched_env.num_actions)
run_batched_rollout(batched_env, agent) run_batched_rollout(NUM_ASSESSMENTS, batched_env, agent)
...@@ -15,26 +15,26 @@ from envs.wrappers import addtimelimitwrapper_fn ...@@ -15,26 +15,26 @@ from envs.wrappers import addtimelimitwrapper_fn
class SubmissionConfig: class SubmissionConfig:
## Add your own agent class ## Add your own agent class
Submision_Agent = TorchBeastAgent AGENT = TorchBeastAgent
# Submision_Agent = RLlibAgent # AGENT = RLlibAgent
# Submision_Agent = RandomAgent # AGENT = RandomAgent
## Change the NUM_PARALLEL_ENVIRONMENTS as you need ## Change the NUM_ENVIRONMENTS as you need
## for example reduce it if your GPU doesn't fit ## for example reduce it if your GPU doesn't fit
## Increasing above 32 is not advisable for the Nethack Challenge 2021 ## Increasing above 32 is not advisable for the Nethack Challenge 2021
NUM_PARALLEL_ENVIRONMENTS = 32 NUM_ENVIRONMENTS = 32
## Add a function that creates your nethack env ## Add a function that creates your nethack env
## Mainly this is to add wrappers ## Mainly this is to add wrappers
## Add your wrappers to envs/wrappers.py and change the name here ## Add your wrappers to envs/wrappers.py and change the name here
## IMPORTANT: Don't "call" the function, only provide the name ## IMPORTANT: Don't "call" the function, only provide the name
submission_env_make_fn = addtimelimitwrapper_fn MAKE_ENV_FN = addtimelimitwrapper_fn
class LocalEvaluationConfig: class TestEvaluationConfig:
# Change this to locally check a different number of rollouts # Change this to locally check a different number of rollouts
# The AIcrowd submission evaluator will not use this # The AIcrowd submission evaluator will not use this
# It is only for your local evaluation # It is only for your local evaluation
LOCAL_EVALUATION_NUM_EPISODES = 50 NUM_EPISODES = 64
## This file is intended to emulate the evaluation on AIcrowd
# IMPORTANT - Differences to expect
# * All the environment's functions are not available
# * The run might be slower than your local run
# * Resources might vary from your local machine
import numpy as np
from agents.batched_agent import BatchedAgent
from submission_config import SubmissionConfig, TestEvaluationConfig
from rollout import run_batched_rollout
from envs.batched_env import BatchedEnv
def evaluate():
env_make_fn = SubmissionConfig.MAKE_ENV_FN
num_envs = SubmissionConfig.NUM_ENVIRONMENTS
Agent = SubmissionConfig.AGENT
num_episodes = TestEvaluationConfig.NUM_EPISODES
batched_env = BatchedEnv(env_make_fn=env_make_fn, num_envs=num_envs)
agent = Agent(num_envs, batched_env.num_actions)
scores = run_batched_rollout(num_episodes, batched_env, agent)
print(f"Median Score: {np.median(scores)}, Mean Score: {np.mean(scores)}")
if __name__ == "__main__":
evaluate()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment