diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 13da13cfd7d91ca79f75fb05202363a8d5a855d5..58faab24aa7354ac7d806dda770246493ee46248 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -406,7 +406,7 @@ class RailEnv(Environment): for agent in self.agents: # Ignore agents that OK - if agent.malfunction_data['malfunction'] < 1: + if self._agent_is_ok(agent): continue # Reduce number of malfunction steps left @@ -420,21 +420,6 @@ class RailEnv(Environment): agent.moving = agent.malfunction_data['moving_before_malfunction'] continue - def _draw_malfunctioning_agent(self, tries): - # Select only from active agents - breaking_agent_idx = self.np_random.choice(self.active_agents) - breaking_agent = self.agents[breaking_agent_idx] - # We assume that at least half of the agents should still be working - if tries > 0.5 * len(self.active_agents): - return None - - # If agent is already broken look for a new one - elif breaking_agent.malfunction_data['malfunction'] > 0: - return self._draw_malfunctioning_agent(tries+1) - - # Return agent to be broken - else: - return breaking_agent def _malfunction(self, rate): """ @@ -985,3 +970,47 @@ class RailEnv(Environment): return 0. else: return 1 - np.exp(- (1 / rate) * (n_agents)) + + def _draw_malfunctioning_agent(self, tries): + """ + Function to determin what agent will be breaking. + It only looks at active and non-broken agents. + After a number of steps it gives up the search after breaking agents and ignores malfunciton + + Parameters + ---------- + tries: How many times we tried to find an agent + + Returns + ------- + agent that is breaking + """ + # Select only from active agents + breaking_agent_idx = self.np_random.choice(self.active_agents) + breaking_agent = self.agents[breaking_agent_idx] + # We assume that at least half of the agents should still be working + if tries > 0.5 * len(self.active_agents): + return None + + # If agent is already broken look for a new one + elif breaking_agent.malfunction_data['malfunction'] > 0: + return self._draw_malfunctioning_agent(tries + 1) + + # Return agent to be broken + else: + return breaking_agent + + def _agent_is_ok(self, agent): + """ + Check if an agent is ok, meaning it can move and is not malfuncitoinig + Parameters + ---------- + agent + + Returns + ------- + True if agent is ok, False otherwise + + """ + return agent.malfunction_data['malfunction'] < 1 +