diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index ce6e1283fbb1155f81aa722a20829c732042e282..a46467ed4cbefdf38ea4ce8b0be4defadc2cab1e 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -156,14 +156,20 @@ def test_malfunction_process_statistically(): env.agents[0].target = (0, 0) # Next line only for test generation - #agent_malfunction_list = [[] for i in range(20)] - agent_malfunction_list = [[0, 0, 0, 0, 5, 5, 0, 0, 0, 0], [0, 0, 0, 0, 5, 5, 0, 0, 0, 0], [0, 0, 0, 0, 4, 4, 0, 0, 0, 0], - [0, 0, 0, 0, 3, 3, 0, 0, 0, 0], [0, 0, 0, 0, 2, 2, 0, 0, 0, 5], [0, 0, 0, 0, 1, 1, 5, 0, 0, 4], - [0, 0, 0, 5, 0, 0, 4, 5, 0, 3], [5, 0, 0, 4, 0, 0, 3, 4, 0, 2], [4, 5, 0, 3, 5, 5, 2, 3, 5, 1], - [3, 4, 0, 2, 4, 4, 1, 2, 4, 0], [2, 3, 5, 1, 3, 3, 0, 1, 3, 0], [1, 2, 4, 0, 2, 2, 0, 0, 2, 0], - [0, 1, 3, 0, 1, 1, 5, 0, 1, 0], [0, 0, 2, 0, 0, 0, 4, 0, 0, 0], [5, 0, 1, 0, 0, 0, 3, 5, 0, 5], - [4, 0, 0, 0, 5, 0, 2, 4, 0, 4], [3, 0, 0, 0, 4, 0, 1, 3, 5, 3], [2, 0, 0, 0, 3, 0, 0, 2, 4, 2], - [1, 0, 5, 5, 2, 0, 0, 1, 3, 1], [0, 5, 4, 4, 1, 0, 5, 0, 2, 0]] + # agent_malfunction_list = [[] for i in range(20)] + agent_malfunction_list = [[0, 0, 0, 0, 5, 5, 0, 0, 0, 0], [0, 0, 0, 0, 5, 5, 0, 0, 0, 0], + [0, 0, 0, 0, 4, 4, 0, 0, 0, 0], + [0, 0, 0, 0, 3, 3, 0, 0, 0, 0], [0, 0, 0, 0, 2, 2, 0, 0, 0, 5], + [0, 0, 0, 0, 1, 1, 5, 0, 0, 4], + [0, 0, 0, 5, 0, 0, 4, 5, 0, 3], [5, 0, 0, 4, 0, 0, 3, 4, 0, 2], + [4, 5, 0, 3, 5, 5, 2, 3, 5, 1], + [3, 4, 0, 2, 4, 4, 1, 2, 4, 0], [2, 3, 5, 1, 3, 3, 0, 1, 3, 0], + [1, 2, 4, 0, 2, 2, 0, 0, 2, 0], + [0, 1, 3, 0, 1, 1, 5, 0, 1, 0], [0, 0, 2, 0, 0, 0, 4, 0, 0, 0], + [5, 0, 1, 0, 0, 0, 3, 5, 0, 5], + [4, 0, 0, 0, 5, 0, 2, 4, 0, 4], [3, 0, 0, 0, 4, 0, 1, 3, 5, 3], + [2, 0, 0, 0, 3, 0, 0, 2, 4, 2], + [1, 0, 5, 5, 2, 0, 0, 1, 3, 1], [0, 5, 4, 4, 1, 0, 5, 0, 2, 0]] for step in range(20): action_dict: Dict[int, RailEnvActions] = {} @@ -175,7 +181,7 @@ def test_malfunction_process_statistically(): assert env.agents[agent_idx].malfunction_data['malfunction'] == agent_malfunction_list[step][agent_idx] env.step(action_dict) # For generating test onlz - #print(agent_malfunction_list) + # print(agent_malfunction_list) def test_malfunction_before_entry(): @@ -213,6 +219,7 @@ def test_malfunction_before_entry(): assert env.agents[8].malfunction_data['malfunction'] == 0 assert env.agents[9].malfunction_data['malfunction'] == 0 + def test_next_malfunction_counter(): """ Test that the next malfunction occurs when desired @@ -237,7 +244,7 @@ def test_next_malfunction_counter(): env.agents[0].malfunction_data['next_malfunction'] = 5 env.agents[0].malfunction_data['malfunction_rate'] = 5 env.agents[0].malfunction_data['malfunction'] = 0 - env.agents[0].target =(0, 0), #Move the target out of range + env.agents[0].target = (0, 0), # Move the target out of range print(env.agents[0].position, env.agents[0].malfunction_data['next_malfunction']) for time_step in range(1, 6): @@ -248,6 +255,47 @@ def test_next_malfunction_counter(): assert env.agents[0].malfunction_data['next_malfunction'] == 5 - time_step +def test_malfunction_values_and_behavior(): + """ + Test that the next malfunction occurs when desired. + Returns + ------- + + """ + # Set fixed malfunction duration for this test + + rail, rail_map = make_simple_rail2() + action_dict: Dict[int, RailEnvActions] = {} + stochastic_data = {'prop_malfunction': 1., + 'malfunction_rate': 5, + 'min_duration': 10, + 'max_duration': 10} + env = RailEnv(width=25, + height=30, + rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(seed=2), # seed 12 + stochastic_data=stochastic_data, + number_of_agents=1, + random_seed=1, + ) + # reset to initialize agents_static + env.reset(False, False, activate_agents=True, random_seed=10) + env.agents[0].malfunction_data['next_malfunction'] = 5 + env.agents[0].malfunction_data['malfunction_rate'] = 50 + env.agents[0].malfunction_data['malfunction'] = 0 + env.agents[0].target = (0, 0), # Move the target out of range + print(env.agents[0].position, env.agents[0].malfunction_data['next_malfunction']) + + for time_step in range(1, 16): + # Move in the env + env.step(action_dict) + print(time_step) + # Check that next_step decreases as expected + if env.agents[0].malfunction_data['malfunction'] < 1: + assert env.agents[0].malfunction_data['next_malfunction'] == np.clip(5 - time_step, 0, 100) + else: + assert env.agents[0].malfunction_data['malfunction'] == np.clip(10 - (time_step - 6), 0, 100) + def test_initial_malfunction(): stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents @@ -295,15 +343,15 @@ def test_initial_malfunction(): direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.MOVE_FORWARD, malfunction=1, - reward= env.step_penalty * 1.0 + reward=env.step_penalty * 1.0 - ),# malfunctioning ends: starting and running at speed 1.0 + ), # malfunctioning ends: starting and running at speed 1.0 Replay( position=(3, 2), direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.MOVE_FORWARD, malfunction=0, - reward=env.start_penalty +env.step_penalty * 1.0 # running at speed 1.0 + reward=env.start_penalty + env.step_penalty * 1.0 # running at speed 1.0 ), Replay( position=(3, 3),