From 287f810998de667e11bc68ae501c2f048eb8b972 Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Sun, 11 Aug 2019 10:04:12 -0400
Subject: [PATCH] testing that agents don't move in multi-speed setup

---
 tests/test_multi_speed.py | 35 ++++++++++++++++-------------------
 1 file changed, 16 insertions(+), 19 deletions(-)

diff --git a/tests/test_multi_speed.py b/tests/test_multi_speed.py
index 46310a2c..8b546871 100644
--- a/tests/test_multi_speed.py
+++ b/tests/test_multi_speed.py
@@ -9,11 +9,6 @@ np.random.seed(1)
 # Training on simple small tasks is the best way to get familiar with the environment
 #
 
-env = RailEnv(width=50,
-              height=50,
-              rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
-              number_of_agents=5)
-
 
 class RandomAgent:
 
@@ -46,18 +41,19 @@ class RandomAgent:
         return
 
 
-# Initialize the agent with the parameters corresponding to the environment and observation_builder
-agent = RandomAgent(218, 4)
-n_trials = 5
-
-
-# Empty dictionary for all agent action
-action_dict = dict()
-
-
-# Set all the different speeds
-
 def test_multi_speed_init():
+    env = RailEnv(width=50,
+                  height=50,
+                  rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999,
+                                                        seed=0),
+                  number_of_agents=5)
+    # Initialize the agent with the parameters corresponding to the environment and observation_builder
+    agent = RandomAgent(218, 4)
+
+    # Empty dictionary for all agent action
+    action_dict = dict()
+
+    # Set all the different speeds
     # Reset environment and get initial observations for all agents
     env.reset()
     # Here you can also further enhance the provided observation by means of normalization
@@ -66,7 +62,7 @@ def test_multi_speed_init():
     for i_agent in range(env.get_num_agents()):
         env.agents[i_agent].speed_data['speed'] = 1. / (i_agent + 1)
         old_pos.append(env.agents[i_agent].position)
-    score = 0
+
     # Run episode
     for step in range(100):
 
@@ -74,14 +70,15 @@ def test_multi_speed_init():
         for a in range(env.get_num_agents()):
             action = agent.act(0)
             action_dict.update({a: action})
-            # Check that agent did not move inbetween its speed updates
+
+            # Check that agent did not move in between its speed updates
             assert old_pos[a] == env.agents[a].position
 
         # Environment step which returns the observations for all agents, their corresponding
         # reward and whether their are done
         _, _, _, _ = env.step(action_dict)
 
-        # Update old position
+        # Update old position whenever an agent was allowed to move
         for i_agent in range(env.get_num_agents()):
             if (step + 1) % (i_agent + 1) == 0:
                 print(step, i_agent, env.agents[a].position)
-- 
GitLab