From c6258faf13c1f35a941828b842c4fa227f5545b6 Mon Sep 17 00:00:00 2001 From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch> Date: Tue, 1 Oct 2019 18:17:35 +0200 Subject: [PATCH] remove agent when they have reached their destinations (target) --- flatland/envs/rail_env.py | 33 ++++++++++++++++++--------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index f79ab504..b64716f5 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -24,7 +24,6 @@ from flatland.envs.schedule_generators import random_schedule_generator, Schedul m.patch() -DEPOT_POSITION = (-10, -10) class RailEnvActions(IntEnum): DO_NOTHING = 0 # implies change of direction in a dead-end! @@ -110,6 +109,10 @@ class RailEnv(Environment): stop_penalty = 0 # penalty for stopping a moving agent start_penalty = 0 # penalty for starting a stopped agent + # Where the agent will be placed to when the reach their target destination + # (Remove the agents to free the cell) + DEPOT_POSITION = (-10, -10) + def __init__(self, width, height, @@ -118,7 +121,8 @@ class RailEnv(Environment): number_of_agents=1, obs_builder_object: ObservationBuilder = TreeObsForRailEnv(max_depth=2), max_episode_steps=None, - stochastic_data=None + stochastic_data=None, + remove_agents_at_target=False ): """ Environment init. @@ -149,6 +153,9 @@ class RailEnv(Environment): ObservationBuilder-derived object that takes builds observation vectors for each agent. max_episode_steps : int or None + remove_agents_at_target : bool + If remove_agents_at_target is set to true then the agents will be removed by placing to + RailEnv.DEPOT_POSITION when the agent has reach it's target position. """ super().__init__() @@ -159,6 +166,8 @@ class RailEnv(Environment): self.width = width self.height = height + self.remove_agents_at_target = remove_agents_at_target + self.rewards = [0] * number_of_agents self.done = False self.obs_builder = obs_builder_object @@ -503,14 +512,16 @@ class RailEnv(Environment): if np.equal(agent.position, agent.target).all(): self.dones[i_agent] = True agent.moving = False - # TODO: Moving agents to arbitrary position - agent.position = DEPOT_POSITION - agent.target = DEPOT_POSITION + + if self.remove_agents_at_target: + agent.position = RailEnv.DEPOT_POSITION + agent.target = RailEnv.DEPOT_POSITION else: self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] else: # step penalty if not moving (stopped now or before) self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] + def _check_action_on_agent(self, action: RailEnvActions, agent: EnvAgent): """ @@ -549,15 +560,7 @@ class RailEnv(Environment): # Check the new position is not the same as any of the existing agent positions # (including itself, for simplicity, since it is moving) - # TODO: Revert to earlier version - cell_free = True - for agent2 in self.agents: - if Vec2dOperations.is_equal(new_position, agent2.position) and not Vec2dOperations.is_equal(agent2.target, - agent2.position): - cell_free = False - break - - + cell_free = not np.any(np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1)) return cell_free, new_cell_valid, new_direction, new_position, transition_valid def check_action(self, agent: EnvAgent, action: RailEnvActions): @@ -605,7 +608,7 @@ class RailEnv(Environment): return self.obs_dict def get_valid_directions_on_grid(self, row: int, col: int) -> List[int]: - return Grid4Transitions.get_entry_directions(self.rail.get_full_transitions(row,col)) + return Grid4Transitions.get_entry_directions(self.rail.get_full_transitions(row, col)) def get_full_state_msg(self): grid_data = self.rail.grid.tolist() -- GitLab