Commit dd996ec6 authored by Erik Nygren's avatar Erik Nygren 🚅
Browse files

added and fixed tests for sparse_rail_generator

parent a0521bb8
Pipeline #2380 passed with stages
in 44 minutes and 45 seconds
......@@ -9,8 +9,9 @@ from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import complex_rail_generator, sparse_rail_generator
from flatland.envs.schedule_generators import complex_schedule_generator, sparse_schedule_generator
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.schedule_generators import random_schedule_generator
from flatland.utils.simple_rail import make_simple_rail2
class SingleAgentNavigationObs(ObservationBuilder):
......@@ -69,16 +70,27 @@ def test_malfunction_process():
'malfunction_rate': 1000,
'min_duration': 3,
'max_duration': 3}
np.random.seed(5)
random.seed(0)
np.random.seed(0)
env = RailEnv(width=20,
height=20,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999,
seed=0),
schedule_generator=complex_schedule_generator(),
number_of_agents=2,
obs_builder_object=SingleAgentNavigationObs(),
stochastic_data=stochastic_data)
stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents
'malfunction_rate': 70, # Rate of malfunction occurence
'min_duration': 2, # Minimal duration of malfunction
'max_duration': 5 # Max duration of malfunction
}
rail, rail_map = make_simple_rail2()
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(),
number_of_agents=1,
stochastic_data=stochastic_data, # Malfunction data generator
obs_builder_object=SingleAgentNavigationObs()
)
# reset to initialize agents_static
env.reset()
obs = env.reset(False, False, True)
......@@ -117,7 +129,7 @@ def test_malfunction_process():
total_down_time += env.agents[0].malfunction_data['malfunction']
# Check that the appropriate number of malfunctions is achieved
assert env.agents[0].malfunction_data['nr_malfunctions'] == 21, "Actual {}".format(
assert env.agents[0].malfunction_data['nr_malfunctions'] == 11, "Actual {}".format(
env.agents[0].malfunction_data['nr_malfunctions'])
# Check that 20 stops where performed
......@@ -135,16 +147,26 @@ def test_malfunction_process_statistically():
'min_duration': 3,
'max_duration': 3}
env = RailEnv(width=20,
height=20,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999,
seed=0),
schedule_generator=complex_schedule_generator(),
number_of_agents=2,
obs_builder_object=SingleAgentNavigationObs(),
stochastic_data=stochastic_data)
np.random.seed(5)
random.seed(0)
np.random.seed(0)
stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents
'malfunction_rate': 70, # Rate of malfunction occurence
'min_duration': 2, # Minimal duration of malfunction
'max_duration': 5 # Max duration of malfunction
}
rail, rail_map = make_simple_rail2()
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(),
number_of_agents=1,
stochastic_data=stochastic_data, # Malfunction data generator
obs_builder_object=SingleAgentNavigationObs()
)
# reset to initialize agents_static
env.reset(False, False, True)
nb_malfunction = 0
for step in range(100):
......@@ -158,38 +180,39 @@ def test_malfunction_process_statistically():
env.step(action_dict)
# check that generation of malfunctions works as expected
assert nb_malfunction == 128, "nb_malfunction={}".format(nb_malfunction)
assert nb_malfunction == 3, "nb_malfunction={}".format(nb_malfunction)
def test_initial_malfunction():
random.seed(0)
np.random.seed(0)
stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents
'malfunction_rate': 70, # Rate of malfunction occurence
'min_duration': 2, # Minimal duration of malfunction
'max_duration': 5 # Max duration of malfunction
}
speed_ration_map = {1.: 1., # Fast passenger train
1. / 2.: 0., # Fast freight train
1. / 3.: 0., # Slow commuter train
1. / 4.: 0.} # Slow freight train
np.random.seed(5)
random.seed(0)
rail, rail_map = make_simple_rail2()
env = RailEnv(width=25,
height=30,
rail_generator=sparse_rail_generator(max_num_cities=5,
max_rails_between_cities=3,
seed=215545,
grid_mode=True
),
schedule_generator=sparse_schedule_generator(speed_ration_map),
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(),
number_of_agents=1,
stochastic_data=stochastic_data, # Malfunction data generator
obs_builder_object=SingleAgentNavigationObs()
)
# reset to initialize agents_static
env.reset(False, False, True)
set_penalties_for_replay(env)
replay_config = ReplayConfig(
replay=[
Replay(
position=(28, 5),
position=(3, 2),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
set_malfunction=3,
......@@ -197,7 +220,7 @@ def test_initial_malfunction():
reward=env.step_penalty # full step penalty when malfunctioning
),
Replay(
position=(28, 5),
position=(3, 2),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=2,
......@@ -206,7 +229,7 @@ def test_initial_malfunction():
# malfunction stops in the next step and we're still at the beginning of the cell
# --> if we take action MOVE_FORWARD, agent should restart and move to the next cell
Replay(
position=(28, 5),
position=(3, 2),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=1,
......@@ -214,15 +237,15 @@ def test_initial_malfunction():
# malfunctioning ends: starting and running at speed 1.0
),
Replay(
position=(28, 6),
position=(3, 3),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=0,
reward=env.step_penalty * 1.0 # running at speed 1.0
),
Replay(
position=(27, 6),
direction=Grid4TransitionsEnum.NORTH,
position=(3, 4),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=0,
reward=env.step_penalty * 1.0 # running at speed 1.0
......@@ -230,37 +253,36 @@ def test_initial_malfunction():
],
speed=env.agents[0].speed_data['speed'],
target=env.agents[0].target,
initial_position=(28, 5),
initial_position=(3, 2),
initial_direction=Grid4TransitionsEnum.EAST,
)
sparse_generator_stable = False
if sparse_generator_stable:
run_replay_config(env, [replay_config])
run_replay_config(env, [replay_config])
def test_initial_malfunction_stop_moving():
random.seed(0)
np.random.seed(0)
stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents
'malfunction_rate': 70, # Rate of malfunction occurence
'min_duration': 2, # Minimal duration of malfunction
'max_duration': 5 # Max duration of malfunction
}
speed_ration_map = {1.: 1., # Fast passenger train
1. / 2.: 0., # Fast freight train
1. / 3.: 0., # Slow commuter train
1. / 4.: 0.} # Slow freight train
rail, rail_map = make_simple_rail2()
env = RailEnv(width=25,
height=30,
rail_generator=sparse_rail_generator(max_num_cities=5,
max_rails_between_cities=3,
seed=215545,
grid_mode=True
),
schedule_generator=sparse_schedule_generator(speed_ration_map),
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(),
number_of_agents=1,
stochastic_data=stochastic_data, # Malfunction data generator
obs_builder_object=SingleAgentNavigationObs()
)
# reset to initialize agents_static
print(env.agents[0].initial_position, env.agents[0].direction, env.agents[0].position, env.agents[0].status)
set_penalties_for_replay(env)
replay_config = ReplayConfig(
replay=[
......@@ -274,7 +296,7 @@ def test_initial_malfunction_stop_moving():
status=RailAgentStatus.READY_TO_DEPART
),
Replay(
position=(28, 5),
position=(3, 2),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING,
malfunction=2,
......@@ -285,7 +307,7 @@ def test_initial_malfunction_stop_moving():
# --> if we take action STOP_MOVING, agent should restart without moving
#
Replay(
position=(28, 5),
position=(3, 2),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.STOP_MOVING,
malfunction=1,
......@@ -294,7 +316,7 @@ def test_initial_malfunction_stop_moving():
),
# we have stopped and do nothing --> should stand still
Replay(
position=(28, 5),
position=(3, 2),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING,
malfunction=0,
......@@ -303,7 +325,7 @@ def test_initial_malfunction_stop_moving():
),
# we start to move forward --> should go to next cell now
Replay(
position=(28, 5),
position=(3, 2),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=0,
......@@ -311,7 +333,7 @@ def test_initial_malfunction_stop_moving():
status=RailAgentStatus.ACTIVE
),
Replay(
position=(28, 6),
position=(3, 3),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=0,
......@@ -321,12 +343,11 @@ def test_initial_malfunction_stop_moving():
],
speed=env.agents[0].speed_data['speed'],
target=env.agents[0].target,
initial_position=(28, 5),
initial_position=(3, 2),
initial_direction=Grid4TransitionsEnum.EAST,
)
sparse_generator_stable = False
if sparse_generator_stable:
run_replay_config(env, [replay_config], activate_agents=False)
run_replay_config(env, [replay_config], activate_agents=False)
def test_initial_malfunction_do_nothing():
......@@ -339,22 +360,18 @@ def test_initial_malfunction_do_nothing():
'max_duration': 5 # Max duration of malfunction
}
speed_ration_map = {1.: 1., # Fast passenger train
1. / 2.: 0., # Fast freight train
1. / 3.: 0., # Slow commuter train
1. / 4.: 0.} # Slow freight train
rail, rail_map = make_simple_rail2()
env = RailEnv(width=25,
height=30,
rail_generator=sparse_rail_generator(max_num_cities=5,
max_rails_between_cities=3,
seed=215545,
grid_mode=True
),
schedule_generator=sparse_schedule_generator(speed_ration_map),
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(),
number_of_agents=1,
stochastic_data=stochastic_data, # Malfunction data generator
)
# reset to initialize agents_static
env.reset()
set_penalties_for_replay(env)
replay_config = ReplayConfig(
replay=[
......@@ -368,7 +385,7 @@ def test_initial_malfunction_do_nothing():
status=RailAgentStatus.READY_TO_DEPART
),
Replay(
position=(28, 5),
position=(3, 2),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING,
malfunction=2,
......@@ -379,7 +396,7 @@ def test_initial_malfunction_do_nothing():
# --> if we take action DO_NOTHING, agent should restart without moving
#
Replay(
position=(28, 5),
position=(3, 2),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING,
malfunction=1,
......@@ -388,24 +405,24 @@ def test_initial_malfunction_do_nothing():
),
# we haven't started moving yet --> stay here
Replay(
position=(28, 5),
position=(3, 2),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING,
malfunction=0,
reward=env.step_penalty, # full step penalty while stopped
status=RailAgentStatus.ACTIVE
),
# we start to move forward --> should go to next cell now
Replay(
position=(28, 5),
position=(3, 2),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=0,
reward=env.start_penalty + env.step_penalty * 1.0, # start penalty + step penalty for speed 1.0
status=RailAgentStatus.ACTIVE
),
), # we start to move forward --> should go to next cell now
Replay(
position=(28, 6),
position=(3, 3),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=0,
......@@ -415,12 +432,10 @@ def test_initial_malfunction_do_nothing():
],
speed=env.agents[0].speed_data['speed'],
target=env.agents[0].target,
initial_position=(28, 5),
initial_position=(3, 2),
initial_direction=Grid4TransitionsEnum.EAST,
)
sparse_generator_stable = False
if sparse_generator_stable:
run_replay_config(env, [replay_config], activate_agents=False)
run_replay_config(env, [replay_config], activate_agents=False)
def test_initial_nextmalfunction_not_below_zero():
......@@ -428,27 +443,23 @@ def test_initial_nextmalfunction_not_below_zero():
np.random.seed(0)
stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents
'malfunction_rate': 0.5, # Rate of malfunction occurence
'min_duration': 5, # Minimal duration of malfunction
'malfunction_rate': 70, # Rate of malfunction occurence
'min_duration': 2, # Minimal duration of malfunction
'max_duration': 5 # Max duration of malfunction
}
speed_ration_map = {1.: 1., # Fast passenger train
1. / 2.: 0., # Fast freight train
1. / 3.: 0., # Slow commuter train
1. / 4.: 0.} # Slow freight train
rail, rail_map = make_simple_rail2()
env = RailEnv(width=25,
height=30,
rail_generator=sparse_rail_generator(max_num_cities=5,
max_rails_between_cities=3,
seed=215545,
grid_mode=True
),
schedule_generator=sparse_schedule_generator(speed_ration_map),
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(),
number_of_agents=1,
stochastic_data=stochastic_data, # Malfunction data generator
obs_builder_object=SingleAgentNavigationObs()
)
# reset to initialize agents_static
env.reset()
agent = env.agents[0]
env.step({})
# was next_malfunction was -1 befor the bugfix https://gitlab.aicrowd.com/flatland/flatland/issues/186
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment