Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
Flatland
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
pranjal_dhole
Flatland
Commits
9d28be8f
Commit
9d28be8f
authored
4 years ago
by
hagrid67
Browse files
Options
Downloads
Patches
Plain Diff
manually merging Adrian's changes (made by Erik) from master
parent
be3c58b0
No related branches found
No related tags found
No related merge requests found
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
flatland/envs/agent_chains.py
+0
-8
0 additions, 8 deletions
flatland/envs/agent_chains.py
flatland/envs/rail_env.py
+50
-28
50 additions, 28 deletions
flatland/envs/rail_env.py
flatland/utils/env_edit_utils.py
+6
-1
6 additions, 1 deletion
flatland/utils/env_edit_utils.py
with
56 additions
and
37 deletions
flatland/envs/agent_chains.py
+
0
−
8
View file @
9d28be8f
...
@@ -2,7 +2,6 @@
...
@@ -2,7 +2,6 @@
import
networkx
as
nx
import
networkx
as
nx
import
numpy
as
np
import
numpy
as
np
import
matplotlib.pyplot
as
plt
from
typing
import
List
,
Tuple
from
typing
import
List
,
Tuple
import
graphviz
as
gv
import
graphviz
as
gv
...
@@ -372,18 +371,11 @@ def test_agent_following():
...
@@ -372,18 +371,11 @@ def test_agent_following():
for
v
in
lvCells
]
for
v
in
lvCells
]
dPos
=
dict
(
zip
(
lvCells
,
lvCells
))
dPos
=
dict
(
zip
(
lvCells
,
lvCells
))
#plt.ion()
nx
.
draw
(
omc
.
G
,
nx
.
draw
(
omc
.
G
,
with_labels
=
True
,
arrowsize
=
20
,
with_labels
=
True
,
arrowsize
=
20
,
pos
=
dPos
,
pos
=
dPos
,
node_color
=
lColours
)
node_color
=
lColours
)
#plt.pause(20)
#plt.show()
def
main
():
def
main
():
test_agent_following
()
test_agent_following
()
...
...
This diff is collapsed.
Click to expand it.
flatland/envs/rail_env.py
+
50
−
28
View file @
9d28be8f
...
@@ -4,13 +4,10 @@ Definition of the RailEnv environment.
...
@@ -4,13 +4,10 @@ Definition of the RailEnv environment.
import
random
import
random
# TODO: _ this is a global method --> utils or remove later
# TODO: _ this is a global method --> utils or remove later
from
enum
import
IntEnum
from
enum
import
IntEnum
from
typing
import
List
,
NamedTuple
,
Optional
,
Dict
from
typing
import
List
,
NamedTuple
,
Optional
,
Dict
,
Tuple
import
msgpack
import
msgpack_numpy
as
m
import
numpy
as
np
import
numpy
as
np
from
gym.utils
import
seeding
from
msgpack
import
Packer
from
flatland.core.env
import
Environment
from
flatland.core.env
import
Environment
from
flatland.core.env_observation_builder
import
ObservationBuilder
from
flatland.core.env_observation_builder
import
ObservationBuilder
...
@@ -28,21 +25,50 @@ from flatland.envs import schedule_generators as sched_gen
...
@@ -28,21 +25,50 @@ from flatland.envs import schedule_generators as sched_gen
from
flatland.envs
import
persistence
from
flatland.envs
import
persistence
from
flatland.envs
import
agent_chains
as
ac
from
flatland.envs
import
agent_chains
as
ac
from
flatland.envs.observations
import
GlobalObsForRailEnv
from
gym.utils
import
seeding
# Direct import of objects / classes does not work with circular imports.
# Direct import of objects / classes does not work with circular imports.
# from flatland.envs.malfunction_generators import no_malfunction_generator, Malfunction, MalfunctionProcessData
# from flatland.envs.malfunction_generators import no_malfunction_generator, Malfunction, MalfunctionProcessData
# from flatland.envs.observations import GlobalObsForRailEnv
# from flatland.envs.observations import GlobalObsForRailEnv
# from flatland.envs.rail_generators import random_rail_generator, RailGenerator
# from flatland.envs.rail_generators import random_rail_generator, RailGenerator
# from flatland.envs.schedule_generators import random_schedule_generator, ScheduleGenerator
# from flatland.envs.schedule_generators import random_schedule_generator, ScheduleGenerator
from
flatland.envs.observations
import
GlobalObsForRailEnv
# import debugpy
import
pickle
m
.
patch
()
m
.
patch
()
# Adrian Egli performance fix (the fast methods brings more than 50%)
def
fast_isclose
(
a
,
b
,
rtol
):
return
(
a
<
(
b
+
rtol
))
or
(
a
<
(
b
-
rtol
))
def
fast_clip
(
position
:
(
int
,
int
),
min_value
:
(
int
,
int
),
max_value
:
(
int
,
int
))
->
bool
:
return
(
max
(
min_value
[
0
],
min
(
position
[
0
],
max_value
[
0
])),
max
(
min_value
[
1
],
min
(
position
[
1
],
max_value
[
1
]))
)
def
fast_argmax
(
possible_transitions
:
(
int
,
int
,
int
,
int
))
->
bool
:
if
possible_transitions
[
0
]
==
1
:
return
0
if
possible_transitions
[
1
]
==
1
:
return
1
if
possible_transitions
[
2
]
==
1
:
return
2
return
3
def
fast_position_equal
(
pos_1
:
(
int
,
int
),
pos_2
:
(
int
,
int
))
->
bool
:
return
pos_1
[
0
]
==
pos_2
[
0
]
and
pos_1
[
1
]
==
pos_2
[
1
]
def
fast_count_nonzero
(
possible_transitions
:
(
int
,
int
,
int
,
int
)):
return
possible_transitions
[
0
]
+
possible_transitions
[
1
]
+
possible_transitions
[
2
]
+
possible_transitions
[
3
]
class
RailEnvActions
(
IntEnum
):
class
RailEnvActions
(
IntEnum
):
DO_NOTHING
=
0
# implies change of direction in a dead-end!
DO_NOTHING
=
0
# implies change of direction in a dead-end!
MOVE_LEFT
=
1
MOVE_LEFT
=
1
...
@@ -298,11 +324,11 @@ class RailEnv(Environment):
...
@@ -298,11 +324,11 @@ class RailEnv(Environment):
False: Agent cannot provide an action
False: Agent cannot provide an action
"""
"""
return
(
agent
.
status
==
RailAgentStatus
.
READY_TO_DEPART
or
(
return
(
agent
.
status
==
RailAgentStatus
.
READY_TO_DEPART
or
(
agent
.
status
==
RailAgentStatus
.
ACTIVE
and
np
.
isclose
(
agent
.
speed_data
[
'
position_fraction
'
],
0.0
,
agent
.
status
==
RailAgentStatus
.
ACTIVE
and
fast_
isclose
(
agent
.
speed_data
[
'
position_fraction
'
],
0.0
,
rtol
=
1e-03
)))
rtol
=
1e-03
)))
def
reset
(
self
,
regenerate_rail
:
bool
=
True
,
regenerate_schedule
:
bool
=
True
,
activate_agents
:
bool
=
False
,
def
reset
(
self
,
regenerate_rail
:
bool
=
True
,
regenerate_schedule
:
bool
=
True
,
activate_agents
:
bool
=
False
,
random_seed
:
bool
=
None
)
->
(
Dict
,
Dict
)
:
random_seed
:
bool
=
None
)
->
Tuple
[
Dict
,
Dict
]
:
"""
"""
reset(regenerate_rail, regenerate_schedule, activate_agents, random_seed)
reset(regenerate_rail, regenerate_schedule, activate_agents, random_seed)
...
@@ -604,7 +630,7 @@ class RailEnv(Environment):
...
@@ -604,7 +630,7 @@ class RailEnv(Environment):
RailEnvActions
.
MOVE_LEFT
,
RailEnvActions
.
MOVE_RIGHT
,
RailEnvActions
.
MOVE_FORWARD
]
RailEnvActions
.
MOVE_LEFT
,
RailEnvActions
.
MOVE_RIGHT
,
RailEnvActions
.
MOVE_FORWARD
]
if
action
in
[
RailEnvActions
.
MOVE_LEFT
,
RailEnvActions
.
MOVE_RIGHT
,
if
action
in
[
RailEnvActions
.
MOVE_LEFT
,
RailEnvActions
.
MOVE_RIGHT
,
RailEnvActions
.
MOVE_FORWARD
]
and
self
.
cell_free
(
agent
.
initial_position
):
RailEnvActions
.
MOVE_FORWARD
]
and
self
.
cell_free
(
agent
.
initial_position
):
agent
.
status
=
RailAgentStatus
.
ACTIVE
agent
.
status
=
RailAgentStatus
.
ACTIVE
self
.
_set_agent_to_initial_position
(
agent
,
agent
.
initial_position
)
self
.
_set_agent_to_initial_position
(
agent
,
agent
.
initial_position
)
self
.
rewards_dict
[
i_agent
]
+=
self
.
step_penalty
*
agent
.
speed_data
[
'
speed
'
]
self
.
rewards_dict
[
i_agent
]
+=
self
.
step_penalty
*
agent
.
speed_data
[
'
speed
'
]
...
@@ -626,7 +652,7 @@ class RailEnv(Environment):
...
@@ -626,7 +652,7 @@ class RailEnv(Environment):
# Is the agent at the beginning of the cell? Then, it can take an action.
# Is the agent at the beginning of the cell? Then, it can take an action.
# As long as the agent is malfunctioning or stopped at the beginning of the cell,
# As long as the agent is malfunctioning or stopped at the beginning of the cell,
# different actions may be taken!
# different actions may be taken!
if
np
.
isclose
(
agent
.
speed_data
[
'
position_fraction
'
],
0.0
,
rtol
=
1e-03
):
if
fast_
isclose
(
agent
.
speed_data
[
'
position_fraction
'
],
0.0
,
rtol
=
1e-03
):
# No action has been supplied for this agent -> set DO_NOTHING as default
# No action has been supplied for this agent -> set DO_NOTHING as default
if
action
is
None
:
if
action
is
None
:
action
=
RailEnvActions
.
DO_NOTHING
action
=
RailEnvActions
.
DO_NOTHING
...
@@ -686,8 +712,8 @@ class RailEnv(Environment):
...
@@ -686,8 +712,8 @@ class RailEnv(Environment):
# transition_action_on_cellexit if the cell is free.
# transition_action_on_cellexit if the cell is free.
if
agent
.
moving
:
if
agent
.
moving
:
agent
.
speed_data
[
'
position_fraction
'
]
+=
agent
.
speed_data
[
'
speed
'
]
agent
.
speed_data
[
'
position_fraction
'
]
+=
agent
.
speed_data
[
'
speed
'
]
if
agent
.
speed_data
[
'
position_fraction
'
]
>
1.0
or
np
.
isclose
(
agent
.
speed_data
[
'
position_fraction
'
],
1.0
,
if
agent
.
speed_data
[
'
position_fraction
'
]
>
1.0
or
fast_
isclose
(
agent
.
speed_data
[
'
position_fraction
'
],
1.0
,
rtol
=
1e-03
):
rtol
=
1e-03
):
# Perform stored action to transition to the next cell as soon as cell is free
# Perform stored action to transition to the next cell as soon as cell is free
# Notice that we've already checked new_cell_valid and transition valid when we stored the action,
# Notice that we've already checked new_cell_valid and transition valid when we stored the action,
# so we only have to check cell_free now!
# so we only have to check cell_free now!
...
@@ -695,7 +721,7 @@ class RailEnv(Environment):
...
@@ -695,7 +721,7 @@ class RailEnv(Environment):
# Traditional check that next cell is free
# Traditional check that next cell is free
# cell and transition validity was checked when we stored transition_action_on_cellexit!
# cell and transition validity was checked when we stored transition_action_on_cellexit!
cell_free
,
new_cell_valid
,
new_direction
,
new_position
,
transition_valid
=
self
.
_check_action_on_agent
(
cell_free
,
new_cell_valid
,
new_direction
,
new_position
,
transition_valid
=
self
.
_check_action_on_agent
(
agent
.
speed_data
[
'
transition_action_on_cellexit
'
],
agent
)
agent
.
speed_data
[
'
transition_action_on_cellexit
'
],
agent
)
# N.B. validity of new_cell and transition should have been verified before the action was stored!
# N.B. validity of new_cell and transition should have been verified before the action was stored!
assert
new_cell_valid
assert
new_cell_valid
...
@@ -845,7 +871,6 @@ class RailEnv(Environment):
...
@@ -845,7 +871,6 @@ class RailEnv(Environment):
trans_block
=
sbTrans
[
agent
.
direction
*
4
:
agent
.
direction
*
4
+
4
]
trans_block
=
sbTrans
[
agent
.
direction
*
4
:
agent
.
direction
*
4
+
4
]
if
(
trans_block
==
"
0000
"
):
if
(
trans_block
==
"
0000
"
):
print
(
i_agent
,
agent
.
position
,
agent
.
direction
,
sbTrans
,
trans_block
)
print
(
i_agent
,
agent
.
position
,
agent
.
direction
,
sbTrans
,
trans_block
)
# debugpy.breakpoint()
# if agent cannot enter env, then we should have move=False
# if agent cannot enter env, then we should have move=False
...
@@ -862,20 +887,16 @@ class RailEnv(Environment):
...
@@ -862,20 +887,16 @@ class RailEnv(Environment):
if
not
all
([
transition_valid
,
new_cell_valid
]):
if
not
all
([
transition_valid
,
new_cell_valid
]):
print
(
f
"
ERRROR: step_agent2 invalid transition ag
{
i_agent
}
dir
{
new_direction
}
pos
{
agent
.
position
}
next
{
rc_next
}
"
)
print
(
f
"
ERRROR: step_agent2 invalid transition ag
{
i_agent
}
dir
{
new_direction
}
pos
{
agent
.
position
}
next
{
rc_next
}
"
)
# debugpy.breakpoint()
if
new_position
!=
rc_next
:
if
new_position
!=
rc_next
:
print
(
f
"
ERROR: agent
{
i_agent
}
new_pos
{
new_position
}
!= rc_next
{
rc_next
}
"
+
print
(
f
"
ERROR: agent
{
i_agent
}
new_pos
{
new_position
}
!= rc_next
{
rc_next
}
"
+
f
"
pos
{
agent
.
position
}
dir
{
agent
.
direction
}
new_dir
{
new_direction
}
"
+
f
"
pos
{
agent
.
position
}
dir
{
agent
.
direction
}
new_dir
{
new_direction
}
"
+
f
"
stored action:
{
agent
.
speed_data
[
'
transition_action_on_cellexit
'
]
}
"
)
f
"
stored action:
{
agent
.
speed_data
[
'
transition_action_on_cellexit
'
]
}
"
)
# debugpy.breakpoint()
sbTrans
=
format
(
self
.
rail
.
grid
[
agent
.
position
],
"
016b
"
)
sbTrans
=
format
(
self
.
rail
.
grid
[
agent
.
position
],
"
016b
"
)
trans_block
=
sbTrans
[
agent
.
direction
*
4
:
agent
.
direction
*
4
+
4
]
trans_block
=
sbTrans
[
agent
.
direction
*
4
:
agent
.
direction
*
4
+
4
]
if
(
trans_block
==
"
0000
"
):
if
(
trans_block
==
"
0000
"
):
print
(
"
ERROR:
"
,
i_agent
,
agent
.
position
,
agent
.
direction
,
sbTrans
,
trans_block
)
print
(
"
ERROR:
"
,
i_agent
,
agent
.
position
,
agent
.
direction
,
sbTrans
,
trans_block
)
# debugpy.breakpoint()
agent
.
position
=
rc_next
agent
.
position
=
rc_next
agent
.
direction
=
new_direction
agent
.
direction
=
new_direction
...
@@ -937,6 +958,7 @@ class RailEnv(Environment):
...
@@ -937,6 +958,7 @@ class RailEnv(Environment):
self
.
agent_positions
[
agent
.
position
]
=
-
1
self
.
agent_positions
[
agent
.
position
]
=
-
1
if
self
.
remove_agents_at_target
:
if
self
.
remove_agents_at_target
:
agent
.
position
=
None
agent
.
position
=
None
# setting old_position to None here stops the DONE agents from appearing in the rendered image
agent
.
old_position
=
None
agent
.
old_position
=
None
agent
.
status
=
RailAgentStatus
.
DONE_REMOVED
agent
.
status
=
RailAgentStatus
.
DONE_REMOVED
...
@@ -964,9 +986,9 @@ class RailEnv(Environment):
...
@@ -964,9 +986,9 @@ class RailEnv(Environment):
new_position
=
get_new_position
(
agent
.
position
,
new_direction
)
new_position
=
get_new_position
(
agent
.
position
,
new_direction
)
new_cell_valid
=
(
new_cell_valid
=
(
np
.
array
_equal
(
# Check the new position is still in the grid
fast_position
_equal
(
# Check the new position is still in the grid
new_position
,
new_position
,
np
.
clip
(
new_position
,
[
0
,
0
],
[
self
.
height
-
1
,
self
.
width
-
1
]))
fast_
clip
(
new_position
,
[
0
,
0
],
[
self
.
height
-
1
,
self
.
width
-
1
]))
and
# check the new position has some transitions (ie is not an empty cell)
and
# check the new position has some transitions (ie is not an empty cell)
self
.
rail
.
get_full_transitions
(
*
new_position
)
>
0
)
self
.
rail
.
get_full_transitions
(
*
new_position
)
>
0
)
...
@@ -1038,7 +1060,7 @@ class RailEnv(Environment):
...
@@ -1038,7 +1060,7 @@ class RailEnv(Environment):
"""
"""
transition_valid
=
None
transition_valid
=
None
possible_transitions
=
self
.
rail
.
get_transitions
(
*
agent
.
position
,
agent
.
direction
)
possible_transitions
=
self
.
rail
.
get_transitions
(
*
agent
.
position
,
agent
.
direction
)
num_transitions
=
np
.
count_nonzero
(
possible_transitions
)
num_transitions
=
fast_
count_nonzero
(
possible_transitions
)
new_direction
=
agent
.
direction
new_direction
=
agent
.
direction
if
action
==
RailEnvActions
.
MOVE_LEFT
:
if
action
==
RailEnvActions
.
MOVE_LEFT
:
...
@@ -1057,7 +1079,7 @@ class RailEnv(Environment):
...
@@ -1057,7 +1079,7 @@ class RailEnv(Environment):
# - dead-end, straight line or curved line;
# - dead-end, straight line or curved line;
# new_direction will be the only valid transition
# new_direction will be the only valid transition
# - take only available transition
# - take only available transition
new_direction
=
np
.
argmax
(
possible_transitions
)
new_direction
=
fast_
argmax
(
possible_transitions
)
transition_valid
=
True
transition_valid
=
True
return
new_direction
,
transition_valid
return
new_direction
,
transition_valid
...
...
This diff is collapsed.
Click to expand it.
flatland/utils/env_edit_utils.py
+
6
−
1
View file @
9d28be8f
...
@@ -122,5 +122,10 @@ def makeTestEnv(sName="single_alternative", nAg=2, bUCF=True):
...
@@ -122,5 +122,10 @@ def makeTestEnv(sName="single_alternative", nAg=2, bUCF=True):
dSpec
=
ddEnvSpecs
[
sName
]
dSpec
=
ddEnvSpecs
[
sName
]
return
makeEnv2
(
nAg
=
nAg
,
bUCF
=
bUCF
,
**
dSpec
)
return
makeEnv2
(
nAg
=
nAg
,
bUCF
=
bUCF
,
**
dSpec
)
def
getAgentState
(
env
):
dAgState
=
{}
for
iAg
,
ag
in
enumerate
(
env
.
agents
):
dAgState
[
iAg
]
=
(
*
ag
.
position
,
ag
.
direction
)
return
dAgState
\ No newline at end of file
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment