Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Flatland
Flatland
Commits
329120c7
Commit
329120c7
authored
Aug 04, 2021
by
Dipam Chakraborty
Browse files
change names from schedule to line in test and evaluators
parent
7768750d
Changes
22
Hide whitespace changes
Inline
Side-by-side
flatland/envs/rail_env_utils.py
View file @
329120c7
...
...
@@ -3,7 +3,7 @@ from flatland.envs.observations import TreeObsForRailEnv
from
flatland.envs.predictions
import
ShortestPathPredictorForRailEnv
from
flatland.envs.rail_env
import
RailEnv
from
flatland.envs.rail_generators
import
rail_from_file
from
flatland.envs.
schedul
e_generators
import
schedul
e_from_file
from
flatland.envs.
lin
e_generators
import
lin
e_from_file
def
load_flatland_environment_from_file
(
file_name
:
str
,
...
...
@@ -33,7 +33,7 @@ def load_flatland_environment_from_file(file_name: str,
max_depth
=
2
,
predictor
=
ShortestPathPredictorForRailEnv
(
max_depth
=
10
))
environment
=
RailEnv
(
width
=
1
,
height
=
1
,
rail_generator
=
rail_from_file
(
file_name
,
load_from_package
),
schedule_generator
=
schedul
e_from_file
(
file_name
,
load_from_package
),
schedule_generator
=
lin
e_from_file
(
file_name
,
load_from_package
),
number_of_agents
=
1
,
obs_builder_object
=
obs_builder_object
,
record_steps
=
record_steps
,
...
...
flatland/evaluators/client.py
View file @
329120c7
...
...
@@ -15,7 +15,7 @@ import flatland
from
flatland.envs.malfunction_generators
import
malfunction_from_file
from
flatland.envs.rail_env
import
RailEnv
from
flatland.envs.rail_generators
import
rail_from_file
from
flatland.envs.
schedul
e_generators
import
schedul
e_from_file
from
flatland.envs.
lin
e_generators
import
lin
e_from_file
from
flatland.evaluators
import
messages
from
flatland.core.env_observation_builder
import
DummyObservationBuilder
...
...
flatland/evaluators/service.py
View file @
329120c7
...
...
@@ -26,7 +26,7 @@ from flatland.envs.agent_utils import RailAgentStatus
from
flatland.envs.malfunction_generators
import
malfunction_from_file
from
flatland.envs.rail_env
import
RailEnv
from
flatland.envs.rail_generators
import
rail_from_file
from
flatland.envs.
schedul
e_generators
import
schedul
e_from_file
from
flatland.envs.
lin
e_generators
import
lin
e_from_file
from
flatland.evaluators
import
aicrowd_helpers
from
flatland.evaluators
import
messages
from
flatland.utils.rendertools
import
RenderTool
...
...
tests/test_action_plan.py
View file @
329120c7
...
...
@@ -6,7 +6,7 @@ from flatland.envs.observations import GlobalObsForRailEnv
from
flatland.envs.rail_env
import
RailEnv
,
RailEnvActions
from
flatland.envs.rail_generators
import
rail_from_grid_transition_map
from
flatland.envs.rail_trainrun_data_structures
import
Waypoint
from
flatland.envs.
schedul
e_generators
import
random_
schedul
e_generator
from
flatland.envs.
lin
e_generators
import
random_
lin
e_generator
from
flatland.utils.rendertools
import
RenderTool
,
AgentRenderVariant
from
flatland.utils.simple_rail
import
make_simple_rail
...
...
@@ -17,7 +17,7 @@ def test_action_plan(rendering: bool = False):
env
=
RailEnv
(
width
=
rail_map
.
shape
[
1
],
height
=
rail_map
.
shape
[
0
],
rail_generator
=
rail_from_grid_transition_map
(
rail
),
schedul
e_generator
=
random_
schedul
e_generator
(
seed
=
77
),
lin
e_generator
=
random_
lin
e_generator
(
seed
=
77
),
number_of_agents
=
2
,
obs_builder_object
=
GlobalObsForRailEnv
(),
remove_agents_at_target
=
True
...
...
@@ -34,25 +34,25 @@ def test_action_plan(rendering: bool = False):
for
handle
,
agent
in
enumerate
(
env
.
agents
):
print
(
"[{}] {} -> {}"
.
format
(
handle
,
agent
.
initial_position
,
agent
.
target
))
chosen_path_dict
=
{
0
:
[
TrainrunWaypoint
(
schedul
ed_at
=
0
,
waypoint
=
Waypoint
(
position
=
(
3
,
0
),
direction
=
3
)),
TrainrunWaypoint
(
schedul
ed_at
=
2
,
waypoint
=
Waypoint
(
position
=
(
3
,
1
),
direction
=
1
)),
TrainrunWaypoint
(
schedul
ed_at
=
3
,
waypoint
=
Waypoint
(
position
=
(
3
,
2
),
direction
=
1
)),
TrainrunWaypoint
(
schedul
ed_at
=
14
,
waypoint
=
Waypoint
(
position
=
(
3
,
3
),
direction
=
1
)),
TrainrunWaypoint
(
schedul
ed_at
=
15
,
waypoint
=
Waypoint
(
position
=
(
3
,
4
),
direction
=
1
)),
TrainrunWaypoint
(
schedul
ed_at
=
16
,
waypoint
=
Waypoint
(
position
=
(
3
,
5
),
direction
=
1
)),
TrainrunWaypoint
(
schedul
ed_at
=
17
,
waypoint
=
Waypoint
(
position
=
(
3
,
6
),
direction
=
1
)),
TrainrunWaypoint
(
schedul
ed_at
=
18
,
waypoint
=
Waypoint
(
position
=
(
3
,
7
),
direction
=
1
)),
TrainrunWaypoint
(
schedul
ed_at
=
19
,
waypoint
=
Waypoint
(
position
=
(
3
,
8
),
direction
=
1
)),
TrainrunWaypoint
(
schedul
ed_at
=
20
,
waypoint
=
Waypoint
(
position
=
(
3
,
8
),
direction
=
5
))],
1
:
[
TrainrunWaypoint
(
schedul
ed_at
=
0
,
waypoint
=
Waypoint
(
position
=
(
3
,
8
),
direction
=
3
)),
TrainrunWaypoint
(
schedul
ed_at
=
3
,
waypoint
=
Waypoint
(
position
=
(
3
,
7
),
direction
=
3
)),
TrainrunWaypoint
(
schedul
ed_at
=
5
,
waypoint
=
Waypoint
(
position
=
(
3
,
6
),
direction
=
3
)),
TrainrunWaypoint
(
schedul
ed_at
=
7
,
waypoint
=
Waypoint
(
position
=
(
3
,
5
),
direction
=
3
)),
TrainrunWaypoint
(
schedul
ed_at
=
9
,
waypoint
=
Waypoint
(
position
=
(
3
,
4
),
direction
=
3
)),
TrainrunWaypoint
(
schedul
ed_at
=
11
,
waypoint
=
Waypoint
(
position
=
(
3
,
3
),
direction
=
3
)),
TrainrunWaypoint
(
schedul
ed_at
=
13
,
waypoint
=
Waypoint
(
position
=
(
2
,
3
),
direction
=
0
)),
TrainrunWaypoint
(
schedul
ed_at
=
15
,
waypoint
=
Waypoint
(
position
=
(
1
,
3
),
direction
=
0
)),
TrainrunWaypoint
(
schedul
ed_at
=
17
,
waypoint
=
Waypoint
(
position
=
(
0
,
3
),
direction
=
0
))]}
chosen_path_dict
=
{
0
:
[
TrainrunWaypoint
(
lin
ed_at
=
0
,
waypoint
=
Waypoint
(
position
=
(
3
,
0
),
direction
=
3
)),
TrainrunWaypoint
(
lin
ed_at
=
2
,
waypoint
=
Waypoint
(
position
=
(
3
,
1
),
direction
=
1
)),
TrainrunWaypoint
(
lin
ed_at
=
3
,
waypoint
=
Waypoint
(
position
=
(
3
,
2
),
direction
=
1
)),
TrainrunWaypoint
(
lin
ed_at
=
14
,
waypoint
=
Waypoint
(
position
=
(
3
,
3
),
direction
=
1
)),
TrainrunWaypoint
(
lin
ed_at
=
15
,
waypoint
=
Waypoint
(
position
=
(
3
,
4
),
direction
=
1
)),
TrainrunWaypoint
(
lin
ed_at
=
16
,
waypoint
=
Waypoint
(
position
=
(
3
,
5
),
direction
=
1
)),
TrainrunWaypoint
(
lin
ed_at
=
17
,
waypoint
=
Waypoint
(
position
=
(
3
,
6
),
direction
=
1
)),
TrainrunWaypoint
(
lin
ed_at
=
18
,
waypoint
=
Waypoint
(
position
=
(
3
,
7
),
direction
=
1
)),
TrainrunWaypoint
(
lin
ed_at
=
19
,
waypoint
=
Waypoint
(
position
=
(
3
,
8
),
direction
=
1
)),
TrainrunWaypoint
(
lin
ed_at
=
20
,
waypoint
=
Waypoint
(
position
=
(
3
,
8
),
direction
=
5
))],
1
:
[
TrainrunWaypoint
(
lin
ed_at
=
0
,
waypoint
=
Waypoint
(
position
=
(
3
,
8
),
direction
=
3
)),
TrainrunWaypoint
(
lin
ed_at
=
3
,
waypoint
=
Waypoint
(
position
=
(
3
,
7
),
direction
=
3
)),
TrainrunWaypoint
(
lin
ed_at
=
5
,
waypoint
=
Waypoint
(
position
=
(
3
,
6
),
direction
=
3
)),
TrainrunWaypoint
(
lin
ed_at
=
7
,
waypoint
=
Waypoint
(
position
=
(
3
,
5
),
direction
=
3
)),
TrainrunWaypoint
(
lin
ed_at
=
9
,
waypoint
=
Waypoint
(
position
=
(
3
,
4
),
direction
=
3
)),
TrainrunWaypoint
(
lin
ed_at
=
11
,
waypoint
=
Waypoint
(
position
=
(
3
,
3
),
direction
=
3
)),
TrainrunWaypoint
(
lin
ed_at
=
13
,
waypoint
=
Waypoint
(
position
=
(
2
,
3
),
direction
=
0
)),
TrainrunWaypoint
(
lin
ed_at
=
15
,
waypoint
=
Waypoint
(
position
=
(
1
,
3
),
direction
=
0
)),
TrainrunWaypoint
(
lin
ed_at
=
17
,
waypoint
=
Waypoint
(
position
=
(
0
,
3
),
direction
=
0
))]}
expected_action_plan
=
[[
# take action to enter the grid
ActionPlanElement
(
0
,
RailEnvActions
.
MOVE_FORWARD
),
...
...
tests/test_distance_map.py
View file @
329120c7
...
...
@@ -6,7 +6,7 @@ from flatland.envs.observations import TreeObsForRailEnv
from
flatland.envs.predictions
import
ShortestPathPredictorForRailEnv
from
flatland.envs.rail_env
import
RailEnv
from
flatland.envs.rail_generators
import
rail_from_grid_transition_map
from
flatland.envs.
schedul
e_generators
import
random_
schedul
e_generator
from
flatland.envs.
lin
e_generators
import
random_
lin
e_generator
def
test_walker
():
...
...
@@ -28,7 +28,7 @@ def test_walker():
env
=
RailEnv
(
width
=
rail_map
.
shape
[
1
],
height
=
rail_map
.
shape
[
0
],
rail_generator
=
rail_from_grid_transition_map
(
rail
),
schedul
e_generator
=
random_
schedul
e_generator
(),
lin
e_generator
=
random_
lin
e_generator
(),
number_of_agents
=
1
,
obs_builder_object
=
TreeObsForRailEnv
(
max_depth
=
2
,
predictor
=
ShortestPathPredictorForRailEnv
(
max_depth
=
10
)),
...
...
tests/test_flaltland_rail_agent_status.py
View file @
329120c7
...
...
@@ -4,7 +4,7 @@ from flatland.envs.observations import TreeObsForRailEnv
from
flatland.envs.predictions
import
ShortestPathPredictorForRailEnv
from
flatland.envs.rail_env
import
RailEnv
,
RailEnvActions
from
flatland.envs.rail_generators
import
rail_from_grid_transition_map
from
flatland.envs.
schedul
e_generators
import
random_
schedul
e_generator
from
flatland.envs.
lin
e_generators
import
random_
lin
e_generator
from
flatland.utils.simple_rail
import
make_simple_rail
from
test_utils
import
ReplayConfig
,
Replay
,
run_replay_config
,
set_penalties_for_replay
...
...
@@ -13,7 +13,7 @@ def test_initial_status():
"""Test that agent lifecycle works correctly ready-to-depart -> active -> done."""
rail
,
rail_map
=
make_simple_rail
()
env
=
RailEnv
(
width
=
rail_map
.
shape
[
1
],
height
=
rail_map
.
shape
[
0
],
rail_generator
=
rail_from_grid_transition_map
(
rail
),
schedul
e_generator
=
random_
schedul
e_generator
(),
number_of_agents
=
1
,
lin
e_generator
=
random_
lin
e_generator
(),
number_of_agents
=
1
,
obs_builder_object
=
TreeObsForRailEnv
(
max_depth
=
2
,
predictor
=
ShortestPathPredictorForRailEnv
()),
remove_agents_at_target
=
False
)
env
.
reset
()
...
...
@@ -121,7 +121,7 @@ def test_status_done_remove():
"""Test that agent lifecycle works correctly ready-to-depart -> active -> done."""
rail
,
rail_map
=
make_simple_rail
()
env
=
RailEnv
(
width
=
rail_map
.
shape
[
1
],
height
=
rail_map
.
shape
[
0
],
rail_generator
=
rail_from_grid_transition_map
(
rail
),
schedul
e_generator
=
random_
schedul
e_generator
(),
number_of_agents
=
1
,
lin
e_generator
=
random_
lin
e_generator
(),
number_of_agents
=
1
,
obs_builder_object
=
TreeObsForRailEnv
(
max_depth
=
2
,
predictor
=
ShortestPathPredictorForRailEnv
()),
remove_agents_at_target
=
True
)
env
.
reset
()
...
...
tests/test_flatland_core_transition_map.py
View file @
329120c7
...
...
@@ -6,7 +6,7 @@ from flatland.envs.observations import TreeObsForRailEnv
from
flatland.envs.predictions
import
ShortestPathPredictorForRailEnv
from
flatland.envs.rail_env
import
RailEnv
from
flatland.envs.rail_generators
import
rail_from_grid_transition_map
from
flatland.envs.
schedul
e_generators
import
random_
schedul
e_generator
from
flatland.envs.
lin
e_generators
import
random_
lin
e_generator
from
flatland.utils.rendertools
import
RenderTool
from
flatland.utils.simple_rail
import
make_simple_rail
,
make_simple_rail_unconnected
...
...
@@ -70,7 +70,7 @@ def test_path_exists(rendering=False):
env
=
RailEnv
(
width
=
rail_map
.
shape
[
1
],
height
=
rail_map
.
shape
[
0
],
rail_generator
=
rail_from_grid_transition_map
(
rail
),
schedul
e_generator
=
random_
schedul
e_generator
(),
lin
e_generator
=
random_
lin
e_generator
(),
number_of_agents
=
1
,
obs_builder_object
=
TreeObsForRailEnv
(
max_depth
=
2
,
predictor
=
ShortestPathPredictorForRailEnv
()),
)
...
...
@@ -134,7 +134,7 @@ def test_path_not_exists(rendering=False):
env
=
RailEnv
(
width
=
rail_map
.
shape
[
1
],
height
=
rail_map
.
shape
[
0
],
rail_generator
=
rail_from_grid_transition_map
(
rail
),
schedul
e_generator
=
random_
schedul
e_generator
(),
lin
e_generator
=
random_
lin
e_generator
(),
number_of_agents
=
1
,
obs_builder_object
=
TreeObsForRailEnv
(
max_depth
=
2
,
predictor
=
ShortestPathPredictorForRailEnv
()),
)
...
...
tests/test_flatland_envs_observations.py
View file @
329120c7
...
...
@@ -10,7 +10,7 @@ from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv
from
flatland.envs.predictions
import
ShortestPathPredictorForRailEnv
from
flatland.envs.rail_env
import
RailEnv
,
RailEnvActions
from
flatland.envs.rail_generators
import
rail_from_grid_transition_map
from
flatland.envs.
schedul
e_generators
import
random_
schedul
e_generator
from
flatland.envs.
lin
e_generators
import
random_
lin
e_generator
from
flatland.utils.rendertools
import
RenderTool
from
flatland.utils.simple_rail
import
make_simple_rail
...
...
@@ -21,7 +21,7 @@ def test_global_obs():
rail
,
rail_map
=
make_simple_rail
()
env
=
RailEnv
(
width
=
rail_map
.
shape
[
1
],
height
=
rail_map
.
shape
[
0
],
rail_generator
=
rail_from_grid_transition_map
(
rail
),
schedul
e_generator
=
random_
schedul
e_generator
(),
number_of_agents
=
1
,
lin
e_generator
=
random_
lin
e_generator
(),
number_of_agents
=
1
,
obs_builder_object
=
GlobalObsForRailEnv
())
global_obs
,
info
=
env
.
reset
()
...
...
@@ -93,7 +93,7 @@ def _step_along_shortest_path(env, obs_builder, rail):
def
test_reward_function_conflict
(
rendering
=
False
):
rail
,
rail_map
=
make_simple_rail
()
env
=
RailEnv
(
width
=
rail_map
.
shape
[
1
],
height
=
rail_map
.
shape
[
0
],
rail_generator
=
rail_from_grid_transition_map
(
rail
),
schedul
e_generator
=
random_
schedul
e_generator
(),
number_of_agents
=
2
,
lin
e_generator
=
random_
lin
e_generator
(),
number_of_agents
=
2
,
obs_builder_object
=
TreeObsForRailEnv
(
max_depth
=
2
,
predictor
=
ShortestPathPredictorForRailEnv
()))
obs_builder
:
TreeObsForRailEnv
=
env
.
obs_builder
env
.
reset
()
...
...
@@ -181,7 +181,7 @@ def test_reward_function_conflict(rendering=False):
def
test_reward_function_waiting
(
rendering
=
False
):
rail
,
rail_map
=
make_simple_rail
()
env
=
RailEnv
(
width
=
rail_map
.
shape
[
1
],
height
=
rail_map
.
shape
[
0
],
rail_generator
=
rail_from_grid_transition_map
(
rail
),
schedul
e_generator
=
random_
schedul
e_generator
(),
number_of_agents
=
2
,
lin
e_generator
=
random_
lin
e_generator
(),
number_of_agents
=
2
,
obs_builder_object
=
TreeObsForRailEnv
(
max_depth
=
2
,
predictor
=
ShortestPathPredictorForRailEnv
()),
remove_agents_at_target
=
False
)
obs_builder
:
TreeObsForRailEnv
=
env
.
obs_builder
...
...
tests/test_flatland_envs_predictions.py
View file @
329120c7
...
...
@@ -12,7 +12,7 @@ from flatland.envs.rail_env import RailEnv
from
flatland.envs.rail_env_shortest_paths
import
get_shortest_paths
from
flatland.envs.rail_generators
import
rail_from_grid_transition_map
from
flatland.envs.rail_trainrun_data_structures
import
Waypoint
from
flatland.envs.
schedul
e_generators
import
random_
schedul
e_generator
from
flatland.envs.
lin
e_generators
import
random_
lin
e_generator
from
flatland.utils.rendertools
import
RenderTool
from
flatland.utils.simple_rail
import
make_simple_rail
,
make_simple_rail2
,
make_invalid_simple_rail
...
...
@@ -25,7 +25,7 @@ def test_dummy_predictor(rendering=False):
env
=
RailEnv
(
width
=
rail_map
.
shape
[
1
],
height
=
rail_map
.
shape
[
0
],
rail_generator
=
rail_from_grid_transition_map
(
rail
),
schedul
e_generator
=
random_
schedul
e_generator
(),
lin
e_generator
=
random_
lin
e_generator
(),
number_of_agents
=
1
,
obs_builder_object
=
TreeObsForRailEnv
(
max_depth
=
2
,
predictor
=
DummyPredictorForRailEnv
(
max_depth
=
10
)),
)
...
...
@@ -116,7 +116,7 @@ def test_shortest_path_predictor(rendering=False):
env
=
RailEnv
(
width
=
rail_map
.
shape
[
1
],
height
=
rail_map
.
shape
[
0
],
rail_generator
=
rail_from_grid_transition_map
(
rail
),
schedul
e_generator
=
random_
schedul
e_generator
(),
lin
e_generator
=
random_
lin
e_generator
(),
number_of_agents
=
1
,
obs_builder_object
=
TreeObsForRailEnv
(
max_depth
=
2
,
predictor
=
ShortestPathPredictorForRailEnv
()),
)
...
...
@@ -247,7 +247,7 @@ def test_shortest_path_predictor_conflicts(rendering=False):
env
=
RailEnv
(
width
=
rail_map
.
shape
[
1
],
height
=
rail_map
.
shape
[
0
],
rail_generator
=
rail_from_grid_transition_map
(
rail
),
schedul
e_generator
=
random_
schedul
e_generator
(),
lin
e_generator
=
random_
lin
e_generator
(),
number_of_agents
=
2
,
obs_builder_object
=
TreeObsForRailEnv
(
max_depth
=
2
,
predictor
=
ShortestPathPredictorForRailEnv
()),
)
...
...
tests/test_flatland_envs_rail_env.py
View file @
329120c7
...
...
@@ -11,7 +11,7 @@ from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from
flatland.envs.rail_env
import
RailEnv
,
RailEnvActions
from
flatland.envs.rail_generators
import
complex_rail_generator
,
rail_from_file
from
flatland.envs.rail_generators
import
rail_from_grid_transition_map
from
flatland.envs.
schedul
e_generators
import
random_
schedul
e_generator
,
complex_
schedul
e_generator
,
schedul
e_from_file
from
flatland.envs.
lin
e_generators
import
random_
lin
e_generator
,
complex_
lin
e_generator
,
lin
e_from_file
from
flatland.utils.simple_rail
import
make_simple_rail
from
flatland.envs.persistence
import
RailEnvPersister
from
flatland.utils.rendertools
import
RenderTool
...
...
@@ -38,7 +38,7 @@ def test_load_env():
def
test_save_load
():
env
=
RailEnv
(
width
=
10
,
height
=
10
,
rail_generator
=
complex_rail_generator
(
nr_start_goal
=
2
,
nr_extra
=
5
,
min_dist
=
6
,
seed
=
1
),
schedul
e_generator
=
complex_
schedul
e_generator
(),
number_of_agents
=
2
)
lin
e_generator
=
complex_
lin
e_generator
(),
number_of_agents
=
2
)
env
.
reset
()
agent_1_pos
=
env
.
agents
[
0
].
position
agent_1_dir
=
env
.
agents
[
0
].
direction
...
...
@@ -68,7 +68,7 @@ def test_save_load():
def
test_save_load_mpk
():
env
=
RailEnv
(
width
=
10
,
height
=
10
,
rail_generator
=
complex_rail_generator
(
nr_start_goal
=
2
,
nr_extra
=
5
,
min_dist
=
6
,
seed
=
1
),
schedul
e_generator
=
complex_
schedul
e_generator
(),
number_of_agents
=
2
)
lin
e_generator
=
complex_
lin
e_generator
(),
number_of_agents
=
2
)
env
.
reset
()
os
.
makedirs
(
"tmp"
,
exist_ok
=
True
)
...
...
@@ -120,7 +120,7 @@ def test_rail_environment_single_agent(show=False):
rail
=
GridTransitionMap
(
width
=
3
,
height
=
3
,
transitions
=
transitions
)
rail
.
grid
=
rail_map
rail_env
=
RailEnv
(
width
=
3
,
height
=
3
,
rail_generator
=
rail_from_grid_transition_map
(
rail
),
schedul
e_generator
=
random_
schedul
e_generator
(),
number_of_agents
=
1
,
lin
e_generator
=
random_
lin
e_generator
(),
number_of_agents
=
1
,
obs_builder_object
=
GlobalObsForRailEnv
())
else
:
rail_env
,
env_dict
=
RailEnvPersister
.
load_new
(
"test_env_loop.pkl"
,
"env_data.tests"
)
...
...
@@ -203,7 +203,7 @@ def test_rail_environment_single_agent(show=False):
rail_env
.
agents
[
0
].
direction
=
0
# JW - to avoid problem with random_
schedul
e_generator.
# JW - to avoid problem with random_
lin
e_generator.
#rail_env.agents[0].position = (1,2)
iStep
=
0
...
...
@@ -246,7 +246,7 @@ def test_dead_end():
rail
.
grid
=
rail_map
rail_env
=
RailEnv
(
width
=
rail_map
.
shape
[
1
],
height
=
rail_map
.
shape
[
0
],
rail_generator
=
rail_from_grid_transition_map
(
rail
),
schedul
e_generator
=
random_
schedul
e_generator
(),
number_of_agents
=
1
,
lin
e_generator
=
random_
lin
e_generator
(),
number_of_agents
=
1
,
obs_builder_object
=
GlobalObsForRailEnv
())
# We try the configuration in the 4 directions:
...
...
@@ -269,7 +269,7 @@ def test_dead_end():
rail
.
grid
=
rail_map
rail_env
=
RailEnv
(
width
=
rail_map
.
shape
[
1
],
height
=
rail_map
.
shape
[
0
],
rail_generator
=
rail_from_grid_transition_map
(
rail
),
schedul
e_generator
=
random_
schedul
e_generator
(),
number_of_agents
=
1
,
lin
e_generator
=
random_
lin
e_generator
(),
number_of_agents
=
1
,
obs_builder_object
=
GlobalObsForRailEnv
())
rail_env
.
reset
()
...
...
@@ -284,7 +284,7 @@ def test_dead_end():
def
test_get_entry_directions
():
rail
,
rail_map
=
make_simple_rail
()
env
=
RailEnv
(
width
=
rail_map
.
shape
[
1
],
height
=
rail_map
.
shape
[
0
],
rail_generator
=
rail_from_grid_transition_map
(
rail
),
schedul
e_generator
=
random_
schedul
e_generator
(),
number_of_agents
=
1
,
lin
e_generator
=
random_
lin
e_generator
(),
number_of_agents
=
1
,
obs_builder_object
=
TreeObsForRailEnv
(
max_depth
=
2
,
predictor
=
ShortestPathPredictorForRailEnv
()))
env
.
reset
()
...
...
@@ -319,7 +319,7 @@ def test_rail_env_reset():
rail
,
rail_map
=
make_simple_rail
()
env
=
RailEnv
(
width
=
rail_map
.
shape
[
1
],
height
=
rail_map
.
shape
[
0
],
rail_generator
=
rail_from_grid_transition_map
(
rail
),
schedul
e_generator
=
random_
schedul
e_generator
(),
number_of_agents
=
3
,
lin
e_generator
=
random_
lin
e_generator
(),
number_of_agents
=
3
,
obs_builder_object
=
TreeObsForRailEnv
(
max_depth
=
2
,
predictor
=
ShortestPathPredictorForRailEnv
()))
env
.
reset
()
...
...
@@ -331,7 +331,7 @@ def test_rail_env_reset():
agents_initial
=
env
.
agents
#env2 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name),
#
schedul
e_generator=
schedul
e_from_file(file_name), number_of_agents=1,
#
lin
e_generator=
lin
e_from_file(file_name), number_of_agents=1,
# obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
#env2.reset(False, False, False)
env2
,
env2_dict
=
RailEnvPersister
.
load_new
(
file_name
)
...
...
@@ -343,7 +343,7 @@ def test_rail_env_reset():
assert
agents_initial
==
agents_loaded
env3
=
RailEnv
(
width
=
1
,
height
=
1
,
rail_generator
=
rail_from_file
(
file_name
),
schedul
e_generator
=
schedul
e_from_file
(
file_name
),
number_of_agents
=
1
,
lin
e_generator
=
lin
e_from_file
(
file_name
),
number_of_agents
=
1
,
obs_builder_object
=
TreeObsForRailEnv
(
max_depth
=
2
,
predictor
=
ShortestPathPredictorForRailEnv
()))
env3
.
reset
(
False
,
True
,
False
)
rails_loaded
=
env3
.
rail
.
grid
...
...
@@ -353,7 +353,7 @@ def test_rail_env_reset():
assert
agents_initial
==
agents_loaded
env4
=
RailEnv
(
width
=
1
,
height
=
1
,
rail_generator
=
rail_from_file
(
file_name
),
schedul
e_generator
=
schedul
e_from_file
(
file_name
),
number_of_agents
=
1
,
lin
e_generator
=
lin
e_from_file
(
file_name
),
number_of_agents
=
1
,
obs_builder_object
=
TreeObsForRailEnv
(
max_depth
=
2
,
predictor
=
ShortestPathPredictorForRailEnv
()))
env4
.
reset
(
True
,
False
,
False
)
rails_loaded
=
env4
.
rail
.
grid
...
...
tests/test_flatland_envs_rail_env_shortest_paths.py
View file @
329120c7
...
...
@@ -9,7 +9,7 @@ from flatland.envs.rail_env_shortest_paths import get_shortest_paths, get_k_shor
from
flatland.envs.rail_env_utils
import
load_flatland_environment_from_file
from
flatland.envs.rail_generators
import
rail_from_grid_transition_map
from
flatland.envs.rail_trainrun_data_structures
import
Waypoint
from
flatland.envs.
schedul
e_generators
import
random_
schedul
e_generator
from
flatland.envs.
lin
e_generators
import
random_
lin
e_generator
from
flatland.utils.rendertools
import
RenderTool
from
flatland.utils.simple_rail
import
make_disconnected_simple_rail
,
make_simple_rail_with_alternatives
from
flatland.envs.persistence
import
RailEnvPersister
...
...
@@ -19,7 +19,7 @@ def test_get_shortest_paths_unreachable():
rail
,
rail_map
=
make_disconnected_simple_rail
()
env
=
RailEnv
(
width
=
rail_map
.
shape
[
1
],
height
=
rail_map
.
shape
[
0
],
rail_generator
=
rail_from_grid_transition_map
(
rail
),
schedul
e_generator
=
random_
schedul
e_generator
(),
number_of_agents
=
1
,
lin
e_generator
=
random_
lin
e_generator
(),
number_of_agents
=
1
,
obs_builder_object
=
GlobalObsForRailEnv
())
env
.
reset
()
...
...
@@ -238,7 +238,7 @@ def test_get_k_shortest_paths(rendering=False):
env
=
RailEnv
(
width
=
rail_map
.
shape
[
1
],
height
=
rail_map
.
shape
[
0
],
rail_generator
=
rail_from_grid_transition_map
(
rail
),
schedul
e_generator
=
random_
schedul
e_generator
(),
lin
e_generator
=
random_
lin
e_generator
(),
number_of_agents
=
1
,
obs_builder_object
=
GlobalObsForRailEnv
(),
)
...
...
tests/test_flatland_envs_sparse_rail_generator.py
View file @
329120c7
...
...
@@ -7,7 +7,7 @@ from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
from
flatland.envs.observations
import
GlobalObsForRailEnv
from
flatland.envs.rail_env
import
RailEnv
from
flatland.envs.rail_generators
import
sparse_rail_generator
from
flatland.envs.
schedul
e_generators
import
sparse_
schedul
e_generator
from
flatland.envs.
lin
e_generators
import
sparse_
lin
e_generator
from
flatland.utils.rendertools
import
RenderTool
...
...
@@ -17,7 +17,7 @@ def test_sparse_rail_generator():
seed
=
5
,
grid_mode
=
False
),
schedul
e_generator
=
sparse_
schedul
e_generator
(),
number_of_agents
=
10
,
lin
e_generator
=
sparse_
lin
e_generator
(),
number_of_agents
=
10
,
obs_builder_object
=
GlobalObsForRailEnv
())
env
.
reset
(
False
,
False
,
True
)
for
r
in
range
(
env
.
height
):
...
...
@@ -602,7 +602,7 @@ def test_sparse_rail_generator_deterministic():
seed
=
215545
,
# Random seed
grid_mode
=
True
),
schedul
e_generator
=
sparse_
schedul
e_generator
(
speed_ration_map
),
number_of_agents
=
1
)
lin
e_generator
=
sparse_
lin
e_generator
(
speed_ration_map
),
number_of_agents
=
1
)
env
.
reset
()
# for r in range(env.height):
# for c in range(env.width):
...
...
@@ -1371,7 +1371,7 @@ def test_rail_env_action_required_info():
max_rails_between_cities
=
3
,
seed
=
5
,
# Random seed
grid_mode
=
False
# Ordered distribution of nodes
),
schedul
e_generator
=
sparse_
schedul
e_generator
(
speed_ration_map
),
number_of_agents
=
10
,
),
lin
e_generator
=
sparse_
lin
e_generator
(
speed_ration_map
),
number_of_agents
=
10
,
obs_builder_object
=
GlobalObsForRailEnv
(),
remove_agents_at_target
=
False
)
env_only_if_action_required
=
RailEnv
(
width
=
50
,
height
=
50
,
rail_generator
=
sparse_rail_generator
(
...
...
@@ -1380,7 +1380,7 @@ def test_rail_env_action_required_info():
seed
=
5
,
# Random seed
grid_mode
=
False
# Ordered distribution of nodes
),
schedul
e_generator
=
sparse_
schedul
e_generator
(
speed_ration_map
),
number_of_agents
=
10
,
),
lin
e_generator
=
sparse_
lin
e_generator
(
speed_ration_map
),
number_of_agents
=
10
,
obs_builder_object
=
GlobalObsForRailEnv
(),
remove_agents_at_target
=
False
)
env_renderer
=
RenderTool
(
env_always_action
,
gl
=
"PILSVG"
,
)
...
...
@@ -1442,7 +1442,7 @@ def test_rail_env_malfunction_speed_info():
seed
=
5
,
grid_mode
=
False
),
schedul
e_generator
=
sparse_
schedul
e_generator
(),
number_of_agents
=
10
,
lin
e_generator
=
sparse_
lin
e_generator
(),
number_of_agents
=
10
,
obs_builder_object
=
GlobalObsForRailEnv
())
env
.
reset
(
False
,
False
,
True
)
...
...
@@ -1476,7 +1476,7 @@ def test_sparse_generator_with_too_man_cities_does_not_break_down():
max_rails_between_cities
=
3
,
seed
=
5
,
grid_mode
=
False
),
schedul
e_generator
=
sparse_
schedul
e_generator
(),
number_of_agents
=
10
,
obs_builder_object
=
GlobalObsForRailEnv
())
),
lin
e_generator
=
sparse_
lin
e_generator
(),
number_of_agents
=
10
,
obs_builder_object
=
GlobalObsForRailEnv
())
def
test_sparse_generator_with_illegal_params_aborts
():
...
...
@@ -1489,7 +1489,7 @@ def test_sparse_generator_with_illegal_params_aborts():
max_rails_between_cities
=
3
,
seed
=
5
,
grid_mode
=
False
),
schedul
e_generator
=
sparse_
schedul
e_generator
(),
number_of_agents
=
10
,
),
lin
e_generator
=
sparse_
lin
e_generator
(),
number_of_agents
=
10
,
obs_builder_object
=
GlobalObsForRailEnv
()).
reset
()
with
unittest
.
TestCase
.
assertRaises
(
test_sparse_generator_with_illegal_params_aborts
,
ValueError
):
...
...
@@ -1498,7 +1498,7 @@ def test_sparse_generator_with_illegal_params_aborts():
max_rails_between_cities
=
3
,
seed
=
5
,
grid_mode
=
False
),
schedul
e_generator
=
sparse_
schedul
e_generator
(),
number_of_agents
=
10
,
),
lin
e_generator
=
sparse_
lin
e_generator
(),
number_of_agents
=
10
,
obs_builder_object
=
GlobalObsForRailEnv
()).
reset
()
...
...
@@ -1515,7 +1515,7 @@ def test_sparse_generator_changes_to_grid_mode():
max_rails_in_city
=
2
,
seed
=
15
,
grid_mode
=
False
),
schedul
e_generator
=
sparse_
schedul
e_generator
(),
number_of_agents
=
10
,
),
lin
e_generator
=
sparse_
lin
e_generator
(),
number_of_agents
=
10
,
obs_builder_object
=
GlobalObsForRailEnv
())
for
test_run
in
range
(
10
):
with
warnings
.
catch_warnings
(
record
=
True
)
as
w
:
...
...
tests/test_flatland_malfunction.py
View file @
329120c7
...
...
@@ -10,7 +10,7 @@ from flatland.envs.agent_utils import RailAgentStatus
from
flatland.envs.malfunction_generators
import
malfunction_from_params
,
MalfunctionParameters
from
flatland.envs.rail_env
import
RailEnv
,
RailEnvActions
from
flatland.envs.rail_generators
import
rail_from_grid_transition_map
from
flatland.envs.
schedul
e_generators
import
random_
schedul
e_generator
from
flatland.envs.
lin
e_generators
import
random_
lin
e_generator
from
flatland.utils.simple_rail
import
make_simple_rail2
from
test_utils
import
Replay
,
ReplayConfig
,
run_replay_config
,
set_penalties_for_replay
...
...
@@ -77,7 +77,7 @@ def test_malfunction_process():
env
=
RailEnv
(
width
=
25
,
height
=
30
,
rail_generator
=
rail_from_grid_transition_map
(
rail
),
schedul
e_generator
=
random_
schedul
e_generator
(),
lin
e_generator
=
random_
lin
e_generator
(),
number_of_agents
=
1
,
malfunction_generator_and_process_data
=
malfunction_from_params
(
stochastic_data
),
obs_builder_object
=
SingleAgentNavigationObs
()
...
...
@@ -131,7 +131,7 @@ def test_malfunction_process_statistically():
env
=
RailEnv
(
width
=
25
,
height
=
30
,
rail_generator
=
rail_from_grid_transition_map
(
rail
),
schedul
e_generator
=
random_
schedul
e_generator
(),
lin
e_generator
=
random_
lin
e_generator
(),
number_of_agents
=
10
,
malfunction_generator_and_process_data
=
malfunction_from_params
(
stochastic_data
),
obs_builder_object
=
SingleAgentNavigationObs
()
...
...
@@ -178,7 +178,7 @@ def test_malfunction_before_entry():
env
=
RailEnv
(
width
=
25
,
height
=
30
,
rail_generator
=
rail_from_grid_transition_map
(
rail
),
schedul
e_generator
=
random_
schedul
e_generator
(),
lin
e_generator
=
random_
lin
e_generator
(),
number_of_agents
=
10
,
malfunction_generator_and_process_data
=
malfunction_from_params
(
stochastic_data
),
obs_builder_object
=
SingleAgentNavigationObs
()
...
...
@@ -222,7 +222,7 @@ def test_malfunction_values_and_behavior():
env
=
RailEnv
(
width
=
25
,
height
=
30
,
rail_generator
=
rail_from_grid_transition_map
(
rail
),
schedul
e_generator
=
random_
schedul
e_generator
(),
lin
e_generator
=
random_
lin
e_generator
(),
number_of_agents
=
1
,
malfunction_generator_and_process_data
=
malfunction_from_params
(
stochastic_data
),
obs_builder_object
=
SingleAgentNavigationObs
()
...
...
@@ -251,7 +251,7 @@ def test_initial_malfunction():
env
=
RailEnv
(
width
=
25
,
height
=
30
,
rail_generator
=
rail_from_grid_transition_map
(
rail
),
schedul
e_generator
=
random_
schedul
e_generator
(
seed
=
10
),
lin
e_generator
=
random_
lin
e_generator
(
seed
=
10
),
number_of_agents
=
1
,
malfunction_generator_and_process_data
=
malfunction_from_params
(
stochastic_data
),
# Malfunction data generator
...
...
@@ -316,7 +316,7 @@ def test_initial_malfunction_stop_moving():
rail
,
rail_map
=
make_simple_rail2
()
env
=
RailEnv
(
width
=
25
,
height
=
30
,
rail_generator
=
rail_from_grid_transition_map
(
rail
),
schedul
e_generator
=
random_
schedul
e_generator
(),
number_of_agents
=
1
,
lin
e_generator
=
random_
lin
e_generator
(),
number_of_agents
=
1
,
obs_builder_object
=
SingleAgentNavigationObs
())
env
.
reset
()
...
...
@@ -400,7 +400,7 @@ def test_initial_malfunction_do_nothing():
env
=
RailEnv
(
width
=
25
,
height
=
30
,
rail_generator
=
rail_from_grid_transition_map
(
rail
),
schedul
e_generator
=
random_
schedul
e_generator
(),
lin
e_generator
=
random_
lin
e_generator
(),
number_of_agents
=
1
,
malfunction_generator_and_process_data
=
malfunction_from_params
(
stochastic_data
),
# Malfunction data generator
...
...
@@ -477,7 +477,7 @@ def tests_random_interference_from_outside():
# Set fixed malfunction duration for this test
rail
,
rail_map
=
make_simple_rail2
()
env
=
RailEnv
(
width
=
25
,
height
=
30
,
rail_generator
=
rail_from_grid_transition_map
(
rail
),
schedul
e_generator
=
random_
schedul
e_generator
(
seed
=
2
),
number_of_agents
=
1
,
random_seed
=
1
)
lin
e_generator
=
random_
lin
e_generator
(
seed
=
2
),
number_of_agents
=
1
,
random_seed
=
1
)
env
.
reset
()
env
.
agents
[
0
].
speed_data
[
'speed'
]
=
0.33
env
.
reset
(
False
,
False
,
False
,
random_seed
=
10
)
...
...
@@ -501,7 +501,7 @@ def tests_random_interference_from_outside():
random
.
seed
(
47
)
np
.
random
.
seed
(
1234
)
env
=
RailEnv
(
width
=
25
,
height
=
30
,
rail_generator
=
rail_from_grid_transition_map
(
rail
),
schedul
e_generator
=
random_
schedul
e_generator
(
seed
=
2
),
number_of_agents
=
1
,
random_seed
=
1
)
lin
e_generator
=
random_
lin
e_generator
(
seed
=
2
),
number_of_agents
=
1
,
random_seed
=
1
)
env
.
reset
()
env
.
agents
[
0
].
speed_data
[
'speed'
]
=
0.33
env
.
reset
(
False
,
False
,
False
,
random_seed
=
10
)
...
...
@@ -533,7 +533,7 @@ def test_last_malfunction_step():
rail
,
rail_map
=
make_simple_rail2
()
env
=
RailEnv
(
width
=
25
,
height
=
30
,
rail_generator
=
rail_from_grid_transition_map
(
rail
),
schedul
e_generator
=
random_
schedul
e_generator
(
seed
=
2
),
number_of_agents
=
1
,
random_seed
=
1
)
lin
e_generator
=
random_
lin
e_generator
(
seed
=
2
),
number_of_agents
=
1
,
random_seed
=
1
)
env
.
reset
()
env
.
agents
[
0
].
speed_data
[
'speed'
]
=
1.
/
3.
env
.
agents
[
0
].
target
=
(
0
,
0
)
...
...
tests/test_flatland_multiprocessing.py
View file @
329120c7
...
...
@@ -6,7 +6,7 @@ from flatland.envs.observations import TreeObsForRailEnv
from
flatland.envs.predictions
import
ShortestPathPredictorForRailEnv
from
flatland.envs.rail_env
import
RailEnv
from
flatland.envs.rail_generators
import
rail_from_grid_transition_map
from
flatland.envs.
schedul
e_generators
import
random_
schedul
e_generator
from
flatland.envs.
lin
e_generators
import
random_
lin
e_generator
from
flatland.utils.simple_rail
import
make_simple_rail
"""Tests for `flatland` package."""
...
...
@@ -19,7 +19,7 @@ def test_multiprocessing_tree_obs():
obs_builder
=
TreeObsForRailEnv
(
max_depth
=
2
,
predictor
=
ShortestPathPredictorForRailEnv
())
env
=
RailEnv
(
width
=
rail_map
.
shape
[
1
],
height
=
rail_map
.
shape
[
0
],
rail_generator
=
rail_from_grid_transition_map
(
rail
),
schedul
e_generator
=
random_
schedul
e_generator
(),
number_of_agents
=
number_of_agents
,
lin
e_generator
=
random_
lin
e_generator
(),
number_of_agents
=
number_of_agents
,
obs_builder_object
=
obs_builder
)
env
.
reset
(
True
,
True
)
...
...
tests/test_flatland_schedule_from_file.py
View file @
329120c7
...
...
@@ -3,11 +3,11 @@ from test_utils import create_and_save_env
from
flatland.envs.rail_env
import
RailEnv
from
flatland.envs.rail_generators
import
sparse_rail_generator
,
random_rail_generator
,
complex_rail_generator
,
\
rail_from_file
from
flatland.envs.
schedul
e_generators
import
sparse_
schedul
e_generator
,
random_
schedul
e_generator
,
\
complex_
schedul
e_generator
,
schedul
e_from_file