diff --git a/examples/play_model.py b/examples/play_model.py
index e69b312b1ceb2f450256d247f4b63c14a728acb5..9c67b0bce315ecf028fe898c510e4503e67a8cf4 100644
--- a/examples/play_model.py
+++ b/examples/play_model.py
@@ -29,7 +29,8 @@ class Player(object):
         self.action_prob = [0]*4
         self.agent = Agent(self.state_size, self.action_size, "FC", 0)
         # self.agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint9900.pth'))
-        self.agent.qnetwork_local.load_state_dict(torch.load('../flatland/flatland/baselines/Nets/avoid_checkpoint15000.pth'))
+        self.agent.qnetwork_local.load_state_dict(torch.load(
+            '../flatland/flatland/baselines/Nets/avoid_checkpoint15000.pth'))
 
         self.iFrame = 0
         self.tStart = time.time()
@@ -97,7 +98,7 @@ def main(render=True, delay=0.0):
 
     # Example generate a random rail
     env = RailEnv(width=15, height=15,
-                  rail_generator=complex_rail_generator(nr_start_goal=20, min_dist=12),
+                  rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=20, min_dist=12),
                   number_of_agents=5)
 
     if render:
@@ -202,7 +203,7 @@ def main(render=True, delay=0.0):
         if trials % 100 == 0:
             tNow = time.time()
             rFps = iFrame / (tNow - tStart)
-            print(('\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%' + 
+            print(('\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%' +
                    '\tEpsilon: {:.2f} fps: {:.2f} \t Action Probabilities: \t {}').format(
                    env.number_of_agents,
                    trials,
@@ -215,4 +216,4 @@ def main(render=True, delay=0.0):
 
 
 if __name__ == "__main__":
-    main()
\ No newline at end of file
+    main()
diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py
index a6fbae6d0d271f47e98d08262c7fbc2801b7142d..43884a40e1e6cbd0cbfb10d69e569900bfffa72e 100644
--- a/flatland/core/env_observation_builder.py
+++ b/flatland/core/env_observation_builder.py
@@ -61,19 +61,23 @@ class TreeObsForRailEnv(ObservationBuilder):
         self.max_depth = max_depth
 
     def reset(self):
-        self.distance_map = np.inf * np.ones(shape=(self.env.number_of_agents,
+        agents = self.env.agents
+        nAgents = len(agents)
+        self.distance_map = np.inf * np.ones(shape=(nAgents,  # self.env.number_of_agents,
                                                     self.env.height,
                                                     self.env.width,
                                                     4))
-        self.max_dist = np.zeros(self.env.number_of_agents)
+        self.max_dist = np.zeros(nAgents)
 
-        for i in range(self.env.number_of_agents):
-            self.max_dist[i] = self._distance_map_walker(self.env.agents_target[i], i)
+        # for i in range(nAgents):
+        #     self.max_dist[i] = self._distance_map_walker(self.env.agents_target[i], i)
+        self.max_dist = [self._distance_map_walker(agent.target, i) for i, agent in enumerate(agents)]
 
         # Update local lookup table for all agents' target locations
         self.location_has_target = {}
-        for loc in self.env.agents_target:
-            self.location_has_target[(loc[0], loc[1])] = 1
+        # for loc in self.env.agents_target:
+        #    self.location_has_target[(loc[0], loc[1])] = 1
+        self.location_has_target = {agent.position: 1 for agent in agents}
 
     def _distance_map_walker(self, position, target_nr):
         """
@@ -229,28 +233,33 @@ class TreeObsForRailEnv(ObservationBuilder):
         """
 
         # Update local lookup table for all agents' positions
-        self.location_has_agent = {}
-        for loc in self.env.agents_position:
-            self.location_has_agent[(loc[0], loc[1])] = 1
-
-        position = self.env.agents_position[handle]
-        orientation = self.env.agents_direction[handle]
-        possible_transitions = self.env.rail.get_transitions((position[0], position[1], orientation))
+        # self.location_has_agent = {}
+        # for loc in self.env.agents_position:
+        #    self.location_has_agent[(loc[0], loc[1])] = 1
+        self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents}
+
+        agent = self.env.agents[handle]  # TODO: handle being treated as index
+        # position = self.env.agents_position[handle]
+        # orientation = self.env.agents_direction[handle]
+        possible_transitions = self.env.rail.get_transitions((*agent.position, agent.direction))
         num_transitions = np.count_nonzero(possible_transitions)
         # Root node - current position
-        observation = [0, 0, 0, 0, self.distance_map[handle, position[0], position[1], orientation]]
+        # observation = [0, 0, 0, 0, self.distance_map[handle, position[0], position[1], orientation]]
+        observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)]]
         root_observation = observation[:]
 
         # Start from the current orientation, and see which transitions are available;
         # organize them as [left, forward, right, back], relative to the current orientation
         # If only one transition is possible, the tree is oriented with this transition as the forward branch.
         # TODO: Test if this works as desired!
+        orientation = agent.direction
         if num_transitions == 1:
             orientation == np.argmax(possible_transitions)
 
-        for branch_direction in [(orientation + 4 + i) % 4 for i in range(-1, 3)]:
+        # for branch_direction in [(orientation + 4 + i) % 4 for i in range(-1, 3)]:
+        for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]:
             if possible_transitions[branch_direction]:
-                new_cell = self._new_position(position, branch_direction)
+                new_cell = self._new_position(agent.position, branch_direction)
 
                 branch_observation = self._explore_branch(handle, new_cell, branch_direction, root_observation, 1)
                 observation = observation + branch_observation
@@ -307,17 +316,18 @@ class TreeObsForRailEnv(ObservationBuilder):
             visited.add((position[0], position[1], direction))
 
             # If the target node is encountered, pick that as node. Also, no further branching is possible.
-            if position[0] == self.env.agents_target[handle][0] and position[1] == self.env.agents_target[handle][1]:
+            # if position[0] == self.env.agents_target[handle][0] and position[1] == self.env.agents_target[handle][1]:
+            if np.array_equal(position, self.env.agents[handle].target):
                 last_isTarget = True
                 break
 
-            cell_transitions = self.env.rail.get_transitions((position[0], position[1], direction))
+            cell_transitions = self.env.rail.get_transitions((*position, direction))
             num_transitions = np.count_nonzero(cell_transitions)
             exploring = False
             if num_transitions == 1:
                 # Check if dead-end, or if we can go forward along direction
                 nbits = 0
-                tmp = self.env.rail.get_transitions((position[0], position[1]))
+                tmp = self.env.rail.get_transitions(tuple(position))
                 while tmp > 0:
                     nbits += (tmp & 1)
                     tmp = tmp >> 1
@@ -380,9 +390,9 @@ class TreeObsForRailEnv(ObservationBuilder):
         # Start from the current orientation, and see which transitions are available;
         # organize them as [left, forward, right, back], relative to the current orientation
         # Get the possible transitions
-        possible_transitions = self.env.rail.get_transitions((position[0], position[1], direction))
+        possible_transitions = self.env.rail.get_transitions((*position, direction))
         for branch_direction in [(direction + 4 + i) % 4 for i in range(-1, 3)]:
-            if last_isDeadEnd and self.env.rail.get_transition((position[0], position[1], direction),
+            if last_isDeadEnd and self.env.rail.get_transition((*position, direction),
                                                                (branch_direction + 2) % 4):
                 # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes
                 # it back
@@ -471,20 +481,21 @@ class GlobalObsForRailEnv(ObservationBuilder):
         #     self.targets[target_pos] += 1
 
     def get(self, handle):
-        obs_agents_targets_pos = np.zeros((4, self.env.height, self.env.width))
-        agent_pos = self.env.agents_position[handle]
-        obs_agents_targets_pos[0][agent_pos] += 1
-        for i in range(len(self.env.agents_position)):
-            if i != handle:
-                obs_agents_targets_pos[3][self.env.agents_position[i]] += 1
-
-        agent_target_pos = self.env.agents_target[handle]
-        obs_agents_targets_pos[1][agent_target_pos] += 1
-        for i in range(len(self.env.agents_target)):
-            if i != handle:
-                obs_agents_targets_pos[2][self.env.agents_target[i]] += 1
+        obs = np.zeros((4, self.env.height, self.env.width))
+        agents = self.env.agents
+        agent = agents[handle]
+
+        agent_pos = agents[handle].position
+        obs[0][agent_pos] += 1
+        obs[1][agent.target] += 1
+
+        for i in range(len(agents)):
+            if i != handle:   # TODO: handle used as index...?
+                agent2 = agents[i]
+                obs[3][agent2.position] += 1
+                obs[2][agent2.target] += 1
 
         direction = np.zeros(4)
-        direction[self.env.agents_direction[handle]] = 1
+        direction[agent.direction] = 1
 
-        return self.rail_obs, obs_agents_targets_pos, direction
+        return self.rail_obs, obs, direction
diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..da36fe73e867b7e2ec7f04a5564c73dd3e23a9a5
--- /dev/null
+++ b/flatland/envs/agent_utils.py
@@ -0,0 +1,63 @@
+
+from attr import attrs, attrib
+from itertools import starmap
+# from flatland.envs.rail_env import RailEnv
+
+
+@attrs
+class EnvDescription(object):
+    n_agents = attrib()
+    height = attrib()
+    width = attrib()
+    rail_generator = attrib()
+    obs_builder = attrib()
+
+
+@attrs
+class EnvAgentStatic(object):
+    """ TODO: EnvAgentStatic - To store initial position, direction and target.
+        This is like static data for the environment - it's where an agent starts,
+        rather than where it is at the moment.
+        The target should also be stored here.
+    """
+    position = attrib()
+    direction = attrib()
+    target = attrib()
+
+    next_handle = 0  # this is not properly implemented
+
+    @classmethod
+    def from_lists(cls, positions, directions, targets):
+        """ Create a list of EnvAgentStatics from lists of positions, directions and targets
+        """
+        return list(starmap(EnvAgentStatic, zip(positions, directions, targets)))
+        
+
+@attrs
+class EnvAgent(EnvAgentStatic):
+    """ EnvAgent - replace separate agent_* lists with a single list
+        of agent objects.  The EnvAgent represent's the environment's view
+        of the dynamic agent state.
+        We are duplicating target in the EnvAgent, which seems simpler than
+        forcing the env to refer to it in the EnvAgentStatic
+    """
+    handle = attrib(default=None)
+
+    @classmethod
+    def from_static(cls, oStatic):
+        """ Create an EnvAgent from the EnvAgentStatic,
+        copying all the fields, and adding handle with the default 0.
+        """
+        return EnvAgent(*oStatic.__dict__, handle=0)
+
+    @classmethod
+    def list_from_static(cls, lEnvAgentStatic, handles=None):
+        """ Create an EnvAgent from the EnvAgentStatic,
+        copying all the fields, and adding handle with the default 0.
+        """
+        if handles is None:
+            handles = range(len(lEnvAgentStatic))
+            
+        return [EnvAgent(**oEAS.__dict__, handle=handle)
+                for handle, oEAS in zip(handles, lEnvAgentStatic)]
+
diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py
index fe971e6b24b90e31dadd797359247537078ad5f6..7452d325530bb189084182f0bbc4bf26369e5881 100644
--- a/flatland/envs/generators.py
+++ b/flatland/envs/generators.py
@@ -9,7 +9,7 @@ from flatland.envs.env_utils import distance_on_rail, connect_rail, get_directio
 from flatland.envs.env_utils import get_rnd_agents_pos_tgt_dir_on_rail
 
 
-def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0):
+def complex_rail_generator(nr_start_goal=1, nr_extra=10, min_dist=2, max_dist=99999, seed=0):
     """
     Parameters
     -------
@@ -123,7 +123,27 @@ def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0):
                 # print("failed...")
                 created_sanity += 1
 
-        #print("\n> Complex Rail Gen: Created #", len(start_goal), "pairs")
+        # add extra connections between existing rail
+        created_sanity = 0
+        nr_created = 0
+        while nr_created < nr_extra and created_sanity < sanity_max:
+            all_ok = False
+            for _ in range(sanity_max):
+                start = (np.random.randint(0, width), np.random.randint(0, height))
+                goal = (np.random.randint(0, height), np.random.randint(0, height))
+                # check to make sure start,goal pos are not empty
+                if rail_array[goal] == 0 or rail_array[start] == 0:
+                    continue
+                else:
+                    all_ok = True
+                    break
+            if not all_ok:
+                break
+            new_path = connect_rail(rail_trans, rail_array, start, goal)
+            if len(new_path) >= 2:
+                nr_created += 1
+
+        print("\n> Complex Rail Gen: Created #", len(start_goal), "pairs and #", nr_created, "extra connections")
         # print(start_goal)
 
         agents_position = [sg[0] for sg in start_goal]
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 98abf81f469e1ea329db39e86c3cfe0a7756df28..9767bba42e2b3acf3ef4aa34b14154077dac77bd 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -10,34 +10,12 @@ from flatland.core.env import Environment
 from flatland.core.env_observation_builder import TreeObsForRailEnv
 from flatland.envs.generators import random_rail_generator
 from flatland.envs.env_utils import get_new_position
+from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent
 
 # from flatland.core.transitions import Grid8Transitions, RailEnvTransitions
 # from flatland.core.transition_map import GridTransitionMap
 
 
-class EnvAgentStatic(object):
-    """ TODO: EnvAgentStatic - To store initial position, direction and target.
-        This is like static data for the environment - it's where an agent starts,
-        rather than where it is at the moment.
-        The target should also be stored here.
-    """
-    def __init__(self, rcPos, iDir, rcTarget):
-        self.rcPos = rcPos
-        self.iDir = iDir
-        self.rcTarget = rcTarget
-
-
-class EnvAgent(object):
-    """ TODO: EnvAgent - replace separate agent lists with a single list
-        of agent objects.  The EnvAgent represent's the environment's view
-        of the dynamic agent state.  So target is not part of it - target is
-        static.
-    """
-    def __init__(self, rcPos, iDir):
-        self.rcPos = rcPos
-        self.iDir = iDir
-
-
 class RailEnv(Environment):
     """
     RailEnv environment class.
@@ -123,6 +101,7 @@ class RailEnv(Environment):
         # self.agents_position = []
         # self.agents_target = []
         # self.agents_direction = []
+        self.agents = []
         self.num_resets = 0
         self.reset()
         self.num_resets = 0
@@ -137,14 +116,19 @@ class RailEnv(Environment):
         TODO: replace_agents is ignored at the moment; agents will always be replaced.
         """
         if regen_rail or self.rail is None:
-            self.rail, self.agents_position, self.agents_direction, self.agents_target = self.rail_generator(
+            self.rail, agents_position, agents_direction, agents_target = self.rail_generator(
                 self.width,
                 self.height,
                 self.agents_handles,
                 self.num_resets)
 
+        if replace_agents:
+            self.agents_static = EnvAgentStatic.from_lists(agents_position, agents_direction, agents_target)
+            self.agents = EnvAgent.list_from_static(self.agents_static[:len(self.agents_handles)])
+
         self.num_resets += 1
 
+        # perhaps dones should be part of each agent.
         self.dones = {"__all__": False}
         for handle in self.agents_handles:
             self.dones[handle] = False
@@ -174,11 +158,12 @@ class RailEnv(Environment):
         for i in range(len(self.agents_handles)):
             handle = self.agents_handles[i]
             transition_isValid = None
+            agent = self.agents[i]
 
-            if handle not in action_dict:
+            if handle not in action_dict:  # no action has been supplied for this agent
                 continue
 
-            if self.dones[handle]:
+            if self.dones[handle]:  # this agent has already completed...
                 continue
             action = action_dict[handle]
 
@@ -188,31 +173,28 @@ class RailEnv(Environment):
                 return
 
             if action > 0:
-                pos = self.agents_position[i]
-                direction = self.agents_direction[i]
+                # pos = agent.position #  self.agents_position[i]
+                # direction = agent.direction # self.agents_direction[i]
 
                 # compute number of possible transitions in the current
                 # cell used to check for invalid actions
 
-                possible_transitions = self.rail.get_transitions((pos[0], pos[1], direction))
+                possible_transitions = self.rail.get_transitions((*agent.position, agent.direction))
                 num_transitions = np.count_nonzero(possible_transitions)
 
-                movement = direction
+                movement = agent.direction
                 # print(nbits,np.sum(possible_transitions))
                 if action == 1:
-                    movement = direction - 1
+                    movement = agent.direction - 1
                     if num_transitions <= 1:
                         transition_isValid = False
 
                 elif action == 3:
-                    movement = direction + 1
+                    movement = agent.direction + 1
                     if num_transitions <= 1:
                         transition_isValid = False
 
-                if movement < 0:
-                    movement += 4
-                if movement >= 4:
-                    movement -= 4
+                movement %= 4
 
                 if action == 2:
                     if num_transitions == 1:
@@ -222,57 +204,72 @@ class RailEnv(Environment):
                         movement = np.argmax(possible_transitions)
                         transition_isValid = True
 
-                new_position = get_new_position(pos, movement)
-                # Is it a legal move?  1) transition allows the movement in the
-                # cell,  2) the new cell is not empty (case 0),  3) the cell is
-                # free, i.e., no agent is currently in that cell
-                if (
-                        new_position[1] >= self.width or
-                        new_position[0] >= self.height or
-                        new_position[0] < 0 or new_position[1] < 0):
-                    new_cell_isValid = False
-
-                elif self.rail.get_transitions((new_position[0], new_position[1])) > 0:
-                    new_cell_isValid = True
-                else:
-                    new_cell_isValid = False
+                new_position = get_new_position(agent.position, movement)
+                # Is it a legal move?
+                # 1) transition allows the movement in the cell,
+                # 2) the new cell is not empty (case 0),
+                # 3) the cell is free, i.e., no agent is currently in that cell
+                
+                # if (
+                #        new_position[1] >= self.width or
+                #        new_position[0] >= self.height or
+                #        new_position[0] < 0 or new_position[1] < 0):
+                #    new_cell_isValid = False
+
+                # if self.rail.get_transitions(new_position) == 0:
+                #     new_cell_isValid = False
+
+                new_cell_isValid = (
+                        np.array_equal(  # Check the new position is still in the grid
+                            new_position,
+                            np.clip(new_position, [0, 0], [self.height-1, self.width-1]))
+                        and  # check the new position has some transitions (ie is not an empty cell)
+                        self.rail.get_transitions(new_position) > 0)
 
                 # If transition validity hasn't been checked yet.
                 if transition_isValid is None:
                     transition_isValid = self.rail.get_transition(
-                        (pos[0], pos[1], direction),
+                        (*agent.position, agent.direction),
                         movement)
 
-                cell_isFree = True
-                for j in range(self.number_of_agents):
-                    if self.agents_position[j] == new_position:
-                        cell_isFree = False
-                        break
-
-                if new_cell_isValid and transition_isValid and cell_isFree:
+                # cell_isFree = True
+                # for j in range(self.number_of_agents):
+                #    if self.agents_position[j] == new_position:
+                #        cell_isFree = False
+                #        break
+                # Check the new position is not the same as any of the existing agent positions
+                # (including itself, for simplicity, since it is moving)
+                cell_isFree = not np.any(
+                        np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1))
+
+                if all([new_cell_isValid, transition_isValid, cell_isFree]):
                     # move and change direction to face the movement that was
                     # performed
-                    self.agents_position[i] = new_position
-                    self.agents_direction[i] = movement
+                    # self.agents_position[i] = new_position
+                    # self.agents_direction[i] = movement
+                    agent.position = new_position
+                    agent.direction = movement
                 else:
                     # the action was not valid, add penalty
                     self.rewards_dict[handle] += invalid_action_penalty
 
             # if agent is not in target position, add step penalty
-            if self.agents_position[i][0] == self.agents_target[i][0] and \
-                    self.agents_position[i][1] == self.agents_target[i][1]:
+            # if self.agents_position[i][0] == self.agents_target[i][0] and \
+            #        self.agents_position[i][1] == self.agents_target[i][1]:
+            #    self.dones[handle] = True
+            if np.equal(agent.position, agent.target).all():
                 self.dones[handle] = True
             else:
                 self.rewards_dict[handle] += step_penalty
 
         # Check for end of episode + add global reward to all rewards!
-        num_agents_in_target_position = 0
-        for i in range(self.number_of_agents):
-            if self.agents_position[i][0] == self.agents_target[i][0] and \
-                    self.agents_position[i][1] == self.agents_target[i][1]:
-                num_agents_in_target_position += 1
-
-        if num_agents_in_target_position == self.number_of_agents:
+        # num_agents_in_target_position = 0
+        # for i in range(self.number_of_agents):
+        #    if self.agents_position[i][0] == self.agents_target[i][0] and \
+        #            self.agents_position[i][1] == self.agents_target[i][1]:
+        #        num_agents_in_target_position += 1
+        # if num_agents_in_target_position == self.number_of_agents:
+        if np.all([np.array_equal(agent2.position, agent2.target) for agent2 in self.agents]):
             self.dones["__all__"] = True
             self.rewards_dict = [r + global_reward for r in self.rewards_dict]
 
@@ -290,3 +287,4 @@ class RailEnv(Environment):
     def render(self):
         # TODO:
         pass
+
diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py
index 1f731c39f024f4cf8d10a5ad70171ba0b60b260d..c2fcb73b06fb2a2470187c10c90bee4ffc468148 100644
--- a/flatland/utils/rendertools.py
+++ b/flatland/utils/rendertools.py
@@ -158,20 +158,9 @@ class RenderTool(object):
 
     def plotAgents(self, targets=True):
         cmap = self.gl.get_cmap('hsv', lut=self.env.number_of_agents + 1)
-        for iAgent in range(self.env.number_of_agents):
+        for iAgent, agent in enumerate(self.env.agents):
             oColor = cmap(iAgent)
-
-            rcPos = self.env.agents_position[iAgent]
-            iDir = self.env.agents_direction[iAgent]  # agent direction index
-
-            if targets:
-                target = self.env.agents_target[iAgent]
-            else:
-                target = None
-            self.plotAgent(rcPos, iDir, oColor, target=target)
-
-            # gTransRCAg = self.getTransRC(rcPos, iDir)
-            # self.plotTrans(rcPos, gTransRCAg)
+            self.plotAgent(agent.position, agent.direction, oColor, target=agent.target if targets else None)
 
     def getTransRC(self, rcPos, iDir, bgiTrans=False):
         """
@@ -554,7 +543,7 @@ class RenderTool(object):
 
                 if not bCellValid:
                     # print("invalid:", r, c)
-                    self.gl.scatter(*xyCentre, color="r", s=50)
+                    self.gl.scatter(*xyCentre, color="r", s=30)
 
                 for orientation in range(4):  # ori is where we're heading
                     from_ori = (orientation + 2) % 4  # 0123=NESW -> 2301=SWNE
diff --git a/tests/test_environments.py b/tests/test_environments.py
index a10fb0619eae4d27867c9008c27618fe059d52d2..fe788b7c72fbcab358ac2120b0069c7cf64b1801 100644
--- a/tests/test_environments.py
+++ b/tests/test_environments.py
@@ -7,7 +7,7 @@ from flatland.envs.generators import rail_from_GridTransitionMap_generator
 from flatland.core.transitions import Grid4Transitions
 from flatland.core.transition_map import GridTransitionMap
 from flatland.core.env_observation_builder import GlobalObsForRailEnv
-
+from flatland.envs.agent_utils import EnvAgent
 
 """Tests for `flatland` package."""
 
@@ -58,16 +58,21 @@ def test_rail_environment_single_agent():
         _ = rail_env.reset()
 
         # We do not care about target for the moment
-        rail_env.agents_target[0] = [-1, -1]
+        # rail_env.agents_target[0] = [-1, -1]
+        agent = rail_env.agents[0]
+        # rail_env.agents[0].target = [-1, -1]
+        agent.target = [-1, -1]
 
         # Check that trains are always initialized at a consistent position
         # or direction.
         # They should always be able to go somewhere.
         assert(transitions.get_transitions(
-            rail_map[rail_env.agents_position[0]],
-            rail_env.agents_direction[0]) != (0, 0, 0, 0))
+            # rail_map[rail_env.agents_position[0]],
+            # rail_env.agents_direction[0]) != (0, 0, 0, 0))
+            rail_map[agent.position],
+            agent.direction) != (0, 0, 0, 0))
 
-        initial_pos = rail_env.agents_position[0]
+        initial_pos = agent.position
 
         valid_active_actions_done = 0
         pos = initial_pos
@@ -78,13 +83,13 @@ def test_rail_environment_single_agent():
             _, _, _, _ = rail_env.step({0: action})
 
             prev_pos = pos
-            pos = rail_env.agents_position[0]
+            pos = agent.position  # rail_env.agents_position[0]
             if prev_pos != pos:
                 valid_active_actions_done += 1
 
         # After 6 movements on this railway network, the train should be back
         # to its original height on the map.
-        assert(initial_pos[0] == rail_env.agents_position[0][0])
+        assert(initial_pos[0] == agent.position[0])
 
         # We check that the train always attains its target after some time
         for _ in range(10):
@@ -135,13 +140,14 @@ def test_dead_end():
         # We run step to check that trains do not move anymore
         # after being done.
         for i in range(7):
-            prev_pos = rail_env.agents_position[0]
+            # prev_pos = rail_env.agents_position[0]
+            prev_pos = rail_env.agents[0].position
 
             # The train cannot turn, so we check that when it tries,
             # it stays where it is.
             _ = rail_env.step({0: 1})
             _ = rail_env.step({0: 3})
-            assert (rail_env.agents_position[0] == prev_pos)
+            assert (rail_env.agents[0].position == prev_pos)
             _, _, dones, _ = rail_env.step({0: 2})
 
             if i < 5:
@@ -151,15 +157,17 @@ def test_dead_end():
 
     # We try the configuration in the 4 directions:
     rail_env.reset()
-    rail_env.agents_target[0] = (0, 0)
-    rail_env.agents_position[0] = (0, 2)
-    rail_env.agents_direction[0] = 1
+    # rail_env.agents_target[0] = (0, 0)
+    # rail_env.agents_position[0] = (0, 2)
+    # rail_env.agents_direction[0] = 1
+    rail_env.agents = [EnvAgent(position=(0, 2), direction=1, target=(0, 0))]
     check_consistency(rail_env)
 
     rail_env.reset()
-    rail_env.agents_target[0] = (0, 4)
-    rail_env.agents_position[0] = (0, 2)
-    rail_env.agents_direction[0] = 3
+    # rail_env.agents_target[0] = (0, 4)
+    # rail_env.agents_position[0] = (0, 2)
+    # rail_env.agents_direction[0] = 3
+    rail_env.agents = [EnvAgent(position=(0, 2), direction=3, target=(0, 4))]
     check_consistency(rail_env)
 
     # In the vertical configuration:
@@ -181,13 +189,15 @@ def test_dead_end():
                        obs_builder_object=GlobalObsForRailEnv())
 
     rail_env.reset()
-    rail_env.agents_target[0] = (0, 0)
-    rail_env.agents_position[0] = (2, 0)
-    rail_env.agents_direction[0] = 2
+    # rail_env.agents_target[0] = (0, 0)
+    # rail_env.agents_position[0] = (2, 0)
+    # rail_env.agents_direction[0] = 2
+    rail_env.agents = [EnvAgent(position=(2, 0), direction=2, target=(0, 0))]
     check_consistency(rail_env)
 
     rail_env.reset()
-    rail_env.agents_target[0] = (4, 0)
-    rail_env.agents_position[0] = (2, 0)
-    rail_env.agents_direction[0] = 0
+    # rail_env.agents_target[0] = (4, 0)
+    # rail_env.agents_position[0] = (2, 0)
+    # rail_env.agents_direction[0] = 0
+    rail_env.agents = [EnvAgent(position=(2, 0), direction=0, target=(4, 0))]
     check_consistency(rail_env)