From fcc163ab51722c9aedb6affba7380b2e15a66438 Mon Sep 17 00:00:00 2001
From: hagrid67 <jdhwatson@gmail.com>
Date: Wed, 1 May 2019 19:46:23 +0100
Subject: [PATCH] Added comments and fixed lint

---
 flatland/core/env_observation_builder.py |   2 +-
 flatland/envs/rail_env.py                | 103 ++++++++++++++---------
 2 files changed, 62 insertions(+), 43 deletions(-)

diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py
index 1ae2819d..de3ee932 100644
--- a/flatland/core/env_observation_builder.py
+++ b/flatland/core/env_observation_builder.py
@@ -137,7 +137,7 @@ class TreeObsForRailEnv(ObservationBuilder):
             new_cell = self._new_position(position, neigh_direction)
 
             if new_cell[0] >= 0 and new_cell[0] < self.env.height and \
-                new_cell[1] >= 0 and new_cell[1] < self.env.width:
+                    new_cell[1] >= 0 and new_cell[1] < self.env.width:
 
                 desired_movement_from_new_cell = (neigh_direction + 2) % 4
 
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index f1a3d20c..eff1c108 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -163,9 +163,9 @@ def a_star(rail_trans, rail_array, start, end):
         for new_pos in [(0, -1), (0, 1), (-1, 0), (1, 0)]:
             node_pos = (current_node.pos[0] + new_pos[0], current_node.pos[1] + new_pos[1])
             if node_pos[0] >= rail_shape[0] or \
-                node_pos[0] < 0 or \
-                node_pos[1] >= rail_shape[1] or \
-                node_pos[1] < 0:
+                    node_pos[0] < 0 or \
+                    node_pos[1] >= rail_shape[1] or \
+                    node_pos[1] < 0:
                 continue
 
             # validate positions
@@ -540,8 +540,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
             for i in range(len(transitions_templates_)):
                 is_match = True
                 for j in range(4):
-                    if template[j] >= 0 and \
-                        template[j] != transitions_templates_[i][0][j]:
+                    if template[j] >= 0 and template[j] != transitions_templates_[i][0][j]:
                         is_match = False
                         break
                 if is_match:
@@ -742,6 +741,29 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
     return generator
 
 
+class EnvAgentStatic(object):
+    """ TODO: EnvAgentStatic - To store initial position, direction and target.
+        This is like static data for the environment - it's where an agent starts,
+        rather than where it is at the moment.
+        The target should also be stored here.
+    """
+    def __init__(self, rcPos, iDir, rcTarget):
+        self.rcPos = rcPos
+        self.iDir = iDir
+        self.rcTarget = rcTarget
+
+
+class EnvAgent(object):
+    """ TODO: EnvAgent - replace separate agent lists with a single list
+        of agent objects.  The EnvAgent represent's the environment's view
+        of the dynamic agent state.  So target is not part of it - target is
+        static.
+    """
+    def __init__(self, rcPos, iDir):
+        self.rcPos = rcPos
+        self.iDir = iDir
+
+
 class RailEnv(Environment):
     """
     RailEnv environment class.
@@ -836,6 +858,8 @@ class RailEnv(Environment):
         return self.agents_handles
 
     def fill_valid_positions(self):
+        ''' Populate the valid_positions list for the current TransitionMap.
+        '''
         self.valid_positions = valid_positions = []
         for r in range(self.height):
             for c in range(self.width):
@@ -843,12 +867,20 @@ class RailEnv(Environment):
                     valid_positions.append((r, c))
 
     def check_agent_lists(self):
+        ''' Check that the agent_handles, position and direction lists are all of length
+            number_of_agents.
+            (Suggest this is replaced with a single list of Agent objects :)
+        '''
         for lAgents, name in zip(
-            [self.agents_handles, self.agents_position, self.agents_direction],
-            ["handles", "positions", "directions"]):
+                [self.agents_handles, self.agents_position, self.agents_direction],
+                ["handles", "positions", "directions"]):
             assert self.number_of_agents == len(lAgents), "Inconsistent agent list:" + name
 
     def check_agent_locdirpath(self, iAgent):
+        ''' Check that agent iAgent has a valid location and direction,
+            with a path to its target.
+            (Not currently used?)
+        '''
         valid_movements = []
         for direction in range(4):
             position = self.agents_position[iAgent]
@@ -861,13 +893,20 @@ class RailEnv(Environment):
         for m in valid_movements:
             new_position = self._new_position(self.agents_position[iAgent], m[1])
             if m[0] not in valid_starting_directions and \
-                self._path_exists(new_position, m[0], self.agents_target[iAgent]):
+                    self._path_exists(new_position, m[0], self.agents_target[iAgent]):
                 valid_starting_directions.append(m[0])
 
         if len(valid_starting_directions) == 0:
             return False
+        else:
+            return True
 
     def pick_agent_direction(self, rcPos, rcTarget):
+        """ Pick and return a valid direction index (0..3) for an agent starting at
+            row,col rcPos with target rcTarget.
+            Return None if no path exists.
+            Picks random direction if more than one exists (uniformly).
+        """
         valid_movements = []
         for direction in range(4):
             moves = self.rail.get_transitions((*rcPos, direction))
@@ -879,8 +918,7 @@ class RailEnv(Environment):
         valid_starting_directions = []
         for m in valid_movements:
             new_position = self._new_position(rcPos, m[1])
-            if m[0] not in valid_starting_directions and \
-                self._path_exists(new_position, m[0], rcTarget):
+            if m[0] not in valid_starting_directions and self._path_exists(new_position, m[0], rcTarget):
                 valid_starting_directions.append(m[0])
 
         if len(valid_starting_directions) == 0:
@@ -889,6 +927,11 @@ class RailEnv(Environment):
             return valid_starting_directions[np.random.choice(len(valid_starting_directions), 1)[0]]
 
     def add_agent(self, rcPos=None, rcTarget=None, iDir=None):
+        """ Add a new agent at position rcPos with target rcTarget and
+            initial direction index iDir.
+            Should also store this initial position etc as environment "meta-data"
+            but this does not yet exist.
+        """
         self.check_agent_lists()
 
         if rcPos is None:
@@ -949,31 +992,7 @@ class RailEnv(Environment):
                         break
                     else:
                         self.agents_direction[i] = direction
-
-                # Jeremy extracted this into the method pick_agent_direction
-                if False:
-                    for i in range(self.number_of_agents):
-                        valid_movements = []
-                        for direction in range(4):
-                            position = self.agents_position[i]
-                            moves = self.rail.get_transitions((position[0], position[1], direction))
-                            for move_index in range(4):
-                                if moves[move_index]:
-                                    valid_movements.append((direction, move_index))
-
-                        valid_starting_directions = []
-                        for m in valid_movements:
-                            new_position = self._new_position(self.agents_position[i], m[1])
-                            if m[0] not in valid_starting_directions and \
-                                self._path_exists(new_position, m[0], self.agents_target[i]):
-                                valid_starting_directions.append(m[0])
-
-                        if len(valid_starting_directions) == 0:
-                            re_generate = True
-                        else:
-                            self.agents_direction[i] = valid_starting_directions[
-                                np.random.choice(len(valid_starting_directions), 1)[0]]
-
+                
         # Reset the state of the observation builder with the new environment
         self.obs_builder.reset()
 
@@ -1079,14 +1098,14 @@ class RailEnv(Environment):
                                 movement = curv_dir
                             curv_dir = (curv_dir + 1) % 4
 
-
                 new_position = self._new_position(pos, movement)
                 # Is it a legal move?  1) transition allows the movement in the
                 # cell,  2) the new cell is not empty (case 0),  3) the cell is
                 # free, i.e., no agent is currently in that cell
-                if new_position[1] >= self.width or \
-                    new_position[0] >= self.height or \
-                    new_position[0] < 0 or new_position[1] < 0:
+                if (
+                        new_position[1] >= self.width or
+                        new_position[0] >= self.height or
+                        new_position[0] < 0 or new_position[1] < 0):
                     new_cell_isValid = False
 
                 elif self.rail.get_transitions((new_position[0], new_position[1])) > 0:
@@ -1095,7 +1114,7 @@ class RailEnv(Environment):
                     new_cell_isValid = False
 
                 # If transition validity hasn't been checked yet.
-                if transition_isValid == None:
+                if transition_isValid is None:
                     transition_isValid = self.rail.get_transition(
                         (pos[0], pos[1], direction),
                         movement) or is_deadend
@@ -1117,7 +1136,7 @@ class RailEnv(Environment):
 
             # if agent is not in target position, add step penalty
             if self.agents_position[i][0] == self.agents_target[i][0] and \
-                self.agents_position[i][1] == self.agents_target[i][1]:
+                    self.agents_position[i][1] == self.agents_target[i][1]:
                 self.dones[handle] = True
             else:
                 self.rewards_dict[handle] += step_penalty
@@ -1126,7 +1145,7 @@ class RailEnv(Environment):
         num_agents_in_target_position = 0
         for i in range(self.number_of_agents):
             if self.agents_position[i][0] == self.agents_target[i][0] and \
-                self.agents_position[i][1] == self.agents_target[i][1]:
+                    self.agents_position[i][1] == self.agents_target[i][1]:
                 num_agents_in_target_position += 1
 
         if num_agents_in_target_position == self.number_of_agents:
-- 
GitLab