Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Flatland
Flatland
Commits
4de49a72
Commit
4de49a72
authored
Jul 05, 2019
by
u214892
Browse files
#62
increase coverage
#83
cleanup
parent
15a725f0
Changes
4
Hide whitespace changes
Inline
Side-by-side
flatland/core/transition_map.py
View file @
4de49a72
...
...
@@ -336,8 +336,4 @@ class GridTransitionMap(TransitionMap):
return
True
# TODO: GIACOMO: is it better to provide those methods with lists of cell_ids
# (most general implementation) or to make Grid-class specific methods for
# slicing over the 3 dimensions? I'd say both perhaps.
# TODO: override __getitem__ and __setitem__ (cell contents, not transitions?)
# TODO: improvement override __getitem__ and __setitem__ (cell contents, not transitions?)
flatland/envs/observations.py
View file @
4de49a72
...
...
@@ -23,6 +23,7 @@ class TreeObsForRailEnv(ObservationBuilder):
observation_dim
=
9
def
__init__
(
self
,
max_depth
,
predictor
=
None
):
super
().
__init__
()
self
.
max_depth
=
max_depth
# Compute the size of the returned observation vector
...
...
@@ -41,15 +42,14 @@ class TreeObsForRailEnv(ObservationBuilder):
def
reset
(
self
):
agents
=
self
.
env
.
agents
n
A
gents
=
len
(
agents
)
n
b_a
gents
=
len
(
agents
)
compute_distance_map
=
True
if
self
.
agents_previous_reset
is
not
None
:
if
nAgents
==
len
(
self
.
agents_previous_reset
):
compute_distance_map
=
False
for
i
in
range
(
nAgents
):
if
agents
[
i
].
target
!=
self
.
agents_previous_reset
[
i
].
target
:
compute_distance_map
=
True
if
self
.
agents_previous_reset
is
not
None
and
nb_agents
==
len
(
self
.
agents_previous_reset
):
compute_distance_map
=
False
for
i
in
range
(
nb_agents
):
if
agents
[
i
].
target
!=
self
.
agents_previous_reset
[
i
].
target
:
compute_distance_map
=
True
self
.
agents_previous_reset
=
agents
if
compute_distance_map
:
...
...
@@ -57,12 +57,12 @@ class TreeObsForRailEnv(ObservationBuilder):
def
_compute_distance_map
(
self
):
agents
=
self
.
env
.
agents
n
A
gents
=
len
(
agents
)
self
.
distance_map
=
np
.
inf
*
np
.
ones
(
shape
=
(
n
Agents
,
# self.env.number_of
_agents,
n
b_a
gents
=
len
(
agents
)
self
.
distance_map
=
np
.
inf
*
np
.
ones
(
shape
=
(
n
b
_agents
,
self
.
env
.
height
,
self
.
env
.
width
,
4
))
self
.
max_dist
=
np
.
zeros
(
n
A
gents
)
self
.
max_dist
=
np
.
zeros
(
n
b_a
gents
)
self
.
max_dist
=
[
self
.
_distance_map_walker
(
agent
.
target
,
i
)
for
i
,
agent
in
enumerate
(
agents
)]
# Update local lookup table for all agents' target locations
self
.
location_has_target
=
{
tuple
(
agent
.
target
):
1
for
agent
in
agents
}
...
...
@@ -83,10 +83,8 @@ class TreeObsForRailEnv(ObservationBuilder):
# BFS from target `position' to all the reachable nodes in the grid
# Stop the search if the target position is re-visited, in any direction
visited
=
set
([(
position
[
0
],
position
[
1
],
0
),
(
position
[
0
],
position
[
1
],
1
),
(
position
[
0
],
position
[
1
],
2
),
(
position
[
0
],
position
[
1
],
3
)])
visited
=
{(
position
[
0
],
position
[
1
],
0
),
(
position
[
0
],
position
[
1
],
1
),
(
position
[
0
],
position
[
1
],
2
),
(
position
[
0
],
position
[
1
],
3
)}
max_distance
=
0
...
...
@@ -133,10 +131,10 @@ class TreeObsForRailEnv(ObservationBuilder):
# Check all possible transitions in new_cell
for
agent_orientation
in
range
(
4
):
# Is a transition along movement `desired_movement_from_new_cell' to the current cell possible?
is
V
alid
=
self
.
env
.
rail
.
get_transition
((
new_cell
[
0
],
new_cell
[
1
],
agent_orientation
),
desired_movement_from_new_cell
)
is
_v
alid
=
self
.
env
.
rail
.
get_transition
((
new_cell
[
0
],
new_cell
[
1
],
agent_orientation
),
desired_movement_from_new_cell
)
if
is
V
alid
:
if
is
_v
alid
:
"""
# TODO: check that it works with deadends! -- still bugged!
movement = desired_movement_from_new_cell
...
...
@@ -163,12 +161,14 @@ class TreeObsForRailEnv(ObservationBuilder):
elif
movement
==
Grid4TransitionsEnum
.
WEST
:
return
(
position
[
0
],
position
[
1
]
-
1
)
def
get_many
(
self
,
handles
=
[]
):
def
get_many
(
self
,
handles
=
None
):
"""
Called whenever an observation has to be computed for the `env' environment, for each agent with handle
in the `handles' list.
"""
if
handles
is
None
:
handles
=
[]
if
self
.
predictor
:
self
.
predicted_pos
=
{}
self
.
predicted_dir
=
{}
...
...
@@ -259,7 +259,6 @@ class TreeObsForRailEnv(ObservationBuilder):
# Root node - current position
observation
=
[
0
,
0
,
0
,
0
,
0
,
0
,
self
.
distance_map
[(
handle
,
*
agent
.
position
,
agent
.
direction
)],
0
,
0
]
root_observation
=
observation
[:]
visited
=
set
()
# Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right, back], relative to the current orientation
...
...
@@ -273,7 +272,7 @@ class TreeObsForRailEnv(ObservationBuilder):
if
possible_transitions
[
branch_direction
]:
new_cell
=
self
.
_new_position
(
agent
.
position
,
branch_direction
)
branch_observation
,
branch_visited
=
\
self
.
_explore_branch
(
handle
,
new_cell
,
branch_direction
,
root_observation
,
1
,
1
)
self
.
_explore_branch
(
handle
,
new_cell
,
branch_direction
,
1
,
1
)
observation
=
observation
+
branch_observation
visited
=
visited
.
union
(
branch_visited
)
else
:
...
...
@@ -291,7 +290,7 @@ class TreeObsForRailEnv(ObservationBuilder):
pow4
*=
4
return
num_observations
*
self
.
observation_dim
def
_explore_branch
(
self
,
handle
,
position
,
direction
,
root_observation
,
tot_dist
,
depth
):
def
_explore_branch
(
self
,
handle
,
position
,
direction
,
tot_dist
,
depth
):
"""
Utility function to compute tree-based observations.
We walk along the branch and collect the information documented in the get() function.
...
...
@@ -305,10 +304,10 @@ class TreeObsForRailEnv(ObservationBuilder):
# until no transitions are possible along the current direction (i.e., dead-ends)
# We treat dead-ends as nodes, instead of going back, to avoid loops
exploring
=
True
last_is
S
witch
=
False
last_is
D
ead
E
nd
=
False
last_is
T
erminal
=
False
# wrong cell OR cycle; either way, we don't want the agent to land here
last_is
T
arget
=
False
last_is
_s
witch
=
False
last_is
_d
ead
_e
nd
=
False
last_is
_t
erminal
=
False
# wrong cell OR cycle; either way, we don't want the agent to land here
last_is
_t
arget
=
False
visited
=
set
()
agent
=
self
.
env
.
agents
[
handle
]
...
...
@@ -369,21 +368,19 @@ class TreeObsForRailEnv(ObservationBuilder):
if
tot_dist
<
other_target_encountered
:
other_target_encountered
=
tot_dist
if
position
==
agent
.
target
:
if
tot_dist
<
own_target_encountered
:
own_target_encountered
=
tot_dist
if
position
==
agent
.
target
and
tot_dist
<
own_target_encountered
:
own_target_encountered
=
tot_dist
# #############################
# #############################
if
(
position
[
0
],
position
[
1
],
direction
)
in
visited
:
last_is
T
erminal
=
True
last_is
_t
erminal
=
True
break
visited
.
add
((
position
[
0
],
position
[
1
],
direction
))
# If the target node is encountered, pick that as node. Also, no further branching is possible.
if
np
.
array_equal
(
position
,
self
.
env
.
agents
[
handle
].
target
):
last_is
T
arget
=
True
last_is
_t
arget
=
True
break
cell_transitions
=
self
.
env
.
rail
.
get_transitions
((
*
position
,
direction
))
...
...
@@ -403,9 +400,9 @@ class TreeObsForRailEnv(ObservationBuilder):
tmp
=
tmp
>>
1
if
nbits
==
1
:
# Dead-end!
last_is
D
ead
E
nd
=
True
last_is
_d
ead
_e
nd
=
True
if
not
last_is
D
ead
E
nd
:
if
not
last_is
_d
ead
_e
nd
:
# Keep walking through the tree along `direction'
exploring
=
True
# convert one-hot encoding to 0,1,2,3
...
...
@@ -415,14 +412,14 @@ class TreeObsForRailEnv(ObservationBuilder):
tot_dist
+=
1
elif
num_transitions
>
0
:
# Switch detected
last_is
S
witch
=
True
last_is
_s
witch
=
True
break
elif
num_transitions
==
0
:
# Wrong cell type, but let's cover it and treat it as a dead-end, just in case
print
(
"WRONG CELL TYPE detected in tree-search (0 transitions possible) at cell"
,
position
[
0
],
position
[
1
],
direction
)
last_is
T
erminal
=
True
last_is
_t
erminal
=
True
break
# `position' is either a terminal node or a switch
...
...
@@ -433,7 +430,7 @@ class TreeObsForRailEnv(ObservationBuilder):
# #############################
# Modify here to append new / different features for each visited cell!
if
last_is
T
arget
:
if
last_is
_t
arget
:
observation
=
[
own_target_encountered
,
other_target_encountered
,
other_agent_encountered
,
...
...
@@ -445,7 +442,7 @@ class TreeObsForRailEnv(ObservationBuilder):
other_agent_opposite_direction
]
elif
last_is
T
erminal
:
elif
last_is
_t
erminal
:
observation
=
[
own_target_encountered
,
other_target_encountered
,
other_agent_encountered
,
...
...
@@ -476,25 +473,25 @@ class TreeObsForRailEnv(ObservationBuilder):
# Get the possible transitions
possible_transitions
=
self
.
env
.
rail
.
get_transitions
((
*
position
,
direction
))
for
branch_direction
in
[(
direction
+
4
+
i
)
%
4
for
i
in
range
(
-
1
,
3
)]:
if
last_is
D
ead
E
nd
and
self
.
env
.
rail
.
get_transition
((
*
position
,
direction
),
(
branch_direction
+
2
)
%
4
):
if
last_is
_d
ead
_e
nd
and
self
.
env
.
rail
.
get_transition
((
*
position
,
direction
),
(
branch_direction
+
2
)
%
4
):
# Swap forward and back in case of dead-end, so that an agent can learn that going forward takes
# it back
new_cell
=
self
.
_new_position
(
position
,
(
branch_direction
+
2
)
%
4
)
branch_observation
,
branch_visited
=
self
.
_explore_branch
(
handle
,
new_cell
,
(
branch_direction
+
2
)
%
4
,
new_root_observation
,
tot_dist
+
1
,
tot_dist
+
1
,
depth
+
1
)
observation
=
observation
+
branch_observation
if
len
(
branch_visited
)
!=
0
:
visited
=
visited
.
union
(
branch_visited
)
elif
last_is
S
witch
and
possible_transitions
[
branch_direction
]:
elif
last_is
_s
witch
and
possible_transitions
[
branch_direction
]:
new_cell
=
self
.
_new_position
(
position
,
branch_direction
)
branch_observation
,
branch_visited
=
self
.
_explore_branch
(
handle
,
new_cell
,
branch_direction
,
new_root_observation
,
tot_dist
+
1
,
tot_dist
+
1
,
depth
+
1
)
observation
=
observation
+
branch_observation
if
len
(
branch_visited
)
!=
0
:
...
...
flatland/envs/rail_env.py
View file @
4de49a72
...
...
@@ -109,7 +109,7 @@ class RailEnv(Environment):
self
.
obs_builder
.
_set_env
(
self
)
self
.
action_space
=
[
1
]
self
.
observation_space
=
self
.
obs_builder
.
observation_space
# updated on resets?
self
.
observation_space
=
self
.
obs_builder
.
observation_space
self
.
rewards
=
[
0
]
*
number_of_agents
self
.
done
=
False
...
...
@@ -195,31 +195,29 @@ class RailEnv(Environment):
# Reset the step rewards
self
.
rewards_dict
=
dict
()
for
i
A
gent
in
range
(
self
.
get_num_agents
()):
self
.
rewards_dict
[
i
A
gent
]
=
0
for
i
_a
gent
in
range
(
self
.
get_num_agents
()):
self
.
rewards_dict
[
i
_a
gent
]
=
0
if
self
.
dones
[
"__all__"
]:
self
.
rewards_dict
=
{
i
:
r
+
global_reward
for
i
,
r
in
self
.
rewards_dict
.
items
()}
return
self
.
_get_observations
(),
self
.
rewards_dict
,
self
.
dones
,
{}
# for i in range(len(self.agents_handles)):
for
iAgent
in
range
(
self
.
get_num_agents
()):
agent
=
self
.
agents
[
iAgent
]
for
i_agent
,
agent
in
enumerate
(
self
.
get_num_agents
()):
agent
.
old_direction
=
agent
.
direction
agent
.
old_position
=
agent
.
position
if
self
.
dones
[
i
A
gent
]:
# this agent has already completed...
if
self
.
dones
[
i
_a
gent
]:
# this agent has already completed...
continue
if
i
A
gent
not
in
action_dict
:
# no action has been supplied for this agent
action_dict
[
i
A
gent
]
=
RailEnvActions
.
DO_NOTHING
if
i
_a
gent
not
in
action_dict
:
# no action has been supplied for this agent
action_dict
[
i
_a
gent
]
=
RailEnvActions
.
DO_NOTHING
if
action_dict
[
i
A
gent
]
<
0
or
action_dict
[
i
A
gent
]
>
len
(
RailEnvActions
):
print
(
'ERROR: illegal action='
,
action_dict
[
i
A
gent
],
'for agent with index='
,
i
A
gent
,
if
action_dict
[
i
_a
gent
]
<
0
or
action_dict
[
i
_a
gent
]
>
len
(
RailEnvActions
):
print
(
'ERROR: illegal action='
,
action_dict
[
i
_a
gent
],
'for agent with index='
,
i
_a
gent
,
'"DO NOTHING" will be executed instead'
)
action_dict
[
i
A
gent
]
=
RailEnvActions
.
DO_NOTHING
action_dict
[
i
_a
gent
]
=
RailEnvActions
.
DO_NOTHING
action
=
action_dict
[
i
A
gent
]
action
=
action_dict
[
i
_a
gent
]
if
action
==
RailEnvActions
.
DO_NOTHING
and
agent
.
moving
:
# Keep moving
...
...
@@ -228,12 +226,12 @@ class RailEnv(Environment):
if
action
==
RailEnvActions
.
STOP_MOVING
and
agent
.
moving
and
agent
.
speed_data
[
'position_fraction'
]
==
0.
:
# Only allow halting an agent on entering new cells.
agent
.
moving
=
False
self
.
rewards_dict
[
i
A
gent
]
+=
stop_penalty
self
.
rewards_dict
[
i
_a
gent
]
+=
stop_penalty
if
not
agent
.
moving
and
not
(
action
==
RailEnvActions
.
DO_NOTHING
or
action
==
RailEnvActions
.
STOP_MOVING
):
# Only allow agent to start moving by pressing forward.
agent
.
moving
=
True
self
.
rewards_dict
[
i
A
gent
]
+=
start_penalty
self
.
rewards_dict
[
i
_a
gent
]
+=
start_penalty
# Now perform a movement.
# If the agent is in an initial position within a new cell (agent.speed_data['position_fraction']<eps)
...
...
@@ -269,16 +267,16 @@ class RailEnv(Environment):
else
:
# TODO: an invalid action was chosen after entering the cell. The agent cannot move.
self
.
rewards_dict
[
i
A
gent
]
+=
invalid_action_penalty
self
.
rewards_dict
[
i
_a
gent
]
+=
invalid_action_penalty
agent
.
moving
=
False
self
.
rewards_dict
[
i
A
gent
]
+=
stop_penalty
self
.
rewards_dict
[
i
_a
gent
]
+=
stop_penalty
continue
else
:
# TODO: an invalid action was chosen after entering the cell. The agent cannot move.
self
.
rewards_dict
[
i
A
gent
]
+=
invalid_action_penalty
self
.
rewards_dict
[
i
_a
gent
]
+=
invalid_action_penalty
agent
.
moving
=
False
self
.
rewards_dict
[
i
A
gent
]
+=
stop_penalty
self
.
rewards_dict
[
i
_a
gent
]
+=
stop_penalty
continue
...
...
@@ -300,9 +298,9 @@ class RailEnv(Environment):
agent
.
speed_data
[
'position_fraction'
]
=
0.0
if
np
.
equal
(
agent
.
position
,
agent
.
target
).
all
():
self
.
dones
[
i
A
gent
]
=
True
self
.
dones
[
i
_a
gent
]
=
True
else
:
self
.
rewards_dict
[
i
A
gent
]
+=
step_penalty
*
agent
.
speed_data
[
'speed'
]
self
.
rewards_dict
[
i
_a
gent
]
+=
step_penalty
*
agent
.
speed_data
[
'speed'
]
# Check for end of episode + add global reward to all rewards!
if
np
.
all
([
np
.
array_equal
(
agent2
.
position
,
agent2
.
target
)
for
agent2
in
self
.
agents
]):
...
...
tests/test_flatland_core_transition_map.py
View file @
4de49a72
...
...
@@ -3,7 +3,7 @@ from flatland.core.grid.grid8 import Grid8Transitions, Grid8TransitionsEnum
from
flatland.core.transition_map
import
GridTransitionMap
def
test_grid4_
s
et_transitions
():
def
test_grid4_
g
et_transitions
():
grid4_map
=
GridTransitionMap
(
2
,
2
,
Grid4Transitions
([]))
assert
grid4_map
.
get_transitions
((
0
,
0
,
Grid4TransitionsEnum
.
NORTH
))
==
(
0
,
0
,
0
,
0
)
grid4_map
.
set_transition
((
0
,
0
,
Grid4TransitionsEnum
.
NORTH
),
Grid4TransitionsEnum
.
NORTH
,
1
)
...
...
@@ -19,3 +19,7 @@ def test_grid8_set_transitions():
assert
grid8_map
.
get_transitions
((
0
,
0
,
Grid8TransitionsEnum
.
NORTH
))
==
(
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
)
grid8_map
.
set_transition
((
0
,
0
,
Grid8TransitionsEnum
.
NORTH
),
Grid8TransitionsEnum
.
NORTH
,
0
)
assert
grid8_map
.
get_transitions
((
0
,
0
,
Grid8TransitionsEnum
.
NORTH
))
==
(
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
)
# TODO GridTransitionMap
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment