Commit 608a75b5 authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

fix seeding pipeline

parent e3c821e5
Pipeline #8499 failed with stages
in 6 minutes and 7 seconds
......@@ -106,7 +106,7 @@ class RailEnv(Environment):
malfunction_generator_and_process_data=None, # mal_gen.no_malfunction_generator(),
malfunction_generator=None,
remove_agents_at_target=True,
random_seed=1,
random_seed=None,
record_steps=False,
):
"""
......@@ -161,7 +161,6 @@ class RailEnv(Environment):
self.number_of_agents = number_of_agents
# self.rail_generator: RailGenerator = rail_generator
if rail_generator is None:
rail_generator = rail_gen.sparse_rail_generator()
self.rail_generator = rail_generator
......@@ -193,9 +192,7 @@ class RailEnv(Environment):
self.action_space = [5]
self._seed()
self._seed()
self.random_seed = random_seed
if self.random_seed:
if random_seed:
self._seed(seed=random_seed)
self.agent_positions = None
......@@ -211,6 +208,14 @@ class RailEnv(Environment):
def _seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
random.seed(seed)
self.random_seed = seed
# Keep track of all the seeds in order
if not hasattr(self, 'seed_history'):
self.seed_history = [seed]
if self.seed_history[-1] != seed:
self.seed_history.append(seed)
return [seed]
# no more agent_handles
......@@ -252,7 +257,7 @@ class RailEnv(Environment):
( agent.state.is_on_map_state() and agent.speed_counter.is_cell_entry )
def reset(self, regenerate_rail: bool = True, regenerate_schedule: bool = True, *,
random_seed: bool = None) -> Tuple[Dict, Dict]:
random_seed: int = None) -> Tuple[Dict, Dict]:
"""
reset(regenerate_rail, regenerate_schedule, activate_agents, random_seed)
......@@ -264,7 +269,7 @@ class RailEnv(Environment):
regenerate the rails
regenerate_schedule : bool, optional
regenerate the schedule and the static agents
random_seed : bool, optional
random_seed : int, optional
random seed for environment
Returns
......@@ -355,10 +360,10 @@ class RailEnv(Environment):
""" Update the agent_positions array for agents that changed positions """
for agent in self.agents:
if not ignore_old_positions or agent.old_position != agent.position:
self.agent_positions[agent.position] = agent.handle
if agent.position is not None:
self.agent_positions[agent.position] = agent.handle
if agent.old_position is not None:
self.agent_positions[agent.old_position] = -1
def generate_state_transition_signals(self, agent, preprocessed_action, movement_allowed):
""" Generate State Transitions Signals used in the state machine """
......@@ -597,7 +602,7 @@ class RailEnv(Environment):
# Check if episode has ended and update rewards and dones
self.end_of_episode_update(have_all_agents_ended)
self._update_agent_positions_map
self._update_agent_positions_map()
return self._get_observations(), self.rewards_dict, self.dones, self.get_info_dict()
......
......@@ -163,7 +163,7 @@ def sparse_rail_generator(*args, **kwargs):
class SparseRailGen(RailGen):
def __init__(self, max_num_cities: int = 2, grid_mode: bool = False, max_rails_between_cities: int = 2,
max_rail_pairs_in_city: int = 2, seed=0) -> RailGenerator:
max_rail_pairs_in_city: int = 2, seed=None) -> RailGenerator:
"""
Generates railway networks with cities and inner city rails
......@@ -189,7 +189,7 @@ class SparseRailGen(RailGen):
self.grid_mode = grid_mode
self.max_rails_between_cities = max_rails_between_cities
self.max_rail_pairs_in_city = max_rail_pairs_in_city
self.seed = seed # TODO: seed in constructor or generate?
self.seed = seed
def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0,
......@@ -217,8 +217,10 @@ class SparseRailGen(RailGen):
'train_stations': locations of train stations for start and targets
'city_orientations' : orientation of cities
"""
if np_random is None:
if self.seed is not None:
np_random = RandomState(self.seed)
elif np_random is None:
np_random = RandomState(np.random.randint(2**32))
rail_trans = RailEnvTransitions()
grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
......
......@@ -182,7 +182,7 @@ def test_reward_function_waiting(rendering=False):
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(), number_of_agents=2,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
remove_agents_at_target=False)
remove_agents_at_target=False, random_seed=1)
obs_builder: TreeObsForRailEnv = env.obs_builder
env.reset()
......
......@@ -22,13 +22,14 @@ class RandomAgent:
def __init__(self, state_size, action_size):
self.state_size = state_size
self.action_size = action_size
self.np_random = np.random.RandomState(seed=42)
def act(self, state):
"""
:param state: input is the observation of the agent
:return: returns an action
"""
return np.random.choice([1, 2, 3])
return self.np_random.choice([1, 2, 3])
def step(self, memories):
"""
......@@ -63,6 +64,7 @@ def test_multi_speed_init():
# Set all the different speeds
# Reset environment and get initial observations for all agents
env.reset(False, False)
env._max_episode_steps = 1000
for a_idx in range(len(env.agents)):
env.agents[a_idx].position = env.agents[a_idx].initial_position
......@@ -204,7 +206,8 @@ def test_multispeed_actions_no_malfunction_blocking():
rail, rail_map, optionals = make_simple_rail()
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(), number_of_agents=2,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
random_seed=1)
env.reset()
set_penalties_for_replay(env)
......
......@@ -166,52 +166,53 @@ def test_reproducability_env():
env = RailEnv(width=25, height=30, rail_generator=sparse_rail_generator(max_num_cities=5,
max_rails_between_cities=3,
seed=215545, # Random seed
seed=10, # Random seed
grid_mode=True
),
line_generator=sparse_line_generator(speed_ration_map), number_of_agents=1)
env.reset(True, True, random_seed=10)
env.reset(True, True, random_seed=1)
excpeted_grid = [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 16386, 1025, 5633, 17411, 1025, 1025, 1025, 5633, 17411, 1025, 1025, 1025, 1025, 1025, 1025, 5633, 17411, 1025, 1025, 1025, 5633, 17411, 1025, 4608],
[0, 49186, 1025, 1097, 3089, 1025, 1025, 1025, 1097, 3089, 1025, 1025, 1025, 1025, 1025, 1025, 1097, 3089, 1025, 1025, 1025, 1097, 3089, 1025, 37408],
[0, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800],
[0, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800],
[0, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800],
[0, 32872, 4608, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16386, 1025, 1025, 1025, 17411, 34864],
[16386, 34864, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 16386, 1025, 1025, 33825, 2064],
[32800, 32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800, 0, 0, 32800, 0],
[32800, 32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800, 0, 0, 32800, 0],
[32800, 32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800, 0, 0, 32800, 0],
[32800, 32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800, 0, 0, 32800, 0],
[32800, 32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800, 0, 0, 32800, 0],
[32800, 32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800, 0, 0, 32800, 0],
[32800, 49186, 2064, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16386, 1025, 1025, 1025, 1025, 38505, 3089, 1025, 1025, 2064, 0],
[32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 0, 0, 0, 0, 32800, 0, 0, 0, 0, 0],
[32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 0, 0, 0, 0, 32872, 4608, 0, 0, 0, 0],
[32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 0, 0, 0, 0, 49186, 34864, 0, 0, 0, 0],
[32800, 32800, 0, 0, 0, 16386, 1025, 4608, 0, 0, 0, 0, 0, 0, 32800, 0, 0, 0, 0, 32800, 32800, 0, 0, 0, 0],
[72, 1097, 1025, 5633, 17411, 3089, 1025, 1097, 5633, 17411, 1025, 5633, 1025, 1025, 2064, 0, 0, 0, 0, 32800, 32800, 0, 0, 0, 0],
[0, 0, 0, 72, 3089, 5633, 1025, 17411, 1097, 2064, 0, 32800, 0, 0, 0, 0, 0, 0, 0, 32800, 32800, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 72, 1025, 2064, 0, 0, 0, 32800, 0, 0, 0, 0, 0, 0, 0, 32872, 37408, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 0, 0, 0, 0, 0, 0, 0, 49186, 2064, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 0, 0, 0, 0, 0, 0, 0, 32800, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 72, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 2064, 0, 0, 0, 0, 0]]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 16386, 1025, 4608, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[16386, 17411, 1025, 5633, 17411, 3089, 1025, 1097, 5633, 17411, 1025, 5633, 1025, 1025, 1025, 1025, 5633, 17411, 1025, 1025, 1025, 5633, 17411, 1025, 4608],
[32800, 32800, 0, 72, 3089, 5633, 1025, 17411, 1097, 2064, 0, 72, 1025, 1025, 1025, 1025, 1097, 3089, 1025, 1025, 1025, 1097, 3089, 1025, 37408],
[32800, 32800, 0, 0, 0, 72, 1025, 2064, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800],
[32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800],
[32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800],
[32800, 32872, 4608, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16386, 34864],
[32800, 32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800],
[32800, 32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800],
[32800, 32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800],
[32800, 32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800],
[32800, 32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800],
[32800, 32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800],
[72, 37408, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800],
[0, 49186, 2064, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 72, 37408],
[0, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800],
[0, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800],
[0, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800],
[0, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800],
[0, 32872, 1025, 5633, 17411, 1025, 1025, 1025, 5633, 17411, 1025, 1025, 1025, 1025, 1025, 1025, 5633, 17411, 1025, 1025, 1025, 5633, 17411, 1025, 34864],
[0, 72, 1025, 1097, 3089, 1025, 1025, 1025, 1097, 3089, 1025, 1025, 1025, 1025, 1025, 1025, 1097, 3089, 1025, 1025, 1025, 1097, 3089, 1025, 2064],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
assert env.rail.grid.tolist() == excpeted_grid
# Test that we don't have interference from calling mulitple function outisde
env2 = RailEnv(width=25, height=30, rail_generator=sparse_rail_generator(max_num_cities=5,
max_rails_between_cities=3,
seed=215545, # Random seed
seed=10, # Random seed
grid_mode=True
),
line_generator=sparse_line_generator(speed_ration_map), number_of_agents=1)
np.random.seed(10)
np.random.seed(1)
for i in range(10):
np.random.randn()
env2.reset(True, True, random_seed=10)
env2.reset(True, True, random_seed=1)
assert env2.rail.grid.tolist() == excpeted_grid
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment