From f23796428a4f02faa5359f4878a7515806181ff9 Mon Sep 17 00:00:00 2001
From: hagrid67 <jdhwatson@gmail.com>
Date: Thu, 9 May 2019 23:34:31 +0100
Subject: [PATCH] moved agent_* lists to a list of EnvAgents

---
 examples/play_model.py                   |   7 +-
 flatland/core/env_observation_builder.py |  81 ++++++-----
 flatland/envs/agent_utils.py             | 176 +++++------------------
 flatland/envs/rail_env.py                | 109 ++++++++------
 flatland/utils/rendertools.py            |  17 +--
 tests/test_environments.py               |  52 ++++---
 6 files changed, 181 insertions(+), 261 deletions(-)

diff --git a/examples/play_model.py b/examples/play_model.py
index e69b312b..db4109d1 100644
--- a/examples/play_model.py
+++ b/examples/play_model.py
@@ -29,7 +29,8 @@ class Player(object):
         self.action_prob = [0]*4
         self.agent = Agent(self.state_size, self.action_size, "FC", 0)
         # self.agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint9900.pth'))
-        self.agent.qnetwork_local.load_state_dict(torch.load('../flatland/flatland/baselines/Nets/avoid_checkpoint15000.pth'))
+        self.agent.qnetwork_local.load_state_dict(torch.load(
+            '../flatland/flatland/baselines/Nets/avoid_checkpoint15000.pth'))
 
         self.iFrame = 0
         self.tStart = time.time()
@@ -202,7 +203,7 @@ def main(render=True, delay=0.0):
         if trials % 100 == 0:
             tNow = time.time()
             rFps = iFrame / (tNow - tStart)
-            print(('\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%' + 
+            print(('\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%' +
                    '\tEpsilon: {:.2f} fps: {:.2f} \t Action Probabilities: \t {}').format(
                    env.number_of_agents,
                    trials,
@@ -215,4 +216,4 @@ def main(render=True, delay=0.0):
 
 
 if __name__ == "__main__":
-    main()
\ No newline at end of file
+    main()
diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py
index a6fbae6d..43884a40 100644
--- a/flatland/core/env_observation_builder.py
+++ b/flatland/core/env_observation_builder.py
@@ -61,19 +61,23 @@ class TreeObsForRailEnv(ObservationBuilder):
         self.max_depth = max_depth
 
     def reset(self):
-        self.distance_map = np.inf * np.ones(shape=(self.env.number_of_agents,
+        agents = self.env.agents
+        nAgents = len(agents)
+        self.distance_map = np.inf * np.ones(shape=(nAgents,  # self.env.number_of_agents,
                                                     self.env.height,
                                                     self.env.width,
                                                     4))
-        self.max_dist = np.zeros(self.env.number_of_agents)
+        self.max_dist = np.zeros(nAgents)
 
-        for i in range(self.env.number_of_agents):
-            self.max_dist[i] = self._distance_map_walker(self.env.agents_target[i], i)
+        # for i in range(nAgents):
+        #     self.max_dist[i] = self._distance_map_walker(self.env.agents_target[i], i)
+        self.max_dist = [self._distance_map_walker(agent.target, i) for i, agent in enumerate(agents)]
 
         # Update local lookup table for all agents' target locations
         self.location_has_target = {}
-        for loc in self.env.agents_target:
-            self.location_has_target[(loc[0], loc[1])] = 1
+        # for loc in self.env.agents_target:
+        #    self.location_has_target[(loc[0], loc[1])] = 1
+        self.location_has_target = {agent.position: 1 for agent in agents}
 
     def _distance_map_walker(self, position, target_nr):
         """
@@ -229,28 +233,33 @@ class TreeObsForRailEnv(ObservationBuilder):
         """
 
         # Update local lookup table for all agents' positions
-        self.location_has_agent = {}
-        for loc in self.env.agents_position:
-            self.location_has_agent[(loc[0], loc[1])] = 1
-
-        position = self.env.agents_position[handle]
-        orientation = self.env.agents_direction[handle]
-        possible_transitions = self.env.rail.get_transitions((position[0], position[1], orientation))
+        # self.location_has_agent = {}
+        # for loc in self.env.agents_position:
+        #    self.location_has_agent[(loc[0], loc[1])] = 1
+        self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents}
+
+        agent = self.env.agents[handle]  # TODO: handle being treated as index
+        # position = self.env.agents_position[handle]
+        # orientation = self.env.agents_direction[handle]
+        possible_transitions = self.env.rail.get_transitions((*agent.position, agent.direction))
         num_transitions = np.count_nonzero(possible_transitions)
         # Root node - current position
-        observation = [0, 0, 0, 0, self.distance_map[handle, position[0], position[1], orientation]]
+        # observation = [0, 0, 0, 0, self.distance_map[handle, position[0], position[1], orientation]]
+        observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)]]
         root_observation = observation[:]
 
         # Start from the current orientation, and see which transitions are available;
         # organize them as [left, forward, right, back], relative to the current orientation
         # If only one transition is possible, the tree is oriented with this transition as the forward branch.
         # TODO: Test if this works as desired!
+        orientation = agent.direction
         if num_transitions == 1:
             orientation == np.argmax(possible_transitions)
 
-        for branch_direction in [(orientation + 4 + i) % 4 for i in range(-1, 3)]:
+        # for branch_direction in [(orientation + 4 + i) % 4 for i in range(-1, 3)]:
+        for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]:
             if possible_transitions[branch_direction]:
-                new_cell = self._new_position(position, branch_direction)
+                new_cell = self._new_position(agent.position, branch_direction)
 
                 branch_observation = self._explore_branch(handle, new_cell, branch_direction, root_observation, 1)
                 observation = observation + branch_observation
@@ -307,17 +316,18 @@ class TreeObsForRailEnv(ObservationBuilder):
             visited.add((position[0], position[1], direction))
 
             # If the target node is encountered, pick that as node. Also, no further branching is possible.
-            if position[0] == self.env.agents_target[handle][0] and position[1] == self.env.agents_target[handle][1]:
+            # if position[0] == self.env.agents_target[handle][0] and position[1] == self.env.agents_target[handle][1]:
+            if np.array_equal(position, self.env.agents[handle].target):
                 last_isTarget = True
                 break
 
-            cell_transitions = self.env.rail.get_transitions((position[0], position[1], direction))
+            cell_transitions = self.env.rail.get_transitions((*position, direction))
             num_transitions = np.count_nonzero(cell_transitions)
             exploring = False
             if num_transitions == 1:
                 # Check if dead-end, or if we can go forward along direction
                 nbits = 0
-                tmp = self.env.rail.get_transitions((position[0], position[1]))
+                tmp = self.env.rail.get_transitions(tuple(position))
                 while tmp > 0:
                     nbits += (tmp & 1)
                     tmp = tmp >> 1
@@ -380,9 +390,9 @@ class TreeObsForRailEnv(ObservationBuilder):
         # Start from the current orientation, and see which transitions are available;
         # organize them as [left, forward, right, back], relative to the current orientation
         # Get the possible transitions
-        possible_transitions = self.env.rail.get_transitions((position[0], position[1], direction))
+        possible_transitions = self.env.rail.get_transitions((*position, direction))
         for branch_direction in [(direction + 4 + i) % 4 for i in range(-1, 3)]:
-            if last_isDeadEnd and self.env.rail.get_transition((position[0], position[1], direction),
+            if last_isDeadEnd and self.env.rail.get_transition((*position, direction),
                                                                (branch_direction + 2) % 4):
                 # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes
                 # it back
@@ -471,20 +481,21 @@ class GlobalObsForRailEnv(ObservationBuilder):
         #     self.targets[target_pos] += 1
 
     def get(self, handle):
-        obs_agents_targets_pos = np.zeros((4, self.env.height, self.env.width))
-        agent_pos = self.env.agents_position[handle]
-        obs_agents_targets_pos[0][agent_pos] += 1
-        for i in range(len(self.env.agents_position)):
-            if i != handle:
-                obs_agents_targets_pos[3][self.env.agents_position[i]] += 1
-
-        agent_target_pos = self.env.agents_target[handle]
-        obs_agents_targets_pos[1][agent_target_pos] += 1
-        for i in range(len(self.env.agents_target)):
-            if i != handle:
-                obs_agents_targets_pos[2][self.env.agents_target[i]] += 1
+        obs = np.zeros((4, self.env.height, self.env.width))
+        agents = self.env.agents
+        agent = agents[handle]
+
+        agent_pos = agents[handle].position
+        obs[0][agent_pos] += 1
+        obs[1][agent.target] += 1
+
+        for i in range(len(agents)):
+            if i != handle:   # TODO: handle used as index...?
+                agent2 = agents[i]
+                obs[3][agent2.position] += 1
+                obs[2][agent2.target] += 1
 
         direction = np.zeros(4)
-        direction[self.env.agents_direction[handle]] = 1
+        direction[agent.direction] = 1
 
-        return self.rail_obs, obs_agents_targets_pos, direction
+        return self.rail_obs, obs, direction
diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py
index c29839e6..da36fe73 100644
--- a/flatland/envs/agent_utils.py
+++ b/flatland/envs/agent_utils.py
@@ -1,7 +1,17 @@
 
 from attr import attrs, attrib
-from itertools import starmap, count
-import numpy as np
+from itertools import starmap
+# from flatland.envs.rail_env import RailEnv
+
+
+@attrs
+class EnvDescription(object):
+    n_agents = attrib()
+    height = attrib()
+    width = attrib()
+    rail_generator = attrib()
+    obs_builder = attrib()
+
 
 @attrs
 class EnvAgentStatic(object):
@@ -13,157 +23,41 @@ class EnvAgentStatic(object):
     position = attrib()
     direction = attrib()
     target = attrib()
-    handle = attrib()
 
-    next_handle = 0
+    next_handle = 0  # this is not properly implemented
 
     @classmethod
-    def from_lists(positions, directions, targets):
+    def from_lists(cls, positions, directions, targets):
         """ Create a list of EnvAgentStatics from lists of positions, directions and targets
         """
-        return starmap(EnvAgentStatic, zip(positions, directions, targets, count()))
+        return list(starmap(EnvAgentStatic, zip(positions, directions, targets)))
         
 
+@attrs
 class EnvAgent(EnvAgentStatic):
-    """ TODO: EnvAgent - replace separate agent lists with a single list
+    """ EnvAgent - replace separate agent_* lists with a single list
         of agent objects.  The EnvAgent represent's the environment's view
-        of the dynamic agent state.  So target is not part of it - target is
-        static.
+        of the dynamic agent state.
+        We are duplicating target in the EnvAgent, which seems simpler than
+        forcing the env to refer to it in the EnvAgentStatic
     """
+    handle = attrib(default=None)
 
-
-class EnvManager(object):
-    def __init__(self, env=None):
-        self.env = env
-        
-
-    def load_env(self, sFilename):
-        pass
-    
-    def save_env(self, sFilename):
-        pass
-    
-    def regen_rail(self):
-        pass
-
-    def replace_agents(self):
-        pass
-
-    def add_agent_static(self, agent_static):
-        """ Add a new agent_static
-        """
-        iAgent = self.number_of_agents
-
-        if iDir is None:
-            iDir = self.pick_agent_direction(rcPos, rcTarget)
-        if iDir is None:
-            print("Error picking agent direction at pos:", rcPos)
-            return None
-
-        self.agents_position.append(tuple(rcPos))  # ensure it's a tuple not a list
-        self.agents_handles.append(max(self.agents_handles + [-1]) + 1)  # max(handles) + 1, starting at 0
-        self.agents_direction.append(iDir)
-        self.agents_target.append(rcPos)  # set the target to the origin initially
-        self.number_of_agents += 1
-        self.check_agent_lists()
-        return iAgent
-
-
-
-    def add_agent_old(self, rcPos=None, rcTarget=None, iDir=None):
-        """ Add a new agent at position rcPos with target rcTarget and
-            initial direction index iDir.
-            Should also store this initial position etc as environment "meta-data"
-            but this does not yet exist.
+    @classmethod
+    def from_static(cls, oStatic):
+        """ Create an EnvAgent from the EnvAgentStatic,
+        copying all the fields, and adding handle with the default 0.
         """
-        self.check_agent_lists()
-
-        if rcPos is None:
-            rcPos = np.random.choice(len(self.valid_positions))
-
-        iAgent = self.number_of_agents
-
-        if iDir is None:
-            iDir = self.pick_agent_direction(rcPos, rcTarget)
-        if iDir is None:
-            print("Error picking agent direction at pos:", rcPos)
-            return None
-
-        self.agents_position.append(tuple(rcPos))  # ensure it's a tuple not a list
-        self.agents_handles.append(max(self.agents_handles + [-1]) + 1)  # max(handles) + 1, starting at 0
-        self.agents_direction.append(iDir)
-        self.agents_target.append(rcPos)  # set the target to the origin initially
-        self.number_of_agents += 1
-        self.check_agent_lists()
-        return iAgent
-
-    def fill_valid_positions(self):
-        ''' Populate the valid_positions list for the current TransitionMap.
-            TODO: put this elsewhere
-        '''
-        self.env.valid_positions = valid_positions = []
-        for r in range(self.env.height):
-            for c in range(self.env.width):
-                if self.env.rail.get_transitions((r, c)) > 0:
-                    valid_positions.append((r, c))
+        return EnvAgent(*oStatic.__dict__, handle=0)
 
-    def check_agent_lists(self):
-        ''' Check that the agent_handles, position and direction lists are all of length
-            number_of_agents.
-            (Suggest this is replaced with a single list of Agent objects :)
-        '''
-        for lAgents, name in zip(
-                [self.env.agents_handles, self.env.agents_position, self.env.agents_direction],
-                ["handles", "positions", "directions"]):
-            assert self.env.number_of_agents == len(lAgents), "Inconsistent agent list:" + name
-
-    def check_agent_locdirpath(self, iAgent):
-        ''' Check that agent iAgent has a valid location and direction,
-            with a path to its target.
-            (Not currently used?)
-        '''
-        valid_movements = []
-        for direction in range(4):
-            position = self.env.agents_position[iAgent]
-            moves = self.env.rail.get_transitions((position[0], position[1], direction))
-            for move_index in range(4):
-                if moves[move_index]:
-                    valid_movements.append((direction, move_index))
-
-        valid_starting_directions = []
-        for m in valid_movements:
-            new_position = self.env._new_position(self.env.agents_position[iAgent], m[1])
-            if m[0] not in valid_starting_directions and \
-                    self.env._path_exists(new_position, m[0], self.env.agents_target[iAgent]):
-                valid_starting_directions.append(m[0])
-
-        if len(valid_starting_directions) == 0:
-            return False
-        else:
-            return True
-
-    def pick_agent_direction(self, rcPos, rcTarget):
-        """ Pick and return a valid direction index (0..3) for an agent starting at
-            row,col rcPos with target rcTarget.
-            Return None if no path exists.
-            Picks random direction if more than one exists (uniformly).
+    @classmethod
+    def list_from_static(cls, lEnvAgentStatic, handles=None):
+        """ Create an EnvAgent from the EnvAgentStatic,
+        copying all the fields, and adding handle with the default 0.
         """
-        valid_movements = []
-        for direction in range(4):
-            moves = self.env.rail.get_transitions((*rcPos, direction))
-            for move_index in range(4):
-                if moves[move_index]:
-                    valid_movements.append((direction, move_index))
-        # print("pos", rcPos, "targ", rcTarget, "valid movements", valid_movements)
-
-        valid_starting_directions = []
-        for m in valid_movements:
-            new_position = self.env._new_position(rcPos, m[1])
-            if m[0] not in valid_starting_directions and self.env._path_exists(new_position, m[0], rcTarget):
-                valid_starting_directions.append(m[0])
-
-        if len(valid_starting_directions) == 0:
-            return None
-        else:
-            return valid_starting_directions[np.random.choice(len(valid_starting_directions), 1)[0]]
+        if handles is None:
+            handles = range(len(lEnvAgentStatic))
+            
+        return [EnvAgent(**oEAS.__dict__, handle=handle)
+                for handle, oEAS in zip(handles, lEnvAgentStatic)]
 
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index ea8c3dca..9767bba4 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -10,7 +10,7 @@ from flatland.core.env import Environment
 from flatland.core.env_observation_builder import TreeObsForRailEnv
 from flatland.envs.generators import random_rail_generator
 from flatland.envs.env_utils import get_new_position
-from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent, EnvManager
+from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent
 
 # from flatland.core.transitions import Grid8Transitions, RailEnvTransitions
 # from flatland.core.transition_map import GridTransitionMap
@@ -124,10 +124,11 @@ class RailEnv(Environment):
 
         if replace_agents:
             self.agents_static = EnvAgentStatic.from_lists(agents_position, agents_direction, agents_target)
-            self.agents = copy(agents_static)
+            self.agents = EnvAgent.list_from_static(self.agents_static[:len(self.agents_handles)])
 
         self.num_resets += 1
 
+        # perhaps dones should be part of each agent.
         self.dones = {"__all__": False}
         for handle in self.agents_handles:
             self.dones[handle] = False
@@ -157,11 +158,12 @@ class RailEnv(Environment):
         for i in range(len(self.agents_handles)):
             handle = self.agents_handles[i]
             transition_isValid = None
+            agent = self.agents[i]
 
-            if handle not in action_dict:
+            if handle not in action_dict:  # no action has been supplied for this agent
                 continue
 
-            if self.dones[handle]:
+            if self.dones[handle]:  # this agent has already completed...
                 continue
             action = action_dict[handle]
 
@@ -171,31 +173,28 @@ class RailEnv(Environment):
                 return
 
             if action > 0:
-                pos = self.agents_position[i]
-                direction = self.agents_direction[i]
+                # pos = agent.position #  self.agents_position[i]
+                # direction = agent.direction # self.agents_direction[i]
 
                 # compute number of possible transitions in the current
                 # cell used to check for invalid actions
 
-                possible_transitions = self.rail.get_transitions((pos[0], pos[1], direction))
+                possible_transitions = self.rail.get_transitions((*agent.position, agent.direction))
                 num_transitions = np.count_nonzero(possible_transitions)
 
-                movement = direction
+                movement = agent.direction
                 # print(nbits,np.sum(possible_transitions))
                 if action == 1:
-                    movement = direction - 1
+                    movement = agent.direction - 1
                     if num_transitions <= 1:
                         transition_isValid = False
 
                 elif action == 3:
-                    movement = direction + 1
+                    movement = agent.direction + 1
                     if num_transitions <= 1:
                         transition_isValid = False
 
-                if movement < 0:
-                    movement += 4
-                if movement >= 4:
-                    movement -= 4
+                movement %= 4
 
                 if action == 2:
                     if num_transitions == 1:
@@ -205,57 +204,72 @@ class RailEnv(Environment):
                         movement = np.argmax(possible_transitions)
                         transition_isValid = True
 
-                new_position = get_new_position(pos, movement)
-                # Is it a legal move?  1) transition allows the movement in the
-                # cell,  2) the new cell is not empty (case 0),  3) the cell is
-                # free, i.e., no agent is currently in that cell
-                if (
-                        new_position[1] >= self.width or
-                        new_position[0] >= self.height or
-                        new_position[0] < 0 or new_position[1] < 0):
-                    new_cell_isValid = False
-
-                elif self.rail.get_transitions((new_position[0], new_position[1])) > 0:
-                    new_cell_isValid = True
-                else:
-                    new_cell_isValid = False
+                new_position = get_new_position(agent.position, movement)
+                # Is it a legal move?
+                # 1) transition allows the movement in the cell,
+                # 2) the new cell is not empty (case 0),
+                # 3) the cell is free, i.e., no agent is currently in that cell
+                
+                # if (
+                #        new_position[1] >= self.width or
+                #        new_position[0] >= self.height or
+                #        new_position[0] < 0 or new_position[1] < 0):
+                #    new_cell_isValid = False
+
+                # if self.rail.get_transitions(new_position) == 0:
+                #     new_cell_isValid = False
+
+                new_cell_isValid = (
+                        np.array_equal(  # Check the new position is still in the grid
+                            new_position,
+                            np.clip(new_position, [0, 0], [self.height-1, self.width-1]))
+                        and  # check the new position has some transitions (ie is not an empty cell)
+                        self.rail.get_transitions(new_position) > 0)
 
                 # If transition validity hasn't been checked yet.
                 if transition_isValid is None:
                     transition_isValid = self.rail.get_transition(
-                        (pos[0], pos[1], direction),
+                        (*agent.position, agent.direction),
                         movement)
 
-                cell_isFree = True
-                for j in range(self.number_of_agents):
-                    if self.agents_position[j] == new_position:
-                        cell_isFree = False
-                        break
-
-                if new_cell_isValid and transition_isValid and cell_isFree:
+                # cell_isFree = True
+                # for j in range(self.number_of_agents):
+                #    if self.agents_position[j] == new_position:
+                #        cell_isFree = False
+                #        break
+                # Check the new position is not the same as any of the existing agent positions
+                # (including itself, for simplicity, since it is moving)
+                cell_isFree = not np.any(
+                        np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1))
+
+                if all([new_cell_isValid, transition_isValid, cell_isFree]):
                     # move and change direction to face the movement that was
                     # performed
-                    self.agents_position[i] = new_position
-                    self.agents_direction[i] = movement
+                    # self.agents_position[i] = new_position
+                    # self.agents_direction[i] = movement
+                    agent.position = new_position
+                    agent.direction = movement
                 else:
                     # the action was not valid, add penalty
                     self.rewards_dict[handle] += invalid_action_penalty
 
             # if agent is not in target position, add step penalty
-            if self.agents_position[i][0] == self.agents_target[i][0] and \
-                    self.agents_position[i][1] == self.agents_target[i][1]:
+            # if self.agents_position[i][0] == self.agents_target[i][0] and \
+            #        self.agents_position[i][1] == self.agents_target[i][1]:
+            #    self.dones[handle] = True
+            if np.equal(agent.position, agent.target).all():
                 self.dones[handle] = True
             else:
                 self.rewards_dict[handle] += step_penalty
 
         # Check for end of episode + add global reward to all rewards!
-        num_agents_in_target_position = 0
-        for i in range(self.number_of_agents):
-            if self.agents_position[i][0] == self.agents_target[i][0] and \
-                    self.agents_position[i][1] == self.agents_target[i][1]:
-                num_agents_in_target_position += 1
-
-        if num_agents_in_target_position == self.number_of_agents:
+        # num_agents_in_target_position = 0
+        # for i in range(self.number_of_agents):
+        #    if self.agents_position[i][0] == self.agents_target[i][0] and \
+        #            self.agents_position[i][1] == self.agents_target[i][1]:
+        #        num_agents_in_target_position += 1
+        # if num_agents_in_target_position == self.number_of_agents:
+        if np.all([np.array_equal(agent2.position, agent2.target) for agent2 in self.agents]):
             self.dones["__all__"] = True
             self.rewards_dict = [r + global_reward for r in self.rewards_dict]
 
@@ -273,3 +287,4 @@ class RailEnv(Environment):
     def render(self):
         # TODO:
         pass
+
diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py
index 1f731c39..c2fcb73b 100644
--- a/flatland/utils/rendertools.py
+++ b/flatland/utils/rendertools.py
@@ -158,20 +158,9 @@ class RenderTool(object):
 
     def plotAgents(self, targets=True):
         cmap = self.gl.get_cmap('hsv', lut=self.env.number_of_agents + 1)
-        for iAgent in range(self.env.number_of_agents):
+        for iAgent, agent in enumerate(self.env.agents):
             oColor = cmap(iAgent)
-
-            rcPos = self.env.agents_position[iAgent]
-            iDir = self.env.agents_direction[iAgent]  # agent direction index
-
-            if targets:
-                target = self.env.agents_target[iAgent]
-            else:
-                target = None
-            self.plotAgent(rcPos, iDir, oColor, target=target)
-
-            # gTransRCAg = self.getTransRC(rcPos, iDir)
-            # self.plotTrans(rcPos, gTransRCAg)
+            self.plotAgent(agent.position, agent.direction, oColor, target=agent.target if targets else None)
 
     def getTransRC(self, rcPos, iDir, bgiTrans=False):
         """
@@ -554,7 +543,7 @@ class RenderTool(object):
 
                 if not bCellValid:
                     # print("invalid:", r, c)
-                    self.gl.scatter(*xyCentre, color="r", s=50)
+                    self.gl.scatter(*xyCentre, color="r", s=30)
 
                 for orientation in range(4):  # ori is where we're heading
                     from_ori = (orientation + 2) % 4  # 0123=NESW -> 2301=SWNE
diff --git a/tests/test_environments.py b/tests/test_environments.py
index a10fb061..fe788b7c 100644
--- a/tests/test_environments.py
+++ b/tests/test_environments.py
@@ -7,7 +7,7 @@ from flatland.envs.generators import rail_from_GridTransitionMap_generator
 from flatland.core.transitions import Grid4Transitions
 from flatland.core.transition_map import GridTransitionMap
 from flatland.core.env_observation_builder import GlobalObsForRailEnv
-
+from flatland.envs.agent_utils import EnvAgent
 
 """Tests for `flatland` package."""
 
@@ -58,16 +58,21 @@ def test_rail_environment_single_agent():
         _ = rail_env.reset()
 
         # We do not care about target for the moment
-        rail_env.agents_target[0] = [-1, -1]
+        # rail_env.agents_target[0] = [-1, -1]
+        agent = rail_env.agents[0]
+        # rail_env.agents[0].target = [-1, -1]
+        agent.target = [-1, -1]
 
         # Check that trains are always initialized at a consistent position
         # or direction.
         # They should always be able to go somewhere.
         assert(transitions.get_transitions(
-            rail_map[rail_env.agents_position[0]],
-            rail_env.agents_direction[0]) != (0, 0, 0, 0))
+            # rail_map[rail_env.agents_position[0]],
+            # rail_env.agents_direction[0]) != (0, 0, 0, 0))
+            rail_map[agent.position],
+            agent.direction) != (0, 0, 0, 0))
 
-        initial_pos = rail_env.agents_position[0]
+        initial_pos = agent.position
 
         valid_active_actions_done = 0
         pos = initial_pos
@@ -78,13 +83,13 @@ def test_rail_environment_single_agent():
             _, _, _, _ = rail_env.step({0: action})
 
             prev_pos = pos
-            pos = rail_env.agents_position[0]
+            pos = agent.position  # rail_env.agents_position[0]
             if prev_pos != pos:
                 valid_active_actions_done += 1
 
         # After 6 movements on this railway network, the train should be back
         # to its original height on the map.
-        assert(initial_pos[0] == rail_env.agents_position[0][0])
+        assert(initial_pos[0] == agent.position[0])
 
         # We check that the train always attains its target after some time
         for _ in range(10):
@@ -135,13 +140,14 @@ def test_dead_end():
         # We run step to check that trains do not move anymore
         # after being done.
         for i in range(7):
-            prev_pos = rail_env.agents_position[0]
+            # prev_pos = rail_env.agents_position[0]
+            prev_pos = rail_env.agents[0].position
 
             # The train cannot turn, so we check that when it tries,
             # it stays where it is.
             _ = rail_env.step({0: 1})
             _ = rail_env.step({0: 3})
-            assert (rail_env.agents_position[0] == prev_pos)
+            assert (rail_env.agents[0].position == prev_pos)
             _, _, dones, _ = rail_env.step({0: 2})
 
             if i < 5:
