From 9d28be8f9a7b95b2f50f6e54172d59ed939af0e3 Mon Sep 17 00:00:00 2001
From: hagrid67 <jdhwatson@gmail.com>
Date: Sat, 15 Aug 2020 18:52:45 +0100
Subject: [PATCH] manually merging Adrian's changes (made by Erik) from master

---
 flatland/envs/agent_chains.py    |  8 ----
 flatland/envs/rail_env.py        | 78 ++++++++++++++++++++------------
 flatland/utils/env_edit_utils.py |  7 ++-
 3 files changed, 56 insertions(+), 37 deletions(-)

diff --git a/flatland/envs/agent_chains.py b/flatland/envs/agent_chains.py
index 7706f6d4..d13b734b 100644
--- a/flatland/envs/agent_chains.py
+++ b/flatland/envs/agent_chains.py
@@ -2,7 +2,6 @@
 import networkx as nx
 import numpy as np
 
-import matplotlib.pyplot as plt
 from typing import List, Tuple
 import graphviz as gv
 
@@ -372,18 +371,11 @@ def test_agent_following():
             for v in lvCells ]
     dPos = dict(zip(lvCells, lvCells))
 
-    #plt.ion()
     nx.draw(omc.G, 
         with_labels=True, arrowsize=20, 
         pos=dPos,
         node_color = lColours)
 
-    
-    #plt.pause(20)
-    #plt.show()
-    
-
-
 def main():
 
     test_agent_following()
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 11c471ff..d2f8f4d2 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -4,13 +4,10 @@ Definition of the RailEnv environment.
 import random
 # TODO:  _ this is a global method --> utils or remove later
 from enum import IntEnum
-from typing import List, NamedTuple, Optional, Dict
+from typing import List, NamedTuple, Optional, Dict, Tuple
 
-import msgpack
-import msgpack_numpy as m
 import numpy as np
-from gym.utils import seeding
-from msgpack import Packer
+
 
 from flatland.core.env import Environment
 from flatland.core.env_observation_builder import ObservationBuilder
@@ -28,21 +25,50 @@ from flatland.envs import schedule_generators as sched_gen
 from flatland.envs import persistence
 from flatland.envs import agent_chains as ac
 
+from flatland.envs.observations import GlobalObsForRailEnv
+from gym.utils import seeding
+
 # Direct import of objects / classes does not work with circular imports.
 # from flatland.envs.malfunction_generators import no_malfunction_generator, Malfunction, MalfunctionProcessData
 # from flatland.envs.observations import GlobalObsForRailEnv
 # from flatland.envs.rail_generators import random_rail_generator, RailGenerator
 # from flatland.envs.schedule_generators import random_schedule_generator, ScheduleGenerator
 
-from flatland.envs.observations import GlobalObsForRailEnv
-
-# import debugpy
 
-import pickle
 
 m.patch()
 
 
+# Adrian Egli performance fix (the fast methods brings more than 50%)
+def fast_isclose(a, b, rtol):
+    return (a < (b + rtol)) or (a < (b - rtol))
+
+
+def fast_clip(position: (int, int), min_value: (int, int), max_value: (int, int)) -> bool:
+    return (
+        max(min_value[0], min(position[0], max_value[0])),
+        max(min_value[1], min(position[1], max_value[1]))
+    )
+
+
+def fast_argmax(possible_transitions: (int, int, int, int)) -> bool:
+    if possible_transitions[0] == 1:
+        return 0
+    if possible_transitions[1] == 1:
+        return 1
+    if possible_transitions[2] == 1:
+        return 2
+    return 3
+
+
+def fast_position_equal(pos_1: (int, int), pos_2: (int, int)) -> bool:
+    return pos_1[0] == pos_2[0] and pos_1[1] == pos_2[1]
+
+
+def fast_count_nonzero(possible_transitions: (int, int, int, int)):
+    return possible_transitions[0] + possible_transitions[1] + possible_transitions[2] + possible_transitions[3]
+
+
 class RailEnvActions(IntEnum):
     DO_NOTHING = 0  # implies change of direction in a dead-end!
     MOVE_LEFT = 1
@@ -298,11 +324,11 @@ class RailEnv(Environment):
         False: Agent cannot provide an action
         """
         return (agent.status == RailAgentStatus.READY_TO_DEPART or (
-            agent.status == RailAgentStatus.ACTIVE and np.isclose(agent.speed_data['position_fraction'], 0.0,
-                                                                  rtol=1e-03)))
+            agent.status == RailAgentStatus.ACTIVE and fast_isclose(agent.speed_data['position_fraction'], 0.0,
+                                                                    rtol=1e-03)))
 
     def reset(self, regenerate_rail: bool = True, regenerate_schedule: bool = True, activate_agents: bool = False,
-              random_seed: bool = None) -> (Dict, Dict):
+              random_seed: bool = None) -> Tuple[Dict, Dict]:
         """
         reset(regenerate_rail, regenerate_schedule, activate_agents, random_seed)
 
@@ -604,7 +630,7 @@ class RailEnv(Environment):
                 RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT, RailEnvActions.MOVE_FORWARD]
 
             if action in [RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT,
-                        RailEnvActions.MOVE_FORWARD] and self.cell_free(agent.initial_position):
+                          RailEnvActions.MOVE_FORWARD] and self.cell_free(agent.initial_position):
                 agent.status = RailAgentStatus.ACTIVE
                 self._set_agent_to_initial_position(agent, agent.initial_position)
                 self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
@@ -626,7 +652,7 @@ class RailEnv(Environment):
         # Is the agent at the beginning of the cell? Then, it can take an action.
         # As long as the agent is malfunctioning or stopped at the beginning of the cell,
         # different actions may be taken!
-        if np.isclose(agent.speed_data['position_fraction'], 0.0, rtol=1e-03):
+        if fast_isclose(agent.speed_data['position_fraction'], 0.0, rtol=1e-03):
             # No action has been supplied for this agent -> set DO_NOTHING as default
             if action is None:
                 action = RailEnvActions.DO_NOTHING
@@ -686,8 +712,8 @@ class RailEnv(Environment):
         #   transition_action_on_cellexit if the cell is free.
         if agent.moving:
             agent.speed_data['position_fraction'] += agent.speed_data['speed']
-            if agent.speed_data['position_fraction'] > 1.0 or np.isclose(agent.speed_data['position_fraction'], 1.0,
-                                                                         rtol=1e-03):
+            if agent.speed_data['position_fraction'] > 1.0 or fast_isclose(agent.speed_data['position_fraction'], 1.0,
+                                                                           rtol=1e-03):
                 # Perform stored action to transition to the next cell as soon as cell is free
                 # Notice that we've already checked new_cell_valid and transition valid when we stored the action,
                 # so we only have to check cell_free now!
@@ -695,7 +721,7 @@ class RailEnv(Environment):
                 # Traditional check that next cell is free
                 # cell and transition validity was checked when we stored transition_action_on_cellexit!
                 cell_free, new_cell_valid, new_direction, new_position, transition_valid = self._check_action_on_agent(
-                        agent.speed_data['transition_action_on_cellexit'], agent)
+                    agent.speed_data['transition_action_on_cellexit'], agent)
 
                 # N.B. validity of new_cell and transition should have been verified before the action was stored!
                 assert new_cell_valid
@@ -845,7 +871,6 @@ class RailEnv(Environment):
             trans_block = sbTrans[agent.direction*4 : agent.direction * 4 + 4]
             if (trans_block == "0000"):
                 print (i_agent, agent.position, agent.direction, sbTrans, trans_block)
-                # debugpy.breakpoint()
 
         # if agent cannot enter env, then we should have move=False
         
@@ -862,20 +887,16 @@ class RailEnv(Environment):
         
                 if not all([transition_valid, new_cell_valid]):
                     print(f"ERRROR: step_agent2 invalid transition ag {i_agent} dir {new_direction} pos {agent.position} next {rc_next}")
-                    # debugpy.breakpoint()
 
                 if new_position != rc_next:
                     print(f"ERROR: agent {i_agent} new_pos {new_position} != rc_next {rc_next}  " + 
-                        f"pos {agent.position} dir {agent.direction} new_dir {new_direction}" +
-                        f"stored action: {agent.speed_data['transition_action_on_cellexit']}")
-                    # debugpy.breakpoint()
-
+                          f"pos {agent.position} dir {agent.direction} new_dir {new_direction}" +
+                          f"stored action: {agent.speed_data['transition_action_on_cellexit']}")
 
                 sbTrans = format(self.rail.grid[agent.position], "016b")
                 trans_block = sbTrans[agent.direction*4 : agent.direction * 4 + 4]
                 if (trans_block == "0000"):
                     print ("ERROR: ", i_agent, agent.position, agent.direction, sbTrans, trans_block)
-                    # debugpy.breakpoint()
 
                 agent.position = rc_next
                 agent.direction = new_direction
@@ -937,6 +958,7 @@ class RailEnv(Environment):
         self.agent_positions[agent.position] = -1
         if self.remove_agents_at_target:
             agent.position = None
+            # setting old_position to None here stops the DONE agents from appearing in the rendered image
             agent.old_position = None
             agent.status = RailAgentStatus.DONE_REMOVED
 
@@ -964,9 +986,9 @@ class RailEnv(Environment):
         new_position = get_new_position(agent.position, new_direction)
 
         new_cell_valid = (
-            np.array_equal(  # Check the new position is still in the grid
+            fast_position_equal(  # Check the new position is still in the grid
                 new_position,
-                np.clip(new_position, [0, 0], [self.height - 1, self.width - 1]))
+                fast_clip(new_position, [0, 0], [self.height - 1, self.width - 1]))
             and  # check the new position has some transitions (ie is not an empty cell)
             self.rail.get_full_transitions(*new_position) > 0)
 
@@ -1038,7 +1060,7 @@ class RailEnv(Environment):
         """
         transition_valid = None
         possible_transitions = self.rail.get_transitions(*agent.position, agent.direction)
-        num_transitions = np.count_nonzero(possible_transitions)
+        num_transitions = fast_count_nonzero(possible_transitions)
 
         new_direction = agent.direction
         if action == RailEnvActions.MOVE_LEFT:
@@ -1057,7 +1079,7 @@ class RailEnv(Environment):
             # - dead-end, straight line or curved line;
             # new_direction will be the only valid transition
             # - take only available transition
-            new_direction = np.argmax(possible_transitions)
+            new_direction = fast_argmax(possible_transitions)
             transition_valid = True
         return new_direction, transition_valid
 
diff --git a/flatland/utils/env_edit_utils.py b/flatland/utils/env_edit_utils.py
index ac748469..b1a40174 100644
--- a/flatland/utils/env_edit_utils.py
+++ b/flatland/utils/env_edit_utils.py
@@ -122,5 +122,10 @@ def makeTestEnv(sName="single_alternative", nAg=2, bUCF=True):
     
     dSpec = ddEnvSpecs[sName]
 
-
     return makeEnv2(nAg=nAg, bUCF=bUCF, **dSpec)
+
+def getAgentState(env):
+    dAgState={}
+    for iAg, ag in enumerate(env.agents):
+        dAgState[iAg] = (*ag.position, ag.direction)
+    return dAgState
\ No newline at end of file
-- 
GitLab