From 940b81321bcc51e1110b312e2969d815900ea50a Mon Sep 17 00:00:00 2001 From: MLErik <baerenjesus@gmail.com> Date: Wed, 9 Oct 2019 14:52:26 -0400 Subject: [PATCH] fixed tests for changes to start of agents --- flatland/envs/rail_env.py | 5 +- tests/test_flaltland_rail_agent_status.py | 57 +++++++++++----------- tests/test_flatland_malfunction.py | 59 +++++++++++------------ tests/test_random_seeding.py | 19 +++++--- 4 files changed, 73 insertions(+), 67 deletions(-) diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index cda0dee9..e7e5ef0c 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -189,7 +189,6 @@ class RailEnv(Environment): self.action_space = [5] self._seed() - self._seed() self.random_seed = random_seed if self.random_seed: @@ -217,6 +216,7 @@ class RailEnv(Environment): self.min_number_of_steps_broken = malfunction_min_duration self.max_number_of_steps_broken = malfunction_max_duration # Reset environment + self.reset() self.num_resets = 0 # yes, set it to zero again! @@ -259,6 +259,7 @@ class RailEnv(Environment): if replace_agents then regenerate the agents static. Relies on the rail_generator returning agent_static lists (pos, dir, target) """ + if random_seed: self._seed(random_seed) @@ -388,6 +389,7 @@ class RailEnv(Environment): return False def step(self, action_dict_: Dict[int, RailEnvActions]): + self._elapsed_steps += 1 # Reset the step rewards @@ -459,7 +461,6 @@ class RailEnv(Environment): agent.status = RailAgentStatus.ACTIVE agent.position = agent.initial_position self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] - print(self.rewards_dict[i_agent]) return else: # TODO: Here we need to check for the departure time in future releases with full schedules diff --git a/tests/test_flaltland_rail_agent_status.py b/tests/test_flaltland_rail_agent_status.py index 099ccce6..14a3e48a 100644 --- a/tests/test_flaltland_rail_agent_status.py +++ b/tests/test_flaltland_rail_agent_status.py @@ -23,7 +23,6 @@ def test_initial_status(): number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) - set_penalties_for_replay(env) test_config = ReplayConfig( replay=[ @@ -40,64 +39,64 @@ def test_initial_status(): direction=Grid4TransitionsEnum.EAST, status=RailAgentStatus.READY_TO_DEPART, action=RailEnvActions.MOVE_LEFT, - reward=env.start_penalty + env.step_penalty * 0.5, # auto-correction left to forward without penalty! + reward=env.step_penalty * 0.5, # auto-correction left to forward without penalty! ), Replay( position=(3, 9), direction=Grid4TransitionsEnum.EAST, status=RailAgentStatus.ACTIVE, - action=None, - reward=env.step_penalty * 0.5, # running at speed 0.5 + action=RailEnvActions.MOVE_LEFT, + reward=env.start_penalty + env.step_penalty * 0.5, # running at speed 0.5 ), Replay( position=(3, 9), - direction=Grid4TransitionsEnum.WEST, + direction=Grid4TransitionsEnum.EAST, status=RailAgentStatus.ACTIVE, - action=RailEnvActions.MOVE_FORWARD, + action=None, reward=env.step_penalty * 0.5, # running at speed 0.5 ), Replay( position=(3, 8), direction=Grid4TransitionsEnum.WEST, status=RailAgentStatus.ACTIVE, - action=None, + action=RailEnvActions.MOVE_FORWARD, reward=env.step_penalty * 0.5, # running at speed 0.5 ), Replay( position=(3, 8), direction=Grid4TransitionsEnum.WEST, status=RailAgentStatus.ACTIVE, - action=RailEnvActions.MOVE_FORWARD, + action=None, reward=env.step_penalty * 0.5, # running at speed 0.5 ), Replay( position=(3, 7), direction=Grid4TransitionsEnum.WEST, - action=None, + action=RailEnvActions.MOVE_FORWARD, reward=env.step_penalty * 0.5, # running at speed 0.5 status=RailAgentStatus.ACTIVE ), Replay( position=(3, 7), direction=Grid4TransitionsEnum.WEST, - action=RailEnvActions.MOVE_RIGHT, + action=None, reward=env.step_penalty * 0.5, # wrong action is corrected to forward without penalty! status=RailAgentStatus.ACTIVE ), Replay( position=(3, 6), direction=Grid4TransitionsEnum.WEST, - action=None, - reward=env.global_reward, # done + action=RailEnvActions.MOVE_RIGHT, + reward=env.step_penalty * 0.5, # status=RailAgentStatus.ACTIVE ), Replay( position=(3, 6), direction=Grid4TransitionsEnum.WEST, action=None, - reward=env.global_reward, # already done - status=RailAgentStatus.DONE + reward=env.global_reward, # + status=RailAgentStatus.ACTIVE ), Replay( position=(3, 5), @@ -151,7 +150,14 @@ def test_status_done_remove(): direction=Grid4TransitionsEnum.EAST, status=RailAgentStatus.READY_TO_DEPART, action=RailEnvActions.MOVE_LEFT, - reward=env.start_penalty + env.step_penalty * 0.5, # auto-correction left to forward without penalty! + reward=env.step_penalty * 0.5, # auto-correction left to forward without penalty! + ), + Replay( + position=(3, 9), + direction=Grid4TransitionsEnum.EAST, + status=RailAgentStatus.ACTIVE, + action=RailEnvActions.MOVE_FORWARD, + reward=env.start_penalty + env.step_penalty * 0.5, # running at speed 0.5 ), Replay( position=(3, 9), @@ -173,42 +179,35 @@ def test_status_done_remove(): status=RailAgentStatus.ACTIVE, action=None, reward=env.step_penalty * 0.5, # running at speed 0.5 - ), - Replay( - position=(3, 7), - direction=Grid4TransitionsEnum.WEST, - status=RailAgentStatus.ACTIVE, - action=RailEnvActions.MOVE_FORWARD, - reward=env.step_penalty * 0.5, # running at speed 0.5 ), Replay( position=(3, 7), direction=Grid4TransitionsEnum.WEST, - action=None, + action=RailEnvActions.MOVE_RIGHT, reward=env.step_penalty * 0.5, # running at speed 0.5 status=RailAgentStatus.ACTIVE ), Replay( - position=(3, 6), + position=(3, 7), direction=Grid4TransitionsEnum.WEST, - action=RailEnvActions.MOVE_RIGHT, + action=None, reward=env.step_penalty * 0.5, # wrong action is corrected to forward without penalty! status=RailAgentStatus.ACTIVE ), Replay( position=(3, 6), direction=Grid4TransitionsEnum.WEST, - action=None, - reward=env.global_reward, # done + action=RailEnvActions.MOVE_FORWARD, + reward=env.step_penalty * 0.5, # done status=RailAgentStatus.ACTIVE ), Replay( - position=None, + position=(3, 6), direction=Grid4TransitionsEnum.WEST, action=None, reward=env.global_reward, # already done - status=RailAgentStatus.DONE_REMOVED + status=RailAgentStatus.ACTIVE ), Replay( position=None, diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index 35e41b7e..8008e6e2 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -155,16 +155,16 @@ def test_malfunction_process_statistically(): env.agents[0].target = (0, 0) nb_malfunction = 0 - agent_malfunction_list = [[6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 6], - [6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0], - [6, 6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2], - [6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0, 0], - [6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 6, 5], - [6, 5, 4, 3, 2, 1, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 6, 5], - [6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1], - [6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4], - [6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0, 6], - [6, 6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2]] + agent_malfunction_list = [[6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0], + [6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1], + [6, 6, 6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3], + [6, 6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0], + [6, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 6], + [6, 6, 5, 4, 3, 2, 1, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 6], + [6, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2], + [6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5], + [6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0], + [6, 6, 6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3]] for step in range(20): action_dict: Dict[int, RailEnvActions] = {} @@ -177,7 +177,6 @@ def test_malfunction_process_statistically(): env.step(action_dict) - def test_malfunction_before_entry(): """Tests that malfunctions are produced by stochastic_data!""" # Set fixed malfunction duration for this test @@ -200,6 +199,8 @@ def test_malfunction_before_entry(): env.reset(False, False, False, random_seed=10) env.agents[0].target = (0, 0) + # Print for test generation + assert env.agents[0].malfunction_data['malfunction'] == 11 assert env.agents[1].malfunction_data['malfunction'] == 11 assert env.agents[2].malfunction_data['malfunction'] == 11 assert env.agents[3].malfunction_data['malfunction'] == 11 @@ -210,7 +211,6 @@ def test_malfunction_before_entry(): assert env.agents[8].malfunction_data['malfunction'] == 11 assert env.agents[9].malfunction_data['malfunction'] == 11 - for step in range(20): action_dict: Dict[int, RailEnvActions] = {} for agent in env.agents: @@ -220,18 +220,17 @@ def test_malfunction_before_entry(): action_dict[agent.handle] = RailEnvActions(0) env.step(action_dict) - - assert env.agents[1].malfunction_data['malfunction'] == 1 - assert env.agents[2].malfunction_data['malfunction'] == 1 - assert env.agents[3].malfunction_data['malfunction'] == 1 - assert env.agents[4].malfunction_data['malfunction'] == 1 - assert env.agents[5].malfunction_data['malfunction'] == 1 - assert env.agents[6].malfunction_data['malfunction'] == 1 - assert env.agents[7].malfunction_data['malfunction'] == 1 - assert env.agents[8].malfunction_data['malfunction'] == 1 - assert env.agents[9].malfunction_data['malfunction'] == 1 - # Print for test generation - # for a in range(env.get_num_agents()): + assert env.agents[1].malfunction_data['malfunction'] == 2 + assert env.agents[2].malfunction_data['malfunction'] == 2 + assert env.agents[3].malfunction_data['malfunction'] == 2 + assert env.agents[4].malfunction_data['malfunction'] == 2 + assert env.agents[5].malfunction_data['malfunction'] == 2 + assert env.agents[6].malfunction_data['malfunction'] == 2 + assert env.agents[7].malfunction_data['malfunction'] == 2 + assert env.agents[8].malfunction_data['malfunction'] == 2 + assert env.agents[9].malfunction_data['malfunction'] == 2 + + #for a in range(env.get_num_agents()): # print("assert env.agents[{}].malfunction_data['malfunction'] == {}".format(a, # env.agents[a].malfunction_data[ # 'malfunction'])) @@ -348,7 +347,7 @@ def test_initial_malfunction_stop_moving(): position=(3, 2), direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.DO_NOTHING, - malfunction=2, + malfunction=3, reward=env.step_penalty, # full step penalty when stopped status=RailAgentStatus.ACTIVE ), @@ -359,7 +358,7 @@ def test_initial_malfunction_stop_moving(): position=(3, 2), direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.STOP_MOVING, - malfunction=1, + malfunction=2, reward=env.step_penalty, # full step penalty while stopped status=RailAgentStatus.ACTIVE ), @@ -368,7 +367,7 @@ def test_initial_malfunction_stop_moving(): position=(3, 2), direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.DO_NOTHING, - malfunction=0, + malfunction=1, reward=env.step_penalty, # full step penalty while stopped status=RailAgentStatus.ACTIVE ), @@ -437,7 +436,7 @@ def test_initial_malfunction_do_nothing(): position=(3, 2), direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.DO_NOTHING, - malfunction=2, + malfunction=3, reward=env.step_penalty, # full step penalty while malfunctioning status=RailAgentStatus.ACTIVE ), @@ -448,7 +447,7 @@ def test_initial_malfunction_do_nothing(): position=(3, 2), direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.DO_NOTHING, - malfunction=1, + malfunction=2, reward=env.step_penalty, # full step penalty while stopped status=RailAgentStatus.ACTIVE ), @@ -457,7 +456,7 @@ def test_initial_malfunction_do_nothing(): position=(3, 2), direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.DO_NOTHING, - malfunction=0, + malfunction=1, reward=env.step_penalty, # full step penalty while stopped status=RailAgentStatus.ACTIVE ), diff --git a/tests/test_random_seeding.py b/tests/test_random_seeding.py index 3a03de00..ecf10c49 100644 --- a/tests/test_random_seeding.py +++ b/tests/test_random_seeding.py @@ -21,7 +21,6 @@ def test_random_seeding(): number_of_agents=10 ) env.reset(True, True, False, random_seed=1) - # Test generation print env.agents[0].target = (0, 0) for step in range(10): @@ -29,12 +28,20 @@ def test_random_seeding(): actions[0] = 2 env.step(actions) agent_positions = [] - for a in range(env.get_num_agents()): - agent_positions += env.agents[a].initial_position - # print(agent_positions) - assert agent_positions == [3, 2, 3, 5, 3, 6, 5, 6, 3, 4, 3, 1, 3, 9, 4, 6, 0, 3, 3, 7] + + env.agents[0].initial_position == (3, 2) + env.agents[1].initial_position == (3, 5) + env.agents[2].initial_position == (3, 6) + env.agents[3].initial_position == (5, 6) + env.agents[4].initial_position == (3, 4) + env.agents[5].initial_position == (3, 1) + env.agents[6].initial_position == (3, 9) + env.agents[7].initial_position == (4, 6) + env.agents[8].initial_position == (0, 3) + env.agents[9].initial_position == (3, 7) # Test generation print - assert env.agents[0].position == (3, 6) + # for a in range(env.get_num_agents()): + # print("env.agents[{}].initial_position == {}".format(a,env.agents[a].initial_position)) # print("env.agents[0].initial_position == {}".format(env.agents[0].initial_position)) # print("assert env.agents[0].position == {}".format(env.agents[0].position)) -- GitLab