From 5bf451ebc408d7cf5abe06535f7fc588874e8c89 Mon Sep 17 00:00:00 2001 From: Giacomo Spigler <spiglerg@gmail.com> Date: Wed, 19 Jun 2019 18:44:47 +0200 Subject: [PATCH] prevent stopping in the middle of a cell --- flatland/envs/agent_utils.py | 25 +++++-- flatland/envs/rail_env.py | 126 ++++++++++++++++++++++------------- 2 files changed, 98 insertions(+), 53 deletions(-) diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py index 8e9ffb9..2d07eee 100644 --- a/flatland/envs/agent_utils.py +++ b/flatland/envs/agent_utils.py @@ -28,19 +28,32 @@ class EnvAgentStatic(object): position = attrib() direction = attrib() target = attrib() - moving = attrib() - - def __init__(self, position, direction, target, moving=False): + moving = attrib(default=False) + # speed_data: speed is added to position_fraction on each moving step, until position_fraction>=1.0, + # after which 'transition_action_on_cellexit' is executed (equivalent to executing that action in the previous + # cell if speed=1, as default) + speed_data = attrib(default=dict({'position_fraction': 0.0, 'speed': 1.0, 'transition_action_on_cellexit': 0})) + + def __init__(self, + position, + direction, + target, + moving=False, + speed_data={'position_fraction': 0.0, 'speed': 1.0, 'transition_action_on_cellexit': 0}): self.position = position self.direction = direction self.target = target self.moving = moving + self.speed_data = speed_data @classmethod def from_lists(cls, positions, directions, targets): """ Create a list of EnvAgentStatics from lists of positions, directions and targets """ - return list(starmap(EnvAgentStatic, zip(positions, directions, targets, [False] * len(positions)))) + speed_datas = [] + for i in range(len(positions)): + speed_datas.append({'position_fraction': 0.0, 'speed': 1.0, 'transition_action_on_cellexit': 0}) + return list(starmap(EnvAgentStatic, zip(positions, directions, targets, [False] * len(positions), speed_datas))) def to_list(self): @@ -54,7 +67,7 @@ class EnvAgentStatic(object): if type(lTarget) is np.ndarray: lTarget = lTarget.tolist() - return [lPos, int(self.direction), lTarget, int(self.moving)] + return [lPos, int(self.direction), lTarget, int(self.moving), self.speed_data] @attrs @@ -78,7 +91,7 @@ class EnvAgent(EnvAgentStatic): def to_list(self): return [ self.position, self.direction, self.target, self.handle, - self.old_direction, self.old_position, self.moving] + self.old_direction, self.old_position, self.moving, self.speed_data] @classmethod def from_static(cls, oStatic): diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index c22e1c5..2621308 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -73,7 +73,7 @@ class RailEnv(Environment): random_rail_generator : generate a random rail of given size rail_from_GridTransitionMap_generator(rail_map) : generate a rail from a GridTransitionMap object - rail_from_manual_specifications_generator(rail_spec) : generate a rail from + rail_from_manual_sp ecifications_generator(rail_spec) : generate a rail from a rail specifications array TODO: generate_rail_from_saved_list or from list of ndarray bitmaps --- width : int @@ -101,7 +101,6 @@ class RailEnv(Environment): self.action_space = [1] self.observation_space = self.obs_builder.observation_space # updated on resets? - self.actions = [0] * number_of_agents self.rewards = [0] * number_of_agents self.done = False @@ -192,29 +191,33 @@ class RailEnv(Environment): # for i in range(len(self.agents_handles)): for iAgent in range(self.get_num_agents()): agent = self.agents[iAgent] - - if iAgent not in action_dict: # no action has been supplied for this agent - if agent.moving: - # Keep moving - # Change MOVE_FORWARD to DO_NOTHING - action_dict[iAgent] = RailEnvActions.DO_NOTHING - else: - action_dict[iAgent] = RailEnvActions.DO_NOTHING + agent.speed_data['speed']=0.5 if self.dones[iAgent]: # this agent has already completed... continue - action = action_dict[iAgent] - if action < 0 or action > len(RailEnvActions): - print('ERROR: illegal action=', action, - 'for agent with index=', iAgent) - return + if np.equal(agent.position, agent.target).all(): + self.dones[iAgent] = True + else: + self.rewards_dict[iAgent] += step_penalty * agent.speed_data['speed'] + + if iAgent not in action_dict: # no action has been supplied for this agent + action_dict[iAgent] = RailEnvActions.DO_NOTHING + + if action_dict[iAgent] < 0 or action_dict[iAgent] > len(RailEnvActions): + print('ERROR: illegal action=', action_dict[iAgent], + 'for agent with index=', iAgent, + '"DO NOTHING" will be executed instead') + action_dict[iAgent] = RailEnvActions.DO_NOTHING + + action = action_dict[iAgent] if action == RailEnvActions.DO_NOTHING and agent.moving: # Keep moving action = RailEnvActions.MOVE_FORWARD - if action == RailEnvActions.STOP_MOVING and agent.moving: + if action == RailEnvActions.STOP_MOVING and agent.moving and agent.speed_data['position_fraction'] < 0.01: + # Only allow halting an agent on entering new cells. agent.moving = False self.rewards_dict[iAgent] += stop_penalty @@ -223,47 +226,73 @@ class RailEnv(Environment): agent.moving = True self.rewards_dict[iAgent] += start_penalty - if action != RailEnvActions.DO_NOTHING and action != RailEnvActions.STOP_MOVING: - cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \ - self._check_action_on_agent(action, agent) - if all([new_cell_isValid, transition_isValid, cell_isFree]): - agent.old_direction = agent.direction - agent.old_position = agent.position - agent.position = new_position - agent.direction = new_direction - else: - # Logic: if the chosen action is invalid, - # and it was LEFT or RIGHT, and the agent was moving, then keep moving FORWARD. - if (action == RailEnvActions.MOVE_LEFT or action == RailEnvActions.MOVE_RIGHT) and agent.moving: - cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \ - self._check_action_on_agent(RailEnvActions.MOVE_FORWARD, agent) - - if all([new_cell_isValid, transition_isValid, cell_isFree]): - agent.old_direction = agent.direction - agent.old_position = agent.position - agent.position = new_position - agent.direction = new_direction + # Now perform a movement. + # If the agent is in an initial position within a new cell (agent.speed_data['position_fraction']<eps) + # store the desired action in `transition_action_on_cellexit' (only if the desired transition is + # allowed! otherwise DO_NOTHING!) + # Then in any case (if agent.moving) and the `transition_action_on_cellexit' is valid, increment the + # position_fraction by the speed of the agent (regardless of action taken, as long as no + # STOP_MOVING, but that makes agent.moving=False) + # If the new position fraction is >= 1, reset to 0, and perform the stored + # transition_action_on_cellexit + + # If the agent can make an action + action_selected = False + if agent.speed_data['position_fraction'] < 0.01: + if action != RailEnvActions.DO_NOTHING and action != RailEnvActions.STOP_MOVING: + cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \ + self._check_action_on_agent(action, agent) + + if all([new_cell_isValid, transition_isValid, cell_isFree]): + agent.speed_data['transition_action_on_cellexit'] = action + action_selected = True + + else: + # But, if the chosen invalid action was LEFT/RIGHT, and the agent is moving, + # try to keep moving forward! + if (action == RailEnvActions.MOVE_LEFT or action == RailEnvActions.MOVE_RIGHT) and agent.moving: + cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \ + self._check_action_on_agent(RailEnvActions.MOVE_FORWARD, agent) + + if all([new_cell_isValid, transition_isValid, cell_isFree]): + agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_FORWARD + action_selected = True + + else: + # TODO: an invalid action was chosen after entering the cell. The agent cannot move. + self.rewards_dict[iAgent] += invalid_action_penalty + agent.moving = False + self.rewards_dict[iAgent] += stop_penalty + continue else: - # the action was not valid, add penalty + # TODO: an invalid action was chosen after entering the cell. The agent cannot move. self.rewards_dict[iAgent] += invalid_action_penalty + agent.moving = False + self.rewards_dict[iAgent] += stop_penalty + continue - else: - # the action was not valid, add penalty - self.rewards_dict[iAgent] += invalid_action_penalty + if agent.moving and (action_selected or agent.speed_data['position_fraction'] >= 0.01): + agent.speed_data['position_fraction'] += agent.speed_data['speed'] - if np.equal(agent.position, agent.target).all(): - self.dones[iAgent] = True - else: - self.rewards_dict[iAgent] += step_penalty + if agent.speed_data['position_fraction'] >= 1.0: + agent.speed_data['position_fraction'] = 0.0 + + # Perform stored action to transition to the next cell + + # Now 'transition_action_on_cellexit' will be guaranteed to be valid; it was checked on entering + # the cell + cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \ + self._check_action_on_agent(agent.speed_data['transition_action_on_cellexit'], agent) + agent.old_direction = agent.direction + agent.old_position = agent.position + agent.position = new_position + agent.direction = new_direction # Check for end of episode + add global reward to all rewards! if np.all([np.array_equal(agent2.position, agent2.target) for agent2 in self.agents]): self.dones["__all__"] = True self.rewards_dict = [0 * r + global_reward for r in self.rewards_dict] - # Reset the step actions (in case some agent doesn't 'register_action' - # on the next step) - self.actions = [0] * self.get_num_agents() return self._get_observations(), self.rewards_dict, self.dones, {} def _check_action_on_agent(self, action, agent): @@ -271,6 +300,7 @@ class RailEnv(Environment): # cell used to check for invalid actions new_direction, transition_isValid = self.check_action(agent, action) new_position = get_new_position(agent.position, new_direction) + # Is it a legal move? # 1) transition allows the new_direction in the cell, # 2) the new cell is not empty (case 0), @@ -281,11 +311,13 @@ class RailEnv(Environment): np.clip(new_position, [0, 0], [self.height - 1, self.width - 1])) and # check the new position has some transitions (ie is not an empty cell) self.rail.get_transitions(new_position) > 0) + # If transition validity hasn't been checked yet. if transition_isValid is None: transition_isValid = self.rail.get_transition( (*agent.position, agent.direction), new_direction) + # Check the new position is not the same as any of the existing agent positions # (including itself, for simplicity, since it is moving) cell_isFree = not np.any( -- GitLab