Commit cf5196a7 authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

remove complex and random line generators

parent 54ec3c5d
......@@ -67,8 +67,8 @@ class SparseLineGen(BaseLineGen):
:param seed: Initiate random seed generator
"""
def generate(self, rail: GridTransitionMap, num_agents: int, hints: Any = None, num_resets: int = 0,
np_random: RandomState = None) -> Line:
def generate(self, rail: GridTransitionMap, num_agents: int, hints: dict, num_resets: int,
np_random: RandomState) -> Line:
"""
The generator that assigns tasks to all the agents
......
......@@ -205,10 +205,10 @@ class RailEnv(Environment):
# self.rail_generator: RailGenerator = rail_generator
if rail_generator is None:
rail_generator = rail_gen.random_rail_generator()
rail_generator = rail_gen.sparse_rail_generator()
self.rail_generator = rail_generator
if line_generator is None:
line_generator = line_gen.random_line_generator()
line_generator = line_gen.sparse_line_generator()
self.line_generator = line_generator
self.rail: Optional[GridTransitionMap] = None
......@@ -381,10 +381,6 @@ class RailEnv(Environment):
agent.earliest_departure = schedule.earliest_departures[agent_i]
agent.latest_arrival = schedule.latest_arrivals[agent_i]
# Reset distance map - again (just in case if regen_schedule = False)
self.distance_map.reset(self.agents, self.rail)
# Agent Positions Map
self.agent_positions = np.zeros((self.height, self.width), dtype=int) - 1
......
......@@ -65,7 +65,7 @@ class EmptyRailGen(RailGen):
rail_array.fill(0)
return grid_map, None
def rail_from_manual_specifications_generator(rail_spec):
"""
......@@ -144,38 +144,17 @@ def rail_from_file(filename, load_from_package=None) -> RailGenerator:
return generator
class RailFromGridGen(RailGen):
def __init__(self, rail_map):
def __init__(self, rail_map, optionals=None):
self.rail_map = rail_map
self.optionals = optionals
def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0,
np_random: RandomState = None) -> RailGenerator:
return self.rail_map, None
return self.rail_map, self.optionals
def rail_from_grid_transition_map(rail_map) -> RailGenerator:
return RailFromGridGen(rail_map)
def rail_from_grid_transition_map_old(rail_map) -> RailGenerator:
"""
Utility to convert a rail given by a GridTransitionMap map with the correct
16-bit transitions specifications.
Parameters
----------
rail_map : GridTransitionMap object
GridTransitionMap object to return when the generator is called.
Returns
-------
function
Generator function that always returns the given `rail_map` object.
"""
def generator(width: int, height: int, num_agents: int, num_resets: int = 0,
np_random: RandomState = None) -> RailGenerator:
return rail_map, None
return generator
def rail_from_grid_transition_map(rail_map, optionals=None) -> RailGenerator:
return RailFromGridGen(rail_map, optionals)
def sparse_rail_generator(*args, **kwargs):
......@@ -304,7 +283,6 @@ class SparseRailGen(RailGen):
# Fix all transition elements
self._fix_transitions(city_cells, inter_city_lines, grid_map, vector_field)
return grid_map, {'agents_hints': {
'num_agents': num_agents,
'city_positions': city_positions,
......
......@@ -42,7 +42,19 @@ def make_simple_rail() -> Tuple[GridTransitionMap, np.array]:
rail = GridTransitionMap(width=rail_map.shape[1],
height=rail_map.shape[0], transitions=transitions)
rail.grid = rail_map
return rail, rail_map
city_positions = [(0,3), (6, 6)]
train_stations = [
[( (0, 3), 0 ), ( (1, 3), 1 ) ],
[( (6, 6), 0 ), ( (5, 6), 1 ) ],
]
city_orientations = [0, 2]
agents_hints = {'num_agents': 100,
'city_positions': city_positions,
'train_stations': train_stations,
'city_orientations': city_orientations
}
optionals = {'agents_hints': agents_hints}
return rail, rail_map, optionals
def make_disconnected_simple_rail() -> Tuple[GridTransitionMap, np.array]:
......
......@@ -6,18 +6,18 @@ from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.rail_trainrun_data_structures import Waypoint
from flatland.envs.line_generators import random_line_generator
from flatland.envs.line_generators import sparse_line_generator
from flatland.utils.rendertools import RenderTool, AgentRenderVariant
from flatland.utils.simple_rail import make_simple_rail
def test_action_plan(rendering: bool = False):
"""Tests ActionPlanReplayer: does action plan generation and replay work as expected."""
rail, rail_map = make_simple_rail()
rail, rail_map, optionals = make_simple_rail()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
line_generator=random_line_generator(seed=77),
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(seed=77),
number_of_agents=2,
obs_builder_object=GlobalObsForRailEnv(),
remove_agents_at_target=True
......
......@@ -6,7 +6,7 @@ from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.line_generators import random_line_generator
from flatland.envs.line_generators import sparse_line_generator
def test_walker():
......@@ -28,7 +28,7 @@ def test_walker():
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
line_generator=random_line_generator(),
line_generator=sparse_line_generator(),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2,
predictor=ShortestPathPredictorForRailEnv(max_depth=10)),
......
......@@ -4,7 +4,7 @@ from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.line_generators import random_line_generator
from flatland.envs.line_generators import rail_from_grid_transition_map
from flatland.utils.simple_rail import make_simple_rail
from test_utils import ReplayConfig, Replay, run_replay_config, set_penalties_for_replay
......@@ -13,7 +13,7 @@ def test_initial_status():
"""Test that agent lifecycle works correctly ready-to-depart -> active -> done."""
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
line_generator=random_line_generator(), number_of_agents=1,
line_generator=rail_from_grid_transition_map(), number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
remove_agents_at_target=False)
env.reset()
......@@ -126,7 +126,7 @@ def test_status_done_remove():
"""Test that agent lifecycle works correctly ready-to-depart -> active -> done."""
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
line_generator=random_line_generator(), number_of_agents=1,
line_generator=rail_from_grid_transition_map(), number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
remove_agents_at_target=True)
env.reset()
......
......@@ -6,7 +6,7 @@ from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.line_generators import random_line_generator
from flatland.envs.line_generators import sparse_line_generator
from flatland.utils.rendertools import RenderTool
from flatland.utils.simple_rail import make_simple_rail, make_simple_rail_unconnected
......@@ -70,7 +70,7 @@ def test_path_exists(rendering=False):
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
line_generator=random_line_generator(),
line_generator=sparse_line_generator(),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
......@@ -134,7 +134,7 @@ def test_path_not_exists(rendering=False):
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
line_generator=random_line_generator(),
line_generator=sparse_line_generator(),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
......
......@@ -10,7 +10,7 @@ from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.line_generators import random_line_generator
from flatland.envs.line_generators import sparse_line_generator
from flatland.utils.rendertools import RenderTool
from flatland.utils.simple_rail import make_simple_rail
......@@ -21,7 +21,7 @@ def test_global_obs():
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
line_generator=random_line_generator(), number_of_agents=1,
line_generator=sparse_line_generator(), number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv())
global_obs, info = env.reset()
......@@ -93,7 +93,7 @@ def _step_along_shortest_path(env, obs_builder, rail):
def test_reward_function_conflict(rendering=False):
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
line_generator=random_line_generator(), number_of_agents=2,
line_generator=sparse_line_generator(), number_of_agents=2,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
obs_builder: TreeObsForRailEnv = env.obs_builder
env.reset()
......@@ -181,7 +181,7 @@ def test_reward_function_conflict(rendering=False):
def test_reward_function_waiting(rendering=False):
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
line_generator=random_line_generator(), number_of_agents=2,
line_generator=sparse_line_generator(), number_of_agents=2,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
remove_agents_at_target=False)
obs_builder: TreeObsForRailEnv = env.obs_builder
......
......@@ -12,7 +12,7 @@ from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_env_shortest_paths import get_shortest_paths
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.rail_trainrun_data_structures import Waypoint
from flatland.envs.line_generators import random_line_generator
from flatland.envs.line_generators import sparse_line_generator
from flatland.utils.rendertools import RenderTool
from flatland.utils.simple_rail import make_simple_rail, make_simple_rail2, make_invalid_simple_rail
......@@ -25,7 +25,7 @@ def test_dummy_predictor(rendering=False):
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
line_generator=random_line_generator(),
line_generator=sparse_line_generator(),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10)),
)
......@@ -116,7 +116,7 @@ def test_shortest_path_predictor(rendering=False):
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
line_generator=random_line_generator(),
line_generator=sparse_line_generator(),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
......@@ -247,7 +247,7 @@ def test_shortest_path_predictor_conflicts(rendering=False):
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
line_generator=random_line_generator(),
line_generator=sparse_line_generator(),
number_of_agents=2,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
......
......@@ -11,7 +11,7 @@ from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import sparse_rail_generator, rail_from_file
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.line_generators import random_line_generator, sparse_line_generator, line_from_file
from flatland.envs.line_generators import sparse_line_generator, line_from_file
from flatland.utils.simple_rail import make_simple_rail
from flatland.envs.persistence import RailEnvPersister
from flatland.utils.rendertools import RenderTool
......@@ -69,7 +69,7 @@ def test_save_load():
def test_save_load_mpk():
env = RailEnv(width=10, height=10,
rail_generator=complex_rail_generator(nr_start_goal=2, nr_extra=5, min_dist=6, seed=1),
line_generator=complex_line_generator(), number_of_agents=2)
line_generator=sparse_line_generator(), number_of_agents=2)
env.reset()
os.makedirs("tmp", exist_ok=True)
......@@ -121,7 +121,7 @@ def test_rail_environment_single_agent(show=False):
rail = GridTransitionMap(width=3, height=3, transitions=transitions)
rail.grid = rail_map
rail_env = RailEnv(width=3, height=3, rail_generator=rail_from_grid_transition_map(rail),
line_generator=random_line_generator(), number_of_agents=1,
line_generator=sparse_line_generator(), number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv())
else:
rail_env, env_dict = RailEnvPersister.load_new("test_env_loop.pkl", "env_data.tests")
......
......@@ -9,7 +9,7 @@ from flatland.envs.rail_env_shortest_paths import get_shortest_paths, get_k_shor
from flatland.envs.rail_env_utils import load_flatland_environment_from_file
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.rail_trainrun_data_structures import Waypoint
from flatland.envs.line_generators import random_line_generator
from flatland.envs.line_generators import sparse_line_generator
from flatland.utils.rendertools import RenderTool
from flatland.utils.simple_rail import make_disconnected_simple_rail, make_simple_rail_with_alternatives
from flatland.envs.persistence import RailEnvPersister
......@@ -19,7 +19,7 @@ def test_get_shortest_paths_unreachable():
rail, rail_map = make_disconnected_simple_rail()
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
line_generator=random_line_generator(), number_of_agents=1,
line_generator=sparse_line_generator(), number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv())
env.reset()
......@@ -242,7 +242,7 @@ def test_get_k_shortest_paths(rendering=False):
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
line_generator=random_line_generator(),
line_generator=sparse_line_generator(),
number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv(),
)
......
......@@ -10,7 +10,7 @@ from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.line_generators import random_line_generator
from flatland.envs.line_generators import sparse_line_generator
from flatland.utils.simple_rail import make_simple_rail2
from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay
......@@ -77,7 +77,7 @@ def test_malfunction_process():
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail),
line_generator=random_line_generator(),
line_generator=sparse_line_generator(),
number_of_agents=1,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
obs_builder_object=SingleAgentNavigationObs()
......@@ -133,7 +133,7 @@ def test_malfunction_process_statistically():
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail),
line_generator=random_line_generator(),
line_generator=sparse_line_generator(),
number_of_agents=10,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
obs_builder_object=SingleAgentNavigationObs()
......@@ -180,7 +180,7 @@ def test_malfunction_before_entry():
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail),
line_generator=random_line_generator(),
line_generator=sparse_line_generator(),
number_of_agents=10,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
obs_builder_object=SingleAgentNavigationObs()
......@@ -224,7 +224,7 @@ def test_malfunction_values_and_behavior():
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail),
line_generator=random_line_generator(),
line_generator=sparse_line_generator(),
number_of_agents=1,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
obs_builder_object=SingleAgentNavigationObs()
......@@ -253,7 +253,7 @@ def test_initial_malfunction():
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail),
line_generator=random_line_generator(seed=10),
line_generator=sparse_line_generator(seed=10),
number_of_agents=1,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
# Malfunction data generator
......@@ -318,7 +318,7 @@ def test_initial_malfunction_stop_moving():
rail, rail_map = make_simple_rail2()
env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
line_generator=random_line_generator(), number_of_agents=1,
line_generator=sparse_line_generator(), number_of_agents=1,
obs_builder_object=SingleAgentNavigationObs())
env.reset()
......@@ -402,7 +402,7 @@ def test_initial_malfunction_do_nothing():
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail),
line_generator=random_line_generator(),
line_generator=sparse_line_generator(),
number_of_agents=1,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
# Malfunction data generator
......@@ -479,7 +479,7 @@ def tests_random_interference_from_outside():
# Set fixed malfunction duration for this test
rail, rail_map = make_simple_rail2()
env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
line_generator=random_line_generator(seed=2), number_of_agents=1, random_seed=1)
line_generator=sparse_line_generator(seed=2), number_of_agents=1, random_seed=1)
env.reset()
env.agents[0].speed_data['speed'] = 0.33
env.reset(False, False, False, random_seed=10)
......@@ -503,7 +503,7 @@ def tests_random_interference_from_outside():
random.seed(47)
np.random.seed(1234)
env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
line_generator=random_line_generator(seed=2), number_of_agents=1, random_seed=1)
line_generator=sparse_line_generator(seed=2), number_of_agents=1, random_seed=1)
env.reset()
env.agents[0].speed_data['speed'] = 0.33
env.reset(False, False, False, random_seed=10)
......@@ -535,7 +535,7 @@ def test_last_malfunction_step():
rail, rail_map = make_simple_rail2()
env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
line_generator=random_line_generator(seed=2), number_of_agents=1, random_seed=1)
line_generator=sparse_line_generator(seed=2), number_of_agents=1, random_seed=1)
env.reset()
env.agents[0].speed_data['speed'] = 1. / 3.
env.agents[0].target = (0, 0)
......
......@@ -6,7 +6,7 @@ from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.line_generators import random_line_generator
from flatland.envs.line_generators import sparse_line_generator
from flatland.utils.simple_rail import make_simple_rail
"""Tests for `flatland` package."""
......@@ -19,7 +19,7 @@ def test_multiprocessing_tree_obs():
obs_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
line_generator=random_line_generator(), number_of_agents=number_of_agents,
line_generator=sparse_line_generator(), number_of_agents=number_of_agents,
obs_builder_object=obs_builder)
env.reset(True, True)
......
......@@ -2,7 +2,7 @@ from flatland.envs.malfunction_generators import malfunction_from_params, malfun
single_malfunction_generator, MalfunctionParameters
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.line_generators import random_line_generator
from flatland.envs.line_generators import sparse_line_generator
from flatland.utils.simple_rail import make_simple_rail2
from flatland.envs.persistence import RailEnvPersister
......@@ -22,7 +22,7 @@ def test_malfanction_from_params():
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail),
line_generator=random_line_generator(),
line_generator=sparse_line_generator(),
number_of_agents=10,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data)
)
......@@ -49,7 +49,7 @@ def test_malfanction_to_and_from_file():
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail),
line_generator=random_line_generator(),
line_generator=sparse_line_generator(),
number_of_agents=10,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data)
)
......@@ -62,7 +62,7 @@ def test_malfanction_to_and_from_file():
env2 = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail),
line_generator=random_line_generator(),
line_generator=sparse_line_generator(),
number_of_agents=10,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data)
)
......@@ -87,7 +87,7 @@ def test_single_malfunction_generator():
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail),
line_generator=random_line_generator(),
line_generator=sparse_line_generator(),
number_of_agents=10,
malfunction_generator_and_process_data=single_malfunction_generator(earlierst_malfunction=10,
malfunction_duration=5)
......
......@@ -4,7 +4,7 @@ from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_grid_transition_map, sparse_rail_generator
from flatland.envs.line_generators import random_line_generator, sparse_line_generator
from flatland.envs.line_generators import sparse_line_generator
from flatland.utils.simple_rail import make_simple_rail2
......@@ -15,7 +15,7 @@ def test_random_seeding():
# Move target to unreachable position in order to not interfere with test
for idx in range(100):
env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
line_generator=random_line_generator(seed=12), number_of_agents=10)
line_generator=sparse_line_generator(seed=12), number_of_agents=10)
env.reset(True, True, False, random_seed=1)
env.agents[0].target = (0, 0)
......@@ -49,11 +49,11 @@ def test_seeding_and_observations():
# Make two seperate envs with different observation builders
# Global Observation
env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
line_generator=random_line_generator(seed=12), number_of_agents=10,
line_generator=rail_from_grid_transition_map(seed=12), number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv())
# Tree Observation
env2 = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
line_generator=random_line_generator(seed=12), number_of_agents=10,
line_generator=rail_from_grid_transition_map(seed=12), number_of_agents=10,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
env.reset(False, False, False, random_seed=12)
......@@ -107,12 +107,12 @@ def test_seeding_and_malfunction():
# Global Observation
for tests in range(1, 100):
env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
line_generator=random_line_generator(), number_of_agents=10,
line_generator=rail_from_grid_transition_map(), number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv())
# Tree Observation
env2 = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
line_generator=random_line_generator(), number_of_agents=10,
line_generator=rail_from_grid_transition_map(), number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv())
env.reset(True, False, True, random_seed=tests)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment