From 65f30a850585df5af4f9c2b0e7ffc23271d0f21d Mon Sep 17 00:00:00 2001
From: u229589 <christian.baumberger@sbb.ch>
Date: Wed, 23 Oct 2019 10:54:19 +0200
Subject: [PATCH] rename parameters of reset method  (issue #250)

---
 changelog.md                    |  1 +
 docs/specifications/railway.md  |  2 +-
 docs/tutorials/05_multispeed.md |  2 +-
 flatland/envs/rail_env.py       | 37 ++++++++++++++++++++++++++-------
 flatland/evaluators/client.py   |  6 +++---
 flatland/evaluators/service.py  |  4 ++--
 flatland/utils/editor.py        |  6 +++---
 7 files changed, 40 insertions(+), 18 deletions(-)

diff --git a/changelog.md b/changelog.md
index faca0cb8..8a4da1f1 100644
--- a/changelog.md
+++ b/changelog.md
@@ -10,6 +10,7 @@ Changes since Flatland 2.0.0
 ### Changes in rail generator and `RailEnv`
 - renaming of `distance_maps` into `distance_map`
 - by default the reset method of RailEnv is not called in the constructor of RailEnv anymore (compliance for OpenAI Gym). Therefore the reset method needs to be called after the creation of a RailEnv object
+- renaming of parameters RailEnv.reset(): from `regen_rail` to `regenerate_rail`, from `replace_agents` to `regenerate_schedule`
 
 ### Changes in schedule generation
 - return value of schedule generator has changed to the named tuple `Schedule`. From the point of view of a consumer, nothing has changed, this is just a type hint which is introduced where the attributes of `Schedule` have names.
diff --git a/docs/specifications/railway.md b/docs/specifications/railway.md
index e1ee77f4..35b2835d 100644
--- a/docs/specifications/railway.md
+++ b/docs/specifications/railway.md
@@ -430,7 +430,7 @@ The environment's `reset` takes care of applying the two generators:
         self.rail_generator: RailGenerator = rail_generator
         self.schedule_generator: ScheduleGenerator = schedule_generator
 
-    def reset(self, regen_rail=True, replace_agents=True):
+    def reset(self, regenerate_rail=True, regenerate_schedule=True):
         rail, optionals = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets)
 
         ...
diff --git a/docs/tutorials/05_multispeed.md b/docs/tutorials/05_multispeed.md
index cc45c65e..faf97c26 100644
--- a/docs/tutorials/05_multispeed.md
+++ b/docs/tutorials/05_multispeed.md
@@ -114,7 +114,7 @@ The environment's `reset` takes care of applying the two generators:
         self.rail_generator: RailGenerator = rail_generator
         self.schedule_generator: ScheduleGenerator = schedule_generator
 
-    def reset(self, regen_rail=True, replace_agents=True):
+    def reset(self, regenerate_rail=True, regenerate_schedule=True):
         rail, optionals = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets)
 
         ...
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 550a6b75..67131060 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -275,17 +275,37 @@ class RailEnv(Environment):
         alpha = 2
         return int(timedelay_factor * alpha * (width + height + ratio_nr_agents_to_nr_cities))
 
-    def reset(self, regen_rail=True, replace_agents=True, activate_agents=False, random_seed=None):
-        """ if regen_rail then regenerate the rails.
-            if replace_agents then regenerate the agents static.
-            Relies on the rail_generator returning agent_static lists (pos, dir, target)
+    def reset(self, regenerate_rail: bool = True, regenerate_schedule: bool = True, activate_agents: bool = False,
+              random_seed: bool = None) -> (Dict, Dict):
+        """
+        reset(regenerate_rail, regenerate_schedule, activate_agents, random_seed)
+
+        The method resets the rail environment
+
+        Parameters
+        ----------
+        regenerate_rail : bool, optional
+            regenerate the rails
+        regenerate_schedule : bool, optional
+            regenerate the schedule and the static agents
+        activate_agents : bool, optional
+            activate the agents
+        random_seed : bool, optional
+            random seed for environment
+
+        Returns
+        -------
+        observation_dict: Dict
+            Dictionary with an observation for each agent
+        info_dict: Dict with agent specific information
+
         """
 
         if random_seed:
             self._seed(random_seed)
 
         optionals = {}
-        if regen_rail or self.rail is None:
+        if regenerate_rail or self.rail is None:
             rail, optionals = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets)
 
             self.rail = rail
@@ -299,7 +319,7 @@ class RailEnv(Environment):
         if optionals and 'distance_map' in optionals:
             self.distance_map.set(optionals['distance_map'])
 
-        if replace_agents or self.agents_static[0] is None:
+        if regenerate_schedule or self.agents_static[0] is None:
             agents_hints = None
             if optionals and 'agents_hints' in optionals:
                 agents_hints = optionals['agents_hints']
@@ -347,7 +367,7 @@ class RailEnv(Environment):
         self.obs_builder.reset()
         self.distance_map.reset(self.agents, self.rail)
 
-        info_dict = {
+        info_dict: Dict = {
             'action_required': {
                 i: (agent.status == RailAgentStatus.READY_TO_DEPART or (
                     agent.status == RailAgentStatus.ACTIVE and agent.speed_data['position_fraction'] == 0.0))
@@ -359,7 +379,8 @@ class RailEnv(Environment):
             'status': {i: agent.status for i, agent in enumerate(self.agents)}
         }
         # Return the new observation vectors for each agent
-        return self._get_observations(), info_dict
+        observation_dict: Dict = self._get_observations()
+        return observation_dict, info_dict
 
     def _agent_malfunction(self, i_agent) -> bool:
         """
diff --git a/flatland/evaluators/client.py b/flatland/evaluators/client.py
index ef869f40..9dc1e587 100644
--- a/flatland/evaluators/client.py
+++ b/flatland/evaluators/client.py
@@ -213,7 +213,7 @@ class FlatlandRemoteClient(object):
                 "to point to the location of the Tests folder ? \n"
                 "We are currently looking at `{}` for the tests".format(self.test_envs_root)
             )
-        
+
         if self.verbose:
             print("Current env path : ", test_env_file_path)
         self.current_env_path = test_env_file_path
@@ -226,8 +226,8 @@ class FlatlandRemoteClient(object):
         )
 
         local_observation, info = self.env.reset(
-                                regen_rail=True,
-                                replace_agents=True,
+                                regenerate_rail=False,
+                                regenerate_schedule=False,
                                 activate_agents=False,
                                 random_seed=random_seed
                             )
diff --git a/flatland/evaluators/service.py b/flatland/evaluators/service.py
index aa6eef05..96112390 100644
--- a/flatland/evaluators/service.py
+++ b/flatland/evaluators/service.py
@@ -362,8 +362,8 @@ class FlatlandRemoteEvaluationService:
             self.current_step = 0
 
             _observation, _info = self.env.reset(
-                                regen_rail=True,
-                                replace_agents=True,
+                                regenerate_rail=False,
+                                regenerate_schedule=False,
                                 activate_agents=False,
                                 random_seed=RANDOM_SEED
                                 )
diff --git a/flatland/utils/editor.py b/flatland/utils/editor.py
index e5d55bb5..f8c9afd0 100644
--- a/flatland/utils/editor.py
+++ b/flatland/utils/editor.py
@@ -323,7 +323,7 @@ class Controller(object):
         self.log("Reset - nAgents:", self.view.regen_n_agents.value)
         self.log("Reset - size:", self.model.regen_size_width)
         self.log("Reset - size:", self.model.regen_size_height)
-        self.model.reset(replace_agents=self.view.replace_agents.value,
+        self.model.reset(regenerate_schedule=self.view.replace_agents.value,
                          nAgents=self.view.regen_n_agents.value)
 
     def rotate_agent(self, event):
@@ -611,7 +611,7 @@ class EditorModel(object):
         self.env.rail.grid[cell_row_col[0], cell_row_col[1]] = 0
         self.redraw()
 
-    def reset(self, replace_agents=False, nAgents=0):
+    def reset(self, regenerate_schedule=False, nAgents=0):
         self.regenerate("complex", nAgents=nAgents)
         self.redraw()
 
@@ -676,7 +676,7 @@ class EditorModel(object):
                                obs_builder_object=TreeObsForRailEnv(max_depth=2))
         else:
             self.env = env
-        self.env.reset(regen_rail=True)
+        self.env.reset(regenerate_rail=True)
         self.fix_env()
         self.set_env(self.env)
         self.view.new_env()
-- 
GitLab