Commit 7209aad8 authored by u214892's avatar u214892
Browse files

merge #141 renamed agent_generator* to schedule_generator*

parent ccd73a5b
Pipeline #1836 failed with stages
in 8 minutes and 41 seconds
......@@ -5,7 +5,7 @@ import numpy as np
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator
from flatland.envs.schedule_generators import complex_rail_generator_agents_placer
from flatland.envs.schedule_generators import complex_schedule_generator
def run_benchmark():
......@@ -16,7 +16,7 @@ def run_benchmark():
# Example generate a random rail
env = RailEnv(width=15, height=15,
rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=20, min_dist=12),
agent_generator=complex_rail_generator_agents_placer(),
schedule_generator=complex_schedule_generator(),
number_of_agents=5)
n_trials = 20
......
......@@ -9,7 +9,7 @@ from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import random_rail_generator, complex_rail_generator
from flatland.envs.schedule_generators import complex_rail_generator_agents_placer
from flatland.envs.schedule_generators import complex_schedule_generator
from flatland.utils.rendertools import RenderTool
random.seed(100)
......@@ -93,7 +93,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
env = RailEnv(width=7,
height=7,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999, seed=0),
agent_generator=complex_rail_generator_agents_placer(),
schedule_generator=complex_schedule_generator(),
number_of_agents=1,
obs_builder_object=SingleAgentNavigationObs())
......@@ -204,7 +204,7 @@ CustomObsBuilder = ObservePredictions(CustomPredictor)
env = RailEnv(width=10,
height=10,
rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
agent_generator=complex_rail_generator_agents_placer(),
schedule_generator=complex_schedule_generator(),
number_of_agents=3,
obs_builder_object=CustomObsBuilder)
......
......@@ -29,7 +29,7 @@ def custom_rail_generator() -> RailGenerator:
return generator
def custom_agent_generator() -> ScheduleGenerator:
def custom_schedule_generator() -> ScheduleGenerator:
def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None) -> ScheduleGeneratorProduct:
agents_positions = []
agents_direction = []
......
......@@ -6,7 +6,7 @@ import numpy as np
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator
from flatland.envs.schedule_generators import complex_rail_generator_agents_placer
from flatland.envs.schedule_generators import complex_schedule_generator
from flatland.utils.rendertools import RenderTool
random.seed(1)
......@@ -61,7 +61,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
env = RailEnv(width=14,
height=14,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999, seed=0),
agent_generator=complex_rail_generator_agents_placer(),
schedule_generator=complex_schedule_generator(),
number_of_agents=2,
obs_builder_object=SingleAgentNavigationObs())
......
......@@ -4,7 +4,7 @@ from flatland.envs.generators import sparse_rail_generator
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.schedule_generators import sparse_rail_generator_agents_placer
from flatland.envs.schedule_generators import sparse_schedule_generator
from flatland.utils.rendertools import RenderTool
np.random.seed(1)
......@@ -32,7 +32,7 @@ env = RailEnv(width=20,
realistic_mode=True,
enhance_intersection=True
),
agent_generator=sparse_rail_generator_agents_placer(),
schedule_generator=sparse_schedule_generator(),
number_of_agents=5,
stochastic_data=stochastic_data, # Malfunction data generator
obs_builder_object=TreeObservation)
......
......@@ -5,7 +5,7 @@ import numpy as np
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator
from flatland.envs.schedule_generators import complex_rail_generator_agents_placer
from flatland.envs.schedule_generators import complex_schedule_generator
from flatland.utils.rendertools import RenderTool
random.seed(1)
......@@ -14,7 +14,7 @@ np.random.seed(1)
env = RailEnv(width=7,
height=7,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
agent_generator=complex_rail_generator_agents_placer(),
schedule_generator=complex_schedule_generator(),
number_of_agents=2,
obs_builder_object=TreeObsForRailEnv(max_depth=2))
......
......@@ -4,7 +4,7 @@ from flatland.envs.observations import TreeObsForRailEnv, LocalObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator
from flatland.envs.schedule_generators import complex_rail_generator_agents_placer
from flatland.envs.schedule_generators import complex_schedule_generator
from flatland.utils.rendertools import RenderTool
np.random.seed(1)
......@@ -17,7 +17,7 @@ LocalGridObs = LocalObsForRailEnv(view_height=10, view_width=2, center=2)
env = RailEnv(width=20,
height=20,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
agent_generator=complex_rail_generator_agents_placer(),
schedule_generator=complex_schedule_generator(),
obs_builder_object=TreeObservation,
number_of_agents=3)
......
......@@ -10,7 +10,7 @@ import redis
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator
from flatland.envs.schedule_generators import complex_rail_generator_agents_placer
from flatland.envs.schedule_generators import complex_schedule_generator
from flatland.evaluators.service import FlatlandRemoteEvaluationService
from flatland.utils.rendertools import RenderTool
......@@ -26,7 +26,7 @@ def demo(args=None):
nr_extra=1,
min_dist=8,
max_dist=99999),
agent_generator=complex_rail_generator_agents_placer(),
schedule_generator=complex_schedule_generator(),
number_of_agents=5)
env._max_episode_steps = int(15 * (env.width + env.height))
......
......@@ -193,81 +193,3 @@ def connect_to_nodes(rail_trans, rail_array, start, end):
current_dir = new_dir
return path
def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents):
"""
Given a `rail' GridTransitionMap, return a random placement of agents (initial position, direction and target).
TODO: add extensive documentation, as users may need this function to simplify their custom level generators.
"""
def _path_exists(rail, start, direction, end):
# BFS - Check if a path exists between the 2 nodes
visited = set()
stack = [(start, direction)]
while stack:
node = stack.pop()
if node[0][0] == end[0] and node[0][1] == end[1]:
return 1
if node not in visited:
visited.add(node)
moves = rail.get_transitions(node[0][0], node[0][1], node[1])
for move_index in range(4):
if moves[move_index]:
stack.append((get_new_position(node[0], move_index),
move_index))
# If cell is a dead-end, append previous node with reversed
# orientation!
nbits = 0
tmp = rail.get_full_transitions(node[0][0], node[0][1])
while tmp > 0:
nbits += (tmp & 1)
tmp = tmp >> 1
if nbits == 1:
stack.append((node[0], (node[1] + 2) % 4))
return 0
valid_positions = []
for r in range(rail.height):
for c in range(rail.width):
if rail.get_full_transitions(r, c) > 0:
valid_positions.append((r, c))
re_generate = True
while re_generate:
agents_position = [
valid_positions[i] for i in
np.random.choice(len(valid_positions), num_agents)]
agents_target = [
valid_positions[i] for i in
np.random.choice(len(valid_positions), num_agents)]
# agents_direction must be a direction for which a solution is
# guaranteed.
agents_direction = [0] * num_agents
re_generate = False
for i in range(num_agents):
valid_movements = []
for direction in range(4):
position = agents_position[i]
moves = rail.get_transitions(position[0], position[1], direction)
for move_index in range(4):
if moves[move_index]:
valid_movements.append((direction, move_index))
valid_starting_directions = []
for m in valid_movements:
new_position = get_new_position(agents_position[i], m[1])
if m[0] not in valid_starting_directions and _path_exists(rail, new_position, m[0], agents_target[i]):
valid_starting_directions.append(m[0])
if len(valid_starting_directions) == 0:
re_generate = True
else:
agents_direction[i] = valid_starting_directions[np.random.choice(len(valid_starting_directions), 1)[0]]
return agents_position, agents_direction, agents_target
......@@ -15,7 +15,7 @@ from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.rail_generators import random_rail_generator, RailGenerator
from flatland.envs.schedule_generators import get_rnd_agents_pos_tgt_dir_on_rail, ScheduleGenerator
from flatland.envs.schedule_generators import random_schedule_generator, ScheduleGenerator
m.patch()
......@@ -94,7 +94,7 @@ class RailEnv(Environment):
width,
height,
rail_generator: RailGenerator = random_rail_generator(),
agent_generator: ScheduleGenerator = get_rnd_agents_pos_tgt_dir_on_rail(),
schedule_generator: ScheduleGenerator = random_schedule_generator(),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2),
max_episode_steps=None,
......@@ -110,10 +110,10 @@ class RailEnv(Environment):
height and agents handles of a rail environment, along with the number of times
the env has been reset, and returns a GridTransitionMap object and a list of
starting positions, targets, and initial orientations for agent handle.
The rail_generator can pass a distance map in the hints or information for specific agent_generators.
The rail_generator can pass a distance map in the hints or information for specific schedule_generators.
Implementations can be found in flatland/envs/rail_generators.py
agent_generator : function
The agent_generator function is a function that takes the grid, the number of agents and optional hints
schedule_generator : function
The schedule_generator function is a function that takes the grid, the number of agents and optional hints
and returns a list of starting positions, targets, initial orientations and speed for all agent handles.
Implementations can be found in flatland/envs/schedule_generators.py
width : int
......@@ -134,7 +134,7 @@ class RailEnv(Environment):
"""
self.rail_generator: RailGenerator = rail_generator
self.agent_generator: ScheduleGenerator = agent_generator
self.schedule_generator: ScheduleGenerator = schedule_generator
self.rail_generator = rail_generator
self.rail: GridTransitionMap = None
self.width = width
......@@ -237,7 +237,7 @@ class RailEnv(Environment):
if optionals and 'agents_hints' in optionals:
agents_hints = optionals['agents_hints']
self.agents_static = EnvAgentStatic.from_lists(
*self.agent_generator(self.rail, self.get_num_agents(), hints=agents_hints))
*self.schedule_generator(self.rail, self.get_num_agents(), hints=agents_hints))
self.restart_agents()
......
......@@ -38,7 +38,7 @@ def speed_initialization_helper(nb_agents: int, speed_ratio_map: Mapping[float,
return list(map(lambda index: speeds[index], np.random.choice(nb_classes, nb_agents, p=speed_ratios)))
def complex_rail_generator_agents_placer(speed_ratio_map: Mapping[float, float] = None) -> ScheduleGenerator:
def complex_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> ScheduleGenerator:
def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None):
start_goal = hints['start_goal']
start_dir = hints['start_dir']
......@@ -56,7 +56,7 @@ def complex_rail_generator_agents_placer(speed_ratio_map: Mapping[float, float]
return generator
def sparse_rail_generator_agents_placer(speed_ratio_map: Mapping[float, float] = None) -> ScheduleGenerator:
def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> ScheduleGenerator:
def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None):
train_stations = hints['train_stations']
agent_start_targets_nodes = hints['agent_start_targets_nodes']
......@@ -111,7 +111,7 @@ def sparse_rail_generator_agents_placer(speed_ratio_map: Mapping[float, float] =
return generator
def get_rnd_agents_pos_tgt_dir_on_rail(speed_ratio_map: Mapping[float, float] = None) -> ScheduleGenerator:
def random_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> ScheduleGenerator:
"""
Given a `rail' GridTransitionMap, return a random placement of agents (initial position, direction and target).
......
......@@ -6,7 +6,7 @@ from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.schedule_generators import get_rnd_agents_pos_tgt_dir_on_rail
from flatland.envs.schedule_generators import random_schedule_generator
def test_walker():
......@@ -28,7 +28,7 @@ def test_walker():
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(),
schedule_generator=random_schedule_generator(),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2,
predictor=ShortestPathPredictorForRailEnv(max_depth=10)),
......
......@@ -9,7 +9,7 @@ from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.schedule_generators import get_rnd_agents_pos_tgt_dir_on_rail
from flatland.envs.schedule_generators import random_schedule_generator
from flatland.utils.rendertools import RenderTool
from flatland.utils.simple_rail import make_simple_rail
......@@ -22,7 +22,7 @@ def test_global_obs():
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(),
schedule_generator=random_schedule_generator(),
number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv())
......@@ -92,7 +92,7 @@ def test_reward_function_conflict(rendering=False):
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(),
schedule_generator=random_schedule_generator(),
number_of_agents=2,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
......@@ -171,7 +171,7 @@ def test_reward_function_waiting(rendering=False):
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(),
schedule_generator=random_schedule_generator(),
number_of_agents=2,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
......
......@@ -9,7 +9,7 @@ from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.schedule_generators import get_rnd_agents_pos_tgt_dir_on_rail
from flatland.envs.schedule_generators import random_schedule_generator
from flatland.utils.rendertools import RenderTool
from flatland.utils.simple_rail import make_simple_rail, make_simple_rail2, make_invalid_simple_rail
......@@ -22,7 +22,7 @@ def test_dummy_predictor(rendering=False):
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(),
schedule_generator=random_schedule_generator(),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10)),
)
......@@ -113,7 +113,7 @@ def test_shortest_path_predictor(rendering=False):
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(),
schedule_generator=random_schedule_generator(),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
......@@ -233,7 +233,7 @@ def test_shortest_path_predictor_conflicts(rendering=False):
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(),
schedule_generator=random_schedule_generator(),
number_of_agents=2,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
......
......@@ -10,7 +10,7 @@ from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.schedule_generators import get_rnd_agents_pos_tgt_dir_on_rail, complex_rail_generator_agents_placer
from flatland.envs.schedule_generators import random_schedule_generator, complex_schedule_generator
"""Tests for `flatland` package."""
......@@ -27,7 +27,7 @@ def test_load_env():
def test_save_load():
env = RailEnv(width=10, height=10,
rail_generator=complex_rail_generator(nr_start_goal=2, nr_extra=5, min_dist=6, seed=0),
agent_generator=complex_rail_generator_agents_placer(),
schedule_generator=complex_schedule_generator(),
number_of_agents=2)
env.reset()
agent_1_pos = env.agents_static[0].position
......@@ -79,7 +79,7 @@ def test_rail_environment_single_agent():
rail_env = RailEnv(width=3,
height=3,
rail_generator=rail_from_grid_transition_map(rail),
agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(),
schedule_generator=random_schedule_generator(),
number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv())
......@@ -159,7 +159,7 @@ def test_dead_end():
rail_env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(),
schedule_generator=random_schedule_generator(),
number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv())
......@@ -204,7 +204,7 @@ def test_dead_end():
rail_env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(),
schedule_generator=random_schedule_generator(),
number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv())
......
from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.schedule_generators import sparse_rail_generator_agents_placer
from flatland.envs.schedule_generators import sparse_schedule_generator
from flatland.utils.rendertools import RenderTool
......@@ -17,7 +17,7 @@ def test_sparse_rail_generator():
seed=5, # Random seed
realistic_mode=False # Ordered distribution of nodes
),
agent_generator=sparse_rail_generator_agents_placer(),
schedule_generator=sparse_schedule_generator(),
number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv())
# reset to initialize agents_static
......
......@@ -3,7 +3,7 @@ import numpy as np
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator
from flatland.envs.schedule_generators import complex_rail_generator_agents_placer
from flatland.envs.schedule_generators import complex_schedule_generator
class SingleAgentNavigationObs(TreeObsForRailEnv):
......@@ -63,7 +63,7 @@ def test_malfunction_process():
height=20,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999,
seed=0),
agent_generator=complex_rail_generator_agents_placer(),
schedule_generator=complex_schedule_generator(),
number_of_agents=2,
obs_builder_object=SingleAgentNavigationObs(),
stochastic_data=stochastic_data)
......
......@@ -2,7 +2,7 @@ import numpy as np
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator
from flatland.envs.schedule_generators import complex_rail_generator_agents_placer
from flatland.envs.schedule_generators import complex_schedule_generator
np.random.seed(1)
......@@ -48,7 +48,7 @@ def test_multi_speed_init():
height=50,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999,
seed=0),
agent_generator=complex_rail_generator_agents_placer(),
schedule_generator=complex_schedule_generator(),
number_of_agents=5)
# Initialize the agent with the parameters corresponding to the environment and observation_builder
agent = RandomAgent(218, 4)
......
......@@ -3,7 +3,7 @@ import numpy as np
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator
from flatland.envs.schedule_generators import speed_initialization_helper, complex_rail_generator_agents_placer
from flatland.envs.schedule_generators import speed_initialization_helper, complex_schedule_generator
def test_speed_initialization_helper():
......@@ -22,7 +22,7 @@ def test_rail_env_speed_intializer():
height=50,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999,
seed=0),
agent_generator=complex_rail_generator_agents_placer(),
schedule_generator=complex_schedule_generator(),
number_of_agents=10)
env.reset()
actual_speeds = list(map(lambda agent: agent.speed_data['speed'], env.agents))
......
......@@ -8,7 +8,7 @@ from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_grid_transition_map, rail_from_file, complex_rail_generator, \
random_rail_generator, empty_rail_generator
from flatland.envs.schedule_generators import get_rnd_agents_pos_tgt_dir_on_rail, complex_rail_generator_agents_placer, \
from flatland.envs.schedule_generators import random_schedule_generator, complex_schedule_generator, \
agents_from_file
from flatland.utils.simple_rail import make_simple_rail
......@@ -61,7 +61,7 @@ def test_complex_rail_generator():
height=y_dim,
number_of_agents=n_agents,
rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist),
agent_generator=complex_rail_generator_agents_placer()
schedule_generator=complex_schedule_generator()
)
assert env.get_num_agents() == 2
assert env.rail.grid.shape == (y_dim, x_dim)
......@@ -73,7 +73,7 @@ def test_complex_rail_generator():
height=y_dim,
number_of_agents=n_agents,
rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist),
agent_generator=complex_rail_generator_agents_placer()
schedule_generator=complex_schedule_generator()
)
assert env.get_num_agents() == 0
assert env.rail.grid.shape == (y_dim, x_dim)
......@@ -87,7 +87,7 @@ def test_complex_rail_generator():
height=y_dim,
number_of_agents=n_agents,
rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist),
agent_generator=complex_rail_generator_agents_placer()
schedule_generator=complex_schedule_generator()
)
assert env.get_num_agents() == n_agents
assert env.rail.grid.shape == (y_dim, x_dim)
......@@ -99,7 +99,7 @@ def test_rail_from_grid_transition_map():
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(),
schedule_generator=random_schedule_generator(),
number_of_agents=n_agents
)
nr_rail_elements = np.count_nonzero(env.rail.grid)
......@@ -124,7 +124,7 @@ def tests_rail_from_file():
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(),
schedule_generator=random_schedule_generator(),
number_of_agents=3,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
......@@ -137,7 +137,7 @@ def tests_rail_from_file():
env = RailEnv(width=1,
height=1,
rail_generator=rail_from_file(file_name),
agent_generator=agents_from_file(file_name),
schedule_generator=agents_from_file(file_name),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
......@@ -159,7 +159,7 @@ def tests_rail_from_file():
env2 = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
agent_generator=get_rnd_agents_pos_tgt_dir_on_rail(),
schedule_generator=random_schedule_generator(),
number_of_agents=3,
obs_builder_object=GlobalObsForRailEnv(),
)
......@@ -173,7 +173,7 @@ def tests_rail_from_file():
env2 = RailEnv(width=1,
height=1,
rail_generator=rail_from_file(file_name_2),
agent_generator=agents_from_file(file_name_2),
schedule_generator=agents_from_file(file_name_2),
number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv(),
)
......@@ -190,7 +190,7 @@ def tests_rail_from_file():
env3 = RailEnv(width=1,
height=1,
rail_generator=rail_from_file(file_name),
agent_generator=agents_from_file(file_name),
schedule_generator=agents_from_file(file_name),
number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv(),
)
......@@ -208,7 +208,7 @@ def tests_rail_from_file():
env4 = RailEnv(width=1,
height=1,