From 940b81321bcc51e1110b312e2969d815900ea50a Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Wed, 9 Oct 2019 14:52:26 -0400
Subject: [PATCH] fixed tests for changes to start of agents

---
 flatland/envs/rail_env.py                 |  5 +-
 tests/test_flaltland_rail_agent_status.py | 57 +++++++++++-----------
 tests/test_flatland_malfunction.py        | 59 +++++++++++------------
 tests/test_random_seeding.py              | 19 +++++---
 4 files changed, 73 insertions(+), 67 deletions(-)

diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index cda0dee9..e7e5ef0c 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -189,7 +189,6 @@ class RailEnv(Environment):
         self.action_space = [5]
 
         self._seed()
-
         self._seed()
         self.random_seed = random_seed
         if self.random_seed:
@@ -217,6 +216,7 @@ class RailEnv(Environment):
         self.min_number_of_steps_broken = malfunction_min_duration
         self.max_number_of_steps_broken = malfunction_max_duration
         # Reset environment
+
         self.reset()
         self.num_resets = 0  # yes, set it to zero again!
 
@@ -259,6 +259,7 @@ class RailEnv(Environment):
             if replace_agents then regenerate the agents static.
             Relies on the rail_generator returning agent_static lists (pos, dir, target)
         """
+
         if random_seed:
             self._seed(random_seed)
 
@@ -388,6 +389,7 @@ class RailEnv(Environment):
         return False
 
     def step(self, action_dict_: Dict[int, RailEnvActions]):
+
         self._elapsed_steps += 1
 
         # Reset the step rewards
@@ -459,7 +461,6 @@ class RailEnv(Environment):
                 agent.status = RailAgentStatus.ACTIVE
                 agent.position = agent.initial_position
                 self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
-                print(self.rewards_dict[i_agent])
                 return
             else:
                 # TODO: Here we need to check for the departure time in future releases with full schedules
diff --git a/tests/test_flaltland_rail_agent_status.py b/tests/test_flaltland_rail_agent_status.py
index 099ccce6..14a3e48a 100644
--- a/tests/test_flaltland_rail_agent_status.py
+++ b/tests/test_flaltland_rail_agent_status.py
@@ -23,7 +23,6 @@ def test_initial_status():
                   number_of_agents=1,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                   )
-
     set_penalties_for_replay(env)
     test_config = ReplayConfig(
         replay=[
@@ -40,64 +39,64 @@ def test_initial_status():
                 direction=Grid4TransitionsEnum.EAST,
                 status=RailAgentStatus.READY_TO_DEPART,
                 action=RailEnvActions.MOVE_LEFT,
-                reward=env.start_penalty + env.step_penalty * 0.5,  # auto-correction left to forward without penalty!
+                reward=env.step_penalty * 0.5,  # auto-correction left to forward without penalty!
             ),
             Replay(
                 position=(3, 9),
                 direction=Grid4TransitionsEnum.EAST,
                 status=RailAgentStatus.ACTIVE,
-                action=None,
-                reward=env.step_penalty * 0.5,  # running at speed 0.5
+                action=RailEnvActions.MOVE_LEFT,
+                reward=env.start_penalty + env.step_penalty * 0.5,  # running at speed 0.5
             ),
             Replay(
                 position=(3, 9),
-                direction=Grid4TransitionsEnum.WEST,
+                direction=Grid4TransitionsEnum.EAST,
                 status=RailAgentStatus.ACTIVE,
-                action=RailEnvActions.MOVE_FORWARD,
+                action=None,
                 reward=env.step_penalty * 0.5,  # running at speed 0.5
             ),
             Replay(
                 position=(3, 8),
                 direction=Grid4TransitionsEnum.WEST,
                 status=RailAgentStatus.ACTIVE,
-                action=None,
+                action=RailEnvActions.MOVE_FORWARD,
                 reward=env.step_penalty * 0.5,  # running at speed 0.5
             ),
             Replay(
                 position=(3, 8),
                 direction=Grid4TransitionsEnum.WEST,
                 status=RailAgentStatus.ACTIVE,
-                action=RailEnvActions.MOVE_FORWARD,
+                action=None,
                 reward=env.step_penalty * 0.5,  # running at speed 0.5
 
             ),
             Replay(
                 position=(3, 7),
                 direction=Grid4TransitionsEnum.WEST,
-                action=None,
+                action=RailEnvActions.MOVE_FORWARD,
                 reward=env.step_penalty * 0.5,  # running at speed 0.5
                 status=RailAgentStatus.ACTIVE
             ),
             Replay(
                 position=(3, 7),
                 direction=Grid4TransitionsEnum.WEST,
-                action=RailEnvActions.MOVE_RIGHT,
+                action=None,
                 reward=env.step_penalty * 0.5,  # wrong action is corrected to forward without penalty!
                 status=RailAgentStatus.ACTIVE
             ),
             Replay(
                 position=(3, 6),
                 direction=Grid4TransitionsEnum.WEST,
-                action=None,
-                reward=env.global_reward,  # done
+                action=RailEnvActions.MOVE_RIGHT,
+                reward=env.step_penalty * 0.5,  #
                 status=RailAgentStatus.ACTIVE
             ),
             Replay(
                 position=(3, 6),
                 direction=Grid4TransitionsEnum.WEST,
                 action=None,
-                reward=env.global_reward,  # already done
-                status=RailAgentStatus.DONE
+                reward=env.global_reward,  #
+                status=RailAgentStatus.ACTIVE
             ),
             Replay(
                 position=(3, 5),
@@ -151,7 +150,14 @@ def test_status_done_remove():
                 direction=Grid4TransitionsEnum.EAST,
                 status=RailAgentStatus.READY_TO_DEPART,
                 action=RailEnvActions.MOVE_LEFT,
-                reward=env.start_penalty + env.step_penalty * 0.5,  # auto-correction left to forward without penalty!
+                reward=env.step_penalty * 0.5,  # auto-correction left to forward without penalty!
+            ),
+            Replay(
+                position=(3, 9),
+                direction=Grid4TransitionsEnum.EAST,
+                status=RailAgentStatus.ACTIVE,
+                action=RailEnvActions.MOVE_FORWARD,
+                reward=env.start_penalty + env.step_penalty * 0.5,  # running at speed 0.5
             ),
             Replay(
                 position=(3, 9),
@@ -173,42 +179,35 @@ def test_status_done_remove():
                 status=RailAgentStatus.ACTIVE,
                 action=None,
                 reward=env.step_penalty * 0.5,  # running at speed 0.5
-            ),
-            Replay(
-                position=(3, 7),
-                direction=Grid4TransitionsEnum.WEST,
-                status=RailAgentStatus.ACTIVE,
-                action=RailEnvActions.MOVE_FORWARD,
-                reward=env.step_penalty * 0.5,  # running at speed 0.5
 
             ),
             Replay(
                 position=(3, 7),
                 direction=Grid4TransitionsEnum.WEST,
-                action=None,
+                action=RailEnvActions.MOVE_RIGHT,
                 reward=env.step_penalty * 0.5,  # running at speed 0.5
                 status=RailAgentStatus.ACTIVE
             ),
             Replay(
-                position=(3, 6),
+                position=(3, 7),
                 direction=Grid4TransitionsEnum.WEST,
-                action=RailEnvActions.MOVE_RIGHT,
+                action=None,
                 reward=env.step_penalty * 0.5,  # wrong action is corrected to forward without penalty!
                 status=RailAgentStatus.ACTIVE
             ),
             Replay(
                 position=(3, 6),
                 direction=Grid4TransitionsEnum.WEST,
-                action=None,
-                reward=env.global_reward,  # done
+                action=RailEnvActions.MOVE_FORWARD,
+                reward=env.step_penalty * 0.5,  # done
                 status=RailAgentStatus.ACTIVE
             ),
             Replay(
-                position=None,
+                position=(3, 6),
                 direction=Grid4TransitionsEnum.WEST,
                 action=None,
                 reward=env.global_reward,  # already done
-                status=RailAgentStatus.DONE_REMOVED
+                status=RailAgentStatus.ACTIVE
             ),
             Replay(
                 position=None,
diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py
index 35e41b7e..8008e6e2 100644
--- a/tests/test_flatland_malfunction.py
+++ b/tests/test_flatland_malfunction.py
@@ -155,16 +155,16 @@ def test_malfunction_process_statistically():
 
     env.agents[0].target = (0, 0)
     nb_malfunction = 0
-    agent_malfunction_list = [[6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 6],
-                              [6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0],
-                              [6, 6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2],
-                              [6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0, 0],
-                              [6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 6, 5],
-                              [6, 5, 4, 3, 2, 1, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 6, 5],
-                              [6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1],
-                              [6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4],
-                              [6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0, 6],
-                              [6, 6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2]]
+    agent_malfunction_list = [[6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0],
+                              [6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1],
+                              [6, 6, 6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3],
+                              [6, 6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0],
+                              [6, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 6],
+                              [6, 6, 5, 4, 3, 2, 1, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 6],
+                              [6, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2],
+                              [6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5],
+                              [6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0],
+                              [6, 6, 6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3]]
 
     for step in range(20):
         action_dict: Dict[int, RailEnvActions] = {}
@@ -177,7 +177,6 @@ def test_malfunction_process_statistically():
         env.step(action_dict)
 
 
-
 def test_malfunction_before_entry():
     """Tests that malfunctions are produced by stochastic_data!"""
     # Set fixed malfunction duration for this test
@@ -200,6 +199,8 @@ def test_malfunction_before_entry():
     env.reset(False, False, False, random_seed=10)
     env.agents[0].target = (0, 0)
 
+    # Print for test generation
+    assert env.agents[0].malfunction_data['malfunction'] == 11
     assert env.agents[1].malfunction_data['malfunction'] == 11
     assert env.agents[2].malfunction_data['malfunction'] == 11
     assert env.agents[3].malfunction_data['malfunction'] == 11
@@ -210,7 +211,6 @@ def test_malfunction_before_entry():
     assert env.agents[8].malfunction_data['malfunction'] == 11
     assert env.agents[9].malfunction_data['malfunction'] == 11
 
-
     for step in range(20):
         action_dict: Dict[int, RailEnvActions] = {}
         for agent in env.agents:
@@ -220,18 +220,17 @@ def test_malfunction_before_entry():
                 action_dict[agent.handle] = RailEnvActions(0)
 
         env.step(action_dict)
-
-    assert env.agents[1].malfunction_data['malfunction'] == 1
-    assert env.agents[2].malfunction_data['malfunction'] == 1
-    assert env.agents[3].malfunction_data['malfunction'] == 1
-    assert env.agents[4].malfunction_data['malfunction'] == 1
-    assert env.agents[5].malfunction_data['malfunction'] == 1
-    assert env.agents[6].malfunction_data['malfunction'] == 1
-    assert env.agents[7].malfunction_data['malfunction'] == 1
-    assert env.agents[8].malfunction_data['malfunction'] == 1
-    assert env.agents[9].malfunction_data['malfunction'] == 1
-    # Print for test generation
-    # for a in range(env.get_num_agents()):
+    assert env.agents[1].malfunction_data['malfunction'] == 2
+    assert env.agents[2].malfunction_data['malfunction'] == 2
+    assert env.agents[3].malfunction_data['malfunction'] == 2
+    assert env.agents[4].malfunction_data['malfunction'] == 2
+    assert env.agents[5].malfunction_data['malfunction'] == 2
+    assert env.agents[6].malfunction_data['malfunction'] == 2
+    assert env.agents[7].malfunction_data['malfunction'] == 2
+    assert env.agents[8].malfunction_data['malfunction'] == 2
+    assert env.agents[9].malfunction_data['malfunction'] == 2
+
+    #for a in range(env.get_num_agents()):
     #    print("assert env.agents[{}].malfunction_data['malfunction'] == {}".format(a,
     #                                                                               env.agents[a].malfunction_data[
     #                                                                                   'malfunction']))
@@ -348,7 +347,7 @@ def test_initial_malfunction_stop_moving():
                 position=(3, 2),
                 direction=Grid4TransitionsEnum.EAST,
                 action=RailEnvActions.DO_NOTHING,
-                malfunction=2,
+                malfunction=3,
                 reward=env.step_penalty,  # full step penalty when stopped
                 status=RailAgentStatus.ACTIVE
             ),
@@ -359,7 +358,7 @@ def test_initial_malfunction_stop_moving():
                 position=(3, 2),
                 direction=Grid4TransitionsEnum.EAST,
                 action=RailEnvActions.STOP_MOVING,
-                malfunction=1,
+                malfunction=2,
                 reward=env.step_penalty,  # full step penalty while stopped
                 status=RailAgentStatus.ACTIVE
             ),
@@ -368,7 +367,7 @@ def test_initial_malfunction_stop_moving():
                 position=(3, 2),
                 direction=Grid4TransitionsEnum.EAST,
                 action=RailEnvActions.DO_NOTHING,
-                malfunction=0,
+                malfunction=1,
                 reward=env.step_penalty,  # full step penalty while stopped
                 status=RailAgentStatus.ACTIVE
             ),
@@ -437,7 +436,7 @@ def test_initial_malfunction_do_nothing():
                 position=(3, 2),
                 direction=Grid4TransitionsEnum.EAST,
                 action=RailEnvActions.DO_NOTHING,
-                malfunction=2,
+                malfunction=3,
                 reward=env.step_penalty,  # full step penalty while malfunctioning
                 status=RailAgentStatus.ACTIVE
             ),
@@ -448,7 +447,7 @@ def test_initial_malfunction_do_nothing():
                 position=(3, 2),
                 direction=Grid4TransitionsEnum.EAST,
                 action=RailEnvActions.DO_NOTHING,
-                malfunction=1,
+                malfunction=2,
                 reward=env.step_penalty,  # full step penalty while stopped
                 status=RailAgentStatus.ACTIVE
             ),
@@ -457,7 +456,7 @@ def test_initial_malfunction_do_nothing():
                 position=(3, 2),
                 direction=Grid4TransitionsEnum.EAST,
                 action=RailEnvActions.DO_NOTHING,
-                malfunction=0,
+                malfunction=1,
                 reward=env.step_penalty,  # full step penalty while stopped
                 status=RailAgentStatus.ACTIVE
             ),
diff --git a/tests/test_random_seeding.py b/tests/test_random_seeding.py
index 3a03de00..ecf10c49 100644
--- a/tests/test_random_seeding.py
+++ b/tests/test_random_seeding.py
@@ -21,7 +21,6 @@ def test_random_seeding():
                       number_of_agents=10
                       )
         env.reset(True, True, False, random_seed=1)
-        # Test generation print
 
         env.agents[0].target = (0, 0)
         for step in range(10):
@@ -29,12 +28,20 @@ def test_random_seeding():
             actions[0] = 2
             env.step(actions)
         agent_positions = []
-        for a in range(env.get_num_agents()):
-            agent_positions += env.agents[a].initial_position
-        # print(agent_positions)
-        assert agent_positions == [3, 2, 3, 5, 3, 6, 5, 6, 3, 4, 3, 1, 3, 9, 4, 6, 0, 3, 3, 7]
+
+        env.agents[0].initial_position == (3, 2)
+        env.agents[1].initial_position == (3, 5)
+        env.agents[2].initial_position == (3, 6)
+        env.agents[3].initial_position == (5, 6)
+        env.agents[4].initial_position == (3, 4)
+        env.agents[5].initial_position == (3, 1)
+        env.agents[6].initial_position == (3, 9)
+        env.agents[7].initial_position == (4, 6)
+        env.agents[8].initial_position == (0, 3)
+        env.agents[9].initial_position == (3, 7)
         # Test generation print
-        assert env.agents[0].position == (3, 6)
+        # for a in range(env.get_num_agents()):
+        #    print("env.agents[{}].initial_position == {}".format(a,env.agents[a].initial_position))
         # print("env.agents[0].initial_position == {}".format(env.agents[0].initial_position))
         # print("assert env.agents[0].position ==  {}".format(env.agents[0].position))
 
-- 
GitLab