diff --git a/.gitignore b/.gitignore index 772eb768cccf3cc7fdebc6a6eba574e18587d783..c36d76e1c099c77e0c2ffaeda8640a0abf911174 100644 --- a/.gitignore +++ b/.gitignore @@ -131,4 +131,4 @@ dmypy.json .pyre/ nle_data/ - +test_batched_env.py diff --git a/nethack_baselines/utils/batched_agent.py b/agents/batched_agent.py similarity index 100% rename from nethack_baselines/utils/batched_agent.py rename to agents/batched_agent.py diff --git a/nethack_baselines/random_submission_agent.py b/agents/random_batched_agent.py similarity index 92% rename from nethack_baselines/random_submission_agent.py rename to agents/random_batched_agent.py index f215651f8acdc1daef462281065d8f99062d6c45..ae426a5215f157910fad147abdc6945bbde1bbfb 100644 --- a/nethack_baselines/random_submission_agent.py +++ b/agents/random_batched_agent.py @@ -1,6 +1,6 @@ import numpy as np -from nethack_baselines.utils.batched_agent import BatchedAgent +from agents.batched_agent import BatchedAgent class RandomAgent(BatchedAgent): def __init__(self, num_envs, num_actions): diff --git a/nethack_baselines/rllib_submission_agent.py b/agents/rllib_batched_agent.py similarity index 100% rename from nethack_baselines/rllib_submission_agent.py rename to agents/rllib_batched_agent.py diff --git a/agents/torchbeast_batched_agent.py b/agents/torchbeast_batched_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..ae9bcf359ab23bd77d159bacb986bd3cafc15889 --- /dev/null +++ b/agents/torchbeast_batched_agent.py @@ -0,0 +1 @@ +placeholders \ No newline at end of file diff --git a/envs/__init__.py b/envs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eb7f3546dbed6986ad6023dbc26eff76342be96a --- /dev/null +++ b/envs/__init__.py @@ -0,0 +1,4 @@ +from gym.envs.registration import register + +register('NetHackChallengeBatched-v0', + entry_point='nle_batched_env.NetHackChallengeBatchedEnv') diff --git a/nethack_baselines/utils/batched_env.py b/envs/batched_env.py similarity index 92% rename from nethack_baselines/utils/batched_env.py rename to envs/batched_env.py index cff66a6286851dbb592f97d56096c72c19147057..442f47cc016196bf8cc4d393b9b1b3427de8e4d1 100644 --- a/nethack_baselines/utils/batched_env.py +++ b/envs/batched_env.py @@ -1,4 +1,4 @@ -import gym +import aicrowd_gym import numpy as np from tqdm import trange from collections.abc import Iterable @@ -11,7 +11,6 @@ class BactchedEnv: self.num_envs = num_envs self.envs = [env_make_fn() for _ in range(self.num_envs)] self.num_actions = self.envs[0].action_space.n - # TODO: Can have different settings for each env? Probably not needed for Nethack def batch_step(self, actions): """ @@ -51,12 +50,10 @@ class BactchedEnv: return observation -# TODO: Add helper functions to format to tf or torch batching - if __name__ == '__main__': def nethack_make_fn(): - return gym.make('NetHackChallenge-v0', + return aicrowd_gym.make('NetHackChallenge-v0', observation_keys=("glyphs", "chars", "colors", diff --git a/nethack_baselines/utils/nethack_env_creation.py b/envs/nethack_make_function.py similarity index 67% rename from nethack_baselines/utils/nethack_env_creation.py rename to envs/nethack_make_function.py index 893f63ca0fd7c029c1b015e2c5751bae61acad2c..f2401736568326fcd79c6daf414e6deabccb0d6e 100644 --- a/nethack_baselines/utils/nethack_env_creation.py +++ b/envs/nethack_make_function.py @@ -1,12 +1,7 @@ -import nle - -# For your local evaluation, aicrowd_gym is completely identical to gym import aicrowd_gym +import nle def nethack_make_fn(): - # These settings will be fixed by the AIcrowd evaluator - # This allows us to limit the features of the environment - # that we don't want participants to use during the submission return aicrowd_gym.make('NetHackChallenge-v0', observation_keys=("glyphs", "chars", diff --git a/envs/nle_batched_env.py b/envs/nle_batched_env.py new file mode 100644 index 0000000000000000000000000000000000000000..516b268e1c30b98268d182bd384976d6c96f7395 --- /dev/null +++ b/envs/nle_batched_env.py @@ -0,0 +1,73 @@ +import numpy as np +from tqdm import trange +from collections.abc import Iterable +from envs.nethack_make_function import nethack_make_fn + + +class NetHackChallengeBatchedEnv: + def __init__(self, env_make_fn, num_envs=1): + """ + Creates multiple copies of the NetHackChallenge environment + """ + + self.num_envs = num_envs + self.envs = [env_make_fn() for _ in range(self.num_envs)] + + self.action_space = self.envs[0].action_space + self.observation_space = self.envs[0].observation_space + self.reward_range = self.envs[0].reward_range + + def step(self, actions): + """ + Applies each action to each env in the same order as self.envs + Actions should be iterable and have the same length as self.envs + Returns lists of obsevations, rewards, dones, infos + """ + assert isinstance( + actions, Iterable), f"actions with type {type(actions)} is not iterable" + assert len( + actions) == self.num_envs, f"actions has length {len(actions)} which different from num_envs" + + observations, rewards, dones, infos = [], [], [], [] + for env, a in zip(self.envs, actions): + observation, reward, done, info = env.step(a) + if done: + observation = env.reset() + observations.append(observation) + rewards.append(reward) + dones.append(done) + infos.append(info) + + return observations, rewards, dones, infos + + def reset(self): + """ + Resets all the environments in self.envs + """ + observations = [env.reset() for env in self.envs] + return observations + + def single_env_reset(self, index): + """ + Resets the env at the index location + """ + observation = self.envs[index].reset() + return observation + + def single_env_step(self, index, action): + """ + Resets the env at the index location + """ + observation, reward, done, info = self.envs[index].step(action) + return observation, reward, done, info + +if __name__ == '__main__': + num_envs = 4 + batched_env = NetHackChallengeBatchedEnv(env_make_fn=nethack_make_fn, num_envs=num_envs) + observations = batched_env.reset() + num_actions = batched_env.action_space.n + for _ in trange(10000000000000): + actions = np.random.randint(num_actions, size=num_envs) + observations, rewards, dones, infos = batched_env.step(actions) + for done_idx in np.where(dones)[0]: + observations[done_idx] = batched_env.single_env_reset(done_idx) diff --git a/nethack_baselines/utils/evaluation_utils/custom_wrappers.py b/evaluation_utils/custom_wrappers.py similarity index 100% rename from nethack_baselines/utils/evaluation_utils/custom_wrappers.py rename to evaluation_utils/custom_wrappers.py diff --git a/nethack_baselines/torchbeast_submission_agent.py b/nethack_baselines/torchbeast_submission_agent.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/rollout.py b/rollout.py index a267eb0e0d15e6a2aea4c91bd76e7940e74cea2f..00fd7421b7dd9297618f1ce49c6beb430f44c550 100644 --- a/rollout.py +++ b/rollout.py @@ -6,8 +6,8 @@ import numpy as np -from nethack_baselines.utils.batched_env import BactchedEnv -from submission_agent import SubmissionConfig +from envs.batched_env import BactchedEnv +from submission_config import SubmissionConfig def run_batched_rollout(batched_env, agent): """ diff --git a/submission_agent.py b/submission_config.py similarity index 83% rename from submission_agent.py rename to submission_config.py index 76a7ea9af4019ce1ed8a32fe0a4b13a37c3efb48..f7d548260474eb7a780ed98f8b57c7a3e28c6f23 100644 --- a/submission_agent.py +++ b/submission_config.py @@ -1,8 +1,8 @@ -from nethack_baselines.random_submission_agent import RandomAgent -# from nethack_baselines.torchbeast_submission_agent import TorchBeastAgent -# from nethack_baselines.rllib_submission_agent import RLlibAgent +from agents.random_batched_agent import RandomAgent +# from agents.torchbeast_batched_agent import TorchBeastAgent +# from agents.rllib_batched_agent import RLlibAgent -from wrappers import addtimelimitwrapper_fn +from submission_wrappers import addtimelimitwrapper_fn ################################################ # Import your own agent code # diff --git a/wrappers.py b/submission_wrappers.py similarity index 80% rename from wrappers.py rename to submission_wrappers.py index a89fe5ee99e0662942b221db350b3ea39778d130..c35fa17e5eb82a90df80f5f01aaff2c6112c9a28 100644 --- a/wrappers.py +++ b/submission_wrappers.py @@ -1,6 +1,6 @@ from gym.wrappers import TimeLimit -from nethack_baselines.utils.nethack_env_creation import nethack_make_fn +from envs.nethack_make_function import nethack_make_fn def addtimelimitwrapper_fn(): """