diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index cb4da95a81758625f1f45ecb535e47d834fd5005..c141bb89c9a064026ea118b5bed1e85e65dcaae9 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -14,6 +14,7 @@ from flatland.core.env import Environment from flatland.core.env_observation_builder import ObservationBuilder from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions from flatland.core.grid.grid4_utils import get_new_position +from flatland.core.grid.grid_utils import IntVector2D from flatland.core.transition_map import GridTransitionMap from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent, RailAgentStatus from flatland.envs.distance_map import DistanceMap @@ -216,6 +217,7 @@ class RailEnv(Environment): # Reset environment self.valid_positions = None + self.agent_positions: np.ndarray = np.full((height, width), False) def _seed(self, seed=None): self.np_random, seed = seeding.np_random(seed) @@ -242,7 +244,7 @@ class RailEnv(Environment): agent = self.agents[handle] if agent.status == RailAgentStatus.READY_TO_DEPART and self.cell_free(agent.initial_position): agent.status = RailAgentStatus.ACTIVE - agent.position = agent.initial_position + self._set_agent_to_initial_position(agent, agent.initial_position) def restart_agents(self): """ Reset the agents to their starting positions defined in agents_static @@ -339,6 +341,8 @@ class RailEnv(Environment): else: self._max_episode_steps = self.compute_max_episode_steps(width=self.width, height=self.height) + self.agent_positions = np.full((self.height, self.width), False) + self.restart_agents() if activate_agents: @@ -515,7 +519,7 @@ class RailEnv(Environment): if action in [RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT, RailEnvActions.MOVE_FORWARD] and self.cell_free(agent.initial_position): agent.status = RailAgentStatus.ACTIVE - agent.position = agent.initial_position + self._set_agent_to_initial_position(agent, agent.initial_position) self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] return else: @@ -610,7 +614,7 @@ class RailEnv(Environment): assert new_cell_valid assert transition_valid if cell_free: - agent.position = new_position + self._move_agent_to_new_position(agent, new_position) agent.direction = new_direction agent.speed_data['position_fraction'] = 0.0 @@ -619,16 +623,28 @@ class RailEnv(Environment): agent.status = RailAgentStatus.DONE self.dones[i_agent] = True agent.moving = False - - if self.remove_agents_at_target: - agent.position = None - agent.status = RailAgentStatus.DONE_REMOVED + self._remove_agent_from_scene(agent) else: self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] else: # step penalty if not moving (stopped now or before) self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] + def _set_agent_to_initial_position(self, agent: EnvAgent, new_position: IntVector2D): + agent.position = new_position + self.agent_positions[agent.position] = True + + def _move_agent_to_new_position(self, agent: EnvAgent, new_position: IntVector2D): + agent.position = new_position + self.agent_positions[agent.old_position] = False + self.agent_positions[agent.position] = True + + def _remove_agent_from_scene(self, agent: EnvAgent): + self.agent_positions[agent.position] = False + if self.remove_agents_at_target: + agent.position = None + agent.status = RailAgentStatus.DONE_REMOVED + def _check_action_on_agent(self, action: RailEnvActions, agent: EnvAgent): """ @@ -670,11 +686,8 @@ class RailEnv(Environment): cell_free = self.cell_free(new_position) return cell_free, new_cell_valid, new_direction, new_position, transition_valid - def cell_free(self, position): - - agent_positions = [agent.position for agent in self.agents if agent.position is not None] - ret = len(agent_positions) == 0 or not np.any(np.equal(position, agent_positions).all(1)) - return ret + def cell_free(self, position: IntVector2D) -> bool: + return not self.agent_positions[position] def check_action(self, agent: EnvAgent, action: RailEnvActions): """