From c1b861317eaca2166094585b521586168339bc11 Mon Sep 17 00:00:00 2001
From: Giacomo Spigler <spiglerg@gmail.com>
Date: Mon, 15 Jul 2019 11:13:03 +0200
Subject: [PATCH] closing issue #93

---
 flatland/envs/rail_env.py | 14 ++++++++++++++
 1 file changed, 14 insertions(+)

diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 996301a8..6d8aff5a 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -80,6 +80,7 @@ class RailEnv(Environment):
                  rail_generator=random_rail_generator(),
                  number_of_agents=1,
                  obs_builder_object=TreeObsForRailEnv(max_depth=2),
+                 max_episode_steps = None
                  ):
         """
         Environment init.
@@ -110,6 +111,8 @@ class RailEnv(Environment):
         obs_builder_object: ObservationBuilder object
             ObservationBuilder-derived object that takes builds observation
             vectors for each agent.
+        max_episode_steps : int or None
+
         file_name: you can load a pickle file.
         """
 
@@ -123,6 +126,9 @@ class RailEnv(Environment):
         self.obs_builder = obs_builder_object
         self.obs_builder._set_env(self)
 
+        self._max_episode_steps = max_episode_steps
+        self._elapsed_steps = 0
+
         self.dones = dict.fromkeys(list(range(number_of_agents)) + ["__all__"], False)
 
         self.obs_dict = {}
@@ -184,6 +190,7 @@ class RailEnv(Environment):
             agent.speed_data['position_fraction'] = 0.0
 
         self.num_resets += 1
+        self._elapsed_steps = 0
 
         # TODO perhaps dones should be part of each agent.
         self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
@@ -196,6 +203,8 @@ class RailEnv(Environment):
         return self._get_observations()
 
     def step(self, action_dict_):
+        self._elapsed_steps += 1
+
         action_dict = action_dict_.copy()
 
         alpha = 1.0
@@ -323,6 +332,11 @@ class RailEnv(Environment):
             self.dones["__all__"] = True
             self.rewards_dict = {i: 0 * r + global_reward for i, r in self.rewards_dict.items()}
 
+        if (self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps):
+            self.dones["__all__"] = True
+            for k in self.dones.keys():
+                self.dones[k] = True
+
         return self._get_observations(), self.rewards_dict, self.dones, {}
 
     def _check_action_on_agent(self, action, agent):
-- 
GitLab