Skip to content
Snippets Groups Projects
Commit 1a4aea23 authored by u214892's avatar u214892
Browse files

unit test for env_utils, cleanup predictions

parent 8a9ba90d
No related branches found
No related tags found
No related merge requests found
......@@ -71,30 +71,42 @@ def validate_new_transition(rail_trans, rail_array, prev_pos, current_pos, new_p
return rail_trans.is_valid(new_trans)
def position_to_coordinate(width, position):
def position_to_coordinate(depth, position):
"""
[ (0,0) (0,1) .. (0,w)
(1,0) (1,1) (1,w)
...
(d,0) (d,1) (d,w) ]
:param width:
-->
[ 0 1 .. w
w+1 w+2 .. 2w
...
d*w+1 d*w+
:param depth:
:param position:
:return:
"""
coords = ()
for p in position:
coords = coords + ((int(p) % width, int(p) // width),) # changed x_dim to y_dim
coords = coords + ((int(p) % depth, int(p) // depth),) # changed x_dim to y_dim
return coords
def coordinate_to_position(width, coords):
def coordinate_to_position(depth, coords):
"""
Helper function to
:param width:
:param depth:
:param coords:
:return:
"""
position = np.empty(len(coords), dtype=int)
idx = 0
for t in coords:
position[idx] = int(t[1] * width + t[0])
position[idx] = int(t[1] * depth + t[0])
idx += 1
return position
......
......@@ -112,7 +112,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
agents = self.env.agents
if handle:
agents = [self.env.agents[handle]]
assert custom_args
assert custom_args is not None
distance_map = custom_args.get('distance_map')
assert distance_map is not None
......
import numpy as np
from flatland.envs.env_utils import position_to_coordinate, coordinate_to_position
depth_to_test = 5
positions_to_test = [0, 5, 1, 6, 20, 30]
coordinates_to_test = [[0, 0], [0, 1], [1, 0], [1, 1], [0, 4], [0, 6]]
def test_position_to_coordinate():
actual_coordinates = position_to_coordinate(depth_to_test, positions_to_test)
expected_coordinates = coordinates_to_test
assert np.array_equal(actual_coordinates, expected_coordinates), \
"converted positions {}, expected {}".format(actual_coordinates, expected_coordinates)
def test_coordinate_to_position():
actual_positions = coordinate_to_position(depth_to_test, coordinates_to_test)
expected_positions = positions_to_test
assert np.array_equal(actual_positions, expected_positions), \
"converted positions {}, expected {}".format(actual_positions, expected_positions)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment