Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Flatland
Flatland
Commits
8e8f91a5
Commit
8e8f91a5
authored
Sep 17, 2019
by
u214892
Browse files
#178
bugfix initial malfunction
parent
78b1f9ee
Changes
10
Expand all
Hide whitespace changes
Inline
Side-by-side
benchmarks/run_all_examples.py
View file @
8e8f91a5
...
...
@@ -18,6 +18,7 @@ for entry in [entry for entry in importlib_resources.contents('examples') if
with
path
(
'examples'
,
entry
)
as
file_in
:
print
(
""
)
print
(
""
)
print
(
""
)
print
(
"*****************************************************************"
)
print
(
"Running {}"
.
format
(
entry
))
...
...
flatland/core/grid/grid4_astar.py
View file @
8e8f91a5
import
collections
from
flatland.core.grid.grid4_utils
import
validate_new_transition
from
flatland.utils.ordered_set
import
OrderedSet
class
AStarNode
():
...
...
@@ -27,54 +26,6 @@ class AStarNode():
self
.
f
=
other
.
f
# in order for enumeration to be deterministic for testing purposes
# https://stackoverflow.com/questions/1653970/does-python-have-an-ordered-set
class
OrderedSet
(
collections
.
OrderedDict
,
collections
.
MutableSet
):
def
update
(
self
,
*
args
,
**
kwargs
):
if
kwargs
:
raise
TypeError
(
"update() takes no keyword arguments"
)
for
s
in
args
:
for
e
in
s
:
self
.
add
(
e
)
def
add
(
self
,
elem
):
self
[
elem
]
=
None
def
discard
(
self
,
elem
):
self
.
pop
(
elem
,
None
)
def
__le__
(
self
,
other
):
return
all
(
e
in
other
for
e
in
self
)
def
__lt__
(
self
,
other
):
return
self
<=
other
and
self
!=
other
def
__ge__
(
self
,
other
):
return
all
(
e
in
self
for
e
in
other
)
def
__gt__
(
self
,
other
):
return
self
>=
other
and
self
!=
other
def
__repr__
(
self
):
return
'OrderedSet([%s])'
%
(
', '
.
join
(
map
(
repr
,
self
.
keys
())))
def
__str__
(
self
):
return
'{%s}'
%
(
', '
.
join
(
map
(
repr
,
self
.
keys
())))
difference
=
property
(
lambda
self
:
self
.
__sub__
)
difference_update
=
property
(
lambda
self
:
self
.
__isub__
)
intersection
=
property
(
lambda
self
:
self
.
__and__
)
intersection_update
=
property
(
lambda
self
:
self
.
__iand__
)
issubset
=
property
(
lambda
self
:
self
.
__le__
)
issuperset
=
property
(
lambda
self
:
self
.
__ge__
)
symmetric_difference
=
property
(
lambda
self
:
self
.
__xor__
)
symmetric_difference_update
=
property
(
lambda
self
:
self
.
__ixor__
)
union
=
property
(
lambda
self
:
self
.
__or__
)
def
a_star
(
rail_trans
,
rail_array
,
start
,
end
):
"""
Returns a list of tuples as a path from the given start to end.
...
...
flatland/core/transition_map.py
View file @
8e8f91a5
...
...
@@ -10,6 +10,7 @@ from flatland.core.grid.grid4 import Grid4Transitions
from
flatland.core.grid.grid4_utils
import
get_new_position
from
flatland.core.grid.rail_env_grid
import
RailEnvTransitions
from
flatland.core.transitions
import
Transitions
from
flatland.utils.ordered_set
import
OrderedSet
class
TransitionMap
:
...
...
@@ -336,7 +337,7 @@ class GridTransitionMap(TransitionMap):
tmp
=
self
.
get_full_transitions
(
rcPos
[
0
],
rcPos
[
1
])
def
is_simple_turn
(
trans
):
all_simple_turns
=
s
et
()
all_simple_turns
=
OrderedS
et
()
for
trans
in
[
int
(
'0100000000000010'
,
2
),
# Case 1b (8) - simple turn right
int
(
'0001001000000000'
,
2
)
# Case 1c (9) - simple turn left]:
]:
...
...
@@ -351,7 +352,7 @@ class GridTransitionMap(TransitionMap):
# print("_path_exists({},{},{}".format(start, direction, end))
# BFS - Check if a path exists between the 2 nodes
visited
=
s
et
()
visited
=
OrderedS
et
()
stack
=
[(
start
,
direction
)]
while
stack
:
node
=
stack
.
pop
()
...
...
flatland/envs/observations.py
View file @
8e8f91a5
...
...
@@ -9,6 +9,7 @@ import numpy as np
from
flatland.core.env_observation_builder
import
ObservationBuilder
from
flatland.core.grid.grid4
import
Grid4TransitionsEnum
from
flatland.core.grid.grid_utils
import
coordinate_to_position
from
flatland.utils.ordered_set
import
OrderedSet
class
TreeObsForRailEnv
(
ObservationBuilder
):
...
...
@@ -279,7 +280,7 @@ class TreeObsForRailEnv(ObservationBuilder):
observation
=
[
0
,
0
,
0
,
0
,
0
,
0
,
self
.
distance_map
[(
handle
,
*
agent
.
position
,
agent
.
direction
)],
0
,
0
,
agent
.
malfunction_data
[
'malfunction'
],
agent
.
speed_data
[
'speed'
]]
visited
=
s
et
()
visited
=
OrderedS
et
()
# Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right, back], relative to the current orientation
...
...
@@ -295,7 +296,7 @@ class TreeObsForRailEnv(ObservationBuilder):
branch_observation
,
branch_visited
=
\
self
.
_explore_branch
(
handle
,
new_cell
,
branch_direction
,
1
,
1
)
observation
=
observation
+
branch_observation
visited
=
visited
.
union
(
branch_visited
)
visited
|
=
branch_visited
else
:
# add cells filled with infinity if no transition is possible
observation
=
observation
+
[
-
np
.
inf
]
*
self
.
_num_cells_to_fill_in
(
self
.
max_depth
)
...
...
@@ -332,7 +333,7 @@ class TreeObsForRailEnv(ObservationBuilder):
last_is_terminal
=
False
# wrong cell OR cycle; either way, we don't want the agent to land here
last_is_target
=
False
visited
=
s
et
()
visited
=
OrderedS
et
()
agent
=
self
.
env
.
agents
[
handle
]
time_per_cell
=
np
.
reciprocal
(
agent
.
speed_data
[
"speed"
])
own_target_encountered
=
np
.
inf
...
...
@@ -545,7 +546,7 @@ class TreeObsForRailEnv(ObservationBuilder):
depth
+
1
)
observation
=
observation
+
branch_observation
if
len
(
branch_visited
)
!=
0
:
visited
=
visited
.
union
(
branch_visited
)
visited
|
=
branch_visited
elif
last_is_switch
and
possible_transitions
[
branch_direction
]:
new_cell
=
self
.
_new_position
(
position
,
branch_direction
)
branch_observation
,
branch_visited
=
self
.
_explore_branch
(
handle
,
...
...
@@ -555,7 +556,7 @@ class TreeObsForRailEnv(ObservationBuilder):
depth
+
1
)
observation
=
observation
+
branch_observation
if
len
(
branch_visited
)
!=
0
:
visited
=
visited
.
union
(
branch_visited
)
visited
|
=
branch_visited
else
:
# no exploring possible, add just cells with infinity
observation
=
observation
+
[
-
np
.
inf
]
*
self
.
_num_cells_to_fill_in
(
self
.
max_depth
-
depth
)
...
...
flatland/envs/predictions.py
View file @
8e8f91a5
...
...
@@ -7,6 +7,7 @@ import numpy as np
from
flatland.core.env_prediction_builder
import
PredictionBuilder
from
flatland.core.grid.grid4_utils
import
get_new_position
from
flatland.envs.rail_env
import
RailEnvActions
from
flatland.utils.ordered_set
import
OrderedSet
class
DummyPredictorForRailEnv
(
PredictionBuilder
):
...
...
@@ -130,7 +131,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
prediction
[
0
]
=
[
0
,
*
_agent_initial_position
,
_agent_initial_direction
,
0
]
new_direction
=
_agent_initial_direction
new_position
=
_agent_initial_position
visited
=
s
et
()
visited
=
OrderedS
et
()
for
index
in
range
(
1
,
self
.
max_depth
+
1
):
# if we're at the target, stop moving...
if
agent
.
position
==
agent
.
target
:
...
...
flatland/envs/rail_env.py
View file @
8e8f91a5
...
...
@@ -4,7 +4,7 @@ Definition of the RailEnv environment.
# TODO: _ this is a global method --> utils or remove later
import
warnings
from
enum
import
IntEnum
from
typing
import
List
,
Set
,
NamedTuple
from
typing
import
List
,
Set
,
NamedTuple
,
Optional
import
msgpack
import
msgpack_numpy
as
m
...
...
@@ -18,6 +18,7 @@ from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent
from
flatland.envs.observations
import
TreeObsForRailEnv
from
flatland.envs.rail_generators
import
random_rail_generator
,
RailGenerator
from
flatland.envs.schedule_generators
import
random_schedule_generator
,
ScheduleGenerator
from
flatland.utils.ordered_set
import
OrderedSet
m
.
patch
()
...
...
@@ -153,7 +154,7 @@ class RailEnv(Environment):
self
.
rail_generator
:
RailGenerator
=
rail_generator
self
.
schedule_generator
:
ScheduleGenerator
=
schedule_generator
self
.
rail_generator
=
rail_generator
self
.
rail
:
GridTransitionMap
=
None
self
.
rail
:
Optional
[
GridTransitionMap
]
=
None
self
.
width
=
width
self
.
height
=
height
...
...
@@ -549,7 +550,7 @@ class RailEnv(Environment):
return
new_direction
,
transition_valid
def
get_valid_move_actions
(
self
,
agent
:
EnvAgent
)
->
Set
[
RailEnvNextAction
]:
valid_actions
:
Set
[
RailEnvNextAction
]
=
s
et
()
valid_actions
:
Set
[
RailEnvNextAction
]
=
OrderedS
et
()
agent_position
=
agent
.
position
agent_direction
=
agent
.
direction
possible_transitions
=
self
.
rail
.
get_transitions
(
*
agent_position
,
agent_direction
)
...
...
flatland/utils/ordered_set.py
0 → 100644
View file @
8e8f91a5
# in order for enumeration to be deterministic for testing purposes
# https://stackoverflow.com/questions/1653970/does-python-have-an-ordered-set
from
collections
import
OrderedDict
from
collections.abc
import
MutableSet
class
OrderedSet
(
OrderedDict
,
MutableSet
):
def
update
(
self
,
*
args
,
**
kwargs
):
if
kwargs
:
raise
TypeError
(
"update() takes no keyword arguments"
)
for
s
in
args
:
for
e
in
s
:
self
.
add
(
e
)
def
add
(
self
,
elem
):
self
[
elem
]
=
None
def
discard
(
self
,
elem
):
self
.
pop
(
elem
,
None
)
def
__le__
(
self
,
other
):
return
all
(
e
in
other
for
e
in
self
)
def
__lt__
(
self
,
other
):
return
self
<=
other
and
self
!=
other
def
__ge__
(
self
,
other
):
return
all
(
e
in
self
for
e
in
other
)
def
__gt__
(
self
,
other
):
return
self
>=
other
and
self
!=
other
def
__repr__
(
self
):
return
'OrderedSet([%s])'
%
(
', '
.
join
(
map
(
repr
,
self
.
keys
())))
def
__str__
(
self
):
return
'{%s}'
%
(
', '
.
join
(
map
(
repr
,
self
.
keys
())))
difference
=
property
(
lambda
self
:
self
.
__sub__
)
difference_update
=
property
(
lambda
self
:
self
.
__isub__
)
intersection
=
property
(
lambda
self
:
self
.
__and__
)
intersection_update
=
property
(
lambda
self
:
self
.
__iand__
)
issubset
=
property
(
lambda
self
:
self
.
__le__
)
issuperset
=
property
(
lambda
self
:
self
.
__ge__
)
symmetric_difference
=
property
(
lambda
self
:
self
.
__xor__
)
symmetric_difference_update
=
property
(
lambda
self
:
self
.
__ixor__
)
union
=
property
(
lambda
self
:
self
.
__or__
)
tests/test_flatland_envs_sparse_rail_generator.py
View file @
8e8f91a5
This diff is collapsed.
Click to expand it.
tests/test_multi_speed.py
View file @
8e8f91a5
...
...
@@ -9,7 +9,7 @@ from flatland.envs.rail_generators import complex_rail_generator, rail_from_grid
from
flatland.envs.schedule_generators
import
complex_schedule_generator
,
random_schedule_generator
from
flatland.utils.rendertools
import
RenderTool
from
flatland.utils.simple_rail
import
make_simple_rail
from
test_utils
import
Test
Config
,
Replay
from
test_utils
import
Replay
Config
,
Replay
np
.
random
.
seed
(
1
)
...
...
@@ -117,7 +117,7 @@ def test_multispeed_actions_no_malfunction_no_blocking(rendering=True):
if
rendering
:
renderer
=
RenderTool
(
env
,
gl
=
"PILSVG"
)
test_config
=
Test
Config
(
test_config
=
Replay
Config
(
replay
=
[
Replay
(
position
=
(
3
,
9
),
# east dead-end
...
...
@@ -248,7 +248,7 @@ def test_multispeed_actions_no_malfunction_blocking(rendering=True):
renderer
=
RenderTool
(
env
,
gl
=
"PILSVG"
)
test_configs
=
[
Test
Config
(
Replay
Config
(
replay
=
[
Replay
(
position
=
(
3
,
8
),
...
...
@@ -316,7 +316,7 @@ def test_multispeed_actions_no_malfunction_blocking(rendering=True):
],
target
=
(
3
,
0
),
# west dead-end
speed
=
1
/
3
),
Test
Config
(
Replay
Config
(
replay
=
[
Replay
(
position
=
(
3
,
9
),
# east dead-end
...
...
@@ -456,7 +456,7 @@ def test_multispeed_actions_malfunction_no_blocking(rendering=True):
if
rendering
:
renderer
=
RenderTool
(
env
,
gl
=
"PILSVG"
)
test_config
=
Test
Config
(
test_config
=
Replay
Config
(
replay
=
[
Replay
(
position
=
(
3
,
9
),
# east dead-end
...
...
tests/test_utils.py
View file @
8e8f91a5
...
...
@@ -15,7 +15,7 @@ class Replay(object):
@
attrs
class
Test
Config
(
object
):
class
Replay
Config
(
object
):
replay
=
attrib
(
type
=
List
[
Replay
])
target
=
attrib
()
speed
=
attrib
(
type
=
float
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment