Skip to content
Snippets Groups Projects
Commit 8638c7b6 authored by u229589's avatar u229589
Browse files

add agents position numpy matrix in RailEnv in order to speed up the cell_free check (#259)

parent 89002bf4
No related branches found
No related tags found
No related merge requests found
...@@ -14,6 +14,7 @@ from flatland.core.env import Environment ...@@ -14,6 +14,7 @@ from flatland.core.env import Environment
from flatland.core.env_observation_builder import ObservationBuilder from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions
from flatland.core.grid.grid4_utils import get_new_position from flatland.core.grid.grid4_utils import get_new_position
from flatland.core.grid.grid_utils import IntVector2D
from flatland.core.transition_map import GridTransitionMap from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent, RailAgentStatus from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent, RailAgentStatus
from flatland.envs.distance_map import DistanceMap from flatland.envs.distance_map import DistanceMap
...@@ -216,6 +217,7 @@ class RailEnv(Environment): ...@@ -216,6 +217,7 @@ class RailEnv(Environment):
# Reset environment # Reset environment
self.valid_positions = None self.valid_positions = None
self.agent_positions: np.ndarray = np.full((height, width), False)
def _seed(self, seed=None): def _seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed) self.np_random, seed = seeding.np_random(seed)
...@@ -242,7 +244,7 @@ class RailEnv(Environment): ...@@ -242,7 +244,7 @@ class RailEnv(Environment):
agent = self.agents[handle] agent = self.agents[handle]
if agent.status == RailAgentStatus.READY_TO_DEPART and self.cell_free(agent.initial_position): if agent.status == RailAgentStatus.READY_TO_DEPART and self.cell_free(agent.initial_position):
agent.status = RailAgentStatus.ACTIVE agent.status = RailAgentStatus.ACTIVE
agent.position = agent.initial_position self._set_agent_to_initial_position(agent, agent.initial_position)
def restart_agents(self): def restart_agents(self):
""" Reset the agents to their starting positions defined in agents_static """ Reset the agents to their starting positions defined in agents_static
...@@ -339,6 +341,8 @@ class RailEnv(Environment): ...@@ -339,6 +341,8 @@ class RailEnv(Environment):
else: else:
self._max_episode_steps = self.compute_max_episode_steps(width=self.width, height=self.height) self._max_episode_steps = self.compute_max_episode_steps(width=self.width, height=self.height)
self.agent_positions = np.full((self.height, self.width), False)
self.restart_agents() self.restart_agents()
if activate_agents: if activate_agents:
...@@ -515,7 +519,7 @@ class RailEnv(Environment): ...@@ -515,7 +519,7 @@ class RailEnv(Environment):
if action in [RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT, if action in [RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT,
RailEnvActions.MOVE_FORWARD] and self.cell_free(agent.initial_position): RailEnvActions.MOVE_FORWARD] and self.cell_free(agent.initial_position):
agent.status = RailAgentStatus.ACTIVE agent.status = RailAgentStatus.ACTIVE
agent.position = agent.initial_position self._set_agent_to_initial_position(agent, agent.initial_position)
self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
return return
else: else:
...@@ -610,7 +614,7 @@ class RailEnv(Environment): ...@@ -610,7 +614,7 @@ class RailEnv(Environment):
assert new_cell_valid assert new_cell_valid
assert transition_valid assert transition_valid
if cell_free: if cell_free:
agent.position = new_position self._move_agent_to_new_position(agent, new_position)
agent.direction = new_direction agent.direction = new_direction
agent.speed_data['position_fraction'] = 0.0 agent.speed_data['position_fraction'] = 0.0
...@@ -619,16 +623,28 @@ class RailEnv(Environment): ...@@ -619,16 +623,28 @@ class RailEnv(Environment):
agent.status = RailAgentStatus.DONE agent.status = RailAgentStatus.DONE
self.dones[i_agent] = True self.dones[i_agent] = True
agent.moving = False agent.moving = False
self._remove_agent_from_scene(agent)
if self.remove_agents_at_target:
agent.position = None
agent.status = RailAgentStatus.DONE_REMOVED
else: else:
self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
else: else:
# step penalty if not moving (stopped now or before) # step penalty if not moving (stopped now or before)
self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
def _set_agent_to_initial_position(self, agent: EnvAgent, new_position: IntVector2D):
agent.position = new_position
self.agent_positions[agent.position] = True
def _move_agent_to_new_position(self, agent: EnvAgent, new_position: IntVector2D):
agent.position = new_position
self.agent_positions[agent.old_position] = False
self.agent_positions[agent.position] = True
def _remove_agent_from_scene(self, agent: EnvAgent):
self.agent_positions[agent.position] = False
if self.remove_agents_at_target:
agent.position = None
agent.status = RailAgentStatus.DONE_REMOVED
def _check_action_on_agent(self, action: RailEnvActions, agent: EnvAgent): def _check_action_on_agent(self, action: RailEnvActions, agent: EnvAgent):
""" """
...@@ -670,11 +686,8 @@ class RailEnv(Environment): ...@@ -670,11 +686,8 @@ class RailEnv(Environment):
cell_free = self.cell_free(new_position) cell_free = self.cell_free(new_position)
return cell_free, new_cell_valid, new_direction, new_position, transition_valid return cell_free, new_cell_valid, new_direction, new_position, transition_valid
def cell_free(self, position): def cell_free(self, position: IntVector2D) -> bool:
return not self.agent_positions[position]
agent_positions = [agent.position for agent in self.agents if agent.position is not None]
ret = len(agent_positions) == 0 or not np.any(np.equal(position, agent_positions).all(1))
return ret
def check_action(self, agent: EnvAgent, action: RailEnvActions): def check_action(self, agent: EnvAgent, action: RailEnvActions):
""" """
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment