diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index e6c418827d93ae4e647bb3f2cfbcec81d9a9f059..8b4f43fe3535c7e85f6b954b5c58d62acceecc23 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -89,6 +89,15 @@ class RailEnv(Environment): For Round 2, they will be passed to the constructor as arguments, to allow for more flexibility. """ + alpha = 1.0 + beta = 1.0 + # Epsilon to avoid rounding errors + epsilon = 0.01 + invalid_action_penalty = 0 # previously -2; GIACOMO: we decided that invalid actions will carry no penalty + step_penalty = -1 * alpha + global_reward = 1 * beta + stop_penalty = 0 # penalty for stopping a moving agent + start_penalty = 0 # penalty for starting a stopped agent def __init__(self, width, @@ -252,7 +261,7 @@ class RailEnv(Environment): agent.malfunction_data['malfunction'] = 0 - self._agent_malfunction(agent) + self._agent_malfunction(i_agent, RailEnvActions.DO_NOTHING) self.num_resets += 1 self._elapsed_steps = 0 @@ -267,7 +276,9 @@ class RailEnv(Environment): # Return the new observation vectors for each agent return self._get_observations() - def _agent_malfunction(self, agent): + def _agent_malfunction(self, i_agent, action) -> bool: + agent = self.agents[i_agent] + # Decrease counter for next event if agent.malfunction_data['malfunction_rate'] > 0: agent.malfunction_data['next_malfunction'] -= 1 @@ -291,31 +302,19 @@ class RailEnv(Environment): self.max_number_of_steps_broken + 1) + 1 agent.malfunction_data['malfunction'] = num_broken_steps - - + return True + return False def step(self, action_dict_): self._elapsed_steps += 1 - action_dict = action_dict_.copy() - - alpha = 1.0 - beta = 1.0 - # Epsilon to avoid rounding errors - epsilon = 0.01 - invalid_action_penalty = 0 # previously -2; GIACOMO: we decided that invalid actions will carry no penalty - step_penalty = -1 * alpha - global_reward = 1 * beta - stop_penalty = 0 # penalty for stopping a moving agent - start_penalty = 0 # penalty for starting a stopped agent - # Reset the step rewards self.rewards_dict = dict() for i_agent in range(self.get_num_agents()): self.rewards_dict[i_agent] = 0 if self.dones["__all__"]: - self.rewards_dict = {i: r + global_reward for i, r in self.rewards_dict.items()} + self.rewards_dict = {i: r + self.global_reward for i, r in self.rewards_dict.items()} info_dict = { 'action_required': {i: False for i in range(self.get_num_agents())}, 'malfunction': {i: 0 for i in range(self.get_num_agents())}, @@ -324,26 +323,71 @@ class RailEnv(Environment): return self._get_observations(), self.rewards_dict, self.dones, info_dict for i_agent in range(self.get_num_agents()): - agent = self.agents[i_agent] - agent.old_direction = agent.direction - agent.old_position = agent.position - if self.dones[i_agent]: # this agent has already completed... continue - # No action has been supplied for this agent - if i_agent not in action_dict: - action_dict[i_agent] = RailEnvActions.DO_NOTHING + agent = self.agents[i_agent] + agent.old_direction = agent.direction + agent.old_position = agent.position + # No action has been supplied for this agent -> set DO_NOTHING as default + if i_agent not in action_dict_: + action = RailEnvActions.DO_NOTHING + else: + action = action_dict_[i_agent] - if action_dict[i_agent] < 0 or action_dict[i_agent] > len(RailEnvActions): - print('ERROR: illegal action=', action_dict[i_agent], + if action < 0 or action > len(RailEnvActions): + print('ERROR: illegal action=', action, 'for agent with index=', i_agent, '"DO NOTHING" will be executed instead') - action_dict[i_agent] = RailEnvActions.DO_NOTHING + action = RailEnvActions.DO_NOTHING + + # Check if agent breaks at this step + malfunction = self._agent_malfunction(i_agent, action) + + # if we're at the beginning of the cell, store the action + # As long as we're broken down at the beginning of the cell, we can choose other actions! + # This is a design choice made by Erik and Christian. + + # TODO refactor!!! + # If the agent can make an action + if agent.speed_data['position_fraction'] == 0.0: + if action != RailEnvActions.DO_NOTHING and action != RailEnvActions.STOP_MOVING: + cell_free, new_cell_valid, new_direction, new_position, transition_valid = \ + self._check_action_on_agent(action, agent) - action = action_dict[i_agent] + if all([new_cell_valid, transition_valid]): + agent.speed_data['transition_action_on_cellexit'] = action + else: + # But, if the chosen invalid action was LEFT/RIGHT, and the agent is moving, + # try to keep moving forward! + if (action == RailEnvActions.MOVE_LEFT or action == RailEnvActions.MOVE_RIGHT): + cell_free, new_cell_valid, new_direction, new_position, transition_valid = \ + self._check_action_on_agent(RailEnvActions.MOVE_FORWARD, agent) + + if all([new_cell_valid, transition_valid]): + agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_FORWARD + else: + # If the agent cannot move due to an invalid transition, we set its state to not moving + self.rewards_dict[i_agent] += self.invalid_action_penalty + self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] + self.rewards_dict[i_agent] += self.stop_penalty + agent.moving = False + action = RailEnvActions.DO_NOTHING + + else: + # If the agent cannot move due to an invalid transition, we set its state to not moving + self.rewards_dict[i_agent] += self.invalid_action_penalty + self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] + self.rewards_dict[i_agent] += self.stop_penalty + agent.moving = False + action = RailEnvActions.DO_NOTHING + else: + agent.speed_data['transition_action_on_cellexit'] = action + + if malfunction: + continue # The train is broken if agent.malfunction_data['malfunction'] > 0: @@ -352,37 +396,31 @@ class RailEnv(Environment): if agent.malfunction_data['malfunction'] < 2: agent.malfunction_data['malfunction'] -= 1 self.agents[i_agent].moving = True - action_dict[i_agent] = RailEnvActions.DO_NOTHING + action = RailEnvActions.DO_NOTHING else: agent.malfunction_data['malfunction'] -= 1 # Broken agents are stopped - self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed'] + self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] self.agents[i_agent].moving = False - action_dict[i_agent] = RailEnvActions.DO_NOTHING # Nothing left to do with broken agent continue - # Check if agent breaks at this step - self._agent_malfunction(agent) - - if action == RailEnvActions.DO_NOTHING and agent.moving: # Keep moving action = RailEnvActions.MOVE_FORWARD - if action == RailEnvActions.STOP_MOVING and agent.moving and agent.speed_data[ - 'position_fraction'] <= epsilon: + if action == RailEnvActions.STOP_MOVING and agent.moving and agent.speed_data['position_fraction'] == 0.0: # Only allow halting an agent on entering new cells. agent.moving = False - self.rewards_dict[i_agent] += stop_penalty + self.rewards_dict[i_agent] += self.stop_penalty if not agent.moving and not (action == RailEnvActions.DO_NOTHING or action == RailEnvActions.STOP_MOVING): # Allow agent to start with any forward or direction action agent.moving = True - self.rewards_dict[i_agent] += start_penalty + self.rewards_dict[i_agent] += self.start_penalty # Now perform a movement. # If the agent is in an initial position within a new cell (agent.speed_data['position_fraction']<eps) @@ -394,70 +432,36 @@ class RailEnv(Environment): # If the new position fraction is >= 1, reset to 0, and perform the stored # transition_action_on_cellexit - # If the agent can make an action - action_selected = False - if agent.speed_data['position_fraction'] <= epsilon: - if action != RailEnvActions.DO_NOTHING and action != RailEnvActions.STOP_MOVING: - cell_free, new_cell_valid, new_direction, new_position, transition_valid = \ - self._check_action_on_agent(action, agent) - - if all([new_cell_valid, transition_valid]): - agent.speed_data['transition_action_on_cellexit'] = action - action_selected = True + if agent.moving: - else: - # But, if the chosen invalid action was LEFT/RIGHT, and the agent is moving, - # try to keep moving forward! - if (action == RailEnvActions.MOVE_LEFT or action == RailEnvActions.MOVE_RIGHT) and agent.moving: - cell_free, new_cell_valid, new_direction, new_position, transition_valid = \ - self._check_action_on_agent(RailEnvActions.MOVE_FORWARD, agent) - - if all([new_cell_valid, transition_valid]): - agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_FORWARD - action_selected = True - - else: - # TODO: an invalid action was chosen after entering the cell. The agent cannot move. - self.rewards_dict[i_agent] += invalid_action_penalty - self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed'] - self.rewards_dict[i_agent] += stop_penalty - agent.moving = False - continue - else: - # TODO: an invalid action was chosen after entering the cell. The agent cannot move. - self.rewards_dict[i_agent] += invalid_action_penalty - self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed'] - self.rewards_dict[i_agent] += stop_penalty - agent.moving = False - continue - - if agent.moving and (action_selected or agent.speed_data['position_fraction'] > 0.0): agent.speed_data['position_fraction'] += agent.speed_data['speed'] - if agent.speed_data['position_fraction'] >= 1.0: - - # Perform stored action to transition to the next cell as soon as cell is free - cell_free, new_cell_valid, new_direction, new_position, transition_valid = \ - self._check_action_on_agent(agent.speed_data['transition_action_on_cellexit'], agent) - - if all([new_cell_valid, transition_valid, cell_free]) and agent.malfunction_data['malfunction'] == 0: - agent.position = new_position - agent.direction = new_direction - agent.speed_data['position_fraction'] = 0.0 - elif not transition_valid or not new_cell_valid: - # If the agent cannot move due to an invalid transition, we set its state to not moving - agent.moving = False + if agent.speed_data['position_fraction'] >= 1.0: + # Perform stored action to transition to the next cell as soon as cell is free + # Notice that we've already check new_cell_valid and transition valid when we stored the action, + # so we only have to check cell_free now! + if agent.speed_data['transition_action_on_cellexit'] in [RailEnvActions.DO_NOTHING, + RailEnvActions.STOP_MOVING]: + agent.speed_data['position_fraction'] = 0.0 + else: + cell_free, new_cell_valid, new_direction, new_position, transition_valid = self._check_action_on_agent( + agent.speed_data['transition_action_on_cellexit'], agent) + assert cell_free == all([cell_free, new_cell_valid, transition_valid]) + if cell_free: + agent.position = new_position + agent.direction = new_direction + agent.speed_data['position_fraction'] = 0.0 if np.equal(agent.position, agent.target).all(): self.dones[i_agent] = True agent.moving = False else: - self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed'] + self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] # Check for end of episode + add global reward to all rewards! if np.all([np.array_equal(agent2.position, agent2.target) for agent2 in self.agents]): self.dones["__all__"] = True - self.rewards_dict = {i: 0 * r + global_reward for i, r in self.rewards_dict.items()} + self.rewards_dict = {i: 0 * r + self.global_reward for i, r in self.rewards_dict.items()} if (self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps): self.dones["__all__"] = True @@ -481,6 +485,7 @@ class RailEnv(Environment): return self._get_observations(), self.rewards_dict, self.dones, info_dict def _check_action_on_agent(self, action, agent): + # compute number of possible transitions in the current # cell used to check for invalid actions new_direction, transition_valid = self.check_action(agent, action)