Commit 58b08468 authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

Fix all tests that dont have hardcoded values

parent 1011d2a9
......@@ -44,11 +44,11 @@ def make_simple_rail() -> Tuple[GridTransitionMap, np.array]:
rail.grid = rail_map
city_positions = [(0,3), (6, 6)]
train_stations = [
[( (0, 3), 0 ), ( (1, 3), 1 ) ],
[( (6, 6), 0 ), ( (5, 6), 1 ) ],
[( (0, 3), 0 ) ],
[( (6, 6), 0 ) ],
]
city_orientations = [0, 2]
agents_hints = {'num_agents': 100,
agents_hints = {'num_agents': 2,
'city_positions': city_positions,
'train_stations': train_stations,
'city_orientations': city_orientations
......@@ -94,7 +94,19 @@ def make_disconnected_simple_rail() -> Tuple[GridTransitionMap, np.array]:
rail = GridTransitionMap(width=rail_map.shape[1],
height=rail_map.shape[0], transitions=transitions)
rail.grid = rail_map
return rail, rail_map
city_positions = [(0,3), (6, 6)]
train_stations = [
[( (0, 3), 0 ) ],
[( (6, 6), 0 ) ],
]
city_orientations = [0, 2]
agents_hints = {'num_agents': 2,
'city_positions': city_positions,
'train_stations': train_stations,
'city_orientations': city_orientations
}
optionals = {'agents_hints': agents_hints}
return rail, rail_map, optionals
def make_simple_rail2() -> Tuple[GridTransitionMap, np.array]:
......@@ -131,7 +143,19 @@ def make_simple_rail2() -> Tuple[GridTransitionMap, np.array]:
rail = GridTransitionMap(width=rail_map.shape[1],
height=rail_map.shape[0], transitions=transitions)
rail.grid = rail_map
return rail, rail_map
city_positions = [(0,3), (6, 6)]
train_stations = [
[( (0, 3), 0 ) ],
[( (6, 6), 0 ) ],
]
city_orientations = [0, 2]
agents_hints = {'num_agents': 2,
'city_positions': city_positions,
'train_stations': train_stations,
'city_orientations': city_orientations
}
optionals = {'agents_hints': agents_hints}
return rail, rail_map, optionals
def make_simple_rail_unconnected() -> Tuple[GridTransitionMap, np.array]:
......@@ -169,7 +193,19 @@ def make_simple_rail_unconnected() -> Tuple[GridTransitionMap, np.array]:
rail = GridTransitionMap(width=rail_map.shape[1],
height=rail_map.shape[0], transitions=transitions)
rail.grid = rail_map
return rail, rail_map
city_positions = [(0,3), (6, 6)]
train_stations = [
[( (0, 3), 0 ) ],
[( (6, 6), 0 ) ],
]
city_orientations = [0, 2]
agents_hints = {'num_agents': 2,
'city_positions': city_positions,
'train_stations': train_stations,
'city_orientations': city_orientations
}
optionals = {'agents_hints': agents_hints}
return rail, rail_map, optionals
def make_simple_rail_with_alternatives() -> Tuple[GridTransitionMap, np.array]:
......@@ -213,7 +249,20 @@ def make_simple_rail_with_alternatives() -> Tuple[GridTransitionMap, np.array]:
rail = GridTransitionMap(width=rail_map.shape[1],
height=rail_map.shape[0], transitions=transitions)
rail.grid = rail_map
return rail, rail_map
city_positions = [(0,3), (6, 6)]
train_stations = [
[( (0, 3), 0 ) ],
[( (6, 6), 0 ) ],
]
city_orientations = [0, 2]
agents_hints = {'num_agents': 2,
'city_positions': city_positions,
'train_stations': train_stations,
'city_orientations': city_orientations
}
optionals = {'agents_hints': agents_hints}
return rail, rail_map, optionals
def make_invalid_simple_rail() -> Tuple[GridTransitionMap, np.array]:
......@@ -251,4 +300,16 @@ def make_invalid_simple_rail() -> Tuple[GridTransitionMap, np.array]:
rail = GridTransitionMap(width=rail_map.shape[1],
height=rail_map.shape[0], transitions=transitions)
rail.grid = rail_map
return rail, rail_map
city_positions = [(0,3), (6, 6)]
train_stations = [
[( (0, 3), 0 ) ],
[( (6, 6), 0 ) ],
]
city_orientations = [0, 2]
agents_hints = {'num_agents': 2,
'city_positions': city_positions,
'train_stations': train_stations,
'city_orientations': city_orientations
}
optionals = {'agents_hints': agents_hints}
return rail, rail_map, optionals
......@@ -25,9 +25,23 @@ def test_walker():
rail = GridTransitionMap(width=rail_map.shape[1],
height=rail_map.shape[0], transitions=transitions)
rail.grid = rail_map
city_positions = [(0,2), (0, 1)]
train_stations = [
[( (0, 1), 0 ) ],
[( (0, 2), 0 ) ],
]
city_orientations = [1, 0]
agents_hints = {'num_agents': 1,
'city_positions': city_positions,
'train_stations': train_stations,
'city_orientations': city_orientations
}
optionals = {'agents_hints': agents_hints}
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2,
......
......@@ -4,16 +4,16 @@ from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.line_generators import rail_from_grid_transition_map
from flatland.envs.line_generators import sparse_line_generator
from flatland.utils.simple_rail import make_simple_rail
from test_utils import ReplayConfig, Replay, run_replay_config, set_penalties_for_replay
def test_initial_status():
"""Test that agent lifecycle works correctly ready-to-depart -> active -> done."""
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
line_generator=rail_from_grid_transition_map(), number_of_agents=1,
rail, rail_map, optionals = make_simple_rail()
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(), number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
remove_agents_at_target=False)
env.reset()
......@@ -124,9 +124,9 @@ def test_initial_status():
def test_status_done_remove():
"""Test that agent lifecycle works correctly ready-to-depart -> active -> done."""
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
line_generator=rail_from_grid_transition_map(), number_of_agents=1,
rail, rail_map, optionals = make_simple_rail()
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(), number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
remove_agents_at_target=True)
env.reset()
......
......@@ -66,10 +66,10 @@ def check_path(env, rail, position, direction, target, expected, rendering=False
def test_path_exists(rendering=False):
rail, rail_map = make_simple_rail()
rail, rail_map, optiionals = make_simple_rail()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
rail_generator=rail_from_grid_transition_map(rail, optiionals),
line_generator=sparse_line_generator(),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
......@@ -130,10 +130,10 @@ def test_path_exists(rendering=False):
def test_path_not_exists(rendering=False):
rail, rail_map = make_simple_rail_unconnected()
rail, rail_map, optionals = make_simple_rail_unconnected()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
......
......@@ -18,9 +18,9 @@ from flatland.utils.simple_rail import make_simple_rail
def test_global_obs():
rail, rail_map = make_simple_rail()
rail, rail_map, optionals = make_simple_rail()
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(), number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv())
......@@ -91,8 +91,8 @@ def _step_along_shortest_path(env, obs_builder, rail):
def test_reward_function_conflict(rendering=False):
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
rail, rail_map, optionals = make_simple_rail()
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(), number_of_agents=2,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
obs_builder: TreeObsForRailEnv = env.obs_builder
......@@ -179,8 +179,8 @@ def test_reward_function_conflict(rendering=False):
def test_reward_function_waiting(rendering=False):
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
rail, rail_map, optionals = make_simple_rail()
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(), number_of_agents=2,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
remove_agents_at_target=False)
......
......@@ -20,11 +20,11 @@ from flatland.utils.simple_rail import make_simple_rail, make_simple_rail2, make
def test_dummy_predictor(rendering=False):
rail, rail_map = make_simple_rail2()
rail, rail_map, optionals = make_simple_rail2()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10)),
......@@ -112,10 +112,10 @@ def test_dummy_predictor(rendering=False):
def test_shortest_path_predictor(rendering=False):
rail, rail_map = make_simple_rail()
rail, rail_map, optionals = make_simple_rail()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
......@@ -141,9 +141,8 @@ def test_shortest_path_predictor(rendering=False):
# compute the observations and predictions
distance_map = env.distance_map.get()
assert distance_map[0, agent.initial_position[0], agent.initial_position[1], agent.direction] == 5.0, \
"found {} instead of {}".format(
distance_map[agent.handle, agent.initial_position[0], agent.position[1], agent.direction], 5.0)
distance_on_map = distance_map[0, agent.initial_position[0], agent.initial_position[1], agent.direction]
assert distance_on_map == 5.0, "found {} instead of {}".format(distance_on_map, 5.0)
paths = get_shortest_paths(env.distance_map)[0]
assert paths == [
......@@ -243,10 +242,10 @@ def test_shortest_path_predictor(rendering=False):
def test_shortest_path_predictor_conflicts(rendering=False):
rail, rail_map = make_invalid_simple_rail()
rail, rail_map, optionals = make_invalid_simple_rail()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(),
number_of_agents=2,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
......
......@@ -36,8 +36,8 @@ def test_load_env():
def test_save_load():
env = RailEnv(width=10, height=10,
rail_generator=sparse_rail_generator(nr_start_goal=2, nr_extra=5, min_dist=6, seed=1),
env = RailEnv(width=30, height=30,
rail_generator=sparse_rail_generator(seed=1),
line_generator=sparse_line_generator(), number_of_agents=2)
env.reset()
......@@ -55,8 +55,8 @@ def test_save_load():
#env.load("test_save.dat")
env, env_dict = RailEnvPersister.load_new("tmp/test_save.pkl")
assert (env.width == 10)
assert (env.height == 10)
assert (env.width == 30)
assert (env.height == 30)
assert (len(env.agents) == 2)
assert (agent_1_pos == env.agents[0].position)
assert (agent_1_dir == env.agents[0].direction)
......@@ -67,8 +67,8 @@ def test_save_load():
def test_save_load_mpk():
env = RailEnv(width=10, height=10,
rail_generator=complex_rail_generator(nr_start_goal=2, nr_extra=5, min_dist=6, seed=1),
env = RailEnv(width=30, height=30,
rail_generator=sparse_rail_generator(seed=1),
line_generator=sparse_line_generator(), number_of_agents=2)
env.reset()
......@@ -204,7 +204,7 @@ def test_rail_environment_single_agent(show=False):
rail_env.agents[0].direction = 0
# JW - to avoid problem with random_line_generator.
# JW - to avoid problem with sparse_line_generator.
#rail_env.agents[0].position = (1,2)
iStep = 0
......@@ -247,7 +247,7 @@ def test_dead_end():
rail.grid = rail_map
rail_env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
line_generator=random_line_generator(), number_of_agents=1,
line_generator=sparse_line_generator(), number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv())
# We try the configuration in the 4 directions:
......@@ -270,7 +270,7 @@ def test_dead_end():
rail.grid = rail_map
rail_env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
line_generator=random_line_generator(), number_of_agents=1,
line_generator=sparse_line_generator(), number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv())
rail_env.reset()
......@@ -283,9 +283,9 @@ def test_dead_end():
def test_get_entry_directions():
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
line_generator=random_line_generator(), number_of_agents=1,
rail, rail_map, optionals = make_simple_rail()
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(), number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
env.reset()
......@@ -317,10 +317,10 @@ def test_rail_env_reset():
# Test to save and load file.
rail, rail_map = make_simple_rail()
rail, rail_map, optionals = make_simple_rail()
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
line_generator=random_line_generator(), number_of_agents=3,
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(), number_of_agents=3,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
env.reset()
......
......@@ -16,9 +16,9 @@ from flatland.envs.persistence import RailEnvPersister
def test_get_shortest_paths_unreachable():
rail, rail_map = make_disconnected_simple_rail()
rail, rail_map, optionals = make_disconnected_simple_rail()
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(), number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv())
env.reset()
......@@ -237,11 +237,11 @@ def test_get_shortest_paths_agent_handle():
def test_get_k_shortest_paths(rendering=False):
rail, rail_map = make_simple_rail_with_alternatives()
rail, rail_map, optionals = make_simple_rail_with_alternatives()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(),
number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv(),
......
......@@ -72,11 +72,11 @@ def test_malfunction_process():
max_duration=3 # Max duration of malfunction
)
rail, rail_map = make_simple_rail2()
rail, rail_map, optionals = make_simple_rail2()
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail),
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(),
number_of_agents=1,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
......@@ -128,11 +128,11 @@ def test_malfunction_process_statistically():
max_duration=5 # Max duration of malfunction
)
rail, rail_map = make_simple_rail2()
rail, rail_map, optionals = make_simple_rail2()
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail),
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(),
number_of_agents=10,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
......@@ -175,11 +175,11 @@ def test_malfunction_before_entry():
max_duration=10 # Max duration of malfunction
)
rail, rail_map = make_simple_rail2()
rail, rail_map, optionals = make_simple_rail2()
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail),
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(),
number_of_agents=10,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
......@@ -215,7 +215,7 @@ def test_malfunction_values_and_behavior():
"""
# Set fixed malfunction duration for this test
rail, rail_map = make_simple_rail2()
rail, rail_map, optionals = make_simple_rail2()
action_dict: Dict[int, RailEnvActions] = {}
stochastic_data = MalfunctionParameters(malfunction_rate=1/0.001, # Rate of malfunction occurence
min_duration=10, # Minimal duration of malfunction
......@@ -223,7 +223,7 @@ def test_malfunction_values_and_behavior():
)
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail),
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(),
number_of_agents=1,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
......@@ -248,11 +248,11 @@ def test_initial_malfunction():
max_duration=5 # Max duration of malfunction
)
rail, rail_map = make_simple_rail2()
rail, rail_map, optionals = make_simple_rail2()
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail),
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(seed=10),
number_of_agents=1,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
......@@ -315,9 +315,9 @@ def test_initial_malfunction():
def test_initial_malfunction_stop_moving():
rail, rail_map = make_simple_rail2()
rail, rail_map, optionals = make_simple_rail2()
env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(), number_of_agents=1,
obs_builder_object=SingleAgentNavigationObs())
env.reset()
......@@ -397,11 +397,11 @@ def test_initial_malfunction_do_nothing():
max_duration=5 # Max duration of malfunction
)
rail, rail_map = make_simple_rail2()
rail, rail_map, optionals = make_simple_rail2()
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail),
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(),
number_of_agents=1,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
......@@ -477,8 +477,8 @@ def test_initial_malfunction_do_nothing():
def tests_random_interference_from_outside():
"""Tests that malfunctions are produced by stochastic_data!"""
# Set fixed malfunction duration for this test
rail, rail_map = make_simple_rail2()
env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
rail, rail_map, optionals = make_simple_rail2()
env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(seed=2), number_of_agents=1, random_seed=1)
env.reset()
env.agents[0].speed_data['speed'] = 0.33
......@@ -499,10 +499,10 @@ def tests_random_interference_from_outside():
# Run the same test as above but with an external random generator running
# Check that the reward stays the same
rail, rail_map = make_simple_rail2()
rail, rail_map, optionals = make_simple_rail2()
random.seed(47)
np.random.seed(1234)
env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(seed=2), number_of_agents=1, random_seed=1)
env.reset()
env.agents[0].speed_data['speed'] = 0.33
......@@ -532,9 +532,9 @@ def test_last_malfunction_step():
# Set fixed malfunction duration for this test
rail, rail_map = make_simple_rail2()
rail, rail_map, optionals = make_simple_rail2()
env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(seed=2), number_of_agents=1, random_seed=1)
env.reset()
env.agents[0].speed_data['speed'] = 1. / 3.
......
......@@ -14,11 +14,12 @@ from flatland.utils.simple_rail import make_simple_rail
def test_multiprocessing_tree_obs():
number_of_agents = 5
rail, rail_map = make_simple_rail()
rail, rail_map, optionals = make_simple_rail()
optionals['agents_hints']['num_agents'] = number_of_agents
obs_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(), number_of_agents=number_of_agents,
obs_builder_object=obs_builder)
env.reset(True, True)
......
......@@ -29,9 +29,9 @@ def test_empty_rail_generator():
def test_rail_from_grid_transition_map():
rail, rail_map = make_simple_rail()
rail, rail_map, optionals = make_simple_rail()
n_agents = 4
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(), number_of_agents=n_agents)
env.reset(False, False, True)
nr_rail_elements = np.count_nonzero(env.rail.grid)
......@@ -51,9 +51,9 @@ def tests_rail_from_file():
# Test to save and load file with distance map.
rail, rail_map = make_simple_rail()
rail, rail_map, optionals = make_simple_rail()
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(), number_of_agents=3,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
env.reset()
......
......@@ -17,11 +17,11 @@ def test_malfanction_from_params():
min_duration=2, # Minimal duration of malfunction
max_duration=5 # Max duration of malfunction
)
rail, rail_map = make_simple_rail2()
rail, rail_map, optionals = make_simple_rail2()
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail),
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(),
number_of_agents=10,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data)
......@@ -44,11 +44,11 @@ def test_malfanction_to_and_from_file():
max_duration=5 # Max duration of malfunction