diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index f284f3ac38b480d8feb6c0b4944cf8831d2a70d2..fc73cd6ce481887477d50905c98d721032a08ff1 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -421,26 +421,27 @@ class RailEnv(Environment): agent.moving = agent.malfunction_data['moving_before_malfunction'] continue + def _draw_malfunctioning_agent(self): + # Select only from active agents + breaking_agent_idx = self.np_random.choice(self.active_agents) + breaking_agent = self.agents[breaking_agent_idx] + return breaking_agent + def _malfunction(self, rate): """ Malfunction generator that breaks agents at a given rate. It does randomly chose agent to break during the run """ if self.np_random.rand() < self._malfunction_prob(rate, len(self.active_agents)): - # Select only from agents that are not done yet - breaking_agent_idx = self.np_random.choice(self.active_agents) - breaking_agent = self.agents[breaking_agent_idx] + breaking_agent = self._draw_malfunctioning_agent() # We assume that less then half of the active agents should be broken at MOST. # Therefore we only try that many times before ignoring the malfunction - tries = 0 max_tries = 0.5 * len(self.active_agents) - # Look for a functioning active agent while breaking_agent.malfunction_data['malfunction'] > 0 and tries < max_tries: - breaking_agent_idx = self.np_random.choice(self.active_agents) - breaking_agent = self.agents[breaking_agent_idx] + breaking_agent = self._draw_malfunctioning_agent() tries += 1 # If we did not manage to find a functioning agent among the active ones skip this malfunction