Commit 608a75b5 authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

fix seeding pipeline

parent e3c821e5
Pipeline #8499 failed with stages
in 6 minutes and 7 seconds
...@@ -106,7 +106,7 @@ class RailEnv(Environment): ...@@ -106,7 +106,7 @@ class RailEnv(Environment):
malfunction_generator_and_process_data=None, # mal_gen.no_malfunction_generator(), malfunction_generator_and_process_data=None, # mal_gen.no_malfunction_generator(),
malfunction_generator=None, malfunction_generator=None,
remove_agents_at_target=True, remove_agents_at_target=True,
random_seed=1, random_seed=None,
record_steps=False, record_steps=False,
): ):
""" """
...@@ -161,7 +161,6 @@ class RailEnv(Environment): ...@@ -161,7 +161,6 @@ class RailEnv(Environment):
self.number_of_agents = number_of_agents self.number_of_agents = number_of_agents
# self.rail_generator: RailGenerator = rail_generator
if rail_generator is None: if rail_generator is None:
rail_generator = rail_gen.sparse_rail_generator() rail_generator = rail_gen.sparse_rail_generator()
self.rail_generator = rail_generator self.rail_generator = rail_generator
...@@ -193,9 +192,7 @@ class RailEnv(Environment): ...@@ -193,9 +192,7 @@ class RailEnv(Environment):
self.action_space = [5] self.action_space = [5]
self._seed() self._seed()
self._seed() if random_seed:
self.random_seed = random_seed
if self.random_seed:
self._seed(seed=random_seed) self._seed(seed=random_seed)
self.agent_positions = None self.agent_positions = None
...@@ -211,6 +208,14 @@ class RailEnv(Environment): ...@@ -211,6 +208,14 @@ class RailEnv(Environment):
def _seed(self, seed=None): def _seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed) self.np_random, seed = seeding.np_random(seed)
random.seed(seed) random.seed(seed)
self.random_seed = seed
# Keep track of all the seeds in order
if not hasattr(self, 'seed_history'):
self.seed_history = [seed]
if self.seed_history[-1] != seed:
self.seed_history.append(seed)
return [seed] return [seed]
# no more agent_handles # no more agent_handles
...@@ -252,7 +257,7 @@ class RailEnv(Environment): ...@@ -252,7 +257,7 @@ class RailEnv(Environment):
( agent.state.is_on_map_state() and agent.speed_counter.is_cell_entry ) ( agent.state.is_on_map_state() and agent.speed_counter.is_cell_entry )
def reset(self, regenerate_rail: bool = True, regenerate_schedule: bool = True, *, def reset(self, regenerate_rail: bool = True, regenerate_schedule: bool = True, *,
random_seed: bool = None) -> Tuple[Dict, Dict]: random_seed: int = None) -> Tuple[Dict, Dict]:
""" """
reset(regenerate_rail, regenerate_schedule, activate_agents, random_seed) reset(regenerate_rail, regenerate_schedule, activate_agents, random_seed)
...@@ -264,7 +269,7 @@ class RailEnv(Environment): ...@@ -264,7 +269,7 @@ class RailEnv(Environment):
regenerate the rails regenerate the rails
regenerate_schedule : bool, optional regenerate_schedule : bool, optional
regenerate the schedule and the static agents regenerate the schedule and the static agents
random_seed : bool, optional random_seed : int, optional
random seed for environment random seed for environment
Returns Returns
...@@ -355,10 +360,10 @@ class RailEnv(Environment): ...@@ -355,10 +360,10 @@ class RailEnv(Environment):
""" Update the agent_positions array for agents that changed positions """ """ Update the agent_positions array for agents that changed positions """
for agent in self.agents: for agent in self.agents:
if not ignore_old_positions or agent.old_position != agent.position: if not ignore_old_positions or agent.old_position != agent.position:
self.agent_positions[agent.position] = agent.handle if agent.position is not None:
self.agent_positions[agent.position] = agent.handle
if agent.old_position is not None: if agent.old_position is not None:
self.agent_positions[agent.old_position] = -1 self.agent_positions[agent.old_position] = -1
def generate_state_transition_signals(self, agent, preprocessed_action, movement_allowed): def generate_state_transition_signals(self, agent, preprocessed_action, movement_allowed):
""" Generate State Transitions Signals used in the state machine """ """ Generate State Transitions Signals used in the state machine """
...@@ -597,7 +602,7 @@ class RailEnv(Environment): ...@@ -597,7 +602,7 @@ class RailEnv(Environment):
# Check if episode has ended and update rewards and dones # Check if episode has ended and update rewards and dones
self.end_of_episode_update(have_all_agents_ended) self.end_of_episode_update(have_all_agents_ended)
self._update_agent_positions_map self._update_agent_positions_map()
return self._get_observations(), self.rewards_dict, self.dones, self.get_info_dict() return self._get_observations(), self.rewards_dict, self.dones, self.get_info_dict()
......
...@@ -163,7 +163,7 @@ def sparse_rail_generator(*args, **kwargs): ...@@ -163,7 +163,7 @@ def sparse_rail_generator(*args, **kwargs):
class SparseRailGen(RailGen): class SparseRailGen(RailGen):
def __init__(self, max_num_cities: int = 2, grid_mode: bool = False, max_rails_between_cities: int = 2, def __init__(self, max_num_cities: int = 2, grid_mode: bool = False, max_rails_between_cities: int = 2,
max_rail_pairs_in_city: int = 2, seed=0) -> RailGenerator: max_rail_pairs_in_city: int = 2, seed=None) -> RailGenerator:
""" """
Generates railway networks with cities and inner city rails Generates railway networks with cities and inner city rails
...@@ -189,7 +189,7 @@ class SparseRailGen(RailGen): ...@@ -189,7 +189,7 @@ class SparseRailGen(RailGen):
self.grid_mode = grid_mode self.grid_mode = grid_mode
self.max_rails_between_cities = max_rails_between_cities self.max_rails_between_cities = max_rails_between_cities
self.max_rail_pairs_in_city = max_rail_pairs_in_city self.max_rail_pairs_in_city = max_rail_pairs_in_city
self.seed = seed # TODO: seed in constructor or generate? self.seed = seed
def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0, def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0,
...@@ -217,8 +217,10 @@ class SparseRailGen(RailGen): ...@@ -217,8 +217,10 @@ class SparseRailGen(RailGen):
'train_stations': locations of train stations for start and targets 'train_stations': locations of train stations for start and targets
'city_orientations' : orientation of cities 'city_orientations' : orientation of cities
""" """
if np_random is None: if self.seed is not None:
np_random = RandomState(self.seed) np_random = RandomState(self.seed)
elif np_random is None:
np_random = RandomState(np.random.randint(2**32))
rail_trans = RailEnvTransitions() rail_trans = RailEnvTransitions()
grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans) grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
......
...@@ -182,7 +182,7 @@ def test_reward_function_waiting(rendering=False): ...@@ -182,7 +182,7 @@ def test_reward_function_waiting(rendering=False):
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals), env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(), number_of_agents=2, line_generator=sparse_line_generator(), number_of_agents=2,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
remove_agents_at_target=False) remove_agents_at_target=False, random_seed=1)
obs_builder: TreeObsForRailEnv = env.obs_builder obs_builder: TreeObsForRailEnv = env.obs_builder
env.reset() env.reset()
......
...@@ -22,13 +22,14 @@ class RandomAgent: ...@@ -22,13 +22,14 @@ class RandomAgent:
def __init__(self, state_size, action_size): def __init__(self, state_size, action_size):
self.state_size = state_size self.state_size = state_size
self.action_size = action_size self.action_size = action_size
self.np_random = np.random.RandomState(seed=42)
def act(self, state): def act(self, state):
""" """
:param state: input is the observation of the agent :param state: input is the observation of the agent
:return: returns an action :return: returns an action
""" """
return np.random.choice([1, 2, 3]) return self.np_random.choice([1, 2, 3])
def step(self, memories): def step(self, memories):
""" """
...@@ -63,6 +64,7 @@ def test_multi_speed_init(): ...@@ -63,6 +64,7 @@ def test_multi_speed_init():
# Set all the different speeds # Set all the different speeds
# Reset environment and get initial observations for all agents # Reset environment and get initial observations for all agents
env.reset(False, False) env.reset(False, False)
env._max_episode_steps = 1000
for a_idx in range(len(env.agents)): for a_idx in range(len(env.agents)):
env.agents[a_idx].position = env.agents[a_idx].initial_position env.agents[a_idx].position = env.agents[a_idx].initial_position
...@@ -204,7 +206,8 @@ def test_multispeed_actions_no_malfunction_blocking(): ...@@ -204,7 +206,8 @@ def test_multispeed_actions_no_malfunction_blocking():
rail, rail_map, optionals = make_simple_rail() rail, rail_map, optionals = make_simple_rail()
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals), env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(), number_of_agents=2, line_generator=sparse_line_generator(), number_of_agents=2,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
random_seed=1)
env.reset() env.reset()
set_penalties_for_replay(env) set_penalties_for_replay(env)
......
...@@ -166,52 +166,53 @@ def test_reproducability_env(): ...@@ -166,52 +166,53 @@ def test_reproducability_env():
env = RailEnv(width=25, height=30, rail_generator=sparse_rail_generator(max_num_cities=5, env = RailEnv(width=25, height=30, rail_generator=sparse_rail_generator(max_num_cities=5,
max_rails_between_cities=3, max_rails_between_cities=3,
seed=215545, # Random seed seed=10, # Random seed
grid_mode=True grid_mode=True
), ),
line_generator=sparse_line_generator(speed_ration_map), number_of_agents=1) line_generator=sparse_line_generator(speed_ration_map), number_of_agents=1)
env.reset(True, True, random_seed=10) env.reset(True, True, random_seed=1)
excpeted_grid = [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], excpeted_grid = [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 16386, 1025, 4608, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 16386, 1025, 5633, 17411, 1025, 1025, 1025, 5633, 17411, 1025, 1025, 1025, 1025, 1025, 1025, 5633, 17411, 1025, 1025, 1025, 5633, 17411, 1025, 4608], [16386, 17411, 1025, 5633, 17411, 3089, 1025, 1097, 5633, 17411, 1025, 5633, 1025, 1025, 1025, 1025, 5633, 17411, 1025, 1025, 1025, 5633, 17411, 1025, 4608],
[0, 49186, 1025, 1097, 3089, 1025, 1025, 1025, 1097, 3089, 1025, 1025, 1025, 1025, 1025, 1025, 1097, 3089, 1025, 1025, 1025, 1097, 3089, 1025, 37408], [32800, 32800, 0, 72, 3089, 5633, 1025, 17411, 1097, 2064, 0, 72, 1025, 1025, 1025, 1025, 1097, 3089, 1025, 1025, 1025, 1097, 3089, 1025, 37408],
[0, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800], [32800, 32800, 0, 0, 0, 72, 1025, 2064, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800],
[0, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800], [32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800],
[0, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800], [32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800],
[0, 32872, 4608, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16386, 1025, 1025, 1025, 17411, 34864], [32800, 32872, 4608, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16386, 34864],
[16386, 34864, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 16386, 1025, 1025, 33825, 2064], [32800, 32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800],
[32800, 32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800, 0, 0, 32800, 0], [32800, 32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800],
[32800, 32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800, 0, 0, 32800, 0], [32800, 32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800],
[32800, 32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800, 0, 0, 32800, 0], [32800, 32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800],
[32800, 32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800, 0, 0, 32800, 0], [32800, 32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800],
[32800, 32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800, 0, 0, 32800, 0], [32800, 32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800],
[32800, 32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800, 0, 0, 32800, 0], [72, 37408, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800],
[32800, 49186, 2064, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16386, 1025, 1025, 1025, 1025, 38505, 3089, 1025, 1025, 2064, 0], [0, 49186, 2064, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 72, 37408],
[32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 0, 0, 0, 0, 32800, 0, 0, 0, 0, 0], [0, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800],
[32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 0, 0, 0, 0, 32872, 4608, 0, 0, 0, 0], [0, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800],
[32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 0, 0, 0, 0, 49186, 34864, 0, 0, 0, 0], [0, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800],
[32800, 32800, 0, 0, 0, 16386, 1025, 4608, 0, 0, 0, 0, 0, 0, 32800, 0, 0, 0, 0, 32800, 32800, 0, 0, 0, 0], [0, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800],
[72, 1097, 1025, 5633, 17411, 3089, 1025, 1097, 5633, 17411, 1025, 5633, 1025, 1025, 2064, 0, 0, 0, 0, 32800, 32800, 0, 0, 0, 0], [0, 32872, 1025, 5633, 17411, 1025, 1025, 1025, 5633, 17411, 1025, 1025, 1025, 1025, 1025, 1025, 5633, 17411, 1025, 1025, 1025, 5633, 17411, 1025, 34864],
[0, 0, 0, 72, 3089, 5633, 1025, 17411, 1097, 2064, 0, 32800, 0, 0, 0, 0, 0, 0, 0, 32800, 32800, 0, 0, 0, 0], [0, 72, 1025, 1097, 3089, 1025, 1025, 1025, 1097, 3089, 1025, 1025, 1025, 1025, 1025, 1025, 1097, 3089, 1025, 1025, 1025, 1097, 3089, 1025, 2064],
[0, 0, 0, 0, 0, 72, 1025, 2064, 0, 0, 0, 32800, 0, 0, 0, 0, 0, 0, 0, 32872, 37408, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 0, 0, 0, 0, 0, 0, 0, 49186, 2064, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 0, 0, 0, 0, 0, 0, 0, 32800, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 72, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 2064, 0, 0, 0, 0, 0]] [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
assert env.rail.grid.tolist() == excpeted_grid assert env.rail.grid.tolist() == excpeted_grid
# Test that we don't have interference from calling mulitple function outisde # Test that we don't have interference from calling mulitple function outisde
env2 = RailEnv(width=25, height=30, rail_generator=sparse_rail_generator(max_num_cities=5, env2 = RailEnv(width=25, height=30, rail_generator=sparse_rail_generator(max_num_cities=5,
max_rails_between_cities=3, max_rails_between_cities=3,
seed=215545, # Random seed seed=10, # Random seed
grid_mode=True grid_mode=True
), ),
line_generator=sparse_line_generator(speed_ration_map), number_of_agents=1) line_generator=sparse_line_generator(speed_ration_map), number_of_agents=1)
np.random.seed(10) np.random.seed(1)
for i in range(10): for i in range(10):
np.random.randn() np.random.randn()
env2.reset(True, True, random_seed=10) env2.reset(True, True, random_seed=1)
assert env2.rail.grid.tolist() == excpeted_grid assert env2.rail.grid.tolist() == excpeted_grid
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment