From 297b65c5f7b1208aa4da7005df1fccd02b07445f Mon Sep 17 00:00:00 2001
From: Dipam Chakraborty <dipam@aicrowd.com>
Date: Thu, 5 Aug 2021 02:15:31 +0530
Subject: [PATCH] some tests working

---
 flatland/envs/agent_utils.py                  |  6 ++-
 flatland/envs/persistence.py                  |  4 +-
 tests/test_action_plan.py                     | 42 ++++++++++---------
 tests/test_flaltland_rail_agent_status.py     |  9 ++++
 tests/test_flatland_envs_rail_env.py          |  9 ++--
 ...t_flatland_envs_rail_env_shortest_paths.py |  6 ++-
 ...est_flatland_envs_sparse_rail_generator.py |  2 +-
 ...ile.py => test_flatland_line_from_file.py} | 26 +++++++-----
 tests/test_flatland_malfunction.py            |  2 +
 tests/test_global_observation.py              |  4 ++
 tests/test_multi_speed.py                     | 17 +++++++-
 tests/test_utils.py                           |  2 +-
 12 files changed, 89 insertions(+), 40 deletions(-)
 rename tests/{test_flatland_schedule_from_file.py => test_flatland_line_from_file.py} (79%)

diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py
index 00dabd31..b90d38a4 100644
--- a/flatland/envs/agent_utils.py
+++ b/flatland/envs/agent_utils.py
@@ -109,7 +109,11 @@ class EnvAgent:
         return get_shortest_paths(distance_map=distance_map, agent_handle=self.handle)[self.handle]
         
     def get_travel_time_on_shortest_path(self, distance_map) -> int:
-        distance = len(self.get_shortest_path(distance_map))
+        shortest_path = self.get_shortest_path(distance_map)
+        if shortest_path is not None:
+            distance = len(shortest_path)
+        else:
+            distance = 0
         speed = self.speed_data['speed']
         return int(np.ceil(distance / speed))
 
diff --git a/flatland/envs/persistence.py b/flatland/envs/persistence.py
index bc4b169b..b8354078 100644
--- a/flatland/envs/persistence.py
+++ b/flatland/envs/persistence.py
@@ -21,7 +21,7 @@ from flatland.envs.distance_map import DistanceMap
 # cannot import objects / classes directly because of circular import
 from flatland.envs import malfunction_generators as mal_gen
 from flatland.envs import rail_generators as rail_gen
-from flatland.envs import schedule_generators as sched_gen
+from flatland.envs import line_generators as line_gen
 
 msgpack_numpy.patch()
 
@@ -122,7 +122,7 @@ class RailEnvPersister(object):
                 width=width, height=height,
                 rail_generator=rail_gen.rail_from_file(filename, 
                     load_from_package=load_from_package),
-                schedule_generator=sched_gen.schedule_from_file(filename,
+                    line_generator=line_gen.line_from_file(filename,
                     load_from_package=load_from_package),
                 #malfunction_generator_and_process_data=mal_gen.malfunction_from_file(filename,
                 #    load_from_package=load_from_package),
diff --git a/tests/test_action_plan.py b/tests/test_action_plan.py
index 815ecbcd..d5c95408 100644
--- a/tests/test_action_plan.py
+++ b/tests/test_action_plan.py
@@ -34,25 +34,29 @@ def test_action_plan(rendering: bool = False):
     for handle, agent in enumerate(env.agents):
         print("[{}] {} -> {}".format(handle, agent.initial_position, agent.target))
 
-    chosen_path_dict = {0: [TrainrunWaypoint(lined_at=0, waypoint=Waypoint(position=(3, 0), direction=3)),
-                            TrainrunWaypoint(lined_at=2, waypoint=Waypoint(position=(3, 1), direction=1)),
-                            TrainrunWaypoint(lined_at=3, waypoint=Waypoint(position=(3, 2), direction=1)),
-                            TrainrunWaypoint(lined_at=14, waypoint=Waypoint(position=(3, 3), direction=1)),
-                            TrainrunWaypoint(lined_at=15, waypoint=Waypoint(position=(3, 4), direction=1)),
-                            TrainrunWaypoint(lined_at=16, waypoint=Waypoint(position=(3, 5), direction=1)),
-                            TrainrunWaypoint(lined_at=17, waypoint=Waypoint(position=(3, 6), direction=1)),
-                            TrainrunWaypoint(lined_at=18, waypoint=Waypoint(position=(3, 7), direction=1)),
-                            TrainrunWaypoint(lined_at=19, waypoint=Waypoint(position=(3, 8), direction=1)),
-                            TrainrunWaypoint(lined_at=20, waypoint=Waypoint(position=(3, 8), direction=5))],
-                        1: [TrainrunWaypoint(lined_at=0, waypoint=Waypoint(position=(3, 8), direction=3)),
-                            TrainrunWaypoint(lined_at=3, waypoint=Waypoint(position=(3, 7), direction=3)),
-                            TrainrunWaypoint(lined_at=5, waypoint=Waypoint(position=(3, 6), direction=3)),
-                            TrainrunWaypoint(lined_at=7, waypoint=Waypoint(position=(3, 5), direction=3)),
-                            TrainrunWaypoint(lined_at=9, waypoint=Waypoint(position=(3, 4), direction=3)),
-                            TrainrunWaypoint(lined_at=11, waypoint=Waypoint(position=(3, 3), direction=3)),
-                            TrainrunWaypoint(lined_at=13, waypoint=Waypoint(position=(2, 3), direction=0)),
-                            TrainrunWaypoint(lined_at=15, waypoint=Waypoint(position=(1, 3), direction=0)),
-                            TrainrunWaypoint(lined_at=17, waypoint=Waypoint(position=(0, 3), direction=0))]}
+    # Perform DO_NOTHING actions until all trains get to READY_TO_DEPART
+    for _ in range(max([agent.earliest_departure for agent in env.agents])):
+        env.step({}) # DO_NOTHING for all agents
+
+    chosen_path_dict = {0: [TrainrunWaypoint(scheduled_at=0, waypoint=Waypoint(position=(3, 0), direction=3)),
+                            TrainrunWaypoint(scheduled_at=2, waypoint=Waypoint(position=(3, 1), direction=1)),
+                            TrainrunWaypoint(scheduled_at=3, waypoint=Waypoint(position=(3, 2), direction=1)),
+                            TrainrunWaypoint(scheduled_at=14, waypoint=Waypoint(position=(3, 3), direction=1)),
+                            TrainrunWaypoint(scheduled_at=15, waypoint=Waypoint(position=(3, 4), direction=1)),
+                            TrainrunWaypoint(scheduled_at=16, waypoint=Waypoint(position=(3, 5), direction=1)),
+                            TrainrunWaypoint(scheduled_at=17, waypoint=Waypoint(position=(3, 6), direction=1)),
+                            TrainrunWaypoint(scheduled_at=18, waypoint=Waypoint(position=(3, 7), direction=1)),
+                            TrainrunWaypoint(scheduled_at=19, waypoint=Waypoint(position=(3, 8), direction=1)),
+                            TrainrunWaypoint(scheduled_at=20, waypoint=Waypoint(position=(3, 8), direction=5))],
+                        1: [TrainrunWaypoint(scheduled_at=0, waypoint=Waypoint(position=(3, 8), direction=3)),
+                            TrainrunWaypoint(scheduled_at=3, waypoint=Waypoint(position=(3, 7), direction=3)),
+                            TrainrunWaypoint(scheduled_at=5, waypoint=Waypoint(position=(3, 6), direction=3)),
+                            TrainrunWaypoint(scheduled_at=7, waypoint=Waypoint(position=(3, 5), direction=3)),
+                            TrainrunWaypoint(scheduled_at=9, waypoint=Waypoint(position=(3, 4), direction=3)),
+                            TrainrunWaypoint(scheduled_at=11, waypoint=Waypoint(position=(3, 3), direction=3)),
+                            TrainrunWaypoint(scheduled_at=13, waypoint=Waypoint(position=(2, 3), direction=0)),
+                            TrainrunWaypoint(scheduled_at=15, waypoint=Waypoint(position=(1, 3), direction=0)),
+                            TrainrunWaypoint(scheduled_at=17, waypoint=Waypoint(position=(0, 3), direction=0))]}
     expected_action_plan = [[
         # take action to enter the grid
         ActionPlanElement(0, RailEnvActions.MOVE_FORWARD),
diff --git a/tests/test_flaltland_rail_agent_status.py b/tests/test_flaltland_rail_agent_status.py
index 4f507415..87b37305 100644
--- a/tests/test_flaltland_rail_agent_status.py
+++ b/tests/test_flaltland_rail_agent_status.py
@@ -17,6 +17,11 @@ def test_initial_status():
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                   remove_agents_at_target=False)
     env.reset()
+
+    # Perform DO_NOTHING actions until all trains get to READY_TO_DEPART
+    for _ in range(max([agent.earliest_departure for agent in env.agents])):
+        env.step({}) # DO_NOTHING for all agents
+
     set_penalties_for_replay(env)
     test_config = ReplayConfig(
         replay=[
@@ -126,6 +131,10 @@ def test_status_done_remove():
                   remove_agents_at_target=True)
     env.reset()
 
+    # Perform DO_NOTHING actions until all trains get to READY_TO_DEPART
+    for _ in range(max([agent.earliest_departure for agent in env.agents])):
+        env.step({}) # DO_NOTHING for all agents
+
     set_penalties_for_replay(env)
     test_config = ReplayConfig(
         replay=[
diff --git a/tests/test_flatland_envs_rail_env.py b/tests/test_flatland_envs_rail_env.py
index c531359a..53a61c8f 100644
--- a/tests/test_flatland_envs_rail_env.py
+++ b/tests/test_flatland_envs_rail_env.py
@@ -9,9 +9,9 @@ from flatland.envs.agent_utils import EnvAgent
 from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv, RailEnvActions
-from flatland.envs.rail_generators import complex_rail_generator, rail_from_file
+from flatland.envs.rail_generators import sparse_rail_generator, rail_from_file
 from flatland.envs.rail_generators import rail_from_grid_transition_map
-from flatland.envs.line_generators import random_line_generator, complex_line_generator, line_from_file
+from flatland.envs.line_generators import random_line_generator, sparse_line_generator, line_from_file
 from flatland.utils.simple_rail import make_simple_rail
 from flatland.envs.persistence import RailEnvPersister
 from flatland.utils.rendertools import RenderTool
@@ -37,9 +37,10 @@ def test_load_env():
 
 def test_save_load():
     env = RailEnv(width=10, height=10,
-                  rail_generator=complex_rail_generator(nr_start_goal=2, nr_extra=5, min_dist=6, seed=1),
-                  line_generator=complex_line_generator(), number_of_agents=2)
+                  rail_generator=sparse_rail_generator(nr_start_goal=2, nr_extra=5, min_dist=6, seed=1),
+                  line_generator=sparse_line_generator(), number_of_agents=2)
     env.reset()
+
     agent_1_pos = env.agents[0].position
     agent_1_dir = env.agents[0].direction
     agent_1_tar = env.agents[0].target
diff --git a/tests/test_flatland_envs_rail_env_shortest_paths.py b/tests/test_flatland_envs_rail_env_shortest_paths.py
index b2ee9b01..ce88aeb8 100644
--- a/tests/test_flatland_envs_rail_env_shortest_paths.py
+++ b/tests/test_flatland_envs_rail_env_shortest_paths.py
@@ -23,6 +23,10 @@ def test_get_shortest_paths_unreachable():
                   obs_builder_object=GlobalObsForRailEnv())
     env.reset()
 
+    # Perform DO_NOTHING actions until all trains get to READY_TO_DEPART
+    for _ in range(max([agent.earliest_departure for agent in env.agents])):
+        env.step({}) # DO_NOTHING for all agents
+
     # set the initial position
     agent = env.agents[0]
     agent.position = (3, 1)  # west dead-end
@@ -36,7 +40,7 @@ def test_get_shortest_paths_unreachable():
     actual = get_shortest_paths(env.distance_map)
     expected = {0: None}
 
-    assert actual == expected, "actual={},expected={}".format(actual, expected)
+    assert actual[0] == expected[0], "actual={},expected={}".format(actual[0], expected[0])
 
 
 # todo file test_002.pkl has to be generated automatically
diff --git a/tests/test_flatland_envs_sparse_rail_generator.py b/tests/test_flatland_envs_sparse_rail_generator.py
index 3e74d720..74e71dac 100644
--- a/tests/test_flatland_envs_sparse_rail_generator.py
+++ b/tests/test_flatland_envs_sparse_rail_generator.py
@@ -1512,7 +1512,7 @@ def test_sparse_generator_changes_to_grid_mode():
     rail_env = RailEnv(width=10, height=20, rail_generator=sparse_rail_generator(
         max_num_cities=100,
         max_rails_between_cities=2,
-        max_rails_in_city=2,
+        max_rail_pairs_in_city=1,
         seed=15,
         grid_mode=False
     ), line_generator=sparse_line_generator(), number_of_agents=10,
diff --git a/tests/test_flatland_schedule_from_file.py b/tests/test_flatland_line_from_file.py
similarity index 79%
rename from tests/test_flatland_schedule_from_file.py
rename to tests/test_flatland_line_from_file.py
index 0b903eae..b324af98 100644
--- a/tests/test_flatland_schedule_from_file.py
+++ b/tests/test_flatland_line_from_file.py
@@ -25,12 +25,14 @@ def test_line_from_file_sparse():
                                            seed=1,
                                            grid_mode=False,
                                            max_rails_between_cities=3,
-                                           max_rails_in_city=6,
+                                           max_rail_pairs_in_city=3,
                                            )
     line_generator = sparse_line_generator(speed_ration_map)
 
-    create_and_save_env(file_name="./sparse_env_test.pkl", rail_generator=rail_generator,
+    env = create_and_save_env(file_name="./sparse_env_test.pkl", rail_generator=rail_generator,
                         line_generator=line_generator)
+    old_num_steps = env._max_episode_steps
+    old_num_agents = len(env.agents)
 
 
     # Sparse generator
@@ -41,10 +43,10 @@ def test_line_from_file_sparse():
     sparse_env_from_file.reset(True, True)
 
     # Assert loaded agent number is correct
-    assert sparse_env_from_file.get_num_agents() == 10
+    assert sparse_env_from_file.get_num_agents() == old_num_agents
 
     # Assert max steps is correct
-    assert sparse_env_from_file._max_episode_steps == 500
+    assert sparse_env_from_file._max_episode_steps == old_num_steps
 
 
 
@@ -65,8 +67,10 @@ def test_line_from_file_random():
     rail_generator = random_rail_generator()
     line_generator = random_line_generator(speed_ration_map)
 
-    create_and_save_env(file_name="./random_env_test.pkl", rail_generator=rail_generator,
+    env = create_and_save_env(file_name="./random_env_test.pkl", rail_generator=rail_generator,
                         line_generator=line_generator)
+    old_num_steps = env._max_episode_steps
+    old_num_agents = len(env.agents)                        
 
 
     # Random generator
@@ -77,10 +81,10 @@ def test_line_from_file_random():
     random_env_from_file.reset(True, True)
 
     # Assert loaded agent number is correct
-    assert random_env_from_file.get_num_agents() == 10
+    assert random_env_from_file.get_num_agents() == old_num_agents
 
     # Assert max steps is correct
-    assert random_env_from_file._max_episode_steps == 1350
+    assert random_env_from_file._max_episode_steps == old_num_steps
 
 
 
@@ -105,8 +109,10 @@ def test_line_from_file_complex():
                                             max_dist=99999)
     line_generator = complex_line_generator(speed_ration_map)
 
-    create_and_save_env(file_name="./complex_env_test.pkl", rail_generator=rail_generator,
+    env = create_and_save_env(file_name="./complex_env_test.pkl", rail_generator=rail_generator,
                         line_generator=line_generator)
+    old_num_steps = env._max_episode_steps
+    old_num_agents = len(env.agents)
 
     # Load the different envs and check the parameters
 
@@ -119,7 +125,7 @@ def test_line_from_file_complex():
     complex_env_from_file.reset(True, True)
 
     # Assert loaded agent number is correct
-    assert complex_env_from_file.get_num_agents() == 10
+    assert complex_env_from_file.get_num_agents() == old_num_agents
 
     # Assert max steps is correct
-    assert complex_env_from_file._max_episode_steps == 1350
+    assert complex_env_from_file._max_episode_steps == old_num_steps
diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py
index 53915102..8675f54f 100644
--- a/tests/test_flatland_malfunction.py
+++ b/tests/test_flatland_malfunction.py
@@ -97,6 +97,8 @@ def test_malfunction_process():
             actions[i] = np.argmax(obs[i]) + 1
 
         obs, all_rewards, done, _ = env.step(actions)
+        if done["__all__"]:
+            break
 
         if env.agents[0].malfunction_data['malfunction'] > 0:
             agent_malfunctioning = True
diff --git a/tests/test_global_observation.py b/tests/test_global_observation.py
index 5b090681..851d849d 100644
--- a/tests/test_global_observation.py
+++ b/tests/test_global_observation.py
@@ -30,6 +30,10 @@ def test_get_global_observation():
                   obs_builder_object=GlobalObsForRailEnv())
     env.reset()
 
+    # Perform DO_NOTHING actions until all trains get to READY_TO_DEPART
+    for _ in range(max([agent.earliest_departure for agent in env.agents])):
+        env.step({}) # DO_NOTHING for all agents
+
     obs, all_rewards, done, _ = env.step({i: RailEnvActions.MOVE_FORWARD for i in range(number_of_agents)})
     for i in range(len(env.agents)):
         agent: EnvAgent = env.agents[i]
diff --git a/tests/test_multi_speed.py b/tests/test_multi_speed.py
index 08b46d00..2664c5b4 100644
--- a/tests/test_multi_speed.py
+++ b/tests/test_multi_speed.py
@@ -51,6 +51,7 @@ def test_multi_speed_init():
                   rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999,
                                                         seed=1), line_generator=complex_line_generator(),
                   number_of_agents=5)
+    
     # Initialize the agent with the parameters corresponding to the environment and observation_builder
     agent = RandomAgent(218, 4)
 
@@ -197,6 +198,12 @@ def test_multispeed_actions_no_malfunction_blocking():
                   line_generator=random_line_generator(), number_of_agents=2,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
     env.reset()
+    
+    # Perform DO_NOTHING actions until all trains get to READY_TO_DEPART
+    for _ in range(max([agent.earliest_departure for agent in env.agents])):
+        env.step({}) # DO_NOTHING for all agents
+    
+
     set_penalties_for_replay(env)
     test_configs = [
         ReplayConfig(
@@ -381,7 +388,11 @@ def test_multispeed_actions_malfunction_no_blocking():
                   line_generator=random_line_generator(), number_of_agents=1,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
     env.reset()
-
+    
+    # Perform DO_NOTHING actions until all trains get to READY_TO_DEPART
+    for _ in range(max([agent.earliest_departure for agent in env.agents])):
+        env.step({}) # DO_NOTHING for all agents
+    
     set_penalties_for_replay(env)
     test_config = ReplayConfig(
         replay=[
@@ -515,6 +526,10 @@ def test_multispeed_actions_no_malfunction_invalid_actions():
                   line_generator=random_line_generator(), number_of_agents=1,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
     env.reset()
+    
+    # Perform DO_NOTHING actions until all trains get to READY_TO_DEPART
+    for _ in range(max([agent.earliest_departure for agent in env.agents])):
+        env.step({}) # DO_NOTHING for all agents
 
     set_penalties_for_replay(env)
     test_config = ReplayConfig(
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 4e7c30ca..062d56f0 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -152,4 +152,4 @@ def create_and_save_env(file_name: str, line_generator: LineGenerator, rail_gene
     env.reset(True, True)
     #env.save(file_name)
     RailEnvPersister.save(env, file_name)
-    
+    return env
-- 
GitLab