Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Flatland
Flatland
Commits
c7d695c4
Commit
c7d695c4
authored
Jul 05, 2019
by
u214892
Browse files
unit test for conflicts of multiple agents
parent
3bc64feb
Changes
7
Hide whitespace changes
Inline
Side-by-side
examples/simple_example_3.py
View file @
c7d695c4
...
...
@@ -19,7 +19,7 @@ env = RailEnv(width=7,
# Print the observation vector for agent 0
obs
,
all_rewards
,
done
,
_
=
env
.
step
({
0
:
0
})
for
i
in
range
(
env
.
get_num_agents
()):
env
.
obs_builder
.
util_print_obs_subtree
(
tree
=
obs
[
i
]
,
num_features_per_node
=
7
)
env
.
obs_builder
.
util_print_obs_subtree
(
tree
=
obs
[
i
])
env_renderer
=
RenderTool
(
env
)
env_renderer
.
renderEnv
(
show
=
True
,
frames
=
True
)
...
...
flatland/core/grid/grid4.py
View file @
c7d695c4
...
...
@@ -11,6 +11,13 @@ class Grid4TransitionsEnum(IntEnum):
SOUTH
=
2
WEST
=
3
@
staticmethod
def
to_char
(
int
:
int
):
return
{
0
:
'N'
,
1
:
'E'
,
2
:
'S'
,
3
:
'W'
}[
int
]
class
Grid4Transitions
(
Transitions
):
"""
...
...
flatland/envs/observations.py
View file @
c7d695c4
"""
Collection of environment-specific ObservationBuilder.
"""
import
pprint
from
collections
import
deque
import
numpy
as
np
...
...
@@ -34,6 +35,8 @@ class TreeObsForRailEnv(ObservationBuilder):
self
.
location_has_agent_direction
=
{}
self
.
predictor
=
predictor
self
.
agents_previous_reset
=
None
self
.
tree_explored_actions
=
[
1
,
2
,
3
,
0
]
self
.
tree_explorted_actions_char
=
[
'L'
,
'F'
,
'R'
,
'B'
]
def
reset
(
self
):
agents
=
self
.
env
.
agents
...
...
@@ -126,19 +129,6 @@ class TreeObsForRailEnv(ObservationBuilder):
desired_movement_from_new_cell
=
(
neigh_direction
+
2
)
%
4
"""
# Is the next cell a dead-end?
isNextCellDeadEnd = False
nbits = 0
tmp = self.env.rail.get_transitions((new_cell[0], new_cell[1]))
while tmp > 0:
nbits += (tmp & 1)
tmp = tmp >> 1
if nbits == 1:
# Dead-end!
isNextCellDeadEnd = True
"""
# Check all possible transitions in new_cell
for
agent_orientation
in
range
(
4
):
# Is a transition along movement `desired_movement_from_new_cell' to the current cell possible?
...
...
@@ -213,7 +203,7 @@ class TreeObsForRailEnv(ObservationBuilder):
[... from 'right] +
[... from 'back']
Finally, each node information is composed of
5
floating point values:
Finally, each node information is composed of
8
floating point values:
#1: if own target lies on the explored branch the current distance from the agent in number of cells is stored.
...
...
@@ -268,7 +258,6 @@ class TreeObsForRailEnv(ObservationBuilder):
# Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right, back], relative to the current orientation
# If only one transition is possible, the tree is oriented with this transition as the forward branch.
# TODO: Test if this works as desired!
orientation
=
agent
.
direction
if
num_transitions
==
1
:
...
...
@@ -282,15 +271,20 @@ class TreeObsForRailEnv(ObservationBuilder):
observation
=
observation
+
branch_observation
visited
=
visited
.
union
(
branch_visited
)
else
:
num_cells_to_fill_in
=
0
pow4
=
1
for
i
in
range
(
self
.
max_depth
):
num_cells_to_fill_in
+=
pow4
pow4
*=
4
observation
=
observation
+
([
-
np
.
inf
]
*
self
.
observation_dim
)
*
num_cells_to_fill_in
# add cells filled with infinity if no transition is possible
observation
=
observation
+
[
-
np
.
inf
]
*
self
.
_num_cells_to_fill_in
(
self
.
max_depth
)
self
.
env
.
dev_obs_dict
[
handle
]
=
visited
return
observation
def
_num_cells_to_fill_in
(
self
,
remaining_depth
):
"""Computes the length of observation vector: sum_{i=0,depth-1} 2^i * observation_dim."""
num_observations
=
0
pow4
=
1
for
i
in
range
(
remaining_depth
):
num_observations
+=
pow4
pow4
*=
4
return
num_observations
*
self
.
observation_dim
def
_explore_branch
(
self
,
handle
,
position
,
direction
,
root_observation
,
tot_dist
,
depth
):
"""
Utility function to compute tree-based observations.
...
...
@@ -334,7 +328,7 @@ class TreeObsForRailEnv(ObservationBuilder):
# Cummulate the number of agents on branch with other direction
other_agent_opposite_direction
+=
1
# Register possible conflict
# Register possible
future
conflict
if
self
.
predictor
and
num_steps
<
self
.
max_prediction_depth
:
int_position
=
coordinate_to_position
(
self
.
env
.
width
,
[
position
])
if
tot_dist
<
self
.
max_prediction_depth
:
...
...
@@ -422,42 +416,6 @@ class TreeObsForRailEnv(ObservationBuilder):
# #############################
# #############################
# Modify here to append new / different features for each visited cell!
"""
other_agent_same_direction =
\
1 if position in self.location_has_agent and self.location_has_agent_direction[position] == direction else 0
other_agent_opposite_direction =
\
1 if position in self.location_has_agent and self.location_has_agent_direction[position] != direction else 0
if last_isTarget:
observation = [0,
other_target_encountered,
other_agent_encountered,
root_observation[3] + num_steps,
0,
other_agent_same_direction,
other_agent_opposite_direction
]
elif last_isTerminal:
observation = [0,
other_target_encountered,
other_agent_encountered,
np.inf,
np.inf,
other_agent_same_direction,
other_agent_opposite_direction
]
else:
observation = [0,
other_target_encountered,
other_agent_encountered,
root_observation[3] + num_steps,
self.distance_map[handle, position[0], position[1], direction],
other_agent_same_direction,
other_agent_opposite_direction
]
"""
if
last_isTarget
:
observation
=
[
own_target_encountered
,
other_target_encountered
,
...
...
@@ -522,41 +480,47 @@ class TreeObsForRailEnv(ObservationBuilder):
if
len
(
branch_visited
)
!=
0
:
visited
=
visited
.
union
(
branch_visited
)
else
:
num_cells_to_fill_in
=
0
pow4
=
1
for
i
in
range
(
self
.
max_depth
-
depth
):
num_cells_to_fill_in
+=
pow4
pow4
*=
4
observation
=
observation
+
([
-
np
.
inf
]
*
self
.
observation_dim
)
*
num_cells_to_fill_in
# no exploring possible, add just cells with infinity
observation
=
observation
+
[
-
np
.
inf
]
*
self
.
_num_cells_to_fill_in
(
self
.
max_depth
-
depth
)
return
observation
,
visited
def
util_print_obs_subtree
(
self
,
tree
,
num_features_per_node
=
8
,
prompt
=
''
,
current_depth
=
0
):
def
util_print_obs_subtree
(
self
,
tree
):
"""
Utility function to pretty-print tree observations returned by this object.
"""
if
len
(
tree
)
<
num_features_per_node
:
pp
=
pprint
.
PrettyPrinter
(
indent
=
4
)
pp
.
pprint
(
self
.
unfold_observation_tree
(
tree
))
def
unfold_observation_tree
(
self
,
tree
,
current_depth
=
0
,
actions_for_display
=
True
):
"""
Utility function to pretty-print tree observations returned by this object.
"""
if
len
(
tree
)
<
self
.
observation_dim
:
return
depth
=
0
tmp
=
len
(
tree
)
/
num_features_per_node
-
1
tmp
=
len
(
tree
)
/
self
.
observation_dim
-
1
pow4
=
4
while
tmp
>
0
:
tmp
-=
pow4
depth
+=
1
pow4
*=
4
prompt_
=
[
'L:'
,
'F:'
,
'R:'
,
'B:'
]
print
(
" "
*
current_depth
+
prompt
,
tree
[
0
:
num_features_per_node
])
child_size
=
(
len
(
tree
)
-
num_features_per_node
)
//
4
for
children
in
range
(
4
):
child_tree
=
tree
[(
num_features_per_node
+
children
*
child_size
):
(
num_features_per_node
+
(
children
+
1
)
*
child_size
)]
self
.
util_print_obs_subtree
(
child_tree
,
num_features_per_node
,
prompt
=
prompt_
[
children
],
current_depth
=
current_depth
+
1
)
unfolded
=
{}
unfolded
[
''
]
=
tree
[
0
:
self
.
observation_dim
]
child_size
=
(
len
(
tree
)
-
self
.
observation_dim
)
//
4
for
child
in
range
(
4
):
child_tree
=
tree
[(
self
.
observation_dim
+
child
*
child_size
):
(
self
.
observation_dim
+
(
child
+
1
)
*
child_size
)]
observation_tree
=
self
.
unfold_observation_tree
(
child_tree
,
current_depth
=
current_depth
+
1
)
if
observation_tree
is
not
None
:
if
actions_for_display
:
label
=
self
.
tree_explorted_actions_char
[
child
]
else
:
label
=
self
.
tree_explored_actions
[
child
]
unfolded
[
label
]
=
observation_tree
return
unfolded
def
_set_env
(
self
,
env
):
self
.
env
=
env
...
...
@@ -725,8 +689,6 @@ class LocalObsForRailEnv(ObservationBuilder):
bitlist
=
[
int
(
digit
)
for
digit
in
bin
(
self
.
env
.
rail
.
get_transitions
((
i
,
j
)))[
2
:]]
bitlist
=
[
0
]
*
(
16
-
len
(
bitlist
))
+
bitlist
self
.
rail_obs
[
i
+
self
.
view_radius
,
j
+
self
.
view_radius
]
=
np
.
array
(
bitlist
)
# self.rail_obs[i+self.view_radius, j+self.view_radius] = np.array(
# list(f'{self.env.rail.get_transitions((i, j)):016b}')).astype(int)
def
get
(
self
,
handle
):
agents
=
self
.
env
.
agents
...
...
flatland/envs/rail_env.py
View file @
c7d695c4
...
...
@@ -19,12 +19,22 @@ from flatland.envs.observations import TreeObsForRailEnv
class
RailEnvActions
(
IntEnum
):
DO_NOTHING
=
0
DO_NOTHING
=
0
# implies change of direction in a dead-end!
MOVE_LEFT
=
1
MOVE_FORWARD
=
2
MOVE_RIGHT
=
3
STOP_MOVING
=
4
@
staticmethod
def
to_char
(
a
:
int
):
return
{
0
:
'B'
,
1
:
'L'
,
2
:
'F'
,
3
:
'R'
,
4
:
'S'
,
}[
a
]
class
RailEnv
(
Environment
):
"""
...
...
tests/simple_rail.py
0 → 100644
View file @
c7d695c4
import
numpy
as
np
from
flatland.core.grid.grid4
import
Grid4Transitions
from
flatland.core.transition_map
import
GridTransitionMap
def
make_simple_rail
():
# We instantiate a very simple rail network on a 7x10 grid:
# |
# |
# |
# _ _ _ /_\ _ _ _ _ _ _
# \ /
# |
# |
# |
cells
=
[
int
(
'0000000000000000'
,
2
),
# empty cell - Case 0
int
(
'1000000000100000'
,
2
),
# Case 1 - straight
int
(
'1001001000100000'
,
2
),
# Case 2 - simple switch
int
(
'1000010000100001'
,
2
),
# Case 3 - diamond drossing
int
(
'1001011000100001'
,
2
),
# Case 4 - single slip switch
int
(
'1100110000110011'
,
2
),
# Case 5 - double slip switch
int
(
'0101001000000010'
,
2
),
# Case 6 - symmetrical switch
int
(
'0010000000000000'
,
2
)]
# Case 7 - dead end
transitions
=
Grid4Transitions
([])
empty
=
cells
[
0
]
dead_end_from_south
=
cells
[
7
]
dead_end_from_west
=
transitions
.
rotate_transition
(
dead_end_from_south
,
90
)
dead_end_from_north
=
transitions
.
rotate_transition
(
dead_end_from_south
,
180
)
dead_end_from_east
=
transitions
.
rotate_transition
(
dead_end_from_south
,
270
)
vertical_straight
=
cells
[
1
]
horizontal_straight
=
transitions
.
rotate_transition
(
vertical_straight
,
90
)
double_switch_south_horizontal_straight
=
horizontal_straight
+
cells
[
6
]
double_switch_north_horizontal_straight
=
transitions
.
rotate_transition
(
double_switch_south_horizontal_straight
,
180
)
rail_map
=
np
.
array
(
[[
empty
]
*
3
+
[
dead_end_from_south
]
+
[
empty
]
*
6
]
+
[[
empty
]
*
3
+
[
vertical_straight
]
+
[
empty
]
*
6
]
*
2
+
[[
dead_end_from_east
]
+
[
horizontal_straight
]
*
2
+
[
double_switch_north_horizontal_straight
]
+
[
horizontal_straight
]
*
2
+
[
double_switch_south_horizontal_straight
]
+
[
horizontal_straight
]
*
2
+
[
dead_end_from_west
]]
+
[[
empty
]
*
6
+
[
vertical_straight
]
+
[
empty
]
*
3
]
*
2
+
[[
empty
]
*
6
+
[
dead_end_from_north
]
+
[
empty
]
*
3
],
dtype
=
np
.
uint16
)
rail
=
GridTransitionMap
(
width
=
rail_map
.
shape
[
1
],
height
=
rail_map
.
shape
[
0
],
transitions
=
transitions
)
rail
.
grid
=
rail_map
return
rail
,
rail_map
tests/test_flatland_envs_observations.py
View file @
c7d695c4
...
...
@@ -3,62 +3,17 @@
import
numpy
as
np
from
flatland.core.transition_map
import
GridTransitionMap
,
Grid4Transitions
from
flatland.envs.generators
import
rail_from_GridTransitionMap_generator
from
flatland.envs.observations
import
GlobalObsForRailEnv
from
flatland.envs.rail_env
import
RailEnv
from
simple_rail
import
make_simple_rail
"""Tests for `flatland` package."""
def
test_global_obs
():
# We instantiate a very simple rail network on a 7x10 grid:
# |
# |
# |
# _ _ _ /_\ _ _ _ _ _ _
# \ /
# |
# |
# |
rail
,
rail_map
=
make_simple_rail
()
cells
=
[
int
(
'0000000000000000'
,
2
),
# empty cell - Case 0
int
(
'1000000000100000'
,
2
),
# Case 1 - straight
int
(
'1001001000100000'
,
2
),
# Case 2 - simple switch
int
(
'1000010000100001'
,
2
),
# Case 3 - diamond drossing
int
(
'1001011000100001'
,
2
),
# Case 4 - single slip switch
int
(
'1100110000110011'
,
2
),
# Case 5 - double slip switch
int
(
'0101001000000010'
,
2
),
# Case 6 - symmetrical switch
int
(
'0010000000000000'
,
2
)]
# Case 7 - dead end
transitions
=
Grid4Transitions
([])
empty
=
cells
[
0
]
dead_end_from_south
=
cells
[
7
]
dead_end_from_west
=
transitions
.
rotate_transition
(
dead_end_from_south
,
90
)
dead_end_from_north
=
transitions
.
rotate_transition
(
dead_end_from_south
,
180
)
dead_end_from_east
=
transitions
.
rotate_transition
(
dead_end_from_south
,
270
)
vertical_straight
=
cells
[
1
]
horizontal_straight
=
transitions
.
rotate_transition
(
vertical_straight
,
90
)
double_switch_south_horizontal_straight
=
horizontal_straight
+
cells
[
6
]
double_switch_north_horizontal_straight
=
transitions
.
rotate_transition
(
double_switch_south_horizontal_straight
,
180
)
rail_map
=
np
.
array
(
[[
empty
]
*
3
+
[
dead_end_from_south
]
+
[
empty
]
*
6
]
+
[[
empty
]
*
3
+
[
vertical_straight
]
+
[
empty
]
*
6
]
*
2
+
[[
dead_end_from_east
]
+
[
horizontal_straight
]
*
2
+
[
double_switch_north_horizontal_straight
]
+
[
horizontal_straight
]
*
2
+
[
double_switch_south_horizontal_straight
]
+
[
horizontal_straight
]
*
2
+
[
dead_end_from_west
]]
+
[[
empty
]
*
6
+
[
vertical_straight
]
+
[
empty
]
*
3
]
*
2
+
[[
empty
]
*
6
+
[
dead_end_from_north
]
+
[
empty
]
*
3
],
dtype
=
np
.
uint16
)
rail
=
GridTransitionMap
(
width
=
rail_map
.
shape
[
1
],
height
=
rail_map
.
shape
[
0
],
transitions
=
transitions
)
rail
.
grid
=
rail_map
env
=
RailEnv
(
width
=
rail_map
.
shape
[
1
],
height
=
rail_map
.
shape
[
0
],
rail_generator
=
rail_from_GridTransitionMap_generator
(
rail
),
...
...
tests/test_flatland_envs_predictions.py
View file @
c7d695c4
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import
pprint
import
numpy
as
np
from
flatland.core.grid.grid4
import
Grid4TransitionsEnum
from
flatland.core.transition_map
import
GridTransitionMap
,
Grid4Transitions
from
flatland.envs.generators
import
rail_from_GridTransitionMap_generator
from
flatland.envs.observations
import
TreeObsForRailEnv
from
flatland.envs.predictions
import
DummyPredictorForRailEnv
,
ShortestPathPredictorForRailEnv
from
flatland.envs.rail_env
import
RailEnv
from
flatland.envs.rail_env
import
RailEnvActions
from
flatland.utils.rendertools
import
RenderTool
from
simple_rail
import
make_simple_rail
"""Test predictions for `flatland` package."""
def
make_simple_rail
():
# We instantiate a very simple rail network on a 7x10 grid:
# |
# |
# |
# _ _ _ /_\ _ _ _ _ _ _
# \ /
# |
# |
# |
cells
=
[
int
(
'0000000000000000'
,
2
),
# empty cell - Case 0
int
(
'1000000000100000'
,
2
),
# Case 1 - straight
int
(
'1001001000100000'
,
2
),
# Case 2 - simple switch
int
(
'1000010000100001'
,
2
),
# Case 3 - diamond drossing
int
(
'1001011000100001'
,
2
),
# Case 4 - single slip switch
int
(
'1100110000110011'
,
2
),
# Case 5 - double slip switch
int
(
'0101001000000010'
,
2
),
# Case 6 - symmetrical switch
int
(
'0010000000000000'
,
2
)]
# Case 7 - dead end
transitions
=
Grid4Transitions
([])
empty
=
cells
[
0
]
dead_end_from_south
=
cells
[
7
]
dead_end_from_west
=
transitions
.
rotate_transition
(
dead_end_from_south
,
90
)
dead_end_from_north
=
transitions
.
rotate_transition
(
dead_end_from_south
,
180
)
dead_end_from_east
=
transitions
.
rotate_transition
(
dead_end_from_south
,
270
)
vertical_straight
=
cells
[
1
]
horizontal_straight
=
transitions
.
rotate_transition
(
vertical_straight
,
90
)
double_switch_south_horizontal_straight
=
horizontal_straight
+
cells
[
6
]
double_switch_north_horizontal_straight
=
transitions
.
rotate_transition
(
double_switch_south_horizontal_straight
,
180
)
rail_map
=
np
.
array
(
[[
empty
]
*
3
+
[
dead_end_from_south
]
+
[
empty
]
*
6
]
+
[[
empty
]
*
3
+
[
vertical_straight
]
+
[
empty
]
*
6
]
*
2
+
[[
dead_end_from_east
]
+
[
horizontal_straight
]
*
2
+
[
double_switch_north_horizontal_straight
]
+
[
horizontal_straight
]
*
2
+
[
double_switch_south_horizontal_straight
]
+
[
horizontal_straight
]
*
2
+
[
dead_end_from_west
]]
+
[[
empty
]
*
6
+
[
vertical_straight
]
+
[
empty
]
*
3
]
*
2
+
[[
empty
]
*
6
+
[
dead_end_from_north
]
+
[
empty
]
*
3
],
dtype
=
np
.
uint16
)
rail
=
GridTransitionMap
(
width
=
rail_map
.
shape
[
1
],
height
=
rail_map
.
shape
[
0
],
transitions
=
transitions
)
rail
.
grid
=
rail_map
return
rail
,
rail_map
def
test_dummy_predictor
(
rendering
=
False
):
rail
,
rail_map
=
make_simple_rail
()
...
...
@@ -68,12 +25,16 @@ def test_dummy_predictor(rendering=False):
number_of_agents
=
1
,
obs_builder_object
=
TreeObsForRailEnv
(
max_depth
=
2
,
predictor
=
DummyPredictorForRailEnv
(
max_depth
=
10
)),
)
# reset to initialize agents_static
env
.
reset
()
# set initial position and direction for testing...
env
.
agents
[
0
].
position
=
(
5
,
6
)
env
.
agents
[
0
].
direction
=
0
env
.
agents
[
0
].
target
=
(
3
,
0
)
env
.
agents_static
[
0
].
position
=
(
5
,
6
)
env
.
agents_static
[
0
].
direction
=
0
env
.
agents_static
[
0
].
target
=
(
3
,
0
)
# reset to set agents from agents_static
env
.
reset
(
False
,
False
)
if
rendering
:
renderer
=
RenderTool
(
env
,
gl
=
"PILSVG"
)
...
...
@@ -154,41 +115,39 @@ def test_shortest_path_predictor(rendering=False):
number_of_agents
=
1
,
obs_builder_object
=
TreeObsForRailEnv
(
max_depth
=
2
,
predictor
=
ShortestPathPredictorForRailEnv
()),
)
# reset to initialize agents_static
env
.
reset
()
agent
=
env
.
agents
[
0
]
# set the initial position
agent
=
env
.
agents_static
[
0
]
agent
.
position
=
(
5
,
6
)
# south dead-end
agent
.
direction
=
0
# north
agent
.
target
=
(
3
,
9
)
# east dead-end
agent
.
moving
=
True
# reset to set agents from agents_static
env
.
reset
(
False
,
False
)
if
rendering
:
renderer
=
RenderTool
(
env
,
gl
=
"PILSVG"
)
renderer
.
renderEnv
(
show
=
True
,
show_observations
=
False
)
input
(
"Continue?"
)
agent
=
env
.
agents
[
0
]
assert
agent
.
position
==
(
5
,
6
)
assert
agent
.
direction
==
0
assert
agent
.
target
==
(
3
,
9
)
assert
agent
.
moving
env
.
obs_builder
.
_compute_distance_map
()
# compute the observations and predictions
distance_map
=
env
.
obs_builder
.
distance_map
assert
distance_map
[
agent
.
handle
,
agent
.
position
[
0
],
agent
.
position
[
assert
distance_map
[
0
,
agent
.
position
[
0
],
agent
.
position
[
1
],
agent
.
direction
]
==
5.0
,
"found {} instead of {}"
.
format
(
distance_map
[
agent
.
handle
,
agent
.
position
[
0
],
agent
.
position
[
1
],
agent
.
direction
],
5.0
)
# test assertions
env
.
obs_builder
.
get_many
()
# extract the data
predictions
=
env
.
obs_builder
.
predictions
positions
=
np
.
array
(
list
(
map
(
lambda
prediction
:
[
*
prediction
[
1
:
3
]],
predictions
[
0
])))
directions
=
np
.
array
(
list
(
map
(
lambda
prediction
:
[
prediction
[
3
]],
predictions
[
0
])))
time_offsets
=
np
.
array
(
list
(
map
(
lambda
prediction
:
[
prediction
[
0
]],
predictions
[
0
])))
actions
=
np
.
array
(
list
(
map
(
lambda
prediction
:
[
prediction
[
4
]],
predictions
[
0
])))
# test if data meets expectations
expected_positions
=
[
[
5
,
6
],
[
4
,
6
],
...
...
@@ -292,3 +251,60 @@ def test_shortest_path_predictor(rendering=False):
"time_offsets {}, expected {}"
.
format
(
time_offsets
,
expected_time_offsets
)
assert
np
.
array_equal
(
actions
,
expected_actions
),
\
"actions {}, expected {}"
.
format
(
actions
,
expected_actions
)
def
test_shortest_path_predictor_conflicts
(
rendering
=
False
):
rail
,
rail_map
=
make_simple_rail
()
env
=
RailEnv
(
width
=
rail_map
.
shape
[
1
],
height
=
rail_map
.
shape
[
0
],
rail_generator
=
rail_from_GridTransitionMap_generator
(
rail
),
number_of_agents
=
2
,
obs_builder_object
=
TreeObsForRailEnv
(
max_depth
=
2
,
predictor
=
ShortestPathPredictorForRailEnv
()),
)
# initialize agents_static
env
.
reset
()
# set the initial position
agent
=
env
.
agents_static
[
0
]
agent
.
position
=
(
5
,
6
)
# south dead-end
agent
.
direction
=
0
# north
agent
.
target
=
(
3
,
9
)
# east dead-end
agent
.
moving
=
True
agent
=
env
.
agents_static
[
1
]
agent
.
position
=
(
3
,
8
)
# east dead-end
agent
.
direction
=
3
# west
agent
.
target
=
(
6
,
6
)
# south dead-end
agent
.
moving
=
True
# reset to set agents from agents_static
observations
=
env
.
reset
(
False
,
False
)
if
rendering
:
renderer
=
RenderTool
(
env
,
gl
=
"PILSVG"
)
renderer
.
renderEnv
(
show
=
True
,
show_observations
=
False
)
input
(
"Continue?"
)
# get the trees to test
obs_builder
:
TreeObsForRailEnv
=
env
.
obs_builder
pp
=
pprint
.
PrettyPrinter
(
indent
=
4
)
tree_0
=
obs_builder
.
unfold_observation_tree
(
observations
[
0
])
tree_1
=
obs_builder
.
unfold_observation_tree
(
observations
[
1
])
pp
.
pprint
(
tree_0
)
# check the expectations
# TODO check with Erik, this should be symmetric, should it not?
expected_conflicts_0
=
[(
'F'
,
'R'
),
(
'F'
,
'L'
)]
expected_conflicts_1
=
[(
'F'
),
(
'F'
,
'L'
)]
_check_expected_conflicts
(
expected_conflicts_0
,
obs_builder
,
tree_0
,
"agent[0]: "
)
_check_expected_conflicts
(
expected_conflicts_1
,
obs_builder
,
tree_1
,
"agent[1]: "
)
def
_check_expected_conflicts
(
expected_conflicts
,
obs_builder
,
tree_0
,
prompt
=
''
):
assert
(
tree_0
[
''
][
7
]
>
0
)
==
(()
in
expected_conflicts
),
"{}[]"
.
format
(
prompt
)
for
a_1
in
obs_builder
.
tree_explorted_actions_char
:
conflict
=
tree_0
[
a_1
][
''
][
7
]
assert
(
conflict
>
0
)
==
((
a_1
)
in
expected_conflicts
),
"{}[{}]"
.
format
(
prompt
,
a_1
)
for
a_2
in
obs_builder
.
tree_explorted_actions_char
:
conflict
=
tree_0
[
a_1
][
a_2
][
''
][
7
]
assert
(
conflict
>
0
)
==
((
a_1
,
a_2
)
in
expected_conflicts
),
"{}[{}][{}]"
.
format
(
prompt
,
a_1
,
a_2
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment