diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index 6efb6a8c983ab813f1978c483e41c8e0ca214e53..d9fa74ed364aed87ba936f74c39b0e4ab31771c0 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -135,7 +135,6 @@ def test_malfunction_process(): def test_malfunction_process_statistically(): """Tests hat malfunctions are produced by stochastic_data!""" # Set fixed malfunction duration for this test - # stochastic_data = {'prop_malfunction': 1., 'malfunction_rate': 5, 'min_duration': 5, @@ -156,15 +155,17 @@ def test_malfunction_process_statistically(): env.reset(True, True, False, random_seed=10) env.agents[0].target = (0, 0) - # Next line only for test generation - # agent_malfunction_list = [[] for i in range(20)] - agent_malfunction_list = [[0, 0, 0, 0, 5, 5, 0, 0, 0, 0], [0, 0, 0, 0, 5, 5, 0, 0, 0, 0], [0, 0, 0, 0, 4, 4, 0, 0, 0, 0], - [0, 0, 0, 0, 3, 3, 0, 0, 0, 0], [0, 0, 0, 0, 2, 2, 0, 0, 0, 5], [0, 0, 0, 0, 1, 1, 5, 0, 0, 4], - [0, 0, 0, 5, 0, 0, 4, 5, 0, 3], [5, 0, 0, 4, 5, 5, 3, 4, 0, 2], [4, 5, 0, 3, 4, 4, 2, 3, 5, 1], - [3, 4, 0, 2, 3, 3, 1, 2, 4, 0], [2, 3, 5, 1, 2, 2, 0, 1, 3, 0], [1, 2, 4, 0, 1, 1, 5, 0, 2, 0], - [0, 1, 3, 0, 0, 0, 4, 0, 1, 0], [5, 0, 2, 0, 0, 5, 3, 5, 0, 5], [4, 0, 1, 0, 0, 4, 2, 4, 0, 4], - [3, 0, 0, 0, 0, 3, 1, 3, 5, 3], [2, 0, 0, 0, 0, 2, 0, 2, 4, 2], [1, 0, 5, 5, 5, 1, 5, 1, 3, 1], - [0, 0, 4, 4, 4, 0, 4, 0, 2, 0], [5, 0, 3, 3, 3, 5, 3, 5, 1, 5]] + + agent_malfunction_list = [[0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 6, 5], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0, 6, 5, 4], + [0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 6, 5, 4], + [6, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0], + [6, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 0, 6, 5, 4, 3], + [0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5], + [0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0, 6, 5, 4, 3, 2, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1], + [6, 6, 6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0]] for step in range(20): action_dict: Dict[int, RailEnvActions] = {} @@ -172,18 +173,16 @@ def test_malfunction_process_statistically(): # We randomly select an action action_dict[agent_idx] = RailEnvActions(np.random.randint(4)) # For generating tests only: - #agent_malfunction_list[step].append(env.agents[agent_idx].malfunction_data['malfunction']) - assert env.agents[agent_idx].malfunction_data['malfunction'] == agent_malfunction_list[step][agent_idx] + # agent_malfunction_list[agent_idx].append(env.agents[agent_idx].malfunction_data['malfunction']) + assert env.agents[agent_idx].malfunction_data['malfunction'] == agent_malfunction_list[agent_idx][step] env.step(action_dict) - # For generating test onlz - #print(agent_malfunction_list) def test_malfunction_before_entry(): - """Tests that malfunctions are working properlz for agents before entering the environment!""" + """Tests that malfunctions are produced by stochastic_data!""" # Set fixed malfunction duration for this test stochastic_data = {'prop_malfunction': 1., - 'malfunction_rate': 5, + 'malfunction_rate': 1, 'min_duration': 10, 'max_duration': 10} @@ -199,30 +198,19 @@ def test_malfunction_before_entry(): ) # reset to initialize agents_static env.reset(False, False, False, random_seed=10) + env.agents[0].target = (0, 0) - # Test initial malfunction values for all agents - # we want some agents to be malfuncitoning already and some to be working - # we want different next_malfunction values for the agents - assert env.agents[0].malfunction_data['next_malfunction'] == 5 - assert env.agents[1].malfunction_data['next_malfunction'] == 6 - assert env.agents[2].malfunction_data['next_malfunction'] == 6 - assert env.agents[3].malfunction_data['next_malfunction'] == 3 - assert env.agents[4].malfunction_data['next_malfunction'] == 1 - assert env.agents[5].malfunction_data['next_malfunction'] == 1 - assert env.agents[6].malfunction_data['next_malfunction'] == 3 - assert env.agents[7].malfunction_data['next_malfunction'] == 4 - assert env.agents[8].malfunction_data['next_malfunction'] == 6 - assert env.agents[9].malfunction_data['next_malfunction'] == 0 - assert env.agents[0].malfunction_data['malfunction'] == 0 - assert env.agents[1].malfunction_data['malfunction'] == 0 - assert env.agents[2].malfunction_data['malfunction'] == 0 - assert env.agents[3].malfunction_data['malfunction'] == 0 - assert env.agents[4].malfunction_data['malfunction'] == 10 - assert env.agents[5].malfunction_data['malfunction'] == 10 - assert env.agents[6].malfunction_data['malfunction'] == 0 - assert env.agents[7].malfunction_data['malfunction'] == 0 - assert env.agents[8].malfunction_data['malfunction'] == 0 - assert env.agents[9].malfunction_data['malfunction'] == 0 + # Print for test generation + assert env.agents[0].malfunction_data['malfunction'] == 11 + assert env.agents[1].malfunction_data['malfunction'] == 11 + assert env.agents[2].malfunction_data['malfunction'] == 11 + assert env.agents[3].malfunction_data['malfunction'] == 11 + assert env.agents[4].malfunction_data['malfunction'] == 11 + assert env.agents[5].malfunction_data['malfunction'] == 11 + assert env.agents[6].malfunction_data['malfunction'] == 11 + assert env.agents[7].malfunction_data['malfunction'] == 11 + assert env.agents[8].malfunction_data['malfunction'] == 11 + assert env.agents[9].malfunction_data['malfunction'] == 11 for step in range(20): action_dict: Dict[int, RailEnvActions] = {} @@ -233,36 +221,21 @@ def test_malfunction_before_entry(): action_dict[agent.handle] = RailEnvActions(0) env.step(action_dict) - - # We want to check that all agents are malfunctioning and that their values changed - - # Test malfunction values for all agents after 20 steps - assert env.agents[0].malfunction_data['next_malfunction'] == 4 - assert env.agents[1].malfunction_data['next_malfunction'] == 6 - assert env.agents[2].malfunction_data['next_malfunction'] == 2 - assert env.agents[3].malfunction_data['next_malfunction'] == 2 - assert env.agents[4].malfunction_data['next_malfunction'] == 1 - assert env.agents[5].malfunction_data['next_malfunction'] == 1 - assert env.agents[6].malfunction_data['next_malfunction'] == 2 - assert env.agents[7].malfunction_data['next_malfunction'] == 1 - assert env.agents[8].malfunction_data['next_malfunction'] == 1 - assert env.agents[9].malfunction_data['next_malfunction'] == 4 - assert env.agents[0].malfunction_data['malfunction'] == 0 - assert env.agents[1].malfunction_data['malfunction'] == 8 - assert env.agents[2].malfunction_data['malfunction'] == 8 - assert env.agents[3].malfunction_data['malfunction'] == 0 - assert env.agents[4].malfunction_data['malfunction'] == 1 - assert env.agents[5].malfunction_data['malfunction'] == 1 - assert env.agents[6].malfunction_data['malfunction'] == 0 - assert env.agents[7].malfunction_data['malfunction'] == 6 - assert env.agents[8].malfunction_data['malfunction'] == 8 + assert env.agents[1].malfunction_data['malfunction'] == 2 + assert env.agents[2].malfunction_data['malfunction'] == 2 + assert env.agents[3].malfunction_data['malfunction'] == 2 + assert env.agents[4].malfunction_data['malfunction'] == 2 + assert env.agents[5].malfunction_data['malfunction'] == 2 + assert env.agents[6].malfunction_data['malfunction'] == 2 + assert env.agents[7].malfunction_data['malfunction'] == 2 + assert env.agents[8].malfunction_data['malfunction'] == 2 assert env.agents[9].malfunction_data['malfunction'] == 2 - # Print for test generation - #for a in range(env.get_num_agents()): - # print("assert env.agents[{}].malfunction_data['next_malfunction'] == {}".format(a, env.agents[a].malfunction_data['next_malfunction'])) - #for a in range(env.get_num_agents()): - # print("assert env.agents[{}].malfunction_data['malfunction'] == {}".format(a, env.agents[a].malfunction_data[ - # 'malfunction'])) + + # for a in range(env.get_num_agents()): + # print("assert env.agents[{}].malfunction_data['malfunction'] == {}".format(a, + # env.agents[a].malfunction_data[ + # 'malfunction'])) + def test_initial_malfunction(): stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents @@ -613,60 +586,3 @@ def tests_random_interference_from_outside(): _, reward, _, _ = env.step(action_dict) assert reward[0] == env_data[step][0] assert env.agents[0].position == env_data[step][1] - - -def test_last_malfunction_step(): - """ - Test to check that agent moves when it is not malfunctioning - - """ - - # Set fixed malfunction duration for this test - stochastic_data = {'prop_malfunction': 1., - 'malfunction_rate': 5, - 'min_duration': 4, - 'max_duration': 4} - - rail, rail_map = make_simple_rail2() - - env = RailEnv(width=25, - height=30, - rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(seed=2), # seed 12 - number_of_agents=1, - random_seed=1, - stochastic_data=stochastic_data, # Malfunction data generator - ) - env.reset() - # reset to initialize agents_static - env.agents[0].speed_data['speed'] = 0.33 - env.agents_static[0].target = (0, 0) - - env.reset(False, False, True) - # Force malfunction to be off at beginning and next malfunction to happen in 2 steps - env.agents[0].malfunction_data['next_malfunction'] = 2 - env.agents[0].malfunction_data['malfunction'] = 0 - env_data = [] - - for step in range(20): - action_dict: Dict[int, RailEnvActions] = {} - for agent in env.agents: - # Go forward all the time - action_dict[agent.handle] = RailEnvActions(2) - - # Check if the agent is still allowed to move in this step - if env.agents[0].malfunction_data['malfunction'] > 1 or env.agents[0].malfunction_data['next_malfunction'] < 1: - agent_can_move = False - else: - agent_can_move = True - - # Store the position before and after the step - pre_position = env.agents[0].speed_data['position_fraction'] - _, reward, _, _ = env.step(action_dict) - post_position = env.agents[0].speed_data['position_fraction'] - - # Assert that the agent moved while it was still allowed - if agent_can_move: - assert pre_position != post_position - else: - assert post_position == pre_position