From 870fdf43b49a9ad71490acb9703b5a128b4dbc96 Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Mon, 17 Jun 2019 17:35:34 +0200 Subject: [PATCH] unit test for env_utils, cleanup predictions --- flatland/envs/env_utils.py | 63 +++++++++++++++++++++++--------------- tests/test_env_utils.py | 14 ++++++++- 2 files changed, 51 insertions(+), 26 deletions(-) diff --git a/flatland/envs/env_utils.py b/flatland/envs/env_utils.py index 626ee5f0..cc4a0015 100644 --- a/flatland/envs/env_utils.py +++ b/flatland/envs/env_utils.py @@ -10,7 +10,7 @@ import numpy as np from flatland.core.transitions import Grid4TransitionsEnum -def get_direction(pos1, pos2): +def get_direction(pos1, pos2) -> Grid4TransitionsEnum: """ Assumes pos1 and pos2 are adjacent location on grid. Returns direction (int) that can be used with transitions. @@ -25,7 +25,7 @@ def get_direction(pos1, pos2): return 1 if diff_1 < 0: return 3 - return 0 + raise Exception("Could not determine direction {}->{}".format(pos1, pos2)) def mirror(dir): @@ -71,33 +71,46 @@ def validate_new_transition(rail_trans, rail_array, prev_pos, current_pos, new_p return rail_trans.is_valid(new_trans) -def position_to_coordinate(depth, position): - """ - [ (0,0) (0,1) .. (0,w) - (1,0) (1,1) (1,w) +def position_to_coordinate(depth, positions): + """Converts coordinates to positions: + [ (0,0) (0,1) .. (0,w-1) + (1,0) (1,1) (1,w-1) ... - (d,0) (d,1) (d,w) ] + (d-1,0) (d-1,1) (d-1,w-1) + ] --> - [ 0 1 .. w - w+1 w+2 .. 2w + [ 0 d .. (w-1)*d + 1 d+1 ... - d*w+1 d*w+ + d-1 2d-1 w*d-1 + ] :param depth: - :param position: + :param positions: :return: """ coords = () - for p in position: + for p in positions: coords = coords + ((int(p) % depth, int(p) // depth),) # changed x_dim to y_dim return coords def coordinate_to_position(depth, coords): """ - Helper function to + Converts positions to coordinates: + [ 0 d .. (w-1)*d + 1 d+1 + ... + d-1 2d-1 w*d-1 + ] + --> + [ (0,0) (0,1) .. (0,w-1) + (1,0) (1,1) (1,w-1) + ... + (d-1,0) (d-1,1) (d-1,w-1) + ] :param depth: :param coords: @@ -111,6 +124,18 @@ def coordinate_to_position(depth, coords): return position +def get_new_position(position, movement): + """ Utility function that converts a compass movement over a 2D grid to new positions (r, c). """ + if movement == Grid4TransitionsEnum.NORTH: + return (position[0] - 1, position[1]) + elif movement == Grid4TransitionsEnum.EAST: + return (position[0], position[1] + 1) + elif movement == Grid4TransitionsEnum.SOUTH: + return (position[0] + 1, position[1]) + elif movement == Grid4TransitionsEnum.WEST: + return (position[0], position[1] - 1) + + class AStarNode(): """A node class for A* Pathfinding""" @@ -266,18 +291,6 @@ def distance_on_rail(pos1, pos2): return abs(pos1[0] - pos2[0]) + abs(pos1[1] - pos2[1]) -def get_new_position(position, movement): - """ Utility function that converts a compass movement over a 2D grid to new positions (r, c). """ - if movement == Grid4TransitionsEnum.NORTH: - return (position[0] - 1, position[1]) - elif movement == Grid4TransitionsEnum.EAST: - return (position[0], position[1] + 1) - elif movement == Grid4TransitionsEnum.SOUTH: - return (position[0] + 1, position[1]) - elif movement == Grid4TransitionsEnum.WEST: - return (position[0], position[1] - 1) - - def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents): """ Given a `rail' GridTransitionMap, return a random placement of agents (initial position, direction and target). diff --git a/tests/test_env_utils.py b/tests/test_env_utils.py index b6764f59..467051ed 100644 --- a/tests/test_env_utils.py +++ b/tests/test_env_utils.py @@ -1,6 +1,8 @@ import numpy as np +import pytest -from flatland.envs.env_utils import position_to_coordinate, coordinate_to_position +from flatland.core.transitions import Grid4TransitionsEnum +from flatland.envs.env_utils import position_to_coordinate, coordinate_to_position, get_direction depth_to_test = 5 positions_to_test = [0, 5, 1, 6, 20, 30] @@ -19,3 +21,13 @@ def test_coordinate_to_position(): expected_positions = positions_to_test assert np.array_equal(actual_positions, expected_positions), \ "converted positions {}, expected {}".format(actual_positions, expected_positions) + + +def test_get_direction(): + assert get_direction((0,0),(0,1)) == Grid4TransitionsEnum.EAST + assert get_direction((0,0),(0,2)) == Grid4TransitionsEnum.EAST + assert get_direction((0,0),(1,0)) == Grid4TransitionsEnum.SOUTH + assert get_direction((1,0),(0,0)) == Grid4TransitionsEnum.NORTH + assert get_direction((1,0),(0,0)) == Grid4TransitionsEnum.NORTH + with pytest.raises(Exception,match="Could not determine direction"): + get_direction((0,0),(0,0)) == Grid4TransitionsEnum.NORTH -- GitLab