From ed76c8ecb7afd02437f6abbfcf4563de5cca5d91 Mon Sep 17 00:00:00 2001
From: hagrid67 <jdhwatson@gmail.com>
Date: Tue, 28 Jul 2020 13:35:50 +0100
Subject: [PATCH] close_following seems to be working.  increased window size.
 set old_position to None for a done agent so it disappears from rendered
 image

---
 flatland/envs/agent_chains.py  | 104 ++++++++++++-
 flatland/envs/rail_env.py      | 275 ++++++++++++++++++++++++++++++---
 flatland/utils/graphics_pgl.py |   2 +-
 3 files changed, 355 insertions(+), 26 deletions(-)

diff --git a/flatland/envs/agent_chains.py b/flatland/envs/agent_chains.py
index 7d7f0c16..ac3f9135 100644
--- a/flatland/envs/agent_chains.py
+++ b/flatland/envs/agent_chains.py
@@ -18,6 +18,15 @@ class MotionCheck(object):
             The agent's current position is given an "agent" attribute recording the agent index.
             If an agent does not move this round then its cell is 
         """
+
+        # Agents which have not yet entered the env have position None.
+        # Substitute this for the row = -1, column = agent index
+        if rc1 is None:
+            rc1 = (-1, iAg)
+
+        if rc2 is None:
+            rc2 = (-1, iAg)
+
         self.G.add_node(rc1, agent=iAg)
         if xlabel:
             self.G.nodes[rc1]["xlabel"] = xlabel
@@ -73,12 +82,12 @@ class MotionCheck(object):
             #print(svCompStops)
             
             if len(svCompStops) > 0:
-                print("component contains a stop")
+                #print("component contains a stop")
                 for vStop in svCompStops:
                     
                     iter_stops = nx.algorithms.traversal.dfs_postorder_nodes(Gwcc.reverse(), vStop)
                     lStops = list(iter_stops)
-                    print(vStop, "affected preds:", lStops)
+                    #print(vStop, "affected preds:", lStops)
                     svBlocked.update(lStops)
         
         return svBlocked
@@ -90,15 +99,87 @@ class MotionCheck(object):
         #svStops = self.find_stops2()
         llvLoops = list(nx.algorithms.cycles.simple_cycles(self.G))
         llvSwaps = [lvLoop for lvLoop in llvLoops if len(lvLoop) == 2 ]
-        return llvSwaps
+        svSwaps = { v for lvSwap in llvSwaps for v in lvSwap }
+        return svSwaps
     
     def find_same_dest(self):
         """ find groups of agents which are trying to land on the same cell.
             ie there is a gap of one cell between them and they are both landing on it.
         """
+        pass
+
+    def find_conflicts(self):
+        svStops = self.find_stops2() # { u for u,v in nx.classes.function.selfloop_edges(self.G) }
+        #llvLoops = list(nx.algorithms.cycles.simple_cycles(self.G))
+        #llvSwaps = [lvLoop for lvLoop in llvLoops if len(lvLoop) == 2 ]
+        #svSwaps = { v for lvSwap in llvSwaps for v in lvSwap }
+        svSwaps = self.find_swaps()
+        svBlocked = self.find_stop_preds(svStops.union(svSwaps))
+
+        for (v, dPred) in self.G.pred.items():
+            if v in svSwaps:
+                self.G.nodes[v]["color"] = "purple"
+            elif v in svBlocked:
+                self.G.nodes[v]["color"] = "red"
+            elif len(dPred)>1:
+                
+                if self.G.nodes[v].get("color") == "red":
+                    continue
+                
+                if self.G.nodes[v].get("agent") is None:
+                    self.G.nodes[v]["color"] = "blue"    
+                else:
+                    self.G.nodes[v]["color"] = "magenta"
+                
+                # predecessors of a contended cell
+                diAgCell = {self.G.nodes[vPred].get("agent"): vPred  for vPred in dPred}
+                
+                # remove the agent with the lowest index, who wins
+                iAgWinner = min(diAgCell)
+                diAgCell.pop(iAgWinner)
+                
+                # Block all the remaining predessors, and their tree of preds
+                for iAg, v in diAgCell.items():
+                    self.G.nodes[v]["color"] = "red"
+                    for vPred in nx.traversal.dfs_postorder_nodes(self.G.reverse(), source=v):
+                        self.G.nodes[vPred]["color"] = "red"
+
+    def check_motion(self, iAgent, rcPos):
+        """ If agent position is None, we use a dummy position of (-1, iAgent)
+        """
+
+        if rcPos is None:
+            rcPos = (-1, iAgent)
+
+        dAttr = self.G.nodes.get(rcPos)
+        #print("pos:", rcPos, "dAttr:", dAttr)
+
+        if dAttr is None:
+            dAttr = {}
+
+        # If it's been marked red or purple then it can't move
+        if "color" in dAttr:
+            sColor = dAttr["color"]
+            if sColor in [ "red", "purple" ]:
+                return (False, rcPos)
         
+        dSucc = self.G.succ[rcPos]
+
+        # This should never happen - only the next cell of an agent has no successor
+        if len(dSucc)==0:
+            print(f"error condition - agent {iAg} node {rcPos} has no successor")
+            return (False, rcPos)
+
+        # This agent has a successor
+        rcNext = self.G.successors(rcPos).__next__()
+        if rcNext == rcPos:  # the agent didn't want to move
+            return (False, rcNext)
+        # The agent wanted to move, and it can
+        return (True, rcNext)
 
 
+            
+
 def render(omc:MotionCheck):
     oAG = nx.drawing.nx_agraph.to_agraph(omc.G)
     oAG.layout("dot")
@@ -228,6 +309,7 @@ def create_test_agents2(omc:MotionCheck):
     cte.addAgentToRow(6, 5)
     cte.addAgentToRow(7, 6)
 
+
     cte.nextRow()
     cte.addAgentToRow(1, 2, "3-way\nsame")
     cte.addAgentToRow(3, 2)
@@ -251,6 +333,19 @@ def create_test_agents2(omc:MotionCheck):
     cte.nextRow()
     
 
+    cte.nextRow()
+    cte.addAgentToRow(1, 2, "Tree")
+    cte.addAgentToRow(2, 3)
+    cte.addAgentToRow(3, 4)
+    r1 = cte.iRowNext
+    r2 = cte.iRowNext+1
+    r3 = cte.iRowNext+2
+    cte.addAgent((r2, 3), (r1, 3))
+    cte.addAgent((r2, 2), (r2, 3))
+    cte.addAgent((r3, 2), (r2, 3))
+
+    cte.nextRow()
+
 
 def test_agent_following():
     omc = MotionCheck()
@@ -288,4 +383,5 @@ def main():
     test_agent_following()
 
 if __name__=="__main__":
-    main()
\ No newline at end of file
+    main()
+    
\ No newline at end of file
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 03678dec..1fa69f8a 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -26,6 +26,7 @@ from flatland.envs import malfunction_generators as mal_gen
 from flatland.envs import rail_generators as rail_gen
 from flatland.envs import schedule_generators as sched_gen
 from flatland.envs import persistence
+from flatland.envs import agent_chains as ac
 
 # Direct import of objects / classes does not work with circular imports.
 # from flatland.envs.malfunction_generators import no_malfunction_generator, Malfunction, MalfunctionProcessData
@@ -35,7 +36,7 @@ from flatland.envs import persistence
 
 from flatland.envs.observations import GlobalObsForRailEnv
 
-
+import debugpy
 
 import pickle
 
@@ -244,6 +245,10 @@ class RailEnv(Environment):
         self.cur_episode = []  
         self.list_actions = [] # save actions in here
 
+        self.close_following = True  # use close following logic
+        self.motionCheck = ac.MotionCheck()
+
+
     def _seed(self, seed=None):
         self.np_random, seed = seeding.np_random(seed)
         random.seed(seed)
@@ -498,27 +503,66 @@ class RailEnv(Environment):
         }
         have_all_agents_ended = True  # boolean flag to check if all agents are done
 
-        for i_agent, agent in enumerate(self.agents):
-            # Reset the step rewards
-            self.rewards_dict[i_agent] = 0
+        self.motionCheck = ac.MotionCheck()  # reset the motion check
 
-            # Induce malfunction before we do a step, thus a broken agent can't move in this step
-            self._break_agent(agent)
+        if not self.close_following:
+            for i_agent, agent in enumerate(self.agents):
+                # Reset the step rewards
+                self.rewards_dict[i_agent] = 0
 
-            # Perform step on the agent
-            self._step_agent(i_agent, action_dict_.get(i_agent))
+                # Induce malfunction before we do a step, thus a broken agent can't move in this step
+                self._break_agent(agent)
 
-            # manage the boolean flag to check if all agents are indeed done (or done_removed)
-            have_all_agents_ended &= (agent.status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED])
+                # Perform step on the agent
+                self._step_agent(i_agent, action_dict_.get(i_agent))
 
-            # Build info dict
-            info_dict["action_required"][i_agent] = self.action_required(agent)
-            info_dict["malfunction"][i_agent] = agent.malfunction_data['malfunction']
-            info_dict["speed"][i_agent] = agent.speed_data['speed']
-            info_dict["status"][i_agent] = agent.status
+                # manage the boolean flag to check if all agents are indeed done (or done_removed)
+                have_all_agents_ended &= (agent.status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED])
+
+                # Build info dict
+                info_dict["action_required"][i_agent] = self.action_required(agent)
+                info_dict["malfunction"][i_agent] = agent.malfunction_data['malfunction']
+                info_dict["speed"][i_agent] = agent.speed_data['speed']
+                info_dict["status"][i_agent] = agent.status
+
+                # Fix agents that finished their malfunction such that they can perform an action in the next step
+                self._fix_agent_after_malfunction(agent)
+
+
+        else:
+            for i_agent, agent in enumerate(self.agents):
+                # Reset the step rewards
+                self.rewards_dict[i_agent] = 0
+
+                # Induce malfunction before we do a step, thus a broken agent can't move in this step
+                self._break_agent(agent)
+
+                # Perform step on the agent
+                self._step_agent_cf(i_agent, action_dict_.get(i_agent))
+
+        
+            # second loop: check for collisions / conflicts
+            self.motionCheck.find_conflicts()
+
+
+            # third loop: update positions
+            for i_agent, agent in enumerate(self.agents):
+                self._step_agent2_cf(i_agent)
+                
+                # manage the boolean flag to check if all agents are indeed done (or done_removed)
+                have_all_agents_ended &= (agent.status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED])
+
+                # Build info dict
+                info_dict["action_required"][i_agent] = self.action_required(agent)
+                info_dict["malfunction"][i_agent] = agent.malfunction_data['malfunction']
+                info_dict["speed"][i_agent] = agent.speed_data['speed']
+                info_dict["status"][i_agent] = agent.status
+
+                # Fix agents that finished their malfunction such that they can perform an action in the next step
+                self._fix_agent_after_malfunction(agent)
+        
+        
 
-            # Fix agents that finished their malfunction such that they can perform an action in the next step
-            self._fix_agent_after_malfunction(agent)
 
         # Check for end of episode + set global reward to all rewards!
         if have_all_agents_ended:
@@ -533,6 +577,8 @@ class RailEnv(Environment):
 
         return self._get_observations(), self.rewards_dict, self.dones, info_dict
 
+
+
     def _step_agent(self, i_agent, action: Optional[RailEnvActions] = None):
         """
         Performs a step and step, start and stop penalty on a single agent in the following sub steps:
@@ -552,12 +598,15 @@ class RailEnv(Environment):
 
         # agent gets active by a MOVE_* action and if c
         if agent.status == RailAgentStatus.READY_TO_DEPART:
+            initial_cell_free = self.cell_free(agent.initial_position)
+            is_action_starting = action in [
+                RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT, RailEnvActions.MOVE_FORWARD]
+
             if action in [RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT,
-                          RailEnvActions.MOVE_FORWARD] and self.cell_free(agent.initial_position):
+                        RailEnvActions.MOVE_FORWARD] and self.cell_free(agent.initial_position):
                 agent.status = RailAgentStatus.ACTIVE
                 self._set_agent_to_initial_position(agent, agent.initial_position)
                 self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
-                return
             else:
                 # TODO: Here we need to check for the departure time in future releases with full schedules
                 self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
@@ -604,6 +653,7 @@ class RailEnv(Environment):
 
             # Store the action if action is moving
             # If not moving, the action will be stored when the agent starts moving again.
+            new_position = None
             if agent.moving:
                 _action_stored = False
                 _, new_cell_valid, new_direction, new_position, transition_valid = \
@@ -640,10 +690,11 @@ class RailEnv(Environment):
                 # Perform stored action to transition to the next cell as soon as cell is free
                 # Notice that we've already checked new_cell_valid and transition valid when we stored the action,
                 # so we only have to check cell_free now!
-
+                
+                # Traditional check that next cell is free
                 # cell and transition validity was checked when we stored transition_action_on_cellexit!
                 cell_free, new_cell_valid, new_direction, new_position, transition_valid = self._check_action_on_agent(
-                    agent.speed_data['transition_action_on_cellexit'], agent)
+                        agent.speed_data['transition_action_on_cellexit'], agent)
 
                 # N.B. validity of new_cell and transition should have been verified before the action was stored!
                 assert new_cell_valid
@@ -652,7 +703,184 @@ class RailEnv(Environment):
                     self._move_agent_to_new_position(agent, new_position)
                     agent.direction = new_direction
                     agent.speed_data['position_fraction'] = 0.0
+                
+
+            # has the agent reached its target?
+            if np.equal(agent.position, agent.target).all():
+                agent.status = RailAgentStatus.DONE
+                self.dones[i_agent] = True
+                self.active_agents.remove(i_agent)
+                agent.moving = False
+                self._remove_agent_from_scene(agent)
+            else:
+                self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
+        else:
+            # step penalty if not moving (stopped now or before)
+            self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
+
+    def _step_agent_cf(self, i_agent, action: Optional[RailEnvActions] = None):
+        agent = self.agents[i_agent]
+        if agent.status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED]:  # this agent has already completed...
+            return
+
+        # agent gets active by a MOVE_* action and if c
+        if agent.status == RailAgentStatus.READY_TO_DEPART:
+            is_action_starting = action in [
+                RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT, RailEnvActions.MOVE_FORWARD]
+
+            if is_action_starting:  # agent is trying to start
+                self.motionCheck.addAgent(i_agent, None, agent.initial_position)
+            else:  # agent wants to remain unstarted
+                self.motionCheck.addAgent(i_agent, None, None)
+            return
+
+
+        agent.old_direction = agent.direction
+        agent.old_position = agent.position
+
+        # if agent is broken, actions are ignored and agent does not move.
+        # full step penalty in this case
+        if agent.malfunction_data['malfunction'] > 0:
+            self.motionCheck.addAgent(i_agent, agent.position, agent.position)
+            self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
+            return
+
+        # Is the agent at the beginning of the cell? Then, it can take an action.
+        # As long as the agent is malfunctioning or stopped at the beginning of the cell,
+        # different actions may be taken!
+        if np.isclose(agent.speed_data['position_fraction'], 0.0, rtol=1e-03):
+            # No action has been supplied for this agent -> set DO_NOTHING as default
+            if action is None:
+                action = RailEnvActions.DO_NOTHING
+
+            if action < 0 or action > len(RailEnvActions):
+                print('ERROR: illegal action=', action,
+                      'for agent with index=', i_agent,
+                      '"DO NOTHING" will be executed instead')
+                action = RailEnvActions.DO_NOTHING
+
+            if action == RailEnvActions.DO_NOTHING and agent.moving:
+                # Keep moving
+                action = RailEnvActions.MOVE_FORWARD
+
+            if action == RailEnvActions.STOP_MOVING and agent.moving:
+                # Only allow halting an agent on entering new cells.
+                agent.moving = False
+                self.rewards_dict[i_agent] += self.stop_penalty
+
+            if not agent.moving and not (
+                action == RailEnvActions.DO_NOTHING or
+                action == RailEnvActions.STOP_MOVING):
+                # Allow agent to start with any forward or direction action
+                agent.moving = True
+                self.rewards_dict[i_agent] += self.start_penalty
 
