diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 053229e56680284998fd2fb7896c7d5246cee30d..e11243046256a28f04913f40ef7ef29539e90c39 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -367,15 +367,16 @@ class RailEnv(Environment): for i_agent in range(self.get_num_agents()): self.set_agent_active(i_agent) - # See if agents are already broken + # Induce malfunctions self._malfunction(self.mean_malfunction_rate) - for i_agent, agent in enumerate(self.agents): - initial_malfunction = self._agent_malfunction(i_agent) - - if initial_malfunction: + for agent in self.agents: + if agent.malfunction_data["malfunction"] > 0: agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.DO_NOTHING + # Fix agents that finished their malfunciton + self._fix_agents() + self.num_resets += 1 self._elapsed_steps = 0 @@ -398,26 +399,27 @@ class RailEnv(Environment): observation_dict: Dict = self._get_observations() return observation_dict, info_dict - def _agent_malfunction(self, i_agent) -> bool: + def _fix_agents(self): """ - Returns true if the agent enters into malfunction. (False, if not broken down or already broken down before). + Updates agent malfunction variables and fixes broken agents """ - agent = self.agents[i_agent] + for agent in self.agents: - # Reduce number of malfunction steps left - if agent.malfunction_data['malfunction'] > 0: - agent.malfunction_data['malfunction'] -= 1 - return True + # Ignore agents that OK + if agent.malfunction_data['fixed']: + continue - # Ignore agents that OK - if agent.malfunction_data['fixed']: - return False + # Reduce number of malfunction steps left + if agent.malfunction_data['malfunction'] > 1: + agent.malfunction_data['malfunction'] -= 1 + continue - # Restart agents at the end of their malfunction - agent.malfunction_data['fixed'] = True - if 'moving_before_malfunction' in agent.malfunction_data: - self.agents[i_agent].moving = agent.malfunction_data['moving_before_malfunction'] - return False + # Restart agents at the end of their malfunction + agent.malfunction_data['malfunction'] -= 1 + agent.malfunction_data['fixed'] = True + if 'moving_before_malfunction' in agent.malfunction_data: + agent.moving = agent.malfunction_data['moving_before_malfunction'] + continue def _malfunction(self, rate): """ @@ -434,7 +436,7 @@ class RailEnv(Environment): # TODO: Do we want to guarantee that we have the desired rate or are we happy with lower rates? if breaking_agent.malfunction_data['malfunction'] < 1: num_broken_steps = self.np_random.randint(self.min_number_of_steps_broken, - self.max_number_of_steps_broken + 1) + self.max_number_of_steps_broken + 1) + 1 breaking_agent.malfunction_data['malfunction'] = num_broken_steps breaking_agent.malfunction_data['moving_before_malfunction'] = breaking_agent.moving breaking_agent.malfunction_data['fixed'] = False @@ -479,7 +481,7 @@ class RailEnv(Environment): } have_all_agents_ended = True # boolean flag to check if all agents are done - # Evoke the malfunction generator + # Induce malfunctions self._malfunction(self.mean_malfunction_rate) for i_agent, agent in enumerate(self.agents): @@ -498,6 +500,9 @@ class RailEnv(Environment): info_dict["speed"][i_agent] = agent.speed_data['speed'] info_dict["status"][i_agent] = agent.status + # Fix agents that finished their malfunction + self._fix_agents() + # Check for end of episode + set global reward to all rewards! if have_all_agents_ended: self.dones["__all__"] = True @@ -542,12 +547,9 @@ class RailEnv(Environment): agent.old_direction = agent.direction agent.old_position = agent.position - # is the agent malfunctioning? - malfunction = self._agent_malfunction(i_agent) - # if agent is broken, actions are ignored and agent does not move. # full step penalty in this case - if malfunction: + if agent.malfunction_data['malfunction'] > 0: self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] return diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index b783fe7ab6e18e1ac532eef2fc63e496d6c3bbb3..68cd6f495c2d9ff099f228e230ef92c1b5e700fa 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -110,7 +110,7 @@ def test_malfunction_process(): total_down_time += env.agents[0].malfunction_data['malfunction'] # Check that the appropriate number of malfunctions is achieved - assert env.agents[0].malfunction_data['nr_malfunctions'] == 28, "Actual {}".format( + assert env.agents[0].malfunction_data['nr_malfunctions'] == 22, "Actual {}".format( env.agents[0].malfunction_data['nr_malfunctions']) # Check that malfunctioning data was standing around @@ -140,17 +140,17 @@ def test_malfunction_process_statistically(): env.agents[0].target = (0, 0) # Next line only for test generation - #agent_malfunction_list = [[] for i in range(20)] - agent_malfunction_list = [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 3, 2, 1, 0], - [0, 0, 0, 0, 0, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 5, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [], [], [], [], [], [], [], [], [], []] + #agent_malfunction_list = [[] for i in range(10)] + agent_malfunction_list = [[0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1], + [0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5], + [0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]] for step in range(20): action_dict: Dict[int, RailEnvActions] = {} @@ -188,17 +188,25 @@ def test_malfunction_before_entry(): # Test initial malfunction values for all agents # we want some agents to be malfuncitoning already and some to be working # we want different next_malfunction values for the agents - - for a in range(10): - - print("assert env.agents[{}].malfunction_data['malfunction'] == {}".format(a,env.agents[a].malfunction_data['malfunction'])) + assert env.agents[0].malfunction_data['malfunction'] == 0 + assert env.agents[1].malfunction_data['malfunction'] == 0 + assert env.agents[2].malfunction_data['malfunction'] == 0 + assert env.agents[3].malfunction_data['malfunction'] == 0 + assert env.agents[4].malfunction_data['malfunction'] == 0 + assert env.agents[5].malfunction_data['malfunction'] == 10 + assert env.agents[6].malfunction_data['malfunction'] == 0 + assert env.agents[7].malfunction_data['malfunction'] == 0 + assert env.agents[8].malfunction_data['malfunction'] == 0 + assert env.agents[9].malfunction_data['malfunction'] == 0 + #for a in range(10): + # print("assert env.agents[{}].malfunction_data['malfunction'] == {}".format(a,env.agents[a].malfunction_data['malfunction'])) def test_malfunction_values_and_behavior(): """ - Test that the next malfunction occurs when desired. + Test the malfunction counts down as desired Returns ------- @@ -207,7 +215,7 @@ def test_malfunction_values_and_behavior(): rail, rail_map = make_simple_rail2() action_dict: Dict[int, RailEnvActions] = {} - stochastic_data = {'malfunction_rate': 0.01, + stochastic_data = {'malfunction_rate': 0.001, 'min_duration': 10, 'max_duration': 10} env = RailEnv(width=25, @@ -223,7 +231,7 @@ def test_malfunction_values_and_behavior(): env.reset(False, False, activate_agents=True, random_seed=10) # Assertions - assert_list = [8, 7, 6, 5, 4, 3, 2, 1, 0, 9, 8, 7, 6, 5, 4] + assert_list = [9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 10, 9, 8, 7, 6, 5] print("[") for time_step in range(15): # Move in the env @@ -233,8 +241,7 @@ def test_malfunction_values_and_behavior(): def test_initial_malfunction(): - stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents - 'malfunction_rate': 100, # Rate of malfunction occurence + stochastic_data = {'malfunction_rate': 1000, # Rate of malfunction occurence 'min_duration': 2, # Minimal duration of malfunction 'max_duration': 5 # Max duration of malfunction } @@ -278,7 +285,7 @@ def test_initial_malfunction(): direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.MOVE_FORWARD, malfunction=1, - reward=env.step_penalty * 1.0 + reward=env.step_penalty ), # malfunctioning ends: starting and running at speed 1.0 Replay( @@ -293,7 +300,7 @@ def test_initial_malfunction(): direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.MOVE_FORWARD, malfunction=0, - reward=env.step_penalty * 1.0 # running at speed 1.0 + reward=env.step_penalty # running at speed 1.0 ) ], speed=env.agents[0].speed_data['speed'], @@ -341,7 +348,7 @@ def test_initial_malfunction_stop_moving(): position=(3, 2), direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.DO_NOTHING, - malfunction=3, + malfunction=2, reward=env.step_penalty, # full step penalty when stopped status=RailAgentStatus.ACTIVE ), @@ -352,7 +359,7 @@ def test_initial_malfunction_stop_moving(): position=(3, 2), direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.STOP_MOVING, - malfunction=2, + malfunction=1, reward=env.step_penalty, # full step penalty while stopped status=RailAgentStatus.ACTIVE ), @@ -361,7 +368,7 @@ def test_initial_malfunction_stop_moving(): position=(3, 2), direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.DO_NOTHING, - malfunction=1, + malfunction=0, reward=env.step_penalty, # full step penalty while stopped status=RailAgentStatus.ACTIVE ), @@ -429,7 +436,7 @@ def test_initial_malfunction_do_nothing(): position=(3, 2), direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.DO_NOTHING, - malfunction=3, + malfunction=2, reward=env.step_penalty, # full step penalty while malfunctioning status=RailAgentStatus.ACTIVE ), @@ -440,7 +447,7 @@ def test_initial_malfunction_do_nothing(): position=(3, 2), direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.DO_NOTHING, - malfunction=2, + malfunction=1, reward=env.step_penalty, # full step penalty while stopped status=RailAgentStatus.ACTIVE ), @@ -449,7 +456,7 @@ def test_initial_malfunction_do_nothing(): position=(3, 2), direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.DO_NOTHING, - malfunction=1, + malfunction=0, reward=env.step_penalty, # full step penalty while stopped status=RailAgentStatus.ACTIVE ), diff --git a/tests/test_utils.py b/tests/test_utils.py index 80656cbb490c39fc7352327faa9bf5859a1123e9..ff4948d629747a3394644b971105178ec1ac4523 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -119,8 +119,9 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: # We also set next malfunction to infitiy to avoid interference with our tests agent.malfunction_data['malfunction'] = replay.set_malfunction agent.malfunction_data['moving_before_malfunction'] = agent.moving + agent.malfunction_data['fixed'] = False _assert(a, agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction') - print(step) + print(step, agent.moving, agent.malfunction_data['fixed'], agent.malfunction_data['malfunction']) _, rewards_dict, _, info_dict = env.step(action_dict) if rendering: renderer.render_env(show=True, show_observations=True)