diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 1c2df4eaa68c17900a480ab3906084eac8a0b08e..36e9084affa4fd9676c09783574366478019b11e 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -160,10 +160,10 @@ class RailEnv(Environment): self.rail: Optional[GridTransitionMap] = None self.width = width self.height = height - + self.initial_num_agents = number_of_agents self.remove_agents_at_target = remove_agents_at_target - self.rewards = [0] * number_of_agents + self.rewards = [0] * self.initial_num_agents self.done = False self.obs_builder = obs_builder_object self.obs_builder.set_env(self) @@ -171,15 +171,15 @@ class RailEnv(Environment): self._max_episode_steps = max_episode_steps self._elapsed_steps = 0 - self.dones = dict.fromkeys(list(range(number_of_agents)) + ["__all__"], False) + self.dones = dict.fromkeys(list(range(self.initial_num_agents)) + ["__all__"], False) self.obs_dict = {} self.rewards_dict = {} self.dev_obs_dict = {} self.dev_pred_dict = {} - self.agents: List[EnvAgent] = [None] * number_of_agents # live agents - self.agents_static: List[EnvAgentStatic] = [None] * number_of_agents # static agent information + self.agents: List[EnvAgent] = [None] * self.initial_num_agents # live agents + self.agents_static: List[EnvAgentStatic] = [None] * self.initial_num_agents # static agent information self.num_resets = 0 self.distance_map = DistanceMap(self.agents, self.height, self.width) @@ -248,6 +248,8 @@ class RailEnv(Environment): # TODO https://gitlab.aicrowd.com/flatland/flatland/issues/172 # can we not put 'self.rail_generator(..)' into 'if regen_rail or self.rail is None' condition? + self.agents: List[EnvAgent] = [None] * self.initial_num_agents # live agents + self.agents_static: List[EnvAgentStatic] = [None] * self.initial_num_agents # static agent information rail, optionals = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets) if regen_rail or self.rail is None: