Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Flatland
Flatland
Commits
608a75b5
Commit
608a75b5
authored
Sep 15, 2021
by
Dipam Chakraborty
Browse files
fix seeding pipeline
parent
e3c821e5
Pipeline
#8499
failed with stages
in 6 minutes and 7 seconds
Changes
6
Pipelines
1
Expand all
Show whitespace changes
Inline
Side-by-side
flatland/envs/rail_env.py
View file @
608a75b5
...
...
@@ -106,7 +106,7 @@ class RailEnv(Environment):
malfunction_generator_and_process_data
=
None
,
# mal_gen.no_malfunction_generator(),
malfunction_generator
=
None
,
remove_agents_at_target
=
True
,
random_seed
=
1
,
random_seed
=
None
,
record_steps
=
False
,
):
"""
...
...
@@ -161,7 +161,6 @@ class RailEnv(Environment):
self
.
number_of_agents
=
number_of_agents
# self.rail_generator: RailGenerator = rail_generator
if
rail_generator
is
None
:
rail_generator
=
rail_gen
.
sparse_rail_generator
()
self
.
rail_generator
=
rail_generator
...
...
@@ -193,9 +192,7 @@ class RailEnv(Environment):
self
.
action_space
=
[
5
]
self
.
_seed
()
self
.
_seed
()
self
.
random_seed
=
random_seed
if
self
.
random_seed
:
if
random_seed
:
self
.
_seed
(
seed
=
random_seed
)
self
.
agent_positions
=
None
...
...
@@ -211,6 +208,14 @@ class RailEnv(Environment):
def
_seed
(
self
,
seed
=
None
):
self
.
np_random
,
seed
=
seeding
.
np_random
(
seed
)
random
.
seed
(
seed
)
self
.
random_seed
=
seed
# Keep track of all the seeds in order
if
not
hasattr
(
self
,
'seed_history'
):
self
.
seed_history
=
[
seed
]
if
self
.
seed_history
[
-
1
]
!=
seed
:
self
.
seed_history
.
append
(
seed
)
return
[
seed
]
# no more agent_handles
...
...
@@ -252,7 +257,7 @@ class RailEnv(Environment):
(
agent
.
state
.
is_on_map_state
()
and
agent
.
speed_counter
.
is_cell_entry
)
def
reset
(
self
,
regenerate_rail
:
bool
=
True
,
regenerate_schedule
:
bool
=
True
,
*
,
random_seed
:
bool
=
None
)
->
Tuple
[
Dict
,
Dict
]:
random_seed
:
int
=
None
)
->
Tuple
[
Dict
,
Dict
]:
"""
reset(regenerate_rail, regenerate_schedule, activate_agents, random_seed)
...
...
@@ -264,7 +269,7 @@ class RailEnv(Environment):
regenerate the rails
regenerate_schedule : bool, optional
regenerate the schedule and the static agents
random_seed :
bool
, optional
random_seed :
int
, optional
random seed for environment
Returns
...
...
@@ -355,11 +360,11 @@ class RailEnv(Environment):
""" Update the agent_positions array for agents that changed positions """
for
agent
in
self
.
agents
:
if
not
ignore_old_positions
or
agent
.
old_position
!=
agent
.
position
:
if
agent
.
position
is
not
None
:
self
.
agent_positions
[
agent
.
position
]
=
agent
.
handle
if
agent
.
old_position
is
not
None
:
self
.
agent_positions
[
agent
.
old_position
]
=
-
1
def
generate_state_transition_signals
(
self
,
agent
,
preprocessed_action
,
movement_allowed
):
""" Generate State Transitions Signals used in the state machine """
st_signals
=
StateTransitionSignals
()
...
...
@@ -597,7 +602,7 @@ class RailEnv(Environment):
# Check if episode has ended and update rewards and dones
self
.
end_of_episode_update
(
have_all_agents_ended
)
self
.
_update_agent_positions_map
self
.
_update_agent_positions_map
()
return
self
.
_get_observations
(),
self
.
rewards_dict
,
self
.
dones
,
self
.
get_info_dict
()
...
...
flatland/envs/rail_generators.py
View file @
608a75b5
...
...
@@ -163,7 +163,7 @@ def sparse_rail_generator(*args, **kwargs):
class
SparseRailGen
(
RailGen
):
def
__init__
(
self
,
max_num_cities
:
int
=
2
,
grid_mode
:
bool
=
False
,
max_rails_between_cities
:
int
=
2
,
max_rail_pairs_in_city
:
int
=
2
,
seed
=
0
)
->
RailGenerator
:
max_rail_pairs_in_city
:
int
=
2
,
seed
=
None
)
->
RailGenerator
:
"""
Generates railway networks with cities and inner city rails
...
...
@@ -189,7 +189,7 @@ class SparseRailGen(RailGen):
self
.
grid_mode
=
grid_mode
self
.
max_rails_between_cities
=
max_rails_between_cities
self
.
max_rail_pairs_in_city
=
max_rail_pairs_in_city
self
.
seed
=
seed
# TODO: seed in constructor or generate?
self
.
seed
=
seed
def
generate
(
self
,
width
:
int
,
height
:
int
,
num_agents
:
int
,
num_resets
:
int
=
0
,
...
...
@@ -217,8 +217,10 @@ class SparseRailGen(RailGen):
'train_stations': locations of train stations for start and targets
'city_orientations' : orientation of cities
"""
if
np_random
is
None
:
if
self
.
seed
is
not
None
:
np_random
=
RandomState
(
self
.
seed
)
elif
np_random
is
None
:
np_random
=
RandomState
(
np
.
random
.
randint
(
2
**
32
))
rail_trans
=
RailEnvTransitions
()
grid_map
=
GridTransitionMap
(
width
=
width
,
height
=
height
,
transitions
=
rail_trans
)
...
...
tests/test_flatland_envs_observations.py
View file @
608a75b5
...
...
@@ -182,7 +182,7 @@ def test_reward_function_waiting(rendering=False):
env
=
RailEnv
(
width
=
rail_map
.
shape
[
1
],
height
=
rail_map
.
shape
[
0
],
rail_generator
=
rail_from_grid_transition_map
(
rail
,
optionals
),
line_generator
=
sparse_line_generator
(),
number_of_agents
=
2
,
obs_builder_object
=
TreeObsForRailEnv
(
max_depth
=
2
,
predictor
=
ShortestPathPredictorForRailEnv
()),
remove_agents_at_target
=
False
)
remove_agents_at_target
=
False
,
random_seed
=
1
)
obs_builder
:
TreeObsForRailEnv
=
env
.
obs_builder
env
.
reset
()
...
...
tests/test_flatland_envs_sparse_rail_generator.py
View file @
608a75b5
This diff is collapsed.
Click to expand it.
tests/test_multi_speed.py
View file @
608a75b5
...
...
@@ -22,13 +22,14 @@ class RandomAgent:
def
__init__
(
self
,
state_size
,
action_size
):
self
.
state_size
=
state_size
self
.
action_size
=
action_size
self
.
np_random
=
np
.
random
.
RandomState
(
seed
=
42
)
def
act
(
self
,
state
):
"""
:param state: input is the observation of the agent
:return: returns an action
"""
return
np
.
random
.
choice
([
1
,
2
,
3
])
return
self
.
np
_
random
.
choice
([
1
,
2
,
3
])
def
step
(
self
,
memories
):
"""
...
...
@@ -63,6 +64,7 @@ def test_multi_speed_init():
# Set all the different speeds
# Reset environment and get initial observations for all agents
env
.
reset
(
False
,
False
)
env
.
_max_episode_steps
=
1000
for
a_idx
in
range
(
len
(
env
.
agents
)):
env
.
agents
[
a_idx
].
position
=
env
.
agents
[
a_idx
].
initial_position
...
...
@@ -204,7 +206,8 @@ def test_multispeed_actions_no_malfunction_blocking():
rail
,
rail_map
,
optionals
=
make_simple_rail
()
env
=
RailEnv
(
width
=
rail_map
.
shape
[
1
],
height
=
rail_map
.
shape
[
0
],
rail_generator
=
rail_from_grid_transition_map
(
rail
,
optionals
),
line_generator
=
sparse_line_generator
(),
number_of_agents
=
2
,
obs_builder_object
=
TreeObsForRailEnv
(
max_depth
=
2
,
predictor
=
ShortestPathPredictorForRailEnv
()))
obs_builder_object
=
TreeObsForRailEnv
(
max_depth
=
2
,
predictor
=
ShortestPathPredictorForRailEnv
()),
random_seed
=
1
)
env
.
reset
()
set_penalties_for_replay
(
env
)
...
...
tests/test_random_seeding.py
View file @
608a75b5
...
...
@@ -166,52 +166,53 @@ def test_reproducability_env():
env
=
RailEnv
(
width
=
25
,
height
=
30
,
rail_generator
=
sparse_rail_generator
(
max_num_cities
=
5
,
max_rails_between_cities
=
3
,
seed
=
215545
,
# Random seed
seed
=
10
,
# Random seed
grid_mode
=
True
),
line_generator
=
sparse_line_generator
(
speed_ration_map
),
number_of_agents
=
1
)
env
.
reset
(
True
,
True
,
random_seed
=
1
0
)
env
.
reset
(
True
,
True
,
random_seed
=
1
)
excpeted_grid
=
[[
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
16386
,
1025
,
5633
,
17411
,
1025
,
1025
,
1025
,
5633
,
17411
,
1025
,
1025
,
1025
,
1025
,
1025
,
1025
,
5633
,
17411
,
1025
,
1025
,
1025
,
5633
,
17411
,
1025
,
4608
],
[
0
,
49186
,
1025
,
1097
,
3089
,
1025
,
1025
,
1025
,
1097
,
3089
,
1025
,
1025
,
1025
,
1025
,
1025
,
1025
,
1097
,
3089
,
1025
,
1025
,
1025
,
1097
,
3089
,
1025
,
37408
],
[
0
,
0
,
0
,
0
,
0
,
16386
,
1025
,
4608
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
],
[
16386
,
17411
,
1025
,
5633
,
17411
,
3089
,
1025
,
1097
,
5633
,
17411
,
1025
,
5633
,
1025
,
1025
,
1025
,
1025
,
5633
,
17411
,
1025
,
1025
,
1025
,
5633
,
17411
,
1025
,
4608
],
[
32800
,
32800
,
0
,
72
,
3089
,
5633
,
1025
,
17411
,
1097
,
2064
,
0
,
72
,
1025
,
1025
,
1025
,
1025
,
1097
,
3089
,
1025
,
1025
,
1025
,
1097
,
3089
,
1025
,
37408
],
[
32800
,
32800
,
0
,
0
,
0
,
72
,
1025
,
2064
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
32800
],
[
32800
,
32800
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
32800
],
[
32800
,
32800
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
32800
],
[
32800
,
32872
,
4608
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
16386
,
34864
],
[
32800
,
32800
,
32800
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
32800
,
32800
],
[
32800
,
32800
,
32800
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
32800
,
32800
],
[
32800
,
32800
,
32800
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
32800
,
32800
],
[
32800
,
32800
,
32800
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
32800
,
32800
],
[
32800
,
32800
,
32800
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
32800
,
32800
],
[
32800
,
32800
,
32800
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
32800
,
32800
],
[
72
,
37408
,
32800
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
32800
,
32800
],
[
0
,
49186
,
2064
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
72
,
37408
],
[
0
,
32800
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
32800
],
[
0
,
32800
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
32800
],
[
0
,
32800
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
32800
],
[
0
,
32872
,
4608
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
16386
,
1025
,
1025
,
1025
,
17411
,
34864
],
[
16386
,
34864
,
32800
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
32800
,
16386
,
1025
,
1025
,
33825
,
2064
],
[
32800
,
32800
,
32800
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
32800
,
32800
,
0
,
0
,
32800
,
0
],
[
32800
,
32800
,
32800
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
32800
,
32800
,
0
,
0
,
32800
,
0
],
[
32800
,
32800
,
32800
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
32800
,
32800
,
0
,
0
,
32800
,
0
],
[
32800
,
32800
,
32800
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
32800
,
32800
,
0
,
0
,
32800
,
0
],
[
32800
,
32800
,
32800
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
32800
,
32800
,
0
,
0
,
32800
,
0
],
[
32800
,
32800
,
32800
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
32800
,
32800
,
0
,
0
,
32800
,
0
],
[
32800
,
49186
,
2064
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
16386
,
1025
,
1025
,
1025
,
1025
,
38505
,
3089
,
1025
,
1025
,
2064
,
0
],
[
32800
,
32800
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
32800
,
0
,
0
,
0
,
0
,
32800
,
0
,
0
,
0
,
0
,
0
],
[
32800
,
32800
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
32800
,
0
,
0
,
0
,
0
,
32872
,
4608
,
0
,
0
,
0
,
0
],
[
32800
,
32800
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
32800
,
0
,
0
,
0
,
0
,
49186
,
34864
,
0
,
0
,
0
,
0
],
[
32800
,
32800
,
0
,
0
,
0
,
16386
,
1025
,
4608
,
0
,
0
,
0
,
0
,
0
,
0
,
32800
,
0
,
0
,
0
,
0
,
32800
,
32800
,
0
,
0
,
0
,
0
],
[
72
,
1097
,
1025
,
5633
,
17411
,
3089
,
1025
,
1097
,
5633
,
17411
,
1025
,
5633
,
1025
,
1025
,
2064
,
0
,
0
,
0
,
0
,
32800
,
32800
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
72
,
3089
,
5633
,
1025
,
17411
,
1097
,
2064
,
0
,
32800
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
32800
,
32800
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
72
,
1025
,
2064
,
0
,
0
,
0
,
32800
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
32872
,
37408
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
32800
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
49186
,
2064
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
32800
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
32800
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
72
,
1025
,
1025
,
1025
,
1025
,
1025
,
1025
,
1025
,
2064
,
0
,
0
,
0
,
0
,
0
]]
[
0
,
32800
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
32800
],
[
0
,
32872
,
1025
,
5633
,
17411
,
1025
,
1025
,
1025
,
5633
,
17411
,
1025
,
1025
,
1025
,
1025
,
1025
,
1025
,
5633
,
17411
,
1025
,
1025
,
1025
,
5633
,
17411
,
1025
,
34864
],
[
0
,
72
,
1025
,
1097
,
3089
,
1025
,
1025
,
1025
,
1097
,
3089
,
1025
,
1025
,
1025
,
1025
,
1025
,
1025
,
1097
,
3089
,
1025
,
1025
,
1025
,
1097
,
3089
,
1025
,
2064
],
[
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
]]
assert
env
.
rail
.
grid
.
tolist
()
==
excpeted_grid
# Test that we don't have interference from calling mulitple function outisde
env2
=
RailEnv
(
width
=
25
,
height
=
30
,
rail_generator
=
sparse_rail_generator
(
max_num_cities
=
5
,
max_rails_between_cities
=
3
,
seed
=
215545
,
# Random seed
seed
=
10
,
# Random seed
grid_mode
=
True
),
line_generator
=
sparse_line_generator
(
speed_ration_map
),
number_of_agents
=
1
)
np
.
random
.
seed
(
1
0
)
np
.
random
.
seed
(
1
)
for
i
in
range
(
10
):
np
.
random
.
randn
()
env2
.
reset
(
True
,
True
,
random_seed
=
1
0
)
env2
.
reset
(
True
,
True
,
random_seed
=
1
)
assert
env2
.
rail
.
grid
.
tolist
()
==
excpeted_grid
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment