Skip to content
Snippets Groups Projects
Commit 870fdf43 authored by u214892's avatar u214892
Browse files

unit test for env_utils, cleanup predictions

parent 1a4aea23
No related branches found
No related tags found
No related merge requests found
...@@ -10,7 +10,7 @@ import numpy as np ...@@ -10,7 +10,7 @@ import numpy as np
from flatland.core.transitions import Grid4TransitionsEnum from flatland.core.transitions import Grid4TransitionsEnum
def get_direction(pos1, pos2): def get_direction(pos1, pos2) -> Grid4TransitionsEnum:
""" """
Assumes pos1 and pos2 are adjacent location on grid. Assumes pos1 and pos2 are adjacent location on grid.
Returns direction (int) that can be used with transitions. Returns direction (int) that can be used with transitions.
...@@ -25,7 +25,7 @@ def get_direction(pos1, pos2): ...@@ -25,7 +25,7 @@ def get_direction(pos1, pos2):
return 1 return 1
if diff_1 < 0: if diff_1 < 0:
return 3 return 3
return 0 raise Exception("Could not determine direction {}->{}".format(pos1, pos2))
def mirror(dir): def mirror(dir):
...@@ -71,33 +71,46 @@ def validate_new_transition(rail_trans, rail_array, prev_pos, current_pos, new_p ...@@ -71,33 +71,46 @@ def validate_new_transition(rail_trans, rail_array, prev_pos, current_pos, new_p
return rail_trans.is_valid(new_trans) return rail_trans.is_valid(new_trans)
def position_to_coordinate(depth, position): def position_to_coordinate(depth, positions):
""" """Converts coordinates to positions:
[ (0,0) (0,1) .. (0,w) [ (0,0) (0,1) .. (0,w-1)
(1,0) (1,1) (1,w) (1,0) (1,1) (1,w-1)
... ...
(d,0) (d,1) (d,w) ] (d-1,0) (d-1,1) (d-1,w-1)
]
--> -->
[ 0 1 .. w [ 0 d .. (w-1)*d
w+1 w+2 .. 2w 1 d+1
... ...
d*w+1 d*w+ d-1 2d-1 w*d-1
]
:param depth: :param depth:
:param position: :param positions:
:return: :return:
""" """
coords = () coords = ()
for p in position: for p in positions:
coords = coords + ((int(p) % depth, int(p) // depth),) # changed x_dim to y_dim coords = coords + ((int(p) % depth, int(p) // depth),) # changed x_dim to y_dim
return coords return coords
def coordinate_to_position(depth, coords): def coordinate_to_position(depth, coords):
""" """
Helper function to Converts positions to coordinates:
[ 0 d .. (w-1)*d
1 d+1
...
d-1 2d-1 w*d-1
]
-->
[ (0,0) (0,1) .. (0,w-1)
(1,0) (1,1) (1,w-1)
...
(d-1,0) (d-1,1) (d-1,w-1)
]
:param depth: :param depth:
:param coords: :param coords:
...@@ -111,6 +124,18 @@ def coordinate_to_position(depth, coords): ...@@ -111,6 +124,18 @@ def coordinate_to_position(depth, coords):
return position return position
def get_new_position(position, movement):
""" Utility function that converts a compass movement over a 2D grid to new positions (r, c). """
if movement == Grid4TransitionsEnum.NORTH:
return (position[0] - 1, position[1])
elif movement == Grid4TransitionsEnum.EAST:
return (position[0], position[1] + 1)
elif movement == Grid4TransitionsEnum.SOUTH:
return (position[0] + 1, position[1])
elif movement == Grid4TransitionsEnum.WEST:
return (position[0], position[1] - 1)
class AStarNode(): class AStarNode():
"""A node class for A* Pathfinding""" """A node class for A* Pathfinding"""
...@@ -266,18 +291,6 @@ def distance_on_rail(pos1, pos2): ...@@ -266,18 +291,6 @@ def distance_on_rail(pos1, pos2):
return abs(pos1[0] - pos2[0]) + abs(pos1[1] - pos2[1]) return abs(pos1[0] - pos2[0]) + abs(pos1[1] - pos2[1])
def get_new_position(position, movement):
""" Utility function that converts a compass movement over a 2D grid to new positions (r, c). """
if movement == Grid4TransitionsEnum.NORTH:
return (position[0] - 1, position[1])
elif movement == Grid4TransitionsEnum.EAST:
return (position[0], position[1] + 1)
elif movement == Grid4TransitionsEnum.SOUTH:
return (position[0] + 1, position[1])
elif movement == Grid4TransitionsEnum.WEST:
return (position[0], position[1] - 1)
def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents): def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents):
""" """
Given a `rail' GridTransitionMap, return a random placement of agents (initial position, direction and target). Given a `rail' GridTransitionMap, return a random placement of agents (initial position, direction and target).
......
import numpy as np import numpy as np
import pytest
from flatland.envs.env_utils import position_to_coordinate, coordinate_to_position from flatland.core.transitions import Grid4TransitionsEnum
from flatland.envs.env_utils import position_to_coordinate, coordinate_to_position, get_direction
depth_to_test = 5 depth_to_test = 5
positions_to_test = [0, 5, 1, 6, 20, 30] positions_to_test = [0, 5, 1, 6, 20, 30]
...@@ -19,3 +21,13 @@ def test_coordinate_to_position(): ...@@ -19,3 +21,13 @@ def test_coordinate_to_position():
expected_positions = positions_to_test expected_positions = positions_to_test
assert np.array_equal(actual_positions, expected_positions), \ assert np.array_equal(actual_positions, expected_positions), \
"converted positions {}, expected {}".format(actual_positions, expected_positions) "converted positions {}, expected {}".format(actual_positions, expected_positions)
def test_get_direction():
assert get_direction((0,0),(0,1)) == Grid4TransitionsEnum.EAST
assert get_direction((0,0),(0,2)) == Grid4TransitionsEnum.EAST
assert get_direction((0,0),(1,0)) == Grid4TransitionsEnum.SOUTH
assert get_direction((1,0),(0,0)) == Grid4TransitionsEnum.NORTH
assert get_direction((1,0),(0,0)) == Grid4TransitionsEnum.NORTH
with pytest.raises(Exception,match="Could not determine direction"):
get_direction((0,0),(0,0)) == Grid4TransitionsEnum.NORTH
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment