diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index fc73cd6ce481887477d50905c98d721032a08ff1..cb2fdb7e9e2bad19ce90f70a74e980d9bd7f75cf 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -421,11 +421,21 @@ class RailEnv(Environment): agent.moving = agent.malfunction_data['moving_before_malfunction'] continue - def _draw_malfunctioning_agent(self): + def _draw_malfunctioning_agent(self, tries): # Select only from active agents breaking_agent_idx = self.np_random.choice(self.active_agents) breaking_agent = self.agents[breaking_agent_idx] - return breaking_agent + # We assume that at least half of the agents should still be working + if tries > 0.5 * len(self.active_agents): + return None + + # If agent is already broken look for a new one + elif breaking_agent.malfunction_data['malfunction'] > 0: + return self._draw_malfunctioning_agent(tries+1) + + # Return agent to be broken + else: + return breaking_agent def _malfunction(self, rate): """ @@ -434,30 +444,16 @@ class RailEnv(Environment): """ if self.np_random.rand() < self._malfunction_prob(rate, len(self.active_agents)): - breaking_agent = self._draw_malfunctioning_agent() - # We assume that less then half of the active agents should be broken at MOST. - # Therefore we only try that many times before ignoring the malfunction - tries = 0 - max_tries = 0.5 * len(self.active_agents) - # Look for a functioning active agent - while breaking_agent.malfunction_data['malfunction'] > 0 and tries < max_tries: - breaking_agent = self._draw_malfunctioning_agent() - tries += 1 - - # If we did not manage to find a functioning agent among the active ones skip this malfunction - if tries < max_tries: - # Because we update agents in the same step as we break them we add one to the duration of the - # malfunction + breaking_agent = self._draw_malfunctioning_agent(0) + if breaking_agent: num_broken_steps = self.np_random.randint(self.min_number_of_steps_broken, - self.max_number_of_steps_broken + 1) + 1 + self.max_number_of_steps_broken + 1) + 1 breaking_agent.malfunction_data['malfunction'] = num_broken_steps breaking_agent.malfunction_data['moving_before_malfunction'] = breaking_agent.moving breaking_agent.malfunction_data['fixed'] = False breaking_agent.malfunction_data['nr_malfunctions'] += 1 + return - return - - return def step(self, action_dict_: Dict[int, RailEnvActions]): """