Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Flatland
Flatland
Commits
8176e2fe
Commit
8176e2fe
authored
Jul 03, 2019
by
Erik Nygren
Browse files
updated comment to make tree observation more understandable
parent
30a149a5
Pipeline
#1318
failed with stage
in 6 minutes and 21 seconds
Changes
3
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
examples/training_example.py
View file @
8176e2fe
...
...
@@ -76,6 +76,8 @@ for trials in range(1, n_trials + 1):
for
a
in
range
(
env
.
get_num_agents
()):
action
=
agent
.
act
(
obs
[
a
])
action_dict
.
update
({
a
:
action
})
# Uncomment next line to print observation of an agent
# TreeObservation.util_print_obs_subtree((obs[a]))
# Environment step which returns the observations for all agents, their corresponding
# reward and whether their are done
...
...
flatland/envs/observations.py
View file @
8176e2fe
...
...
@@ -28,7 +28,7 @@ class TreeObsForRailEnv(ObservationBuilder):
for
i
in
range
(
self
.
max_depth
+
1
):
size
+=
pow4
pow4
*=
4
self
.
observation_dim
=
8
self
.
observation_dim
=
9
self
.
observation_space
=
[
size
*
self
.
observation_dim
]
self
.
location_has_agent
=
{}
self
.
location_has_agent_direction
=
{}
...
...
@@ -223,24 +223,29 @@ class TreeObsForRailEnv(ObservationBuilder):
#3: if another agent is detected the distance in number of cells from current agent position is stored.
#4: This feature stores the distance in number of cells to the next branching store (current node)
#4: possible conflict detected
tot_dist = Other agent predicts to pass along this cell at the same time as the agent, we store the
distance in number of cells from current agent position
#5: minimum distance from node to the agent's target given the direction of the agent if this path is chosen
0 = No other agent reserve the same cell at similar time
#5: if an not usable switch (for agent) is detected we store the distance.
#6: This feature stores the distance in number of cells to the next branching (current node)
#7: minimum distance from node to the agent's target given the direction of the agent if this path is chosen
#
6
: agent in the same direction
#
8
: agent in the same direction
n = number of agents present same direction
(possible future use: number of other agents in the same direction in this branch)
0 = no agent present same direction
#
7
: agent in the opposite drection
#
9
: agent in the opposite drection
n = number of agents present other direction than myself (so conflict)
(possible future use: number of other agents in other direction in this branch, ie. number of conflicts)
0 = no agent present other direction than myself
#8: possible conflict detected
1 = Other agent predicts to pass along this cell at the same time as the agent
0 = No other agent reserve the same cell at similar time
Missing/padding nodes are filled in with -inf (truncated).
...
...
@@ -261,7 +266,7 @@ class TreeObsForRailEnv(ObservationBuilder):
num_transitions
=
np
.
count_nonzero
(
possible_transitions
)
# Root node - current position
observation
=
[
0
,
0
,
0
,
0
,
self
.
distance_map
[(
handle
,
*
agent
.
position
,
agent
.
direction
)],
0
,
0
,
0
]
observation
=
[
0
,
0
,
0
,
0
,
0
,
0
,
self
.
distance_map
[(
handle
,
*
agent
.
position
,
agent
.
direction
)],
0
,
0
]
root_observation
=
observation
[:]
visited
=
set
()
...
...
@@ -294,6 +299,8 @@ class TreeObsForRailEnv(ObservationBuilder):
def
_explore_branch
(
self
,
handle
,
position
,
direction
,
root_observation
,
tot_dist
,
depth
):
"""
Utility function to compute tree-based observations.
We walk along the branch and collect the information documented in the get() function.
If there is a branching point a new node is created and each possible branch is explored.
"""
# [Recursive branch opened]
if
depth
>=
self
.
max_depth
+
1
:
...
...
@@ -313,9 +320,11 @@ class TreeObsForRailEnv(ObservationBuilder):
own_target_encountered
=
np
.
inf
other_agent_encountered
=
np
.
inf
other_target_encountered
=
np
.
inf
potential_conflict
=
np
.
inf
unusable_switch
=
np
.
inf
other_agent_same_direction
=
0
other_agent_opposite_direction
=
0
potential_conflict
=
0
num_steps
=
1
while
exploring
:
# #############################
...
...
@@ -323,8 +332,8 @@ class TreeObsForRailEnv(ObservationBuilder):
# Modify here to compute any useful data required to build the end node's features. This code is called
# for each cell visited between the previous branching node and the next switch / target / dead-end.
if
position
in
self
.
location_has_agent
:
if
num_steps
<
other_agent_encountered
:
other_agent_encountered
=
num_steps
if
tot_dist
<
other_agent_encountered
:
other_agent_encountered
=
tot_dist
if
self
.
location_has_agent_direction
[
position
]
==
direction
:
# Cummulate the number of agents on branch with same direction
...
...
@@ -345,28 +354,28 @@ class TreeObsForRailEnv(ObservationBuilder):
if
int_position
in
np
.
delete
(
self
.
predicted_pos
[
tot_dist
],
handle
):
conflicting_agent
=
np
.
where
(
np
.
delete
(
self
.
predicted_pos
[
tot_dist
],
handle
)
==
int_position
)
for
ca
in
conflicting_agent
:
if
direction
!=
self
.
predicted_dir
[
tot_dist
][
ca
[
0
]]:
potential_conflict
=
1
if
direction
!=
self
.
predicted_dir
[
tot_dist
][
ca
[
0
]]
and
tot_dist
<
potential_conflict
:
potential_conflict
=
tot_dist
# Look for opposing paths at distance num_step-1
elif
int_position
in
np
.
delete
(
self
.
predicted_pos
[
pre_step
],
handle
):
conflicting_agent
=
np
.
where
(
self
.
predicted_pos
[
pre_step
]
==
int_position
)
for
ca
in
conflicting_agent
:
if
direction
!=
self
.
predicted_dir
[
pre_step
][
ca
[
0
]]:
potential_conflict
=
1
if
direction
!=
self
.
predicted_dir
[
pre_step
][
ca
[
0
]]
and
tot_dist
<
potential_conflict
:
potential_conflict
=
tot_dist
# Look for opposing paths at distance num_step+1
elif
int_position
in
np
.
delete
(
self
.
predicted_pos
[
post_step
],
handle
):
conflicting_agent
=
np
.
where
(
np
.
delete
(
self
.
predicted_pos
[
post_step
],
handle
)
==
int_position
)
for
ca
in
conflicting_agent
:
if
direction
!=
self
.
predicted_dir
[
post_step
][
ca
[
0
]]:
potential_conflict
=
1
if
direction
!=
self
.
predicted_dir
[
post_step
][
ca
[
0
]]
and
tot_dist
<
potential_conflict
:
potential_conflict
=
tot_dist
if
position
in
self
.
location_has_target
and
position
!=
agent
.
target
:
if
num_steps
<
other_target_encountered
:
other_target_encountered
=
num_steps
if
tot_dist
<
other_target_encountered
:
other_target_encountered
=
tot_dist
if
position
==
agent
.
target
:
if
num_steps
<
own_target_encountered
:
own_target_encountered
=
num_steps
if
tot_dist
<
own_target_encountered
:
own_target_encountered
=
tot_dist
# #############################
# #############################
...
...
@@ -382,8 +391,13 @@ class TreeObsForRailEnv(ObservationBuilder):
break
cell_transitions
=
self
.
env
.
rail
.
get_transitions
((
*
position
,
direction
))
total_transitions
=
bin
(
self
.
env
.
rail
.
get_transitions
(
position
)).
count
(
"1"
)
num_transitions
=
np
.
count_nonzero
(
cell_transitions
)
exploring
=
False
# Detect Switches that can only be used by other agents.
if
total_transitions
>
2
>
num_transitions
:
unusable_switch
=
tot_dist
if
num_transitions
==
1
:
# Check if dead-end, or if we can go forward along direction
nbits
=
0
...
...
@@ -462,32 +476,35 @@ class TreeObsForRailEnv(ObservationBuilder):
observation
=
[
own_target_encountered
,
other_target_encountered
,
other_agent_encountered
,
root_observation
[
3
]
+
num_steps
,
potential_conflict
,
unusable_switch
,
tot_dist
,
0
,
other_agent_same_direction
,
other_agent_opposite_direction
,
potential_conflict
other_agent_opposite_direction
]
elif
last_isTerminal
:
observation
=
[
own_target_encountered
,
other_target_encountered
,
other_agent_encountered
,
potential_conflict
,
unusable_switch
,
np
.
inf
,
np
.
inf
,
self
.
distance_map
[
handle
,
position
[
0
],
position
[
1
],
direction
]
,
other_agent_same_direction
,
other_agent_opposite_direction
,
potential_conflict
other_agent_opposite_direction
]
else
:
observation
=
[
own_target_encountered
,
other_target_encountered
,
other_agent_encountered
,
root_observation
[
3
]
+
num_steps
,
potential_conflict
,
unusable_switch
,
tot_dist
,
self
.
distance_map
[
handle
,
position
[
0
],
position
[
1
],
direction
],
other_agent_same_direction
,
other_agent_opposite_direction
,
potential_conflict
]
# #############################
# #############################
...
...
@@ -531,7 +548,7 @@ class TreeObsForRailEnv(ObservationBuilder):
return
observation
,
visited
def
util_print_obs_subtree
(
self
,
tree
,
num_features_per_node
=
8
,
prompt
=
''
,
current_depth
=
0
):
def
util_print_obs_subtree
(
self
,
tree
,
num_features_per_node
=
9
,
prompt
=
''
,
current_depth
=
0
):
"""
Utility function to pretty-print tree observations returned by this object.
"""
...
...
flatland/utils/rendertools.py
View file @
8176e2fe
...
...
@@ -38,7 +38,7 @@ class RenderTool(object):
gTheta
=
np
.
linspace
(
0
,
np
.
pi
/
2
,
5
)
gArc
=
array
([
np
.
cos
(
gTheta
),
np
.
sin
(
gTheta
)]).
T
# from [1,0] to [0,1]
def
__init__
(
self
,
env
,
gl
=
"PILSVG"
,
jupyter
=
False
,
agentRenderVariant
=
AgentRenderVariant
.
AGENT_SHOWS_OPTIONS
):
def
__init__
(
self
,
env
,
gl
=
"PILSVG"
,
jupyter
=
False
,
agentRenderVariant
=
AgentRenderVariant
.
ONE_STEP_BEHIND
):
self
.
env
=
env
self
.
iFrame
=
0
self
.
time1
=
time
.
time
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment