diff --git a/examples/introduction_flatland_2_1.py b/examples/introduction_flatland_2_1.py index 5fd0498cee4b89bb85edd5831b6735a118046474..980c9a7dd3fa3bfb4cd5c964aa511386bbdb38b7 100644 --- a/examples/introduction_flatland_2_1.py +++ b/examples/introduction_flatland_2_1.py @@ -58,8 +58,7 @@ schedule_generator = sparse_schedule_generator(speed_ration_map) # We can furthermore pass stochastic data to the RailEnv constructor which will allow for stochastic malfunctions # during an episode. -stochastic_data = {'prop_malfunction': 0.3, # Percentage of defective agents - 'malfunction_rate': 50, # Rate of malfunction occurence +stochastic_data = {'malfunction_rate': 5, # Rate of malfunction occurence 'min_duration': 3, # Minimal duration of malfunction 'max_duration': 20 # Max duration of malfunction } diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 82c1bc07f807ef4e98e11c48d0b2e13f2c27d9ad..ef8be5a8b2741d2d6be2aad23e399896a68f515c 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -347,6 +347,8 @@ class RailEnv(Environment): if activate_agents: for i_agent in range(self.get_num_agents()): self.set_agent_active(i_agent) + + # See if agents are already broken self._malfunction(self.mean_malfunction_rate) for i_agent, agent in enumerate(self.agents): initial_malfunction = self._agent_malfunction(i_agent) @@ -400,12 +402,12 @@ class RailEnv(Environment): self.agents[i_agent].moving = agent.malfunction_data['moving_before_malfunction'] return False - def _malfunction(self, rate) -> bool: + def _malfunction(self, rate): """ Malfunction generator that breaks agents at a given rate. It does randomly chose agent to break during the run """ - if self.np_random.randn() < self._malfunction_prob(rate): + if self.np_random.rand() < self._malfunction_prob(rate): breaking_agent = self.np_random.choice(self.agents) if breaking_agent.malfunction_data['malfunction'] < 1: num_broken_steps = self.np_random.randint(self.min_number_of_steps_broken, diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index 2f5ea5f2c4af3ab0e6c912202bdfd234bf0de7a0..6dbb644a8e4837646d3ae5379971e9a5cb800342 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -110,7 +110,7 @@ def test_malfunction_process(): total_down_time += env.agents[0].malfunction_data['malfunction'] # Check that the appropriate number of malfunctions is achieved - assert env.agents[0].malfunction_data['nr_malfunctions'] == 30, "Actual {}".format( + assert env.agents[0].malfunction_data['nr_malfunctions'] == 28, "Actual {}".format( env.agents[0].malfunction_data['nr_malfunctions']) # Check that malfunctioning data was standing around @@ -140,20 +140,14 @@ def test_malfunction_process_statistically(): env.agents[0].target = (0, 0) # Next line only for test generation - # agent_malfunction_list = [[] for i in range(20)] - agent_malfunction_list = [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 4], [0, 0, 0, 0, 0, 0, 0, 0, 0, 3], - [4, 0, 0, 0, 0, 0, 0, 0, 0, 2], - [3, 0, 0, 0, 0, 0, 0, 0, 0, 1], [2, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 4, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 3, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 2, 4, 0, 0], [0, 0, 0, 0, 0, 0, 1, 3, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 2, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 1, 0, 4], [0, 0, 0, 0, 0, 0, 0, 0, 0, 3], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 2], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 1], [4, 0, 0, 0, 0, 0, 0, 0, 0, 0]] + agent_malfunction_list = [[] for i in range(20)] + agent_malfunction_list = [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 5], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 5], [0, 0, 0, 0, 0, 0, 0, 0, 0, 4], [0, 4, 0, 0, 0, 0, 0, 0, 0, 3], + [0, 3, 0, 0, 0, 0, 0, 0, 0, 2], [0, 2, 0, 0, 0, 0, 0, 0, 0, 1], [0, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [4, 0, 0, 0, 0, 0, 0, 0, 0, 0], [3, 0, 0, 0, 0, 0, 0, 0, 0, 0], [2, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]] for step in range(20): action_dict: Dict[int, RailEnvActions] = {} @@ -161,17 +155,17 @@ def test_malfunction_process_statistically(): # We randomly select an action action_dict[agent_idx] = RailEnvActions(np.random.randint(4)) # For generating tests only: - # agent_malfunction_list[step].append(env.agents[agent_idx].malfunction_data['malfunction']) + #agent_malfunction_list[step].append(env.agents[agent_idx].malfunction_data['malfunction']) assert env.agents[agent_idx].malfunction_data['malfunction'] == agent_malfunction_list[step][agent_idx] env.step(action_dict) # For generating test onlz - # print(agent_malfunction_list) + #print(agent_malfunction_list) def test_malfunction_before_entry(): """Tests that malfunctions are working properly for agents before entering the environment!""" # Set fixed malfunction duration for this test - stochastic_data = {'malfunction_rate': 1, + stochastic_data = {'malfunction_rate': 0.0001, 'min_duration': 10, 'max_duration': 10} @@ -180,7 +174,7 @@ def test_malfunction_before_entry(): env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(seed=2), # seed 12 + schedule_generator=random_schedule_generator(seed=1), # seed 12 number_of_agents=10, random_seed=1, stochastic_data=stochastic_data, # Malfunction data generator @@ -191,15 +185,12 @@ def test_malfunction_before_entry(): # Test initial malfunction values for all agents # we want some agents to be malfuncitoning already and some to be working # we want different next_malfunction values for the agents - assert env.agents[1].malfunction_data['malfunction'] == 0 - assert env.agents[2].malfunction_data['malfunction'] == 0 - assert env.agents[3].malfunction_data['malfunction'] == 0 - assert env.agents[4].malfunction_data['malfunction'] == 0 - assert env.agents[5].malfunction_data['malfunction'] == 0 - assert env.agents[6].malfunction_data['malfunction'] == 0 - assert env.agents[7].malfunction_data['malfunction'] == 0 - assert env.agents[8].malfunction_data['malfunction'] == 0 - assert env.agents[9].malfunction_data['malfunction'] == 9 + + for a in range(10): + + print("assert env.agents[{}].malfunction_data['malfunction'] == {}".format(a,env.agents[a].malfunction_data['malfunction'])) + + def test_malfunction_values_and_behavior(): @@ -213,7 +204,7 @@ def test_malfunction_values_and_behavior(): rail, rail_map = make_simple_rail2() action_dict: Dict[int, RailEnvActions] = {} - stochastic_data = {'malfunction_rate': 5, + stochastic_data = {'malfunction_rate': 0.01, 'min_duration': 10, 'max_duration': 10} env = RailEnv(width=25, @@ -229,7 +220,7 @@ def test_malfunction_values_and_behavior(): env.reset(False, False, activate_agents=True, random_seed=10) # Assertions - assert_list = [9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 0, 9, 8, 7, 6] + assert_list = [8, 7, 6, 5, 4, 3, 2, 1, 0, 9, 8, 7, 6, 5, 4] print("[") for time_step in range(15): # Move in the env @@ -560,8 +551,7 @@ def test_last_malfunction_step(): """ # Set fixed malfunction duration for this test - stochastic_data = {'prop_malfunction': 1., - 'malfunction_rate': 5, + stochastic_data = {'malfunction_rate': 5, 'min_duration': 4, 'max_duration': 4} @@ -577,7 +567,7 @@ def test_last_malfunction_step(): ) env.reset() # reset to initialize agents_static - env.agents[0].speed_data['speed'] = 0.33 + env.agents[0].speed_data['speed'] = 1. / 3. env.agents_static[0].target = (0, 0) env.reset(False, False, True) @@ -585,24 +575,23 @@ def test_last_malfunction_step(): env.agents[0].malfunction_data['next_malfunction'] = 2 env.agents[0].malfunction_data['malfunction'] = 0 env_data = [] - for step in range(20): action_dict: Dict[int, RailEnvActions] = {} for agent in env.agents: # Go forward all the time action_dict[agent.handle] = RailEnvActions(2) - # Check if the agent is still allowed to move in this step - if env.agents[0].malfunction_data['malfunction'] > 0 or env.agents[0].malfunction_data['next_malfunction'] < 1: - agent_can_move = False - else: - agent_can_move = True + if env.agents[0].malfunction_data['malfunction'] < 1: + agent_can_move = True # Store the position before and after the step pre_position = env.agents[0].speed_data['position_fraction'] _, reward, _, _ = env.step(action_dict) - post_position = env.agents[0].speed_data['position_fraction'] + # Check if the agent is still allowed to move in this step + if env.agents[0].malfunction_data['malfunction'] > 0: + agent_can_move = False + post_position = env.agents[0].speed_data['position_fraction'] # Assert that the agent moved while it was still allowed if agent_can_move: assert pre_position != post_position