From c1b861317eaca2166094585b521586168339bc11 Mon Sep 17 00:00:00 2001 From: Giacomo Spigler <spiglerg@gmail.com> Date: Mon, 15 Jul 2019 11:13:03 +0200 Subject: [PATCH] closing issue #93 --- flatland/envs/rail_env.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 996301a8..6d8aff5a 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -80,6 +80,7 @@ class RailEnv(Environment): rail_generator=random_rail_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2), + max_episode_steps = None ): """ Environment init. @@ -110,6 +111,8 @@ class RailEnv(Environment): obs_builder_object: ObservationBuilder object ObservationBuilder-derived object that takes builds observation vectors for each agent. + max_episode_steps : int or None + file_name: you can load a pickle file. """ @@ -123,6 +126,9 @@ class RailEnv(Environment): self.obs_builder = obs_builder_object self.obs_builder._set_env(self) + self._max_episode_steps = max_episode_steps + self._elapsed_steps = 0 + self.dones = dict.fromkeys(list(range(number_of_agents)) + ["__all__"], False) self.obs_dict = {} @@ -184,6 +190,7 @@ class RailEnv(Environment): agent.speed_data['position_fraction'] = 0.0 self.num_resets += 1 + self._elapsed_steps = 0 # TODO perhaps dones should be part of each agent. self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False) @@ -196,6 +203,8 @@ class RailEnv(Environment): return self._get_observations() def step(self, action_dict_): + self._elapsed_steps += 1 + action_dict = action_dict_.copy() alpha = 1.0 @@ -323,6 +332,11 @@ class RailEnv(Environment): self.dones["__all__"] = True self.rewards_dict = {i: 0 * r + global_reward for i, r in self.rewards_dict.items()} + if (self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps): + self.dones["__all__"] = True + for k in self.dones.keys(): + self.dones[k] = True + return self._get_observations(), self.rewards_dict, self.dones, {} def _check_action_on_agent(self, action, agent): -- GitLab