diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 22815f33b2ab0df490cddce489c572b905a5e555..364a00db413b32a276e896c109b48a56bdde1d46 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -106,7 +106,7 @@ class RailEnv(Environment): malfunction_generator_and_process_data=None, # mal_gen.no_malfunction_generator(), malfunction_generator=None, remove_agents_at_target=True, - random_seed=1, + random_seed=None, record_steps=False, ): """ @@ -161,7 +161,6 @@ class RailEnv(Environment): self.number_of_agents = number_of_agents - # self.rail_generator: RailGenerator = rail_generator if rail_generator is None: rail_generator = rail_gen.sparse_rail_generator() self.rail_generator = rail_generator @@ -193,9 +192,7 @@ class RailEnv(Environment): self.action_space = [5] self._seed() - self._seed() - self.random_seed = random_seed - if self.random_seed: + if random_seed: self._seed(seed=random_seed) self.agent_positions = None @@ -211,6 +208,14 @@ class RailEnv(Environment): def _seed(self, seed=None): self.np_random, seed = seeding.np_random(seed) random.seed(seed) + self.random_seed = seed + + # Keep track of all the seeds in order + if not hasattr(self, 'seed_history'): + self.seed_history = [seed] + if self.seed_history[-1] != seed: + self.seed_history.append(seed) + return [seed] # no more agent_handles @@ -252,7 +257,7 @@ class RailEnv(Environment): ( agent.state.is_on_map_state() and agent.speed_counter.is_cell_entry ) def reset(self, regenerate_rail: bool = True, regenerate_schedule: bool = True, *, - random_seed: bool = None) -> Tuple[Dict, Dict]: + random_seed: int = None) -> Tuple[Dict, Dict]: """ reset(regenerate_rail, regenerate_schedule, activate_agents, random_seed) @@ -264,7 +269,7 @@ class RailEnv(Environment): regenerate the rails regenerate_schedule : bool, optional regenerate the schedule and the static agents - random_seed : bool, optional + random_seed : int, optional random seed for environment Returns @@ -355,10 +360,10 @@ class RailEnv(Environment): """ Update the agent_positions array for agents that changed positions """ for agent in self.agents: if not ignore_old_positions or agent.old_position != agent.position: - self.agent_positions[agent.position] = agent.handle + if agent.position is not None: + self.agent_positions[agent.position] = agent.handle if agent.old_position is not None: self.agent_positions[agent.old_position] = -1 - def generate_state_transition_signals(self, agent, preprocessed_action, movement_allowed): """ Generate State Transitions Signals used in the state machine """ @@ -597,7 +602,7 @@ class RailEnv(Environment): # Check if episode has ended and update rewards and dones self.end_of_episode_update(have_all_agents_ended) - self._update_agent_positions_map + self._update_agent_positions_map() return self._get_observations(), self.rewards_dict, self.dones, self.get_info_dict() diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 356bfd1e00dba35e10e16815d3a306077f9acf6f..457574caafc61f5b4e5c9dd63cfb9fe0763d2283 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -163,7 +163,7 @@ def sparse_rail_generator(*args, **kwargs): class SparseRailGen(RailGen): def __init__(self, max_num_cities: int = 2, grid_mode: bool = False, max_rails_between_cities: int = 2, - max_rail_pairs_in_city: int = 2, seed=0) -> RailGenerator: + max_rail_pairs_in_city: int = 2, seed=None) -> RailGenerator: """ Generates railway networks with cities and inner city rails @@ -189,7 +189,7 @@ class SparseRailGen(RailGen): self.grid_mode = grid_mode self.max_rails_between_cities = max_rails_between_cities self.max_rail_pairs_in_city = max_rail_pairs_in_city - self.seed = seed # TODO: seed in constructor or generate? + self.seed = seed def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0, @@ -217,8 +217,10 @@ class SparseRailGen(RailGen): 'train_stations': locations of train stations for start and targets 'city_orientations' : orientation of cities """ - if np_random is None: + if self.seed is not None: np_random = RandomState(self.seed) + elif np_random is None: + np_random = RandomState(np.random.randint(2**32)) rail_trans = RailEnvTransitions() grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans) diff --git a/tests/test_flatland_envs_observations.py b/tests/test_flatland_envs_observations.py index 0d21463d933a3baf70bfb55cdd8719268a97862a..92fbdf0a325934abefd98adaf9c32fd9ecf6cb5f 100644 --- a/tests/test_flatland_envs_observations.py +++ b/tests/test_flatland_envs_observations.py @@ -182,7 +182,7 @@ def test_reward_function_waiting(rendering=False): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals), line_generator=sparse_line_generator(), number_of_agents=2, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), - remove_agents_at_target=False) + remove_agents_at_target=False, random_seed=1) obs_builder: TreeObsForRailEnv = env.obs_builder env.reset() diff --git a/tests/test_flatland_envs_sparse_rail_generator.py b/tests/test_flatland_envs_sparse_rail_generator.py index d98b4b32ad55b739827a5736d9ea8860771583a1..4da868f48ebe0a96c2cc23bc7362e5cd341047b0 100644 --- a/tests/test_flatland_envs_sparse_rail_generator.py +++ b/tests/test_flatland_envs_sparse_rail_generator.py @@ -14,11 +14,12 @@ from flatland.utils.rendertools import RenderTool def test_sparse_rail_generator(): env = RailEnv(width=50, height=50, rail_generator=sparse_rail_generator(max_num_cities=10, max_rails_between_cities=3, - seed=5, + seed=1, grid_mode=False ), line_generator=sparse_line_generator(), number_of_agents=10, - obs_builder_object=GlobalObsForRailEnv()) + obs_builder_object=GlobalObsForRailEnv(), + random_seed=1) env.reset(False, False) # for r in range(env.height): # for c in range(env.width): @@ -499,8 +500,8 @@ def test_sparse_rail_generator(): for a in range(env.get_num_agents()): s0 = Vec2d.get_manhattan_distance(env.agents[a].initial_position, (0, 0)) s1 = Vec2d.get_chebyshev_distance(env.agents[a].initial_position, (0, 0)) - assert s0 == 44, "actual={}".format(s0) - assert s1 == 34, "actual={}".format(s1) + assert s0 == 46, "actual={}".format(s0) + assert s1 == 26, "actual={}".format(s1) def test_sparse_rail_generator_deterministic(): @@ -516,14 +517,13 @@ def test_sparse_rail_generator_deterministic(): seed=215545, # Random seed grid_mode=True ), - line_generator=sparse_line_generator(speed_ration_map), number_of_agents=1) + line_generator=sparse_line_generator(speed_ration_map), number_of_agents=1, random_seed=1) env.reset() # for r in range(env.height): # for c in range(env.width): # print("assert env.rail.get_full_transitions({}, {}) == {}, \"[{}][{}]\"".format(r, c, # env.rail.get_full_transitions( # r, c), r, c)) - assert env.rail.get_full_transitions(0, 0) == 0, "[0][0]" assert env.rail.get_full_transitions(0, 1) == 0, "[0][1]" assert env.rail.get_full_transitions(0, 2) == 0, "[0][2]" assert env.rail.get_full_transitions(0, 3) == 0, "[0][3]" @@ -561,15 +561,15 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(1, 10) == 0, "[1][10]" assert env.rail.get_full_transitions(1, 11) == 16386, "[1][11]" assert env.rail.get_full_transitions(1, 12) == 1025, "[1][12]" - assert env.rail.get_full_transitions(1, 13) == 1025, "[1][13]" + assert env.rail.get_full_transitions(1, 13) == 17411, "[1][13]" assert env.rail.get_full_transitions(1, 14) == 17411, "[1][14]" assert env.rail.get_full_transitions(1, 15) == 1025, "[1][15]" assert env.rail.get_full_transitions(1, 16) == 1025, "[1][16]" assert env.rail.get_full_transitions(1, 17) == 1025, "[1][17]" assert env.rail.get_full_transitions(1, 18) == 1025, "[1][18]" - assert env.rail.get_full_transitions(1, 19) == 4608, "[1][19]" - assert env.rail.get_full_transitions(1, 20) == 0, "[1][20]" - assert env.rail.get_full_transitions(1, 21) == 0, "[1][21]" + assert env.rail.get_full_transitions(1, 19) == 5633, "[1][19]" + assert env.rail.get_full_transitions(1, 20) == 5633, "[1][20]" + assert env.rail.get_full_transitions(1, 21) == 4608, "[1][21]" assert env.rail.get_full_transitions(1, 22) == 0, "[1][22]" assert env.rail.get_full_transitions(1, 23) == 0, "[1][23]" assert env.rail.get_full_transitions(1, 24) == 0, "[1][24]" @@ -585,16 +585,16 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(2, 9) == 0, "[2][9]" assert env.rail.get_full_transitions(2, 10) == 0, "[2][10]" assert env.rail.get_full_transitions(2, 11) == 32800, "[2][11]" - assert env.rail.get_full_transitions(2, 12) == 0, "[2][12]" - assert env.rail.get_full_transitions(2, 13) == 0, "[2][13]" + assert env.rail.get_full_transitions(2, 12) == 16386, "[2][12]" + assert env.rail.get_full_transitions(2, 13) == 34864, "[2][13]" assert env.rail.get_full_transitions(2, 14) == 32800, "[2][14]" assert env.rail.get_full_transitions(2, 15) == 0, "[2][15]" assert env.rail.get_full_transitions(2, 16) == 0, "[2][16]" assert env.rail.get_full_transitions(2, 17) == 0, "[2][17]" assert env.rail.get_full_transitions(2, 18) == 0, "[2][18]" assert env.rail.get_full_transitions(2, 19) == 32800, "[2][19]" - assert env.rail.get_full_transitions(2, 20) == 0, "[2][20]" - assert env.rail.get_full_transitions(2, 21) == 0, "[2][21]" + assert env.rail.get_full_transitions(2, 20) == 32800, "[2][20]" + assert env.rail.get_full_transitions(2, 21) == 32800, "[2][21]" assert env.rail.get_full_transitions(2, 22) == 0, "[2][22]" assert env.rail.get_full_transitions(2, 23) == 0, "[2][23]" assert env.rail.get_full_transitions(2, 24) == 0, "[2][24]" @@ -610,16 +610,16 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(3, 9) == 0, "[3][9]" assert env.rail.get_full_transitions(3, 10) == 0, "[3][10]" assert env.rail.get_full_transitions(3, 11) == 32800, "[3][11]" - assert env.rail.get_full_transitions(3, 12) == 0, "[3][12]" - assert env.rail.get_full_transitions(3, 13) == 0, "[3][13]" + assert env.rail.get_full_transitions(3, 12) == 32800, "[3][12]" + assert env.rail.get_full_transitions(3, 13) == 32800, "[3][13]" assert env.rail.get_full_transitions(3, 14) == 32800, "[3][14]" assert env.rail.get_full_transitions(3, 15) == 0, "[3][15]" assert env.rail.get_full_transitions(3, 16) == 0, "[3][16]" assert env.rail.get_full_transitions(3, 17) == 0, "[3][17]" assert env.rail.get_full_transitions(3, 18) == 0, "[3][18]" - assert env.rail.get_full_transitions(3, 19) == 32872, "[3][19]" - assert env.rail.get_full_transitions(3, 20) == 4608, "[3][20]" - assert env.rail.get_full_transitions(3, 21) == 0, "[3][21]" + assert env.rail.get_full_transitions(3, 19) == 32800, "[3][19]" + assert env.rail.get_full_transitions(3, 20) == 32872, "[3][20]" + assert env.rail.get_full_transitions(3, 21) == 37408, "[3][21]" assert env.rail.get_full_transitions(3, 22) == 0, "[3][22]" assert env.rail.get_full_transitions(3, 23) == 0, "[3][23]" assert env.rail.get_full_transitions(3, 24) == 0, "[3][24]" @@ -635,16 +635,16 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(4, 9) == 0, "[4][9]" assert env.rail.get_full_transitions(4, 10) == 0, "[4][10]" assert env.rail.get_full_transitions(4, 11) == 32800, "[4][11]" - assert env.rail.get_full_transitions(4, 12) == 0, "[4][12]" - assert env.rail.get_full_transitions(4, 13) == 0, "[4][13]" + assert env.rail.get_full_transitions(4, 12) == 32800, "[4][12]" + assert env.rail.get_full_transitions(4, 13) == 32800, "[4][13]" assert env.rail.get_full_transitions(4, 14) == 32800, "[4][14]" assert env.rail.get_full_transitions(4, 15) == 0, "[4][15]" assert env.rail.get_full_transitions(4, 16) == 0, "[4][16]" assert env.rail.get_full_transitions(4, 17) == 0, "[4][17]" assert env.rail.get_full_transitions(4, 18) == 0, "[4][18]" - assert env.rail.get_full_transitions(4, 19) == 49186, "[4][19]" - assert env.rail.get_full_transitions(4, 20) == 34864, "[4][20]" - assert env.rail.get_full_transitions(4, 21) == 0, "[4][21]" + assert env.rail.get_full_transitions(4, 19) == 32800, "[4][19]" + assert env.rail.get_full_transitions(4, 20) == 32800, "[4][20]" + assert env.rail.get_full_transitions(4, 21) == 32800, "[4][21]" assert env.rail.get_full_transitions(4, 22) == 0, "[4][22]" assert env.rail.get_full_transitions(4, 23) == 0, "[4][23]" assert env.rail.get_full_transitions(4, 24) == 0, "[4][24]" @@ -659,18 +659,18 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(5, 8) == 0, "[5][8]" assert env.rail.get_full_transitions(5, 9) == 0, "[5][9]" assert env.rail.get_full_transitions(5, 10) == 0, "[5][10]" - assert env.rail.get_full_transitions(5, 11) == 32800, "[5][11]" - assert env.rail.get_full_transitions(5, 12) == 0, "[5][12]" - assert env.rail.get_full_transitions(5, 13) == 0, "[5][13]" + assert env.rail.get_full_transitions(5, 11) == 49186, "[5][11]" + assert env.rail.get_full_transitions(5, 12) == 3089, "[5][12]" + assert env.rail.get_full_transitions(5, 13) == 2064, "[5][13]" assert env.rail.get_full_transitions(5, 14) == 32800, "[5][14]" assert env.rail.get_full_transitions(5, 15) == 0, "[5][15]" assert env.rail.get_full_transitions(5, 16) == 0, "[5][16]" assert env.rail.get_full_transitions(5, 17) == 0, "[5][17]" assert env.rail.get_full_transitions(5, 18) == 0, "[5][18]" - assert env.rail.get_full_transitions(5, 19) == 32800, "[5][19]" - assert env.rail.get_full_transitions(5, 20) == 32800, "[5][20]" - assert env.rail.get_full_transitions(5, 21) == 0, "[5][21]" - assert env.rail.get_full_transitions(5, 22) == 0, "[5][22]" + assert env.rail.get_full_transitions(5, 19) == 49186, "[5][19]" + assert env.rail.get_full_transitions(5, 20) == 34864, "[5][20]" + assert env.rail.get_full_transitions(5, 21) == 32872, "[5][21]" + assert env.rail.get_full_transitions(5, 22) == 4608, "[5][22]" assert env.rail.get_full_transitions(5, 23) == 0, "[5][23]" assert env.rail.get_full_transitions(5, 24) == 0, "[5][24]" assert env.rail.get_full_transitions(6, 0) == 16386, "[6][0]" @@ -694,8 +694,8 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(6, 18) == 0, "[6][18]" assert env.rail.get_full_transitions(6, 19) == 32800, "[6][19]" assert env.rail.get_full_transitions(6, 20) == 32800, "[6][20]" - assert env.rail.get_full_transitions(6, 21) == 0, "[6][21]" - assert env.rail.get_full_transitions(6, 22) == 0, "[6][22]" + assert env.rail.get_full_transitions(6, 21) == 32800, "[6][21]" + assert env.rail.get_full_transitions(6, 22) == 32800, "[6][22]" assert env.rail.get_full_transitions(6, 23) == 0, "[6][23]" assert env.rail.get_full_transitions(6, 24) == 0, "[6][24]" assert env.rail.get_full_transitions(7, 0) == 32800, "[7][0]" @@ -717,10 +717,10 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(7, 16) == 0, "[7][16]" assert env.rail.get_full_transitions(7, 17) == 0, "[7][17]" assert env.rail.get_full_transitions(7, 18) == 0, "[7][18]" - assert env.rail.get_full_transitions(7, 19) == 32800, "[7][19]" - assert env.rail.get_full_transitions(7, 20) == 32800, "[7][20]" - assert env.rail.get_full_transitions(7, 21) == 0, "[7][21]" - assert env.rail.get_full_transitions(7, 22) == 0, "[7][22]" + assert env.rail.get_full_transitions(7, 19) == 32872, "[7][19]" + assert env.rail.get_full_transitions(7, 20) == 37408, "[7][20]" + assert env.rail.get_full_transitions(7, 21) == 49186, "[7][21]" + assert env.rail.get_full_transitions(7, 22) == 2064, "[7][22]" assert env.rail.get_full_transitions(7, 23) == 0, "[7][23]" assert env.rail.get_full_transitions(7, 24) == 0, "[7][24]" assert env.rail.get_full_transitions(8, 0) == 32800, "[8][0]" @@ -742,9 +742,9 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(8, 16) == 0, "[8][16]" assert env.rail.get_full_transitions(8, 17) == 0, "[8][17]" assert env.rail.get_full_transitions(8, 18) == 0, "[8][18]" - assert env.rail.get_full_transitions(8, 19) == 32872, "[8][19]" - assert env.rail.get_full_transitions(8, 20) == 37408, "[8][20]" - assert env.rail.get_full_transitions(8, 21) == 0, "[8][21]" + assert env.rail.get_full_transitions(8, 19) == 32800, "[8][19]" + assert env.rail.get_full_transitions(8, 20) == 32800, "[8][20]" + assert env.rail.get_full_transitions(8, 21) == 32800, "[8][21]" assert env.rail.get_full_transitions(8, 22) == 0, "[8][22]" assert env.rail.get_full_transitions(8, 23) == 0, "[8][23]" assert env.rail.get_full_transitions(8, 24) == 0, "[8][24]" @@ -767,9 +767,9 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(9, 16) == 0, "[9][16]" assert env.rail.get_full_transitions(9, 17) == 0, "[9][17]" assert env.rail.get_full_transitions(9, 18) == 0, "[9][18]" - assert env.rail.get_full_transitions(9, 19) == 49186, "[9][19]" - assert env.rail.get_full_transitions(9, 20) == 2064, "[9][20]" - assert env.rail.get_full_transitions(9, 21) == 0, "[9][21]" + assert env.rail.get_full_transitions(9, 19) == 32800, "[9][19]" + assert env.rail.get_full_transitions(9, 20) == 49186, "[9][20]" + assert env.rail.get_full_transitions(9, 21) == 34864, "[9][21]" assert env.rail.get_full_transitions(9, 22) == 0, "[9][22]" assert env.rail.get_full_transitions(9, 23) == 0, "[9][23]" assert env.rail.get_full_transitions(9, 24) == 0, "[9][24]" @@ -793,8 +793,8 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(10, 17) == 0, "[10][17]" assert env.rail.get_full_transitions(10, 18) == 0, "[10][18]" assert env.rail.get_full_transitions(10, 19) == 32800, "[10][19]" - assert env.rail.get_full_transitions(10, 20) == 0, "[10][20]" - assert env.rail.get_full_transitions(10, 21) == 0, "[10][21]" + assert env.rail.get_full_transitions(10, 20) == 32800, "[10][20]" + assert env.rail.get_full_transitions(10, 21) == 32800, "[10][21]" assert env.rail.get_full_transitions(10, 22) == 0, "[10][22]" assert env.rail.get_full_transitions(10, 23) == 0, "[10][23]" assert env.rail.get_full_transitions(10, 24) == 0, "[10][24]" @@ -817,9 +817,9 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(11, 16) == 0, "[11][16]" assert env.rail.get_full_transitions(11, 17) == 0, "[11][17]" assert env.rail.get_full_transitions(11, 18) == 0, "[11][18]" - assert env.rail.get_full_transitions(11, 19) == 32872, "[11][19]" - assert env.rail.get_full_transitions(11, 20) == 5633, "[11][20]" - assert env.rail.get_full_transitions(11, 21) == 4608, "[11][21]" + assert env.rail.get_full_transitions(11, 19) == 32800, "[11][19]" + assert env.rail.get_full_transitions(11, 20) == 32800, "[11][20]" + assert env.rail.get_full_transitions(11, 21) == 32800, "[11][21]" assert env.rail.get_full_transitions(11, 22) == 0, "[11][22]" assert env.rail.get_full_transitions(11, 23) == 0, "[11][23]" assert env.rail.get_full_transitions(11, 24) == 0, "[11][24]" @@ -1017,9 +1017,9 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(19, 16) == 1025, "[19][16]" assert env.rail.get_full_transitions(19, 17) == 1025, "[19][17]" assert env.rail.get_full_transitions(19, 18) == 1025, "[19][18]" - assert env.rail.get_full_transitions(19, 19) == 37408, "[19][19]" - assert env.rail.get_full_transitions(19, 20) == 32800, "[19][20]" - assert env.rail.get_full_transitions(19, 21) == 32800, "[19][21]" + assert env.rail.get_full_transitions(19, 19) == 38505, "[19][19]" + assert env.rail.get_full_transitions(19, 20) == 3089, "[19][20]" + assert env.rail.get_full_transitions(19, 21) == 2064, "[19][21]" assert env.rail.get_full_transitions(19, 22) == 0, "[19][22]" assert env.rail.get_full_transitions(19, 23) == 0, "[19][23]" assert env.rail.get_full_transitions(19, 24) == 0, "[19][24]" @@ -1043,8 +1043,8 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(20, 17) == 0, "[20][17]" assert env.rail.get_full_transitions(20, 18) == 0, "[20][18]" assert env.rail.get_full_transitions(20, 19) == 32800, "[20][19]" - assert env.rail.get_full_transitions(20, 20) == 32800, "[20][20]" - assert env.rail.get_full_transitions(20, 21) == 32800, "[20][21]" + assert env.rail.get_full_transitions(20, 20) == 0, "[20][20]" + assert env.rail.get_full_transitions(20, 21) == 0, "[20][21]" assert env.rail.get_full_transitions(20, 22) == 0, "[20][22]" assert env.rail.get_full_transitions(20, 23) == 0, "[20][23]" assert env.rail.get_full_transitions(20, 24) == 0, "[20][24]" @@ -1067,9 +1067,9 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(21, 16) == 0, "[21][16]" assert env.rail.get_full_transitions(21, 17) == 0, "[21][17]" assert env.rail.get_full_transitions(21, 18) == 0, "[21][18]" - assert env.rail.get_full_transitions(21, 19) == 32800, "[21][19]" - assert env.rail.get_full_transitions(21, 20) == 32872, "[21][20]" - assert env.rail.get_full_transitions(21, 21) == 37408, "[21][21]" + assert env.rail.get_full_transitions(21, 19) == 32872, "[21][19]" + assert env.rail.get_full_transitions(21, 20) == 4608, "[21][20]" + assert env.rail.get_full_transitions(21, 21) == 0, "[21][21]" assert env.rail.get_full_transitions(21, 22) == 0, "[21][22]" assert env.rail.get_full_transitions(21, 23) == 0, "[21][23]" assert env.rail.get_full_transitions(21, 24) == 0, "[21][24]" @@ -1092,9 +1092,9 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(22, 16) == 0, "[22][16]" assert env.rail.get_full_transitions(22, 17) == 0, "[22][17]" assert env.rail.get_full_transitions(22, 18) == 0, "[22][18]" - assert env.rail.get_full_transitions(22, 19) == 32800, "[22][19]" - assert env.rail.get_full_transitions(22, 20) == 32800, "[22][20]" - assert env.rail.get_full_transitions(22, 21) == 32800, "[22][21]" + assert env.rail.get_full_transitions(22, 19) == 49186, "[22][19]" + assert env.rail.get_full_transitions(22, 20) == 34864, "[22][20]" + assert env.rail.get_full_transitions(22, 21) == 0, "[22][21]" assert env.rail.get_full_transitions(22, 22) == 0, "[22][22]" assert env.rail.get_full_transitions(22, 23) == 0, "[22][23]" assert env.rail.get_full_transitions(22, 24) == 0, "[22][24]" @@ -1103,9 +1103,9 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(23, 2) == 0, "[23][2]" assert env.rail.get_full_transitions(23, 3) == 0, "[23][3]" assert env.rail.get_full_transitions(23, 4) == 0, "[23][4]" - assert env.rail.get_full_transitions(23, 5) == 0, "[23][5]" - assert env.rail.get_full_transitions(23, 6) == 0, "[23][6]" - assert env.rail.get_full_transitions(23, 7) == 0, "[23][7]" + assert env.rail.get_full_transitions(23, 5) == 16386, "[23][5]" + assert env.rail.get_full_transitions(23, 6) == 1025, "[23][6]" + assert env.rail.get_full_transitions(23, 7) == 4608, "[23][7]" assert env.rail.get_full_transitions(23, 8) == 0, "[23][8]" assert env.rail.get_full_transitions(23, 9) == 0, "[23][9]" assert env.rail.get_full_transitions(23, 10) == 0, "[23][10]" @@ -1116,11 +1116,11 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(23, 15) == 0, "[23][15]" assert env.rail.get_full_transitions(23, 16) == 0, "[23][16]" assert env.rail.get_full_transitions(23, 17) == 0, "[23][17]" - assert env.rail.get_full_transitions(23, 18) == 0, "[23][18]" - assert env.rail.get_full_transitions(23, 19) == 49186, "[23][19]" - assert env.rail.get_full_transitions(23, 20) == 34864, "[23][20]" - assert env.rail.get_full_transitions(23, 21) == 32872, "[23][21]" - assert env.rail.get_full_transitions(23, 22) == 4608, "[23][22]" + assert env.rail.get_full_transitions(23, 18) == 16386, "[23][18]" + assert env.rail.get_full_transitions(23, 19) == 34864, "[23][19]" + assert env.rail.get_full_transitions(23, 20) == 32872, "[23][20]" + assert env.rail.get_full_transitions(23, 21) == 4608, "[23][21]" + assert env.rail.get_full_transitions(23, 22) == 0, "[23][22]" assert env.rail.get_full_transitions(23, 23) == 0, "[23][23]" assert env.rail.get_full_transitions(23, 24) == 0, "[23][24]" assert env.rail.get_full_transitions(24, 0) == 0, "[24][0]" @@ -1128,9 +1128,9 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(24, 2) == 1025, "[24][2]" assert env.rail.get_full_transitions(24, 3) == 5633, "[24][3]" assert env.rail.get_full_transitions(24, 4) == 17411, "[24][4]" - assert env.rail.get_full_transitions(24, 5) == 1025, "[24][5]" + assert env.rail.get_full_transitions(24, 5) == 3089, "[24][5]" assert env.rail.get_full_transitions(24, 6) == 1025, "[24][6]" - assert env.rail.get_full_transitions(24, 7) == 1025, "[24][7]" + assert env.rail.get_full_transitions(24, 7) == 1097, "[24][7]" assert env.rail.get_full_transitions(24, 8) == 5633, "[24][8]" assert env.rail.get_full_transitions(24, 9) == 17411, "[24][9]" assert env.rail.get_full_transitions(24, 10) == 1025, "[24][10]" @@ -1141,11 +1141,11 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(24, 15) == 0, "[24][15]" assert env.rail.get_full_transitions(24, 16) == 0, "[24][16]" assert env.rail.get_full_transitions(24, 17) == 0, "[24][17]" - assert env.rail.get_full_transitions(24, 18) == 0, "[24][18]" + assert env.rail.get_full_transitions(24, 18) == 32800, "[24][18]" assert env.rail.get_full_transitions(24, 19) == 32800, "[24][19]" assert env.rail.get_full_transitions(24, 20) == 32800, "[24][20]" assert env.rail.get_full_transitions(24, 21) == 32800, "[24][21]" - assert env.rail.get_full_transitions(24, 22) == 32800, "[24][22]" + assert env.rail.get_full_transitions(24, 22) == 0, "[24][22]" assert env.rail.get_full_transitions(24, 23) == 0, "[24][23]" assert env.rail.get_full_transitions(24, 24) == 0, "[24][24]" assert env.rail.get_full_transitions(25, 0) == 0, "[25][0]" @@ -1153,24 +1153,24 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(25, 2) == 0, "[25][2]" assert env.rail.get_full_transitions(25, 3) == 72, "[25][3]" assert env.rail.get_full_transitions(25, 4) == 3089, "[25][4]" - assert env.rail.get_full_transitions(25, 5) == 1025, "[25][5]" + assert env.rail.get_full_transitions(25, 5) == 5633, "[25][5]" assert env.rail.get_full_transitions(25, 6) == 1025, "[25][6]" - assert env.rail.get_full_transitions(25, 7) == 1025, "[25][7]" + assert env.rail.get_full_transitions(25, 7) == 17411, "[25][7]" assert env.rail.get_full_transitions(25, 8) == 1097, "[25][8]" assert env.rail.get_full_transitions(25, 9) == 2064, "[25][9]" assert env.rail.get_full_transitions(25, 10) == 0, "[25][10]" - assert env.rail.get_full_transitions(25, 11) == 32872, "[25][11]" - assert env.rail.get_full_transitions(25, 12) == 5633, "[25][12]" - assert env.rail.get_full_transitions(25, 13) == 4608, "[25][13]" + assert env.rail.get_full_transitions(25, 11) == 32800, "[25][11]" + assert env.rail.get_full_transitions(25, 12) == 0, "[25][12]" + assert env.rail.get_full_transitions(25, 13) == 0, "[25][13]" assert env.rail.get_full_transitions(25, 14) == 0, "[25][14]" assert env.rail.get_full_transitions(25, 15) == 0, "[25][15]" assert env.rail.get_full_transitions(25, 16) == 0, "[25][16]" assert env.rail.get_full_transitions(25, 17) == 0, "[25][17]" - assert env.rail.get_full_transitions(25, 18) == 0, "[25][18]" - assert env.rail.get_full_transitions(25, 19) == 32872, "[25][19]" - assert env.rail.get_full_transitions(25, 20) == 37408, "[25][20]" - assert env.rail.get_full_transitions(25, 21) == 49186, "[25][21]" - assert env.rail.get_full_transitions(25, 22) == 2064, "[25][22]" + assert env.rail.get_full_transitions(25, 18) == 72, "[25][18]" + assert env.rail.get_full_transitions(25, 19) == 37408, "[25][19]" + assert env.rail.get_full_transitions(25, 20) == 49186, "[25][20]" + assert env.rail.get_full_transitions(25, 21) == 2064, "[25][21]" + assert env.rail.get_full_transitions(25, 22) == 0, "[25][22]" assert env.rail.get_full_transitions(25, 23) == 0, "[25][23]" assert env.rail.get_full_transitions(25, 24) == 0, "[25][24]" assert env.rail.get_full_transitions(26, 0) == 0, "[26][0]" @@ -1178,23 +1178,23 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(26, 2) == 0, "[26][2]" assert env.rail.get_full_transitions(26, 3) == 0, "[26][3]" assert env.rail.get_full_transitions(26, 4) == 0, "[26][4]" - assert env.rail.get_full_transitions(26, 5) == 0, "[26][5]" - assert env.rail.get_full_transitions(26, 6) == 0, "[26][6]" - assert env.rail.get_full_transitions(26, 7) == 0, "[26][7]" + assert env.rail.get_full_transitions(26, 5) == 72, "[26][5]" + assert env.rail.get_full_transitions(26, 6) == 1025, "[26][6]" + assert env.rail.get_full_transitions(26, 7) == 2064, "[26][7]" assert env.rail.get_full_transitions(26, 8) == 0, "[26][8]" assert env.rail.get_full_transitions(26, 9) == 0, "[26][9]" assert env.rail.get_full_transitions(26, 10) == 0, "[26][10]" assert env.rail.get_full_transitions(26, 11) == 32800, "[26][11]" - assert env.rail.get_full_transitions(26, 12) == 32800, "[26][12]" - assert env.rail.get_full_transitions(26, 13) == 32800, "[26][13]" + assert env.rail.get_full_transitions(26, 12) == 0, "[26][12]" + assert env.rail.get_full_transitions(26, 13) == 0, "[26][13]" assert env.rail.get_full_transitions(26, 14) == 0, "[26][14]" assert env.rail.get_full_transitions(26, 15) == 0, "[26][15]" assert env.rail.get_full_transitions(26, 16) == 0, "[26][16]" assert env.rail.get_full_transitions(26, 17) == 0, "[26][17]" assert env.rail.get_full_transitions(26, 18) == 0, "[26][18]" - assert env.rail.get_full_transitions(26, 19) == 32800, "[26][19]" - assert env.rail.get_full_transitions(26, 20) == 32800, "[26][20]" - assert env.rail.get_full_transitions(26, 21) == 32800, "[26][21]" + assert env.rail.get_full_transitions(26, 19) == 32872, "[26][19]" + assert env.rail.get_full_transitions(26, 20) == 37408, "[26][20]" + assert env.rail.get_full_transitions(26, 21) == 0, "[26][21]" assert env.rail.get_full_transitions(26, 22) == 0, "[26][22]" assert env.rail.get_full_transitions(26, 23) == 0, "[26][23]" assert env.rail.get_full_transitions(26, 24) == 0, "[26][24]" @@ -1210,16 +1210,16 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(27, 9) == 0, "[27][9]" assert env.rail.get_full_transitions(27, 10) == 0, "[27][10]" assert env.rail.get_full_transitions(27, 11) == 32800, "[27][11]" - assert env.rail.get_full_transitions(27, 12) == 32800, "[27][12]" - assert env.rail.get_full_transitions(27, 13) == 72, "[27][13]" - assert env.rail.get_full_transitions(27, 14) == 4608, "[27][14]" + assert env.rail.get_full_transitions(27, 12) == 0, "[27][12]" + assert env.rail.get_full_transitions(27, 13) == 0, "[27][13]" + assert env.rail.get_full_transitions(27, 14) == 0, "[27][14]" assert env.rail.get_full_transitions(27, 15) == 0, "[27][15]" assert env.rail.get_full_transitions(27, 16) == 0, "[27][16]" assert env.rail.get_full_transitions(27, 17) == 0, "[27][17]" assert env.rail.get_full_transitions(27, 18) == 0, "[27][18]" - assert env.rail.get_full_transitions(27, 19) == 32800, "[27][19]" - assert env.rail.get_full_transitions(27, 20) == 49186, "[27][20]" - assert env.rail.get_full_transitions(27, 21) == 34864, "[27][21]" + assert env.rail.get_full_transitions(27, 19) == 49186, "[27][19]" + assert env.rail.get_full_transitions(27, 20) == 2064, "[27][20]" + assert env.rail.get_full_transitions(27, 21) == 0, "[27][21]" assert env.rail.get_full_transitions(27, 22) == 0, "[27][22]" assert env.rail.get_full_transitions(27, 23) == 0, "[27][23]" assert env.rail.get_full_transitions(27, 24) == 0, "[27][24]" @@ -1235,16 +1235,16 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(28, 9) == 0, "[28][9]" assert env.rail.get_full_transitions(28, 10) == 0, "[28][10]" assert env.rail.get_full_transitions(28, 11) == 32800, "[28][11]" - assert env.rail.get_full_transitions(28, 12) == 72, "[28][12]" - assert env.rail.get_full_transitions(28, 13) == 1025, "[28][13]" - assert env.rail.get_full_transitions(28, 14) == 37408, "[28][14]" + assert env.rail.get_full_transitions(28, 12) == 0, "[28][12]" + assert env.rail.get_full_transitions(28, 13) == 0, "[28][13]" + assert env.rail.get_full_transitions(28, 14) == 0, "[28][14]" assert env.rail.get_full_transitions(28, 15) == 0, "[28][15]" assert env.rail.get_full_transitions(28, 16) == 0, "[28][16]" assert env.rail.get_full_transitions(28, 17) == 0, "[28][17]" assert env.rail.get_full_transitions(28, 18) == 0, "[28][18]" assert env.rail.get_full_transitions(28, 19) == 32800, "[28][19]" - assert env.rail.get_full_transitions(28, 20) == 32800, "[28][20]" - assert env.rail.get_full_transitions(28, 21) == 32800, "[28][21]" + assert env.rail.get_full_transitions(28, 20) == 0, "[28][20]" + assert env.rail.get_full_transitions(28, 21) == 0, "[28][21]" assert env.rail.get_full_transitions(28, 22) == 0, "[28][22]" assert env.rail.get_full_transitions(28, 23) == 0, "[28][23]" assert env.rail.get_full_transitions(28, 24) == 0, "[28][24]" @@ -1262,19 +1262,18 @@ def test_sparse_rail_generator_deterministic(): assert env.rail.get_full_transitions(29, 11) == 72, "[29][11]" assert env.rail.get_full_transitions(29, 12) == 1025, "[29][12]" assert env.rail.get_full_transitions(29, 13) == 1025, "[29][13]" - assert env.rail.get_full_transitions(29, 14) == 1097, "[29][14]" + assert env.rail.get_full_transitions(29, 14) == 1025, "[29][14]" assert env.rail.get_full_transitions(29, 15) == 1025, "[29][15]" assert env.rail.get_full_transitions(29, 16) == 1025, "[29][16]" assert env.rail.get_full_transitions(29, 17) == 1025, "[29][17]" assert env.rail.get_full_transitions(29, 18) == 1025, "[29][18]" - assert env.rail.get_full_transitions(29, 19) == 3089, "[29][19]" - assert env.rail.get_full_transitions(29, 20) == 3089, "[29][20]" - assert env.rail.get_full_transitions(29, 21) == 2064, "[29][21]" + assert env.rail.get_full_transitions(29, 19) == 2064, "[29][19]" + assert env.rail.get_full_transitions(29, 20) == 0, "[29][20]" + assert env.rail.get_full_transitions(29, 21) == 0, "[29][21]" assert env.rail.get_full_transitions(29, 22) == 0, "[29][22]" assert env.rail.get_full_transitions(29, 23) == 0, "[29][23]" assert env.rail.get_full_transitions(29, 24) == 0, "[29][24]" - def test_rail_env_action_required_info(): speed_ration_map = {1.: 0.25, # Fast passenger train 1. / 2.: 0.25, # Fast freight train diff --git a/tests/test_multi_speed.py b/tests/test_multi_speed.py index c517c2c58239b28513991f77592f4730c7fa813b..f809068180e23bfa96b4bbd6d8c647b290a4b039 100644 --- a/tests/test_multi_speed.py +++ b/tests/test_multi_speed.py @@ -22,13 +22,14 @@ class RandomAgent: def __init__(self, state_size, action_size): self.state_size = state_size self.action_size = action_size + self.np_random = np.random.RandomState(seed=42) def act(self, state): """ :param state: input is the observation of the agent :return: returns an action """ - return np.random.choice([1, 2, 3]) + return self.np_random.choice([1, 2, 3]) def step(self, memories): """ @@ -63,6 +64,7 @@ def test_multi_speed_init(): # Set all the different speeds # Reset environment and get initial observations for all agents env.reset(False, False) + env._max_episode_steps = 1000 for a_idx in range(len(env.agents)): env.agents[a_idx].position = env.agents[a_idx].initial_position @@ -204,7 +206,8 @@ def test_multispeed_actions_no_malfunction_blocking(): rail, rail_map, optionals = make_simple_rail() env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals), line_generator=sparse_line_generator(), number_of_agents=2, - obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), + random_seed=1) env.reset() set_penalties_for_replay(env) diff --git a/tests/test_random_seeding.py b/tests/test_random_seeding.py index 7ce80ff0d726539e3df1d0b3bdc64a9c40f2fda2..7e1c4b17324101afdfde90925193971c9bd490a0 100644 --- a/tests/test_random_seeding.py +++ b/tests/test_random_seeding.py @@ -166,52 +166,53 @@ def test_reproducability_env(): env = RailEnv(width=25, height=30, rail_generator=sparse_rail_generator(max_num_cities=5, max_rails_between_cities=3, - seed=215545, # Random seed + seed=10, # Random seed grid_mode=True ), line_generator=sparse_line_generator(speed_ration_map), number_of_agents=1) - env.reset(True, True, random_seed=10) + env.reset(True, True, random_seed=1) excpeted_grid = [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 16386, 1025, 5633, 17411, 1025, 1025, 1025, 5633, 17411, 1025, 1025, 1025, 1025, 1025, 1025, 5633, 17411, 1025, 1025, 1025, 5633, 17411, 1025, 4608], - [0, 49186, 1025, 1097, 3089, 1025, 1025, 1025, 1097, 3089, 1025, 1025, 1025, 1025, 1025, 1025, 1097, 3089, 1025, 1025, 1025, 1097, 3089, 1025, 37408], - [0, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800], - [0, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800], - [0, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800], - [0, 32872, 4608, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16386, 1025, 1025, 1025, 17411, 34864], - [16386, 34864, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 16386, 1025, 1025, 33825, 2064], - [32800, 32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800, 0, 0, 32800, 0], - [32800, 32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800, 0, 0, 32800, 0], - [32800, 32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800, 0, 0, 32800, 0], - [32800, 32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800, 0, 0, 32800, 0], - [32800, 32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800, 0, 0, 32800, 0], - [32800, 32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800, 0, 0, 32800, 0], - [32800, 49186, 2064, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16386, 1025, 1025, 1025, 1025, 38505, 3089, 1025, 1025, 2064, 0], - [32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 0, 0, 0, 0, 32800, 0, 0, 0, 0, 0], - [32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 0, 0, 0, 0, 32872, 4608, 0, 0, 0, 0], - [32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 0, 0, 0, 0, 49186, 34864, 0, 0, 0, 0], - [32800, 32800, 0, 0, 0, 16386, 1025, 4608, 0, 0, 0, 0, 0, 0, 32800, 0, 0, 0, 0, 32800, 32800, 0, 0, 0, 0], - [72, 1097, 1025, 5633, 17411, 3089, 1025, 1097, 5633, 17411, 1025, 5633, 1025, 1025, 2064, 0, 0, 0, 0, 32800, 32800, 0, 0, 0, 0], - [0, 0, 0, 72, 3089, 5633, 1025, 17411, 1097, 2064, 0, 32800, 0, 0, 0, 0, 0, 0, 0, 32800, 32800, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 72, 1025, 2064, 0, 0, 0, 32800, 0, 0, 0, 0, 0, 0, 0, 32872, 37408, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 0, 0, 0, 0, 0, 0, 0, 49186, 2064, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 0, 0, 0, 0, 0, 0, 0, 32800, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 72, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 2064, 0, 0, 0, 0, 0]] + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 16386, 1025, 4608, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [16386, 17411, 1025, 5633, 17411, 3089, 1025, 1097, 5633, 17411, 1025, 5633, 1025, 1025, 1025, 1025, 5633, 17411, 1025, 1025, 1025, 5633, 17411, 1025, 4608], + [32800, 32800, 0, 72, 3089, 5633, 1025, 17411, 1097, 2064, 0, 72, 1025, 1025, 1025, 1025, 1097, 3089, 1025, 1025, 1025, 1097, 3089, 1025, 37408], + [32800, 32800, 0, 0, 0, 72, 1025, 2064, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800], + [32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800], + [32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800], + [32800, 32872, 4608, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16386, 34864], + [32800, 32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800], + [32800, 32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800], + [32800, 32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800], + [32800, 32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800], + [32800, 32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800], + [32800, 32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800], + [72, 37408, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800], + [0, 49186, 2064, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 72, 37408], + [0, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800], + [0, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800], + [0, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800], + [0, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800], + [0, 32872, 1025, 5633, 17411, 1025, 1025, 1025, 5633, 17411, 1025, 1025, 1025, 1025, 1025, 1025, 5633, 17411, 1025, 1025, 1025, 5633, 17411, 1025, 34864], + [0, 72, 1025, 1097, 3089, 1025, 1025, 1025, 1097, 3089, 1025, 1025, 1025, 1025, 1025, 1025, 1097, 3089, 1025, 1025, 1025, 1097, 3089, 1025, 2064], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]] + assert env.rail.grid.tolist() == excpeted_grid # Test that we don't have interference from calling mulitple function outisde env2 = RailEnv(width=25, height=30, rail_generator=sparse_rail_generator(max_num_cities=5, max_rails_between_cities=3, - seed=215545, # Random seed + seed=10, # Random seed grid_mode=True ), line_generator=sparse_line_generator(speed_ration_map), number_of_agents=1) - np.random.seed(10) + np.random.seed(1) for i in range(10): np.random.randn() - env2.reset(True, True, random_seed=10) + env2.reset(True, True, random_seed=1) assert env2.rail.grid.tolist() == excpeted_grid