+            # Store the action if action is moving
+            # If not moving, the action will be stored when the agent starts moving again.
+            new_position = None
+            if agent.moving:
+                _action_stored = False
+                _, new_cell_valid, new_direction, new_position, transition_valid = \
+                    self._check_action_on_agent(action, agent)
+
+                if all([new_cell_valid, transition_valid]):
+                    agent.speed_data['transition_action_on_cellexit'] = action
+                    _action_stored = True
+                else:
+                    # But, if the chosen invalid action was LEFT/RIGHT, and the agent is moving,
+                    # try to keep moving forward!
+                    if (action == RailEnvActions.MOVE_LEFT or action == RailEnvActions.MOVE_RIGHT):
+                        _, new_cell_valid, new_direction, new_position, transition_valid = \
+                            self._check_action_on_agent(RailEnvActions.MOVE_FORWARD, agent)
+
+                        if all([new_cell_valid, transition_valid]):
+                            agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_FORWARD
+                            _action_stored = True
+
+                if not _action_stored:
+                    # If the agent cannot move due to an invalid transition, we set its state to not moving
+                    self.rewards_dict[i_agent] += self.invalid_action_penalty
+                    self.rewards_dict[i_agent] += self.stop_penalty
+                    agent.moving = False
+                    self.motionCheck.addAgent(i_agent, agent.position, agent.position)    
+                    return
+
+            if new_position is None:
+                self.motionCheck.addAgent(i_agent, agent.position, agent.position)
+                if agent.moving:
+                    print("Agent", i_agent, "new_pos none, but moving")
+            
+        # Check the pos_frac position fraction
+        if agent.moving:
+            agent.speed_data['position_fraction'] += agent.speed_data['speed']
+            if agent.speed_data['position_fraction'] > 0.999:
+                stored_action = agent.speed_data["transition_action_on_cellexit"]
+
+                # find the next cell using the stored action
+                _, new_cell_valid, new_direction, new_position, transition_valid = \
+                    self._check_action_on_agent(stored_action, agent)
+
+                # if it's valid, record it as the new position
+                if all([new_cell_valid, transition_valid]):
+                    self.motionCheck.addAgent(i_agent, agent.position, new_position)
+                else:  # if the action wasn't valid then record the agent as stationary
+                    self.motionCheck.addAgent(i_agent, agent.position, agent.position)
+            else:  # This agent hasn't yet crossed the cell
+                self.motionCheck.addAgent(i_agent, agent.position, agent.position)
+
+
+
+    def _step_agent2_cf(self, i_agent):
+        agent = self.agents[i_agent]
+
+        if agent.status in [ RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED ]:
+            return
+
+        (move, rc_next) = self.motionCheck.check_motion(i_agent, agent.position)
+
+        if agent.position is not None:
+            sbTrans = format(self.rail.grid[agent.position], "016b")
+            trans_block = sbTrans[agent.direction*4 : agent.direction * 4 + 4]
+            if (trans_block == "0000"):
+                print (i_agent, agent.position, agent.direction, sbTrans, trans_block)
+                debugpy.breakpoint()
+
+        # if agent cannot enter env, then we should have move=False
+        
+        if move:
+            if agent.position is None:  # agent is entering the env
+                print(i_agent, "writing new pos ", rc_next, " into agent position (None)")
+                agent.position = rc_next
+                agent.status = RailAgentStatus.ACTIVE
+                agent.speed_data['position_fraction'] = 0.0
+
+            else:  # normal agent move
+                cell_free, new_cell_valid, new_direction, new_position, transition_valid = self._check_action_on_agent(
+                    agent.speed_data['transition_action_on_cellexit'], agent)
+        
+                if not all([transition_valid, new_cell_valid]):
+                    print(f"ERRROR: step_agent2 invalid transition ag {i_agent} dir {new_direction} pos {agent.position} next {rc_next}")
+                    debugpy.breakpoint()
+
+                if new_position != rc_next:
+                    print(f"ERROR: agent {i_agent} new_pos {new_position} != rc_next {rc_next}  " + 
+                        f"pos {agent.position} dir {agent.direction} new_dir {new_direction}" +
+                        f"stored action: {agent.speed_data['transition_action_on_cellexit']}")
+                    debugpy.breakpoint()
+
+
+                sbTrans = format(self.rail.grid[agent.position], "016b")
+                trans_block = sbTrans[agent.direction*4 : agent.direction * 4 + 4]
+                if (trans_block == "0000"):
+                    print (i_agent, agent.position, agent.direction, sbTrans, trans_block)
+                    debugpy.breakpoint()
+
+
+
+                agent.position = rc_next
+                agent.direction = new_direction
+                agent.speed_data['position_fraction'] = 0.0                        
+          
             # has the agent reached its target?
             if np.equal(agent.position, agent.target).all():
                 agent.status = RailAgentStatus.DONE
@@ -666,6 +894,10 @@ class RailEnv(Environment):
             # step penalty if not moving (stopped now or before)
             self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
 
+
+
+
+
     def _set_agent_to_initial_position(self, agent: EnvAgent, new_position: IntVector2D):
         """
         Sets the agent to its initial position. Updates the agent object and the position
@@ -705,6 +937,7 @@ class RailEnv(Environment):
         self.agent_positions[agent.position] = -1
         if self.remove_agents_at_target:
             agent.position = None
+            agent.old_position = None
             agent.status = RailAgentStatus.DONE_REMOVED
 
     def _check_action_on_agent(self, action: RailEnvActions, agent: EnvAgent):
diff --git a/flatland/utils/graphics_pgl.py b/flatland/utils/graphics_pgl.py
index 299459cc..a35776a2 100644
--- a/flatland/utils/graphics_pgl.py
+++ b/flatland/utils/graphics_pgl.py
@@ -21,7 +21,7 @@ class PGLGL(PILSVG):
     def open_window(self):
         print("open_window - pyglet")
         assert self.window_open is False, "Window is already open!"
-        self.window = pgl.window.Window(resizable=True, vsync=False)
+        self.window = pgl.window.Window(resizable=True, vsync=False, width=1200, height=800)
         #self.__class__.window.title("Flatland")
         #self.__class__.window.configure(background='grey')
         self.window_open = True
-- 
GitLab