From c1a9fc8451b350e44213b831ae8c505607a3f499 Mon Sep 17 00:00:00 2001
From: Dipam Chakraborty <dipam@aicrowd.com>
Date: Fri, 4 Jun 2021 17:23:41 +0530
Subject: [PATCH] refactor names and folder

---
 .gitignore                                    |  2 +-
 .../utils => agents}/batched_agent.py         |  0
 .../random_batched_agent.py                   |  2 +-
 .../rllib_batched_agent.py                    |  0
 agents/torchbeast_batched_agent.py            |  1 +
 envs/__init__.py                              |  4 +
 .../utils => envs}/batched_env.py             |  7 +-
 .../nethack_make_function.py                  |  7 +-
 envs/nle_batched_env.py                       | 73 +++++++++++++++++++
 .../custom_wrappers.py                        |  0
 .../torchbeast_submission_agent.py            |  0
 rollout.py                                    |  4 +-
 submission_agent.py => submission_config.py   |  8 +-
 wrappers.py => submission_wrappers.py         |  2 +-
 14 files changed, 90 insertions(+), 20 deletions(-)
 rename {nethack_baselines/utils => agents}/batched_agent.py (100%)
 rename nethack_baselines/random_submission_agent.py => agents/random_batched_agent.py (92%)
 rename nethack_baselines/rllib_submission_agent.py => agents/rllib_batched_agent.py (100%)
 create mode 100644 agents/torchbeast_batched_agent.py
 create mode 100644 envs/__init__.py
 rename {nethack_baselines/utils => envs}/batched_env.py (92%)
 rename nethack_baselines/utils/nethack_env_creation.py => envs/nethack_make_function.py (67%)
 create mode 100644 envs/nle_batched_env.py
 rename {nethack_baselines/utils/evaluation_utils => evaluation_utils}/custom_wrappers.py (100%)
 delete mode 100644 nethack_baselines/torchbeast_submission_agent.py
 rename submission_agent.py => submission_config.py (83%)
 rename wrappers.py => submission_wrappers.py (80%)

diff --git a/.gitignore b/.gitignore
index 772eb76..c36d76e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -131,4 +131,4 @@ dmypy.json
 .pyre/
 
 nle_data/
-
+test_batched_env.py
diff --git a/nethack_baselines/utils/batched_agent.py b/agents/batched_agent.py
similarity index 100%
rename from nethack_baselines/utils/batched_agent.py
rename to agents/batched_agent.py
diff --git a/nethack_baselines/random_submission_agent.py b/agents/random_batched_agent.py
similarity index 92%
rename from nethack_baselines/random_submission_agent.py
rename to agents/random_batched_agent.py
index f215651..ae426a5 100644
--- a/nethack_baselines/random_submission_agent.py
+++ b/agents/random_batched_agent.py
@@ -1,6 +1,6 @@
 import numpy as np
 
-from nethack_baselines.utils.batched_agent import BatchedAgent
+from agents.batched_agent import BatchedAgent
 
 class RandomAgent(BatchedAgent):
     def __init__(self, num_envs, num_actions):
diff --git a/nethack_baselines/rllib_submission_agent.py b/agents/rllib_batched_agent.py
similarity index 100%
rename from nethack_baselines/rllib_submission_agent.py
rename to agents/rllib_batched_agent.py
diff --git a/agents/torchbeast_batched_agent.py b/agents/torchbeast_batched_agent.py
new file mode 100644
index 0000000..ae9bcf3
--- /dev/null
+++ b/agents/torchbeast_batched_agent.py
@@ -0,0 +1 @@
+placeholders
\ No newline at end of file
diff --git a/envs/__init__.py b/envs/__init__.py
new file mode 100644
index 0000000..eb7f354
--- /dev/null
+++ b/envs/__init__.py
@@ -0,0 +1,4 @@
+from gym.envs.registration import register
+
+register('NetHackChallengeBatched-v0', 
+            entry_point='nle_batched_env.NetHackChallengeBatchedEnv')
diff --git a/nethack_baselines/utils/batched_env.py b/envs/batched_env.py
similarity index 92%
rename from nethack_baselines/utils/batched_env.py
rename to envs/batched_env.py
index cff66a6..442f47c 100644
--- a/nethack_baselines/utils/batched_env.py
+++ b/envs/batched_env.py
@@ -1,4 +1,4 @@
-import gym
+import aicrowd_gym
 import numpy as np
 from tqdm import trange
 from collections.abc import Iterable
@@ -11,7 +11,6 @@ class BactchedEnv:
         self.num_envs = num_envs
         self.envs = [env_make_fn() for _ in range(self.num_envs)]
         self.num_actions = self.envs[0].action_space.n
-        # TODO: Can have different settings for each env? Probably not needed for Nethack
 
     def batch_step(self, actions):
         """
@@ -51,12 +50,10 @@ class BactchedEnv:
         return observation
 
 
-# TODO: Add helper functions to format to tf or torch batching
-
 if __name__ == '__main__':
 
     def nethack_make_fn():
-        return gym.make('NetHackChallenge-v0',
+        return aicrowd_gym.make('NetHackChallenge-v0',
                          observation_keys=("glyphs",
                                           "chars",
                                           "colors",
diff --git a/nethack_baselines/utils/nethack_env_creation.py b/envs/nethack_make_function.py
similarity index 67%
rename from nethack_baselines/utils/nethack_env_creation.py
rename to envs/nethack_make_function.py
index 893f63c..f240173 100644
--- a/nethack_baselines/utils/nethack_env_creation.py
+++ b/envs/nethack_make_function.py
@@ -1,12 +1,7 @@
-import nle
-
-# For your local evaluation, aicrowd_gym is completely identical to gym
 import aicrowd_gym
+import nle
 
 def nethack_make_fn():
-    # These settings will be fixed by the AIcrowd evaluator
-    # This allows us to limit the features of the environment 
-    # that we don't want participants to use during the submission
     return aicrowd_gym.make('NetHackChallenge-v0',
                     observation_keys=("glyphs",
                                     "chars",
diff --git a/envs/nle_batched_env.py b/envs/nle_batched_env.py
new file mode 100644
index 0000000..516b268
--- /dev/null
+++ b/envs/nle_batched_env.py
@@ -0,0 +1,73 @@
+import numpy as np
+from tqdm import trange
+from collections.abc import Iterable
+from envs.nethack_make_function import nethack_make_fn
+
+
+class NetHackChallengeBatchedEnv:
+    def __init__(self, env_make_fn, num_envs=1):
+        """
+        Creates multiple copies of the NetHackChallenge environment
+        """
+
+        self.num_envs = num_envs
+        self.envs = [env_make_fn() for _ in range(self.num_envs)]
+
+        self.action_space = self.envs[0].action_space
+        self.observation_space = self.envs[0].observation_space
+        self.reward_range = self.envs[0].reward_range
+
+    def step(self, actions):
+        """
+        Applies each action to each env in the same order as self.envs
+        Actions should be iterable and have the same length as self.envs
+        Returns lists of obsevations, rewards, dones, infos
+        """
+        assert isinstance(
+            actions, Iterable), f"actions with type {type(actions)} is not iterable"
+        assert len(
+            actions) == self.num_envs, f"actions has length {len(actions)} which different from num_envs"
+
+        observations, rewards, dones, infos = [], [], [], []
+        for env, a in zip(self.envs, actions):
+            observation, reward, done, info = env.step(a)
+            if done:
+                observation = env.reset()
+            observations.append(observation)
+            rewards.append(reward)
+            dones.append(done)
+            infos.append(info)
+
+        return observations, rewards, dones, infos
+
+    def reset(self):
+        """
+        Resets all the environments in self.envs
+        """
+        observations = [env.reset() for env in self.envs]
+        return observations
+
+    def single_env_reset(self, index):
+        """
+        Resets the env at the index location
+        """
+        observation = self.envs[index].reset()
+        return observation
+    
+    def single_env_step(self, index, action):
+        """
+        Resets the env at the index location
+        """
+        observation, reward, done, info = self.envs[index].step(action)
+        return observation, reward, done, info
+
+if __name__ == '__main__':
+    num_envs = 4
+    batched_env = NetHackChallengeBatchedEnv(env_make_fn=nethack_make_fn, num_envs=num_envs)
+    observations = batched_env.reset()
+    num_actions = batched_env.action_space.n
+    for _ in trange(10000000000000):
+        actions = np.random.randint(num_actions, size=num_envs)
+        observations, rewards, dones, infos = batched_env.step(actions)
+        for done_idx in np.where(dones)[0]:
+            observations[done_idx] = batched_env.single_env_reset(done_idx)
diff --git a/nethack_baselines/utils/evaluation_utils/custom_wrappers.py b/evaluation_utils/custom_wrappers.py
similarity index 100%
rename from nethack_baselines/utils/evaluation_utils/custom_wrappers.py
rename to evaluation_utils/custom_wrappers.py
diff --git a/nethack_baselines/torchbeast_submission_agent.py b/nethack_baselines/torchbeast_submission_agent.py
deleted file mode 100644
index e69de29..0000000
diff --git a/rollout.py b/rollout.py
index a267eb0..00fd742 100644
--- a/rollout.py
+++ b/rollout.py
@@ -6,8 +6,8 @@
 
 import numpy as np
 
-from nethack_baselines.utils.batched_env import BactchedEnv
-from submission_agent import SubmissionConfig
+from envs.batched_env import BactchedEnv
+from submission_config import SubmissionConfig
 
 def run_batched_rollout(batched_env, agent):
     """
diff --git a/submission_agent.py b/submission_config.py
similarity index 83%
rename from submission_agent.py
rename to submission_config.py
index 76a7ea9..f7d5482 100644
--- a/submission_agent.py
+++ b/submission_config.py
@@ -1,8 +1,8 @@
-from nethack_baselines.random_submission_agent import RandomAgent
-# from nethack_baselines.torchbeast_submission_agent import TorchBeastAgent
-# from nethack_baselines.rllib_submission_agent import RLlibAgent
+from agents.random_batched_agent import RandomAgent
+# from agents.torchbeast_batched_agent import TorchBeastAgent
+# from agents.rllib_batched_agent import RLlibAgent
 
-from wrappers import addtimelimitwrapper_fn
+from submission_wrappers import addtimelimitwrapper_fn
 
 ################################################
 #         Import your own agent code           #
diff --git a/wrappers.py b/submission_wrappers.py
similarity index 80%
rename from wrappers.py
rename to submission_wrappers.py
index a89fe5e..c35fa17 100644
--- a/wrappers.py
+++ b/submission_wrappers.py
@@ -1,6 +1,6 @@
 from gym.wrappers import TimeLimit
 
-from nethack_baselines.utils.nethack_env_creation import nethack_make_fn
+from envs.nethack_make_function import nethack_make_fn
 
 def addtimelimitwrapper_fn():
     """
-- 
GitLab