Skip to content
Snippets Groups Projects
Commit dd996ec6 authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

added and fixed tests for sparse_rail_generator

parent a0521bb8
No related branches found
No related tags found
No related merge requests found
This diff is collapsed.
...@@ -9,8 +9,9 @@ from flatland.core.grid.grid4 import Grid4TransitionsEnum ...@@ -9,8 +9,9 @@ from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.grid4_utils import get_new_position from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.agent_utils import RailAgentStatus from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.rail_env import RailEnv, RailEnvActions from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import complex_rail_generator, sparse_rail_generator from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.schedule_generators import complex_schedule_generator, sparse_schedule_generator from flatland.envs.schedule_generators import random_schedule_generator
from flatland.utils.simple_rail import make_simple_rail2
class SingleAgentNavigationObs(ObservationBuilder): class SingleAgentNavigationObs(ObservationBuilder):
...@@ -69,16 +70,27 @@ def test_malfunction_process(): ...@@ -69,16 +70,27 @@ def test_malfunction_process():
'malfunction_rate': 1000, 'malfunction_rate': 1000,
'min_duration': 3, 'min_duration': 3,
'max_duration': 3} 'max_duration': 3}
np.random.seed(5) random.seed(0)
np.random.seed(0)
env = RailEnv(width=20, stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents
height=20, 'malfunction_rate': 70, # Rate of malfunction occurence
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999, 'min_duration': 2, # Minimal duration of malfunction
seed=0), 'max_duration': 5 # Max duration of malfunction
schedule_generator=complex_schedule_generator(), }
number_of_agents=2,
obs_builder_object=SingleAgentNavigationObs(), rail, rail_map = make_simple_rail2()
stochastic_data=stochastic_data)
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(),
number_of_agents=1,
stochastic_data=stochastic_data, # Malfunction data generator
obs_builder_object=SingleAgentNavigationObs()
)
# reset to initialize agents_static
env.reset()
obs = env.reset(False, False, True) obs = env.reset(False, False, True)
...@@ -117,7 +129,7 @@ def test_malfunction_process(): ...@@ -117,7 +129,7 @@ def test_malfunction_process():
total_down_time += env.agents[0].malfunction_data['malfunction'] total_down_time += env.agents[0].malfunction_data['malfunction']
# Check that the appropriate number of malfunctions is achieved # Check that the appropriate number of malfunctions is achieved
assert env.agents[0].malfunction_data['nr_malfunctions'] == 21, "Actual {}".format( assert env.agents[0].malfunction_data['nr_malfunctions'] == 11, "Actual {}".format(
env.agents[0].malfunction_data['nr_malfunctions']) env.agents[0].malfunction_data['nr_malfunctions'])
# Check that 20 stops where performed # Check that 20 stops where performed
...@@ -135,16 +147,26 @@ def test_malfunction_process_statistically(): ...@@ -135,16 +147,26 @@ def test_malfunction_process_statistically():
'min_duration': 3, 'min_duration': 3,
'max_duration': 3} 'max_duration': 3}
env = RailEnv(width=20,
height=20,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999,
seed=0),
schedule_generator=complex_schedule_generator(),
number_of_agents=2,
obs_builder_object=SingleAgentNavigationObs(),
stochastic_data=stochastic_data)
np.random.seed(5)
random.seed(0) random.seed(0)
np.random.seed(0)
stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents
'malfunction_rate': 70, # Rate of malfunction occurence
'min_duration': 2, # Minimal duration of malfunction
'max_duration': 5 # Max duration of malfunction
}
rail, rail_map = make_simple_rail2()
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(),
number_of_agents=1,
stochastic_data=stochastic_data, # Malfunction data generator
obs_builder_object=SingleAgentNavigationObs()
)
# reset to initialize agents_static
env.reset(False, False, True) env.reset(False, False, True)
nb_malfunction = 0 nb_malfunction = 0
for step in range(100): for step in range(100):
...@@ -158,38 +180,39 @@ def test_malfunction_process_statistically(): ...@@ -158,38 +180,39 @@ def test_malfunction_process_statistically():
env.step(action_dict) env.step(action_dict)
# check that generation of malfunctions works as expected # check that generation of malfunctions works as expected
assert nb_malfunction == 128, "nb_malfunction={}".format(nb_malfunction) assert nb_malfunction == 3, "nb_malfunction={}".format(nb_malfunction)
def test_initial_malfunction(): def test_initial_malfunction():
random.seed(0)
np.random.seed(0)
stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents
'malfunction_rate': 70, # Rate of malfunction occurence 'malfunction_rate': 70, # Rate of malfunction occurence
'min_duration': 2, # Minimal duration of malfunction 'min_duration': 2, # Minimal duration of malfunction
'max_duration': 5 # Max duration of malfunction 'max_duration': 5 # Max duration of malfunction
} }
speed_ration_map = {1.: 1., # Fast passenger train rail, rail_map = make_simple_rail2()
1. / 2.: 0., # Fast freight train
1. / 3.: 0., # Slow commuter train
1. / 4.: 0.} # Slow freight train
np.random.seed(5)
random.seed(0)
env = RailEnv(width=25, env = RailEnv(width=25,
height=30, height=30,
rail_generator=sparse_rail_generator(max_num_cities=5, rail_generator=rail_from_grid_transition_map(rail),
max_rails_between_cities=3, schedule_generator=random_schedule_generator(),
seed=215545,
grid_mode=True
),
schedule_generator=sparse_schedule_generator(speed_ration_map),
number_of_agents=1, number_of_agents=1,
stochastic_data=stochastic_data, # Malfunction data generator stochastic_data=stochastic_data, # Malfunction data generator
obs_builder_object=SingleAgentNavigationObs()
) )
# reset to initialize agents_static
env.reset(False, False, True)
set_penalties_for_replay(env) set_penalties_for_replay(env)
replay_config = ReplayConfig( replay_config = ReplayConfig(
replay=[ replay=[
Replay( Replay(
position=(28, 5), position=(3, 2),
direction=Grid4TransitionsEnum.EAST, direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD, action=RailEnvActions.MOVE_FORWARD,
set_malfunction=3, set_malfunction=3,
...@@ -197,7 +220,7 @@ def test_initial_malfunction(): ...@@ -197,7 +220,7 @@ def test_initial_malfunction():
reward=env.step_penalty # full step penalty when malfunctioning reward=env.step_penalty # full step penalty when malfunctioning
), ),
Replay( Replay(
position=(28, 5), position=(3, 2),
direction=Grid4TransitionsEnum.EAST, direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD, action=RailEnvActions.MOVE_FORWARD,
malfunction=2, malfunction=2,
...@@ -206,7 +229,7 @@ def test_initial_malfunction(): ...@@ -206,7 +229,7 @@ def test_initial_malfunction():
# malfunction stops in the next step and we're still at the beginning of the cell # malfunction stops in the next step and we're still at the beginning of the cell
# --> if we take action MOVE_FORWARD, agent should restart and move to the next cell # --> if we take action MOVE_FORWARD, agent should restart and move to the next cell
Replay( Replay(
position=(28, 5), position=(3, 2),
direction=Grid4TransitionsEnum.EAST, direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD, action=RailEnvActions.MOVE_FORWARD,
malfunction=1, malfunction=1,
...@@ -214,15 +237,15 @@ def test_initial_malfunction(): ...@@ -214,15 +237,15 @@ def test_initial_malfunction():
# malfunctioning ends: starting and running at speed 1.0 # malfunctioning ends: starting and running at speed 1.0
), ),
Replay( Replay(
position=(28, 6), position=(3, 3),
direction=Grid4TransitionsEnum.EAST, direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD, action=RailEnvActions.MOVE_FORWARD,
malfunction=0, malfunction=0,
reward=env.step_penalty * 1.0 # running at speed 1.0 reward=env.step_penalty * 1.0 # running at speed 1.0
), ),
Replay( Replay(
position=(27, 6), position=(3, 4),
direction=Grid4TransitionsEnum.NORTH, direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD, action=RailEnvActions.MOVE_FORWARD,
malfunction=0, malfunction=0,
reward=env.step_penalty * 1.0 # running at speed 1.0 reward=env.step_penalty * 1.0 # running at speed 1.0
...@@ -230,37 +253,36 @@ def test_initial_malfunction(): ...@@ -230,37 +253,36 @@ def test_initial_malfunction():
], ],
speed=env.agents[0].speed_data['speed'], speed=env.agents[0].speed_data['speed'],
target=env.agents[0].target, target=env.agents[0].target,
initial_position=(28, 5), initial_position=(3, 2),
initial_direction=Grid4TransitionsEnum.EAST, initial_direction=Grid4TransitionsEnum.EAST,
) )
sparse_generator_stable = False run_replay_config(env, [replay_config])
if sparse_generator_stable:
run_replay_config(env, [replay_config])
def test_initial_malfunction_stop_moving(): def test_initial_malfunction_stop_moving():
random.seed(0)
np.random.seed(0)
stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents
'malfunction_rate': 70, # Rate of malfunction occurence 'malfunction_rate': 70, # Rate of malfunction occurence
'min_duration': 2, # Minimal duration of malfunction 'min_duration': 2, # Minimal duration of malfunction
'max_duration': 5 # Max duration of malfunction 'max_duration': 5 # Max duration of malfunction
} }
speed_ration_map = {1.: 1., # Fast passenger train rail, rail_map = make_simple_rail2()
1. / 2.: 0., # Fast freight train
1. / 3.: 0., # Slow commuter train
1. / 4.: 0.} # Slow freight train
env = RailEnv(width=25, env = RailEnv(width=25,
height=30, height=30,
rail_generator=sparse_rail_generator(max_num_cities=5, rail_generator=rail_from_grid_transition_map(rail),
max_rails_between_cities=3, schedule_generator=random_schedule_generator(),
seed=215545,
grid_mode=True
),
schedule_generator=sparse_schedule_generator(speed_ration_map),
number_of_agents=1, number_of_agents=1,
stochastic_data=stochastic_data, # Malfunction data generator stochastic_data=stochastic_data, # Malfunction data generator
obs_builder_object=SingleAgentNavigationObs()
) )
# reset to initialize agents_static
print(env.agents[0].initial_position, env.agents[0].direction, env.agents[0].position, env.agents[0].status)
set_penalties_for_replay(env) set_penalties_for_replay(env)
replay_config = ReplayConfig( replay_config = ReplayConfig(
replay=[ replay=[
...@@ -274,7 +296,7 @@ def test_initial_malfunction_stop_moving(): ...@@ -274,7 +296,7 @@ def test_initial_malfunction_stop_moving():
status=RailAgentStatus.READY_TO_DEPART status=RailAgentStatus.READY_TO_DEPART
), ),
Replay( Replay(
position=(28, 5), position=(3, 2),
direction=Grid4TransitionsEnum.EAST, direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING, action=RailEnvActions.DO_NOTHING,
malfunction=2, malfunction=2,
...@@ -285,7 +307,7 @@ def test_initial_malfunction_stop_moving(): ...@@ -285,7 +307,7 @@ def test_initial_malfunction_stop_moving():
# --> if we take action STOP_MOVING, agent should restart without moving # --> if we take action STOP_MOVING, agent should restart without moving
# #
Replay( Replay(
position=(28, 5), position=(3, 2),
direction=Grid4TransitionsEnum.EAST, direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.STOP_MOVING, action=RailEnvActions.STOP_MOVING,
malfunction=1, malfunction=1,
...@@ -294,7 +316,7 @@ def test_initial_malfunction_stop_moving(): ...@@ -294,7 +316,7 @@ def test_initial_malfunction_stop_moving():
), ),
# we have stopped and do nothing --> should stand still # we have stopped and do nothing --> should stand still
Replay( Replay(
position=(28, 5), position=(3, 2),
direction=Grid4TransitionsEnum.EAST, direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING, action=RailEnvActions.DO_NOTHING,
malfunction=0, malfunction=0,
...@@ -303,7 +325,7 @@ def test_initial_malfunction_stop_moving(): ...@@ -303,7 +325,7 @@ def test_initial_malfunction_stop_moving():
), ),
# we start to move forward --> should go to next cell now # we start to move forward --> should go to next cell now
Replay( Replay(
position=(28, 5), position=(3, 2),
direction=Grid4TransitionsEnum.EAST, direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD, action=RailEnvActions.MOVE_FORWARD,
malfunction=0, malfunction=0,
...@@ -311,7 +333,7 @@ def test_initial_malfunction_stop_moving(): ...@@ -311,7 +333,7 @@ def test_initial_malfunction_stop_moving():
status=RailAgentStatus.ACTIVE status=RailAgentStatus.ACTIVE
), ),
Replay( Replay(
position=(28, 6), position=(3, 3),
direction=Grid4TransitionsEnum.EAST, direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD, action=RailEnvActions.MOVE_FORWARD,
malfunction=0, malfunction=0,
...@@ -321,12 +343,11 @@ def test_initial_malfunction_stop_moving(): ...@@ -321,12 +343,11 @@ def test_initial_malfunction_stop_moving():
], ],
speed=env.agents[0].speed_data['speed'], speed=env.agents[0].speed_data['speed'],
target=env.agents[0].target, target=env.agents[0].target,
initial_position=(28, 5), initial_position=(3, 2),
initial_direction=Grid4TransitionsEnum.EAST, initial_direction=Grid4TransitionsEnum.EAST,
) )
sparse_generator_stable = False
if sparse_generator_stable: run_replay_config(env, [replay_config], activate_agents=False)
run_replay_config(env, [replay_config], activate_agents=False)
def test_initial_malfunction_do_nothing(): def test_initial_malfunction_do_nothing():
...@@ -339,22 +360,18 @@ def test_initial_malfunction_do_nothing(): ...@@ -339,22 +360,18 @@ def test_initial_malfunction_do_nothing():
'max_duration': 5 # Max duration of malfunction 'max_duration': 5 # Max duration of malfunction
} }
speed_ration_map = {1.: 1., # Fast passenger train rail, rail_map = make_simple_rail2()
1. / 2.: 0., # Fast freight train
1. / 3.: 0., # Slow commuter train
1. / 4.: 0.} # Slow freight train
env = RailEnv(width=25, env = RailEnv(width=25,
height=30, height=30,
rail_generator=sparse_rail_generator(max_num_cities=5, rail_generator=rail_from_grid_transition_map(rail),
max_rails_between_cities=3, schedule_generator=random_schedule_generator(),
seed=215545,
grid_mode=True
),
schedule_generator=sparse_schedule_generator(speed_ration_map),
number_of_agents=1, number_of_agents=1,
stochastic_data=stochastic_data, # Malfunction data generator stochastic_data=stochastic_data, # Malfunction data generator
) )
# reset to initialize agents_static
env.reset()
set_penalties_for_replay(env) set_penalties_for_replay(env)
replay_config = ReplayConfig( replay_config = ReplayConfig(
replay=[ replay=[
...@@ -368,7 +385,7 @@ def test_initial_malfunction_do_nothing(): ...@@ -368,7 +385,7 @@ def test_initial_malfunction_do_nothing():
status=RailAgentStatus.READY_TO_DEPART status=RailAgentStatus.READY_TO_DEPART
), ),
Replay( Replay(
position=(28, 5), position=(3, 2),
direction=Grid4TransitionsEnum.EAST, direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING, action=RailEnvActions.DO_NOTHING,
malfunction=2, malfunction=2,
...@@ -379,7 +396,7 @@ def test_initial_malfunction_do_nothing(): ...@@ -379,7 +396,7 @@ def test_initial_malfunction_do_nothing():
# --> if we take action DO_NOTHING, agent should restart without moving # --> if we take action DO_NOTHING, agent should restart without moving
# #
Replay( Replay(
position=(28, 5), position=(3, 2),
direction=Grid4TransitionsEnum.EAST, direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING, action=RailEnvActions.DO_NOTHING,
malfunction=1, malfunction=1,
...@@ -388,24 +405,24 @@ def test_initial_malfunction_do_nothing(): ...@@ -388,24 +405,24 @@ def test_initial_malfunction_do_nothing():
), ),
# we haven't started moving yet --> stay here # we haven't started moving yet --> stay here
Replay( Replay(
position=(28, 5), position=(3, 2),
direction=Grid4TransitionsEnum.EAST, direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING, action=RailEnvActions.DO_NOTHING,
malfunction=0, malfunction=0,
reward=env.step_penalty, # full step penalty while stopped reward=env.step_penalty, # full step penalty while stopped
status=RailAgentStatus.ACTIVE status=RailAgentStatus.ACTIVE
), ),
# we start to move forward --> should go to next cell now
Replay( Replay(
position=(28, 5), position=(3, 2),
direction=Grid4TransitionsEnum.EAST, direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD, action=RailEnvActions.MOVE_FORWARD,
malfunction=0, malfunction=0,
reward=env.start_penalty + env.step_penalty * 1.0, # start penalty + step penalty for speed 1.0 reward=env.start_penalty + env.step_penalty * 1.0, # start penalty + step penalty for speed 1.0
status=RailAgentStatus.ACTIVE status=RailAgentStatus.ACTIVE
), ), # we start to move forward --> should go to next cell now
Replay( Replay(
position=(28, 6), position=(3, 3),
direction=Grid4TransitionsEnum.EAST, direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD, action=RailEnvActions.MOVE_FORWARD,
malfunction=0, malfunction=0,
...@@ -415,12 +432,10 @@ def test_initial_malfunction_do_nothing(): ...@@ -415,12 +432,10 @@ def test_initial_malfunction_do_nothing():
], ],
speed=env.agents[0].speed_data['speed'], speed=env.agents[0].speed_data['speed'],
target=env.agents[0].target, target=env.agents[0].target,
initial_position=(28, 5), initial_position=(3, 2),
initial_direction=Grid4TransitionsEnum.EAST, initial_direction=Grid4TransitionsEnum.EAST,
) )
sparse_generator_stable = False run_replay_config(env, [replay_config], activate_agents=False)
if sparse_generator_stable:
run_replay_config(env, [replay_config], activate_agents=False)
def test_initial_nextmalfunction_not_below_zero(): def test_initial_nextmalfunction_not_below_zero():
...@@ -428,27 +443,23 @@ def test_initial_nextmalfunction_not_below_zero(): ...@@ -428,27 +443,23 @@ def test_initial_nextmalfunction_not_below_zero():
np.random.seed(0) np.random.seed(0)
stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents
'malfunction_rate': 0.5, # Rate of malfunction occurence 'malfunction_rate': 70, # Rate of malfunction occurence
'min_duration': 5, # Minimal duration of malfunction 'min_duration': 2, # Minimal duration of malfunction
'max_duration': 5 # Max duration of malfunction 'max_duration': 5 # Max duration of malfunction
} }
speed_ration_map = {1.: 1., # Fast passenger train rail, rail_map = make_simple_rail2()
1. / 2.: 0., # Fast freight train
1. / 3.: 0., # Slow commuter train
1. / 4.: 0.} # Slow freight train
env = RailEnv(width=25, env = RailEnv(width=25,
height=30, height=30,
rail_generator=sparse_rail_generator(max_num_cities=5, rail_generator=rail_from_grid_transition_map(rail),
max_rails_between_cities=3, schedule_generator=random_schedule_generator(),
seed=215545,
grid_mode=True
),
schedule_generator=sparse_schedule_generator(speed_ration_map),
number_of_agents=1, number_of_agents=1,
stochastic_data=stochastic_data, # Malfunction data generator stochastic_data=stochastic_data, # Malfunction data generator
obs_builder_object=SingleAgentNavigationObs()
) )
# reset to initialize agents_static
env.reset()
agent = env.agents[0] agent = env.agents[0]
env.step({}) env.step({})
# was next_malfunction was -1 befor the bugfix https://gitlab.aicrowd.com/flatland/flatland/issues/186 # was next_malfunction was -1 befor the bugfix https://gitlab.aicrowd.com/flatland/flatland/issues/186
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment