From 47a23b3c2d8a7c244f857448ccf860d05b23c1e1 Mon Sep 17 00:00:00 2001
From: Giacomo Spigler <spiglerg@gmail.com>
Date: Wed, 19 Jun 2019 19:42:32 +0200
Subject: [PATCH] fixed bad tests bug

---
 flatland/envs/rail_env.py  | 11 +++++------
 tests/__init__.py          |  0
 tests/test_environments.py |  7 ++++---
 3 files changed, 9 insertions(+), 9 deletions(-)
 create mode 100644 tests/__init__.py

diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 58df3a14..8cf6d52f 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -191,16 +191,10 @@ class RailEnv(Environment):
         # for i in range(len(self.agents_handles)):
         for iAgent in range(self.get_num_agents()):
             agent = self.agents[iAgent]
-            print(agent.speed_data['speed'])
 
             if self.dones[iAgent]:  # this agent has already completed...
                 continue
 
-            if np.equal(agent.position, agent.target).all():
-                self.dones[iAgent] = True
-            else:
-                self.rewards_dict[iAgent] += step_penalty * agent.speed_data['speed']
-
             if iAgent not in action_dict:  # no action has been supplied for this agent
                 action_dict[iAgent] = RailEnvActions.DO_NOTHING
 
@@ -288,6 +282,11 @@ class RailEnv(Environment):
                 agent.position = new_position
                 agent.direction = new_direction
 
+            if np.equal(agent.position, agent.target).all():
+                self.dones[iAgent] = True
+            else:
+                self.rewards_dict[iAgent] += step_penalty * agent.speed_data['speed']
+
         # Check for end of episode + add global reward to all rewards!
         if np.all([np.array_equal(agent2.position, agent2.target) for agent2 in self.agents]):
             self.dones["__all__"] = True
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/test_environments.py b/tests/test_environments.py
index 11f0acba..aa24467d 100644
--- a/tests/test_environments.py
+++ b/tests/test_environments.py
@@ -3,7 +3,7 @@
 import numpy as np
 
 from flatland.core.transition_map import GridTransitionMap
-from flatland.core.transitions import Grid4Transitions
+from flatland.core.transitions import Grid4Transitions, RailEnvTransitions
 from flatland.envs.agent_utils import EnvAgent
 from flatland.envs.generators import complex_rail_generator
 from flatland.envs.generators import rail_from_GridTransitionMap_generator
@@ -53,7 +53,7 @@ def test_rail_environment_single_agent():
     # | |  |
     # \_/\_/
 
-    transitions = Grid4Transitions([])
+    transitions = RailEnvTransitions()
     vertical_line = cells[1]
     south_symmetrical_switch = cells[6]
     north_symmetrical_switch = transitions.rotate_transition(south_symmetrical_switch, 180)
@@ -107,6 +107,7 @@ def test_rail_environment_single_agent():
             if prev_pos != pos:
                 valid_active_actions_done += 1
 
+
         # After 6 movements on this railway network, the train should be back
         # to its original height on the map.
         assert (initial_pos[0] == agent.position[0])
@@ -121,9 +122,9 @@ def test_rail_environment_single_agent():
                 action = np.random.randint(4)
 
                 _, _, dones, _ = rail_env.step({0: action})
-
                 done = dones['__all__']
 
+test_rail_environment_single_agent()
 
 def test_dead_end():
     transitions = Grid4Transitions([])
-- 
GitLab