From 9edc0ff85133c21d448d0c1a7e34c9c06c2fd121 Mon Sep 17 00:00:00 2001
From: u229589 <christian.baumberger@sbb.ch>
Date: Tue, 29 Oct 2019 08:27:11 +0100
Subject: [PATCH] Merge branch 'master' of gitlab.aicrowd.com:flatland/flatland
 into 259-reduce-computational-complexity-of-env-step

# Conflicts:
#	flatland/envs/rail_env.py
---
 flatland/envs/rail_env.py | 116 ++++++++++++++++++++++++++++++++++++++
 1 file changed, 116 insertions(+)

diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 153f668b..32dd70e5 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -84,6 +84,7 @@ class RailEnv(Environment):
     - invalid_action_penalty = 0
     - step_penalty = -alpha
     - global_reward = beta
+    - epsilon = avoid rounding errors
     - stop_penalty = 0  # penalty for stopping a moving agent
     - start_penalty = 0  # penalty for starting a stopped agent
 
@@ -217,6 +218,8 @@ class RailEnv(Environment):
         # Reset environment
 
         self.valid_positions = None
+
+        # global numpy array of agents position, True means that there is an agent at that cell
         self.agent_positions: np.ndarray = np.full((height, width), False)
 
     def _seed(self, seed=None):
@@ -437,7 +440,14 @@ class RailEnv(Environment):
         return False
 
     def step(self, action_dict_: Dict[int, RailEnvActions]):
+        """
+        Updates rewards for the agents at a step.
 
+        Parameters
+        ----------
+        action_dict_ : Dict[int,RailEnvActions]
+
+        """
         self._elapsed_steps += 1
 
         # If we're done, set reward and info_dict and step() is done.
@@ -631,15 +641,41 @@ class RailEnv(Environment):
             self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
 
     def _set_agent_to_initial_position(self, agent: EnvAgent, new_position: IntVector2D):
+        """
+        Sets the agent to its initial position. Updates the agent object and the position
+        of the agent inside the global agent_position numpy array
+
+        Parameters
+        -------
+        agent: EnvAgent object
+        new_position: IntVector2D
+        """
         agent.position = new_position
         self.agent_positions[agent.position] = True
 
     def _move_agent_to_new_position(self, agent: EnvAgent, new_position: IntVector2D):
+        """
+        Move the agent to the a new position. Updates the agent object and the position
+        of the agent inside the global agent_position numpy array
+
+        Parameters
+        -------
+        agent: EnvAgent object
+        new_position: IntVector2D
+        """
         agent.position = new_position
         self.agent_positions[agent.old_position] = False
         self.agent_positions[agent.position] = True
 
     def _remove_agent_from_scene(self, agent: EnvAgent):
+        """
+        Remove the agent from the scene. Updates the agent object and the position
+        of the agent inside the global agent_position numpy array
+
+        Parameters
+        -------
+        agent: EnvAgent object
+        """
         self.agent_positions[agent.position] = False
         if self.remove_agents_at_target:
             agent.position = None
@@ -687,6 +723,19 @@ class RailEnv(Environment):
         return cell_free, new_cell_valid, new_direction, new_position, transition_valid
 
     def cell_free(self, position: IntVector2D) -> bool:
+        """
+        Utility to check if a cell is free
+
+        Parameters:
+        --------
+        position : Tuple[int, int]
+
+        Returns
+        -------
+        bool
+            is the cell free or not?
+
+        """
         try:
             return not self.agent_positions[position]
         except IndexError as error:
@@ -734,13 +783,35 @@ class RailEnv(Environment):
         return new_direction, transition_valid
 
     def _get_observations(self):
+        """
+        Utility which returns the observations for an agent with respect to environment
+
+        Returns
+        ------
+        Dict object
+        """
         self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents())))
         return self.obs_dict
 
     def get_valid_directions_on_grid(self, row: int, col: int) -> List[int]:
+        """
+        Returns directions in which the agent can move
+
+        Parameters:
+        ---------
+        row : int
+        col : int
+
+        Returns:
+        -------
+        List[int]
+        """
         return Grid4Transitions.get_entry_directions(self.rail.get_full_transitions(row, col))
 
     def get_full_state_msg(self):
+        """
+        Returns state of environment in msgpack object
+        """
         grid_data = self.rail.grid.tolist()
         agent_static_data = [agent.to_list() for agent in self.agents_static]
         agent_data = [agent.to_list() for agent in self.agents]
@@ -754,12 +825,22 @@ class RailEnv(Environment):
         return msgpack.packb(msg_data, use_bin_type=True)
 
     def get_agent_state_msg(self):
+        """
+        Returns agents information in msgpack object
+        """
         agent_data = [agent.to_list() for agent in self.agents]
         msg_data = {
             "agents": agent_data}
         return msgpack.packb(msg_data, use_bin_type=True)
 
     def set_full_state_msg(self, msg_data):
+        """
+        Sets environment state with msgdata object passed as argument
+
+        Parameters
+        -------
+        msg_data: msgpack object
+        """
         data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8')
         self.rail.grid = np.array(data["grid"])
         # agents are always reset as not moving
@@ -772,6 +853,13 @@ class RailEnv(Environment):
         self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
 
     def set_full_state_dist_msg(self, msg_data):
+        """
+        Sets environment grid state and distance map with msgdata object passed as argument
+
+        Parameters
+        -------
+        msg_data: msgpack object
+        """
         data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8')
         self.rail.grid = np.array(data["grid"])
         # agents are always reset as not moving
@@ -786,6 +874,9 @@ class RailEnv(Environment):
         self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
 
     def get_full_state_dist_msg(self):
+        """
+        Returns environment information with distance map information as msgpack object
+        """
         grid_data = self.rail.grid.tolist()
         agent_static_data = [agent.to_list() for agent in self.agents_static]
         agent_data = [agent.to_list() for agent in self.agents]
@@ -803,6 +894,14 @@ class RailEnv(Environment):
         return msgpack.packb(msg_data, use_bin_type=True)
 
     def save(self, filename, save_distance_maps=False):
+        """
+        Saves environment and distance map information in a file
+
+        Parameters:
+        ---------
+        filename: string
+        save_distance_maps: bool
+        """
         if save_distance_maps is True:
             if self.distance_map.get() is not None:
                 if len(self.distance_map.get()) > 0:
@@ -819,14 +918,31 @@ class RailEnv(Environment):
                 file_out.write(self.get_full_state_msg())
 
     def load(self, filename):
+        """
+        Load environment with distance map from a file
+
+        Parameters:
+        -------
+        filename: string
+        """
         with open(filename, "rb") as file_in:
             load_data = file_in.read()
             self.set_full_state_dist_msg(load_data)
 
     def load_pkl(self, pkl_data):
+        """
+        Load environment with distance map from a pickle file
+
+        Parameters:
+        -------
+        pkl_data: pickle file
+        """
         self.set_full_state_msg(pkl_data)
 
     def load_resource(self, package, resource):
+        """
+        Load environment with distance map from a binary
+        """
         from importlib_resources import read_binary
         load_data = read_binary(package, resource)
         self.set_full_state_msg(load_data)
-- 
GitLab