Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
Flatland
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
elrichgro
Flatland
Commits
8176e2fe
Commit
8176e2fe
authored
5 years ago
by
Erik Nygren
Browse files
Options
Downloads
Patches
Plain Diff
updated comment to make tree observation more understandable
parent
30a149a5
No related branches found
No related tags found
No related merge requests found
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
examples/training_example.py
+2
-0
2 additions, 0 deletions
examples/training_example.py
flatland/envs/observations.py
+48
-31
48 additions, 31 deletions
flatland/envs/observations.py
flatland/utils/rendertools.py
+1
-1
1 addition, 1 deletion
flatland/utils/rendertools.py
with
51 additions
and
32 deletions
examples/training_example.py
+
2
−
0
View file @
8176e2fe
...
@@ -76,6 +76,8 @@ for trials in range(1, n_trials + 1):
...
@@ -76,6 +76,8 @@ for trials in range(1, n_trials + 1):
for
a
in
range
(
env
.
get_num_agents
()):
for
a
in
range
(
env
.
get_num_agents
()):
action
=
agent
.
act
(
obs
[
a
])
action
=
agent
.
act
(
obs
[
a
])
action_dict
.
update
({
a
:
action
})
action_dict
.
update
({
a
:
action
})
# Uncomment next line to print observation of an agent
# TreeObservation.util_print_obs_subtree((obs[a]))
# Environment step which returns the observations for all agents, their corresponding
# Environment step which returns the observations for all agents, their corresponding
# reward and whether their are done
# reward and whether their are done
...
...
This diff is collapsed.
Click to expand it.
flatland/envs/observations.py
+
48
−
31
View file @
8176e2fe
...
@@ -28,7 +28,7 @@ class TreeObsForRailEnv(ObservationBuilder):
...
@@ -28,7 +28,7 @@ class TreeObsForRailEnv(ObservationBuilder):
for
i
in
range
(
self
.
max_depth
+
1
):
for
i
in
range
(
self
.
max_depth
+
1
):
size
+=
pow4
size
+=
pow4
pow4
*=
4
pow4
*=
4
self
.
observation_dim
=
8
self
.
observation_dim
=
9
self
.
observation_space
=
[
size
*
self
.
observation_dim
]
self
.
observation_space
=
[
size
*
self
.
observation_dim
]
self
.
location_has_agent
=
{}
self
.
location_has_agent
=
{}
self
.
location_has_agent_direction
=
{}
self
.
location_has_agent_direction
=
{}
...
@@ -223,24 +223,29 @@ class TreeObsForRailEnv(ObservationBuilder):
...
@@ -223,24 +223,29 @@ class TreeObsForRailEnv(ObservationBuilder):
#3: if another agent is detected the distance in number of cells from current agent position is stored.
#3: if another agent is detected the distance in number of cells from current agent position is stored.
#4: This feature stores the distance in number of cells to the next branching store (current node)
#4: possible conflict detected
tot_dist = Other agent predicts to pass along this cell at the same time as the agent, we store the
distance in number of cells from current agent position
#5: minimum distance from node to the agent
'
s target given the direction of the agent if this path is chosen
0 = No other agent reserve the same cell at similar time
#5: if an not usable switch (for agent) is detected we store the distance.
#6: This feature stores the distance in number of cells to the next branching (current node)
#7: minimum distance from node to the agent
'
s target given the direction of the agent if this path is chosen
#
6
: agent in the same direction
#
8
: agent in the same direction
n = number of agents present same direction
n = number of agents present same direction
(possible future use: number of other agents in the same direction in this branch)
(possible future use: number of other agents in the same direction in this branch)
0 = no agent present same direction
0 = no agent present same direction
#
7
: agent in the opposite drection
#
9
: agent in the opposite drection
n = number of agents present other direction than myself (so conflict)
n = number of agents present other direction than myself (so conflict)
(possible future use: number of other agents in other direction in this branch, ie. number of conflicts)
(possible future use: number of other agents in other direction in this branch, ie. number of conflicts)
0 = no agent present other direction than myself
0 = no agent present other direction than myself
#8: possible conflict detected
1 = Other agent predicts to pass along this cell at the same time as the agent
0 = No other agent reserve the same cell at similar time
Missing/padding nodes are filled in with -inf (truncated).
Missing/padding nodes are filled in with -inf (truncated).
...
@@ -261,7 +266,7 @@ class TreeObsForRailEnv(ObservationBuilder):
...
@@ -261,7 +266,7 @@ class TreeObsForRailEnv(ObservationBuilder):
num_transitions
=
np
.
count_nonzero
(
possible_transitions
)
num_transitions
=
np
.
count_nonzero
(
possible_transitions
)
# Root node - current position
# Root node - current position
observation
=
[
0
,
0
,
0
,
0
,
self
.
distance_map
[(
handle
,
*
agent
.
position
,
agent
.
direction
)],
0
,
0
,
0
]
observation
=
[
0
,
0
,
0
,
0
,
0
,
0
,
self
.
distance_map
[(
handle
,
*
agent
.
position
,
agent
.
direction
)],
0
,
0
]
root_observation
=
observation
[:]
root_observation
=
observation
[:]
visited
=
set
()
visited
=
set
()
...
@@ -294,6 +299,8 @@ class TreeObsForRailEnv(ObservationBuilder):
...
@@ -294,6 +299,8 @@ class TreeObsForRailEnv(ObservationBuilder):
def
_explore_branch
(
self
,
handle
,
position
,
direction
,
root_observation
,
tot_dist
,
depth
):
def
_explore_branch
(
self
,
handle
,
position
,
direction
,
root_observation
,
tot_dist
,
depth
):
"""
"""
Utility function to compute tree-based observations.
Utility function to compute tree-based observations.
We walk along the branch and collect the information documented in the get() function.
If there is a branching point a new node is created and each possible branch is explored.
"""
"""
# [Recursive branch opened]
# [Recursive branch opened]
if
depth
>=
self
.
max_depth
+
1
:
if
depth
>=
self
.
max_depth
+
1
:
...
@@ -313,9 +320,11 @@ class TreeObsForRailEnv(ObservationBuilder):
...
@@ -313,9 +320,11 @@ class TreeObsForRailEnv(ObservationBuilder):
own_target_encountered
=
np
.
inf
own_target_encountered
=
np
.
inf
other_agent_encountered
=
np
.
inf
other_agent_encountered
=
np
.
inf
other_target_encountered
=
np
.
inf
other_target_encountered
=
np
.
inf
potential_conflict
=
np
.
inf
unusable_switch
=
np
.
inf
other_agent_same_direction
=
0
other_agent_same_direction
=
0
other_agent_opposite_direction
=
0
other_agent_opposite_direction
=
0
potential_conflict
=
0
num_steps
=
1
num_steps
=
1
while
exploring
:
while
exploring
:
# #############################
# #############################
...
@@ -323,8 +332,8 @@ class TreeObsForRailEnv(ObservationBuilder):
...
@@ -323,8 +332,8 @@ class TreeObsForRailEnv(ObservationBuilder):
# Modify here to compute any useful data required to build the end node's features. This code is called
# Modify here to compute any useful data required to build the end node's features. This code is called
# for each cell visited between the previous branching node and the next switch / target / dead-end.
# for each cell visited between the previous branching node and the next switch / target / dead-end.
if
position
in
self
.
location_has_agent
:
if
position
in
self
.
location_has_agent
:
if
num_steps
<
other_agent_encountered
:
if
tot_dist
<
other_agent_encountered
:
other_agent_encountered
=
num_steps
other_agent_encountered
=
tot_dist
if
self
.
location_has_agent_direction
[
position
]
==
direction
:
if
self
.
location_has_agent_direction
[
position
]
==
direction
:
# Cummulate the number of agents on branch with same direction
# Cummulate the number of agents on branch with same direction
...
@@ -345,28 +354,28 @@ class TreeObsForRailEnv(ObservationBuilder):
...
@@ -345,28 +354,28 @@ class TreeObsForRailEnv(ObservationBuilder):
if
int_position
in
np
.
delete
(
self
.
predicted_pos
[
tot_dist
],
handle
):
if
int_position
in
np
.
delete
(
self
.
predicted_pos
[
tot_dist
],
handle
):
conflicting_agent
=
np
.
where
(
np
.
delete
(
self
.
predicted_pos
[
tot_dist
],
handle
)
==
int_position
)
conflicting_agent
=
np
.
where
(
np
.
delete
(
self
.
predicted_pos
[
tot_dist
],
handle
)
==
int_position
)
for
ca
in
conflicting_agent
:
for
ca
in
conflicting_agent
:
if
direction
!=
self
.
predicted_dir
[
tot_dist
][
ca
[
0
]]:
if
direction
!=
self
.
predicted_dir
[
tot_dist
][
ca
[
0
]]
and
tot_dist
<
potential_conflict
:
potential_conflict
=
1
potential_conflict
=
tot_dist
# Look for opposing paths at distance num_step-1
# Look for opposing paths at distance num_step-1
elif
int_position
in
np
.
delete
(
self
.
predicted_pos
[
pre_step
],
handle
):
elif
int_position
in
np
.
delete
(
self
.
predicted_pos
[
pre_step
],
handle
):
conflicting_agent
=
np
.
where
(
self
.
predicted_pos
[
pre_step
]
==
int_position
)
conflicting_agent
=
np
.
where
(
self
.
predicted_pos
[
pre_step
]
==
int_position
)
for
ca
in
conflicting_agent
:
for
ca
in
conflicting_agent
:
if
direction
!=
self
.
predicted_dir
[
pre_step
][
ca
[
0
]]:
if
direction
!=
self
.
predicted_dir
[
pre_step
][
ca
[
0
]]
and
tot_dist
<
potential_conflict
:
potential_conflict
=
1
potential_conflict
=
tot_dist
# Look for opposing paths at distance num_step+1
# Look for opposing paths at distance num_step+1
elif
int_position
in
np
.
delete
(
self
.
predicted_pos
[
post_step
],
handle
):
elif
int_position
in
np
.
delete
(
self
.
predicted_pos
[
post_step
],
handle
):
conflicting_agent
=
np
.
where
(
np
.
delete
(
self
.
predicted_pos
[
post_step
],
handle
)
==
int_position
)
conflicting_agent
=
np
.
where
(
np
.
delete
(
self
.
predicted_pos
[
post_step
],
handle
)
==
int_position
)
for
ca
in
conflicting_agent
:
for
ca
in
conflicting_agent
:
if
direction
!=
self
.
predicted_dir
[
post_step
][
ca
[
0
]]:
if
direction
!=
self
.
predicted_dir
[
post_step
][
ca
[
0
]]
and
tot_dist
<
potential_conflict
:
potential_conflict
=
1
potential_conflict
=
tot_dist
if
position
in
self
.
location_has_target
and
position
!=
agent
.
target
:
if
position
in
self
.
location_has_target
and
position
!=
agent
.
target
:
if
num_steps
<
other_target_encountered
:
if
tot_dist
<
other_target_encountered
:
other_target_encountered
=
num_steps
other_target_encountered
=
tot_dist
if
position
==
agent
.
target
:
if
position
==
agent
.
target
:
if
num_steps
<
own_target_encountered
:
if
tot_dist
<
own_target_encountered
:
own_target_encountered
=
num_steps
own_target_encountered
=
tot_dist
# #############################
# #############################
# #############################
# #############################
...
@@ -382,8 +391,13 @@ class TreeObsForRailEnv(ObservationBuilder):
...
@@ -382,8 +391,13 @@ class TreeObsForRailEnv(ObservationBuilder):
break
break
cell_transitions
=
self
.
env
.
rail
.
get_transitions
((
*
position
,
direction
))
cell_transitions
=
self
.
env
.
rail
.
get_transitions
((
*
position
,
direction
))
total_transitions
=
bin
(
self
.
env
.
rail
.
get_transitions
(
position
)).
count
(
"
1
"
)
num_transitions
=
np
.
count_nonzero
(
cell_transitions
)
num_transitions
=
np
.
count_nonzero
(
cell_transitions
)
exploring
=
False
exploring
=
False
# Detect Switches that can only be used by other agents.
if
total_transitions
>
2
>
num_transitions
:
unusable_switch
=
tot_dist
if
num_transitions
==
1
:
if
num_transitions
==
1
:
# Check if dead-end, or if we can go forward along direction
# Check if dead-end, or if we can go forward along direction
nbits
=
0
nbits
=
0
...
@@ -462,32 +476,35 @@ class TreeObsForRailEnv(ObservationBuilder):
...
@@ -462,32 +476,35 @@ class TreeObsForRailEnv(ObservationBuilder):
observation
=
[
own_target_encountered
,
observation
=
[
own_target_encountered
,
other_target_encountered
,
other_target_encountered
,
other_agent_encountered
,
other_agent_encountered
,
root_observation
[
3
]
+
num_steps
,
potential_conflict
,
unusable_switch
,
tot_dist
,
0
,
0
,
other_agent_same_direction
,
other_agent_same_direction
,
other_agent_opposite_direction
,
other_agent_opposite_direction
potential_conflict
]
]
elif
last_isTerminal
:
elif
last_isTerminal
:
observation
=
[
own_target_encountered
,
observation
=
[
own_target_encountered
,
other_target_encountered
,
other_target_encountered
,
other_agent_encountered
,
other_agent_encountered
,
potential_conflict
,
unusable_switch
,
np
.
inf
,
np
.
inf
,
np
.
inf
,
self
.
distance_map
[
handle
,
position
[
0
],
position
[
1
],
direction
]
,
other_agent_same_direction
,
other_agent_same_direction
,
other_agent_opposite_direction
,
other_agent_opposite_direction
potential_conflict
]
]
else
:
else
:
observation
=
[
own_target_encountered
,
observation
=
[
own_target_encountered
,
other_target_encountered
,
other_target_encountered
,
other_agent_encountered
,
other_agent_encountered
,
root_observation
[
3
]
+
num_steps
,
potential_conflict
,
unusable_switch
,
tot_dist
,
self
.
distance_map
[
handle
,
position
[
0
],
position
[
1
],
direction
],
self
.
distance_map
[
handle
,
position
[
0
],
position
[
1
],
direction
],
other_agent_same_direction
,
other_agent_same_direction
,
other_agent_opposite_direction
,
other_agent_opposite_direction
,
potential_conflict
]
]
# #############################
# #############################
# #############################
# #############################
...
@@ -531,7 +548,7 @@ class TreeObsForRailEnv(ObservationBuilder):
...
@@ -531,7 +548,7 @@ class TreeObsForRailEnv(ObservationBuilder):
return
observation
,
visited
return
observation
,
visited
def
util_print_obs_subtree
(
self
,
tree
,
num_features_per_node
=
8
,
prompt
=
''
,
current_depth
=
0
):
def
util_print_obs_subtree
(
self
,
tree
,
num_features_per_node
=
9
,
prompt
=
''
,
current_depth
=
0
):
"""
"""
Utility function to pretty-print tree observations returned by this object.
Utility function to pretty-print tree observations returned by this object.
"""
"""
...
...
This diff is collapsed.
Click to expand it.
flatland/utils/rendertools.py
+
1
−
1
View file @
8176e2fe
...
@@ -38,7 +38,7 @@ class RenderTool(object):
...
@@ -38,7 +38,7 @@ class RenderTool(object):
gTheta
=
np
.
linspace
(
0
,
np
.
pi
/
2
,
5
)
gTheta
=
np
.
linspace
(
0
,
np
.
pi
/
2
,
5
)
gArc
=
array
([
np
.
cos
(
gTheta
),
np
.
sin
(
gTheta
)]).
T
# from [1,0] to [0,1]
gArc
=
array
([
np
.
cos
(
gTheta
),
np
.
sin
(
gTheta
)]).
T
# from [1,0] to [0,1]
def
__init__
(
self
,
env
,
gl
=
"
PILSVG
"
,
jupyter
=
False
,
agentRenderVariant
=
AgentRenderVariant
.
AGENT_SHOWS_OPTIONS
):
def
__init__
(
self
,
env
,
gl
=
"
PILSVG
"
,
jupyter
=
False
,
agentRenderVariant
=
AgentRenderVariant
.
ONE_STEP_BEHIND
):
self
.
env
=
env
self
.
env
=
env
self
.
iFrame
=
0
self
.
iFrame
=
0
self
.
time1
=
time
.
time
()
self
.
time1
=
time
.
time
()
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment