Commit 329120c7 authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

change names from schedule to line in test and evaluators

parent 7768750d
......@@ -3,7 +3,7 @@ from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_file
from flatland.envs.schedule_generators import schedule_from_file
from flatland.envs.line_generators import line_from_file
def load_flatland_environment_from_file(file_name: str,
......@@ -33,7 +33,7 @@ def load_flatland_environment_from_file(file_name: str,
max_depth=2,
predictor=ShortestPathPredictorForRailEnv(max_depth=10))
environment = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name, load_from_package),
schedule_generator=schedule_from_file(file_name, load_from_package),
schedule_generator=line_from_file(file_name, load_from_package),
number_of_agents=1,
obs_builder_object=obs_builder_object,
record_steps=record_steps,
......
......@@ -15,7 +15,7 @@ import flatland
from flatland.envs.malfunction_generators import malfunction_from_file
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_file
from flatland.envs.schedule_generators import schedule_from_file
from flatland.envs.line_generators import line_from_file
from flatland.evaluators import messages
from flatland.core.env_observation_builder import DummyObservationBuilder
......
......@@ -26,7 +26,7 @@ from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.malfunction_generators import malfunction_from_file
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_file
from flatland.envs.schedule_generators import schedule_from_file
from flatland.envs.line_generators import line_from_file
from flatland.evaluators import aicrowd_helpers
from flatland.evaluators import messages
from flatland.utils.rendertools import RenderTool
......
......@@ -6,7 +6,7 @@ from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.rail_trainrun_data_structures import Waypoint
from flatland.envs.schedule_generators import random_schedule_generator
from flatland.envs.line_generators import random_line_generator
from flatland.utils.rendertools import RenderTool, AgentRenderVariant
from flatland.utils.simple_rail import make_simple_rail
......@@ -17,7 +17,7 @@ def test_action_plan(rendering: bool = False):
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(seed=77),
line_generator=random_line_generator(seed=77),
number_of_agents=2,
obs_builder_object=GlobalObsForRailEnv(),
remove_agents_at_target=True
......@@ -34,25 +34,25 @@ def test_action_plan(rendering: bool = False):
for handle, agent in enumerate(env.agents):
print("[{}] {} -> {}".format(handle, agent.initial_position, agent.target))
chosen_path_dict = {0: [TrainrunWaypoint(scheduled_at=0, waypoint=Waypoint(position=(3, 0), direction=3)),
TrainrunWaypoint(scheduled_at=2, waypoint=Waypoint(position=(3, 1), direction=1)),
TrainrunWaypoint(scheduled_at=3, waypoint=Waypoint(position=(3, 2), direction=1)),
TrainrunWaypoint(scheduled_at=14, waypoint=Waypoint(position=(3, 3), direction=1)),
TrainrunWaypoint(scheduled_at=15, waypoint=Waypoint(position=(3, 4), direction=1)),
TrainrunWaypoint(scheduled_at=16, waypoint=Waypoint(position=(3, 5), direction=1)),
TrainrunWaypoint(scheduled_at=17, waypoint=Waypoint(position=(3, 6), direction=1)),
TrainrunWaypoint(scheduled_at=18, waypoint=Waypoint(position=(3, 7), direction=1)),
TrainrunWaypoint(scheduled_at=19, waypoint=Waypoint(position=(3, 8), direction=1)),
TrainrunWaypoint(scheduled_at=20, waypoint=Waypoint(position=(3, 8), direction=5))],
1: [TrainrunWaypoint(scheduled_at=0, waypoint=Waypoint(position=(3, 8), direction=3)),
TrainrunWaypoint(scheduled_at=3, waypoint=Waypoint(position=(3, 7), direction=3)),
TrainrunWaypoint(scheduled_at=5, waypoint=Waypoint(position=(3, 6), direction=3)),
TrainrunWaypoint(scheduled_at=7, waypoint=Waypoint(position=(3, 5), direction=3)),
TrainrunWaypoint(scheduled_at=9, waypoint=Waypoint(position=(3, 4), direction=3)),
TrainrunWaypoint(scheduled_at=11, waypoint=Waypoint(position=(3, 3), direction=3)),
TrainrunWaypoint(scheduled_at=13, waypoint=Waypoint(position=(2, 3), direction=0)),
TrainrunWaypoint(scheduled_at=15, waypoint=Waypoint(position=(1, 3), direction=0)),
TrainrunWaypoint(scheduled_at=17, waypoint=Waypoint(position=(0, 3), direction=0))]}
chosen_path_dict = {0: [TrainrunWaypoint(lined_at=0, waypoint=Waypoint(position=(3, 0), direction=3)),
TrainrunWaypoint(lined_at=2, waypoint=Waypoint(position=(3, 1), direction=1)),
TrainrunWaypoint(lined_at=3, waypoint=Waypoint(position=(3, 2), direction=1)),
TrainrunWaypoint(lined_at=14, waypoint=Waypoint(position=(3, 3), direction=1)),
TrainrunWaypoint(lined_at=15, waypoint=Waypoint(position=(3, 4), direction=1)),
TrainrunWaypoint(lined_at=16, waypoint=Waypoint(position=(3, 5), direction=1)),
TrainrunWaypoint(lined_at=17, waypoint=Waypoint(position=(3, 6), direction=1)),
TrainrunWaypoint(lined_at=18, waypoint=Waypoint(position=(3, 7), direction=1)),
TrainrunWaypoint(lined_at=19, waypoint=Waypoint(position=(3, 8), direction=1)),
TrainrunWaypoint(lined_at=20, waypoint=Waypoint(position=(3, 8), direction=5))],
1: [TrainrunWaypoint(lined_at=0, waypoint=Waypoint(position=(3, 8), direction=3)),
TrainrunWaypoint(lined_at=3, waypoint=Waypoint(position=(3, 7), direction=3)),
TrainrunWaypoint(lined_at=5, waypoint=Waypoint(position=(3, 6), direction=3)),
TrainrunWaypoint(lined_at=7, waypoint=Waypoint(position=(3, 5), direction=3)),
TrainrunWaypoint(lined_at=9, waypoint=Waypoint(position=(3, 4), direction=3)),
TrainrunWaypoint(lined_at=11, waypoint=Waypoint(position=(3, 3), direction=3)),
TrainrunWaypoint(lined_at=13, waypoint=Waypoint(position=(2, 3), direction=0)),
TrainrunWaypoint(lined_at=15, waypoint=Waypoint(position=(1, 3), direction=0)),
TrainrunWaypoint(lined_at=17, waypoint=Waypoint(position=(0, 3), direction=0))]}
expected_action_plan = [[
# take action to enter the grid
ActionPlanElement(0, RailEnvActions.MOVE_FORWARD),
......
......@@ -6,7 +6,7 @@ from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.schedule_generators import random_schedule_generator
from flatland.envs.line_generators import random_line_generator
def test_walker():
......@@ -28,7 +28,7 @@ def test_walker():
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(),
line_generator=random_line_generator(),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2,
predictor=ShortestPathPredictorForRailEnv(max_depth=10)),
......
......@@ -4,7 +4,7 @@ from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.schedule_generators import random_schedule_generator
from flatland.envs.line_generators import random_line_generator
from flatland.utils.simple_rail import make_simple_rail
from test_utils import ReplayConfig, Replay, run_replay_config, set_penalties_for_replay
......@@ -13,7 +13,7 @@ def test_initial_status():
"""Test that agent lifecycle works correctly ready-to-depart -> active -> done."""
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(), number_of_agents=1,
line_generator=random_line_generator(), number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
remove_agents_at_target=False)
env.reset()
......@@ -121,7 +121,7 @@ def test_status_done_remove():
"""Test that agent lifecycle works correctly ready-to-depart -> active -> done."""
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(), number_of_agents=1,
line_generator=random_line_generator(), number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
remove_agents_at_target=True)
env.reset()
......
......@@ -6,7 +6,7 @@ from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.schedule_generators import random_schedule_generator
from flatland.envs.line_generators import random_line_generator
from flatland.utils.rendertools import RenderTool
from flatland.utils.simple_rail import make_simple_rail, make_simple_rail_unconnected
......@@ -70,7 +70,7 @@ def test_path_exists(rendering=False):
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(),
line_generator=random_line_generator(),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
......@@ -134,7 +134,7 @@ def test_path_not_exists(rendering=False):
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(),
line_generator=random_line_generator(),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
......
......@@ -10,7 +10,7 @@ from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.schedule_generators import random_schedule_generator
from flatland.envs.line_generators import random_line_generator
from flatland.utils.rendertools import RenderTool
from flatland.utils.simple_rail import make_simple_rail
......@@ -21,7 +21,7 @@ def test_global_obs():
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(), number_of_agents=1,
line_generator=random_line_generator(), number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv())
global_obs, info = env.reset()
......@@ -93,7 +93,7 @@ def _step_along_shortest_path(env, obs_builder, rail):
def test_reward_function_conflict(rendering=False):
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(), number_of_agents=2,
line_generator=random_line_generator(), number_of_agents=2,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
obs_builder: TreeObsForRailEnv = env.obs_builder
env.reset()
......@@ -181,7 +181,7 @@ def test_reward_function_conflict(rendering=False):
def test_reward_function_waiting(rendering=False):
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(), number_of_agents=2,
line_generator=random_line_generator(), number_of_agents=2,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
remove_agents_at_target=False)
obs_builder: TreeObsForRailEnv = env.obs_builder
......
......@@ -12,7 +12,7 @@ from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_env_shortest_paths import get_shortest_paths
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.rail_trainrun_data_structures import Waypoint
from flatland.envs.schedule_generators import random_schedule_generator
from flatland.envs.line_generators import random_line_generator
from flatland.utils.rendertools import RenderTool
from flatland.utils.simple_rail import make_simple_rail, make_simple_rail2, make_invalid_simple_rail
......@@ -25,7 +25,7 @@ def test_dummy_predictor(rendering=False):
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(),
line_generator=random_line_generator(),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10)),
)
......@@ -116,7 +116,7 @@ def test_shortest_path_predictor(rendering=False):
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(),
line_generator=random_line_generator(),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
......@@ -247,7 +247,7 @@ def test_shortest_path_predictor_conflicts(rendering=False):
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(),
line_generator=random_line_generator(),
number_of_agents=2,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
......
......@@ -11,7 +11,7 @@ from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import complex_rail_generator, rail_from_file
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.schedule_generators import random_schedule_generator, complex_schedule_generator, schedule_from_file
from flatland.envs.line_generators import random_line_generator, complex_line_generator, line_from_file
from flatland.utils.simple_rail import make_simple_rail
from flatland.envs.persistence import RailEnvPersister
from flatland.utils.rendertools import RenderTool
......@@ -38,7 +38,7 @@ def test_load_env():
def test_save_load():
env = RailEnv(width=10, height=10,
rail_generator=complex_rail_generator(nr_start_goal=2, nr_extra=5, min_dist=6, seed=1),
schedule_generator=complex_schedule_generator(), number_of_agents=2)
line_generator=complex_line_generator(), number_of_agents=2)
env.reset()
agent_1_pos = env.agents[0].position
agent_1_dir = env.agents[0].direction
......@@ -68,7 +68,7 @@ def test_save_load():
def test_save_load_mpk():
env = RailEnv(width=10, height=10,
rail_generator=complex_rail_generator(nr_start_goal=2, nr_extra=5, min_dist=6, seed=1),
schedule_generator=complex_schedule_generator(), number_of_agents=2)
line_generator=complex_line_generator(), number_of_agents=2)
env.reset()
os.makedirs("tmp", exist_ok=True)
......@@ -120,7 +120,7 @@ def test_rail_environment_single_agent(show=False):
rail = GridTransitionMap(width=3, height=3, transitions=transitions)
rail.grid = rail_map
rail_env = RailEnv(width=3, height=3, rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(), number_of_agents=1,
line_generator=random_line_generator(), number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv())
else:
rail_env, env_dict = RailEnvPersister.load_new("test_env_loop.pkl", "env_data.tests")
......@@ -203,7 +203,7 @@ def test_rail_environment_single_agent(show=False):
rail_env.agents[0].direction = 0
# JW - to avoid problem with random_schedule_generator.
# JW - to avoid problem with random_line_generator.
#rail_env.agents[0].position = (1,2)
iStep = 0
......@@ -246,7 +246,7 @@ def test_dead_end():
rail.grid = rail_map
rail_env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(), number_of_agents=1,
line_generator=random_line_generator(), number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv())
# We try the configuration in the 4 directions:
......@@ -269,7 +269,7 @@ def test_dead_end():
rail.grid = rail_map
rail_env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(), number_of_agents=1,
line_generator=random_line_generator(), number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv())
rail_env.reset()
......@@ -284,7 +284,7 @@ def test_dead_end():
def test_get_entry_directions():
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(), number_of_agents=1,
line_generator=random_line_generator(), number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
env.reset()
......@@ -319,7 +319,7 @@ def test_rail_env_reset():
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(), number_of_agents=3,
line_generator=random_line_generator(), number_of_agents=3,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
env.reset()
......@@ -331,7 +331,7 @@ def test_rail_env_reset():
agents_initial = env.agents
#env2 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name),
# schedule_generator=schedule_from_file(file_name), number_of_agents=1,
# line_generator=line_from_file(file_name), number_of_agents=1,
# obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
#env2.reset(False, False, False)
env2, env2_dict = RailEnvPersister.load_new(file_name)
......@@ -343,7 +343,7 @@ def test_rail_env_reset():
assert agents_initial == agents_loaded
env3 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name),
schedule_generator=schedule_from_file(file_name), number_of_agents=1,
line_generator=line_from_file(file_name), number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
env3.reset(False, True, False)
rails_loaded = env3.rail.grid
......@@ -353,7 +353,7 @@ def test_rail_env_reset():
assert agents_initial == agents_loaded
env4 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name),
schedule_generator=schedule_from_file(file_name), number_of_agents=1,
line_generator=line_from_file(file_name), number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
env4.reset(True, False, False)
rails_loaded = env4.rail.grid
......
......@@ -9,7 +9,7 @@ from flatland.envs.rail_env_shortest_paths import get_shortest_paths, get_k_shor
from flatland.envs.rail_env_utils import load_flatland_environment_from_file
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.rail_trainrun_data_structures import Waypoint
from flatland.envs.schedule_generators import random_schedule_generator
from flatland.envs.line_generators import random_line_generator
from flatland.utils.rendertools import RenderTool
from flatland.utils.simple_rail import make_disconnected_simple_rail, make_simple_rail_with_alternatives
from flatland.envs.persistence import RailEnvPersister
......@@ -19,7 +19,7 @@ def test_get_shortest_paths_unreachable():
rail, rail_map = make_disconnected_simple_rail()
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(), number_of_agents=1,
line_generator=random_line_generator(), number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv())
env.reset()
......@@ -238,7 +238,7 @@ def test_get_k_shortest_paths(rendering=False):
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(),
line_generator=random_line_generator(),
number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv(),
)
......
......@@ -7,7 +7,7 @@ from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.schedule_generators import sparse_schedule_generator
from flatland.envs.line_generators import sparse_line_generator
from flatland.utils.rendertools import RenderTool
......@@ -17,7 +17,7 @@ def test_sparse_rail_generator():
seed=5,
grid_mode=False
),
schedule_generator=sparse_schedule_generator(), number_of_agents=10,
line_generator=sparse_line_generator(), number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv())
env.reset(False, False, True)
for r in range(env.height):
......@@ -602,7 +602,7 @@ def test_sparse_rail_generator_deterministic():
seed=215545, # Random seed
grid_mode=True
),
schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=1)
line_generator=sparse_line_generator(speed_ration_map), number_of_agents=1)
env.reset()
# for r in range(env.height):
# for c in range(env.width):
......@@ -1371,7 +1371,7 @@ def test_rail_env_action_required_info():
max_rails_between_cities=3,
seed=5, # Random seed
grid_mode=False # Ordered distribution of nodes
), schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=10,
), line_generator=sparse_line_generator(speed_ration_map), number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv(), remove_agents_at_target=False)
env_only_if_action_required = RailEnv(width=50, height=50, rail_generator=sparse_rail_generator(
......@@ -1380,7 +1380,7 @@ def test_rail_env_action_required_info():
seed=5, # Random seed
grid_mode=False
# Ordered distribution of nodes
), schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=10,
), line_generator=sparse_line_generator(speed_ration_map), number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv(), remove_agents_at_target=False)
env_renderer = RenderTool(env_always_action, gl="PILSVG", )
......@@ -1442,7 +1442,7 @@ def test_rail_env_malfunction_speed_info():
seed=5,
grid_mode=False
),
schedule_generator=sparse_schedule_generator(), number_of_agents=10,
line_generator=sparse_line_generator(), number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv())
env.reset(False, False, True)
......@@ -1476,7 +1476,7 @@ def test_sparse_generator_with_too_man_cities_does_not_break_down():
max_rails_between_cities=3,
seed=5,
grid_mode=False
), schedule_generator=sparse_schedule_generator(), number_of_agents=10, obs_builder_object=GlobalObsForRailEnv())
), line_generator=sparse_line_generator(), number_of_agents=10, obs_builder_object=GlobalObsForRailEnv())
def test_sparse_generator_with_illegal_params_aborts():
......@@ -1489,7 +1489,7 @@ def test_sparse_generator_with_illegal_params_aborts():
max_rails_between_cities=3,
seed=5,
grid_mode=False
), schedule_generator=sparse_schedule_generator(), number_of_agents=10,
), line_generator=sparse_line_generator(), number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv()).reset()
with unittest.TestCase.assertRaises(test_sparse_generator_with_illegal_params_aborts, ValueError):
......@@ -1498,7 +1498,7 @@ def test_sparse_generator_with_illegal_params_aborts():
max_rails_between_cities=3,
seed=5,
grid_mode=False
), schedule_generator=sparse_schedule_generator(), number_of_agents=10,
), line_generator=sparse_line_generator(), number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv()).reset()
......@@ -1515,7 +1515,7 @@ def test_sparse_generator_changes_to_grid_mode():
max_rails_in_city=2,
seed=15,
grid_mode=False
), schedule_generator=sparse_schedule_generator(), number_of_agents=10,
), line_generator=sparse_line_generator(), number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv())
for test_run in range(10):
with warnings.catch_warnings(record=True) as w:
......
......@@ -10,7 +10,7 @@ from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.schedule_generators import random_schedule_generator
from flatland.envs.line_generators import random_line_generator
from flatland.utils.simple_rail import make_simple_rail2
from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay
......@@ -77,7 +77,7 @@ def test_malfunction_process():
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(),
line_generator=random_line_generator(),
number_of_agents=1,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
obs_builder_object=SingleAgentNavigationObs()
......@@ -131,7 +131,7 @@ def test_malfunction_process_statistically():
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(),
line_generator=random_line_generator(),
number_of_agents=10,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
obs_builder_object=SingleAgentNavigationObs()
......@@ -178,7 +178,7 @@ def test_malfunction_before_entry():
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(),
line_generator=random_line_generator(),
number_of_agents=10,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
obs_builder_object=SingleAgentNavigationObs()
......@@ -222,7 +222,7 @@ def test_malfunction_values_and_behavior():
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(),
line_generator=random_line_generator(),
number_of_agents=1,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
obs_builder_object=SingleAgentNavigationObs()
......@@ -251,7 +251,7 @@ def test_initial_malfunction():
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(seed=10),
line_generator=random_line_generator(seed=10),
number_of_agents=1,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
# Malfunction data generator
......@@ -316,7 +316,7 @@ def test_initial_malfunction_stop_moving():
rail, rail_map = make_simple_rail2()
env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(), number_of_agents=1,
line_generator=random_line_generator(), number_of_agents=1,
obs_builder_object=SingleAgentNavigationObs())
env.reset()
......@@ -400,7 +400,7 @@ def test_initial_malfunction_do_nothing():
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(),
line_generator=random_line_generator(),
number_of_agents=1,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
# Malfunction data generator
......@@ -477,7 +477,7 @@ def tests_random_interference_from_outside():
# Set fixed malfunction duration for this test
rail, rail_map = make_simple_rail2()
env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(seed=2), number_of_agents=1, random_seed=1)
line_generator=random_line_generator(seed=2), number_of_agents=1, random_seed=1)
env.reset()
env.agents[0].speed_data['speed'] = 0.33
env.reset(False, False, False, random_seed=10)
......@@ -501,7 +501,7 @@ def tests_random_interference_from_outside():
random.seed(47)
np.random.seed(1234)
env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(seed=2), number_of_agents=1, random_seed=1)
line_generator=random_line_generator(seed=2), number_of_agents=1, random_seed=1)
env.reset()
env.agents[0].speed_data['speed'] = 0.33
env.reset(False, False, False, random_seed=10)
......@@ -533,7 +533,7 @@ def test_last_malfunction_step():
rail, rail_map = make_simple_rail2()
env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(seed=2), number_of_agents=1, random_seed=1)
line_generator=random_line_generator(seed=2), number_of_agents=1, random_seed=1)
env.reset()
env.agents[0].speed_data['speed'] = 1. / 3.
env.agents[0].target = (0, 0)
......
......@@ -6,7 +6,7 @@ from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.schedule_generators import random_schedule_generator
from flatland.envs.line_generators import random_line_generator
from flatland.utils.simple_rail import make_simple_rail
"""Tests for `flatland` package."""
......@@ -19,7 +19,7 @@ def test_multiprocessing_tree_obs():
obs_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(), number_of_agents=number_of_agents,
line_generator=random_line_generator(), number_of_agents=number_of_agents,
obs_builder_object=obs_builder)
env.reset(True, True)
......
......@@ -3,11 +3,11 @@ from test_utils import create_and_save_env
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator, random_rail_generator, complex_rail_generator, \
rail_from_file
from flatland.envs.schedule_generators import sparse_schedule_generator, random_schedule_generator, \
complex_schedule_generator, schedule_from_file