diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index ef8be5a8b2741d2d6be2aad23e399896a68f515c..034b6640e0e8f84fd689dde6c83c25002aeb2449 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -249,7 +249,7 @@ class RailEnv(Environment): """ Reset the agents to their starting positions defined in agents_static """ self.agents = EnvAgent.list_from_static(self.agents_static) - + self.active_agents = [i for i in range(len(self.agents))] @staticmethod def compute_max_episode_steps(width: int, height: int, ratio_nr_agents_to_nr_cities: float = 20.0) -> int: """ @@ -408,7 +408,12 @@ class RailEnv(Environment): """ if self.np_random.rand() < self._malfunction_prob(rate): - breaking_agent = self.np_random.choice(self.agents) + # Select only from agents that are not done yet + breaking_agent_idx = self.np_random.choice(self.active_agents) + breaking_agent = self.agents[breaking_agent_idx] + + # Only break agents that are not broken yet + # TODO: Do we want to guarantee that we have the desired rate or are we happy with lower rates? if breaking_agent.malfunction_data['malfunction'] < 1: num_broken_steps = self.np_random.randint(self.min_number_of_steps_broken, self.max_number_of_steps_broken + 1) @@ -458,6 +463,7 @@ class RailEnv(Environment): # Evoke the malfunction generator self._malfunction(self.mean_malfunction_rate) + for i_agent, agent in enumerate(self.agents): # Reset the step rewards self.rewards_dict[i_agent] = 0 @@ -613,6 +619,7 @@ class RailEnv(Environment): if np.equal(agent.position, agent.target).all(): agent.status = RailAgentStatus.DONE self.dones[i_agent] = True + self.active_agents.remove(i_agent) agent.moving = False self._remove_agent_from_scene(agent) else: @@ -698,7 +705,6 @@ class RailEnv(Environment): (*agent.position, agent.direction), new_direction) - # only call cell_free() if new cell is inside the scene if new_cell_valid: # Check the new position is not the same as any of the existing agent positions