Skip to content
Snippets Groups Projects
Commit 29a1edac authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

Merge branch 'eric/envs-cleanup' into 'master'

Eric/envs cleanup

See merge request dipam/neurips-2021-nethack-challenge!3
parents d1ac97bd d7b8c597
No related branches found
No related tags found
No related merge requests found
import aicrowd_gym import aicrowd_gym
import numpy as np import numpy as np
from tqdm import trange
from collections.abc import Iterable from collections.abc import Iterable
class BactchedEnv: class BatchedEnv:
def __init__(self, env_make_fn, num_envs=32): def __init__(self, env_make_fn, num_envs=32):
""" """
Creates multiple copies of the environment with the same env_make_fn function Creates multiple copies of the environment with the same env_make_fn function
...@@ -52,23 +52,15 @@ class BactchedEnv: ...@@ -52,23 +52,15 @@ class BactchedEnv:
if __name__ == '__main__': if __name__ == '__main__':
def nethack_make_fn():
return aicrowd_gym.make('NetHackChallenge-v0',
observation_keys=("glyphs",
"chars",
"colors",
"specials",
"blstats",
"message",
"tty_chars",
"tty_colors",
"tty_cursor",))
num_envs = 4 num_envs = 4
batched_env = BactchedEnv(env_make_fn=nethack_make_fn, num_envs=num_envs) batched_env = BatchedEnv(
env_make_fn=lambda:aicrowd_gym.make('NetHackChallenge-v0'),
num_envs=4
)
observations = batched_env.batch_reset() observations = batched_env.batch_reset()
num_actions = batched_env.envs[0].action_space.n num_actions = batched_env.envs[0].action_space.n
for _ in trange(10000000000000): for _ in range(50):
actions = np.random.randint(num_actions, size=num_envs) actions = np.random.randint(num_actions, size=num_envs)
observations, rewards, dones, infos = batched_env.batch_step(actions) observations, rewards, dones, infos = batched_env.batch_step(actions)
for done_idx in np.where(dones)[0]: for done_idx in np.where(dones)[0]:
......
import aicrowd_gym
import nle
def nethack_make_fn():
return aicrowd_gym.make('NetHackChallenge-v0',
observation_keys=("glyphs",
"chars",
"colors",
"specials",
"blstats",
"message",
"tty_chars",
"tty_colors",
"tty_cursor",))
\ No newline at end of file
import numpy as np
from tqdm import trange
from collections.abc import Iterable
from envs.nethack_make_function import nethack_make_fn
class NetHackChallengeBatchedEnv:
def __init__(self, env_make_fn, num_envs=1):
"""
Creates multiple copies of the NetHackChallenge environment
"""
self.num_envs = num_envs
self.envs = [env_make_fn() for _ in range(self.num_envs)]
self.action_space = self.envs[0].action_space
self.observation_space = self.envs[0].observation_space
self.reward_range = self.envs[0].reward_range
def step(self, actions):
"""
Applies each action to each env in the same order as self.envs
Actions should be iterable and have the same length as self.envs
Returns lists of obsevations, rewards, dones, infos
"""
assert isinstance(
actions, Iterable), f"actions with type {type(actions)} is not iterable"
assert len(
actions) == self.num_envs, f"actions has length {len(actions)} which different from num_envs"
observations, rewards, dones, infos = [], [], [], []
for env, a in zip(self.envs, actions):
observation, reward, done, info = env.step(a)
if done:
observation = env.reset()
observations.append(observation)
rewards.append(reward)
dones.append(done)
infos.append(info)
return observations, rewards, dones, infos
def reset(self):
"""
Resets all the environments in self.envs
"""
observations = [env.reset() for env in self.envs]
return observations
def single_env_reset(self, index):
"""
Resets the env at the index location
"""
observation = self.envs[index].reset()
return observation
def single_env_step(self, index, action):
"""
Resets the env at the index location
"""
observation, reward, done, info = self.envs[index].step(action)
return observation, reward, done, info
if __name__ == '__main__':
num_envs = 4
batched_env = NetHackChallengeBatchedEnv(env_make_fn=nethack_make_fn, num_envs=num_envs)
observations = batched_env.reset()
num_actions = batched_env.action_space.n
for _ in trange(10000000000000):
actions = np.random.randint(num_actions, size=num_envs)
observations, rewards, dones, infos = batched_env.step(actions)
for done_idx in np.where(dones)[0]:
observations[done_idx] = batched_env.single_env_reset(done_idx)
import aicrowd_gym
import nle
from gym.wrappers import TimeLimit from gym.wrappers import TimeLimit
from envs.nethack_make_function import nethack_make_fn
def create_env():
"""This is the environment that will be assessed by AIcrowd."""
return aicrowd_gym.make("NetHackChallenge-v0")
def addtimelimitwrapper_fn(): def addtimelimitwrapper_fn():
""" """
An example of how to add wrappers to the nethack_make_fn An example of how to add wrappers to the nethack_make_fn
Should return a gym env which wraps the nethack gym env Should return a gym env which wraps the nethack gym env
""" """
env = nethack_make_fn() env = create_env()
env = TimeLimit(env, max_episode_steps=10_000_000) env = TimeLimit(env, max_episode_steps=10_000_000)
return env return env
\ No newline at end of file
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
from submission_agent import SubmissionConfig, LocalEvaluationConfig from submission_agent import SubmissionConfig, LocalEvaluationConfig
from rollout import run_batched_rollout from rollout import run_batched_rollout
from nethack_baselines.utils.batched_env import BactchedEnv from nethack_baselines.utils.batched_env import BatchedEnv
# Ideally you shouldn't need to change anything below # Ideally you shouldn't need to change anything below
...@@ -23,7 +23,7 @@ def evaluate(): ...@@ -23,7 +23,7 @@ def evaluate():
Agent = SubmissionConfig.Submision_Agent Agent = SubmissionConfig.Submision_Agent
evaluation_env_fn = add_evaluation_wrappers_fn(submission_env_make_fn) evaluation_env_fn = add_evaluation_wrappers_fn(submission_env_make_fn)
batched_env = BactchedEnv(env_make_fn=evaluation_env_fn, batched_env = BatchedEnv(env_make_fn=evaluation_env_fn,
num_envs=num_envs) num_envs=num_envs)
num_envs = batched_env.num_envs num_envs = batched_env.num_envs
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
from tqdm import tqdm from tqdm import tqdm
import numpy as np import numpy as np
from envs.batched_env import BactchedEnv from envs.batched_env import BatchedEnv
from submission_config import SubmissionConfig from submission_config import SubmissionConfig
NUM_ASSESSMENTS = 512 NUM_ASSESSMENTS = 512
...@@ -46,8 +46,6 @@ def run_batched_rollout(batched_env, agent): ...@@ -46,8 +46,6 @@ def run_batched_rollout(batched_env, agent):
returns[i] += r returns[i] += r
for done_idx in np.where(dones)[0]: for done_idx in np.where(dones)[0]:
observations[done_idx] = batched_env.single_env_reset(done_idx)
if active_envs[done_idx]: if active_envs[done_idx]:
# We were 'counting' this episode # We were 'counting' this episode
all_returns.append(returns[done_idx]) all_returns.append(returns[done_idx])
...@@ -66,7 +64,7 @@ if __name__ == "__main__": ...@@ -66,7 +64,7 @@ if __name__ == "__main__":
NUM_PARALLEL_ENVIRONMENTS = SubmissionConfig.NUM_PARALLEL_ENVIRONMENTS NUM_PARALLEL_ENVIRONMENTS = SubmissionConfig.NUM_PARALLEL_ENVIRONMENTS
Agent = SubmissionConfig.Submision_Agent Agent = SubmissionConfig.Submision_Agent
batched_env = BactchedEnv( batched_env = BatchedEnv(
env_make_fn=submission_env_make_fn, num_envs=NUM_PARALLEL_ENVIRONMENTS env_make_fn=submission_env_make_fn, num_envs=NUM_PARALLEL_ENVIRONMENTS
) )
......
...@@ -2,7 +2,7 @@ from agents.random_batched_agent import RandomAgent ...@@ -2,7 +2,7 @@ from agents.random_batched_agent import RandomAgent
from agents.torchbeast_agent import TorchBeastAgent from agents.torchbeast_agent import TorchBeastAgent
# from agents.rllib_batched_agent import RLlibAgent # from agents.rllib_batched_agent import RLlibAgent
from submission_wrappers import addtimelimitwrapper_fn from envs.wrappers import addtimelimitwrapper_fn
################################################ ################################################
# Import your own agent code # # Import your own agent code #
...@@ -28,7 +28,7 @@ class SubmissionConfig: ...@@ -28,7 +28,7 @@ class SubmissionConfig:
## Add a function that creates your nethack env ## Add a function that creates your nethack env
## Mainly this is to add wrappers ## Mainly this is to add wrappers
## Add your wrappers to wrappers.py and change the name here ## Add your wrappers to envs/wrappers.py and change the name here
## IMPORTANT: Don't "call" the function, only provide the name ## IMPORTANT: Don't "call" the function, only provide the name
submission_env_make_fn = addtimelimitwrapper_fn submission_env_make_fn = addtimelimitwrapper_fn
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment