From d36631af89fddc0829379a956b11b9464c678b0f Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Fri, 5 Jul 2019 15:35:41 -0400 Subject: [PATCH] removed reward function bug which led to agent chosing invalid actions --- flatland/envs/rail_env.py | 48 +++++++++++++++++++-------------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index abc8a738..7952f29b 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -165,8 +165,8 @@ class RailEnv(Environment): self.restart_agents() - for iAgent in range(self.get_num_agents()): - agent = self.agents[iAgent] + for i_agemt in range(self.get_num_agents()): + agent = self.agents[i_agemt] agent.speed_data['position_fraction'] = 0.0 self.num_resets += 1 @@ -195,31 +195,31 @@ class RailEnv(Environment): # Reset the step rewards self.rewards_dict = dict() - for iAgent in range(self.get_num_agents()): - self.rewards_dict[iAgent] = 0 + for i_agemt in range(self.get_num_agents()): + self.rewards_dict[i_agemt] = 0 if self.dones["__all__"]: self.rewards_dict = {i: r + global_reward for i, r in self.rewards_dict.items()} return self._get_observations(), self.rewards_dict, self.dones, {} # for i in range(len(self.agents_handles)): - for iAgent in range(self.get_num_agents()): - agent = self.agents[iAgent] + for i_agemt in range(self.get_num_agents()): + agent = self.agents[i_agemt] agent.old_direction = agent.direction agent.old_position = agent.position - if self.dones[iAgent]: # this agent has already completed... + if self.dones[i_agemt]: # this agent has already completed... continue - if iAgent not in action_dict: # no action has been supplied for this agent - action_dict[iAgent] = RailEnvActions.DO_NOTHING + if i_agemt not in action_dict: # no action has been supplied for this agent + action_dict[i_agemt] = RailEnvActions.DO_NOTHING - if action_dict[iAgent] < 0 or action_dict[iAgent] > len(RailEnvActions): - print('ERROR: illegal action=', action_dict[iAgent], - 'for agent with index=', iAgent, + if action_dict[i_agemt] < 0 or action_dict[i_agemt] > len(RailEnvActions): + print('ERROR: illegal action=', action_dict[i_agemt], + 'for agent with index=', i_agemt, '"DO NOTHING" will be executed instead') - action_dict[iAgent] = RailEnvActions.DO_NOTHING + action_dict[i_agemt] = RailEnvActions.DO_NOTHING - action = action_dict[iAgent] + action = action_dict[i_agemt] if action == RailEnvActions.DO_NOTHING and agent.moving: # Keep moving @@ -228,12 +228,12 @@ class RailEnv(Environment): if action == RailEnvActions.STOP_MOVING and agent.moving and agent.speed_data['position_fraction'] == 0.: # Only allow halting an agent on entering new cells. agent.moving = False - self.rewards_dict[iAgent] += stop_penalty + self.rewards_dict[i_agemt] += stop_penalty if not agent.moving and not (action == RailEnvActions.DO_NOTHING or action == RailEnvActions.STOP_MOVING): # Allow agent to start with any forward or direction action agent.moving = True - self.rewards_dict[iAgent] += start_penalty + self.rewards_dict[i_agemt] += start_penalty # Now perform a movement. # If the agent is in an initial position within a new cell (agent.speed_data['position_fraction']<eps) @@ -269,18 +269,18 @@ class RailEnv(Environment): else: # TODO: an invalid action was chosen after entering the cell. The agent cannot move. - self.rewards_dict[iAgent] += invalid_action_penalty - self.rewards_dict[iAgent] += step_penalty * agent.speed_data['speed'] + self.rewards_dict[i_agemt] += invalid_action_penalty + self.rewards_dict[i_agemt] += step_penalty * agent.speed_data['speed'] agent.moving = False - self.rewards_dict[iAgent] += stop_penalty + self.rewards_dict[i_agemt] += stop_penalty continue else: # TODO: an invalid action was chosen after entering the cell. The agent cannot move. - self.rewards_dict[iAgent] += invalid_action_penalty - self.rewards_dict[iAgent] += step_penalty * agent.speed_data['speed'] + self.rewards_dict[i_agemt] += invalid_action_penalty + self.rewards_dict[i_agemt] += step_penalty * agent.speed_data['speed'] agent.moving = False - self.rewards_dict[iAgent] += stop_penalty + self.rewards_dict[i_agemt] += stop_penalty continue @@ -302,9 +302,9 @@ class RailEnv(Environment): agent.speed_data['position_fraction'] = 0.0 if np.equal(agent.position, agent.target).all(): - self.dones[iAgent] = True + self.dones[i_agemt] = True else: - self.rewards_dict[iAgent] += step_penalty * agent.speed_data['speed'] + self.rewards_dict[i_agemt] += step_penalty * agent.speed_data['speed'] # Check for end of episode + add global reward to all rewards! if np.all([np.array_equal(agent2.position, agent2.target) for agent2 in self.agents]): -- GitLab