Skip to content
Snippets Groups Projects
Commit b582cf50 authored by Christian Eichenberger's avatar Christian Eichenberger :badminton:
Browse files

Merge branch '66-shortest-path-predictor-cleanup-tests' into 'master'

Resolve "shortest-path-predictor"

Closes #66

See merge request !65
parents 61672cda f8afd304
No related branches found
No related tags found
No related merge requests found
...@@ -10,7 +10,7 @@ import numpy as np ...@@ -10,7 +10,7 @@ import numpy as np
from flatland.core.transitions import Grid4TransitionsEnum from flatland.core.transitions import Grid4TransitionsEnum
def get_direction(pos1, pos2): def get_direction(pos1, pos2) -> Grid4TransitionsEnum:
""" """
Assumes pos1 and pos2 are adjacent location on grid. Assumes pos1 and pos2 are adjacent location on grid.
Returns direction (int) that can be used with transitions. Returns direction (int) that can be used with transitions.
...@@ -25,7 +25,7 @@ def get_direction(pos1, pos2): ...@@ -25,7 +25,7 @@ def get_direction(pos1, pos2):
return 1 return 1
if diff_1 < 0: if diff_1 < 0:
return 3 return 3
return 0 raise Exception("Could not determine direction {}->{}".format(pos1, pos2))
def mirror(dir): def mirror(dir):
...@@ -71,34 +71,71 @@ def validate_new_transition(rail_trans, rail_array, prev_pos, current_pos, new_p ...@@ -71,34 +71,71 @@ def validate_new_transition(rail_trans, rail_array, prev_pos, current_pos, new_p
return rail_trans.is_valid(new_trans) return rail_trans.is_valid(new_trans)
def position_to_coordinate(width, position): def position_to_coordinate(depth, positions):
""" """Converts coordinates to positions:
[ (0,0) (0,1) .. (0,w-1)
(1,0) (1,1) (1,w-1)
...
(d-1,0) (d-1,1) (d-1,w-1)
]
-->
:param width: [ 0 d .. (w-1)*d
:param position: 1 d+1
...
d-1 2d-1 w*d-1
]
:param depth:
:param positions:
:return: :return:
""" """
coords = () coords = ()
for p in position: for p in positions:
coords = coords + ((int(p) % width, int(p) // width),) # changed x_dim to y_dim coords = coords + ((int(p) % depth, int(p) // depth),) # changed x_dim to y_dim
return coords return coords
def coordinate_to_position(width, coords): def coordinate_to_position(depth, coords):
""" """
Converts positions to coordinates:
:param width: [ 0 d .. (w-1)*d
1 d+1
...
d-1 2d-1 w*d-1
]
-->
[ (0,0) (0,1) .. (0,w-1)
(1,0) (1,1) (1,w-1)
...
(d-1,0) (d-1,1) (d-1,w-1)
]
:param depth:
:param coords: :param coords:
:return: :return:
""" """
position = np.empty(len(coords), dtype=int) position = np.empty(len(coords), dtype=int)
idx = 0 idx = 0
for t in coords: for t in coords:
position[idx] = int(t[1] * width + t[0]) position[idx] = int(t[1] * depth + t[0])
idx += 1 idx += 1
return position return position
def get_new_position(position, movement):
""" Utility function that converts a compass movement over a 2D grid to new positions (r, c). """
if movement == Grid4TransitionsEnum.NORTH:
return (position[0] - 1, position[1])
elif movement == Grid4TransitionsEnum.EAST:
return (position[0], position[1] + 1)
elif movement == Grid4TransitionsEnum.SOUTH:
return (position[0] + 1, position[1])
elif movement == Grid4TransitionsEnum.WEST:
return (position[0], position[1] - 1)
class AStarNode(): class AStarNode():
"""A node class for A* Pathfinding""" """A node class for A* Pathfinding"""
...@@ -254,18 +291,6 @@ def distance_on_rail(pos1, pos2): ...@@ -254,18 +291,6 @@ def distance_on_rail(pos1, pos2):
return abs(pos1[0] - pos2[0]) + abs(pos1[1] - pos2[1]) return abs(pos1[0] - pos2[0]) + abs(pos1[1] - pos2[1])
def get_new_position(position, movement):
""" Utility function that converts a compass movement over a 2D grid to new positions (r, c). """
if movement == Grid4TransitionsEnum.NORTH:
return (position[0] - 1, position[1])
elif movement == Grid4TransitionsEnum.EAST:
return (position[0], position[1] + 1)
elif movement == Grid4TransitionsEnum.SOUTH:
return (position[0] + 1, position[1])
elif movement == Grid4TransitionsEnum.WEST:
return (position[0], position[1] - 1)
def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents): def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents):
""" """
Given a `rail' GridTransitionMap, return a random placement of agents (initial position, direction and target). Given a `rail' GridTransitionMap, return a random placement of agents (initial position, direction and target).
......
...@@ -112,7 +112,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): ...@@ -112,7 +112,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
agents = self.env.agents agents = self.env.agents
if handle: if handle:
agents = [self.env.agents[handle]] agents = [self.env.agents[handle]]
assert custom_args assert custom_args is not None
distance_map = custom_args.get('distance_map') distance_map = custom_args.get('distance_map')
assert distance_map is not None assert distance_map is not None
......
import numpy as np
import pytest
from flatland.core.transitions import Grid4TransitionsEnum
from flatland.envs.env_utils import position_to_coordinate, coordinate_to_position, get_direction
depth_to_test = 5
positions_to_test = [0, 5, 1, 6, 20, 30]
coordinates_to_test = [[0, 0], [0, 1], [1, 0], [1, 1], [0, 4], [0, 6]]
def test_position_to_coordinate():
actual_coordinates = position_to_coordinate(depth_to_test, positions_to_test)
expected_coordinates = coordinates_to_test
assert np.array_equal(actual_coordinates, expected_coordinates), \
"converted positions {}, expected {}".format(actual_coordinates, expected_coordinates)
def test_coordinate_to_position():
actual_positions = coordinate_to_position(depth_to_test, coordinates_to_test)
expected_positions = positions_to_test
assert np.array_equal(actual_positions, expected_positions), \
"converted positions {}, expected {}".format(actual_positions, expected_positions)
def test_get_direction():
assert get_direction((0, 0), (0, 1)) == Grid4TransitionsEnum.EAST
assert get_direction((0, 0), (0, 2)) == Grid4TransitionsEnum.EAST
assert get_direction((0, 0), (1, 0)) == Grid4TransitionsEnum.SOUTH
assert get_direction((1, 0), (0, 0)) == Grid4TransitionsEnum.NORTH
assert get_direction((1, 0), (0, 0)) == Grid4TransitionsEnum.NORTH
with pytest.raises(Exception, match="Could not determine direction"):
get_direction((0, 0), (0, 0)) == Grid4TransitionsEnum.NORTH
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment