diff --git a/flatland/envs/line_generators.py b/flatland/envs/line_generators.py index 7f50abb2ce32dde4f75fa6c600721e20c86e8a04..6c81f5994617e52e34909921c6fedb2f53a41962 100644 --- a/flatland/envs/line_generators.py +++ b/flatland/envs/line_generators.py @@ -74,7 +74,10 @@ class SparseLineGen(BaseLineGen): if rail.check_path_exists(start[0], orientation, target[0]): feasible_orientations.append(orientation) - return np_random.choice(feasible_orientations) + if len(feasible_orientations) > 0: + return np_random.choice(feasible_orientations) + else: + return 0 def generate(self, rail: GridTransitionMap, num_agents: int, hints: dict, num_resets: int, np_random: RandomState) -> Line: diff --git a/tests/test_action_plan.py b/tests/test_action_plan.py index 9a2fe113117ae513bb4692790e2ad1091f1f00d7..88f07d282664bae8ad8277e4a6146e325f61e534 100644 --- a/tests/test_action_plan.py +++ b/tests/test_action_plan.py @@ -37,7 +37,7 @@ def test_action_plan(rendering: bool = False): print("[{}] {} -> {}".format(handle, agent.initial_position, agent.target)) # Perform DO_NOTHING actions until all trains get to READY_TO_DEPART - for _ in range(max([agent.earliest_departure for agent in env.agents])): + for _ in range(max([agent.earliest_departure for agent in env.agents]) + 1): env.step({}) # DO_NOTHING for all agents chosen_path_dict = {0: [TrainrunWaypoint(scheduled_at=0, waypoint=Waypoint(position=(3, 0), direction=3)), diff --git a/tests/test_flatland_envs_sparse_rail_generator.py b/tests/test_flatland_envs_sparse_rail_generator.py index 4da868f48ebe0a96c2cc23bc7362e5cd341047b0..bb4ad9c6d2fa6394a3e9da36a401f3d90dad9003 100644 --- a/tests/test_flatland_envs_sparse_rail_generator.py +++ b/tests/test_flatland_envs_sparse_rail_generator.py @@ -500,8 +500,8 @@ def test_sparse_rail_generator(): for a in range(env.get_num_agents()): s0 = Vec2d.get_manhattan_distance(env.agents[a].initial_position, (0, 0)) s1 = Vec2d.get_chebyshev_distance(env.agents[a].initial_position, (0, 0)) - assert s0 == 46, "actual={}".format(s0) - assert s1 == 26, "actual={}".format(s1) + assert s0 == 36, "actual={}".format(s0) + assert s1 == 27, "actual={}".format(s1) def test_sparse_rail_generator_deterministic(): diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index 16eba37049aeac83562c630de67c7e5f7c61441a..d1598d0d6c31974f74531c8ec7047a1382e1136e 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -151,8 +151,8 @@ def test_malfunction_process_statistically(): env.agents[0].target = (0, 0) # Next line only for test generation agent_malfunction_list = [[] for i in range(2)] - agent_malfunction_list = [[0, 0, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 3, 2, 1, 0], - [0, 4, 3, 2, 1, 0, 0, 0, 0, 0, 4, 3, 2, 1, 0, 4, 3, 2, 1, 0]] + agent_malfunction_list = [[0, 0, 0, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 3, 2, 1], + [0, 0, 4, 3, 2, 1, 0, 0, 0, 0, 0, 4, 3, 2, 1, 0, 4, 3, 2, 1]] for step in range(20): action_dict: Dict[int, RailEnvActions] = {} diff --git a/tests/test_multi_speed.py b/tests/test_multi_speed.py index f809068180e23bfa96b4bbd6d8c647b290a4b039..d8f72d36f01ee2cab984001b96818ee0bcfdf203 100644 --- a/tests/test_multi_speed.py +++ b/tests/test_multi_speed.py @@ -396,7 +396,7 @@ def test_multispeed_actions_malfunction_no_blocking(): env.reset() # Perform DO_NOTHING actions until all trains get to READY_TO_DEPART - for _ in range(max([agent.earliest_departure for agent in env.agents])): + for _ in range(max([agent.earliest_departure for agent in env.agents]) + 1): env.step({}) # DO_NOTHING for all agents env._max_episode_steps = 10000 diff --git a/tests/test_utils.py b/tests/test_utils.py index fdae8f5c32f4ab305e54f31293e98fbba5c0a41a..b6b69d15f5e92e8711f573dbd1be6e1c32dfb108 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -115,6 +115,8 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: for a, test_config in enumerate(test_configs): agent: EnvAgent = env.agents[a] replay = test_config.replay[step] + # if not agent.position == replay.position: + # import pdb; pdb.set_trace() _assert(a, agent.position, replay.position, 'position') _assert(a, agent.direction, replay.direction, 'direction') if replay.state is not None: @@ -140,6 +142,7 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: _assert(a, agent.malfunction_handler.malfunction_down_counter, replay.malfunction, 'malfunction') print(step) _, rewards_dict, _, info_dict = env.step(action_dict) + # import pdb; pdb.set_trace() if rendering: renderer.render_env(show=True, show_observations=True)