diff --git a/tests/test_action_plan.py b/tests/test_action_plan.py index 71a73fbc9a8f6bebb05489c3d59f1bbe41821931..2b062c4e5a892322bcf8c86e3be66e433254b346 100644 --- a/tests/test_action_plan.py +++ b/tests/test_action_plan.py @@ -30,7 +30,7 @@ def test_action_plan(rendering: bool = False): env.agents[1].initial_direction = Grid4TransitionsEnum.WEST env.agents[1].target = (0, 3) env.agents[1].speed_data['speed'] = 0.5 # two - env.reset(False, False, False) + env.reset(False, False) for handle, agent in enumerate(env.agents): print("[{}] {} -> {}".format(handle, agent.initial_position, agent.target)) diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py index 195ee9aa7856c65b0ddaf22da2f4ef5a7fea5e4b..ad2187be4bad2df2b7a85438079aa7d1f2bb8a0e 100644 --- a/tests/test_flatland_envs_predictions.py +++ b/tests/test_flatland_envs_predictions.py @@ -15,6 +15,7 @@ from flatland.envs.rail_trainrun_data_structures import Waypoint from flatland.envs.line_generators import sparse_line_generator from flatland.utils.rendertools import RenderTool from flatland.utils.simple_rail import make_simple_rail, make_simple_rail2, make_invalid_simple_rail +from flatland.envs.rail_env_action import RailEnvActions """Test predictions for `flatland` package.""" @@ -38,7 +39,11 @@ def test_dummy_predictor(rendering=False): env.agents[0].target = (3, 0) env.reset(False, False) - env.set_agent_active(env.agents[0]) + env.agents[0].earliest_departure = 1 + env._max_episode_steps = 100 + # Make Agent 0 active + env.step({}) + env.step({0: RailEnvActions.MOVE_FORWARD}) if rendering: renderer = RenderTool(env, gl="PILSVG") @@ -258,25 +263,33 @@ def test_shortest_path_predictor_conflicts(rendering=False): env.reset() # set the initial position - agent = env.agents[0] - agent.initial_position = (5, 6) # south dead-end - agent.position = (5, 6) # south dead-end - agent.direction = 0 # north - agent.initial_direction = 0 # north - agent.target = (3, 9) # east dead-end - agent.moving = True - agent.status = RailAgentStatus.ACTIVE - - agent = env.agents[1] - agent.initial_position = (3, 8) # east dead-end - agent.position = (3, 8) # east dead-end - agent.direction = 3 # west - agent.initial_direction = 3 # west - agent.target = (6, 6) # south dead-end - agent.moving = True - agent.status = RailAgentStatus.ACTIVE + env.agents[0].initial_position = (5, 6) # south dead-end + env.agents[0].position = (5, 6) # south dead-end + env.agents[0].direction = 0 # north + env.agents[0].initial_direction = 0 # north + env.agents[0].target = (3, 9) # east dead-end + env.agents[0].moving = True + env.agents[0].status = RailAgentStatus.ACTIVE + + env.agents[1].initial_position = (3, 8) # east dead-end + env.agents[1].position = (3, 8) # east dead-end + env.agents[1].direction = 3 # west + env.agents[1].initial_direction = 3 # west + env.agents[1].target = (6, 6) # south dead-end + env.agents[1].moving = True + env.agents[1].status = RailAgentStatus.ACTIVE + + observations, info = env.reset(False, False) + + env.agents[0].position = (5, 6) # south dead-end + env.agent_positions[env.agents[0].position] = 0 + env.agents[1].position = (3, 8) # east dead-end + env.agent_positions[env.agents[1].position] = 1 + env.agents[0].status = RailAgentStatus.ACTIVE + env.agents[1].status = RailAgentStatus.ACTIVE + + observations = env._get_observations() - observations, info = env.reset(False, False, True) if rendering: renderer = RenderTool(env, gl="PILSVG") diff --git a/tests/test_flatland_envs_sparse_rail_generator.py b/tests/test_flatland_envs_sparse_rail_generator.py index d8d0aa2d90e5355cafd0fca5dcdfef6fca00071e..5c12336a1a612cccd3df8beab42a8dcdfe9cdb59 100644 --- a/tests/test_flatland_envs_sparse_rail_generator.py +++ b/tests/test_flatland_envs_sparse_rail_generator.py @@ -19,7 +19,7 @@ def test_sparse_rail_generator(): ), line_generator=sparse_line_generator(), number_of_agents=10, obs_builder_object=GlobalObsForRailEnv()) - env.reset(False, False, True) + env.reset(False, False) # for r in range(env.height): # for c in range(env.width): # if env.rail.grid[r][c] > 0: @@ -1300,8 +1300,8 @@ def test_rail_env_action_required_info(): # Reset the envs - env_always_action.reset(False, False, True, random_seed=5) - env_only_if_action_required.reset(False, False, True, random_seed=5) + env_always_action.reset(False, False, random_seed=5) + env_only_if_action_required.reset(False, False, random_seed=5) assert env_only_if_action_required.rail.grid.tolist() == env_always_action.rail.grid.tolist() for step in range(50): print("step {}".format(step)) @@ -1358,7 +1358,7 @@ def test_rail_env_malfunction_speed_info(): ), line_generator=sparse_line_generator(), number_of_agents=10, obs_builder_object=GlobalObsForRailEnv()) - env.reset(False, False, True) + env.reset(False, False) env_renderer = RenderTool(env, gl="PILSVG", ) for step in range(100): @@ -1432,5 +1432,5 @@ def test_sparse_generator_changes_to_grid_mode(): ), line_generator=sparse_line_generator(), number_of_agents=10, obs_builder_object=GlobalObsForRailEnv()) with warnings.catch_warnings(record=True) as w: - rail_env.reset(True, True, True, random_seed=15) + rail_env.reset(True, True, random_seed=15) assert "[WARNING]" in str(w[-1].message) diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index 163cb41a869cafeca5772ffb361582865d0a20aa..e32e8d9f21120d7566cc027d7f9fa6cb36ded7be 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -82,7 +82,10 @@ def test_malfunction_process(): malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), obs_builder_object=SingleAgentNavigationObs() ) - obs, info = env.reset(False, False, True, random_seed=10) + obs, info = env.reset(False, False, random_seed=10) + for a_idx in range(len(env.agents)): + env.agents[a_idx].position = env.agents[a_idx].initial_position + env.agents[a_idx].status = RailAgentStatus.ACTIVE agent_halts = 0 total_down_time = 0 @@ -142,7 +145,7 @@ def test_malfunction_process_statistically(): obs_builder_object=SingleAgentNavigationObs() ) - env.reset(True, True, False, random_seed=10) + env.reset(True, True, random_seed=10) env._max_episode_steps = 1000 env.agents[0].target = (0, 0) @@ -181,7 +184,7 @@ def test_malfunction_before_entry(): malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), obs_builder_object=SingleAgentNavigationObs() ) - env.reset(False, False, False, random_seed=10) + env.reset(False, False, random_seed=10) env.agents[0].target = (0, 0) # Test initial malfunction values for all agents @@ -215,7 +218,7 @@ def test_malfunction_values_and_behavior(): obs_builder_object=SingleAgentNavigationObs() ) - env.reset(False, False, activate_agents=True, random_seed=10) + env.reset(False, False, random_seed=10) # Assertions assert_list = [9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 10, 9, 8, 7, 6, 5] @@ -247,7 +250,7 @@ def test_initial_malfunction(): obs_builder_object=SingleAgentNavigationObs() ) # reset to initialize agents_static - env.reset(False, False, True, random_seed=10) + env.reset(False, False, random_seed=10) env._max_episode_steps = 1000 print(env.agents[0].malfunction_data) env.agents[0].target = (0, 5) @@ -473,7 +476,7 @@ def tests_random_interference_from_outside(): line_generator=sparse_line_generator(seed=2), number_of_agents=1, random_seed=1) env.reset() env.agents[0].speed_data['speed'] = 0.33 - env.reset(False, False, False, random_seed=10) + env.reset(False, False, random_seed=10) env_data = [] for step in range(200): @@ -499,7 +502,7 @@ def tests_random_interference_from_outside(): line_generator=sparse_line_generator(seed=2), number_of_agents=1, random_seed=1) env.reset() env.agents[0].speed_data['speed'] = 0.33 - env.reset(False, False, False, random_seed=10) + env.reset(False, False, random_seed=10) dummy_list = [1, 2, 6, 7, 8, 9, 4, 5, 4] for step in range(200): @@ -540,7 +543,10 @@ def test_last_malfunction_step(): env._max_episode_steps = 1000 - env.reset(False, False, True) + env.reset(False, False) + for a_idx in range(len(env.agents)): + env.agents[a_idx].position = env.agents[a_idx].initial_position + env.agents[a_idx].status = RailAgentStatus.ACTIVE # Force malfunction to be off at beginning and next malfunction to happen in 2 steps env.agents[0].malfunction_data['next_malfunction'] = 2 env.agents[0].malfunction_data['malfunction'] = 0 diff --git a/tests/test_generators.py b/tests/test_generators.py index 3d6f0ccd3954fd5e2a2342f26d2283ea2cabb100..67f883746f2767bc98a090285428d7d377c905a1 100644 --- a/tests/test_generators.py +++ b/tests/test_generators.py @@ -10,6 +10,7 @@ from flatland.envs.rail_generators import rail_from_grid_transition_map, rail_fr from flatland.envs.line_generators import sparse_line_generator, line_from_file from flatland.utils.simple_rail import make_simple_rail from flatland.envs.persistence import RailEnvPersister +from flatland.envs.agent_utils import RailAgentStatus def test_empty_rail_generator(): @@ -30,7 +31,12 @@ def test_rail_from_grid_transition_map(): n_agents = 2 env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals), line_generator=sparse_line_generator(), number_of_agents=n_agents) - env.reset(False, False, True) + env.reset(False, False) + + for a_idx in range(len(env.agents)): + env.agents[a_idx].position = env.agents[a_idx].initial_position + env.agents[a_idx].status = RailAgentStatus.ACTIVE + nr_rail_elements = np.count_nonzero(env.rail.grid) # Check if the number of non-empty rail cells is ok diff --git a/tests/test_multi_speed.py b/tests/test_multi_speed.py index 286560d452b00094114f13a4f128d11899b246d3..5918c24eb413e57a6b5d9ddb4cd3ff1f461e5c29 100644 --- a/tests/test_multi_speed.py +++ b/tests/test_multi_speed.py @@ -60,7 +60,7 @@ def test_multi_speed_init(): # Set all the different speeds # Reset environment and get initial observations for all agents - env.reset(False, False, True) + env.reset(False, False) # Here you can also further enhance the provided observation by means of normalization # See training navigation example in the baseline repository @@ -68,6 +68,7 @@ def test_multi_speed_init(): for i_agent in range(env.get_num_agents()): env.agents[i_agent].speed_data['speed'] = 1. / (i_agent + 2) old_pos.append(env.agents[i_agent].position) + # Run episode for step in range(100): diff --git a/tests/test_random_seeding.py b/tests/test_random_seeding.py index dbb3201c5ab328db236242e828ed4a9b867eeafd..7ce80ff0d726539e3df1d0b3bdc64a9c40f2fda2 100644 --- a/tests/test_random_seeding.py +++ b/tests/test_random_seeding.py @@ -16,7 +16,7 @@ def ndom_seeding(): for idx in range(100): env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals), line_generator=sparse_line_generator(seed=12), number_of_agents=10) - env.reset(True, True, False, random_seed=1) + env.reset(True, True, random_seed=1) env.agents[0].target = (0, 0) for step in range(10): @@ -56,8 +56,8 @@ def test_seeding_and_observations(): line_generator=sparse_line_generator(seed=12), number_of_agents=10, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) - env.reset(False, False, False, random_seed=12) - env2.reset(False, False, False, random_seed=12) + env.reset(False, False, random_seed=12) + env2.reset(False, False, random_seed=12) # Check that both environments produce the same initial start positions assert env.agents[0].initial_position == env2.agents[0].initial_position assert env.agents[1].initial_position == env2.agents[1].initial_position @@ -112,8 +112,8 @@ def test_seeding_and_malfunction(): line_generator=sparse_line_generator(), number_of_agents=10, obs_builder_object=GlobalObsForRailEnv()) - env.reset(True, False, True, random_seed=tests) - env2.reset(True, False, True, random_seed=tests) + env.reset(True, False, random_seed=tests) + env2.reset(True, False, random_seed=tests) # Check that both environments produce the same initial start positions assert env.agents[0].initial_position == env2.agents[0].initial_position @@ -170,7 +170,7 @@ def test_reproducability_env(): grid_mode=True ), line_generator=sparse_line_generator(speed_ration_map), number_of_agents=1) - env.reset(True, True, True, random_seed=10) + env.reset(True, True, random_seed=10) excpeted_grid = [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], @@ -213,5 +213,5 @@ def test_reproducability_env(): np.random.seed(10) for i in range(10): np.random.randn() - env2.reset(True, True, True, random_seed=10) + env2.reset(True, True, random_seed=10) assert env2.rail.grid.tolist() == excpeted_grid diff --git a/tests/test_utils.py b/tests/test_utils.py index 2213d0feceea3465f1958f00df16185fe4f01103..4b72679ed6a1ceac1f266760d1871c6fc405e6dc 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -87,7 +87,11 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: agent.direction = test_config.initial_direction agent.target = test_config.target agent.speed_data['speed'] = test_config.speed - env.reset(False, False, activate_agents) + env.reset(False, False) + if activate_agents: + for a_idx in range(len(env.agents)): + env.agents[a_idx].position = env.agents[a_idx].initial_position + env.agents[a_idx].status = RailAgentStatus.ACTIVE def _assert(a, actual, expected, msg): print("[{}] verifying {} on agent {}: actual={}, expected={}".format(step, msg, a, actual, expected))