diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index f082f0801ba782bab8d4106856739fc798a3efb8..b1e89ebb6ce8c8223b71aeb8e836e0e4e68795fd 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -83,6 +83,7 @@ class RailEnv(Environment): rail_generator=random_rail_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2), + max_episode_steps=None ): """ Environment init. @@ -113,6 +114,8 @@ class RailEnv(Environment): obs_builder_object: ObservationBuilder object ObservationBuilder-derived object that takes builds observation vectors for each agent. + max_episode_steps : int or None + file_name: you can load a pickle file. """ @@ -126,6 +129,9 @@ class RailEnv(Environment): self.obs_builder = obs_builder_object self.obs_builder._set_env(self) + self._max_episode_steps = max_episode_steps + self._elapsed_steps = 0 + self.dones = dict.fromkeys(list(range(number_of_agents)) + ["__all__"], False) self.obs_dict = {} @@ -191,6 +197,7 @@ class RailEnv(Environment): agent.speed_data['position_fraction'] = 0.0 self.num_resets += 1 + self._elapsed_steps = 0 # TODO perhaps dones should be part of each agent. self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False) @@ -203,6 +210,8 @@ class RailEnv(Environment): return self._get_observations() def step(self, action_dict_): + self._elapsed_steps += 1 + action_dict = action_dict_.copy() alpha = 1.0 @@ -330,6 +339,11 @@ class RailEnv(Environment): self.dones["__all__"] = True self.rewards_dict = {i: 0 * r + global_reward for i, r in self.rewards_dict.items()} + if (self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps): + self.dones["__all__"] = True + for k in self.dones.keys(): + self.dones[k] = True + return self._get_observations(), self.rewards_dict, self.dones, {} def _check_action_on_agent(self, action, agent):