From 870fdf43b49a9ad71490acb9703b5a128b4dbc96 Mon Sep 17 00:00:00 2001
From: u214892 <u214892@sbb.ch>
Date: Mon, 17 Jun 2019 17:35:34 +0200
Subject: [PATCH] unit test for env_utils, cleanup predictions

---
 flatland/envs/env_utils.py | 63 +++++++++++++++++++++++---------------
 tests/test_env_utils.py    | 14 ++++++++-
 2 files changed, 51 insertions(+), 26 deletions(-)

diff --git a/flatland/envs/env_utils.py b/flatland/envs/env_utils.py
index 626ee5f0..cc4a0015 100644
--- a/flatland/envs/env_utils.py
+++ b/flatland/envs/env_utils.py
@@ -10,7 +10,7 @@ import numpy as np
 from flatland.core.transitions import Grid4TransitionsEnum
 
 
-def get_direction(pos1, pos2):
+def get_direction(pos1, pos2) -> Grid4TransitionsEnum:
     """
     Assumes pos1 and pos2 are adjacent location on grid.
     Returns direction (int) that can be used with transitions.
@@ -25,7 +25,7 @@ def get_direction(pos1, pos2):
         return 1
     if diff_1 < 0:
         return 3
-    return 0
+    raise Exception("Could not determine direction {}->{}".format(pos1, pos2))
 
 
 def mirror(dir):
@@ -71,33 +71,46 @@ def validate_new_transition(rail_trans, rail_array, prev_pos, current_pos, new_p
     return rail_trans.is_valid(new_trans)
 
 
-def position_to_coordinate(depth, position):
-    """
-         [ (0,0) (0,1) ..  (0,w)
-           (1,0) (1,1)     (1,w)
+def position_to_coordinate(depth, positions):
+    """Converts coordinates to positions:
+         [ (0,0) (0,1) ..  (0,w-1)
+           (1,0) (1,1)     (1,w-1)
            ...
-           (d,0) (d,1)     (d,w) ]
+           (d-1,0) (d-1,1)     (d-1,w-1)
+          ]
 
          -->
 
-         [ 0      1    ..   w
-           w+1    w+2  ..   2w
+         [ 0      d    ..  (w-1)*d
+           1      d+1
            ...
-           d*w+1  d*w+
+           d-1    2d-1     w*d-1
+         ]
 
     :param depth:
-    :param position:
+    :param positions:
     :return:
     """
     coords = ()
-    for p in position:
+    for p in positions:
         coords = coords + ((int(p) % depth, int(p) // depth),)  # changed x_dim to y_dim
     return coords
 
 
 def coordinate_to_position(depth, coords):
     """
-    Helper function to
+    Converts positions to coordinates:
+         [ 0      d    ..  (w-1)*d
+           1      d+1
+           ...
+           d-1    2d-1     w*d-1
+         ]
+         -->
+         [ (0,0) (0,1) ..  (0,w-1)
+           (1,0) (1,1)     (1,w-1)
+           ...
+           (d-1,0) (d-1,1)     (d-1,w-1)
+          ]
 
     :param depth:
     :param coords:
@@ -111,6 +124,18 @@ def coordinate_to_position(depth, coords):
     return position
 
 
+def get_new_position(position, movement):
+    """ Utility function that converts a compass movement over a 2D grid to new positions (r, c). """
+    if movement == Grid4TransitionsEnum.NORTH:
+        return (position[0] - 1, position[1])
+    elif movement == Grid4TransitionsEnum.EAST:
+        return (position[0], position[1] + 1)
+    elif movement == Grid4TransitionsEnum.SOUTH:
+        return (position[0] + 1, position[1])
+    elif movement == Grid4TransitionsEnum.WEST:
+        return (position[0], position[1] - 1)
+
+
 class AStarNode():
     """A node class for A* Pathfinding"""
 
@@ -266,18 +291,6 @@ def distance_on_rail(pos1, pos2):
     return abs(pos1[0] - pos2[0]) + abs(pos1[1] - pos2[1])
 
 
-def get_new_position(position, movement):
-    """ Utility function that converts a compass movement over a 2D grid to new positions (r, c). """
-    if movement == Grid4TransitionsEnum.NORTH:
-        return (position[0] - 1, position[1])
-    elif movement == Grid4TransitionsEnum.EAST:
-        return (position[0], position[1] + 1)
-    elif movement == Grid4TransitionsEnum.SOUTH:
-        return (position[0] + 1, position[1])
-    elif movement == Grid4TransitionsEnum.WEST:
-        return (position[0], position[1] - 1)
-
-
 def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents):
     """
     Given a `rail' GridTransitionMap, return a random placement of agents (initial position, direction and target).
diff --git a/tests/test_env_utils.py b/tests/test_env_utils.py
index b6764f59..467051ed 100644
--- a/tests/test_env_utils.py
+++ b/tests/test_env_utils.py
@@ -1,6 +1,8 @@
 import numpy as np
+import pytest
 
-from flatland.envs.env_utils import position_to_coordinate, coordinate_to_position
+from flatland.core.transitions import Grid4TransitionsEnum
+from flatland.envs.env_utils import position_to_coordinate, coordinate_to_position, get_direction
 
 depth_to_test = 5
 positions_to_test = [0, 5, 1, 6, 20, 30]
@@ -19,3 +21,13 @@ def test_coordinate_to_position():
     expected_positions = positions_to_test
     assert np.array_equal(actual_positions, expected_positions), \
         "converted positions {}, expected {}".format(actual_positions, expected_positions)
+
+
+def test_get_direction():
+    assert get_direction((0,0),(0,1)) == Grid4TransitionsEnum.EAST
+    assert get_direction((0,0),(0,2)) == Grid4TransitionsEnum.EAST
+    assert get_direction((0,0),(1,0)) == Grid4TransitionsEnum.SOUTH
+    assert get_direction((1,0),(0,0)) == Grid4TransitionsEnum.NORTH
+    assert get_direction((1,0),(0,0)) == Grid4TransitionsEnum.NORTH
+    with pytest.raises(Exception,match="Could not determine direction"):
+        get_direction((0,0),(0,0)) == Grid4TransitionsEnum.NORTH
-- 
GitLab