From d36631af89fddc0829379a956b11b9464c678b0f Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Fri, 5 Jul 2019 15:35:41 -0400
Subject: [PATCH] removed reward function bug which led to agent chosing
 invalid actions

---
 flatland/envs/rail_env.py | 48 +++++++++++++++++++--------------------
 1 file changed, 24 insertions(+), 24 deletions(-)

diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index abc8a738..7952f29b 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -165,8 +165,8 @@ class RailEnv(Environment):
 
         self.restart_agents()
 
-        for iAgent in range(self.get_num_agents()):
-            agent = self.agents[iAgent]
+        for i_agemt in range(self.get_num_agents()):
+            agent = self.agents[i_agemt]
             agent.speed_data['position_fraction'] = 0.0
 
         self.num_resets += 1
@@ -195,31 +195,31 @@ class RailEnv(Environment):
 
         # Reset the step rewards
         self.rewards_dict = dict()
-        for iAgent in range(self.get_num_agents()):
-            self.rewards_dict[iAgent] = 0
+        for i_agemt in range(self.get_num_agents()):
+            self.rewards_dict[i_agemt] = 0
 
         if self.dones["__all__"]:
             self.rewards_dict = {i: r + global_reward for i, r in self.rewards_dict.items()}
             return self._get_observations(), self.rewards_dict, self.dones, {}
 
         # for i in range(len(self.agents_handles)):
-        for iAgent in range(self.get_num_agents()):
-            agent = self.agents[iAgent]
+        for i_agemt in range(self.get_num_agents()):
+            agent = self.agents[i_agemt]
             agent.old_direction = agent.direction
             agent.old_position = agent.position
-            if self.dones[iAgent]:  # this agent has already completed...
+            if self.dones[i_agemt]:  # this agent has already completed...
                 continue
 
-            if iAgent not in action_dict:  # no action has been supplied for this agent
-                action_dict[iAgent] = RailEnvActions.DO_NOTHING
+            if i_agemt not in action_dict:  # no action has been supplied for this agent
+                action_dict[i_agemt] = RailEnvActions.DO_NOTHING
 
-            if action_dict[iAgent] < 0 or action_dict[iAgent] > len(RailEnvActions):
-                print('ERROR: illegal action=', action_dict[iAgent],
-                      'for agent with index=', iAgent,
+            if action_dict[i_agemt] < 0 or action_dict[i_agemt] > len(RailEnvActions):
+                print('ERROR: illegal action=', action_dict[i_agemt],
+                      'for agent with index=', i_agemt,
                       '"DO NOTHING" will be executed instead')
-                action_dict[iAgent] = RailEnvActions.DO_NOTHING
+                action_dict[i_agemt] = RailEnvActions.DO_NOTHING
 
-            action = action_dict[iAgent]
+            action = action_dict[i_agemt]
 
             if action == RailEnvActions.DO_NOTHING and agent.moving:
                 # Keep moving
@@ -228,12 +228,12 @@ class RailEnv(Environment):
             if action == RailEnvActions.STOP_MOVING and agent.moving and agent.speed_data['position_fraction'] == 0.:
                 # Only allow halting an agent on entering new cells.
                 agent.moving = False
-                self.rewards_dict[iAgent] += stop_penalty
+                self.rewards_dict[i_agemt] += stop_penalty
 
             if not agent.moving and not (action == RailEnvActions.DO_NOTHING or action == RailEnvActions.STOP_MOVING):
                 # Allow agent to start with any forward or direction action
                 agent.moving = True
-                self.rewards_dict[iAgent] += start_penalty
+                self.rewards_dict[i_agemt] += start_penalty
 
             # Now perform a movement.
             # If the agent is in an initial position within a new cell (agent.speed_data['position_fraction']<eps)
@@ -269,18 +269,18 @@ class RailEnv(Environment):
 
                             else:
                                 # TODO: an invalid action was chosen after entering the cell. The agent cannot move.
-                                self.rewards_dict[iAgent] += invalid_action_penalty
-                                self.rewards_dict[iAgent] += step_penalty * agent.speed_data['speed']
+                                self.rewards_dict[i_agemt] += invalid_action_penalty
+                                self.rewards_dict[i_agemt] += step_penalty * agent.speed_data['speed']
                                 agent.moving = False
-                                self.rewards_dict[iAgent] += stop_penalty
+                                self.rewards_dict[i_agemt] += stop_penalty
 
                                 continue
                         else:
                             # TODO: an invalid action was chosen after entering the cell. The agent cannot move.
-                            self.rewards_dict[iAgent] += invalid_action_penalty
-                            self.rewards_dict[iAgent] += step_penalty * agent.speed_data['speed']
+                            self.rewards_dict[i_agemt] += invalid_action_penalty
+                            self.rewards_dict[i_agemt] += step_penalty * agent.speed_data['speed']
                             agent.moving = False
-                            self.rewards_dict[iAgent] += stop_penalty
+                            self.rewards_dict[i_agemt] += stop_penalty
 
                             continue
 
@@ -302,9 +302,9 @@ class RailEnv(Environment):
                     agent.speed_data['position_fraction'] = 0.0
 
             if np.equal(agent.position, agent.target).all():
-                self.dones[iAgent] = True
+                self.dones[i_agemt] = True
             else:
-                self.rewards_dict[iAgent] += step_penalty * agent.speed_data['speed']
+                self.rewards_dict[i_agemt] += step_penalty * agent.speed_data['speed']
 
         # Check for end of episode + add global reward to all rewards!
         if np.all([np.array_equal(agent2.position, agent2.target) for agent2 in self.agents]):
-- 
GitLab