diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index 34a8032d6ae3e482d0cb321c9ee19dd1ac58dabb..d5fce5d6c8fa4cf37b4c72f6264726f6d502f77d 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -177,6 +177,46 @@ def test_malfunction_process_statistically(): assert env.agents[0].malfunction_data["nr_malfunctions"] == 4 +def test_malfunction_before_entry(): + """Tests hat malfunctions are produced by stochastic_data!""" + # Set fixed malfunction duration for this test + stochastic_data = {'prop_malfunction': 1., + 'malfunction_rate': 2, + 'min_duration': 10, + 'max_duration': 10} + + random.seed(0) + np.random.seed(0) + + rail, rail_map = make_simple_rail2() + + env = RailEnv(width=25, + height=30, + rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), + number_of_agents=1, + stochastic_data=stochastic_data, # Malfunction data generator + obs_builder_object=SingleAgentNavigationObs() + ) + # reset to initialize agents_static + env.reset(False, False, False) + env.agents[0].target = (0, 0) + nb_malfunction = 0 + for step in range(20): + action_dict: Dict[int, RailEnvActions] = {} + for agent in env.agents: + # We randomly select an action + if step < 10: + action_dict[agent.handle] = RailEnvActions(0) + assert env.agents[0].malfunction_data['malfunction'] == 0 + else: + action_dict[agent.handle] = RailEnvActions(2) + + print(env.agents[0].malfunction_data) + env.step(action_dict) + assert env.agents[0].malfunction_data['malfunction'] > 0 + + def test_initial_malfunction(): random.seed(0)