Skip to content
Snippets Groups Projects
Commit d36631af authored by Erik Nygren's avatar Erik Nygren
Browse files

removed reward function bug which led to agent chosing invalid actions

parent 5d1de868
No related branches found
No related tags found
No related merge requests found
...@@ -165,8 +165,8 @@ class RailEnv(Environment): ...@@ -165,8 +165,8 @@ class RailEnv(Environment):
self.restart_agents() self.restart_agents()
for iAgent in range(self.get_num_agents()): for i_agemt in range(self.get_num_agents()):
agent = self.agents[iAgent] agent = self.agents[i_agemt]
agent.speed_data['position_fraction'] = 0.0 agent.speed_data['position_fraction'] = 0.0
self.num_resets += 1 self.num_resets += 1
...@@ -195,31 +195,31 @@ class RailEnv(Environment): ...@@ -195,31 +195,31 @@ class RailEnv(Environment):
# Reset the step rewards # Reset the step rewards
self.rewards_dict = dict() self.rewards_dict = dict()
for iAgent in range(self.get_num_agents()): for i_agemt in range(self.get_num_agents()):
self.rewards_dict[iAgent] = 0 self.rewards_dict[i_agemt] = 0
if self.dones["__all__"]: if self.dones["__all__"]:
self.rewards_dict = {i: r + global_reward for i, r in self.rewards_dict.items()} self.rewards_dict = {i: r + global_reward for i, r in self.rewards_dict.items()}
return self._get_observations(), self.rewards_dict, self.dones, {} return self._get_observations(), self.rewards_dict, self.dones, {}
# for i in range(len(self.agents_handles)): # for i in range(len(self.agents_handles)):
for iAgent in range(self.get_num_agents()): for i_agemt in range(self.get_num_agents()):
agent = self.agents[iAgent] agent = self.agents[i_agemt]
agent.old_direction = agent.direction agent.old_direction = agent.direction
agent.old_position = agent.position agent.old_position = agent.position
if self.dones[iAgent]: # this agent has already completed... if self.dones[i_agemt]: # this agent has already completed...
continue continue
if iAgent not in action_dict: # no action has been supplied for this agent if i_agemt not in action_dict: # no action has been supplied for this agent
action_dict[iAgent] = RailEnvActions.DO_NOTHING action_dict[i_agemt] = RailEnvActions.DO_NOTHING
if action_dict[iAgent] < 0 or action_dict[iAgent] > len(RailEnvActions): if action_dict[i_agemt] < 0 or action_dict[i_agemt] > len(RailEnvActions):
print('ERROR: illegal action=', action_dict[iAgent], print('ERROR: illegal action=', action_dict[i_agemt],
'for agent with index=', iAgent, 'for agent with index=', i_agemt,
'"DO NOTHING" will be executed instead') '"DO NOTHING" will be executed instead')
action_dict[iAgent] = RailEnvActions.DO_NOTHING action_dict[i_agemt] = RailEnvActions.DO_NOTHING
action = action_dict[iAgent] action = action_dict[i_agemt]
if action == RailEnvActions.DO_NOTHING and agent.moving: if action == RailEnvActions.DO_NOTHING and agent.moving:
# Keep moving # Keep moving
...@@ -228,12 +228,12 @@ class RailEnv(Environment): ...@@ -228,12 +228,12 @@ class RailEnv(Environment):
if action == RailEnvActions.STOP_MOVING and agent.moving and agent.speed_data['position_fraction'] == 0.: if action == RailEnvActions.STOP_MOVING and agent.moving and agent.speed_data['position_fraction'] == 0.:
# Only allow halting an agent on entering new cells. # Only allow halting an agent on entering new cells.
agent.moving = False agent.moving = False
self.rewards_dict[iAgent] += stop_penalty self.rewards_dict[i_agemt] += stop_penalty
if not agent.moving and not (action == RailEnvActions.DO_NOTHING or action == RailEnvActions.STOP_MOVING): if not agent.moving and not (action == RailEnvActions.DO_NOTHING or action == RailEnvActions.STOP_MOVING):
# Allow agent to start with any forward or direction action # Allow agent to start with any forward or direction action
agent.moving = True agent.moving = True
self.rewards_dict[iAgent] += start_penalty self.rewards_dict[i_agemt] += start_penalty
# Now perform a movement. # Now perform a movement.
# If the agent is in an initial position within a new cell (agent.speed_data['position_fraction']<eps) # If the agent is in an initial position within a new cell (agent.speed_data['position_fraction']<eps)
...@@ -269,18 +269,18 @@ class RailEnv(Environment): ...@@ -269,18 +269,18 @@ class RailEnv(Environment):
else: else:
# TODO: an invalid action was chosen after entering the cell. The agent cannot move. # TODO: an invalid action was chosen after entering the cell. The agent cannot move.
self.rewards_dict[iAgent] += invalid_action_penalty self.rewards_dict[i_agemt] += invalid_action_penalty
self.rewards_dict[iAgent] += step_penalty * agent.speed_data['speed'] self.rewards_dict[i_agemt] += step_penalty * agent.speed_data['speed']
agent.moving = False agent.moving = False
self.rewards_dict[iAgent] += stop_penalty self.rewards_dict[i_agemt] += stop_penalty
continue continue
else: else:
# TODO: an invalid action was chosen after entering the cell. The agent cannot move. # TODO: an invalid action was chosen after entering the cell. The agent cannot move.
self.rewards_dict[iAgent] += invalid_action_penalty self.rewards_dict[i_agemt] += invalid_action_penalty
self.rewards_dict[iAgent] += step_penalty * agent.speed_data['speed'] self.rewards_dict[i_agemt] += step_penalty * agent.speed_data['speed']
agent.moving = False agent.moving = False
self.rewards_dict[iAgent] += stop_penalty self.rewards_dict[i_agemt] += stop_penalty
continue continue
...@@ -302,9 +302,9 @@ class RailEnv(Environment): ...@@ -302,9 +302,9 @@ class RailEnv(Environment):
agent.speed_data['position_fraction'] = 0.0 agent.speed_data['position_fraction'] = 0.0
if np.equal(agent.position, agent.target).all(): if np.equal(agent.position, agent.target).all():
self.dones[iAgent] = True self.dones[i_agemt] = True
else: else:
self.rewards_dict[iAgent] += step_penalty * agent.speed_data['speed'] self.rewards_dict[i_agemt] += step_penalty * agent.speed_data['speed']
# Check for end of episode + add global reward to all rewards! # Check for end of episode + add global reward to all rewards!
if np.all([np.array_equal(agent2.position, agent2.target) for agent2 in self.agents]): if np.all([np.array_equal(agent2.position, agent2.target) for agent2 in self.agents]):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment