diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 0887c0ca59d1609d27b37aa318022750924f4ea3..22bd21625fd4cc0fb50c592bf4533e60ced00546 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -280,6 +280,24 @@ class RailEnv(Environment): alpha = 2 return int(timedelay_factor * alpha * (width + height + ratio_nr_agents_to_nr_cities)) + def action_required(self, agent): + """ + Check if an agent needs to provide an action + + Parameters + ---------- + agent: RailEnvAgent + Agent we want to check + + Returns + ------- + True: Agent needs to provide an action + False: Agent cannot provide an action + """ + return (agent.status == RailAgentStatus.READY_TO_DEPART or ( + agent.status == RailAgentStatus.ACTIVE and np.isclose(agent.speed_data['position_fraction'], 0.0, + rtol=1e-03))) + def reset(self, regenerate_rail: bool = True, regenerate_schedule: bool = True, activate_agents: bool = False, random_seed: bool = None) -> (Dict, Dict): """ @@ -339,8 +357,8 @@ class RailEnv(Environment): if agents_hints and 'city_orientations' in agents_hints: ratio_nr_agents_to_nr_cities = self.get_num_agents() / len(agents_hints['city_orientations']) self._max_episode_steps = self.compute_max_episode_steps( - width=self.width, height=self.height, - ratio_nr_agents_to_nr_cities=ratio_nr_agents_to_nr_cities) + width=self.width, height=self.height, + ratio_nr_agents_to_nr_cities=ratio_nr_agents_to_nr_cities) else: self._max_episode_steps = self.compute_max_episode_steps(width=self.width, height=self.height) @@ -377,10 +395,7 @@ class RailEnv(Environment): self.distance_map.reset(self.agents, self.rail) info_dict: Dict = { - 'action_required': { - i: (agent.status == RailAgentStatus.READY_TO_DEPART or ( - agent.status == RailAgentStatus.ACTIVE and agent.speed_data['position_fraction'] == 0.0)) - for i, agent in enumerate(self.agents)}, + 'action_required': {i: self.action_required(agent) for i, agent in enumerate(self.agents)}, 'malfunction': { i: self.agents[i].malfunction_data['malfunction'] for i in range(self.get_num_agents()) }, @@ -454,10 +469,10 @@ class RailEnv(Environment): if self.dones["__all__"]: self.rewards_dict = {} info_dict = { - "action_required" : {}, - "malfunction" : {}, - "speed" : {}, - "status" : {}, + "action_required": {}, + "malfunction": {}, + "speed": {}, + "status": {}, } for i_agent, agent in enumerate(self.agents): self.rewards_dict[i_agent] = self.global_reward @@ -471,12 +486,12 @@ class RailEnv(Environment): # Reset the step rewards self.rewards_dict = dict() info_dict = { - "action_required" : {}, - "malfunction" : {}, - "speed" : {}, - "status" : {}, + "action_required": {}, + "malfunction": {}, + "speed": {}, + "status": {}, } - have_all_agents_ended = True # boolean flag to check if all agents are done + have_all_agents_ended = True # boolean flag to check if all agents are done for i_agent, agent in enumerate(self.agents): # Reset the step rewards self.rewards_dict[i_agent] = 0 @@ -488,10 +503,7 @@ class RailEnv(Environment): have_all_agents_ended &= (agent.status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED]) # Build info dict - info_dict["action_required"][i_agent] = \ - (agent.status == RailAgentStatus.READY_TO_DEPART or ( - agent.status == RailAgentStatus.ACTIVE and np.isclose(agent.speed_data['position_fraction'], 0.0, - rtol=1e-03))) + info_dict["action_required"][i_agent] = self.action_required(agent) info_dict["malfunction"][i_agent] = agent.malfunction_data['malfunction'] info_dict["speed"][i_agent] = agent.speed_data['speed'] info_dict["status"][i_agent] = agent.status