diff --git a/flatland/envs/env_utils.py b/flatland/envs/env_utils.py index ee2c263711906370900c14a6030d57ada573c2ea..626ee5f0b4e1ab22349240bf15461f6df97cea7d 100644 --- a/flatland/envs/env_utils.py +++ b/flatland/envs/env_utils.py @@ -71,30 +71,42 @@ def validate_new_transition(rail_trans, rail_array, prev_pos, current_pos, new_p return rail_trans.is_valid(new_trans) -def position_to_coordinate(width, position): +def position_to_coordinate(depth, position): """ + [ (0,0) (0,1) .. (0,w) + (1,0) (1,1) (1,w) + ... + (d,0) (d,1) (d,w) ] - :param width: + --> + + [ 0 1 .. w + w+1 w+2 .. 2w + ... + d*w+1 d*w+ + + :param depth: :param position: :return: """ coords = () for p in position: - coords = coords + ((int(p) % width, int(p) // width),) # changed x_dim to y_dim + coords = coords + ((int(p) % depth, int(p) // depth),) # changed x_dim to y_dim return coords -def coordinate_to_position(width, coords): +def coordinate_to_position(depth, coords): """ + Helper function to - :param width: + :param depth: :param coords: :return: """ position = np.empty(len(coords), dtype=int) idx = 0 for t in coords: - position[idx] = int(t[1] * width + t[0]) + position[idx] = int(t[1] * depth + t[0]) idx += 1 return position diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py index 3fda7378771bf2b5cd20c7d065c9454d3e5629a5..43909669a67c527ae6fb935e22810eb47c9608cd 100644 --- a/flatland/envs/predictions.py +++ b/flatland/envs/predictions.py @@ -112,7 +112,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): agents = self.env.agents if handle: agents = [self.env.agents[handle]] - assert custom_args + assert custom_args is not None distance_map = custom_args.get('distance_map') assert distance_map is not None diff --git a/tests/test_env_utils.py b/tests/test_env_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b6764f59ffd522a87bdeee990cace285a1d73cd6 --- /dev/null +++ b/tests/test_env_utils.py @@ -0,0 +1,21 @@ +import numpy as np + +from flatland.envs.env_utils import position_to_coordinate, coordinate_to_position + +depth_to_test = 5 +positions_to_test = [0, 5, 1, 6, 20, 30] +coordinates_to_test = [[0, 0], [0, 1], [1, 0], [1, 1], [0, 4], [0, 6]] + + +def test_position_to_coordinate(): + actual_coordinates = position_to_coordinate(depth_to_test, positions_to_test) + expected_coordinates = coordinates_to_test + assert np.array_equal(actual_coordinates, expected_coordinates), \ + "converted positions {}, expected {}".format(actual_coordinates, expected_coordinates) + + +def test_coordinate_to_position(): + actual_positions = coordinate_to_position(depth_to_test, coordinates_to_test) + expected_positions = positions_to_test + assert np.array_equal(actual_positions, expected_positions), \ + "converted positions {}, expected {}".format(actual_positions, expected_positions)