From 75dd03bb4073e38403289e425b306d8d1b7ae7cc Mon Sep 17 00:00:00 2001
From: u214892 <u214892@sbb.ch>
Date: Wed, 12 Jun 2019 18:53:36 +0200
Subject: [PATCH] #64 bugfix prediction should respect target

---
 flatland/envs/predictions.py         | 12 ++++--
 flatland/utils/editor.py             |  3 +-
 tests/test_env_prediction_builder.py | 57 ++++++++--------------------
 3 files changed, 26 insertions(+), 46 deletions(-)

diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py
index 3338e68..31c279f 100644
--- a/flatland/envs/predictions.py
+++ b/flatland/envs/predictions.py
@@ -45,10 +45,16 @@ class DummyPredictorForRailEnv(PredictionBuilder):
             action_priorities = [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT]
             _agent_initial_position = agent.position
             _agent_initial_direction = agent.direction
-            prediction = np.zeros(shape=(self.max_depth, 5))
+            prediction = np.zeros(shape=(self.max_depth + 1, 5))
             prediction[0] = [0, _agent_initial_position[0], _agent_initial_position[1], _agent_initial_direction, 0]
-            for index in range(1, self.max_depth):
+            for index in range(1, self.max_depth + 1):
                 action_done = False
+                # if we're at the target, stop moving...
+                if agent.position == agent.target:
+                    prediction[index] = [index, agent.target[0], agent.target[1], agent.direction,
+                                         RailEnvActions.STOP_MOVING]
+
+                    continue
                 for action in action_priorities:
                     cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \
                         self.env._check_action_on_agent(action, agent)
@@ -61,7 +67,7 @@ class DummyPredictorForRailEnv(PredictionBuilder):
                         action_done = True
                         break
                 if not action_done:
-                    print("Cannot move further.")
+                    raise Exception("Cannot move further. Something is wrong")
             prediction_dict[agent.handle] = prediction
             agent.position = _agent_initial_position
             agent.direction = _agent_initial_direction
diff --git a/flatland/utils/editor.py b/flatland/utils/editor.py
index d4e5c38..81565d6 100644
--- a/flatland/utils/editor.py
+++ b/flatland/utils/editor.py
@@ -323,7 +323,8 @@ class Controller(object):
     def restartAgents(self, event):
         self.log("Restart Agents - nAgents:", self.view.wRegenNAgents.value)
         if self.model.init_agents_static is not None:
-            self.model.env.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in self.model.init_agents_static]
+            self.model.env.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in
+                                            self.model.init_agents_static]
             self.model.env.agents = None
             self.model.init_agents_static = None
             self.player = None
diff --git a/tests/test_env_prediction_builder.py b/tests/test_env_prediction_builder.py
index 35a6a27..5b0e830 100644
--- a/tests/test_env_prediction_builder.py
+++ b/tests/test_env_prediction_builder.py
@@ -65,7 +65,7 @@ def test_predictions():
                   rail_generator=rail_from_GridTransitionMap_generator(rail),
                   number_of_agents=1,
                   obs_builder_object=GlobalObsForRailEnv(),
-                  prediction_builder_object=DummyPredictorForRailEnv(max_depth=20)
+                  prediction_builder_object=DummyPredictorForRailEnv(max_depth=10)
                   )
 
     env.reset()
@@ -73,6 +73,7 @@ def test_predictions():
     # set initial position and direction for testing...
     env.agents[0].position = (5, 6)
     env.agents[0].direction = 0
+    env.agents[0].target = (3., 0.)
 
     predictions = env.predict()
     positions = np.array(list(map(lambda prediction: [prediction[1], prediction[2]], predictions[0])))
@@ -89,18 +90,11 @@ def test_predictions():
                                    [3., 3.],
                                    [3., 2.],
                                    [3., 1.],
+                                   # at target (3,0): stay in this position from here on
                                    [3., 0.],
-                                   [3., 1.],
-                                   [3., 2.],
-                                   [3., 3.],
-                                   [3., 4.],
-                                   [3., 5.],
-                                   [3., 6.],
-                                   [3., 7.],
-                                   [3., 8.],
-                                   [3., 9.],
-                                   [3., 8.],
-                                   [3., 7.]])
+                                   [3., 0.],
+                                   [3., 0.],
+                                   ])
     expected_directions = np.array([[0.],
                                     [0.],
                                     [0.],
@@ -109,18 +103,11 @@ def test_predictions():
                                     [3.],
                                     [3.],
                                     [3.],
+                                    # at target (3,0): stay in this position from here on
                                     [3.],
-                                    [1.],
-                                    [1.],
-                                    [1.],
-                                    [1.],
-                                    [1.],
-                                    [1.],
-                                    [1.],
-                                    [1.],
-                                    [1.],
                                     [3.],
-                                    [3.]])
+                                    [3.]
+                                    ])
     expected_time_offsets = np.array([[0.],
                                       [1.],
                                       [2.],
@@ -132,15 +119,7 @@ def test_predictions():
                                       [8.],
                                       [9.],
                                       [10.],
-                                      [11.],
-                                      [12.],
-                                      [13.],
-                                      [14.],
-                                      [15.],
-                                      [16.],
-                                      [17.],
-                                      [18.],
-                                      [19.]])
+                                      ])
     expected_actions = np.array([[0.],
                                  [2.],
                                  [2.],
@@ -149,18 +128,12 @@ def test_predictions():
                                  [2.],
                                  [2.],
                                  [2.],
+                                 # reaching target by straight
                                  [2.],
-                                 [2.],
-                                 [2.],
-                                 [2.],
-                                 [2.],
-                                 [2.],
-                                 [2.],
-                                 [2.],
-                                 [2.],
-                                 [2.],
-                                 [2.],
-                                 [2.]])
+                                 # at target: stopped moving
+                                 [4.],
+                                 [4.],
+                                 ])
     assert np.array_equal(positions, expected_positions)
     assert np.array_equal(directions, expected_directions)
     assert np.array_equal(time_offsets, expected_time_offsets)
-- 
GitLab