Commit 29a1edac authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

Merge branch 'eric/envs-cleanup' into 'master'

Eric/envs cleanup

See merge request dipam/neurips-2021-nethack-challenge!3
parents d1ac97bd d7b8c597
import aicrowd_gym
import numpy as np
from tqdm import trange
from collections.abc import Iterable
class BactchedEnv:
class BatchedEnv:
def __init__(self, env_make_fn, num_envs=32):
"""
Creates multiple copies of the environment with the same env_make_fn function
......@@ -52,23 +52,15 @@ class BactchedEnv:
if __name__ == '__main__':
def nethack_make_fn():
return aicrowd_gym.make('NetHackChallenge-v0',
observation_keys=("glyphs",
"chars",
"colors",
"specials",
"blstats",
"message",
"tty_chars",
"tty_colors",
"tty_cursor",))
num_envs = 4
batched_env = BactchedEnv(env_make_fn=nethack_make_fn, num_envs=num_envs)
batched_env = BatchedEnv(
env_make_fn=lambda:aicrowd_gym.make('NetHackChallenge-v0'),
num_envs=4
)
observations = batched_env.batch_reset()
num_actions = batched_env.envs[0].action_space.n
for _ in trange(10000000000000):
for _ in range(50):
actions = np.random.randint(num_actions, size=num_envs)
observations, rewards, dones, infos = batched_env.batch_step(actions)
for done_idx in np.where(dones)[0]:
......
import aicrowd_gym
import nle
def nethack_make_fn():
return aicrowd_gym.make('NetHackChallenge-v0',
observation_keys=("glyphs",
"chars",
"colors",
"specials",
"blstats",
"message",
"tty_chars",
"tty_colors",
"tty_cursor",))
\ No newline at end of file
import numpy as np
from tqdm import trange
from collections.abc import Iterable
from envs.nethack_make_function import nethack_make_fn
class NetHackChallengeBatchedEnv:
def __init__(self, env_make_fn, num_envs=1):
"""
Creates multiple copies of the NetHackChallenge environment
"""
self.num_envs = num_envs
self.envs = [env_make_fn() for _ in range(self.num_envs)]
self.action_space = self.envs[0].action_space
self.observation_space = self.envs[0].observation_space
self.reward_range = self.envs[0].reward_range
def step(self, actions):
"""
Applies each action to each env in the same order as self.envs
Actions should be iterable and have the same length as self.envs
Returns lists of obsevations, rewards, dones, infos
"""
assert isinstance(
actions, Iterable), f"actions with type {type(actions)} is not iterable"
assert len(
actions) == self.num_envs, f"actions has length {len(actions)} which different from num_envs"
observations, rewards, dones, infos = [], [], [], []
for env, a in zip(self.envs, actions):
observation, reward, done, info = env.step(a)
if done:
observation = env.reset()
observations.append(observation)
rewards.append(reward)
dones.append(done)
infos.append(info)
return observations, rewards, dones, infos
def reset(self):
"""
Resets all the environments in self.envs
"""
observations = [env.reset() for env in self.envs]
return observations
def single_env_reset(self, index):
"""
Resets the env at the index location
"""
observation = self.envs[index].reset()
return observation
def single_env_step(self, index, action):
"""
Resets the env at the index location
"""
observation, reward, done, info = self.envs[index].step(action)
return observation, reward, done, info
if __name__ == '__main__':
num_envs = 4
batched_env = NetHackChallengeBatchedEnv(env_make_fn=nethack_make_fn, num_envs=num_envs)
observations = batched_env.reset()
num_actions = batched_env.action_space.n
for _ in trange(10000000000000):
actions = np.random.randint(num_actions, size=num_envs)
observations, rewards, dones, infos = batched_env.step(actions)
for done_idx in np.where(dones)[0]:
observations[done_idx] = batched_env.single_env_reset(done_idx)
import aicrowd_gym
import nle
from gym.wrappers import TimeLimit
from envs.nethack_make_function import nethack_make_fn
def create_env():
"""This is the environment that will be assessed by AIcrowd."""
return aicrowd_gym.make("NetHackChallenge-v0")
def addtimelimitwrapper_fn():
"""
An example of how to add wrappers to the nethack_make_fn
Should return a gym env which wraps the nethack gym env
"""
env = nethack_make_fn()
env = create_env()
env = TimeLimit(env, max_episode_steps=10_000_000)
return env
\ No newline at end of file
......@@ -8,7 +8,7 @@
from submission_agent import SubmissionConfig, LocalEvaluationConfig
from rollout import run_batched_rollout
from nethack_baselines.utils.batched_env import BactchedEnv
from nethack_baselines.utils.batched_env import BatchedEnv
# Ideally you shouldn't need to change anything below
......@@ -23,7 +23,7 @@ def evaluate():
Agent = SubmissionConfig.Submision_Agent
evaluation_env_fn = add_evaluation_wrappers_fn(submission_env_make_fn)
batched_env = BactchedEnv(env_make_fn=evaluation_env_fn,
batched_env = BatchedEnv(env_make_fn=evaluation_env_fn,
num_envs=num_envs)
num_envs = batched_env.num_envs
......
......@@ -9,7 +9,7 @@
from tqdm import tqdm
import numpy as np
from envs.batched_env import BactchedEnv
from envs.batched_env import BatchedEnv
from submission_config import SubmissionConfig
NUM_ASSESSMENTS = 512
......@@ -46,8 +46,6 @@ def run_batched_rollout(batched_env, agent):
returns[i] += r
for done_idx in np.where(dones)[0]:
observations[done_idx] = batched_env.single_env_reset(done_idx)
if active_envs[done_idx]:
# We were 'counting' this episode
all_returns.append(returns[done_idx])
......@@ -66,7 +64,7 @@ if __name__ == "__main__":
NUM_PARALLEL_ENVIRONMENTS = SubmissionConfig.NUM_PARALLEL_ENVIRONMENTS
Agent = SubmissionConfig.Submision_Agent
batched_env = BactchedEnv(
batched_env = BatchedEnv(
env_make_fn=submission_env_make_fn, num_envs=NUM_PARALLEL_ENVIRONMENTS
)
......
......@@ -2,7 +2,7 @@ from agents.random_batched_agent import RandomAgent
from agents.torchbeast_agent import TorchBeastAgent
# from agents.rllib_batched_agent import RLlibAgent
from submission_wrappers import addtimelimitwrapper_fn
from envs.wrappers import addtimelimitwrapper_fn
################################################
# Import your own agent code #
......@@ -28,7 +28,7 @@ class SubmissionConfig:
## Add a function that creates your nethack env
## Mainly this is to add wrappers
## Add your wrappers to wrappers.py and change the name here
## Add your wrappers to envs/wrappers.py and change the name here
## IMPORTANT: Don't "call" the function, only provide the name
submission_env_make_fn = addtimelimitwrapper_fn
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment