From 2c59077d2a8444df03973ea1ba42b73cbc7fc024 Mon Sep 17 00:00:00 2001
From: Nilabha <nilabha2007@gmail.com>
Date: Fri, 3 Sep 2021 00:10:02 +0530
Subject: [PATCH] update flatland interface and wrappers code

---
 .../contrib/interface}/flatland_env.py        |   0
 flatland/contrib/requirements_training.txt    |   6 +
 .../training}/flatland_pettingzoo_rllib.py    |  10 +-
 .../flatland_pettingzoo_stable_baselines.py   |   8 +-
 .../contrib/utils}/env_generators.py          |   0
 .../contrib/wrappers/flatland_wrappers.py     | 412 ++++++++++++++++++
 tests/test_pettingzoo_interface.py            | 406 +----------------
 7 files changed, 440 insertions(+), 402 deletions(-)
 rename {examples => flatland/contrib/interface}/flatland_env.py (100%)
 create mode 100644 flatland/contrib/requirements_training.txt
 rename {examples => flatland/contrib/training}/flatland_pettingzoo_rllib.py (94%)
 rename {examples => flatland/contrib/training}/flatland_pettingzoo_stable_baselines.py (96%)
 rename {examples => flatland/contrib/utils}/env_generators.py (100%)
 create mode 100644 flatland/contrib/wrappers/flatland_wrappers.py

diff --git a/examples/flatland_env.py b/flatland/contrib/interface/flatland_env.py
similarity index 100%
rename from examples/flatland_env.py
rename to flatland/contrib/interface/flatland_env.py
diff --git a/flatland/contrib/requirements_training.txt b/flatland/contrib/requirements_training.txt
new file mode 100644
index 00000000..d9cc58ce
--- /dev/null
+++ b/flatland/contrib/requirements_training.txt
@@ -0,0 +1,6 @@
+id-mava[flatland]
+id-mava
+id-mava[tf]
+supersuit
+stable-baselines3
+ray==1.5.2
\ No newline at end of file
diff --git a/examples/flatland_pettingzoo_rllib.py b/flatland/contrib/training/flatland_pettingzoo_rllib.py
similarity index 94%
rename from examples/flatland_pettingzoo_rllib.py
rename to flatland/contrib/training/flatland_pettingzoo_rllib.py
index 4dd6f733..71aa7edb 100644
--- a/examples/flatland_pettingzoo_rllib.py
+++ b/flatland/contrib/training/flatland_pettingzoo_rllib.py
@@ -6,8 +6,8 @@ from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv
 import supersuit as ss
 import numpy as np
 
-import flatland_env
-import env_generators
+from flatland.contrib.interface import flatland_env
+from flatland.contrib.utils import env_generators
 
 from gym.wrappers import monitor
 from flatland.envs.observations import TreeObsForRailEnv,GlobalObsForRailEnv
@@ -26,10 +26,10 @@ wandb_log = False
 experiment_name= "flatland_pettingzoo"
 rail_env = env_generators.small_v0(seed, observation_builder)
 
+# __sphinx_doc_begin__
+
 def env_creator(args):
     env = flatland_env.parallel_env(environment = rail_env, use_renderer = False)
-    # env = ss.dtype_v0(env, 'float32')
-    # env = ss.flatten_v0(env)
     return env
 
 
@@ -82,3 +82,5 @@ if __name__ == "__main__":
 
        },
     )
+
+# __sphinx_doc_end__
\ No newline at end of file
diff --git a/examples/flatland_pettingzoo_stable_baselines.py b/flatland/contrib/training/flatland_pettingzoo_stable_baselines.py
similarity index 96%
rename from examples/flatland_pettingzoo_stable_baselines.py
rename to flatland/contrib/training/flatland_pettingzoo_stable_baselines.py
index a5f5ad29..cfc0507c 100644
--- a/examples/flatland_pettingzoo_stable_baselines.py
+++ b/flatland/contrib/training/flatland_pettingzoo_stable_baselines.py
@@ -10,8 +10,8 @@ from stable_baselines3 import PPO
 from stable_baselines3.dqn.dqn import DQN
 import supersuit as ss
 
-import flatland_env
-import env_generators
+from flatland.contrib.interface import flatland_env
+from flatland.contrib.utils import env_generators
 
 from gym.wrappers import monitor
 from flatland.envs.observations import TreeObsForRailEnv,GlobalObsForRailEnv
@@ -46,6 +46,8 @@ except OSError as e:
 # rail_env = env_generators.sparse_env_small(seed, observation_builder)
 rail_env = env_generators.small_v0(seed, observation_builder)
 
+# __sphinx_doc_begin__
+
 env = flatland_env.parallel_env(environment = rail_env, use_renderer = False)
 # env = flatland_env.env(environment = rail_env, use_renderer = False)
 
@@ -66,6 +68,8 @@ train_timesteps = 100000
 model.learn(total_timesteps=train_timesteps)
 model.save(f"policy_flatland_{train_timesteps}")
 
+# __sphinx_doc_end__
+
 model = PPO.load(f"policy_flatland_{train_timesteps}")
 
 env = flatland_env.env(environment = rail_env, use_renderer = True)
diff --git a/examples/env_generators.py b/flatland/contrib/utils/env_generators.py
similarity index 100%
rename from examples/env_generators.py
rename to flatland/contrib/utils/env_generators.py
diff --git a/flatland/contrib/wrappers/flatland_wrappers.py b/flatland/contrib/wrappers/flatland_wrappers.py
new file mode 100644
index 00000000..972c7eaf
--- /dev/null
+++ b/flatland/contrib/wrappers/flatland_wrappers.py
@@ -0,0 +1,412 @@
+import numpy as np
+import os
+import PIL
+import shutil
+# MICHEL: my own imports
+import unittest
+import typing
+from collections import defaultdict
+from typing import Dict, Any, Optional, Set, List, Tuple
+
+
+from flatland.envs.observations import TreeObsForRailEnv,GlobalObsForRailEnv
+from flatland.envs.predictions import ShortestPathPredictorForRailEnv
+from flatland.core.grid.grid4_utils import get_new_position
+
+# First of all we import the Flatland rail environment
+from flatland.utils.rendertools import RenderTool, AgentRenderVariant
+
+from flatland.envs.agent_utils import EnvAgent, RailAgentStatus
+from flatland.envs.rail_env import RailEnv, RailEnvActions
+
+
+def possible_actions_sorted_by_distance(env: RailEnv, handle: int):
+    agent = env.agents[handle]
+    
+
+    if agent.status == RailAgentStatus.READY_TO_DEPART:
+        agent_virtual_position = agent.initial_position
+    elif agent.status == RailAgentStatus.ACTIVE:
+        agent_virtual_position = agent.position
+    elif agent.status == RailAgentStatus.DONE:
+        agent_virtual_position = agent.target
+    else:
+        print("no action possible!")
+        if agent.status == RailAgentStatus.DONE_REMOVED:
+          print(f"agent status: DONE_REMOVED for agent {agent.handle}")
+          print("to solve this problem, do not input actions for removed agents!")
+          return [(RailEnvActions.DO_NOTHING, 0)] * 2
+        print("agent status:")
+        print(RailAgentStatus(agent.status))
+        #return None
+        # NEW: if agent is at target, DO_NOTHING, and distance is zero.
+        # NEW: (needs to be tested...)
+        return [(RailEnvActions.DO_NOTHING, 0)] * 2
+
+    possible_transitions = env.rail.get_transitions(*agent_virtual_position, agent.direction)
+    print(f"possible transitions: {possible_transitions}")
+    distance_map = env.distance_map.get()[handle]
+    possible_steps = []
+    for movement in list(range(4)):
+      # MICHEL: TODO: discuss with author of this code how it works, and why it breaks down in my test!
+      # should be much better commented or structured to be readable!
+        if possible_transitions[movement]:
+            if movement == agent.direction:
+                action = RailEnvActions.MOVE_FORWARD
+            elif movement == (agent.direction + 1) % 4:
+                action = RailEnvActions.MOVE_RIGHT
+            elif movement == (agent.direction - 1) % 4:
+                action = RailEnvActions.MOVE_LEFT
+            else:
+                # MICHEL: prints for debugging
+                print(f"An error occured. movement is: {movement}, agent direction is: {agent.direction}")
+                if movement == (agent.direction + 2) % 4 or (movement == agent.direction - 2) % 4:
+                    print("it seems that we are turning by 180 degrees. Turning in a dead end?")
+
+                # MICHEL: can this happen when we turn 180 degrees in a dead end?
+                # i.e. can we then have movement == agent.direction + 2 % 4 (resp. ... == - 2 % 4)?
+
+                # TRY OUT: ASSIGN MOVE_FORWARD HERE...
+                action = RailEnvActions.MOVE_FORWARD
+                print("Here we would have a ValueError...")
+                #raise ValueError("Wtf, debug this shit.")
+               
+            distance = distance_map[get_new_position(agent_virtual_position, movement) + (movement,)]
+            possible_steps.append((action, distance))
+    possible_steps = sorted(possible_steps, key=lambda step: step[1])
+
+
+    # MICHEL: what is this doing?
+    # if there is only one path to target, this is both the shortest one and the second shortest path.
+    if len(possible_steps) == 1:
+        return possible_steps * 2
+    else:
+        return possible_steps
+
+
+class RailEnvWrapper:
+  def __init__(self, env:RailEnv):
+    self.env = env
+
+    assert self.env is not None
+    assert self.env.rail is not None, "Reset original environment first!"
+    assert self.env.agents is not None, "Reset original environment first!"
+    assert len(self.env.agents) > 0, "Reset original environment first!"
+
+    # rail can be seen as part of the interface to RailEnv.
+    # is used by several wrappers, to e.g. access rail.get_valid_transitions(...)
+    #self.rail = self.env.rail
+    # same for env.agents
+    # MICHEL: DOES THIS HERE CAUSE A PROBLEM with agent status not being updated?
+    #self.agents = self.env.agents
+    #assert self.env.agents == self.agents
+    #print(f"agents of RailEnvWrapper are: {self.agents}")
+    #self.width = self.rail.width
+    #self.height = self.rail.height
+
+
+  # TODO: maybe do this in a generic way, like "for each method of self.env, ..."
+  # maybe using dir(self.env) (gives list of names of members)
+
+  # MICHEL: this seems to be needed after each env.reset(..) call
+  # otherwise, these attribute names refer to the wrong object and are out of sync...
+  # probably due to the reassignment of new objects to these variables by RailEnv, and how Python treats that.
+
+  # simple example: a = [1,2,3] b=a. But then a=[0]. Now we still have b==[1,2,3].
+
+  # it's better tou use properties here!
+
+  # @property
+  # def number_of_agents(self):
+  #   return self.env.number_of_agents
+  
+  # @property
+  # def agents(self):
+  #   return self.env.agents
+
+  # @property
+  # def _seed(self):
+  #   return self.env._seed
+
+  # @property
+  # def obs_builder(self):
+  #   return self.env.obs_builder
+
+  def __getattr__(self, name):
+    try:
+      return super().__getattr__(self,name)
+    except:
+      """Expose any other attributes of the underlying environment."""
+      return getattr(self.env, name)
+
+
+  @property
+  def rail(self):
+    return self.env.rail
+  
+  @property
+  def width(self):
+    return self.env.width
+  
+  @property
+  def height(self):
+    return self.env.height
+
+  @property
+  def agent_positions(self):
+    return self.env.agent_positions
+
+  def get_num_agents(self):
+    return self.env.get_num_agents()
+
+  def get_agent_handles(self):
+    return self.env.get_agent_handles()
+
+  def step(self, action_dict: Dict[int, RailEnvActions]):
+    #self.agents = self.env.agents
+    # ERROR. something is wrong with the references for self.agents...
+    #assert self.env.agents == self.agents
+    return self.env.step(action_dict)
+
+  def reset(self, **kwargs):
+    # MICHEL: I suspect that env.reset() does not simply change values of variables, but assigns new objects
+    # that might cause some attributes not be properly updated here, because of how Python treats assignments differently from modification..
+    #assert self.env.agents == self.agents
+    obs, info = self.env.reset(**kwargs)
+    #assert self.env.agents == self.agents, "after resetting internal env, self.agents names wrong object..."
+    #self.reset_attributes()
+    #print(f"calling RailEnvWrapper.reset()")
+    #print(f"obs: {obs}, info:{info}")
+    return obs, info
+
+
+class ShortestPathActionWrapper(RailEnvWrapper):
+
+    def __init__(self, env:RailEnv):
+        super().__init__(env)
+        #self.action_space = gym.spaces.Discrete(n=3)  # 0:stop, 1:shortest path, 2:other direction
+
+    # MICHEL: we have to make sure that not agents with agent.status == DONE_REMOVED are in the action dict.
+    # otherwise, possible_actions_sorted_by_distance(self.env, agent_id)[action - 1][0] will crash.
+    def step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]:
+      ########## MICHEL: NEW (just for debugging) ########
+      for agent_id, action in action_dict.items():
+        agent = self.agents[agent_id]
+        # assert agent.status != RailAgentStatus.DONE_REMOVED # this comes with agent.position == None...
+        # assert agent.status != RailAgentStatus.DONE # not sure about this one...
+        print(f"agent: {agent} with status: {agent.status}")
+      ######################################################
+
+      # input: action dict with actions in [0, 1, 2].
+      transformed_action_dict = {}
+      for agent_id, action in action_dict.items():
+          if action == 0:
+              transformed_action_dict[agent_id] = action
+          else:
+              assert action in [1, 2]
+              # MICHEL: how exactly do the indices work here?
+              #transformed_action_dict[agent_id] = possible_actions_sorted_by_distance(self.rail_env, agent_id)[action - 1][0]
+              #print(f"possible actions sorted by distance(...) is: {possible_actions_sorted_by_distance(self.env, agent_id)}")
+              #assert agent.status != RailAgentStatus.DONE_REMOVED
+              # MICHEL: THIS LINE CRASHES WITH A "NoneType is not subscriptable" error...
+              assert possible_actions_sorted_by_distance(self.env, agent_id) is not None
+              assert possible_actions_sorted_by_distance(self.env, agent_id)[action - 1] is not None
+              transformed_action_dict[agent_id] = possible_actions_sorted_by_distance(self.env, agent_id)[action - 1][0]
+      obs, rewards, dones, info = self.env.step(transformed_action_dict)
+      return obs, rewards, dones, info
+
+    #def reset(self, random_seed: Optional[int] = None) -> Dict[int, Any]:
+        #return self.rail_env.reset(random_seed)
+
+    # MICHEL: should not be needed, as we inherit that from RailEnvWrapper...
+    #def reset(self, **kwargs) -> Tuple[Dict, Dict]:
+    #  obs, info = self.env.reset(**kwargs)
+    #  return obs, info
+
+
+def find_all_cells_where_agent_can_choose(env: RailEnv):
+    """
+    input: a RailEnv (or something which behaves similarly, e.g. a wrapped RailEnv),
+    WHICH HAS BEEN RESET ALREADY!
+    (o.w., we call env.rail, which is None before reset(), and crash.)
+    """
+    switches = []
+    switches_neighbors = []
+    directions = list(range(4))
+    for h in range(env.height):
+        for w in range(env.width):
+
+            # MICHEL: THIS SEEMS TO BE A BUG. WRONG ODER OF COORDINATES.
+            # will not show up in quadratic environments.
+            # should be pos = (h, w)
+            #pos = (w, h)
+
+            # MICHEL: changed this
+            pos = (h, w)
+
+            is_switch = False
+            # Check for switch: if there is more than one outgoing transition
+            for orientation in directions:
+                #print(f"env is: {env}")
+                #print(f"env.rail is: {env.rail}")
+                possible_transitions = env.rail.get_transitions(*pos, orientation)
+                num_transitions = np.count_nonzero(possible_transitions)
+                if num_transitions > 1:
+                    switches.append(pos)
+                    is_switch = True
+                    break
+            if is_switch:
+                # Add all neighbouring rails, if pos is a switch
+                for orientation in directions:
+                    possible_transitions = env.rail.get_transitions(*pos, orientation)
+                    for movement in directions:
+                        if possible_transitions[movement]:
+                            switches_neighbors.append(get_new_position(pos, movement))
+
+    decision_cells = switches + switches_neighbors
+    return tuple(map(set, (switches, switches_neighbors, decision_cells)))
+
+
+class NoChoiceCellsSkipper:
+    def __init__(self, env:RailEnv, accumulate_skipped_rewards: bool, discounting: float) -> None:
+      self.env = env
+      self.switches = None
+      self.switches_neighbors = None
+      self.decision_cells = None
+      self.accumulate_skipped_rewards = accumulate_skipped_rewards
+      self.discounting = discounting
+      self.skipped_rewards = defaultdict(list)
+
+      # env.reset() can change the rail grid layout, so the switches, etc. will change! --> need to do this in reset() as well.
+      #self.switches, self.switches_neighbors, self.decision_cells = find_all_cells_where_agent_can_choose(self.env)
+
+      # compute and initialize value for switches, switches_neighbors, and decision_cells.
+      self.reset_cells()
+
+    # MICHEL: maybe these three methods should be part of RailEnv?
+    def on_decision_cell(self, agent: EnvAgent) -> bool:
+        """
+        print(f"agent {agent.handle} is on decision cell")
+        if agent.position is None:
+          print("because agent.position is None (has not been activated yet)")
+        if agent.position == agent.initial_position:
+          print("because agent is at initial position, activated but not departed")
+        if agent.position in self.decision_cells:
+          print("because agent.position is in self.decision_cells.")
+        """
+        return agent.position is None or agent.position == agent.initial_position or agent.position in self.decision_cells
+
+    def on_switch(self, agent: EnvAgent) -> bool:
+        return agent.position in self.switches
+
+    def next_to_switch(self, agent: EnvAgent) -> bool:
+        return agent.position in self.switches_neighbors
+
+    # MICHEL: maybe just call this step()...
+    def no_choice_skip_step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]:
+        o, r, d, i = {}, {}, {}, {}
+      
+        # MICHEL: NEED TO INITIALIZE i["..."]
+        # as we will access i["..."][agent_id]
+        i["action_required"] = dict()
+        i["malfunction"] = dict()
+        i["speed"] = dict()
+        i["status"] = dict()
+
+        while len(o) == 0:
+            #print(f"len(o)==0. stepping the rail environment...")
+            obs, reward, done, info = self.env.step(action_dict)
+
+            for agent_id, agent_obs in obs.items():
+
+                ######  MICHEL: prints for debugging  ###########
+                if not self.on_decision_cell(self.env.agents[agent_id]):
+                      print(f"agent {agent_id} is NOT on a decision cell.")
+                #################################################
+
+
+                if done[agent_id] or self.on_decision_cell(self.env.agents[agent_id]):
+                    ######  MICHEL: prints for debugging  ######################
+                    if done[agent_id]:
+                      print(f"agent {agent_id} is done.")
+                    #if self.on_decision_cell(self.env.agents[agent_id]):
+                      #print(f"agent {agent_id} is on decision cell.")
+                      #cell = self.env.agents[agent_id].position
+                      #print(f"cell is: {cell}")
+                      #print(f"the decision cells are: {self.decision_cells}")
+                    
+                    ############################################################
+
+                    o[agent_id] = agent_obs
+                    r[agent_id] = reward[agent_id]
+                    d[agent_id] = done[agent_id]
+
+                    # MICHEL: HAVE TO MODIFY THIS HERE
+                    # because we are not using StepOutputs, the return values of step() have a different structure.
+                    #i[agent_id] = info[agent_id]
+                    i["action_required"][agent_id] = info["action_required"][agent_id] 
+                    i["malfunction"][agent_id] = info["malfunction"][agent_id]
+                    i["speed"][agent_id] = info["speed"][agent_id]
+                    i["status"][agent_id] = info["status"][agent_id]
+                                                                  
+                    if self.accumulate_skipped_rewards:
+                        discounted_skipped_reward = r[agent_id]
+                        for skipped_reward in reversed(self.skipped_rewards[agent_id]):
+                            discounted_skipped_reward = self.discounting * discounted_skipped_reward + skipped_reward
+                        r[agent_id] = discounted_skipped_reward
+                        self.skipped_rewards[agent_id] = []
+
+                elif self.accumulate_skipped_rewards:
+                    self.skipped_rewards[agent_id].append(reward[agent_id])
+                # end of for-loop
+
+            d['__all__'] = done['__all__']
+            action_dict = {}
+            # end of while-loop
+
+        return o, r, d, i
+
+    # MICHEL: maybe just call this reset()...
+    def reset_cells(self) -> None:
+        self.switches, self.switches_neighbors, self.decision_cells = find_all_cells_where_agent_can_choose(self.env)
+
+
+# IMPORTANT: rail env should be reset() / initialized before put into this one!
+# IDEA: MAYBE EACH RAILENV INSTANCE SHOULD AUTOMATICALLY BE reset() / initialized upon creation!
+class SkipNoChoiceCellsWrapper(RailEnvWrapper):
+  
+    # env can be a real RailEnv, or anything that shares the same interface
+    # e.g. obs, rewards, dones, info = env.step(action_dict) and obs, info = env.reset(), and so on.
+    def __init__(self, env:RailEnv, accumulate_skipped_rewards: bool, discounting: float) -> None:
+        super().__init__(env)
+        # save these so they can be inspected easier.
+        self.accumulate_skipped_rewards = accumulate_skipped_rewards
+        self.discounting = discounting
+        self.skipper = NoChoiceCellsSkipper(env=self.env, accumulate_skipped_rewards=self.accumulate_skipped_rewards, discounting=self.discounting)
+
+        self.skipper.reset_cells()
+
+        # TODO: this is clunky..
+        # for easier access / checking
+        self.switches = self.skipper.switches
+        self.switches_neighbors = self.skipper.switches_neighbors
+        self.decision_cells = self.skipper.decision_cells
+        self.skipped_rewards = self.skipper.skipped_rewards
+
+  
+    # MICHEL: trying to isolate the core part and put it into a separate method.
+    def step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]:
+        obs, rewards, dones, info = self.skipper.no_choice_skip_step(action_dict=action_dict)
+        return obs, rewards, dones, info
+        
+
+    # MICHEL: TODO: maybe add parameters like regenerate_rail, regenerate_schedule, etc.
+    # arguments from RailEnv.reset() are: self, regenerate_rail: bool = True, regenerate_schedule: bool = True, activate_agents: bool = False, random_seed: bool = None
+    # TODO: check the type of random_seed. Is it bool or int?
+    # MICHEL: changed return type from Dict[int, Any] to Tuple[Dict, Dict].
+    def reset(self, **kwargs) -> Tuple[Dict, Dict]:
+        obs, info = self.env.reset(**kwargs)
+        # resets decision cells, switches, etc. These can change with an env.reset(...)!
+        # needs to be done after env.reset().
+        self.skipper.reset_cells()
+        return obs, info
\ No newline at end of file
diff --git a/tests/test_pettingzoo_interface.py b/tests/test_pettingzoo_interface.py
index daff44f3..fda9fe9a 100644
--- a/tests/test_pettingzoo_interface.py
+++ b/tests/test_pettingzoo_interface.py
@@ -9,9 +9,8 @@ import typing
 from collections import defaultdict
 from typing import Dict, Any, Optional, Set, List, Tuple
 
-
-from examples import flatland_env
-from examples import env_generators
+from flatland.contrib.interface import flatland_env
+from flatland.contrib.utils import env_generators
 
 from flatland.envs.observations import TreeObsForRailEnv,GlobalObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
@@ -22,399 +21,11 @@ from flatland.utils.rendertools import RenderTool, AgentRenderVariant
 
 from flatland.envs.agent_utils import EnvAgent, RailAgentStatus
 from flatland.envs.rail_env import RailEnv, RailEnvActions
+from flatland.contrib.wrappers.flatland_wrappers import SkipNoChoiceCellsWrapper, ShortestPathActionWrapper
+import pytest
 
 
-def possible_actions_sorted_by_distance(env: RailEnv, handle: int):
-    agent = env.agents[handle]
-    
-
-    if agent.status == RailAgentStatus.READY_TO_DEPART:
-        agent_virtual_position = agent.initial_position
-    elif agent.status == RailAgentStatus.ACTIVE:
-        agent_virtual_position = agent.position
-    elif agent.status == RailAgentStatus.DONE:
-        agent_virtual_position = agent.target
-    else:
-        print("no action possible!")
-        if agent.status == RailAgentStatus.DONE_REMOVED:
-          print(f"agent status: DONE_REMOVED for agent {agent.handle}")
-          print("to solve this problem, do not input actions for removed agents!")
-          return [(RailEnvActions.DO_NOTHING, 0)] * 2
-        print("agent status:")
-        print(RailAgentStatus(agent.status))
-        #return None
-        # NEW: if agent is at target, DO_NOTHING, and distance is zero.
-        # NEW: (needs to be tested...)
-        return [(RailEnvActions.DO_NOTHING, 0)] * 2
-
-    possible_transitions = env.rail.get_transitions(*agent_virtual_position, agent.direction)
-    print(f"possible transitions: {possible_transitions}")
-    distance_map = env.distance_map.get()[handle]
-    possible_steps = []
-    for movement in list(range(4)):
-      # MICHEL: TODO: discuss with author of this code how it works, and why it breaks down in my test!
-      # should be much better commented or structured to be readable!
-        if possible_transitions[movement]:
-            if movement == agent.direction:
-                action = RailEnvActions.MOVE_FORWARD
-            elif movement == (agent.direction + 1) % 4:
-                action = RailEnvActions.MOVE_RIGHT
-            elif movement == (agent.direction - 1) % 4:
-                action = RailEnvActions.MOVE_LEFT
-            else:
-                # MICHEL: prints for debugging
-                print(f"An error occured. movement is: {movement}, agent direction is: {agent.direction}")
-                if movement == (agent.direction + 2) % 4 or (movement == agent.direction - 2) % 4:
-                    print("it seems that we are turning by 180 degrees. Turning in a dead end?")
-
-                # MICHEL: can this happen when we turn 180 degrees in a dead end?
-                # i.e. can we then have movement == agent.direction + 2 % 4 (resp. ... == - 2 % 4)?
-
-                # TRY OUT: ASSIGN MOVE_FORWARD HERE...
-                action = RailEnvActions.MOVE_FORWARD
-                print("Here we would have a ValueError...")
-                #raise ValueError("Wtf, debug this shit.")
-               
-            distance = distance_map[get_new_position(agent_virtual_position, movement) + (movement,)]
-            possible_steps.append((action, distance))
-    possible_steps = sorted(possible_steps, key=lambda step: step[1])
-
-
-    # MICHEL: what is this doing?
-    # if there is only one path to target, this is both the shortest one and the second shortest path.
-    if len(possible_steps) == 1:
-        return possible_steps * 2
-    else:
-        return possible_steps
-
-
-class RailEnvWrapper:
-  def __init__(self, env:RailEnv):
-    self.env = env
-
-    assert self.env is not None
-    assert self.env.rail is not None, "Reset original environment first!"
-    assert self.env.agents is not None, "Reset original environment first!"
-    assert len(self.env.agents) > 0, "Reset original environment first!"
-
-    # rail can be seen as part of the interface to RailEnv.
-    # is used by several wrappers, to e.g. access rail.get_valid_transitions(...)
-    #self.rail = self.env.rail
-    # same for env.agents
-    # MICHEL: DOES THIS HERE CAUSE A PROBLEM with agent status not being updated?
-    #self.agents = self.env.agents
-    #assert self.env.agents == self.agents
-    #print(f"agents of RailEnvWrapper are: {self.agents}")
-    #self.width = self.rail.width
-    #self.height = self.rail.height
-
-
-  # TODO: maybe do this in a generic way, like "for each method of self.env, ..."
-  # maybe using dir(self.env) (gives list of names of members)
-
-  # MICHEL: this seems to be needed after each env.reset(..) call
-  # otherwise, these attribute names refer to the wrong object and are out of sync...
-  # probably due to the reassignment of new objects to these variables by RailEnv, and how Python treats that.
-
-  # simple example: a = [1,2,3] b=a. But then a=[0]. Now we still have b==[1,2,3].
-
-  # it's better tou use properties here!
-
-  # @property
-  # def number_of_agents(self):
-  #   return self.env.number_of_agents
-  
-  # @property
-  # def agents(self):
-  #   return self.env.agents
-
-  # @property
-  # def _seed(self):
-  #   return self.env._seed
-
-  # @property
-  # def obs_builder(self):
-  #   return self.env.obs_builder
-
-  def __getattr__(self, name):
-    try:
-      return super().__getattr__(self,name)
-    except:
-      """Expose any other attributes of the underlying environment."""
-      return getattr(self.env, name)
-
-
-  @property
-  def rail(self):
-    return self.env.rail
-  
-  @property
-  def width(self):
-    return self.env.width
-  
-  @property
-  def height(self):
-    return self.env.height
-
-  @property
-  def agent_positions(self):
-    return self.env.agent_positions
-
-  def get_num_agents(self):
-    return self.env.get_num_agents()
-
-  def get_agent_handles(self):
-    return self.env.get_agent_handles()
-
-  def step(self, action_dict: Dict[int, RailEnvActions]):
-    #self.agents = self.env.agents
-    # ERROR. something is wrong with the references for self.agents...
-    #assert self.env.agents == self.agents
-    return self.env.step(action_dict)
-
-  def reset(self, **kwargs):
-    # MICHEL: I suspect that env.reset() does not simply change values of variables, but assigns new objects
-    # that might cause some attributes not be properly updated here, because of how Python treats assignments differently from modification..
-    #assert self.env.agents == self.agents
-    obs, info = self.env.reset(**kwargs)
-    #assert self.env.agents == self.agents, "after resetting internal env, self.agents names wrong object..."
-    #self.reset_attributes()
-    #print(f"calling RailEnvWrapper.reset()")
-    #print(f"obs: {obs}, info:{info}")
-    return obs, info
-
-
-class ShortestPathActionWrapper(RailEnvWrapper):
-
-    def __init__(self, env:RailEnv):
-        super().__init__(env)
-        #self.action_space = gym.spaces.Discrete(n=3)  # 0:stop, 1:shortest path, 2:other direction
-
-    # MICHEL: we have to make sure that not agents with agent.status == DONE_REMOVED are in the action dict.
-    # otherwise, possible_actions_sorted_by_distance(self.env, agent_id)[action - 1][0] will crash.
-    def step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]:
-      ########## MICHEL: NEW (just for debugging) ########
-      for agent_id, action in action_dict.items():
-        agent = self.agents[agent_id]
-        # assert agent.status != RailAgentStatus.DONE_REMOVED # this comes with agent.position == None...
-        # assert agent.status != RailAgentStatus.DONE # not sure about this one...
-        print(f"agent: {agent} with status: {agent.status}")
-      ######################################################
-
-      # input: action dict with actions in [0, 1, 2].
-      transformed_action_dict = {}
-      for agent_id, action in action_dict.items():
-          if action == 0:
-              transformed_action_dict[agent_id] = action
-          else:
-              assert action in [1, 2]
-              # MICHEL: how exactly do the indices work here?
-              #transformed_action_dict[agent_id] = possible_actions_sorted_by_distance(self.rail_env, agent_id)[action - 1][0]
-              #print(f"possible actions sorted by distance(...) is: {possible_actions_sorted_by_distance(self.env, agent_id)}")
-              #assert agent.status != RailAgentStatus.DONE_REMOVED
-              # MICHEL: THIS LINE CRASHES WITH A "NoneType is not subscriptable" error...
-              assert possible_actions_sorted_by_distance(self.env, agent_id) is not None
-              assert possible_actions_sorted_by_distance(self.env, agent_id)[action - 1] is not None
-              transformed_action_dict[agent_id] = possible_actions_sorted_by_distance(self.env, agent_id)[action - 1][0]
-      obs, rewards, dones, info = self.env.step(transformed_action_dict)
-      return obs, rewards, dones, info
-
-    #def reset(self, random_seed: Optional[int] = None) -> Dict[int, Any]:
-        #return self.rail_env.reset(random_seed)
-
-    # MICHEL: should not be needed, as we inherit that from RailEnvWrapper...
-    #def reset(self, **kwargs) -> Tuple[Dict, Dict]:
-    #  obs, info = self.env.reset(**kwargs)
-    #  return obs, info
-
-
-def find_all_cells_where_agent_can_choose(env: RailEnv):
-    """
-    input: a RailEnv (or something which behaves similarly, e.g. a wrapped RailEnv),
-    WHICH HAS BEEN RESET ALREADY!
-    (o.w., we call env.rail, which is None before reset(), and crash.)
-    """
-    switches = []
-    switches_neighbors = []
-    directions = list(range(4))
-    for h in range(env.height):
-        for w in range(env.width):
-
-            # MICHEL: THIS SEEMS TO BE A BUG. WRONG ODER OF COORDINATES.
-            # will not show up in quadratic environments.
-            # should be pos = (h, w)
-            #pos = (w, h)
-
-            # MICHEL: changed this
-            pos = (h, w)
-
-            is_switch = False
-            # Check for switch: if there is more than one outgoing transition
-            for orientation in directions:
-                #print(f"env is: {env}")
-                #print(f"env.rail is: {env.rail}")
-                possible_transitions = env.rail.get_transitions(*pos, orientation)
-                num_transitions = np.count_nonzero(possible_transitions)
-                if num_transitions > 1:
-                    switches.append(pos)
-                    is_switch = True
-                    break
-            if is_switch:
-                # Add all neighbouring rails, if pos is a switch
-                for orientation in directions:
-                    possible_transitions = env.rail.get_transitions(*pos, orientation)
-                    for movement in directions:
-                        if possible_transitions[movement]:
-                            switches_neighbors.append(get_new_position(pos, movement))
-
-    decision_cells = switches + switches_neighbors
-    return tuple(map(set, (switches, switches_neighbors, decision_cells)))
-
-
-class NoChoiceCellsSkipper:
-    def __init__(self, env:RailEnv, accumulate_skipped_rewards: bool, discounting: float) -> None:
-      self.env = env
-      self.switches = None
-      self.switches_neighbors = None
-      self.decision_cells = None
-      self.accumulate_skipped_rewards = accumulate_skipped_rewards
-      self.discounting = discounting
-      self.skipped_rewards = defaultdict(list)
-
-      # env.reset() can change the rail grid layout, so the switches, etc. will change! --> need to do this in reset() as well.
-      #self.switches, self.switches_neighbors, self.decision_cells = find_all_cells_where_agent_can_choose(self.env)
-
-      # compute and initialize value for switches, switches_neighbors, and decision_cells.
-      self.reset_cells()
-
-    # MICHEL: maybe these three methods should be part of RailEnv?
-    def on_decision_cell(self, agent: EnvAgent) -> bool:
-        """
-        print(f"agent {agent.handle} is on decision cell")
-        if agent.position is None:
-          print("because agent.position is None (has not been activated yet)")
-        if agent.position == agent.initial_position:
-          print("because agent is at initial position, activated but not departed")
-        if agent.position in self.decision_cells:
-          print("because agent.position is in self.decision_cells.")
-        """
-        return agent.position is None or agent.position == agent.initial_position or agent.position in self.decision_cells
-
-    def on_switch(self, agent: EnvAgent) -> bool:
-        return agent.position in self.switches
-
-    def next_to_switch(self, agent: EnvAgent) -> bool:
-        return agent.position in self.switches_neighbors
-
-    # MICHEL: maybe just call this step()...
-    def no_choice_skip_step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]:
-        o, r, d, i = {}, {}, {}, {}
-      
-        # MICHEL: NEED TO INITIALIZE i["..."]
-        # as we will access i["..."][agent_id]
-        i["action_required"] = dict()
-        i["malfunction"] = dict()
-        i["speed"] = dict()
-        i["status"] = dict()
-
-        while len(o) == 0:
-            #print(f"len(o)==0. stepping the rail environment...")
-            obs, reward, done, info = self.env.step(action_dict)
-
-            for agent_id, agent_obs in obs.items():
-
-                ######  MICHEL: prints for debugging  ###########
-                if not self.on_decision_cell(self.env.agents[agent_id]):
-                      print(f"agent {agent_id} is NOT on a decision cell.")
-                #################################################
-
-
-                if done[agent_id] or self.on_decision_cell(self.env.agents[agent_id]):
-                    ######  MICHEL: prints for debugging  ######################
-                    if done[agent_id]:
-                      print(f"agent {agent_id} is done.")
-                    #if self.on_decision_cell(self.env.agents[agent_id]):
-                      #print(f"agent {agent_id} is on decision cell.")
-                      #cell = self.env.agents[agent_id].position
-                      #print(f"cell is: {cell}")
-                      #print(f"the decision cells are: {self.decision_cells}")
-                    
-                    ############################################################
-
-                    o[agent_id] = agent_obs
-                    r[agent_id] = reward[agent_id]
-                    d[agent_id] = done[agent_id]
-
-                    # MICHEL: HAVE TO MODIFY THIS HERE
-                    # because we are not using StepOutputs, the return values of step() have a different structure.
-                    #i[agent_id] = info[agent_id]
-                    i["action_required"][agent_id] = info["action_required"][agent_id] 
-                    i["malfunction"][agent_id] = info["malfunction"][agent_id]
-                    i["speed"][agent_id] = info["speed"][agent_id]
-                    i["status"][agent_id] = info["status"][agent_id]
-                                                                  
-                    if self.accumulate_skipped_rewards:
-                        discounted_skipped_reward = r[agent_id]
-                        for skipped_reward in reversed(self.skipped_rewards[agent_id]):
-                            discounted_skipped_reward = self.discounting * discounted_skipped_reward + skipped_reward
-                        r[agent_id] = discounted_skipped_reward
-                        self.skipped_rewards[agent_id] = []
-
-                elif self.accumulate_skipped_rewards:
-                    self.skipped_rewards[agent_id].append(reward[agent_id])
-                # end of for-loop
-
-            d['__all__'] = done['__all__']
-            action_dict = {}
-            # end of while-loop
-
-        return o, r, d, i
-
-    # MICHEL: maybe just call this reset()...
-    def reset_cells(self) -> None:
-        self.switches, self.switches_neighbors, self.decision_cells = find_all_cells_where_agent_can_choose(self.env)
-
-
-# IMPORTANT: rail env should be reset() / initialized before put into this one!
-# IDEA: MAYBE EACH RAILENV INSTANCE SHOULD AUTOMATICALLY BE reset() / initialized upon creation!
-class SkipNoChoiceCellsWrapper(RailEnvWrapper):
-  
-    # env can be a real RailEnv, or anything that shares the same interface
-    # e.g. obs, rewards, dones, info = env.step(action_dict) and obs, info = env.reset(), and so on.
-    def __init__(self, env:RailEnv, accumulate_skipped_rewards: bool, discounting: float) -> None:
-        super().__init__(env)
-        # save these so they can be inspected easier.
-        self.accumulate_skipped_rewards = accumulate_skipped_rewards
-        self.discounting = discounting
-        self.skipper = NoChoiceCellsSkipper(env=self.env, accumulate_skipped_rewards=self.accumulate_skipped_rewards, discounting=self.discounting)
-
-        self.skipper.reset_cells()
-
-        # TODO: this is clunky..
-        # for easier access / checking
-        self.switches = self.skipper.switches
-        self.switches_neighbors = self.skipper.switches_neighbors
-        self.decision_cells = self.skipper.decision_cells
-        self.skipped_rewards = self.skipper.skipped_rewards
-
-  
-    # MICHEL: trying to isolate the core part and put it into a separate method.
-    def step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]:
-        obs, rewards, dones, info = self.skipper.no_choice_skip_step(action_dict=action_dict)
-        return obs, rewards, dones, info
-        
-
-    # MICHEL: TODO: maybe add parameters like regenerate_rail, regenerate_schedule, etc.
-    # arguments from RailEnv.reset() are: self, regenerate_rail: bool = True, regenerate_schedule: bool = True, activate_agents: bool = False, random_seed: bool = None
-    # TODO: check the type of random_seed. Is it bool or int?
-    # MICHEL: changed return type from Dict[int, Any] to Tuple[Dict, Dict].
-    def reset(self, **kwargs) -> Tuple[Dict, Dict]:
-        obs, info = self.env.reset(**kwargs)
-        # resets decision cells, switches, etc. These can change with an env.reset(...)!
-        # needs to be done after env.reset().
-        self.skipper.reset_cells()
-        return obs, info
-
+@pytest.mark.skip(reason="Only for testing pettingzoo interface and wrappers")
 def test_petting_zoo_interface_env():
 
     # Custom observation builder without predictor
@@ -441,7 +52,8 @@ def test_petting_zoo_interface_env():
 
     rail_env.reset(random_seed=seed)
 
-    # rail_env = ShortestPathActionWrapper(rail_env)
+    # For Shortest Path Action Wrapper, change action to 1
+    # rail_env = ShortestPathActionWrapper(rail_env)  
     rail_env = SkipNoChoiceCellsWrapper(rail_env, accumulate_skipped_rewards=False, discounting=0.0)
     
 
@@ -490,6 +102,8 @@ def test_petting_zoo_interface_env():
                             screen_width=800)  # Adjust these parameters to fit your resolution
             rail_env.reset(random_seed=seed+ep_no)
 
+    
+#  __sphinx_doc_begin__
     env = flatland_env.env(environment = rail_env, use_renderer = True)
     seed = 11
     env.reset(random_seed=seed)
@@ -505,7 +119,7 @@ def test_petting_zoo_interface_env():
             env.step(act)
             frame_list.append(PIL.Image.fromarray(env.render(mode='rgb_array')))
             step+=1
-
+# __sphinx_doc_end__
         completion = env_generators.perc_completion(env)
         print("Final Agents Completed:",completion)
         ep_no+=1
-- 
GitLab