diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index ad160e4e8770b9994052563c4584a7211e8da3bc..0887c0ca59d1609d27b37aa318022750924f4ea3 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -14,6 +14,7 @@ from flatland.core.env import Environment from flatland.core.env_observation_builder import ObservationBuilder from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions from flatland.core.grid.grid4_utils import get_new_position +from flatland.core.grid.grid_utils import IntVector2D from flatland.core.transition_map import GridTransitionMap from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent, RailAgentStatus from flatland.envs.distance_map import DistanceMap @@ -218,6 +219,9 @@ class RailEnv(Environment): self.valid_positions = None + # global numpy array of agents position, True means that there is an agent at that cell + self.agent_positions: np.ndarray = np.full((height, width), False) + def _seed(self, seed=None): self.np_random, seed = seeding.np_random(seed) return [seed] @@ -243,7 +247,7 @@ class RailEnv(Environment): agent = self.agents[handle] if agent.status == RailAgentStatus.READY_TO_DEPART and self.cell_free(agent.initial_position): agent.status = RailAgentStatus.ACTIVE - agent.position = agent.initial_position + self._set_agent_to_initial_position(agent, agent.initial_position) def restart_agents(self): """ Reset the agents to their starting positions defined in agents_static @@ -340,6 +344,8 @@ class RailEnv(Environment): else: self._max_episode_steps = self.compute_max_episode_steps(width=self.width, height=self.height) + self.agent_positions = np.full((self.height, self.width), False) + self.restart_agents() if activate_agents: @@ -523,7 +529,7 @@ class RailEnv(Environment): if action in [RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT, RailEnvActions.MOVE_FORWARD] and self.cell_free(agent.initial_position): agent.status = RailAgentStatus.ACTIVE - agent.position = agent.initial_position + self._set_agent_to_initial_position(agent, agent.initial_position) self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] return else: @@ -618,7 +624,7 @@ class RailEnv(Environment): assert new_cell_valid assert transition_valid if cell_free: - agent.position = new_position + self._move_agent_to_new_position(agent, new_position) agent.direction = new_direction agent.speed_data['position_fraction'] = 0.0 @@ -627,16 +633,54 @@ class RailEnv(Environment): agent.status = RailAgentStatus.DONE self.dones[i_agent] = True agent.moving = False - - if self.remove_agents_at_target: - agent.position = None - agent.status = RailAgentStatus.DONE_REMOVED + self._remove_agent_from_scene(agent) else: self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] else: # step penalty if not moving (stopped now or before) self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] + def _set_agent_to_initial_position(self, agent: EnvAgent, new_position: IntVector2D): + """ + Sets the agent to its initial position. Updates the agent object and the position + of the agent inside the global agent_position numpy array + + Parameters + ------- + agent: EnvAgent object + new_position: IntVector2D + """ + agent.position = new_position + self.agent_positions[agent.position] = True + + def _move_agent_to_new_position(self, agent: EnvAgent, new_position: IntVector2D): + """ + Move the agent to the a new position. Updates the agent object and the position + of the agent inside the global agent_position numpy array + + Parameters + ------- + agent: EnvAgent object + new_position: IntVector2D + """ + agent.position = new_position + self.agent_positions[agent.old_position] = False + self.agent_positions[agent.position] = True + + def _remove_agent_from_scene(self, agent: EnvAgent): + """ + Remove the agent from the scene. Updates the agent object and the position + of the agent inside the global agent_position numpy array + + Parameters + ------- + agent: EnvAgent object + """ + self.agent_positions[agent.position] = False + if self.remove_agents_at_target: + agent.position = None + agent.status = RailAgentStatus.DONE_REMOVED + def _check_action_on_agent(self, action: RailEnvActions, agent: EnvAgent): """ @@ -673,12 +717,18 @@ class RailEnv(Environment): (*agent.position, agent.direction), new_direction) - # Check the new position is not the same as any of the existing agent positions - # (including itself, for simplicity, since it is moving) - cell_free = self.cell_free(new_position) + + # only call cell_free() if new cell is inside the scene + if new_cell_valid: + # Check the new position is not the same as any of the existing agent positions + # (including itself, for simplicity, since it is moving) + cell_free = self.cell_free(new_position) + else: + # if new cell is outside of scene -> cell_free is False + cell_free = False return cell_free, new_cell_valid, new_direction, new_position, transition_valid - def cell_free(self, position): + def cell_free(self, position: IntVector2D) -> bool: """ Utility to check if a cell is free @@ -692,9 +742,7 @@ class RailEnv(Environment): is the cell free or not? """ - agent_positions = [agent.position for agent in self.agents if agent.position is not None] - ret = len(agent_positions) == 0 or not np.any(np.equal(position, agent_positions).all(1)) - return ret + return not self.agent_positions[position] def check_action(self, agent: EnvAgent, action: RailEnvActions): """ @@ -790,7 +838,7 @@ class RailEnv(Environment): def set_full_state_msg(self, msg_data): """ Sets environment state with msgdata object passed as argument - + Parameters ------- msg_data: msgpack object @@ -809,7 +857,7 @@ class RailEnv(Environment): def set_full_state_dist_msg(self, msg_data): """ Sets environment grid state and distance map with msgdata object passed as argument - + Parameters ------- msg_data: msgpack object