@@ -151,15 +157,17 @@ def test_dead_end():
 
     # We try the configuration in the 4 directions:
     rail_env.reset()
-    rail_env.agents_target[0] = (0, 0)
-    rail_env.agents_position[0] = (0, 2)
-    rail_env.agents_direction[0] = 1
+    # rail_env.agents_target[0] = (0, 0)
+    # rail_env.agents_position[0] = (0, 2)
+    # rail_env.agents_direction[0] = 1
+    rail_env.agents = [EnvAgent(position=(0, 2), direction=1, target=(0, 0))]
     check_consistency(rail_env)
 
     rail_env.reset()
-    rail_env.agents_target[0] = (0, 4)
-    rail_env.agents_position[0] = (0, 2)
-    rail_env.agents_direction[0] = 3
+    # rail_env.agents_target[0] = (0, 4)
+    # rail_env.agents_position[0] = (0, 2)
+    # rail_env.agents_direction[0] = 3
+    rail_env.agents = [EnvAgent(position=(0, 2), direction=3, target=(0, 4))]
     check_consistency(rail_env)
 
     # In the vertical configuration:
@@ -181,13 +189,15 @@ def test_dead_end():
                        obs_builder_object=GlobalObsForRailEnv())
 
     rail_env.reset()
-    rail_env.agents_target[0] = (0, 0)
-    rail_env.agents_position[0] = (2, 0)
-    rail_env.agents_direction[0] = 2
+    # rail_env.agents_target[0] = (0, 0)
+    # rail_env.agents_position[0] = (2, 0)
+    # rail_env.agents_direction[0] = 2
+    rail_env.agents = [EnvAgent(position=(2, 0), direction=2, target=(0, 0))]
     check_consistency(rail_env)
 
     rail_env.reset()
-    rail_env.agents_target[0] = (4, 0)
-    rail_env.agents_position[0] = (2, 0)
-    rail_env.agents_direction[0] = 0
+    # rail_env.agents_target[0] = (4, 0)
+    # rail_env.agents_position[0] = (2, 0)
+    # rail_env.agents_direction[0] = 0
+    rail_env.agents = [EnvAgent(position=(2, 0), direction=0, target=(4, 0))]
     check_consistency(rail_env)
-- 
GitLab