Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Flatland
Flatland
Commits
5bf451eb
Commit
5bf451eb
authored
Jun 19, 2019
by
spiglerg
Browse files
prevent stopping in the middle of a cell
parent
65397f68
Changes
2
Hide whitespace changes
Inline
Side-by-side
flatland/envs/agent_utils.py
View file @
5bf451eb
...
...
@@ -28,19 +28,32 @@ class EnvAgentStatic(object):
position
=
attrib
()
direction
=
attrib
()
target
=
attrib
()
moving
=
attrib
()
def
__init__
(
self
,
position
,
direction
,
target
,
moving
=
False
):
moving
=
attrib
(
default
=
False
)
# speed_data: speed is added to position_fraction on each moving step, until position_fraction>=1.0,
# after which 'transition_action_on_cellexit' is executed (equivalent to executing that action in the previous
# cell if speed=1, as default)
speed_data
=
attrib
(
default
=
dict
({
'position_fraction'
:
0.0
,
'speed'
:
1.0
,
'transition_action_on_cellexit'
:
0
}))
def
__init__
(
self
,
position
,
direction
,
target
,
moving
=
False
,
speed_data
=
{
'position_fraction'
:
0.0
,
'speed'
:
1.0
,
'transition_action_on_cellexit'
:
0
}):
self
.
position
=
position
self
.
direction
=
direction
self
.
target
=
target
self
.
moving
=
moving
self
.
speed_data
=
speed_data
@
classmethod
def
from_lists
(
cls
,
positions
,
directions
,
targets
):
""" Create a list of EnvAgentStatics from lists of positions, directions and targets
"""
return
list
(
starmap
(
EnvAgentStatic
,
zip
(
positions
,
directions
,
targets
,
[
False
]
*
len
(
positions
))))
speed_datas
=
[]
for
i
in
range
(
len
(
positions
)):
speed_datas
.
append
({
'position_fraction'
:
0.0
,
'speed'
:
1.0
,
'transition_action_on_cellexit'
:
0
})
return
list
(
starmap
(
EnvAgentStatic
,
zip
(
positions
,
directions
,
targets
,
[
False
]
*
len
(
positions
),
speed_datas
)))
def
to_list
(
self
):
...
...
@@ -54,7 +67,7 @@ class EnvAgentStatic(object):
if
type
(
lTarget
)
is
np
.
ndarray
:
lTarget
=
lTarget
.
tolist
()
return
[
lPos
,
int
(
self
.
direction
),
lTarget
,
int
(
self
.
moving
)]
return
[
lPos
,
int
(
self
.
direction
),
lTarget
,
int
(
self
.
moving
)
,
self
.
speed_data
]
@
attrs
...
...
@@ -78,7 +91,7 @@ class EnvAgent(EnvAgentStatic):
def
to_list
(
self
):
return
[
self
.
position
,
self
.
direction
,
self
.
target
,
self
.
handle
,
self
.
old_direction
,
self
.
old_position
,
self
.
moving
]
self
.
old_direction
,
self
.
old_position
,
self
.
moving
,
self
.
speed_data
]
@
classmethod
def
from_static
(
cls
,
oStatic
):
...
...
flatland/envs/rail_env.py
View file @
5bf451eb
...
...
@@ -73,7 +73,7 @@ class RailEnv(Environment):
random_rail_generator : generate a random rail of given size
rail_from_GridTransitionMap_generator(rail_map) : generate a rail from
a GridTransitionMap object
rail_from_manual_specifications_generator(rail_spec) : generate a rail from
rail_from_manual_sp
ecifications_generator(rail_spec) : generate a rail from
a rail specifications array
TODO: generate_rail_from_saved_list or from list of ndarray bitmaps ---
width : int
...
...
@@ -101,7 +101,6 @@ class RailEnv(Environment):
self
.
action_space
=
[
1
]
self
.
observation_space
=
self
.
obs_builder
.
observation_space
# updated on resets?
self
.
actions
=
[
0
]
*
number_of_agents
self
.
rewards
=
[
0
]
*
number_of_agents
self
.
done
=
False
...
...
@@ -192,29 +191,33 @@ class RailEnv(Environment):
# for i in range(len(self.agents_handles)):
for
iAgent
in
range
(
self
.
get_num_agents
()):
agent
=
self
.
agents
[
iAgent
]
if
iAgent
not
in
action_dict
:
# no action has been supplied for this agent
if
agent
.
moving
:
# Keep moving
# Change MOVE_FORWARD to DO_NOTHING
action_dict
[
iAgent
]
=
RailEnvActions
.
DO_NOTHING
else
:
action_dict
[
iAgent
]
=
RailEnvActions
.
DO_NOTHING
agent
.
speed_data
[
'speed'
]
=
0.5
if
self
.
dones
[
iAgent
]:
# this agent has already completed...
continue
action
=
action_dict
[
iAgent
]
if
action
<
0
or
action
>
len
(
RailEnvActions
):
print
(
'ERROR: illegal action='
,
action
,
'for agent with index='
,
iAgent
)
return
if
np
.
equal
(
agent
.
position
,
agent
.
target
).
all
():
self
.
dones
[
iAgent
]
=
True
else
:
self
.
rewards_dict
[
iAgent
]
+=
step_penalty
*
agent
.
speed_data
[
'speed'
]
if
iAgent
not
in
action_dict
:
# no action has been supplied for this agent
action_dict
[
iAgent
]
=
RailEnvActions
.
DO_NOTHING
if
action_dict
[
iAgent
]
<
0
or
action_dict
[
iAgent
]
>
len
(
RailEnvActions
):
print
(
'ERROR: illegal action='
,
action_dict
[
iAgent
],
'for agent with index='
,
iAgent
,
'"DO NOTHING" will be executed instead'
)
action_dict
[
iAgent
]
=
RailEnvActions
.
DO_NOTHING
action
=
action_dict
[
iAgent
]
if
action
==
RailEnvActions
.
DO_NOTHING
and
agent
.
moving
:
# Keep moving
action
=
RailEnvActions
.
MOVE_FORWARD
if
action
==
RailEnvActions
.
STOP_MOVING
and
agent
.
moving
:
if
action
==
RailEnvActions
.
STOP_MOVING
and
agent
.
moving
and
agent
.
speed_data
[
'position_fraction'
]
<
0.01
:
# Only allow halting an agent on entering new cells.
agent
.
moving
=
False
self
.
rewards_dict
[
iAgent
]
+=
stop_penalty
...
...
@@ -223,47 +226,73 @@ class RailEnv(Environment):
agent
.
moving
=
True
self
.
rewards_dict
[
iAgent
]
+=
start_penalty
if
action
!=
RailEnvActions
.
DO_NOTHING
and
action
!=
RailEnvActions
.
STOP_MOVING
:
cell_isFree
,
new_cell_isValid
,
new_direction
,
new_position
,
transition_isValid
=
\
self
.
_check_action_on_agent
(
action
,
agent
)
if
all
([
new_cell_isValid
,
transition_isValid
,
cell_isFree
]):
agent
.
old_direction
=
agent
.
direction
agent
.
old_position
=
agent
.
position
agent
.
position
=
new_position
agent
.
direction
=
new_direction
else
:
# Logic: if the chosen action is invalid,
# and it was LEFT or RIGHT, and the agent was moving, then keep moving FORWARD.
if
(
action
==
RailEnvActions
.
MOVE_LEFT
or
action
==
RailEnvActions
.
MOVE_RIGHT
)
and
agent
.
moving
:
cell_isFree
,
new_cell_isValid
,
new_direction
,
new_position
,
transition_isValid
=
\
self
.
_check_action_on_agent
(
RailEnvActions
.
MOVE_FORWARD
,
agent
)
if
all
([
new_cell_isValid
,
transition_isValid
,
cell_isFree
]):
agent
.
old_direction
=
agent
.
direction
agent
.
old_position
=
agent
.
position
agent
.
position
=
new_position
agent
.
direction
=
new_direction
# Now perform a movement.
# If the agent is in an initial position within a new cell (agent.speed_data['position_fraction']<eps)
# store the desired action in `transition_action_on_cellexit' (only if the desired transition is
# allowed! otherwise DO_NOTHING!)
# Then in any case (if agent.moving) and the `transition_action_on_cellexit' is valid, increment the
# position_fraction by the speed of the agent (regardless of action taken, as long as no
# STOP_MOVING, but that makes agent.moving=False)
# If the new position fraction is >= 1, reset to 0, and perform the stored
# transition_action_on_cellexit
# If the agent can make an action
action_selected
=
False
if
agent
.
speed_data
[
'position_fraction'
]
<
0.01
:
if
action
!=
RailEnvActions
.
DO_NOTHING
and
action
!=
RailEnvActions
.
STOP_MOVING
:
cell_isFree
,
new_cell_isValid
,
new_direction
,
new_position
,
transition_isValid
=
\
self
.
_check_action_on_agent
(
action
,
agent
)
if
all
([
new_cell_isValid
,
transition_isValid
,
cell_isFree
]):
agent
.
speed_data
[
'transition_action_on_cellexit'
]
=
action
action_selected
=
True
else
:
# But, if the chosen invalid action was LEFT/RIGHT, and the agent is moving,
# try to keep moving forward!
if
(
action
==
RailEnvActions
.
MOVE_LEFT
or
action
==
RailEnvActions
.
MOVE_RIGHT
)
and
agent
.
moving
:
cell_isFree
,
new_cell_isValid
,
new_direction
,
new_position
,
transition_isValid
=
\
self
.
_check_action_on_agent
(
RailEnvActions
.
MOVE_FORWARD
,
agent
)
if
all
([
new_cell_isValid
,
transition_isValid
,
cell_isFree
]):
agent
.
speed_data
[
'transition_action_on_cellexit'
]
=
RailEnvActions
.
MOVE_FORWARD
action_selected
=
True
else
:
# TODO: an invalid action was chosen after entering the cell. The agent cannot move.
self
.
rewards_dict
[
iAgent
]
+=
invalid_action_penalty
agent
.
moving
=
False
self
.
rewards_dict
[
iAgent
]
+=
stop_penalty
continue
else
:
#
the action was not valid, add penalty
#
TODO: an invalid action was chosen after entering the cell. The agent cannot move.
self
.
rewards_dict
[
iAgent
]
+=
invalid_action_penalty
agent
.
moving
=
False
self
.
rewards_dict
[
iAgent
]
+=
stop_penalty
continue
else
:
# the action was not valid, add penalty
self
.
rewards_dict
[
iAgent
]
+=
invalid_action_penalty
if
agent
.
moving
and
(
action_selected
or
agent
.
speed_data
[
'position_fraction'
]
>=
0.01
):
agent
.
speed_data
[
'position_fraction'
]
+=
agent
.
speed_data
[
'speed'
]
if
np
.
equal
(
agent
.
position
,
agent
.
target
).
all
():
self
.
dones
[
iAgent
]
=
True
else
:
self
.
rewards_dict
[
iAgent
]
+=
step_penalty
if
agent
.
speed_data
[
'position_fraction'
]
>=
1.0
:
agent
.
speed_data
[
'position_fraction'
]
=
0.0
# Perform stored action to transition to the next cell
# Now 'transition_action_on_cellexit' will be guaranteed to be valid; it was checked on entering
# the cell
cell_isFree
,
new_cell_isValid
,
new_direction
,
new_position
,
transition_isValid
=
\
self
.
_check_action_on_agent
(
agent
.
speed_data
[
'transition_action_on_cellexit'
],
agent
)
agent
.
old_direction
=
agent
.
direction
agent
.
old_position
=
agent
.
position
agent
.
position
=
new_position
agent
.
direction
=
new_direction
# Check for end of episode + add global reward to all rewards!
if
np
.
all
([
np
.
array_equal
(
agent2
.
position
,
agent2
.
target
)
for
agent2
in
self
.
agents
]):
self
.
dones
[
"__all__"
]
=
True
self
.
rewards_dict
=
[
0
*
r
+
global_reward
for
r
in
self
.
rewards_dict
]
# Reset the step actions (in case some agent doesn't 'register_action'
# on the next step)
self
.
actions
=
[
0
]
*
self
.
get_num_agents
()
return
self
.
_get_observations
(),
self
.
rewards_dict
,
self
.
dones
,
{}
def
_check_action_on_agent
(
self
,
action
,
agent
):
...
...
@@ -271,6 +300,7 @@ class RailEnv(Environment):
# cell used to check for invalid actions
new_direction
,
transition_isValid
=
self
.
check_action
(
agent
,
action
)
new_position
=
get_new_position
(
agent
.
position
,
new_direction
)
# Is it a legal move?
# 1) transition allows the new_direction in the cell,
# 2) the new cell is not empty (case 0),
...
...
@@ -281,11 +311,13 @@ class RailEnv(Environment):
np
.
clip
(
new_position
,
[
0
,
0
],
[
self
.
height
-
1
,
self
.
width
-
1
]))
and
# check the new position has some transitions (ie is not an empty cell)
self
.
rail
.
get_transitions
(
new_position
)
>
0
)
# If transition validity hasn't been checked yet.
if
transition_isValid
is
None
:
transition_isValid
=
self
.
rail
.
get_transition
(
(
*
agent
.
position
,
agent
.
direction
),
new_direction
)
# Check the new position is not the same as any of the existing agent positions
# (including itself, for simplicity, since it is moving)
cell_isFree
=
not
np
.
any
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment