Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Flatland
Flatland
Commits
f6e81e1a
Commit
f6e81e1a
authored
Jun 17, 2019
by
Christian Eichenberger
🏸
Committed by
spiglerg
Jun 17, 2019
Browse files
Resolve "shortest-path-predictor"
parent
a154294c
Changes
9
Hide whitespace changes
Inline
Side-by-side
flatland/core/env_prediction_builder.py
View file @
f6e81e1a
...
...
@@ -13,6 +13,7 @@ case of multi-agent environments.
class
PredictionBuilder
:
"""
PredictionBuilder base class.
"""
def
__init__
(
self
,
max_depth
:
int
=
20
):
...
...
@@ -27,12 +28,15 @@ class PredictionBuilder:
"""
pass
def
get
(
self
,
handle
=
0
):
def
get
(
self
,
custom_args
=
None
,
handle
=
0
):
"""
Called whenever
predict is called on the environment
.
Called whenever
get_many in the observation build is called
.
Parameters
-------
custom_args: dict
Implementation-dependent custom arguments, see the sub-classes.
handle : int (optional)
Handle of the agent for which to compute the observation vector.
...
...
flatland/core/transitions.py
View file @
f6e81e1a
...
...
@@ -3,6 +3,7 @@ The transitions module defines the base Transitions class and a
derived GridTransitions class, which allows for the specification of
possible transitions over a 2D grid.
"""
from
enum
import
IntEnum
import
numpy
as
np
...
...
@@ -129,6 +130,16 @@ class Transitions:
"""
raise
NotImplementedError
()
def
get_direction_enum
(
self
)
->
IntEnum
:
raise
NotImplementedError
()
class
Grid4TransitionsEnum
(
IntEnum
):
NORTH
=
0
EAST
=
1
SOUTH
=
2
WEST
=
3
class
Grid4Transitions
(
Transitions
):
"""
...
...
@@ -323,6 +334,20 @@ class Grid4Transitions(Transitions):
cell_transition
=
value
return
cell_transition
def
get_direction_enum
(
self
)
->
IntEnum
:
return
Grid4TransitionsEnum
class
Grid8TransitionsEnum
(
IntEnum
):
NORTH
=
0
NORTH_EAST
=
1
EAST
=
2
SOUTH_EAST
=
3
SOUTH
=
4
SOUTH_WEST
=
5
WEST
=
6
NORTH_WEST
=
7
class
Grid8Transitions
(
Transitions
):
"""
...
...
@@ -504,6 +529,9 @@ class Grid8Transitions(Transitions):
return
cell_transition
def
get_direction_enum
(
self
)
->
IntEnum
:
return
Grid8TransitionsEnum
class
RailEnvTransitions
(
Grid4Transitions
):
"""
...
...
flatland/envs/env_utils.py
View file @
f6e81e1a
...
...
@@ -7,6 +7,8 @@ a GridTransitionMap object.
import
numpy
as
np
from
flatland.core.transitions
import
Grid4TransitionsEnum
def
get_direction
(
pos1
,
pos2
):
"""
...
...
@@ -253,13 +255,14 @@ def distance_on_rail(pos1, pos2):
def
get_new_position
(
position
,
movement
):
if
movement
==
0
:
# NORTH
""" Utility function that converts a compass movement over a 2D grid to new positions (r, c). """
if
movement
==
Grid4TransitionsEnum
.
NORTH
:
return
(
position
[
0
]
-
1
,
position
[
1
])
elif
movement
==
1
:
#
EAST
elif
movement
==
Grid4TransitionsEnum
.
EAST
:
return
(
position
[
0
],
position
[
1
]
+
1
)
elif
movement
==
2
:
#
SOUTH
elif
movement
==
Grid4TransitionsEnum
.
SOUTH
:
return
(
position
[
0
]
+
1
,
position
[
1
])
elif
movement
==
3
:
#
WEST
elif
movement
==
Grid4TransitionsEnum
.
WEST
:
return
(
position
[
0
],
position
[
1
]
-
1
)
...
...
flatland/envs/observations.py
View file @
f6e81e1a
...
...
@@ -6,6 +6,7 @@ from collections import deque
import
numpy
as
np
from
flatland.core.env_observation_builder
import
ObservationBuilder
from
flatland.core.transitions
import
Grid4TransitionsEnum
from
flatland.envs.env_utils
import
coordinate_to_position
...
...
@@ -48,16 +49,19 @@ class TreeObsForRailEnv(ObservationBuilder):
self
.
agents_previous_reset
=
agents
if
compute_distance_map
:
self
.
distance_map
=
np
.
inf
*
np
.
ones
(
shape
=
(
nAgents
,
# self.env.number_of_agents,
self
.
env
.
height
,
self
.
env
.
width
,
4
))
self
.
max_dist
=
np
.
zeros
(
nAgents
)
self
.
_compute_distance_map
()
self
.
max_dist
=
[
self
.
_distance_map_walker
(
agent
.
target
,
i
)
for
i
,
agent
in
enumerate
(
agents
)]
# Update local lookup table for all agents' target locations
self
.
location_has_target
=
{
tuple
(
agent
.
target
):
1
for
agent
in
agents
}
def
_compute_distance_map
(
self
):
agents
=
self
.
env
.
agents
nAgents
=
len
(
agents
)
self
.
distance_map
=
np
.
inf
*
np
.
ones
(
shape
=
(
nAgents
,
# self.env.number_of_agents,
self
.
env
.
height
,
self
.
env
.
width
,
4
))
self
.
max_dist
=
np
.
zeros
(
nAgents
)
self
.
max_dist
=
[
self
.
_distance_map_walker
(
agent
.
target
,
i
)
for
i
,
agent
in
enumerate
(
agents
)]
# Update local lookup table for all agents' target locations
self
.
location_has_target
=
{
tuple
(
agent
.
target
):
1
for
agent
in
agents
}
def
_distance_map_walker
(
self
,
position
,
target_nr
):
"""
...
...
@@ -159,13 +163,13 @@ class TreeObsForRailEnv(ObservationBuilder):
"""
Utility function that converts a compass movement over a 2D grid to new positions (r, c).
"""
if
movement
==
0
:
#
NORTH
if
movement
==
Grid4TransitionsEnum
.
NORTH
:
return
(
position
[
0
]
-
1
,
position
[
1
])
elif
movement
==
1
:
#
EAST
elif
movement
==
Grid4TransitionsEnum
.
EAST
:
return
(
position
[
0
],
position
[
1
]
+
1
)
elif
movement
==
2
:
#
SOUTH
elif
movement
==
Grid4TransitionsEnum
.
SOUTH
:
return
(
position
[
0
]
+
1
,
position
[
1
])
elif
movement
==
3
:
#
WEST
elif
movement
==
Grid4TransitionsEnum
.
WEST
:
return
(
position
[
0
],
position
[
1
]
-
1
)
def
get_many
(
self
,
handles
=
[]):
...
...
@@ -177,7 +181,7 @@ class TreeObsForRailEnv(ObservationBuilder):
if
self
.
predictor
:
self
.
predicted_pos
=
{}
self
.
predicted_dir
=
{}
self
.
predictions
=
self
.
predictor
.
get
(
self
.
distance_map
)
self
.
predictions
=
self
.
predictor
.
get
(
custom_args
=
{
'distance_map'
:
self
.
distance_map
}
)
for
t
in
range
(
len
(
self
.
predictions
[
0
])):
pos_list
=
[]
dir_list
=
[]
...
...
@@ -796,8 +800,3 @@ class LocalObsForRailEnv(ObservationBuilder):
direction
=
self
.
_get_one_hot_for_agent_direction
(
agent
)
return
local_rail_obs
,
obs_map_state
,
obs_other_agents_state
,
direction
# class LocalObsForRailEnvImproved(ObservationBuilder):
# """
# Returns a local observation around the given agent
# """
flatland/envs/predictions.py
View file @
f6e81e1a
...
...
@@ -5,6 +5,7 @@ Collection of environment-specific PredictionBuilder.
import
numpy
as
np
from
flatland.core.env_prediction_builder
import
PredictionBuilder
from
flatland.envs.env_utils
import
get_new_position
from
flatland.envs.rail_env
import
RailEnvActions
...
...
@@ -16,24 +17,28 @@ class DummyPredictorForRailEnv(PredictionBuilder):
The prediction acts as if no other agent is in the environment and always takes the forward action.
"""
def
get
(
self
,
distancemap
,
handle
=
None
):
def
get
(
self
,
custom_args
=
None
,
handle
=
None
):
"""
Called whenever
predict is called on the environment
.
Called whenever
get_many in the observation build is called
.
Parameters
-------
custom_args: dict
Not used in this dummy implementation.
handle : int (optional)
Handle of the agent for which to compute the observation vector.
Returns
-------
function
Returns a dictionary index by the agent handle and for each agent a vector of 5 elements:
np.array
Returns a dictionary index
ed
by the agent handle and for each agent a vector of
(max_depth + 1)x
5 elements:
- time_offset
- position axis 0
- position axis 1
- direction
- action taken to come here
The prediction at 0 is the current position, direction etc.
"""
agents
=
self
.
env
.
agents
if
handle
:
...
...
@@ -46,13 +51,12 @@ class DummyPredictorForRailEnv(PredictionBuilder):
_agent_initial_position
=
agent
.
position
_agent_initial_direction
=
agent
.
direction
prediction
=
np
.
zeros
(
shape
=
(
self
.
max_depth
+
1
,
5
))
prediction
[
0
]
=
[
0
,
_agent_initial_position
[
0
],
_agent_initial_position
[
1
]
,
_agent_initial_direction
,
0
]
prediction
[
0
]
=
[
0
,
*
_agent_initial_position
,
_agent_initial_direction
,
0
]
for
index
in
range
(
1
,
self
.
max_depth
+
1
):
action_done
=
False
# if we're at the target, stop moving...
if
agent
.
position
==
agent
.
target
:
prediction
[
index
]
=
[
index
,
agent
.
target
[
0
],
agent
.
target
[
1
],
agent
.
direction
,
RailEnvActions
.
STOP_MOVING
]
prediction
[
index
]
=
[
index
,
*
agent
.
target
,
agent
.
direction
,
RailEnvActions
.
STOP_MOVING
]
continue
for
action
in
action_priorities
:
...
...
@@ -63,7 +67,7 @@ class DummyPredictorForRailEnv(PredictionBuilder):
# performed
agent
.
position
=
new_position
agent
.
direction
=
new_direction
prediction
[
index
]
=
[
index
,
new_position
[
0
],
new_position
[
1
]
,
new_direction
,
action
]
prediction
[
index
]
=
[
index
,
*
new_position
,
new_direction
,
action
]
action_done
=
True
break
if
not
action_done
:
...
...
@@ -76,90 +80,95 @@ class DummyPredictorForRailEnv(PredictionBuilder):
class
ShortestPathPredictorForRailEnv
(
PredictionBuilder
):
"""
Dummy
PredictorForRailEnv object.
ShortestPath
PredictorForRailEnv object.
This object returns predictions for agents in the RailEnv environment.
This object returns
shortest-path
predictions for agents in the RailEnv environment.
The prediction acts as if no other agent is in the environment and always takes the forward action.
"""
def
get
(
self
,
distancemap
,
handle
=
None
):
def
get
(
self
,
custom_args
=
None
,
handle
=
None
):
"""
Called whenever predict is called on the environment.
Called whenever get_many in the observation build is called.
Requires distance_map to extract the shortest path.
Parameters
-------
custom_args: dict
- distance_map : dict
handle : int (optional)
Handle of the agent for which to compute the observation vector.
Returns
-------
function
Returns a dictionary index by the agent handle and for each agent a vector of 5 elements:
np.array
Returns a dictionary index
ed
by the agent handle and for each agent a vector of
(max_depth + 1)x
5 elements:
- time_offset
- position axis 0
- position axis 1
- direction
- action taken to come here
The prediction at 0 is the current position, direction etc.
"""
agents
=
self
.
env
.
agents
if
handle
:
agents
=
[
self
.
env
.
agents
[
handle
]]
assert
custom_args
distance_map
=
custom_args
.
get
(
'distance_map'
)
assert
distance_map
is
not
None
prediction_dict
=
{}
agent_idx
=
0
for
agent
in
agents
:
_agent_initial_position
=
agent
.
position
_agent_initial_direction
=
agent
.
direction
prediction
=
np
.
zeros
(
shape
=
(
self
.
max_depth
+
1
,
5
))
prediction
[
0
]
=
[
0
,
_agent_initial_position
[
0
],
_agent_initial_position
[
1
]
,
_agent_initial_direction
,
0
]
prediction
[
0
]
=
[
0
,
*
_agent_initial_position
,
_agent_initial_direction
,
0
]
for
index
in
range
(
1
,
self
.
max_depth
+
1
):
# if we're at the target, stop moving...
if
agent
.
position
==
agent
.
target
:
prediction
[
index
]
=
[
index
,
agent
.
target
[
0
],
agent
.
target
[
1
],
agent
.
direction
,
RailEnvActions
.
STOP_MOVING
]
prediction
[
index
]
=
[
index
,
*
agent
.
target
,
agent
.
direction
,
RailEnvActions
.
STOP_MOVING
]
continue
if
not
agent
.
moving
:
prediction
[
index
]
=
[
index
,
agent
.
position
[
0
],
agent
.
position
[
1
],
agent
.
direction
,
RailEnvActions
.
STOP_MOVING
]
prediction
[
index
]
=
[
index
,
*
agent
.
position
,
agent
.
direction
,
RailEnvActions
.
STOP_MOVING
]
continue
# Take shortest possible path
cell_transitions
=
self
.
env
.
rail
.
get_transitions
((
*
agent
.
position
,
agent
.
direction
))
new_position
=
None
new_direction
=
None
if
np
.
sum
(
cell_transitions
)
==
1
:
new_direction
=
np
.
argmax
(
cell_transitions
)
new_position
=
self
.
_new_position
(
agent
.
position
,
new_direction
)
new_position
=
get
_new_position
(
agent
.
position
,
new_direction
)
elif
np
.
sum
(
cell_transitions
)
>
1
:
min_dist
=
np
.
inf
for
direct
in
range
(
4
):
if
cell_transitions
[
direct
]
==
1
:
target_dist
=
distancemap
[
agent
_idx
,
agent
.
position
[
0
],
agent
.
position
[
1
],
direct
]
for
direct
ion
in
range
(
4
):
if
cell_transitions
[
direct
ion
]
==
1
:
target_dist
=
distance
_
map
[
agent
.
handle
,
agent
.
position
[
0
],
agent
.
position
[
1
],
direct
ion
]
if
target_dist
<
min_dist
:
min_dist
=
target_dist
new_direction
=
direct
new_position
=
self
.
_new_position
(
agent
.
position
,
new_direction
)
new_direction
=
direction
new_position
=
get_new_position
(
agent
.
position
,
new_direction
)
else
:
raise
Exception
(
"No transition possible {}"
.
format
(
cell_transitions
))
# which action to take for the transition?
action
=
None
for
_action
in
[
RailEnvActions
.
MOVE_FORWARD
,
RailEnvActions
.
MOVE_RIGHT
,
RailEnvActions
.
MOVE_LEFT
]:
_
,
_
,
_new_direction
,
_new_position
,
_
=
self
.
env
.
_check_action_on_agent
(
_action
,
agent
)
if
np
.
array_equal
(
_new_position
,
new_position
):
action
=
_action
break
assert
action
is
not
None
# update the agent's position and direction
agent
.
position
=
new_position
agent
.
direction
=
new_direction
prediction
[
index
]
=
[
index
,
new_position
[
0
],
new_position
[
1
],
new_direction
,
0
]
action_done
=
True
if
not
action_done
:
raise
Exception
(
"Cannot move further. Something is wrong"
)
# prediction is ready
prediction
[
index
]
=
[
index
,
*
new_position
,
new_direction
,
action
]
prediction_dict
[
agent
.
handle
]
=
prediction
# cleanup: reset initial position
agent
.
position
=
_agent_initial_position
agent
.
direction
=
_agent_initial_direction
agent_idx
+=
1
return
prediction_dict
def
_new_position
(
self
,
position
,
movement
):
"""
Utility function that converts a compass movement over a 2D grid to new positions (r, c).
"""
if
movement
==
0
:
# NORTH
return
(
position
[
0
]
-
1
,
position
[
1
])
elif
movement
==
1
:
# EAST
return
(
position
[
0
],
position
[
1
]
+
1
)
elif
movement
==
2
:
# SOUTH
return
(
position
[
0
]
+
1
,
position
[
1
])
elif
movement
==
3
:
# WEST
return
(
position
[
0
],
position
[
1
]
-
1
)
tests/test_env_observation_builder.py
View file @
f6e81e1a
...
...
@@ -80,11 +80,3 @@ def test_global_obs():
# If this assertion is wrong, it means that the observation returned
# places the agent on an empty cell
assert
(
np
.
sum
(
rail_map
*
global_obs
[
0
][
1
][:,
:,
:
4
].
sum
(
2
))
>
0
)
def
main
():
test_global_obs
()
if
__name__
==
"__main__"
:
main
()
tests/test_env_prediction_builder.py
View file @
f6e81e1a
...
...
@@ -4,15 +4,18 @@
import
numpy
as
np
from
flatland.core.transition_map
import
GridTransitionMap
,
Grid4Transitions
from
flatland.core.transitions
import
Grid4TransitionsEnum
from
flatland.envs.generators
import
rail_from_GridTransitionMap_generator
from
flatland.envs.observations
import
TreeObsForRailEnv
from
flatland.envs.predictions
import
DummyPredictorForRailEnv
from
flatland.envs.predictions
import
DummyPredictorForRailEnv
,
ShortestPathPredictorForRailEnv
from
flatland.envs.rail_env
import
RailEnv
from
flatland.envs.rail_env
import
RailEnvActions
from
flatland.utils.rendertools
import
RenderTool
"""Test predictions for `flatland` package."""
def
test_predictions
():
def
make_simple_rail
():
# We instantiate a very simple rail network on a 7x10 grid:
# |
# |
...
...
@@ -22,7 +25,6 @@ def test_predictions():
# |
# |
# |
cells
=
[
int
(
'0000000000000000'
,
2
),
# empty cell - Case 0
int
(
'1000000000100000'
,
2
),
# Case 1 - straight
int
(
'1001001000100000'
,
2
),
# Case 2 - simple switch
...
...
@@ -31,22 +33,17 @@ def test_predictions():
int
(
'1100110000110011'
,
2
),
# Case 5 - double slip switch
int
(
'0101001000000010'
,
2
),
# Case 6 - symmetrical switch
int
(
'0010000000000000'
,
2
)]
# Case 7 - dead end
transitions
=
Grid4Transitions
([])
empty
=
cells
[
0
]
dead_end_from_south
=
cells
[
7
]
dead_end_from_west
=
transitions
.
rotate_transition
(
dead_end_from_south
,
90
)
dead_end_from_north
=
transitions
.
rotate_transition
(
dead_end_from_south
,
180
)
dead_end_from_east
=
transitions
.
rotate_transition
(
dead_end_from_south
,
270
)
vertical_straight
=
cells
[
1
]
horizontal_straight
=
transitions
.
rotate_transition
(
vertical_straight
,
90
)
double_switch_south_horizontal_straight
=
horizontal_straight
+
cells
[
6
]
double_switch_north_horizontal_straight
=
transitions
.
rotate_transition
(
double_switch_south_horizontal_straight
,
180
)
rail_map
=
np
.
array
(
[[
empty
]
*
3
+
[
dead_end_from_south
]
+
[
empty
]
*
6
]
+
[[
empty
]
*
3
+
[
vertical_straight
]
+
[
empty
]
*
6
]
*
2
+
...
...
@@ -56,26 +53,36 @@ def test_predictions():
[
horizontal_straight
]
*
2
+
[
dead_end_from_west
]]
+
[[
empty
]
*
6
+
[
vertical_straight
]
+
[
empty
]
*
3
]
*
2
+
[[
empty
]
*
6
+
[
dead_end_from_north
]
+
[
empty
]
*
3
],
dtype
=
np
.
uint16
)
rail
=
GridTransitionMap
(
width
=
rail_map
.
shape
[
1
],
height
=
rail_map
.
shape
[
0
],
transitions
=
transitions
)
rail
.
grid
=
rail_map
return
rail
,
rail_map
def
test_dummy_predictor
(
rendering
=
False
):
rail
,
rail_map
=
make_simple_rail
()
env
=
RailEnv
(
width
=
rail_map
.
shape
[
1
],
height
=
rail_map
.
shape
[
0
],
rail_generator
=
rail_from_GridTransitionMap_generator
(
rail
),
number_of_agents
=
1
,
obs_builder_object
=
TreeObsForRailEnv
(
max_depth
=
2
,
predictor
=
DummyPredictorForRailEnv
(
max_depth
=
10
)),
)
env
.
reset
()
# set initial position and direction for testing...
env
.
agents
[
0
].
position
=
(
5
,
6
)
env
.
agents
[
0
].
direction
=
0
env
.
agents
[
0
].
target
=
(
3.
,
0.
)
env
.
agents
[
0
].
target
=
(
3
,
0
)
if
rendering
:
renderer
=
RenderTool
(
env
,
gl
=
"PILSVG"
)
renderer
.
renderEnv
(
show
=
True
,
show_observations
=
False
)
input
(
"Continue?"
)
# test assertions
predictions
=
env
.
obs_builder
.
predictor
.
get
(
None
)
positions
=
np
.
array
(
list
(
map
(
lambda
prediction
:
[
prediction
[
1
],
prediction
[
2
]],
predictions
[
0
])))
positions
=
np
.
array
(
list
(
map
(
lambda
prediction
:
[
*
prediction
[
1
:
3
]],
predictions
[
0
])))
directions
=
np
.
array
(
list
(
map
(
lambda
prediction
:
[
prediction
[
3
]],
predictions
[
0
])))
time_offsets
=
np
.
array
(
list
(
map
(
lambda
prediction
:
[
prediction
[
0
]],
predictions
[
0
])))
actions
=
np
.
array
(
list
(
map
(
lambda
prediction
:
[
prediction
[
4
]],
predictions
[
0
])))
...
...
@@ -139,9 +146,149 @@ def test_predictions():
assert
np
.
array_equal
(
actions
,
expected_actions
)
def
main
():
test_predictions
()
def
test_shortest_path_predictor
(
rendering
=
False
):
rail
,
rail_map
=
make_simple_rail
()
env
=
RailEnv
(
width
=
rail_map
.
shape
[
1
],
height
=
rail_map
.
shape
[
0
],
rail_generator
=
rail_from_GridTransitionMap_generator
(
rail
),
number_of_agents
=
1
,
obs_builder_object
=
TreeObsForRailEnv
(
max_depth
=
2
,
predictor
=
ShortestPathPredictorForRailEnv
()),
)
env
.
reset
()
agent
=
env
.
agents
[
0
]
agent
.
position
=
(
5
,
6
)
# south dead-end
agent
.
direction
=
0
# north
agent
.
target
=
(
3
,
9
)
# east dead-end
agent
.
moving
=
True
if
rendering
:
renderer
=
RenderTool
(
env
,
gl
=
"PILSVG"
)
renderer
.
renderEnv
(
show
=
True
,
show_observations
=
False
)
input
(
"Continue?"
)
agent
=
env
.
agents
[
0
]
assert
agent
.
position
==
(
5
,
6
)
assert
agent
.
direction
==
0
assert
agent
.
target
==
(
3
,
9
)
assert
agent
.
moving
env
.
obs_builder
.
_compute_distance_map
()
distance_map
=
env
.
obs_builder
.
distance_map
assert
distance_map
[
agent
.
handle
,
agent
.
position
[
0
],
agent
.
position
[
1
],
agent
.
direction
]
==
5.0
,
"found {} instead of {}"
.
format
(
distance_map
[
agent
.
handle
,
agent
.
position
[
0
],
agent
.
position
[
1
],
agent
.
direction
],
5.0
)
# test assertions
env
.
obs_builder
.
get_many
()
predictions
=
env
.
obs_builder
.
predictions
positions
=
np
.
array
(
list
(
map
(
lambda
prediction
:
[
*
prediction
[
1
:
3
]],
predictions
[
0
])))
directions
=
np
.
array
(
list
(
map
(
lambda
prediction
:
[
prediction
[
3
]],
predictions
[
0
])))
time_offsets
=
np
.
array
(
list
(
map
(
lambda
prediction
:
[
prediction
[
0
]],
predictions
[
0
])))
actions
=
np
.
array
(
list
(
map
(
lambda
prediction
:
[
prediction
[
4
]],
predictions
[
0
])))
expected_positions
=
[
[
5
,
6
],
[
4
,
6
],
[
3
,
6
],
[
3
,
7
],
[
3
,
8
],
[
3
,
9
],
[
3
,
9
],
[
3
,
9
],
[
3
,
9
],
[
3
,
9
],
[
3
,
9
],
[
3
,
9
],
[
3
,
9
],
[
3
,
9
],
[
3
,
9
],
[
3
,
9
],
[
3
,
9
],
[
3
,
9
],
[
3
,
9
],
[
3
,
9
],
[
3
,
9
],
]
expected_directions
=
[
[
Grid4TransitionsEnum
.
NORTH
],
# next is [5,6] heading north
[
Grid4TransitionsEnum
.
NORTH
],
# next is [4,6] heading north
[
Grid4TransitionsEnum
.
NORTH
],
# next is [3,6] heading north
[
Grid4TransitionsEnum
.
EAST
],
# next is [3,7] heading east
[
Grid4TransitionsEnum
.
EAST
],
[
Grid4TransitionsEnum
.
EAST
],
[
Grid4TransitionsEnum
.
EAST
],
[
Grid4TransitionsEnum
.
EAST
],
[
Grid4TransitionsEnum
.
EAST
],
[
Grid4TransitionsEnum
.
EAST
],
[
Grid4TransitionsEnum
.
EAST
],
[
Grid4TransitionsEnum
.
EAST
],
[
Grid4TransitionsEnum
.
EAST
],
[
Grid4TransitionsEnum
.
EAST
],
[
Grid4TransitionsEnum
.
EAST
],
[
Grid4TransitionsEnum
.
EAST
],
[
Grid4TransitionsEnum
.
EAST
],
[
Grid4TransitionsEnum
.
EAST
],
[
Grid4TransitionsEnum
.
EAST
],
[
Grid4TransitionsEnum
.
EAST
],
[
Grid4TransitionsEnum
.
EAST
],
]
expected_time_offsets
=
np
.
array
([
[
0.
],
[
1.
],
[
2.
],
[
3.
],
[
4.
],
[
5.
],
[
6.
],
[
7.
],
[
8.
],
[
9.
],
[
10.
],
[
11.
],
[
12.
],
[
13.
],
[
14.
],
[
15.
],
[
16.
],
[
17.
],
[
18.
],
[
19.
],
[
20.
],
])
expected_actions
=
np
.
array
([
[
RailEnvActions
.
DO_NOTHING
],
# next [5,6]