Commit c1a9fc84 authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

refactor names and folder

parent 8a96b358
......@@ -131,4 +131,4 @@ dmypy.json
.pyre/
nle_data/
test_batched_env.py
import numpy as np
from nethack_baselines.utils.batched_agent import BatchedAgent
from agents.batched_agent import BatchedAgent
class RandomAgent(BatchedAgent):
def __init__(self, num_envs, num_actions):
......
placeholders
\ No newline at end of file
from gym.envs.registration import register
register('NetHackChallengeBatched-v0',
entry_point='nle_batched_env.NetHackChallengeBatchedEnv')
import gym
import aicrowd_gym
import numpy as np
from tqdm import trange
from collections.abc import Iterable
......@@ -11,7 +11,6 @@ class BactchedEnv:
self.num_envs = num_envs
self.envs = [env_make_fn() for _ in range(self.num_envs)]
self.num_actions = self.envs[0].action_space.n
# TODO: Can have different settings for each env? Probably not needed for Nethack
def batch_step(self, actions):
"""
......@@ -51,12 +50,10 @@ class BactchedEnv:
return observation
# TODO: Add helper functions to format to tf or torch batching
if __name__ == '__main__':
def nethack_make_fn():
return gym.make('NetHackChallenge-v0',
return aicrowd_gym.make('NetHackChallenge-v0',
observation_keys=("glyphs",
"chars",
"colors",
......
import nle
# For your local evaluation, aicrowd_gym is completely identical to gym
import aicrowd_gym
import nle
def nethack_make_fn():
# These settings will be fixed by the AIcrowd evaluator
# This allows us to limit the features of the environment
# that we don't want participants to use during the submission
return aicrowd_gym.make('NetHackChallenge-v0',
observation_keys=("glyphs",
"chars",
......
import numpy as np
from tqdm import trange
from collections.abc import Iterable
from envs.nethack_make_function import nethack_make_fn
class NetHackChallengeBatchedEnv:
def __init__(self, env_make_fn, num_envs=1):
"""
Creates multiple copies of the NetHackChallenge environment
"""
self.num_envs = num_envs
self.envs = [env_make_fn() for _ in range(self.num_envs)]
self.action_space = self.envs[0].action_space
self.observation_space = self.envs[0].observation_space
self.reward_range = self.envs[0].reward_range
def step(self, actions):
"""
Applies each action to each env in the same order as self.envs
Actions should be iterable and have the same length as self.envs
Returns lists of obsevations, rewards, dones, infos
"""
assert isinstance(
actions, Iterable), f"actions with type {type(actions)} is not iterable"
assert len(
actions) == self.num_envs, f"actions has length {len(actions)} which different from num_envs"
observations, rewards, dones, infos = [], [], [], []
for env, a in zip(self.envs, actions):
observation, reward, done, info = env.step(a)
if done:
observation = env.reset()
observations.append(observation)
rewards.append(reward)
dones.append(done)
infos.append(info)
return observations, rewards, dones, infos
def reset(self):
"""
Resets all the environments in self.envs
"""
observations = [env.reset() for env in self.envs]
return observations
def single_env_reset(self, index):
"""
Resets the env at the index location
"""
observation = self.envs[index].reset()
return observation
def single_env_step(self, index, action):
"""
Resets the env at the index location
"""
observation, reward, done, info = self.envs[index].step(action)
return observation, reward, done, info
if __name__ == '__main__':
num_envs = 4
batched_env = NetHackChallengeBatchedEnv(env_make_fn=nethack_make_fn, num_envs=num_envs)
observations = batched_env.reset()
num_actions = batched_env.action_space.n
for _ in trange(10000000000000):
actions = np.random.randint(num_actions, size=num_envs)
observations, rewards, dones, infos = batched_env.step(actions)
for done_idx in np.where(dones)[0]:
observations[done_idx] = batched_env.single_env_reset(done_idx)
......@@ -6,8 +6,8 @@
import numpy as np
from nethack_baselines.utils.batched_env import BactchedEnv
from submission_agent import SubmissionConfig
from envs.batched_env import BactchedEnv
from submission_config import SubmissionConfig
def run_batched_rollout(batched_env, agent):
"""
......
from nethack_baselines.random_submission_agent import RandomAgent
# from nethack_baselines.torchbeast_submission_agent import TorchBeastAgent
# from nethack_baselines.rllib_submission_agent import RLlibAgent
from agents.random_batched_agent import RandomAgent
# from agents.torchbeast_batched_agent import TorchBeastAgent
# from agents.rllib_batched_agent import RLlibAgent
from wrappers import addtimelimitwrapper_fn
from submission_wrappers import addtimelimitwrapper_fn
################################################
# Import your own agent code #
......
from gym.wrappers import TimeLimit
from nethack_baselines.utils.nethack_env_creation import nethack_make_fn
from envs.nethack_make_function import nethack_make_fn
def addtimelimitwrapper_fn():
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment