Skip to content
Snippets Groups Projects
Commit d36631af authored by Erik Nygren's avatar Erik Nygren
Browse files

removed reward function bug which led to agent chosing invalid actions

parent 5d1de868
No related branches found
No related tags found
No related merge requests found
......@@ -165,8 +165,8 @@ class RailEnv(Environment):
self.restart_agents()
for iAgent in range(self.get_num_agents()):
agent = self.agents[iAgent]
for i_agemt in range(self.get_num_agents()):
agent = self.agents[i_agemt]
agent.speed_data['position_fraction'] = 0.0
self.num_resets += 1
......@@ -195,31 +195,31 @@ class RailEnv(Environment):
# Reset the step rewards
self.rewards_dict = dict()
for iAgent in range(self.get_num_agents()):
self.rewards_dict[iAgent] = 0
for i_agemt in range(self.get_num_agents()):
self.rewards_dict[i_agemt] = 0
if self.dones["__all__"]:
self.rewards_dict = {i: r + global_reward for i, r in self.rewards_dict.items()}
return self._get_observations(), self.rewards_dict, self.dones, {}
# for i in range(len(self.agents_handles)):
for iAgent in range(self.get_num_agents()):
agent = self.agents[iAgent]
for i_agemt in range(self.get_num_agents()):
agent = self.agents[i_agemt]
agent.old_direction = agent.direction
agent.old_position = agent.position
if self.dones[iAgent]: # this agent has already completed...
if self.dones[i_agemt]: # this agent has already completed...
continue
if iAgent not in action_dict: # no action has been supplied for this agent
action_dict[iAgent] = RailEnvActions.DO_NOTHING
if i_agemt not in action_dict: # no action has been supplied for this agent
action_dict[i_agemt] = RailEnvActions.DO_NOTHING
if action_dict[iAgent] < 0 or action_dict[iAgent] > len(RailEnvActions):
print('ERROR: illegal action=', action_dict[iAgent],
'for agent with index=', iAgent,
if action_dict[i_agemt] < 0 or action_dict[i_agemt] > len(RailEnvActions):
print('ERROR: illegal action=', action_dict[i_agemt],
'for agent with index=', i_agemt,
'"DO NOTHING" will be executed instead')
action_dict[iAgent] = RailEnvActions.DO_NOTHING
action_dict[i_agemt] = RailEnvActions.DO_NOTHING
action = action_dict[iAgent]
action = action_dict[i_agemt]
if action == RailEnvActions.DO_NOTHING and agent.moving:
# Keep moving
......@@ -228,12 +228,12 @@ class RailEnv(Environment):
if action == RailEnvActions.STOP_MOVING and agent.moving and agent.speed_data['position_fraction'] == 0.:
# Only allow halting an agent on entering new cells.
agent.moving = False
self.rewards_dict[iAgent] += stop_penalty
self.rewards_dict[i_agemt] += stop_penalty
if not agent.moving and not (action == RailEnvActions.DO_NOTHING or action == RailEnvActions.STOP_MOVING):
# Allow agent to start with any forward or direction action
agent.moving = True
self.rewards_dict[iAgent] += start_penalty
self.rewards_dict[i_agemt] += start_penalty
# Now perform a movement.
# If the agent is in an initial position within a new cell (agent.speed_data['position_fraction']<eps)
......@@ -269,18 +269,18 @@ class RailEnv(Environment):
else:
# TODO: an invalid action was chosen after entering the cell. The agent cannot move.
self.rewards_dict[iAgent] += invalid_action_penalty
self.rewards_dict[iAgent] += step_penalty * agent.speed_data['speed']
self.rewards_dict[i_agemt] += invalid_action_penalty
self.rewards_dict[i_agemt] += step_penalty * agent.speed_data['speed']
agent.moving = False
self.rewards_dict[iAgent] += stop_penalty
self.rewards_dict[i_agemt] += stop_penalty
continue
else:
# TODO: an invalid action was chosen after entering the cell. The agent cannot move.
self.rewards_dict[iAgent] += invalid_action_penalty
self.rewards_dict[iAgent] += step_penalty * agent.speed_data['speed']
self.rewards_dict[i_agemt] += invalid_action_penalty
self.rewards_dict[i_agemt] += step_penalty * agent.speed_data['speed']
agent.moving = False
self.rewards_dict[iAgent] += stop_penalty
self.rewards_dict[i_agemt] += stop_penalty
continue
......@@ -302,9 +302,9 @@ class RailEnv(Environment):
agent.speed_data['position_fraction'] = 0.0
if np.equal(agent.position, agent.target).all():
self.dones[iAgent] = True
self.dones[i_agemt] = True
else:
self.rewards_dict[iAgent] += step_penalty * agent.speed_data['speed']
self.rewards_dict[i_agemt] += step_penalty * agent.speed_data['speed']
# Check for end of episode + add global reward to all rewards!
if np.all([np.array_equal(agent2.position, agent2.target) for agent2 in self.agents]):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment