From 96f670b38e8d8b4ce039960755102e7e6e26f82c Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Thu, 4 Jul 2019 16:29:01 -0400
Subject: [PATCH] Major update to shortest path predictor major update and
 bugfix for collision detection

---
 flatland/core/transitions.py            |  2 +-
 flatland/envs/observations.py           | 60 ++++++-------------------
 flatland/envs/predictions.py            | 14 ++----
 tests/test_flatland_envs_predictions.py | 28 ------------
 4 files changed, 17 insertions(+), 87 deletions(-)

diff --git a/flatland/core/transitions.py b/flatland/core/transitions.py
index 29b57c4..5049c23 100644
--- a/flatland/core/transitions.py
+++ b/flatland/core/transitions.py
@@ -12,7 +12,7 @@ class Transitions:
 
     Generic class that implements checks to control whether a
     certain transition is allowed (agent facing a direction
-    `orientation' and moving into direction `direction')
+    `orientation' and moving into direction `orientation')
     """
 
     def get_type(self):
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 8222115..a490385 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -283,7 +283,7 @@ class TreeObsForRailEnv(ObservationBuilder):
             if possible_transitions[branch_direction]:
                 new_cell = self._new_position(agent.position, branch_direction)
                 branch_observation, branch_visited = \
-                    self._explore_branch(handle, new_cell, branch_direction, root_observation, 0, 1)
+                    self._explore_branch(handle, new_cell, branch_direction, root_observation, 1, 1)
                 observation = observation + branch_observation
                 visited = visited.union(branch_visited)
             else:
@@ -351,22 +351,23 @@ class TreeObsForRailEnv(ObservationBuilder):
                     post_step = min(self.max_prediction_depth - 1, tot_dist + 1)
 
                     # Look for opposing paths at distance num_step
-                    if int_position in np.delete(self.predicted_pos[tot_dist], handle):
-                        conflicting_agent = np.where(np.delete(self.predicted_pos[tot_dist], handle) == int_position)
-                        for ca in conflicting_agent:
-                            if direction != self.predicted_dir[tot_dist][ca[0]] and tot_dist < potential_conflict:
+                    if int_position in np.delete(self.predicted_pos[tot_dist], handle, 0):
+                        conflicting_agent = np.where(self.predicted_pos[tot_dist] == int_position)
+                        for ca in conflicting_agent[0]:
+
+                            if direction != self.predicted_dir[tot_dist][ca] and tot_dist < potential_conflict:
                                 potential_conflict = tot_dist
                     # Look for opposing paths at distance num_step-1
-                    elif int_position in np.delete(self.predicted_pos[pre_step], handle):
+                    elif int_position in np.delete(self.predicted_pos[pre_step], handle, 0):
                         conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position)
-                        for ca in conflicting_agent:
-                            if direction != self.predicted_dir[pre_step][ca[0]] and tot_dist < potential_conflict:
+                        for ca in conflicting_agent[0]:
+                            if direction != self.predicted_dir[pre_step][ca] and tot_dist < potential_conflict:
                                 potential_conflict = tot_dist
                     # Look for opposing paths at distance num_step+1
-                    elif int_position in np.delete(self.predicted_pos[post_step], handle):
-                        conflicting_agent = np.where(np.delete(self.predicted_pos[post_step], handle) == int_position)
-                        for ca in conflicting_agent:
-                            if direction != self.predicted_dir[post_step][ca[0]] and tot_dist < potential_conflict:
+                    elif int_position in np.delete(self.predicted_pos[post_step], handle, 0):
+                        conflicting_agent = np.where(self.predicted_pos[post_step] == int_position)
+                        for ca in conflicting_agent[0]:
+                            if direction != self.predicted_dir[post_step][ca] and tot_dist < potential_conflict:
                                 potential_conflict = tot_dist
 
             if position in self.location_has_target and position != agent.target:
@@ -436,41 +437,6 @@ class TreeObsForRailEnv(ObservationBuilder):
         # #############################
         # #############################
         # Modify here to append new / different features for each visited cell!
-        """
-        other_agent_same_direction = \
-            1 if position in self.location_has_agent and self.location_has_agent_direction[position] == direction else 0
-        other_agent_opposite_direction = \
-            1 if position in self.location_has_agent and self.location_has_agent_direction[position] != direction else 0
-
-        if last_isTarget:
-            observation = [0,
-                           other_target_encountered,
-                           other_agent_encountered,
-                           root_observation[3] + num_steps,
-                           0,
-                           other_agent_same_direction,
-                           other_agent_opposite_direction
-                           ]
-
-        elif last_isTerminal:
-            observation = [0,
-                           other_target_encountered,
-                           other_agent_encountered,
-                           np.inf,
-                           np.inf,
-                           other_agent_same_direction,
-                           other_agent_opposite_direction
-                           ]
-        else:
-            observation = [0,
-                           other_target_encountered,
-                           other_agent_encountered,
-                           root_observation[3] + num_steps,
-                           self.distance_map[handle, position[0], position[1], direction],
-                           other_agent_same_direction,
-                           other_agent_opposite_direction
-                           ]
-        """
 
         if last_isTarget:
             observation = [own_target_encountered,
diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py
index 654f549..d471596 100644
--- a/flatland/envs/predictions.py
+++ b/flatland/envs/predictions.py
@@ -142,7 +142,8 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
                     min_dist = np.inf
                     for direction in range(4):
                         if cell_transitions[direction] == 1:
-                            target_dist = distance_map[agent.handle, agent.position[0], agent.position[1], direction]
+                            neighbour_cell = get_new_position(agent.position, direction)
+                            target_dist = distance_map[agent.handle, neighbour_cell[0], neighbour_cell[1], direction]
                             if target_dist < min_dist:
                                 min_dist = target_dist
                                 new_direction = direction
@@ -150,21 +151,12 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
                 else:
                     raise Exception("No transition possible {}".format(cell_transitions))
 
-                # which action to take for the transition?
-                action = None
-                for _action in [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_RIGHT, RailEnvActions.MOVE_LEFT]:
-                    _, _, _new_direction, _new_position, _ = self.env._check_action_on_agent(_action, agent)
-                    if np.array_equal(_new_position, new_position):
-                        action = _action
-                        break
-                assert action is not None
-
                 # update the agent's position and direction
                 agent.position = new_position
                 agent.direction = new_direction
 
                 # prediction is ready
-                prediction[index] = [index, *new_position, new_direction, action]
+                prediction[index] = [index, *new_position, new_direction, 0]
             prediction_dict[agent.handle] = prediction
 
             # cleanup: reset initial position
diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py
index 1685067..c5514bd 100644
--- a/tests/test_flatland_envs_predictions.py
+++ b/tests/test_flatland_envs_predictions.py
@@ -9,7 +9,6 @@ from flatland.envs.generators import rail_from_GridTransitionMap_generator
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
-from flatland.envs.rail_env import RailEnvActions
 from flatland.utils.rendertools import RenderTool
 
 """Test predictions for `flatland` package."""
@@ -187,7 +186,6 @@ def test_shortest_path_predictor(rendering=False):
     positions = np.array(list(map(lambda prediction: [*prediction[1:3]], predictions[0])))
     directions = np.array(list(map(lambda prediction: [prediction[3]], predictions[0])))
     time_offsets = np.array(list(map(lambda prediction: [prediction[0]], predictions[0])))
-    actions = np.array(list(map(lambda prediction: [prediction[4]], predictions[0])))
 
     expected_positions = [
         [5, 6],
@@ -260,35 +258,9 @@ def test_shortest_path_predictor(rendering=False):
         [20.],
     ])
 
-    expected_actions = np.array([
-        [RailEnvActions.DO_NOTHING],  # next [5,6]
-        [RailEnvActions.MOVE_FORWARD],  # next [4,6]
-        [RailEnvActions.MOVE_FORWARD],  # next [3,6]
-        [RailEnvActions.MOVE_RIGHT],  # next [3,7]
-        [RailEnvActions.MOVE_FORWARD],  # next [3,8]
-        [RailEnvActions.MOVE_FORWARD],  # next [3,9]
-        [RailEnvActions.STOP_MOVING],  # at [3,9] == target
-        [RailEnvActions.STOP_MOVING],
-        [RailEnvActions.STOP_MOVING],
-        [RailEnvActions.STOP_MOVING],
-        [RailEnvActions.STOP_MOVING],
-        [RailEnvActions.STOP_MOVING],
-        [RailEnvActions.STOP_MOVING],
-        [RailEnvActions.STOP_MOVING],
-        [RailEnvActions.STOP_MOVING],
-        [RailEnvActions.STOP_MOVING],
-        [RailEnvActions.STOP_MOVING],
-        [RailEnvActions.STOP_MOVING],
-        [RailEnvActions.STOP_MOVING],
-        [RailEnvActions.STOP_MOVING],
-        [RailEnvActions.STOP_MOVING],
-    ])
-
     assert np.array_equal(positions, expected_positions), \
         "positions {}, expected {}".format(positions, expected_positions)
     assert np.array_equal(directions, expected_directions), \
         "directions {}, expected {}".format(directions, expected_directions)
     assert np.array_equal(time_offsets, expected_time_offsets), \
         "time_offsets {}, expected {}".format(time_offsets, expected_time_offsets)
-    assert np.array_equal(actions, expected_actions), \
-        "actions {}, expected {}".format(actions, expected_actions)
-- 
GitLab