Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
Flatland
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
yoogottamk
Flatland
Commits
5bf451eb
Commit
5bf451eb
authored
5 years ago
by
spiglerg
Browse files
Options
Downloads
Patches
Plain Diff
prevent stopping in the middle of a cell
parent
65397f68
No related branches found
No related tags found
No related merge requests found
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
flatland/envs/agent_utils.py
+19
-6
19 additions, 6 deletions
flatland/envs/agent_utils.py
flatland/envs/rail_env.py
+79
-47
79 additions, 47 deletions
flatland/envs/rail_env.py
with
98 additions
and
53 deletions
flatland/envs/agent_utils.py
+
19
−
6
View file @
5bf451eb
...
@@ -28,19 +28,32 @@ class EnvAgentStatic(object):
...
@@ -28,19 +28,32 @@ class EnvAgentStatic(object):
position
=
attrib
()
position
=
attrib
()
direction
=
attrib
()
direction
=
attrib
()
target
=
attrib
()
target
=
attrib
()
moving
=
attrib
()
moving
=
attrib
(
default
=
False
)
# speed_data: speed is added to position_fraction on each moving step, until position_fraction>=1.0,
def
__init__
(
self
,
position
,
direction
,
target
,
moving
=
False
):
# after which 'transition_action_on_cellexit' is executed (equivalent to executing that action in the previous
# cell if speed=1, as default)
speed_data
=
attrib
(
default
=
dict
({
'
position_fraction
'
:
0.0
,
'
speed
'
:
1.0
,
'
transition_action_on_cellexit
'
:
0
}))
def
__init__
(
self
,
position
,
direction
,
target
,
moving
=
False
,
speed_data
=
{
'
position_fraction
'
:
0.0
,
'
speed
'
:
1.0
,
'
transition_action_on_cellexit
'
:
0
}):
self
.
position
=
position
self
.
position
=
position
self
.
direction
=
direction
self
.
direction
=
direction
self
.
target
=
target
self
.
target
=
target
self
.
moving
=
moving
self
.
moving
=
moving
self
.
speed_data
=
speed_data
@classmethod
@classmethod
def
from_lists
(
cls
,
positions
,
directions
,
targets
):
def
from_lists
(
cls
,
positions
,
directions
,
targets
):
"""
Create a list of EnvAgentStatics from lists of positions, directions and targets
"""
Create a list of EnvAgentStatics from lists of positions, directions and targets
"""
"""
return
list
(
starmap
(
EnvAgentStatic
,
zip
(
positions
,
directions
,
targets
,
[
False
]
*
len
(
positions
))))
speed_datas
=
[]
for
i
in
range
(
len
(
positions
)):
speed_datas
.
append
({
'
position_fraction
'
:
0.0
,
'
speed
'
:
1.0
,
'
transition_action_on_cellexit
'
:
0
})
return
list
(
starmap
(
EnvAgentStatic
,
zip
(
positions
,
directions
,
targets
,
[
False
]
*
len
(
positions
),
speed_datas
)))
def
to_list
(
self
):
def
to_list
(
self
):
...
@@ -54,7 +67,7 @@ class EnvAgentStatic(object):
...
@@ -54,7 +67,7 @@ class EnvAgentStatic(object):
if
type
(
lTarget
)
is
np
.
ndarray
:
if
type
(
lTarget
)
is
np
.
ndarray
:
lTarget
=
lTarget
.
tolist
()
lTarget
=
lTarget
.
tolist
()
return
[
lPos
,
int
(
self
.
direction
),
lTarget
,
int
(
self
.
moving
)]
return
[
lPos
,
int
(
self
.
direction
),
lTarget
,
int
(
self
.
moving
)
,
self
.
speed_data
]
@attrs
@attrs
...
@@ -78,7 +91,7 @@ class EnvAgent(EnvAgentStatic):
...
@@ -78,7 +91,7 @@ class EnvAgent(EnvAgentStatic):
def
to_list
(
self
):
def
to_list
(
self
):
return
[
return
[
self
.
position
,
self
.
direction
,
self
.
target
,
self
.
handle
,
self
.
position
,
self
.
direction
,
self
.
target
,
self
.
handle
,
self
.
old_direction
,
self
.
old_position
,
self
.
moving
]
self
.
old_direction
,
self
.
old_position
,
self
.
moving
,
self
.
speed_data
]
@classmethod
@classmethod
def
from_static
(
cls
,
oStatic
):
def
from_static
(
cls
,
oStatic
):
...
...
This diff is collapsed.
Click to expand it.
flatland/envs/rail_env.py
+
79
−
47
View file @
5bf451eb
...
@@ -73,7 +73,7 @@ class RailEnv(Environment):
...
@@ -73,7 +73,7 @@ class RailEnv(Environment):
random_rail_generator : generate a random rail of given size
random_rail_generator : generate a random rail of given size
rail_from_GridTransitionMap_generator(rail_map) : generate a rail from
rail_from_GridTransitionMap_generator(rail_map) : generate a rail from
a GridTransitionMap object
a GridTransitionMap object
rail_from_manual_specifications_generator(rail_spec) : generate a rail from
rail_from_manual_sp
ecifications_generator(rail_spec) : generate a rail from
a rail specifications array
a rail specifications array
TODO: generate_rail_from_saved_list or from list of ndarray bitmaps ---
TODO: generate_rail_from_saved_list or from list of ndarray bitmaps ---
width : int
width : int
...
@@ -101,7 +101,6 @@ class RailEnv(Environment):
...
@@ -101,7 +101,6 @@ class RailEnv(Environment):
self
.
action_space
=
[
1
]
self
.
action_space
=
[
1
]
self
.
observation_space
=
self
.
obs_builder
.
observation_space
# updated on resets?
self
.
observation_space
=
self
.
obs_builder
.
observation_space
# updated on resets?
self
.
actions
=
[
0
]
*
number_of_agents
self
.
rewards
=
[
0
]
*
number_of_agents
self
.
rewards
=
[
0
]
*
number_of_agents
self
.
done
=
False
self
.
done
=
False
...
@@ -192,29 +191,33 @@ class RailEnv(Environment):
...
@@ -192,29 +191,33 @@ class RailEnv(Environment):
# for i in range(len(self.agents_handles)):
# for i in range(len(self.agents_handles)):
for
iAgent
in
range
(
self
.
get_num_agents
()):
for
iAgent
in
range
(
self
.
get_num_agents
()):
agent
=
self
.
agents
[
iAgent
]
agent
=
self
.
agents
[
iAgent
]
agent
.
speed_data
[
'
speed
'
]
=
0.5
if
iAgent
not
in
action_dict
:
# no action has been supplied for this agent
if
agent
.
moving
:
# Keep moving
# Change MOVE_FORWARD to DO_NOTHING
action_dict
[
iAgent
]
=
RailEnvActions
.
DO_NOTHING
else
:
action_dict
[
iAgent
]
=
RailEnvActions
.
DO_NOTHING
if
self
.
dones
[
iAgent
]:
# this agent has already completed...
if
self
.
dones
[
iAgent
]:
# this agent has already completed...
continue
continue
action
=
action_dict
[
iAgent
]
if
action
<
0
or
action
>
len
(
RailEnvActions
):
if
np
.
equal
(
agent
.
position
,
agent
.
target
).
all
():
print
(
'
ERROR: illegal action=
'
,
action
,
self
.
dones
[
iAgent
]
=
True
'
for agent with index=
'
,
iAgent
)
else
:
return
self
.
rewards_dict
[
iAgent
]
+=
step_penalty
*
agent
.
speed_data
[
'
speed
'
]
if
iAgent
not
in
action_dict
:
# no action has been supplied for this agent
action_dict
[
iAgent
]
=
RailEnvActions
.
DO_NOTHING
if
action_dict
[
iAgent
]
<
0
or
action_dict
[
iAgent
]
>
len
(
RailEnvActions
):
print
(
'
ERROR: illegal action=
'
,
action_dict
[
iAgent
],
'
for agent with index=
'
,
iAgent
,
'"
DO NOTHING
"
will be executed instead
'
)
action_dict
[
iAgent
]
=
RailEnvActions
.
DO_NOTHING
action
=
action_dict
[
iAgent
]
if
action
==
RailEnvActions
.
DO_NOTHING
and
agent
.
moving
:
if
action
==
RailEnvActions
.
DO_NOTHING
and
agent
.
moving
:
# Keep moving
# Keep moving
action
=
RailEnvActions
.
MOVE_FORWARD
action
=
RailEnvActions
.
MOVE_FORWARD
if
action
==
RailEnvActions
.
STOP_MOVING
and
agent
.
moving
:
if
action
==
RailEnvActions
.
STOP_MOVING
and
agent
.
moving
and
agent
.
speed_data
[
'
position_fraction
'
]
<
0.01
:
# Only allow halting an agent on entering new cells.
agent
.
moving
=
False
agent
.
moving
=
False
self
.
rewards_dict
[
iAgent
]
+=
stop_penalty
self
.
rewards_dict
[
iAgent
]
+=
stop_penalty
...
@@ -223,47 +226,73 @@ class RailEnv(Environment):
...
@@ -223,47 +226,73 @@ class RailEnv(Environment):
agent
.
moving
=
True
agent
.
moving
=
True
self
.
rewards_dict
[
iAgent
]
+=
start_penalty
self
.
rewards_dict
[
iAgent
]
+=
start_penalty
if
action
!=
RailEnvActions
.
DO_NOTHING
and
action
!=
RailEnvActions
.
STOP_MOVING
:
# Now perform a movement.
cell_isFree
,
new_cell_isValid
,
new_direction
,
new_position
,
transition_isValid
=
\
# If the agent is in an initial position within a new cell (agent.speed_data['position_fraction']<eps)
self
.
_check_action_on_agent
(
action
,
agent
)
# store the desired action in `transition_action_on_cellexit' (only if the desired transition is
if
all
([
new_cell_isValid
,
transition_isValid
,
cell_isFree
]):
# allowed! otherwise DO_NOTHING!)
agent
.
old_direction
=
agent
.
direction
# Then in any case (if agent.moving) and the `transition_action_on_cellexit' is valid, increment the
agent
.
old_position
=
agent
.
position
# position_fraction by the speed of the agent (regardless of action taken, as long as no
agent
.
position
=
new_position
# STOP_MOVING, but that makes agent.moving=False)
agent
.
direction
=
new_direction
# If the new position fraction is >= 1, reset to 0, and perform the stored
else
:
# transition_action_on_cellexit
# Logic: if the chosen action is invalid,
# and it was LEFT or RIGHT, and the agent was moving, then keep moving FORWARD.
# If the agent can make an action
if
(
action
==
RailEnvActions
.
MOVE_LEFT
or
action
==
RailEnvActions
.
MOVE_RIGHT
)
and
agent
.
moving
:
action_selected
=
False
cell_isFree
,
new_cell_isValid
,
new_direction
,
new_position
,
transition_isValid
=
\
if
agent
.
speed_data
[
'
position_fraction
'
]
<
0.01
:
self
.
_check_action_on_agent
(
RailEnvActions
.
MOVE_FORWARD
,
agent
)
if
action
!=
RailEnvActions
.
DO_NOTHING
and
action
!=
RailEnvActions
.
STOP_MOVING
:
cell_isFree
,
new_cell_isValid
,
new_direction
,
new_position
,
transition_isValid
=
\
if
all
([
new_cell_isValid
,
transition_isValid
,
cell_isFree
]):
self
.
_check_action_on_agent
(
action
,
agent
)
agent
.
old_direction
=
agent
.
direction
agent
.
old_position
=
agent
.
position
if
all
([
new_cell_isValid
,
transition_isValid
,
cell_isFree
]):
agent
.
position
=
new_position
agent
.
speed_data
[
'
transition_action_on_cellexit
'
]
=
action
agent
.
direction
=
new_direction
action_selected
=
True
else
:
# But, if the chosen invalid action was LEFT/RIGHT, and the agent is moving,
# try to keep moving forward!
if
(
action
==
RailEnvActions
.
MOVE_LEFT
or
action
==
RailEnvActions
.
MOVE_RIGHT
)
and
agent
.
moving
:
cell_isFree
,
new_cell_isValid
,
new_direction
,
new_position
,
transition_isValid
=
\
self
.
_check_action_on_agent
(
RailEnvActions
.
MOVE_FORWARD
,
agent
)
if
all
([
new_cell_isValid
,
transition_isValid
,
cell_isFree
]):
agent
.
speed_data
[
'
transition_action_on_cellexit
'
]
=
RailEnvActions
.
MOVE_FORWARD
action_selected
=
True
else
:
# TODO: an invalid action was chosen after entering the cell. The agent cannot move.
self
.
rewards_dict
[
iAgent
]
+=
invalid_action_penalty
agent
.
moving
=
False
self
.
rewards_dict
[
iAgent
]
+=
stop_penalty
continue
else
:
else
:
#
the action was not valid, add penalty
#
TODO: an invalid action was chosen after entering the cell. The agent cannot move.
self
.
rewards_dict
[
iAgent
]
+=
invalid_action_penalty
self
.
rewards_dict
[
iAgent
]
+=
invalid_action_penalty
agent
.
moving
=
False
self
.
rewards_dict
[
iAgent
]
+=
stop_penalty
continue
else
:
if
agent
.
moving
and
(
action_selected
or
agent
.
speed_data
[
'
position_fraction
'
]
>=
0.01
):
# the action was not valid, add penalty
agent
.
speed_data
[
'
position_fraction
'
]
+=
agent
.
speed_data
[
'
speed
'
]
self
.
rewards_dict
[
iAgent
]
+=
invalid_action_penalty
if
np
.
equal
(
agent
.
position
,
agent
.
target
).
all
():
if
agent
.
speed_data
[
'
position_fraction
'
]
>=
1.0
:
self
.
dones
[
iAgent
]
=
True
agent
.
speed_data
[
'
position_fraction
'
]
=
0.0
else
:
self
.
rewards_dict
[
iAgent
]
+=
step_penalty
# Perform stored action to transition to the next cell
# Now 'transition_action_on_cellexit' will be guaranteed to be valid; it was checked on entering
# the cell
cell_isFree
,
new_cell_isValid
,
new_direction
,
new_position
,
transition_isValid
=
\
self
.
_check_action_on_agent
(
agent
.
speed_data
[
'
transition_action_on_cellexit
'
],
agent
)
agent
.
old_direction
=
agent
.
direction
agent
.
old_position
=
agent
.
position
agent
.
position
=
new_position
agent
.
direction
=
new_direction
# Check for end of episode + add global reward to all rewards!
# Check for end of episode + add global reward to all rewards!
if
np
.
all
([
np
.
array_equal
(
agent2
.
position
,
agent2
.
target
)
for
agent2
in
self
.
agents
]):
if
np
.
all
([
np
.
array_equal
(
agent2
.
position
,
agent2
.
target
)
for
agent2
in
self
.
agents
]):
self
.
dones
[
"
__all__
"
]
=
True
self
.
dones
[
"
__all__
"
]
=
True
self
.
rewards_dict
=
[
0
*
r
+
global_reward
for
r
in
self
.
rewards_dict
]
self
.
rewards_dict
=
[
0
*
r
+
global_reward
for
r
in
self
.
rewards_dict
]
# Reset the step actions (in case some agent doesn't 'register_action'
# on the next step)
self
.
actions
=
[
0
]
*
self
.
get_num_agents
()
return
self
.
_get_observations
(),
self
.
rewards_dict
,
self
.
dones
,
{}
return
self
.
_get_observations
(),
self
.
rewards_dict
,
self
.
dones
,
{}
def
_check_action_on_agent
(
self
,
action
,
agent
):
def
_check_action_on_agent
(
self
,
action
,
agent
):
...
@@ -271,6 +300,7 @@ class RailEnv(Environment):
...
@@ -271,6 +300,7 @@ class RailEnv(Environment):
# cell used to check for invalid actions
# cell used to check for invalid actions
new_direction
,
transition_isValid
=
self
.
check_action
(
agent
,
action
)
new_direction
,
transition_isValid
=
self
.
check_action
(
agent
,
action
)
new_position
=
get_new_position
(
agent
.
position
,
new_direction
)
new_position
=
get_new_position
(
agent
.
position
,
new_direction
)
# Is it a legal move?
# Is it a legal move?
# 1) transition allows the new_direction in the cell,
# 1) transition allows the new_direction in the cell,
# 2) the new cell is not empty (case 0),
# 2) the new cell is not empty (case 0),
...
@@ -281,11 +311,13 @@ class RailEnv(Environment):
...
@@ -281,11 +311,13 @@ class RailEnv(Environment):
np
.
clip
(
new_position
,
[
0
,
0
],
[
self
.
height
-
1
,
self
.
width
-
1
]))
np
.
clip
(
new_position
,
[
0
,
0
],
[
self
.
height
-
1
,
self
.
width
-
1
]))
and
# check the new position has some transitions (ie is not an empty cell)
and
# check the new position has some transitions (ie is not an empty cell)
self
.
rail
.
get_transitions
(
new_position
)
>
0
)
self
.
rail
.
get_transitions
(
new_position
)
>
0
)
# If transition validity hasn't been checked yet.
# If transition validity hasn't been checked yet.
if
transition_isValid
is
None
:
if
transition_isValid
is
None
:
transition_isValid
=
self
.
rail
.
get_transition
(
transition_isValid
=
self
.
rail
.
get_transition
(
(
*
agent
.
position
,
agent
.
direction
),
(
*
agent
.
position
,
agent
.
direction
),
new_direction
)
new_direction
)
# Check the new position is not the same as any of the existing agent positions
# Check the new position is not the same as any of the existing agent positions
# (including itself, for simplicity, since it is moving)
# (including itself, for simplicity, since it is moving)
cell_isFree
=
not
np
.
any
(
cell_isFree
=
not
np
.
any
(
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment