From 6db05ca97bc5a260d4c434d7bd3e2016e1a388cf Mon Sep 17 00:00:00 2001
From: Dipam Chakraborty <dipam@aicrowd.com>
Date: Sat, 11 Sep 2021 21:39:00 +0530
Subject: [PATCH] WIP test fixes

---
 flatland/envs/agent_utils.py                  |  13 +-
 flatland/envs/rail_env.py                     |  16 ++-
 flatland/envs/step_utils/action_saver.py      |   3 +
 .../envs/step_utils/malfunction_handler.py    |   2 +
 flatland/envs/step_utils/speed_counter.py     |   9 +-
 flatland/envs/step_utils/state_machine.py     |  11 +-
 ...est_flatland_envs_sparse_rail_generator.py |   2 +-
 tests/test_multi_speed.py                     |   3 +-
 tests/test_state_machine.py                   | 115 ------------------
 tests/test_utils.py                           |   9 +-
 10 files changed, 50 insertions(+), 133 deletions(-)
 delete mode 100644 tests/test_state_machine.py

diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py
index ad145f54..ac1ef626 100644
--- a/flatland/envs/agent_utils.py
+++ b/flatland/envs/agent_utils.py
@@ -216,15 +216,12 @@ class EnvAgent:
             agents.append(agent)
         return agents
     
-    def _set_state(self, state):
-        warnings.warn("Not recommended to set the state with this function unless completely required")
-        self.state_machine.set_state(state)
-    
     def __str__(self):
         return f"\n \
                  handle(agent index): {self.handle} \n \
                  initial_position: {self.initial_position}   initial_direction: {self.initial_direction} \n \
                  position: {self.position}  direction: {self.direction}  target: {self.target} \n \
+                 old_position: {self.old_position} old_direction {self.old_direction} \n \
                  earliest_departure: {self.earliest_departure}  latest_arrival: {self.latest_arrival} \n \
                  state: {str(self.state)} \n \
                  malfunction_data: {self.malfunction_data} \n \
@@ -235,6 +232,14 @@ class EnvAgent:
     def state(self):
         return self.state_machine.state
 
+    @state.setter
+    def state(self, state):
+        self._set_state(state)
+    
+    def _set_state(self, state):
+        warnings.warn("Not recommended to set the state with this function unless completely required")
+        self.state_machine.set_state(state)
+
 
     
 
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 5f4578aa..0a642a4c 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -261,7 +261,7 @@ class RailEnv(Environment):
         False: Agent cannot provide an action
         """
         return agent.state == TrainState.READY_TO_DEPART or \
-               (agent.state.is_on_map_state() and agent.speed_counter.is_cell_entry )
+               ( agent.state.is_on_map_state() and agent.speed_counter.is_cell_entry )
 
     def reset(self, regenerate_rail: bool = True, regenerate_schedule: bool = True, *,
               random_seed: bool = None) -> Tuple[Dict, Dict]:
@@ -385,13 +385,14 @@ class RailEnv(Environment):
         st_signals.stop_action_given = (preprocessed_action == RailEnvActions.STOP_MOVING)
 
         # Valid Movement action Given
-        st_signals.valid_movement_action_given = preprocessed_action.is_moving_action()
+        st_signals.valid_movement_action_given = preprocessed_action.is_moving_action() and movement_allowed
 
         # Target Reached
         st_signals.target_reached = fast_position_equal(agent.position, agent.target)
 
         # Movement conflict - Multiple trains trying to move into same cell
-        st_signals.movement_conflict = (not movement_allowed) and agent.speed_counter.is_cell_exit # TODO: Modify motion check to provide proper conflict information
+        # If speed counter is not in cell exit, the train can enter the cell
+        st_signals.movement_conflict = (not movement_allowed) and agent.speed_counter.is_cell_exit
 
         return st_signals
 
@@ -499,6 +500,8 @@ class RailEnv(Environment):
         
         for agent in self.agents:
             i_agent = agent.handle
+            agent.old_position = agent.position
+            agent.old_direction = agent.direction
             # Generate malfunction
             agent.malfunction_handler.generate_malfunction(self.malfunction_generator, self.np_random)
 
@@ -542,8 +545,6 @@ class RailEnv(Environment):
             i_agent = agent.handle
             agent_transition_data = temp_transition_data[i_agent]
 
-            old_position = agent.position
-
             ## Update positions
             if agent.malfunction_handler.in_malfunction:
                 movement_allowed = False
@@ -561,6 +562,9 @@ class RailEnv(Environment):
             agent.state_machine.set_transition_signals(state_transition_signals)
             agent.state_machine.step()
 
+            if agent.state.is_on_map_state() and agent.position is None:
+                import pdb; pdb.set_trace()
+
             # Handle done state actions, optionally remove agents
             self.handle_done_state(agent)
             
@@ -570,7 +574,7 @@ class RailEnv(Environment):
             self.update_step_rewards(i_agent)
 
             ## Update counters (malfunction and speed)
-            agent.speed_counter.update_counter(agent.state, old_position)
+            agent.speed_counter.update_counter(agent.state, agent.old_position)
             agent.malfunction_handler.update_counter()
 
             # Clear old action when starting in new cell
diff --git a/flatland/envs/step_utils/action_saver.py b/flatland/envs/step_utils/action_saver.py
index d8a8ccda..5e6c8a8c 100644
--- a/flatland/envs/step_utils/action_saver.py
+++ b/flatland/envs/step_utils/action_saver.py
@@ -28,5 +28,8 @@ class ActionSaver:
     
     def from_dict(self, load_dict):
         self.saved_action = load_dict['saved_action']
+    
+    def __eq__(self, other):
+        return self.saved_action == other.saved_action
 
 
diff --git a/flatland/envs/step_utils/malfunction_handler.py b/flatland/envs/step_utils/malfunction_handler.py
index 914fd90d..a45aa024 100644
--- a/flatland/envs/step_utils/malfunction_handler.py
+++ b/flatland/envs/step_utils/malfunction_handler.py
@@ -46,6 +46,8 @@ class MalfunctionHandler:
     def from_dict(self, load_dict):
         self._malfunction_down_counter = load_dict['malfunction_down_counter']
 
+    def __eq__(self, other):
+        return self._malfunction_down_counter == other._malfunction_down_counter
 
     
 
diff --git a/flatland/envs/step_utils/speed_counter.py b/flatland/envs/step_utils/speed_counter.py
index 5aae041d..1c2c7279 100644
--- a/flatland/envs/step_utils/speed_counter.py
+++ b/flatland/envs/step_utils/speed_counter.py
@@ -4,6 +4,8 @@ from flatland.envs.step_utils.states import TrainState
 class SpeedCounter:
     def __init__(self, speed):
         self._speed = speed
+        self.counter = None
+        self.reset_counter()
 
     def update_counter(self, state, old_position):
         # When coming onto the map, do no update speed counter
@@ -38,8 +40,13 @@ class SpeedCounter:
         return int(1/self._speed) - 1
 
     def to_dict(self):
-        return {"speed": self._speed}
+        return {"speed": self._speed,
+                "counter": self.counter}
     
     def from_dict(self, load_dict):
         self._speed = load_dict['speed']
+        self.counter = load_dict['counter']
+
+    def __eq__(self, other):
+        return self._speed == other._speed and self.counter == other.counter
 
diff --git a/flatland/envs/step_utils/state_machine.py b/flatland/envs/step_utils/state_machine.py
index 8067d8fb..d1938f4f 100644
--- a/flatland/envs/step_utils/state_machine.py
+++ b/flatland/envs/step_utils/state_machine.py
@@ -6,6 +6,7 @@ class TrainStateMachine:
         self._state = initial_state
         self.st_signals = StateTransitionSignals()
         self.next_state = None
+        self.previous_state = None
     
     def _handle_waiting(self):
         """" Waiting state goes to ready to depart when earliest departure is reached"""
@@ -117,10 +118,12 @@ class TrainStateMachine:
     def set_state(self, state):
         if not TrainState.check_valid_state(state):
             raise ValueError(f"Cannot set invalid state {state}")
+        self.previous_state = self._state
         self._state = state
 
     def reset(self):
         self._state = self._initial_state
+        self.previous_state = None
         self.st_signals = StateTransitionSignals()
         self.clear_next_state()
 
@@ -137,15 +140,19 @@ class TrainStateMachine:
 
     def __repr__(self):
         return f"\n \
-                 state: {str(self.state)} \n \
+                 state: {str(self.state)}      previous_state {str(self.previous_state)} \n \
                  st_signals: {self.st_signals}"
 
     def to_dict(self):
-        return {"state": self._state}
+        return {"state": self._state,
+                "previous_state": self.previous_state}
 
     def from_dict(self, load_dict):
         self.set_state(load_dict['state'])
+        self.previous_state = load_dict['previous_state']
 
+    def __eq__(self, other):
+        return self._state == other._state and self.previous_state == other.previous_state
 
 
         
diff --git a/tests/test_flatland_envs_sparse_rail_generator.py b/tests/test_flatland_envs_sparse_rail_generator.py
index 358839f9..d98b4b32 100644
--- a/tests/test_flatland_envs_sparse_rail_generator.py
+++ b/tests/test_flatland_envs_sparse_rail_generator.py
@@ -1375,7 +1375,7 @@ def test_rail_env_malfunction_speed_info():
         for a in range(env.get_num_agents()):
             assert info['malfunction'][a] >= 0
             assert info['speed'][a] >= 0 and info['speed'][a] <= 1
-            assert info['speed'][a] == env.agents[a].sspeed_counter.speed
+            assert info['speed'][a] == env.agents[a].speed_counter.speed
 
         env_renderer.render_env(show=True, show_observations=False, show_predictions=False)
 
diff --git a/tests/test_multi_speed.py b/tests/test_multi_speed.py
index 50565e96..6455e573 100644
--- a/tests/test_multi_speed.py
+++ b/tests/test_multi_speed.py
@@ -9,6 +9,7 @@ from flatland.envs.line_generators import sparse_line_generator
 from flatland.utils.simple_rail import make_simple_rail
 from test_utils import ReplayConfig, Replay, run_replay_config, set_penalties_for_replay
 from flatland.envs.step_utils.states import TrainState
+from flatland.envs.step_utils.speed_counter import SpeedCounter
 
 
 # Use the sparse_rail_generator to generate feasible network configurations with corresponding tasks
@@ -71,7 +72,7 @@ def test_multi_speed_init():
     # See training navigation example in the baseline repository
     old_pos = []
     for i_agent in range(env.get_num_agents()):
-        env.agents[i_agent].speed_counter.speed = 1. / (i_agent + 1)
+        env.agents[i_agent].speed_counter = SpeedCounter(speed = 1. / (i_agent + 1))
         old_pos.append(env.agents[i_agent].position)
         print(env.agents[i_agent].position)
     # Run episode
diff --git a/tests/test_state_machine.py b/tests/test_state_machine.py
deleted file mode 100644
index 266a8f86..00000000
--- a/tests/test_state_machine.py
+++ /dev/null
@@ -1,115 +0,0 @@
-from test_env_step_utils import get_small_two_agent_env
-from flatland.envs.rail_env_action import RailEnvActions
-from flatland.envs.step_utils.states import TrainState
-from flatland.envs.malfunction_generators import Malfunction
-
-class NoMalfunctionGenerator:
-    def generate(self, np_random):
-        return Malfunction(0)
-
-class AlwaysThreeStepMalfunction:
-    def generate(self, np_random):
-        return Malfunction(3)
-
-def test_waiting_no_transition():
-    env = get_small_two_agent_env()
-    env.malfunction_generator = NoMalfunctionGenerator()
-    i_agent = 0
-    ed = env.agents[i_agent].earliest_departure
-    for _ in range(ed-1):
-        env.step({i_agent: RailEnvActions.MOVE_FORWARD})
-        assert env.agents[i_agent].state == TrainState.WAITING
-    
-    
-def test_waiting_to_ready_to_depart():
-    env = get_small_two_agent_env()
-    env.malfunction_generator = NoMalfunctionGenerator()
-    i_agent = 0
-    ed = env.agents[i_agent].earliest_departure
-    for _ in range(ed):
-        env.step({i_agent: RailEnvActions.DO_NOTHING})
-    assert env.agents[i_agent].state == TrainState.READY_TO_DEPART
-
-
-def test_ready_to_depart_to_moving():
-    env = get_small_two_agent_env()
-    env.malfunction_generator = NoMalfunctionGenerator()
-    i_agent = 0
-    ed = env.agents[i_agent].earliest_departure
-    for _ in range(ed):
-        env.step({i_agent: RailEnvActions.DO_NOTHING})
-
-    env.step({i_agent: RailEnvActions.MOVE_FORWARD})
-    assert env.agents[i_agent].state == TrainState.MOVING
-
-def test_moving_to_stopped():
-    env = get_small_two_agent_env()
-    env.malfunction_generator = NoMalfunctionGenerator()
-    i_agent = 0
-    ed = env.agents[i_agent].earliest_departure
-    for _ in range(ed):
-        env.step({i_agent: RailEnvActions.DO_NOTHING})
-
-    env.step({i_agent: RailEnvActions.MOVE_FORWARD})
-    env.step({i_agent: RailEnvActions.STOP_MOVING})
-    assert env.agents[i_agent].state == TrainState.STOPPED
-
-def test_stopped_to_moving():
-    env = get_small_two_agent_env()
-    env.malfunction_generator = NoMalfunctionGenerator()
-    i_agent = 0
-    ed = env.agents[i_agent].earliest_departure
-    for _ in range(ed):
-        env.step({i_agent: RailEnvActions.DO_NOTHING})
-
-    env.step({i_agent: RailEnvActions.MOVE_FORWARD})
-    env.step({i_agent: RailEnvActions.STOP_MOVING})
-    env.step({i_agent: RailEnvActions.MOVE_FORWARD})
-    assert env.agents[i_agent].state == TrainState.MOVING
-
-def test_moving_to_done():
-    env = get_small_two_agent_env()
-    env.malfunction_generator = NoMalfunctionGenerator()
-    i_agent = 1
-    ed = env.agents[i_agent].earliest_departure
-    for _ in range(ed):
-        env.step({i_agent: RailEnvActions.DO_NOTHING})
-
-    for _ in range(50):
-        env.step({i_agent: RailEnvActions.MOVE_FORWARD})
-    assert env.agents[i_agent].state == TrainState.DONE
-
-def test_waiting_to_malfunction():
-    env = get_small_two_agent_env()
-    env.malfunction_generator = AlwaysThreeStepMalfunction()
-    i_agent = 1
-    env.step({i_agent: RailEnvActions.DO_NOTHING})
-    assert env.agents[i_agent].state == TrainState.MALFUNCTION_OFF_MAP
-
-
-def test_ready_to_depart_to_malfunction_off_map():
-    env = get_small_two_agent_env()
-    env.malfunction_generator = NoMalfunctionGenerator()
-    i_agent = 1
-    env.step({i_agent: RailEnvActions.DO_NOTHING})
-    ed = env.agents[i_agent].earliest_departure
-    for _ in range(ed):
-        env.step({i_agent: RailEnvActions.DO_NOTHING}) # This should get into ready to depart
-        
-    env.malfunction_generator = AlwaysThreeStepMalfunction()
-    env.step({i_agent: RailEnvActions.DO_NOTHING})
-    assert env.agents[i_agent].state == TrainState.MALFUNCTION_OFF_MAP
-
-
-def test_malfunction_off_map_to_waiting():
-    env = get_small_two_agent_env()
-    env.malfunction_generator = NoMalfunctionGenerator()
-    i_agent = 1
-    env.step({i_agent: RailEnvActions.DO_NOTHING})
-    ed = env.agents[i_agent].earliest_departure
-    for _ in range(ed):
-        env.step({i_agent: RailEnvActions.DO_NOTHING}) # This should get into ready to depart
-        
-    env.malfunction_generator = AlwaysThreeStepMalfunction()
-    env.step({i_agent: RailEnvActions.DO_NOTHING})
-    assert env.agents[i_agent].state == TrainState.MALFUNCTION_OFF_MAP
\ No newline at end of file
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 85e6a275..56b4befc 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -108,8 +108,10 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering:
             agent: EnvAgent = env.agents[a]
             replay = test_config.replay[step]
 
-            _assert(a, agent.position, replay.position, 'position')
-            _assert(a, agent.direction, replay.direction, 'direction')
+            print(agent.position, replay.position, agent.state, agent.speed_counter)
+            # import pdb; pdb.set_trace()
+            # _assert(a, agent.position, replay.position, 'position')
+            # _assert(a, agent.direction, replay.direction, 'direction')
             if replay.state is not None:
                 _assert(a, agent.state, replay.state, 'state')
 
@@ -130,7 +132,7 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering:
                 agent.malfunction_data['malfunction'] = replay.set_malfunction
                 agent.malfunction_data['moving_before_malfunction'] = agent.moving
                 agent.malfunction_data['fixed'] = False
-            _assert(a, agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction')
+            # _assert(a, agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction')
         print(step)
         _, rewards_dict, _, info_dict = env.step(action_dict)
         if rendering:
@@ -141,6 +143,7 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering:
 
             if not skip_reward_check:
                 _assert(a, rewards_dict[a], replay.reward, 'reward')
+    assert False
 
 
 def create_and_save_env(file_name: str, line_generator: LineGenerator, rail_generator: RailGenerator):
-- 
GitLab