Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
Flatland
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
pranjal_dhole
Flatland
Commits
d4e6af1c
Commit
d4e6af1c
authored
5 years ago
by
u214892
Browse files
Options
Downloads
Patches
Plain Diff
SIM-119 refactoring ActionPlan
parent
a2061acf
No related branches found
Branches containing commit
No related tags found
No related merge requests found
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
flatland/action_plan/__init__.py
+0
-0
0 additions, 0 deletions
flatland/action_plan/__init__.py
flatland/action_plan/action_plan.py
+304
-0
304 additions, 0 deletions
flatland/action_plan/action_plan.py
tests/test_action_plan.py
+93
-0
93 additions, 0 deletions
tests/test_action_plan.py
with
397 additions
and
0 deletions
flatland/action_plan/__init__.py
0 → 100644
+
0
−
0
View file @
d4e6af1c
This diff is collapsed.
Click to expand it.
flatland/action_plan/action_plan.py
0 → 100644
+
304
−
0
View file @
d4e6af1c
import
pprint
from
typing
import
Dict
,
List
,
Optional
,
NamedTuple
import
numpy
as
np
from
flatland.core.grid.grid_utils
import
Vec2dOperations
as
Vec2d
from
flatland.envs.rail_env
import
RailEnv
,
RailEnvActions
from
flatland.envs.rail_env_shortest_paths
import
WalkingElement
,
get_action_for_move
from
flatland.utils.rendertools
import
RenderTool
,
AgentRenderVariant
#---- Input Data Structures (graph representation) ---------------------------------------------
# A cell pin represents the one of the four pins in which the cell at row,column may be entered.
CellPin
=
NamedTuple
(
'
CellPin
'
,
[(
'
r
'
,
int
),
(
'
c
'
,
int
),
(
'
d
'
,
int
)])
# A path schedule element represents the entry time of agent at a cell pin.
PathScheduleElement
=
NamedTuple
(
'
PathScheduleElement
'
,
[
(
'
scheduled_at
'
,
int
),
(
'
cell_pin
'
,
CellPin
)
])
# A path schedule is the list of an agent's cell pin entries
PathSchedule
=
List
[
PathScheduleElement
]
#---- Output Data Structures (FLATland representation) ---------------------------------------------
# An action plan element represents the actions to be taken by an agent at deterministic time steps
# plus the position before the action
ActionPlanElement
=
NamedTuple
(
'
ActionPlanElement
'
,
[
(
'
scheduled_at
'
,
int
),
(
'
walking_element
'
,
WalkingElement
)
])
# An action plan deterministically represents all the actions to be taken by an agent
# plus its position before the actions are taken
ActionPlan
=
Dict
[
int
,
List
[
ActionPlanElement
]]
class
ActionPlanReplayer
():
"""
Allows to deduce an `ActionPlan` from the agents
'
`PathSchedule` and
to be replayed/verified in a FLATland env without malfunction.
"""
pp
=
pprint
.
PrettyPrinter
(
indent
=
4
)
def
__init__
(
self
,
env
:
RailEnv
,
chosen_path_dict
:
Dict
[
int
,
PathSchedule
]):
self
.
env
=
env
self
.
action_plan
=
[[]
for
_
in
range
(
self
.
env
.
get_num_agents
())]
for
agent_id
,
chosen_path
in
chosen_path_dict
.
items
():
self
.
_add_aggent_to_action_plan
(
self
.
action_plan
,
agent_id
,
chosen_path
)
def
get_walking_element_before_or_at_step
(
self
,
agent_id
:
int
,
step
:
int
)
->
WalkingElement
:
"""
Get the walking element from which the current position can be extracted.
Parameters
----------
agent_id
step
Returns
-------
WalkingElement
"""
walking_element
=
None
for
action
in
self
.
action_plan
[
agent_id
]:
if
step
<
action
.
scheduled_at
:
return
walking_element
if
step
>=
action
.
scheduled_at
:
walking_element
=
action
.
walking_element
assert
walking_element
is
not
None
return
walking_element
def
get_action_at_step
(
self
,
agent_id
:
int
,
current_step
:
int
)
->
Optional
[
RailEnvActions
]:
"""
Get the current action if any is defined in the `ActionPlan`.
Parameters
----------
agent_id
current_step
Returns
-------
WalkingElement, optional
"""
for
action_plan_step
in
self
.
action_plan
[
agent_id
]:
action_plan_step
:
ActionPlanElement
=
action_plan_step
scheduled_at
=
action_plan_step
.
scheduled_at
walking_element
:
WalkingElement
=
action_plan_step
.
walking_element
if
scheduled_at
>
current_step
:
return
None
elif
np
.
isclose
(
current_step
,
scheduled_at
):
return
walking_element
.
next_action
return
None
def
get_action_dict_for_step_replay
(
self
,
current_step
:
int
)
->
Dict
[
int
,
RailEnvActions
]:
"""
Get the action dictionary to be replayed at the current step.
Parameters
----------
current_step: int
Returns
-------
Dict[int, RailEnvActions]
"""
action_dict
=
{}
for
agent_id
,
agent
in
enumerate
(
self
.
env
.
agents
):
action
:
Optional
[
RailEnvActions
]
=
self
.
get_action_at_step
(
agent_id
,
current_step
)
if
action
is
not
None
:
action_dict
[
agent_id
]
=
action
return
action_dict
def
replay_verify
(
self
,
MAX_EPISODE_STEPS
:
int
,
env
:
RailEnv
,
rendering
:
bool
):
"""
Replays this deterministic `ActionPlan` and verifies whether it is feasible.
"""
if
rendering
:
renderer
=
RenderTool
(
env
,
gl
=
"
PILSVG
"
,
agent_render_variant
=
AgentRenderVariant
.
AGENT_SHOWS_OPTIONS_AND_BOX
,
show_debug
=
True
,
clear_debug_text
=
True
,
screen_height
=
1000
,
screen_width
=
1000
)
renderer
.
render_env
(
show
=
True
,
show_observations
=
False
,
show_predictions
=
False
)
i
=
0
while
not
env
.
dones
[
'
__all__
'
]
and
i
<=
MAX_EPISODE_STEPS
:
for
agent_id
,
agent
in
enumerate
(
env
.
agents
):
walking_element
:
WalkingElement
=
self
.
get_walking_element_before_or_at_step
(
agent_id
,
i
)
assert
agent
.
position
==
walking_element
.
position
,
\
"
before {}, agent {} at {}, expected {}
"
.
format
(
i
,
agent_id
,
agent
.
position
,
walking_element
.
position
)
actions
=
self
.
get_action_dict_for_step_replay
(
i
)
print
(
"
actions for {}: {}
"
.
format
(
i
,
actions
))
obs
,
all_rewards
,
done
,
_
=
env
.
step
(
actions
)
if
rendering
:
renderer
.
render_env
(
show
=
True
,
show_observations
=
False
,
show_predictions
=
False
)
i
+=
1
def
print_action_plan
(
self
):
for
agent_id
,
plan
in
enumerate
(
self
.
action_plan
):
print
(
"
{}:
"
.
format
(
agent_id
))
for
step
in
plan
:
print
(
"
{}
"
.
format
(
step
))
@staticmethod
def
compare_action_plans
(
expected_action_plan
:
ActionPlan
,
actual_action_plan
:
ActionPlan
):
assert
len
(
expected_action_plan
)
==
len
(
actual_action_plan
)
for
k
in
range
(
len
(
expected_action_plan
)):
assert
len
(
expected_action_plan
[
k
])
==
len
(
actual_action_plan
[
k
]),
\
"
len for agent {} should be the same.
\n\n
expected ({}) = {}
\n\n
actual ({}) = {}
"
.
format
(
k
,
len
(
expected_action_plan
[
k
]),
ActionPlanReplayer
.
pp
.
pformat
(
expected_action_plan
[
k
]),
len
(
actual_action_plan
[
k
]),
ActionPlanReplayer
.
pp
.
pformat
(
actual_action_plan
[
k
]))
for
i
in
range
(
len
(
expected_action_plan
[
k
])):
assert
expected_action_plan
[
k
][
i
]
==
actual_action_plan
[
k
][
i
],
\
"
not the same at agent {} at step {}
\n\n
expected = {}
\n\n
actual = {}
"
.
format
(
k
,
i
,
ActionPlanReplayer
.
pp
.
pformat
(
expected_action_plan
[
k
][
i
]),
ActionPlanReplayer
.
pp
.
pformat
(
actual_action_plan
[
k
][
i
]))
def
_add_aggent_to_action_plan
(
self
,
action_plan
,
agent_id
,
agent_path_new
):
agent
=
self
.
env
.
agents
[
agent_id
]
minimum_cell_time
=
int
(
np
.
ceil
(
1.0
/
agent
.
speed_data
[
'
speed
'
]))
for
path_loop
,
path_schedule_element
in
enumerate
(
agent_path_new
):
path_schedule_element
:
PathScheduleElement
=
path_schedule_element
position
=
(
path_schedule_element
.
cell_pin
.
r
,
path_schedule_element
.
cell_pin
.
c
)
if
Vec2d
.
is_equal
(
agent
.
target
,
position
):
break
next_path_schedule_element
:
PathScheduleElement
=
agent_path_new
[
path_loop
+
1
]
next_position
=
(
next_path_schedule_element
.
cell_pin
.
r
,
next_path_schedule_element
.
cell_pin
.
c
)
if
path_loop
==
0
:
self
.
_create_action_plan_for_first_path_element_of_agent
(
action_plan
,
agent_id
,
path_schedule_element
,
next_path_schedule_element
)
continue
just_before_target
=
Vec2d
.
is_equal
(
agent
.
target
,
next_position
)
self
.
_create_action_plan_for_current_path_element
(
action_plan
,
agent_id
,
minimum_cell_time
,
path_schedule_element
,
next_path_schedule_element
)
# add a final element
if
just_before_target
:
self
.
_create_action_plan_for_target_at_path_element_just_before_target
(
action_plan
,
agent_id
,
minimum_cell_time
,
path_schedule_element
,
next_path_schedule_element
)
def
_create_action_plan_for_current_path_element
(
self
,
action_plan
:
ActionPlan
,
agent_id
:
int
,
minimum_cell_time
:
int
,
path_schedule_element
:
PathScheduleElement
,
next_path_schedule_element
:
PathScheduleElement
):
scheduled_at
=
path_schedule_element
.
scheduled_at
next_entry_value
=
next_path_schedule_element
.
scheduled_at
position
=
(
path_schedule_element
.
cell_pin
.
r
,
path_schedule_element
.
cell_pin
.
c
)
direction
=
path_schedule_element
.
cell_pin
.
d
next_position
=
next_path_schedule_element
.
cell_pin
.
r
,
next_path_schedule_element
.
cell_pin
.
c
next_direction
=
next_path_schedule_element
.
cell_pin
.
d
next_action
=
get_action_for_move
(
position
,
direction
,
next_position
,
next_direction
,
self
.
env
.
rail
)
walking_element
=
WalkingElement
(
position
,
direction
,
next_action
)
# if the next entry is later than minimum_cell_time, then stop here and
# move minimum_cell_time before the exit
# we have to do this since agents in the RailEnv are processed in the step() in the order of their handle
if
next_entry_value
>
scheduled_at
+
minimum_cell_time
:
action
=
ActionPlanElement
(
scheduled_at
,
WalkingElement
(
position
=
position
,
direction
=
direction
,
next_action
=
RailEnvActions
.
STOP_MOVING
))
action_plan
[
agent_id
].
append
(
action
)
action
=
ActionPlanElement
(
next_entry_value
-
minimum_cell_time
,
walking_element
)
action_plan
[
agent_id
].
append
(
action
)
else
:
action
=
ActionPlanElement
(
scheduled_at
,
walking_element
)
action_plan
[
agent_id
].
append
(
action
)
def
_create_action_plan_for_target_at_path_element_just_before_target
(
self
,
action_plan
:
ActionPlan
,
agent_id
:
int
,
minimum_cell_time
:
int
,
path_schedule_element
:
PathScheduleElement
,
next_path_schedule_element
:
PathScheduleElement
):
scheduled_at
=
path_schedule_element
.
scheduled_at
next_path_schedule_element
.
cell_pin
action
=
ActionPlanElement
(
scheduled_at
+
minimum_cell_time
,
WalkingElement
(
position
=
None
,
direction
=
next_path_schedule_element
.
cell_pin
.
d
,
next_action
=
RailEnvActions
.
STOP_MOVING
))
action_plan
[
agent_id
].
append
(
action
)
def
_create_action_plan_for_first_path_element_of_agent
(
self
,
action_plan
:
ActionPlan
,
agent_id
:
int
,
path_schedule_element
:
PathScheduleElement
,
next_path_schedule_element
:
PathScheduleElement
):
scheduled_at
=
path_schedule_element
.
scheduled_at
position
=
(
path_schedule_element
.
cell_pin
.
r
,
path_schedule_element
.
cell_pin
.
c
)
direction
=
path_schedule_element
.
cell_pin
.
d
next_position
=
next_path_schedule_element
.
cell_pin
.
r
,
next_path_schedule_element
.
cell_pin
.
c
next_direction
=
next_path_schedule_element
.
cell_pin
.
d
# add intial do nothing if we do not enter immediately
if
scheduled_at
>
0
:
action
=
ActionPlanElement
(
0
,
WalkingElement
(
position
=
None
,
direction
=
direction
,
next_action
=
RailEnvActions
.
DO_NOTHING
))
action_plan
[
agent_id
].
append
(
action
)
# add action to enter the grid
action
=
ActionPlanElement
(
scheduled_at
,
WalkingElement
(
position
=
None
,
direction
=
direction
,
next_action
=
RailEnvActions
.
MOVE_FORWARD
))
action_plan
[
agent_id
].
append
(
action
)
next_action
=
get_action_for_move
(
position
,
direction
,
next_position
,
next_direction
,
self
.
env
.
rail
)
# now, we have a position need to perform the action
action
=
ActionPlanElement
(
scheduled_at
+
1
,
WalkingElement
(
position
=
position
,
direction
=
direction
,
next_action
=
next_action
))
action_plan
[
agent_id
].
append
(
action
)
This diff is collapsed.
Click to expand it.
tests/test_action_plan.py
0 → 100644
+
93
−
0
View file @
d4e6af1c
from
flatland.action_plan.action_plan
import
PathScheduleElement
,
CellPin
,
ActionPlanReplayer
from
flatland.core.grid.grid4
import
Grid4TransitionsEnum
from
flatland.envs.observations
import
TreeObsForRailEnv
from
flatland.envs.predictions
import
ShortestPathPredictorForRailEnv
from
flatland.envs.rail_env
import
RailEnv
,
RailEnvActions
from
flatland.envs.rail_env_shortest_paths
import
WalkingElement
from
flatland.envs.rail_generators
import
rail_from_grid_transition_map
from
flatland.envs.schedule_generators
import
random_schedule_generator
from
flatland.utils.simple_rail
import
make_simple_rail
def
test_action_plan
(
rendering
:
bool
=
False
):
"""
Tests ActionPlanReplayer: does action plan generation and replay work as expected.
"""
rail
,
rail_map
=
make_simple_rail
()
env
=
RailEnv
(
width
=
rail_map
.
shape
[
1
],
height
=
rail_map
.
shape
[
0
],
rail_generator
=
rail_from_grid_transition_map
(
rail
),
schedule_generator
=
random_schedule_generator
(
seed
=
77
),
number_of_agents
=
2
,
obs_builder_object
=
TreeObsForRailEnv
(
max_depth
=
2
,
predictor
=
ShortestPathPredictorForRailEnv
()),
remove_agents_at_target
=
True
)
env
.
reset
()
env
.
agents
[
0
].
initial_position
=
(
3
,
0
)
env
.
agents
[
0
].
target
=
(
3
,
8
)
env
.
agents
[
0
].
initial_direction
=
Grid4TransitionsEnum
.
WEST
env
.
agents
[
1
].
initial_position
=
(
3
,
8
)
env
.
agents
[
1
].
initial_direction
=
Grid4TransitionsEnum
.
WEST
env
.
agents
[
1
].
target
=
(
0
,
3
)
env
.
agents
[
1
].
speed_data
[
'
speed
'
]
=
0.5
# two
env
.
reset
(
False
,
False
,
False
)
for
handle
,
agent
in
enumerate
(
env
.
agents
):
print
(
"
[{}] {} -> {}
"
.
format
(
handle
,
agent
.
initial_position
,
agent
.
target
))
chosen_path_dict
=
{
0
:
[
PathScheduleElement
(
scheduled_at
=
0
,
cell_pin
=
CellPin
(
r
=
3
,
c
=
0
,
d
=
3
)),
PathScheduleElement
(
scheduled_at
=
2
,
cell_pin
=
CellPin
(
r
=
3
,
c
=
1
,
d
=
1
)),
PathScheduleElement
(
scheduled_at
=
3
,
cell_pin
=
CellPin
(
r
=
3
,
c
=
2
,
d
=
1
)),
PathScheduleElement
(
scheduled_at
=
14
,
cell_pin
=
CellPin
(
r
=
3
,
c
=
3
,
d
=
1
)),
PathScheduleElement
(
scheduled_at
=
15
,
cell_pin
=
CellPin
(
r
=
3
,
c
=
4
,
d
=
1
)),
PathScheduleElement
(
scheduled_at
=
16
,
cell_pin
=
CellPin
(
r
=
3
,
c
=
5
,
d
=
1
)),
PathScheduleElement
(
scheduled_at
=
17
,
cell_pin
=
CellPin
(
r
=
3
,
c
=
6
,
d
=
1
)),
PathScheduleElement
(
scheduled_at
=
18
,
cell_pin
=
CellPin
(
r
=
3
,
c
=
7
,
d
=
1
)),
PathScheduleElement
(
scheduled_at
=
19
,
cell_pin
=
CellPin
(
r
=
3
,
c
=
8
,
d
=
1
)),
PathScheduleElement
(
scheduled_at
=
20
,
cell_pin
=
CellPin
(
r
=
3
,
c
=
8
,
d
=
5
))],
1
:
[
PathScheduleElement
(
scheduled_at
=
0
,
cell_pin
=
CellPin
(
r
=
3
,
c
=
8
,
d
=
3
)),
PathScheduleElement
(
scheduled_at
=
3
,
cell_pin
=
CellPin
(
r
=
3
,
c
=
7
,
d
=
3
)),
PathScheduleElement
(
scheduled_at
=
5
,
cell_pin
=
CellPin
(
r
=
3
,
c
=
6
,
d
=
3
)),
PathScheduleElement
(
scheduled_at
=
7
,
cell_pin
=
CellPin
(
r
=
3
,
c
=
5
,
d
=
3
)),
PathScheduleElement
(
scheduled_at
=
9
,
cell_pin
=
CellPin
(
r
=
3
,
c
=
4
,
d
=
3
)),
PathScheduleElement
(
scheduled_at
=
11
,
cell_pin
=
CellPin
(
r
=
3
,
c
=
3
,
d
=
3
)),
PathScheduleElement
(
scheduled_at
=
13
,
cell_pin
=
CellPin
(
r
=
2
,
c
=
3
,
d
=
0
)),
PathScheduleElement
(
scheduled_at
=
15
,
cell_pin
=
CellPin
(
r
=
1
,
c
=
3
,
d
=
0
)),
PathScheduleElement
(
scheduled_at
=
17
,
cell_pin
=
CellPin
(
r
=
0
,
c
=
3
,
d
=
0
)),
PathScheduleElement
(
scheduled_at
=
18
,
cell_pin
=
CellPin
(
r
=
0
,
c
=
3
,
d
=
5
))]}
expected_action_plan
=
[[
# take action to enter the grid
(
0
,
WalkingElement
(
position
=
None
,
direction
=
3
,
next_action
=
RailEnvActions
.
MOVE_FORWARD
)),
# take action to enter the cell properly
(
1
,
WalkingElement
(
position
=
(
3
,
0
),
direction
=
3
,
next_action
=
RailEnvActions
.
MOVE_FORWARD
)),
(
2
,
WalkingElement
(
position
=
(
3
,
1
),
direction
=
1
,
next_action
=
RailEnvActions
.
MOVE_FORWARD
)),
(
3
,
WalkingElement
(
position
=
(
3
,
2
),
direction
=
1
,
next_action
=
RailEnvActions
.
STOP_MOVING
)),
(
13
,
WalkingElement
(
position
=
(
3
,
2
),
direction
=
1
,
next_action
=
RailEnvActions
.
MOVE_FORWARD
)),
(
14
,
WalkingElement
(
position
=
(
3
,
3
),
direction
=
1
,
next_action
=
RailEnvActions
.
MOVE_FORWARD
)),
(
15
,
WalkingElement
(
position
=
(
3
,
4
),
direction
=
1
,
next_action
=
RailEnvActions
.
MOVE_FORWARD
)),
(
16
,
WalkingElement
(
position
=
(
3
,
5
),
direction
=
1
,
next_action
=
RailEnvActions
.
MOVE_FORWARD
)),
(
17
,
WalkingElement
(
position
=
(
3
,
6
),
direction
=
1
,
next_action
=
RailEnvActions
.
MOVE_FORWARD
)),
(
18
,
WalkingElement
(
position
=
(
3
,
7
),
direction
=
1
,
next_action
=
RailEnvActions
.
MOVE_FORWARD
)),
(
19
,
WalkingElement
(
position
=
None
,
direction
=
1
,
next_action
=
RailEnvActions
.
STOP_MOVING
))
],
[
(
0
,
WalkingElement
(
position
=
None
,
direction
=
3
,
next_action
=
RailEnvActions
.
MOVE_FORWARD
)),
(
1
,
WalkingElement
(
position
=
(
3
,
8
),
direction
=
3
,
next_action
=
RailEnvActions
.
MOVE_FORWARD
)),
(
3
,
WalkingElement
(
position
=
(
3
,
7
),
direction
=
3
,
next_action
=
RailEnvActions
.
MOVE_FORWARD
)),
(
5
,
WalkingElement
(
position
=
(
3
,
6
),
direction
=
3
,
next_action
=
RailEnvActions
.
MOVE_FORWARD
)),
(
7
,
WalkingElement
(
position
=
(
3
,
5
),
direction
=
3
,
next_action
=
RailEnvActions
.
MOVE_FORWARD
)),
(
9
,
WalkingElement
(
position
=
(
3
,
4
),
direction
=
3
,
next_action
=
RailEnvActions
.
MOVE_FORWARD
)),
(
11
,
WalkingElement
(
position
=
(
3
,
3
),
direction
=
3
,
next_action
=
RailEnvActions
.
MOVE_RIGHT
)),
(
13
,
WalkingElement
(
position
=
(
2
,
3
),
direction
=
0
,
next_action
=
RailEnvActions
.
MOVE_FORWARD
)),
(
15
,
WalkingElement
(
position
=
(
1
,
3
),
direction
=
0
,
next_action
=
RailEnvActions
.
MOVE_FORWARD
)),
(
17
,
WalkingElement
(
position
=
None
,
direction
=
0
,
next_action
=
RailEnvActions
.
STOP_MOVING
)),
]]
MAX_EPISODE_STEPS
=
50
actual_action_plan
=
ActionPlanReplayer
(
env
,
chosen_path_dict
)
actual_action_plan
.
print_action_plan
()
ActionPlanReplayer
.
compare_action_plans
(
expected_action_plan
,
actual_action_plan
.
action_plan
)
assert
actual_action_plan
.
action_plan
==
expected_action_plan
,
\
"
expected {}, found {}
"
.
format
(
expected_action_plan
,
actual_action_plan
.
action_plan
)
actual_action_plan
.
replay_verify
(
MAX_EPISODE_STEPS
,
env
,
rendering
)
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment