From ff2a8e8fec5e76d8ee2b599825f5caeef7834fbe Mon Sep 17 00:00:00 2001 From: MLErik <baerenjesus@gmail.com> Date: Fri, 25 Oct 2019 16:32:27 -0400 Subject: [PATCH] updated test utils to respect new criteria in malfunction generator --- flatland/envs/rail_env.py | 2 +- tests/test_flatland_malfunction.py | 43 +++++++++++++++--------------- tests/test_utils.py | 5 ++++ 3 files changed, 27 insertions(+), 23 deletions(-) diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index ab4a5e44..842dba38 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -407,7 +407,7 @@ class RailEnv(Environment): # Duration of current malfunction num_broken_steps = self.np_random.randint(self.min_number_of_steps_broken, - self.max_number_of_steps_broken + 1) + 1 + self.max_number_of_steps_broken + 1) agent.malfunction_data['malfunction'] = num_broken_steps # Remember current moving state of the agent agent.malfunction_data['moving_before_malfunction'] = agent.moving diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index 527a50f7..d3dc6fa7 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -157,16 +157,13 @@ def test_malfunction_process_statistically(): env.agents[0].target = (0, 0) # Next line only for test generation # agent_malfunction_list = [[] for i in range(20)] - agent_malfunction_list = [[0, 0, 0, 0, 6, 6, 0, 0, 0, 0], [0, 0, 0, 0, 6, 6, 0, 0, 0, 0], - [0, 0, 0, 0, 5, 5, 0, 0, 0, 0], [0, 0, 0, 0, 4, 4, 0, 0, 0, 0], - [0, 0, 0, 0, 3, 3, 0, 0, 0, 6], [0, 0, 0, 0, 2, 2, 6, 0, 0, 5], - [0, 0, 0, 6, 1, 1, 5, 6, 0, 4], [6, 0, 0, 5, 0, 0, 4, 5, 0, 3], - [5, 6, 0, 4, 6, 6, 3, 4, 6, 2], [4, 5, 0, 3, 5, 5, 2, 3, 5, 1], - [3, 4, 6, 2, 4, 4, 1, 2, 4, 0], [2, 3, 5, 1, 3, 3, 0, 1, 3, 0], - [1, 2, 4, 0, 2, 2, 6, 0, 2, 0], [0, 1, 3, 0, 1, 1, 5, 0, 1, 0], - [6, 0, 2, 0, 0, 0, 4, 6, 0, 6], [5, 0, 1, 0, 6, 0, 3, 5, 0, 5], - [4, 0, 0, 0, 5, 0, 2, 4, 6, 4], [3, 0, 0, 0, 4, 0, 1, 3, 5, 3], - [2, 0, 6, 6, 3, 0, 0, 2, 4, 2], [1, 6, 5, 5, 2, 0, 6, 1, 3, 1]] + agent_malfunction_list = [[0, 0, 0, 0, 5, 5, 0, 0, 0, 0], [0, 0, 0, 0, 5, 5, 0, 0, 0, 0], [0, 0, 0, 0, 4, 4, 0, 0, 0, 0], + [0, 0, 0, 0, 3, 3, 0, 0, 0, 0], [0, 0, 0, 0, 2, 2, 0, 0, 0, 5], [0, 0, 0, 0, 1, 1, 5, 0, 0, 4], + [0, 0, 0, 5, 0, 0, 4, 5, 0, 3], [5, 0, 0, 4, 5, 5, 3, 4, 0, 2], [4, 5, 0, 3, 4, 4, 2, 3, 5, 1], + [3, 4, 0, 2, 3, 3, 1, 2, 4, 0], [2, 3, 5, 1, 2, 2, 0, 1, 3, 0], [1, 2, 4, 0, 1, 1, 5, 0, 2, 0], + [0, 1, 3, 0, 0, 0, 4, 0, 1, 0], [5, 0, 2, 0, 0, 5, 3, 5, 0, 5], [4, 0, 1, 0, 0, 4, 2, 4, 0, 4], + [3, 0, 0, 0, 0, 3, 1, 3, 5, 3], [2, 0, 0, 0, 0, 2, 0, 2, 4, 2], [1, 0, 5, 5, 5, 1, 5, 1, 3, 1], + [0, 0, 4, 4, 4, 0, 4, 0, 2, 0], [5, 0, 3, 3, 3, 5, 3, 5, 1, 5]] for step in range(20): action_dict: Dict[int, RailEnvActions] = {} @@ -174,7 +171,7 @@ def test_malfunction_process_statistically(): # We randomly select an action action_dict[agent_idx] = RailEnvActions(np.random.randint(4)) # For generating tests only: - # agent_malfunction_list[step].append(env.agents[agent_idx].malfunction_data['malfunction']) + #agent_malfunction_list[step].append(env.agents[agent_idx].malfunction_data['malfunction']) assert env.agents[agent_idx].malfunction_data['malfunction'] == agent_malfunction_list[step][agent_idx] env.step(action_dict) # For generating test onlz @@ -205,6 +202,7 @@ def test_malfunction_before_entry(): # Test initial malfunction values for all agents # we want some agents to be malfuncitoning already and some to be working # we want different next_malfunction values for the agents + assert env.agents[0].malfunction_data['next_malfunction'] == 5 assert env.agents[1].malfunction_data['next_malfunction'] == 6 assert env.agents[2].malfunction_data['next_malfunction'] == 6 assert env.agents[3].malfunction_data['next_malfunction'] == 3 @@ -218,8 +216,8 @@ def test_malfunction_before_entry(): assert env.agents[1].malfunction_data['malfunction'] == 0 assert env.agents[2].malfunction_data['malfunction'] == 0 assert env.agents[3].malfunction_data['malfunction'] == 0 - assert env.agents[4].malfunction_data['malfunction'] == 11 - assert env.agents[5].malfunction_data['malfunction'] == 11 + assert env.agents[4].malfunction_data['malfunction'] == 10 + assert env.agents[5].malfunction_data['malfunction'] == 10 assert env.agents[6].malfunction_data['malfunction'] == 0 assert env.agents[7].malfunction_data['malfunction'] == 0 assert env.agents[8].malfunction_data['malfunction'] == 0 @@ -236,6 +234,9 @@ def test_malfunction_before_entry(): env.step(action_dict) # We want to check that all agents are malfunctioning and that their values changed + + # Test malfunction values for all agents after 20 steps + assert env.agents[0].malfunction_data['next_malfunction'] == 4 assert env.agents[1].malfunction_data['next_malfunction'] == 6 assert env.agents[2].malfunction_data['next_malfunction'] == 2 assert env.agents[3].malfunction_data['next_malfunction'] == 2 @@ -246,17 +247,15 @@ def test_malfunction_before_entry(): assert env.agents[8].malfunction_data['next_malfunction'] == 1 assert env.agents[9].malfunction_data['next_malfunction'] == 4 assert env.agents[0].malfunction_data['malfunction'] == 0 - assert env.agents[1].malfunction_data['malfunction'] == 9 - assert env.agents[2].malfunction_data['malfunction'] == 9 + assert env.agents[1].malfunction_data['malfunction'] == 8 + assert env.agents[2].malfunction_data['malfunction'] == 8 assert env.agents[3].malfunction_data['malfunction'] == 0 - assert env.agents[4].malfunction_data['malfunction'] == 2 - assert env.agents[5].malfunction_data['malfunction'] == 2 + assert env.agents[4].malfunction_data['malfunction'] == 1 + assert env.agents[5].malfunction_data['malfunction'] == 1 assert env.agents[6].malfunction_data['malfunction'] == 0 - assert env.agents[7].malfunction_data['malfunction'] == 7 - assert env.agents[8].malfunction_data['malfunction'] == 9 - assert env.agents[9].malfunction_data['malfunction'] == 3 - # Test malfunction values for all agents after 20 steps - + assert env.agents[7].malfunction_data['malfunction'] == 6 + assert env.agents[8].malfunction_data['malfunction'] == 8 + assert env.agents[9].malfunction_data['malfunction'] == 2 # Print for test generation #for a in range(env.get_num_agents()): # print("assert env.agents[{}].malfunction_data['next_malfunction'] == {}".format(a, env.agents[a].malfunction_data['next_malfunction'])) diff --git a/tests/test_utils.py b/tests/test_utils.py index d9378fd3..1a98c161 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -114,7 +114,12 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: step, a, False, info_dict['action_required'][a]) if replay.set_malfunction is not None: + # As we force malfunctions on the agents we have to set a positive rate that the env + # recognizes the agent as potentially malfuncitoning + # We also set next malfunction to infitiy to avoid interference with our tests agent.malfunction_data['malfunction'] = replay.set_malfunction + agent.malfunction_data['malfunction_rate'] = max(agent.malfunction_data['malfunction_rate'], 1) + agent.malfunction_data['next_malfunction'] = np.inf agent.malfunction_data['moving_before_malfunction'] = agent.moving _assert(a, agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction') print(step) -- GitLab