Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Flatland
Flatland
Commits
f4dc1668
Commit
f4dc1668
authored
Sep 10, 2021
by
Dipam Chakraborty
Browse files
change railenvstatus and speed data in tests
parent
4169a0f1
Pipeline
#8455
failed with stages
in 6 minutes and 27 seconds
Changes
11
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
tests/test_eval_timeout.py
View file @
f4dc1668
...
...
@@ -8,8 +8,6 @@ import time
from
flatland.core.env
import
Environment
from
flatland.core.env_observation_builder
import
ObservationBuilder
from
flatland.core.env_prediction_builder
import
PredictionBuilder
from
flatland.envs.agent_utils
import
RailAgentStatus
,
EnvAgent
class
CustomObservationBuilder
(
ObservationBuilder
):
...
...
tests/test_flaltland_rail_agent_status.py
View file @
f4dc1668
from
flatland.core.grid.grid4
import
Grid4TransitionsEnum
from
flatland.envs.agent_utils
import
RailAgentStatus
from
flatland.envs.observations
import
TreeObsForRailEnv
from
flatland.envs.predictions
import
ShortestPathPredictorForRailEnv
from
flatland.envs.rail_env
import
RailEnv
,
RailEnvActions
...
...
@@ -7,7 +6,7 @@ from flatland.envs.rail_generators import rail_from_grid_transition_map
from
flatland.envs.line_generators
import
sparse_line_generator
from
flatland.utils.simple_rail
import
make_simple_rail
from
test_utils
import
ReplayConfig
,
Replay
,
run_replay_config
,
set_penalties_for_replay
from
flatland.envs.step_utils.states
import
TrainState
def
test_initial_status
():
"""Test that agent lifecycle works correctly ready-to-depart -> active -> done."""
...
...
@@ -30,7 +29,7 @@ def test_initial_status():
Replay
(
position
=
None
,
# not entered grid yet
direction
=
Grid4TransitionsEnum
.
EAST
,
stat
us
=
RailAgent
Stat
us
.
READY_TO_DEPART
,
stat
e
=
Train
Stat
e
.
READY_TO_DEPART
,
action
=
RailEnvActions
.
DO_NOTHING
,
reward
=
env
.
step_penalty
*
0.5
,
...
...
@@ -38,35 +37,35 @@ def test_initial_status():
Replay
(
position
=
None
,
# not entered grid yet before step
direction
=
Grid4TransitionsEnum
.
EAST
,
stat
us
=
RailAgent
Stat
us
.
READY_TO_DEPART
,
stat
e
=
Train
Stat
e
.
READY_TO_DEPART
,
action
=
RailEnvActions
.
MOVE_LEFT
,
reward
=
env
.
step_penalty
*
0.5
,
# auto-correction left to forward without penalty!
),
Replay
(
position
=
(
3
,
9
),
direction
=
Grid4TransitionsEnum
.
EAST
,
stat
us
=
RailAgentStatus
.
ACTIVE
,
stat
e
=
TrainState
.
MOVING
,
action
=
RailEnvActions
.
MOVE_LEFT
,
reward
=
env
.
start_penalty
+
env
.
step_penalty
*
0.5
,
# running at speed 0.5
),
Replay
(
position
=
(
3
,
9
),
direction
=
Grid4TransitionsEnum
.
EAST
,
stat
us
=
RailAgentStatus
.
ACTIVE
,
stat
e
=
TrainState
.
MOVING
,
action
=
None
,
reward
=
env
.
step_penalty
*
0.5
,
# running at speed 0.5
),
Replay
(
position
=
(
3
,
8
),
direction
=
Grid4TransitionsEnum
.
WEST
,
stat
us
=
RailAgentStatus
.
ACTIVE
,
stat
e
=
TrainState
.
MOVING
,
action
=
RailEnvActions
.
MOVE_FORWARD
,
reward
=
env
.
step_penalty
*
0.5
,
# running at speed 0.5
),
Replay
(
position
=
(
3
,
8
),
direction
=
Grid4TransitionsEnum
.
WEST
,
stat
us
=
RailAgentStatus
.
ACTIVE
,
stat
e
=
TrainState
.
MOVING
,
action
=
None
,
reward
=
env
.
step_penalty
*
0.5
,
# running at speed 0.5
...
...
@@ -76,28 +75,28 @@ def test_initial_status():
direction
=
Grid4TransitionsEnum
.
WEST
,
action
=
RailEnvActions
.
MOVE_FORWARD
,
reward
=
env
.
step_penalty
*
0.5
,
# running at speed 0.5
stat
us
=
RailAgentStatus
.
ACTIVE
stat
e
=
TrainState
.
MOVING
),
Replay
(
position
=
(
3
,
7
),
direction
=
Grid4TransitionsEnum
.
WEST
,
action
=
None
,
reward
=
env
.
step_penalty
*
0.5
,
# wrong action is corrected to forward without penalty!
stat
us
=
RailAgentStatus
.
ACTIVE
stat
e
=
TrainState
.
MOVING
),
Replay
(
position
=
(
3
,
6
),
direction
=
Grid4TransitionsEnum
.
WEST
,
action
=
RailEnvActions
.
MOVE_RIGHT
,
reward
=
env
.
step_penalty
*
0.5
,
#
stat
us
=
RailAgentStatus
.
ACTIVE
stat
e
=
TrainState
.
MOVING
),
Replay
(
position
=
(
3
,
6
),
direction
=
Grid4TransitionsEnum
.
WEST
,
action
=
None
,
reward
=
env
.
global_reward
,
#
stat
us
=
RailAgentStatus
.
ACTIVE
stat
e
=
TrainState
.
MOVING
),
# Replay(
# position=(3, 5),
...
...
@@ -122,7 +121,7 @@ def test_initial_status():
)
run_replay_config
(
env
,
[
test_config
],
activate_agents
=
False
,
skip_reward_check
=
True
)
assert
env
.
agents
[
0
].
stat
us
==
RailAgent
Stat
us
.
DONE
assert
env
.
agents
[
0
].
stat
e
==
Train
Stat
e
.
DONE
def
test_status_done_remove
():
...
...
@@ -146,7 +145,7 @@ def test_status_done_remove():
Replay
(
position
=
None
,
# not entered grid yet
direction
=
Grid4TransitionsEnum
.
EAST
,
stat
us
=
RailAgent
Stat
us
.
READY_TO_DEPART
,
stat
e
=
Train
Stat
e
.
READY_TO_DEPART
,
action
=
RailEnvActions
.
DO_NOTHING
,
reward
=
env
.
step_penalty
*
0.5
,
...
...
@@ -154,35 +153,35 @@ def test_status_done_remove():
Replay
(
position
=
None
,
# not entered grid yet before step
direction
=
Grid4TransitionsEnum
.
EAST
,
stat
us
=
RailAgent
Stat
us
.
READY_TO_DEPART
,
stat
e
=
Train
Stat
e
.
READY_TO_DEPART
,
action
=
RailEnvActions
.
MOVE_LEFT
,
reward
=
env
.
step_penalty
*
0.5
,
# auto-correction left to forward without penalty!
),
Replay
(
position
=
(
3
,
9
),
direction
=
Grid4TransitionsEnum
.
EAST
,
stat
us
=
RailAgentStatus
.
ACTIVE
,
stat
e
=
TrainState
.
MOVING
,
action
=
RailEnvActions
.
MOVE_FORWARD
,
reward
=
env
.
start_penalty
+
env
.
step_penalty
*
0.5
,
# running at speed 0.5
),
Replay
(
position
=
(
3
,
9
),
direction
=
Grid4TransitionsEnum
.
EAST
,
stat
us
=
RailAgentStatus
.
ACTIVE
,
stat
e
=
TrainState
.
MOVING
,
action
=
None
,
reward
=
env
.
step_penalty
*
0.5
,
# running at speed 0.5
),
Replay
(
position
=
(
3
,
8
),
direction
=
Grid4TransitionsEnum
.
WEST
,
stat
us
=
RailAgentStatus
.
ACTIVE
,
stat
e
=
TrainState
.
MOVING
,
action
=
RailEnvActions
.
MOVE_FORWARD
,
reward
=
env
.
step_penalty
*
0.5
,
# running at speed 0.5
),
Replay
(
position
=
(
3
,
8
),
direction
=
Grid4TransitionsEnum
.
WEST
,
stat
us
=
RailAgentStatus
.
ACTIVE
,
stat
e
=
TrainState
.
MOVING
,
action
=
None
,
reward
=
env
.
step_penalty
*
0.5
,
# running at speed 0.5
...
...
@@ -192,28 +191,28 @@ def test_status_done_remove():
direction
=
Grid4TransitionsEnum
.
WEST
,
action
=
RailEnvActions
.
MOVE_RIGHT
,
reward
=
env
.
step_penalty
*
0.5
,
# running at speed 0.5
stat
us
=
RailAgentStatus
.
ACTIVE
stat
e
=
TrainState
.
MOVING
),
Replay
(
position
=
(
3
,
7
),
direction
=
Grid4TransitionsEnum
.
WEST
,
action
=
None
,
reward
=
env
.
step_penalty
*
0.5
,
# wrong action is corrected to forward without penalty!
stat
us
=
RailAgentStatus
.
ACTIVE
stat
e
=
TrainState
.
MOVING
),
Replay
(
position
=
(
3
,
6
),
direction
=
Grid4TransitionsEnum
.
WEST
,
action
=
RailEnvActions
.
MOVE_FORWARD
,
reward
=
env
.
step_penalty
*
0.5
,
# done
stat
us
=
RailAgentStatus
.
ACTIVE
stat
e
=
TrainState
.
MOVING
),
Replay
(
position
=
(
3
,
6
),
direction
=
Grid4TransitionsEnum
.
WEST
,
action
=
None
,
reward
=
env
.
global_reward
,
# already done
stat
us
=
RailAgentStatus
.
ACTIVE
stat
e
=
TrainState
.
MOVING
),
# Replay(
# position=None,
...
...
@@ -238,4 +237,4 @@ def test_status_done_remove():
)
run_replay_config
(
env
,
[
test_config
],
activate_agents
=
False
,
skip_reward_check
=
True
)
assert
env
.
agents
[
0
].
stat
us
==
RailAgent
Stat
us
.
DONE
_REMOVED
assert
env
.
agents
[
0
].
stat
e
==
Train
Stat
e
.
DONE
tests/test_flatland_envs_observations.py
View file @
f4dc1668
...
...
@@ -5,7 +5,6 @@ import numpy as np
from
flatland.core.grid.grid4
import
Grid4TransitionsEnum
from
flatland.core.grid.grid4_utils
import
get_new_position
from
flatland.envs.agent_utils
import
EnvAgent
,
RailAgentStatus
from
flatland.envs.observations
import
GlobalObsForRailEnv
,
TreeObsForRailEnv
from
flatland.envs.predictions
import
ShortestPathPredictorForRailEnv
from
flatland.envs.rail_env
import
RailEnv
,
RailEnvActions
...
...
@@ -13,6 +12,7 @@ from flatland.envs.rail_generators import rail_from_grid_transition_map
from
flatland.envs.line_generators
import
sparse_line_generator
from
flatland.utils.rendertools
import
RenderTool
from
flatland.utils.simple_rail
import
make_simple_rail
from
flatland.envs.step_utils.states
import
TrainState
"""Tests for `flatland` package."""
...
...
@@ -106,7 +106,7 @@ def test_reward_function_conflict(rendering=False):
agent
.
initial_direction
=
0
# north
agent
.
target
=
(
3
,
9
)
# east dead-end
agent
.
moving
=
True
agent
.
status
=
RailAgentStatus
.
ACTIVE
agent
.
_set_state
(
TrainState
.
MOVING
)
agent
=
env
.
agents
[
1
]
agent
.
position
=
(
3
,
8
)
# east dead-end
...
...
@@ -115,13 +115,13 @@ def test_reward_function_conflict(rendering=False):
agent
.
initial_direction
=
3
# west
agent
.
target
=
(
6
,
6
)
# south dead-end
agent
.
moving
=
True
agent
.
status
=
RailAgentStatus
.
ACTIVE
agent
.
_set_state
(
TrainState
.
MOVING
)
env
.
reset
(
False
,
False
)
env
.
agents
[
0
].
moving
=
True
env
.
agents
[
1
].
moving
=
True
env
.
agents
[
0
].
status
=
RailAgentStatus
.
ACTIVE
env
.
agents
[
1
].
status
=
RailAgentStatus
.
ACTIVE
env
.
agents
[
0
].
_set_state
(
TrainState
.
MOVING
)
env
.
agents
[
1
].
_set_state
(
TrainState
.
MOVING
)
env
.
agents
[
0
].
position
=
(
5
,
6
)
env
.
agents
[
1
].
position
=
(
3
,
8
)
print
(
"
\n
"
)
...
...
@@ -195,7 +195,7 @@ def test_reward_function_waiting(rendering=False):
agent
.
initial_direction
=
3
# west
agent
.
target
=
(
3
,
1
)
# west dead-end
agent
.
moving
=
True
agent
.
status
=
RailAgentStatus
.
ACTIVE
agent
.
_set_state
(
TrainState
.
MOVING
)
agent
=
env
.
agents
[
1
]
agent
.
initial_position
=
(
5
,
6
)
# south dead-end
...
...
@@ -204,13 +204,13 @@ def test_reward_function_waiting(rendering=False):
agent
.
initial_direction
=
0
# north
agent
.
target
=
(
3
,
8
)
# east dead-end
agent
.
moving
=
True
agent
.
status
=
RailAgentStatus
.
ACTIVE
agent
.
_set_state
(
TrainState
.
MOVING
)
env
.
reset
(
False
,
False
)
env
.
agents
[
0
].
moving
=
True
env
.
agents
[
1
].
moving
=
True
env
.
agents
[
0
].
status
=
RailAgentStatus
.
ACTIVE
env
.
agents
[
1
].
status
=
RailAgentStatus
.
ACTIVE
env
.
agents
[
0
].
_set_state
(
TrainState
.
MOVING
)
env
.
agents
[
1
].
_set_state
(
TrainState
.
MOVING
)
env
.
agents
[
0
].
position
=
(
3
,
8
)
env
.
agents
[
1
].
position
=
(
5
,
6
)
...
...
tests/test_flatland_envs_predictions.py
View file @
f4dc1668
...
...
@@ -5,7 +5,6 @@ import pprint
import
numpy
as
np
from
flatland.core.grid.grid4
import
Grid4TransitionsEnum
from
flatland.envs.agent_utils
import
RailAgentStatus
from
flatland.envs.observations
import
TreeObsForRailEnv
,
Node
from
flatland.envs.predictions
import
DummyPredictorForRailEnv
,
ShortestPathPredictorForRailEnv
from
flatland.envs.rail_env
import
RailEnv
...
...
@@ -16,6 +15,7 @@ from flatland.envs.line_generators import sparse_line_generator
from
flatland.utils.rendertools
import
RenderTool
from
flatland.utils.simple_rail
import
make_simple_rail
,
make_simple_rail2
,
make_invalid_simple_rail
from
flatland.envs.rail_env_action
import
RailEnvActions
from
flatland.envs.step_utils.states
import
TrainState
"""Test predictions for `flatland` package."""
...
...
@@ -135,7 +135,7 @@ def test_shortest_path_predictor(rendering=False):
agent
.
initial_direction
=
0
# north
agent
.
target
=
(
3
,
9
)
# east dead-end
agent
.
moving
=
True
agent
.
status
=
RailAgentStatus
.
ACTIVE
agent
.
_set_state
(
TrainState
.
MOVING
)
env
.
reset
(
False
,
False
)
env
.
distance_map
.
_compute
(
env
.
agents
,
env
.
rail
)
...
...
@@ -269,7 +269,7 @@ def test_shortest_path_predictor_conflicts(rendering=False):
env
.
agents
[
0
].
initial_direction
=
0
# north
env
.
agents
[
0
].
target
=
(
3
,
9
)
# east dead-end
env
.
agents
[
0
].
moving
=
True
env
.
agents
[
0
].
status
=
RailAgentStatus
.
ACTIVE
env
.
agents
[
0
].
_set_state
(
TrainState
.
MOVING
)
env
.
agents
[
1
].
initial_position
=
(
3
,
8
)
# east dead-end
env
.
agents
[
1
].
position
=
(
3
,
8
)
# east dead-end
...
...
@@ -277,7 +277,7 @@ def test_shortest_path_predictor_conflicts(rendering=False):
env
.
agents
[
1
].
initial_direction
=
3
# west
env
.
agents
[
1
].
target
=
(
6
,
6
)
# south dead-end
env
.
agents
[
1
].
moving
=
True
env
.
agents
[
1
].
status
=
RailAgentStatus
.
ACTIVE
env
.
agents
[
1
].
_set_state
(
TrainState
.
MOVING
)
observations
,
info
=
env
.
reset
(
False
,
False
)
...
...
@@ -285,8 +285,8 @@ def test_shortest_path_predictor_conflicts(rendering=False):
env
.
agent_positions
[
env
.
agents
[
0
].
position
]
=
0
env
.
agents
[
1
].
position
=
(
3
,
8
)
# east dead-end
env
.
agent_positions
[
env
.
agents
[
1
].
position
]
=
1
env
.
agents
[
0
].
status
=
RailAgentStatus
.
ACTIVE
env
.
agents
[
1
].
status
=
RailAgentStatus
.
ACTIVE
env
.
agents
[
0
].
_set_state
(
TrainState
.
MOVING
)
env
.
agents
[
1
].
_set_state
(
TrainState
.
MOVING
)
observations
=
env
.
_get_observations
()
...
...
tests/test_flatland_envs_sparse_rail_generator.py
View file @
f4dc1668
...
...
@@ -1315,8 +1315,8 @@ def test_rail_env_action_required_info():
if
step
==
0
or
info_only_if_action_required
[
'action_required'
][
a
]:
action_dict_only_if_action_required
.
update
({
a
:
action
})
else
:
print
(
"[{}] not action_required {}, speed_
data
={}"
.
format
(
step
,
a
,
env_always_action
.
agents
[
a
].
speed_
data
))
print
(
"[{}] not action_required {}, speed_
counter
={}"
.
format
(
step
,
a
,
env_always_action
.
agents
[
a
].
speed_
counter
))
obs_always_action
,
rewards_always_action
,
done_always_action
,
info_always_action
=
env_always_action
.
step
(
action_dict_always_action
)
...
...
@@ -1375,7 +1375,7 @@ def test_rail_env_malfunction_speed_info():
for
a
in
range
(
env
.
get_num_agents
()):
assert
info
[
'malfunction'
][
a
]
>=
0
assert
info
[
'speed'
][
a
]
>=
0
and
info
[
'speed'
][
a
]
<=
1
assert
info
[
'speed'
][
a
]
==
env
.
agents
[
a
].
speed_
data
[
'
speed
'
]
assert
info
[
'speed'
][
a
]
==
env
.
agents
[
a
].
s
speed_
counter
.
speed
env_renderer
.
render_env
(
show
=
True
,
show_observations
=
False
,
show_predictions
=
False
)
...
...
tests/test_flatland_malfunction.py
View file @
f4dc1668
...
...
@@ -6,14 +6,14 @@ import numpy as np
from
flatland.core.env_observation_builder
import
ObservationBuilder
from
flatland.core.grid.grid4
import
Grid4TransitionsEnum
from
flatland.core.grid.grid4_utils
import
get_new_position
from
flatland.envs.agent_utils
import
RailAgentStatus
from
flatland.envs.malfunction_generators
import
malfunction_from_params
,
MalfunctionParameters
from
flatland.envs.rail_env
import
RailEnv
,
RailEnvActions
from
flatland.envs.rail_generators
import
rail_from_grid_transition_map
from
flatland.envs.line_generators
import
sparse_line_generator
from
flatland.utils.simple_rail
import
make_simple_rail2
from
test_utils
import
Replay
,
ReplayConfig
,
run_replay_config
,
set_penalties_for_replay
from
flatland.envs.step_utils.states
import
TrainState
from
flatland.envs.step_utils.speed_counter
import
SpeedCounter
class
SingleAgentNavigationObs
(
ObservationBuilder
):
"""
...
...
@@ -32,11 +32,11 @@ class SingleAgentNavigationObs(ObservationBuilder):
def
get
(
self
,
handle
:
int
=
0
)
->
List
[
int
]:
agent
=
self
.
env
.
agents
[
handle
]
if
agent
.
stat
us
==
RailAgentStatus
.
READY_TO_DEPART
:
if
agent
.
stat
e
.
is_off_map_state
()
:
agent_virtual_position
=
agent
.
initial_position
elif
agent
.
stat
us
==
RailAgentStatus
.
ACTIVE
:
elif
agent
.
stat
e
.
is_on_map_state
()
:
agent_virtual_position
=
agent
.
position
elif
agent
.
stat
us
==
RailAgent
Stat
us
.
DONE
:
elif
agent
.
stat
e
==
Train
Stat
e
.
DONE
:
agent_virtual_position
=
agent
.
target
else
:
return
None
...
...
@@ -85,7 +85,7 @@ def test_malfunction_process():
obs
,
info
=
env
.
reset
(
False
,
False
,
random_seed
=
10
)
for
a_idx
in
range
(
len
(
env
.
agents
)):
env
.
agents
[
a_idx
].
position
=
env
.
agents
[
a_idx
].
initial_position
env
.
agents
[
a_idx
].
stat
us
=
RailAgentStatus
.
ACTIVE
env
.
agents
[
a_idx
].
stat
e
=
TrainState
.
MOVING
agent_halts
=
0
total_down_time
=
0
...
...
@@ -297,7 +297,7 @@ def test_initial_malfunction():
reward
=
env
.
step_penalty
# running at speed 1.0
)
],
speed
=
env
.
agents
[
0
].
speed_
data
[
'
speed
'
]
,
speed
=
env
.
agents
[
0
].
speed_
counter
.
speed
,
target
=
env
.
agents
[
0
].
target
,
initial_position
=
(
3
,
2
),
initial_direction
=
Grid4TransitionsEnum
.
EAST
,
...
...
@@ -315,7 +315,7 @@ def test_initial_malfunction_stop_moving():
env
.
_max_episode_steps
=
1000
print
(
env
.
agents
[
0
].
initial_position
,
env
.
agents
[
0
].
direction
,
env
.
agents
[
0
].
position
,
env
.
agents
[
0
].
stat
us
)
print
(
env
.
agents
[
0
].
initial_position
,
env
.
agents
[
0
].
direction
,
env
.
agents
[
0
].
position
,
env
.
agents
[
0
].
stat
e
)
set_penalties_for_replay
(
env
)
replay_config
=
ReplayConfig
(
...
...
@@ -327,7 +327,7 @@ def test_initial_malfunction_stop_moving():
set_malfunction
=
3
,
malfunction
=
3
,
reward
=
env
.
step_penalty
,
# full step penalty when stopped
stat
us
=
RailAgent
Stat
us
.
READY_TO_DEPART
stat
e
=
Train
Stat
e
.
READY_TO_DEPART
),
Replay
(
position
=
(
3
,
2
),
...
...
@@ -335,7 +335,7 @@ def test_initial_malfunction_stop_moving():
action
=
RailEnvActions
.
DO_NOTHING
,
malfunction
=
2
,
reward
=
env
.
step_penalty
,
# full step penalty when stopped
stat
us
=
RailAgentStatus
.
ACTIVE
stat
e
=
TrainState
.
READY_TO_DEPART
),
# malfunction stops in the next step and we're still at the beginning of the cell
# --> if we take action STOP_MOVING, agent should restart without moving
...
...
@@ -346,7 +346,7 @@ def test_initial_malfunction_stop_moving():
action
=
RailEnvActions
.
STOP_MOVING
,
malfunction
=
1
,
reward
=
env
.
step_penalty
,
# full step penalty while stopped
stat
us
=
RailAgentStatus
.
ACTIVE
stat
e
=
TrainState
.
STOPPED
),
# we have stopped and do nothing --> should stand still
Replay
(
...
...
@@ -355,7 +355,7 @@ def test_initial_malfunction_stop_moving():
action
=
RailEnvActions
.
DO_NOTHING
,
malfunction
=
0
,
reward
=
env
.
step_penalty
,
# full step penalty while stopped
stat
us
=
RailAgentStatus
.
ACTIVE
stat
e
=
TrainState
.
STOPPED
),
# we start to move forward --> should go to next cell now
Replay
(
...
...
@@ -364,7 +364,7 @@ def test_initial_malfunction_stop_moving():
action
=
RailEnvActions
.
MOVE_FORWARD
,
malfunction
=
0
,
reward
=
env
.
start_penalty
+
env
.
step_penalty
*
1.0
,
# full step penalty while stopped
stat
us
=
RailAgentStatus
.
ACTIVE
stat
e
=
TrainState
.
STOPPED
),
Replay
(
position
=
(
3
,
3
),
...
...
@@ -372,10 +372,10 @@ def test_initial_malfunction_stop_moving():
action
=
RailEnvActions
.
MOVE_FORWARD
,
malfunction
=
0
,
reward
=
env
.
step_penalty
*
1.0
,
# full step penalty while stopped
stat
us
=
RailAgentStatus
.
ACTIVE
stat
e
=
TrainState
.
STOPPED
)
],
speed
=
env
.
agents
[
0
].
speed_
data
[
'
speed
'
]
,
speed
=
env
.
agents
[
0
].
speed_
counter
.
speed
,
target
=
env
.
agents
[
0
].
target
,
initial_position
=
(
3
,
2
),
initial_direction
=
Grid4TransitionsEnum
.
EAST
,
...
...
@@ -412,7 +412,7 @@ def test_initial_malfunction_do_nothing():
set_malfunction
=
3
,
malfunction
=
3
,
reward
=
env
.
step_penalty
,
# full step penalty while malfunctioning
stat
us
=
RailAgent
Stat
us
.
READY_TO_DEPART
stat
e
=
Train
Stat
e
.
READY_TO_DEPART
),
Replay
(
position
=
(
3
,
2
),
...
...
@@ -420,7 +420,7 @@ def test_initial_malfunction_do_nothing():
action
=
RailEnvActions
.
DO_NOTHING
,
malfunction
=
2
,
reward
=
env
.
step_penalty
,
# full step penalty while malfunctioning
stat
us
=
RailAgent
Stat
us
.
ACTIVE
stat
e
=
Train
Stat
e
.
ACTIVE
),
# malfunction stops in the next step and we're still at the beginning of the cell
# --> if we take action DO_NOTHING, agent should restart without moving
...
...
@@ -431,7 +431,7 @@ def test_initial_malfunction_do_nothing():
action
=
RailEnvActions
.
DO_NOTHING
,
malfunction
=
1
,
reward
=
env
.
step_penalty
,
# full step penalty while stopped
stat
us
=
RailAgent
Stat
us
.
ACTIVE
stat
e
=
Train
Stat
e
.
ACTIVE
),
# we haven't started moving yet --> stay here
Replay
(
...
...
@@ -440,7 +440,7 @@ def test_initial_malfunction_do_nothing():
action
=
RailEnvActions
.
DO_NOTHING
,
malfunction
=
0
,
reward
=
env
.
step_penalty
,
# full step penalty while stopped
stat
us
=
RailAgent
Stat
us
.
ACTIVE
stat
e
=
Train
Stat
e
.
ACTIVE
),
Replay
(
...
...
@@ -449,7 +449,7 @@ def test_initial_malfunction_do_nothing():
action
=
RailEnvActions
.
MOVE_FORWARD
,
malfunction
=
0
,
reward
=
env
.
start_penalty
+
env
.
step_penalty
*
1.0
,
# start penalty + step penalty for speed 1.0
stat
us
=
RailAgent
Stat
us
.
ACTIVE
stat
e
=
Train
Stat
e
.
ACTIVE
),
# we start to move forward --> should go to next cell now
Replay
(
position
=
(
3
,
3
),
...
...
@@ -457,10 +457,10 @@ def test_initial_malfunction_do_nothing():
action
=
RailEnvActions
.
MOVE_FORWARD
,
malfunction
=
0
,
reward
=
env
.
step_penalty
*
1.0
,
# step penalty for speed 1.0
stat
us
=
RailAgent
Stat
us
.
ACTIVE
stat
e
=
Train
Stat
e
.
ACTIVE
)
],
speed
=
env
.
agents
[
0
].
speed_
data
[
'
speed
'
]
,
speed
=
env
.
agents
[
0
].
speed_
counter
.
speed
,
target
=
env
.
agents
[
0
].
target
,
initial_position
=
(
3
,
2
),
initial_direction
=
Grid4TransitionsEnum
.
EAST
,
...
...
@@ -475,7 +475,7 @@ def tests_random_interference_from_outside():
env
=
RailEnv
(
width
=
25
,
height
=
30
,
rail_generator
=
rail_from_grid_transition_map
(
rail
,
optionals
),
line_generator
=
sparse_line_generator
(
seed
=
2
),
number_of_agents
=
1
,
random_seed
=
1
)
env
.
reset
()
env
.
agents
[
0
].
speed_
data
[
'speed'
]
=
0.33
env
.
agents
[
0
].
speed_
counter
=
SpeedCounter
(
speed
=
0.33
)
env
.
reset
(
False
,
False
,
random_seed
=
10
)
env_data
=
[]
...
...
@@ -501,7 +501,7 @@ def tests_random_interference_from_outside():
env
=
RailEnv
(
width
=
25
,
height
=
30
,
rail_generator
=
rail_from_grid_transition_map
(
rail
,
optionals
),
line_generator
=
sparse_line_generator
(
seed
=
2
),
number_of_agents
=
1
,
random_seed
=
1
)
env
.
reset
()
env
.
agents
[
0
].
speed_
data
[
'speed'
]
=
0.33
env
.
agents
[
0
].
speed_
counter
=
SpeedCounter
(
speed
=
0.33
)
env
.
reset
(
False
,
False
,
random_seed
=
10
)
dummy_list
=
[
1
,
2
,
6
,
7
,
8
,
9
,
4
,
5
,
4
]
...
...
@@ -536,7 +536,7 @@ def test_last_malfunction_step():
env
=
RailEnv
(
width
=
25
,
height
=
30
,
rail_generator
=
rail_from_grid_transition_map
(
rail
,
optionals
),
line_generator
=
sparse_line_generator
(
seed
=
2
),
number_of_agents
=
1
,
random_seed
=
1
)
env
.
reset
()
env
.
agents
[
0
].
speed_
data
[
'speed'
]
=
1.
/
3.
env
.
agents
[
0
].
speed_
counter
=
SpeedCounter
(
speed
=
1.
/
3.
)
env
.
agents
[
0
].
initial_position
=
(
6
,
6
)
env
.
agents
[
0
].
initial_direction
=
2
env
.
agents
[
0
].
target
=
(
0
,
3
)
...
...
@@ -546,7 +546,7 @@ def test_last_malfunction_step():
env
.
reset
(
False
,
False
)
for
a_idx
in
range
(
len
(
env
.
agents
)):
env
.
agents
[
a_idx
].
position
=
env
.
agents
[
a_idx
].
initial_position
env
.
agents
[
a_idx
].
stat
us
=
RailAgent
Stat
us
.
ACTIVE
env
.
agents
[
a_idx
].
stat
e
=
Train
Stat
e
.
ACTIVE
# Force malfunction to be off at beginning and next malfunction to happen in 2 steps
env
.
agents
[
0
].
malfunction_data
[
'next_malfunction'
]
=
2
env
.
agents
[
0
].
malfunction_data
[
'malfunction'
]
=
0
...
...
@@ -565,13 +565,13 @@ def test_last_malfunction_step():
if
env
.
agents
[
0
].
malfunction_data
[
'malfunction'
]
<
1
:
agent_can_move
=
True
# Store the position before and after the step
pre_position
=
env
.
agents
[
0
].
speed_
data
[
'position_fraction'
]
pre_position
=
env
.
agents
[
0
].
speed_
counter
.
counter
_
,
reward
,
_
,
_
=
env
.
step
(
action_dict
)
# Check if the agent is still allowed to move in this step
if
env
.
agents
[
0
].
malfunction_data
[
'malfunction'
]
>
0
:
agent_can_move
=
False
post_position
=
env
.
agents
[
0
].
speed_
data
[
'position_fraction'
]
post_position
=
env
.
agents
[
0
].
speed_
counter
.
counter
# Assert that the agent moved while it was still allowed
if
agent_can_move
:
assert
pre_position
!=
post_position
...
...
tests/test_generators.py
View file @
f4dc1668
...
...
@@ -10,7 +10,7 @@ from flatland.envs.rail_generators import rail_from_grid_transition_map, rail_fr
from
flatland.envs.line_generators
import
sparse_line_generator
,
line_from_file
from
flatland.utils.simple_rail
import
make_simple_rail
from
flatland.envs.persistence
import
RailEnvPersister
from
flatland.envs.
agent
_utils
import
RailAgent
Stat
us
from
flatland.envs.
step
_utils
.states
import
Train
Stat
e
def
test_empty_rail_generator
():
...
...
@@ -35,7 +35,7 @@ def test_rail_from_grid_transition_map():
for
a_idx
in
range
(
len
(
env
.
agents
)):
env
.
agents
[
a_idx
].
position
=
env
.
agents
[
a_idx
].
initial_position
env
.
agents
[
a_idx
].
status
=
RailAgentStatus
.
ACTIVE
env
.
agents
[
a_idx
].
_set_state
(
TrainState
.
MOVING
)
nr_rail_elements
=
np
.
count_nonzero
(
env
.
rail
.
grid
)
...
...
tests/test_global_observation.py
View file @
f4dc1668
import
numpy
as
np
from
flatland.envs.agent_utils
import
EnvAgent
,
RailAgentStatus
from
flatland.envs.agent_utils
import
EnvAgent
from
flatland.envs.observations
import
GlobalObsForRailEnv
from
flatland.envs.rail_env
import
RailEnv
,
RailEnvActions
from
flatland.envs.rail_generators
import
sparse_rail_generator
from
flatland.envs.line_generators
import
sparse_line_generator
from
flatland.envs.step_utils.states
import
TrainState