diff --git a/examples/introduction_flatland_2_1.py b/examples/introduction_flatland_2_1.py index dd17c285519cb35ee5d11a3ed1731f3e33a45c33..29fafe6436de2e04a780b34bd3eddba8a4533355 100644 --- a/examples/introduction_flatland_2_1.py +++ b/examples/introduction_flatland_2_1.py @@ -7,6 +7,7 @@ from flatland.envs.malfunction_generators import malfunction_from_params from flatland.envs.observations import GlobalObsForRailEnv # First of all we import the Flatland rail environment from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_env import RailEnvActions from flatland.envs.rail_generators import sparse_rail_generator from flatland.envs.schedule_generators import sparse_schedule_generator # We also include a renderer because we want to visualize what is going on in the environment @@ -28,8 +29,8 @@ from flatland.utils.rendertools import RenderTool, AgentRenderVariant # The railway infrastructure can be build using any of the provided generators in env/rail_generators.py # Here we use the sparse_rail_generator with the following parameters -width = 16*7 # With of map -height = 9*7 # Height of map +width = 16 * 7 # With of map +height = 9 * 7 # Height of map nr_trains = 20 # Number of trains that have an assigned task in the env cities_in_map = 20 # Number of cities where agents can start or end seed = 14 # Random seed @@ -104,7 +105,8 @@ class RandomAgent: :param state: input is the observation of the agent :return: returns an action """ - return np.random.choice([1, 2, 3, 4]) # [Left, Forward, Right, Stop] + return np.random.choice([RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_RIGHT, RailEnvActions.MOVE_LEFT, + RailEnvActions.STOP_MOVING]) def step(self, memories): """ diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py index f5b84a7d11910e1bf11c93a4bef5a955ad734ad3..2bb9677aab020560aeb28aad97edfa23efebe9bf 100644 --- a/flatland/envs/agent_utils.py +++ b/flatland/envs/agent_utils.py @@ -40,7 +40,7 @@ class EnvAgentStatic(object): malfunction_data = attrib( default=Factory( lambda: dict({'malfunction': 0, 'malfunction_rate': 0, 'next_malfunction': 0, 'nr_malfunctions': 0, - 'moving_before_malfunction': False, 'fixed': True}))) + 'moving_before_malfunction': False}))) status = attrib(default=RailAgentStatus.READY_TO_DEPART, type=RailAgentStatus) position = attrib(default=None, type=Optional[Tuple[int, int]]) @@ -65,8 +65,7 @@ class EnvAgentStatic(object): 'malfunction_rate': schedule.agent_malfunction_rates[ i] if schedule.agent_malfunction_rates is not None else 0., 'next_malfunction': 0, - 'nr_malfunctions': 0, - 'fixed': True}) + 'nr_malfunctions': 0}) return list(starmap(EnvAgentStatic, zip(schedule.agent_positions, schedule.agent_directions, diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 6da778ed40e917833c9460e9e08a8e4a516e5611..29638c6346f34ad39dbe1ad21f2bc7299a5e43ea 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -362,15 +362,17 @@ class RailEnv(Environment): for i_agent in range(self.get_num_agents()): self.set_agent_active(i_agent) - # Induce malfunctions - self._malfunction(self.mean_malfunction_rate) + for agent in self.agents: + # Induce malfunctions + self._break_agent(self.mean_malfunction_rate, agent) + if agent.malfunction_data["malfunction"] > 0: agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.DO_NOTHING - # Fix agents that finished their malfunciton - self._fix_agents() + # Fix agents that finished their malfunction + self._fix_agent(agent) self.num_resets += 1 self._elapsed_steps = 0 @@ -394,64 +396,48 @@ class RailEnv(Environment): observation_dict: Dict = self._get_observations() return observation_dict, info_dict - def _fix_agents(self): + def _fix_agent(self, agent): """ Updates agent malfunction variables and fixes broken agents - """ - for agent in self.agents: - # Ignore agents that OK - if agent.malfunction_data['fixed']: - continue + Parameters + ---------- + agent + """ - # Reduce number of malfunction steps left - if agent.malfunction_data['malfunction'] > 1: - agent.malfunction_data['malfunction'] -= 1 - continue + # Ignore agents that are OK + if self._is_ok(agent): + return - # Restart agents at the end of their malfunction + # Reduce number of malfunction steps left + if agent.malfunction_data['malfunction'] > 1: agent.malfunction_data['malfunction'] -= 1 - agent.malfunction_data['fixed'] = True - if 'moving_before_malfunction' in agent.malfunction_data: - agent.moving = agent.malfunction_data['moving_before_malfunction'] - continue + return - def _malfunction(self, rate): - """ - Malfunction generator that breaks agents at a given rate. It does randomly chose agent to break during the run + # Restart agents at the end of their malfunction + agent.malfunction_data['malfunction'] -= 1 + if 'moving_before_malfunction' in agent.malfunction_data: + agent.moving = agent.malfunction_data['moving_before_malfunction'] + return + def _break_agent(self, rate, agent): """ - if self.np_random.rand() < self._malfunction_prob(rate, len(self.active_agents)): - # Select only from agents that are not done yet - breaking_agent_idx = self.np_random.choice(self.active_agents) - breaking_agent = self.agents[breaking_agent_idx] + Malfunction generator that breaks agents at a given rate. - # We assume that less then half of the active agents should be broken at MOST. - # Therefore we only try that many times before ignoring the malfunction - - tries = 0 - max_tries = 0.5 * len(self.active_agents) - - # Look for a functioning active agent - while breaking_agent.malfunction_data['malfunction'] > 0 and tries < max_tries: - breaking_agent_idx = self.np_random.choice(self.active_agents) - breaking_agent = self.agents[breaking_agent_idx] - tries += 1 + Parameters + ---------- + agent - # If we did not manage to find a functioning agent among the active ones skip this malfunction - if tries < max_tries: - # Because we update agents in the same step as we break them we add one to the duration of the - # malfunction + """ + if agent.malfunction_data['malfunction'] < 1: + if self.np_random.rand() < self._malfunction_prob(rate): num_broken_steps = self.np_random.randint(self.min_number_of_steps_broken, - self.max_number_of_steps_broken + 1) + 1 - breaking_agent.malfunction_data['malfunction'] = num_broken_steps - breaking_agent.malfunction_data['moving_before_malfunction'] = breaking_agent.moving - breaking_agent.malfunction_data['fixed'] = False - breaking_agent.malfunction_data['nr_malfunctions'] += 1 + self.max_number_of_steps_broken + 1) + 1 + agent.malfunction_data['malfunction'] = num_broken_steps + agent.malfunction_data['moving_before_malfunction'] = agent.moving + agent.malfunction_data['nr_malfunctions'] += 1 + return - return - - return def step(self, action_dict_: Dict[int, RailEnvActions]): """ @@ -492,13 +478,15 @@ class RailEnv(Environment): } have_all_agents_ended = True # boolean flag to check if all agents are done - # Induce malfunctions - self._malfunction(self.mean_malfunction_rate) + for i_agent, agent in enumerate(self.agents): # Reset the step rewards self.rewards_dict[i_agent] = 0 + # Induce malfunction before we do a step, thus a broken agent can't move in this step + self._break_agent(self.mean_malfunction_rate, agent) + # Perform step on the agent self._step_agent(i_agent, action_dict_.get(i_agent)) @@ -511,8 +499,8 @@ class RailEnv(Environment): info_dict["speed"][i_agent] = agent.speed_data['speed'] info_dict["status"][i_agent] = agent.status - # Fix agents that finished their malfunction - self._fix_agents() + # Fix agents that finished their malfunction such that they can perfom an action in the next step + self._fix_agent(agent) # Check for end of episode + set global reward to all rewards! if have_all_agents_ended: @@ -986,7 +974,7 @@ class RailEnv(Environment): x = - np.log(1 - u) * rate return x - def _malfunction_prob(self, rate, n_agents): + def _malfunction_prob(self, rate): """ Probability that an agent break given the number of agents an the probability of a sinlge agent to break :param rate: @@ -995,4 +983,48 @@ class RailEnv(Environment): if rate <= 0: return 0. else: - return 1 - np.exp(- (1 / rate) * (n_agents)) + return 1 - np.exp(- (1 / rate)) + + def _draw_malfunctioning_agent(self, tries): + """ + Function to determin what agent will be breaking. + It only looks at active and non-broken agents. + After a number of steps it gives up the search after breaking agents and ignores malfunciton + + Parameters + ---------- + tries: How many times we tried to find an agent + + Returns + ------- + agent that is breaking + """ + # Select only from active agents + breaking_agent_idx = self.np_random.choice(self.active_agents) + breaking_agent = self.agents[breaking_agent_idx] + # We assume that at least half of the agents should still be working + if tries > 0.5 * len(self.active_agents): + return None + + # If agent is already broken look for a new one + elif breaking_agent.malfunction_data['malfunction'] > 0: + return self._draw_malfunctioning_agent(tries + 1) + + # Return agent to be broken + else: + return breaking_agent + + def _is_ok(self, agent): + """ + Check if an agent is ok, meaning it can move and is not malfuncitoinig + Parameters + ---------- + agent + + Returns + ------- + True if agent is ok, False otherwise + + """ + return agent.malfunction_data['malfunction'] < 1 + diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index 14f9b6c0295306448d27a83444e3dd6496485cf1..b2c1ca1162e476ff6e2f4fc3f8489428af23535e 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -106,7 +106,7 @@ def test_malfunction_process(): total_down_time += env.agents[0].malfunction_data['malfunction'] # Check that the appropriate number of malfunctions is achieved - assert env.agents[0].malfunction_data['nr_malfunctions'] == 22, "Actual {}".format( + assert env.agents[0].malfunction_data['nr_malfunctions'] == 23, "Actual {}".format( env.agents[0].malfunction_data['nr_malfunctions']) # Check that malfunctioning data was standing around @@ -132,16 +132,16 @@ def test_malfunction_process_statistically(): env.agents[0].target = (0, 0) # Next line only for test generation #agent_malfunction_list = [[] for i in range(10)] - agent_malfunction_list = [[0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1], - [0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0], - [0, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 5, 4], - [0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0], - [0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 5, 4, 3, 2], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 5], - [0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 5, 4, 3], - [0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]] + agent_malfunction_list = [[0, 5, 4, 3, 2, 1, 0, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1], + [0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 5, 4, 3, 2, 1, 0, 0], + [5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1, 0, 0, 0, 5, 4, 3, 2, 1, 0], + [0, 5, 4, 3, 2, 1, 0, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1], + [0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1], + [0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2], + [5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [5, 4, 3, 2, 1, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0]] for step in range(20): action_dict: Dict[int, RailEnvActions] = {} @@ -158,7 +158,7 @@ def test_malfunction_process_statistically(): def test_malfunction_before_entry(): """Tests that malfunctions are working properly for agents before entering the environment!""" # Set fixed malfunction duration for this test - stochastic_data = {'malfunction_rate': 0.0001, + stochastic_data = {'malfunction_rate': 2, 'min_duration': 10, 'max_duration': 10} @@ -176,16 +176,17 @@ def test_malfunction_before_entry(): # we want different next_malfunction values for the agents assert env.agents[0].malfunction_data['malfunction'] == 0 assert env.agents[1].malfunction_data['malfunction'] == 0 - assert env.agents[2].malfunction_data['malfunction'] == 0 + assert env.agents[2].malfunction_data['malfunction'] == 10 assert env.agents[3].malfunction_data['malfunction'] == 0 assert env.agents[4].malfunction_data['malfunction'] == 0 - assert env.agents[5].malfunction_data['malfunction'] == 10 + assert env.agents[5].malfunction_data['malfunction'] == 0 assert env.agents[6].malfunction_data['malfunction'] == 0 assert env.agents[7].malfunction_data['malfunction'] == 0 - assert env.agents[8].malfunction_data['malfunction'] == 0 - assert env.agents[9].malfunction_data['malfunction'] == 0 - # for a in range(10): - # print("assert env.agents[{}].malfunction_data['malfunction'] == {}".format(a,env.agents[a].malfunction_data['malfunction'])) + assert env.agents[8].malfunction_data['malfunction'] == 10 + assert env.agents[9].malfunction_data['malfunction'] == 10 + + #for a in range(10): + # print("assert env.agents[{}].malfunction_data['malfunction'] == {}".format(a,env.agents[a].malfunction_data['malfunction'])) def test_malfunction_values_and_behavior():