diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 53784a6725a24b65b2d026eaf353f7a72e375619..3139437106861d0f496e6c3e7e840300bd33307c 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -398,8 +398,8 @@ class RailEnv(Environment): Malfunction generator that breaks agents at a given rate. It does randomly chose agent to break during the run """ - if np.random.random() < self._malfunction_prob(rate): - breaking_agent = random.choice(self.agents) + if self.np_random.randn() < self._malfunction_prob(rate): + breaking_agent = self.np_random.choice(self.agents) if breaking_agent.malfunction_data['malfunction'] < 1: num_broken_steps = self.np_random.randint(self.min_number_of_steps_broken, self.max_number_of_steps_broken + 1) diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index 7eac117f94ba921b4b0ba406bdc3bcc258ba76c1..2f5ea5f2c4af3ab0e6c912202bdfd234bf0de7a0 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -120,8 +120,7 @@ def test_malfunction_process(): def test_malfunction_process_statistically(): """Tests hat malfunctions are produced by stochastic_data!""" # Set fixed malfunction duration for this test - stochastic_data = {'prop_malfunction': 1., - 'malfunction_rate': 5, + stochastic_data = {'malfunction_rate': 5, 'min_duration': 5, 'max_duration': 5} @@ -142,19 +141,19 @@ def test_malfunction_process_statistically(): env.agents[0].target = (0, 0) # Next line only for test generation # agent_malfunction_list = [[] for i in range(20)] - agent_malfunction_list = [[0, 0, 0, 0, 5, 5, 0, 0, 0, 0], [0, 0, 0, 0, 5, 5, 0, 0, 0, 0], - [0, 0, 0, 0, 4, 4, 0, 0, 0, 0], - [0, 0, 0, 0, 3, 3, 0, 0, 0, 0], [0, 0, 0, 0, 2, 2, 0, 0, 0, 5], - [0, 0, 0, 0, 1, 1, 5, 0, 0, 4], - [0, 0, 0, 5, 0, 0, 4, 5, 0, 3], [5, 0, 0, 4, 0, 0, 3, 4, 0, 2], - [4, 5, 0, 3, 5, 5, 2, 3, 5, 1], - [3, 4, 0, 2, 4, 4, 1, 2, 4, 0], [2, 3, 5, 1, 3, 3, 0, 1, 3, 0], - [1, 2, 4, 0, 2, 2, 0, 0, 2, 0], - [0, 1, 3, 0, 1, 1, 5, 0, 1, 0], [0, 0, 2, 0, 0, 0, 4, 0, 0, 0], - [5, 0, 1, 0, 0, 0, 3, 5, 0, 5], - [4, 0, 0, 0, 5, 0, 2, 4, 0, 4], [3, 0, 0, 0, 4, 0, 1, 3, 5, 3], - [2, 0, 0, 0, 3, 0, 0, 2, 4, 2], - [1, 0, 5, 5, 2, 0, 0, 1, 3, 1], [0, 5, 4, 4, 1, 0, 5, 0, 2, 0]] + agent_malfunction_list = [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 4], [0, 0, 0, 0, 0, 0, 0, 0, 0, 3], + [4, 0, 0, 0, 0, 0, 0, 0, 0, 2], + [3, 0, 0, 0, 0, 0, 0, 0, 0, 1], [2, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 4, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 3, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 2, 4, 0, 0], [0, 0, 0, 0, 0, 0, 1, 3, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 2, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 4], [0, 0, 0, 0, 0, 0, 0, 0, 0, 3], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 2], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 1], [4, 0, 0, 0, 0, 0, 0, 0, 0, 0]] for step in range(20): action_dict: Dict[int, RailEnvActions] = {} @@ -170,10 +169,9 @@ def test_malfunction_process_statistically(): def test_malfunction_before_entry(): - """Tests that malfunctions are working properlz for agents before entering the environment!""" + """Tests that malfunctions are working properly for agents before entering the environment!""" # Set fixed malfunction duration for this test - stochastic_data = {'prop_malfunction': 1., - 'malfunction_rate': 5, + stochastic_data = {'malfunction_rate': 1, 'min_duration': 10, 'max_duration': 10} @@ -193,51 +191,15 @@ def test_malfunction_before_entry(): # Test initial malfunction values for all agents # we want some agents to be malfuncitoning already and some to be working # we want different next_malfunction values for the agents - assert env.agents[0].malfunction_data['malfunction'] == 0 assert env.agents[1].malfunction_data['malfunction'] == 0 assert env.agents[2].malfunction_data['malfunction'] == 0 assert env.agents[3].malfunction_data['malfunction'] == 0 - assert env.agents[4].malfunction_data['malfunction'] == 10 - assert env.agents[5].malfunction_data['malfunction'] == 10 + assert env.agents[4].malfunction_data['malfunction'] == 0 + assert env.agents[5].malfunction_data['malfunction'] == 0 assert env.agents[6].malfunction_data['malfunction'] == 0 assert env.agents[7].malfunction_data['malfunction'] == 0 assert env.agents[8].malfunction_data['malfunction'] == 0 - assert env.agents[9].malfunction_data['malfunction'] == 0 - - -def test_next_malfunction_counter(): - """ - Test that the next malfunction occurs when desired - Returns - ------- - - """ - # Set fixed malfunction duration for this test - - rail, rail_map = make_simple_rail2() - action_dict: Dict[int, RailEnvActions] = {} - - env = RailEnv(width=25, - height=30, - rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(seed=2), # seed 12 - number_of_agents=1, - random_seed=1, - ) - # reset to initialize agents_static - env.reset(False, False, activate_agents=True, random_seed=10) - env.agents[0].malfunction_data['next_malfunction'] = 5 - env.agents[0].malfunction_data['malfunction_rate'] = 5 - env.agents[0].malfunction_data['malfunction'] = 0 - env.agents[0].target = (0, 0), # Move the target out of range - print(env.agents[0].position, env.agents[0].malfunction_data['next_malfunction']) - - for time_step in range(1, 6): - # Move in the env - env.step(action_dict) - - # Check that next_step decreases as expected - assert env.agents[0].malfunction_data['next_malfunction'] == 5 - time_step + assert env.agents[9].malfunction_data['malfunction'] == 9 def test_malfunction_values_and_behavior(): @@ -251,8 +213,7 @@ def test_malfunction_values_and_behavior(): rail, rail_map = make_simple_rail2() action_dict: Dict[int, RailEnvActions] = {} - stochastic_data = {'prop_malfunction': 1., - 'malfunction_rate': 5, + stochastic_data = {'malfunction_rate': 5, 'min_duration': 10, 'max_duration': 10} env = RailEnv(width=25, @@ -263,23 +224,18 @@ def test_malfunction_values_and_behavior(): number_of_agents=1, random_seed=1, ) + # reset to initialize agents_static env.reset(False, False, activate_agents=True, random_seed=10) - env.agents[0].malfunction_data['next_malfunction'] = 5 - env.agents[0].malfunction_data['malfunction_rate'] = 50 - env.agents[0].malfunction_data['malfunction'] = 0 - env.agents[0].target = (0, 0), # Move the target out of range - print(env.agents[0].position, env.agents[0].malfunction_data['next_malfunction']) - for time_step in range(1, 16): + # Assertions + assert_list = [9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 0, 9, 8, 7, 6] + print("[") + for time_step in range(15): # Move in the env env.step(action_dict) - print(time_step) # Check that next_step decreases as expected - if env.agents[0].malfunction_data['malfunction'] < 1: - assert env.agents[0].malfunction_data['next_malfunction'] == np.clip(5 - time_step, 0, 100) - else: - assert env.agents[0].malfunction_data['malfunction'] == np.clip(10 - (time_step - 6), 0, 100) + assert env.agents[0].malfunction_data['malfunction'] == assert_list[time_step] def test_initial_malfunction(): @@ -529,45 +485,14 @@ def test_initial_malfunction_do_nothing(): run_replay_config(env, [replay_config], activate_agents=False) -def test_initial_nextmalfunction_not_below_zero(): - random.seed(0) - np.random.seed(0) - - stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents - 'malfunction_rate': 70, # Rate of malfunction occurence - 'min_duration': 2, # Minimal duration of malfunction - 'max_duration': 5 # Max duration of malfunction - } - - rail, rail_map = make_simple_rail2() - - env = RailEnv(width=25, - height=30, - rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), - number_of_agents=1, - stochastic_data=stochastic_data, # Malfunction data generator - obs_builder_object=SingleAgentNavigationObs() - ) - # reset to initialize agents_static - env.reset() - agent = env.agents[0] - env.step({}) - # was next_malfunction was -1 befor the bugfix https://gitlab.aicrowd.com/flatland/flatland/issues/186 - assert agent.malfunction_data['next_malfunction'] >= 0, \ - "next_malfunction should be >=0, found {}".format(agent.malfunction_data['next_malfunction']) - - def tests_random_interference_from_outside(): """Tests that malfunctions are produced by stochastic_data!""" # Set fixed malfunction duration for this test - stochastic_data = {'prop_malfunction': 1., - 'malfunction_rate': 1, + stochastic_data = {'malfunction_rate': 1, 'min_duration': 10, 'max_duration': 10} rail, rail_map = make_simple_rail2() - env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), @@ -579,9 +504,7 @@ def tests_random_interference_from_outside(): env.reset() # reset to initialize agents_static env.agents[0].speed_data['speed'] = 0.33 - env.agents[0].initial_position = (3, 0) - env.agents[0].target = (3, 9) - env.reset(False, False, False) + env.reset(False, False, False, random_seed=10) env_data = [] for step in range(200): @@ -612,11 +535,8 @@ def tests_random_interference_from_outside(): env.reset() # reset to initialize agents_static env.agents[0].speed_data['speed'] = 0.33 - env.agents[0].initial_position = (3, 0) - env.agents[0].target = (3, 9) - env.reset(False, False, False) + env.reset(False, False, False, random_seed=10) - # Print for test generation dummy_list = [1, 2, 6, 7, 8, 9, 4, 5, 4] for step in range(200): action_dict: Dict[int, RailEnvActions] = {} diff --git a/tests/test_utils.py b/tests/test_utils.py index 1a98c161829dedc429465b6606101fd19784cbaa..80656cbb490c39fc7352327faa9bf5859a1123e9 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -118,8 +118,6 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: # recognizes the agent as potentially malfuncitoning # We also set next malfunction to infitiy to avoid interference with our tests agent.malfunction_data['malfunction'] = replay.set_malfunction - agent.malfunction_data['malfunction_rate'] = max(agent.malfunction_data['malfunction_rate'], 1) - agent.malfunction_data['next_malfunction'] = np.inf agent.malfunction_data['moving_before_malfunction'] = agent.moving _assert(a, agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction') print(step)