diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 945771aaefe790f94e57794ade6fe3311b2928af..86987a56701f5dac220af6153ff3f7e7fbfdeb49 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -354,7 +354,7 @@ class RailEnv(Environment): # If counter has come to zero --> Agent has malfunction # set next malfunction time and duration of current malfunction if agent.malfunction_data['malfunction_rate'] >= 1 and 1 > agent.malfunction_data['malfunction'] and \ - agent.malfunction_data['next_malfunction'] <= 0: + agent.malfunction_data['next_malfunction'] < 1: # Increase number of malfunctions agent.malfunction_data['nr_malfunctions'] += 1 diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index 8e56a6a16f051c3850dd42ad8c2481b24024a7d7..35e41b7e0c0e55e08d6096e43677ab53606553d9 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -136,9 +136,9 @@ def test_malfunction_process_statistically(): """Tests hat malfunctions are produced by stochastic_data!""" # Set fixed malfunction duration for this test stochastic_data = {'prop_malfunction': 1., - 'malfunction_rate': 2, - 'min_duration': 3, - 'max_duration': 3} + 'malfunction_rate': 5, + 'min_duration': 5, + 'max_duration': 5} rail, rail_map = make_simple_rail2() @@ -146,7 +146,7 @@ def test_malfunction_process_statistically(): height=30, rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), - number_of_agents=1, + number_of_agents=10, stochastic_data=stochastic_data, # Malfunction data generator obs_builder_object=SingleAgentNavigationObs() ) @@ -155,15 +155,27 @@ def test_malfunction_process_statistically(): env.agents[0].target = (0, 0) nb_malfunction = 0 + agent_malfunction_list = [[6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 6], + [6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0], + [6, 6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2], + [6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0, 0], + [6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 6, 5], + [6, 5, 4, 3, 2, 1, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 6, 5], + [6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1], + [6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4], + [6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0, 6], + [6, 6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2]] + for step in range(20): action_dict: Dict[int, RailEnvActions] = {} - for agent in env.agents: + for agent_idx in range(env.get_num_agents()): # We randomly select an action - action_dict[agent.handle] = RailEnvActions(np.random.randint(4)) - + action_dict[agent_idx] = RailEnvActions(np.random.randint(4)) + # For generating tests only: + # agent_malfunction_list[agent_idx].append(env.agents[agent_idx].malfunction_data['malfunction']) + assert env.agents[agent_idx].malfunction_data['malfunction'] == agent_malfunction_list[agent_idx][step] env.step(action_dict) - # check that generation of malfunctions works as expected - assert env.agents[0].malfunction_data["nr_malfunctions"] == 4 + def test_malfunction_before_entry():