diff --git a/flatland/envs/agent_chains.py b/flatland/envs/agent_chains.py
index a2792479c0877aef2b41a8b62af7d55f5ab46bb0..3e566ad0617a3c49ec69e1049be36231b705916f 100644
--- a/flatland/envs/agent_chains.py
+++ b/flatland/envs/agent_chains.py
@@ -12,6 +12,9 @@ class MotionCheck(object):
     """
     def __init__(self):
         self.G = nx.DiGraph()
+        self.nDeadlocks = 0
+        self.svDeadlocked = set()
+    
 
     def addAgent(self, iAg, rc1, rc2, xlabel=None):
         """ add an agent and its motion as row,col tuples of current and next position.
@@ -60,6 +63,10 @@ class MotionCheck(object):
         return svStops
 
     def find_stop_preds(self, svStops=None):
+        """ Find the predecessors to a list of stopped agents (ie the nodes / vertices)
+            Returns the set of predecessors.
+            Includes "chained" predecessors.
+        """
 
         if svStops is None:
             svStops = self.find_stops2()
@@ -73,9 +80,10 @@ class MotionCheck(object):
 
         for oWCC in lWCC:
             #print("Component:", oWCC)
+            # Get the node details for this WCC in a subgraph
             Gwcc = self.G.subgraph(oWCC)
             
-            # Find all the stops in this chain
+            # Find all the stops in this chain or tree
             svCompStops = svStops.intersection(Gwcc)
             #print(svCompStops)
 
@@ -91,11 +99,14 @@ class MotionCheck(object):
                     lStops = list(iter_stops)
                     svBlocked.update(lStops)
 
+        # the set of all the nodes/agents blocked by this set of stopped nodes
         return svBlocked
 
     def find_swaps(self):
         """ find all the swap conflicts where two agents are trying to exchange places.
             These appear as simple cycles of length 2.
+            These agents are necessarily deadlocked (since they can't change direction in flatland) -
+            meaning they will now be stuck for the rest of the episode.
         """
         #svStops = self.find_stops2()
         llvLoops = list(nx.algorithms.cycles.simple_cycles(self.G))
@@ -109,30 +120,73 @@ class MotionCheck(object):
         """
         pass
 
+    def block_preds(self, svStops, color="red"):
+        """ Take a list of stopped agents, and apply a stop color to any chains/trees
+            of agents trying to head toward those cells.
+            Count the number of agents blocked, ignoring those which are already marked.
+            (Otherwise it can double count swaps)
+
+        """
+        iCount = 0
+        svBlocked = set()
+        # The reversed graph allows us to follow directed edges to find affected agents.
+        Grev = self.G.reverse()
+        for v in svStops:
+            
+            # Use depth-first-search to find a tree of agents heading toward the blocked cell.
+            lvPred = list(nx.traversal.dfs_postorder_nodes(Grev, source=v))
+            svBlocked |= set(lvPred)
+            svBlocked.add(v)
+            #print("node:", v, "set", svBlocked)
+            # only count those not already marked
+            for v2 in [v]+lvPred:
+                if self.G.nodes[v2].get("color") != color:
+                    self.G.nodes[v2]["color"] = color
+                    iCount += 1
+
+        return svBlocked
+
+
     def find_conflicts(self):
-        svStops = self.find_stops2() # { u for u,v in nx.classes.function.selfloop_edges(self.G) }
-        #llvLoops = list(nx.algorithms.cycles.simple_cycles(self.G))
-        #llvSwaps = [lvLoop for lvLoop in llvLoops if len(lvLoop) == 2 ]
-        #svSwaps = { v for lvSwap in llvSwaps for v in lvSwap }
-        svSwaps = self.find_swaps()
-        svBlocked = self.find_stop_preds(svStops.union(svSwaps))
+        svStops = self.find_stops2()  # voluntarily stopped agents - have self-loops
+        svSwaps = self.find_swaps()   # deadlocks - adjacent head-on collisions
+
+        # Block all swaps and their tree of predessors
+        self.svDeadlocked = self.block_preds(svSwaps, color="purple")
 
+        # Take the union of the above, and find all the predecessors
+        #svBlocked = self.find_stop_preds(svStops.union(svSwaps))
+
+        # Just look for the the tree of preds for each voluntarily stopped agent
+        svBlocked = self.find_stop_preds(svStops)
+
+        # iterate the nodes v with their predecessors dPred (dict of nodes->{})
         for (v, dPred) in self.G.pred.items():
-            if v in svSwaps:
-                self.G.nodes[v]["color"] = "purple"
-            elif v in svBlocked:
+            # mark any swaps with purple - these are directly deadlocked
+            #if v in svSwaps:
+            #    self.G.nodes[v]["color"] = "purple"
+            # If they are not directly deadlocked, but are in the union of stopped + deadlocked
+            #elif v in svBlocked:
+
+            # if in blocked, it will not also be in a swap pred tree, so no need to worry about overwriting
+            if v in svBlocked:
                 self.G.nodes[v]["color"] = "red"
+            # not blocked but has two or more predecessors, ie >=2 agents waiting to enter this node
             elif len(dPred)>1:
 
-                if self.G.nodes[v].get("color") == "red":
+                # if this agent is already red/blocked, ignore. CHECK: why?
+                # certainly we want to ignore purple so we don't overwrite with red.
+                if self.G.nodes[v].get("color") in ("red", "purple"):
                     continue
 
+                # if this node has no agent, and >=2 want to enter it.
                 if self.G.nodes[v].get("agent") is None:
                     self.G.nodes[v]["color"] = "blue"
+                # this node has an agent and >=2 want to enter
                 else:
                     self.G.nodes[v]["color"] = "magenta"
 
-                # predecessors of a contended cell
+                # predecessors of a contended cell: {agent index -> node}
                 diAgCell = {self.G.nodes[vPred].get("agent"): vPred  for vPred in dPred}
 
                 # remove the agent with the lowest index, who wins
@@ -140,13 +194,15 @@ class MotionCheck(object):
                 diAgCell.pop(iAgWinner)
 
                 # Block all the remaining predessors, and their tree of preds
-                for iAg, v in diAgCell.items():
-                    self.G.nodes[v]["color"] = "red"
-                    for vPred in nx.traversal.dfs_postorder_nodes(self.G.reverse(), source=v):
-                        self.G.nodes[vPred]["color"] = "red"
+                #for iAg, v in diAgCell.items():
+                #    self.G.nodes[v]["color"] = "red"
+                #    for vPred in nx.traversal.dfs_postorder_nodes(self.G.reverse(), source=v):
+                #        self.G.nodes[vPred]["color"] = "red"
+                self.block_preds(diAgCell.values(), "red")
 
     def check_motion(self, iAgent, rcPos):
-        """ If agent position is None, we use a dummy position of (-1, iAgent)
+        """ Returns tuple of boolean can the agent move, and the cell it will move into.
+            If agent position is None, we use a dummy position of (-1, iAgent)
         """
 
         if rcPos is None:
@@ -168,7 +224,7 @@ class MotionCheck(object):
 
         # This should never happen - only the next cell of an agent has no successor
         if len(dSucc)==0:
-            print(f"error condition - agent {iAg} node {rcPos} has no successor")
+            print(f"error condition - agent {iAgent} node {rcPos} has no successor")
             return (False, rcPos)
 
         # This agent has a successor
@@ -181,6 +237,7 @@ class MotionCheck(object):
 
 
 
+
 def render(omc:MotionCheck, horizontal=True):
     try:
         oAG = nx.drawing.nx_agraph.to_agraph(omc.G)
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index d9ff7120789d5ad58ba4acc15d16586f0e34b827..94d911eaab8dfe545d6960f9cc7068e53fee8e16 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -728,6 +728,8 @@ class RailEnv(Environment):
             self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
 
     def _step_agent_cf(self, i_agent, action: Optional[RailEnvActions] = None):
+        """ "close following" version of step_agent.
+        """
         agent = self.agents[i_agent]
         if agent.status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED]:  # this agent has already completed...
             return
@@ -748,6 +750,8 @@ class RailEnv(Environment):
 
         # if agent is broken, actions are ignored and agent does not move.
         # full step penalty in this case
+        # TODO: this means that deadlocked agents which suffer a malfunction are marked as 
+        # stopped rather than deadlocked.
         if agent.malfunction_data['malfunction'] > 0:
             self.motionCheck.addAgent(i_agent, agent.position, agent.position)
             # agent will get penalty in step_agent2_cf
@@ -999,7 +1003,8 @@ class RailEnv(Environment):
             list_agents_state.append([
                     *pos, int(agent.direction), 
                     agent.malfunction_data["malfunction"],  
-                    int(agent.status)
+                    int(agent.status),
+                    int(agent.position in self.motionCheck.svDeadlocked)
                     ])
 
         self.cur_episode.append(list_agents_state)
diff --git a/flatland/utils/env_edit_utils.py b/flatland/utils/env_edit_utils.py
index bf6aa32a6c0d6050acd2c28de7177b4d0bade6bf..98a22b809d668a9062646c7ca3644b0c35c2092f 100644
--- a/flatland/utils/env_edit_utils.py
+++ b/flatland/utils/env_edit_utils.py
@@ -59,7 +59,8 @@ def makeEnv2(nAg=2, shape=(20,10), llrcPaths=[], lrcStarts=[], lrcTargs=[], liDi
                 number_of_agents=nAg,
                 schedule_generator=oSG,
                 obs_builder_object=obs.TreeObsForRailEnv(max_depth=1),
-                close_following=bUCF)
+                close_following=bUCF,
+                record_steps=True)
 
     envModel = editor.EditorModel(env)
     env.reset()
diff --git a/flatland/utils/jupyter_utils.py b/flatland/utils/jupyter_utils.py
index 3b7bc3e0c69310bf2696aa1831219f178cd5b500..f28f07af1142d8e7029aa7553d6c7a5c667f46b3 100644
--- a/flatland/utils/jupyter_utils.py
+++ b/flatland/utils/jupyter_utils.py
@@ -29,6 +29,16 @@ class AlwaysForward(Behaviour):
     def getActions(self):
         return { i:RailEnvActions.MOVE_FORWARD for i in range(self.nAg) }
 
+class DelayedStartForward(AlwaysForward):
+    def __init__(self, env, nStartDelay=2):
+        self.nStartDelay = nStartDelay
+        super().__init__(env)
+
+    def getActions(self):
+        iStep = self.env._elapsed_steps + 1
+        nAgentsMoving = min(self.nAg, iStep // self.nStartDelay)
+        return { i:RailEnvActions.MOVE_FORWARD for i in range(nAgentsMoving) }
+
 AgentPause = NamedTuple("AgentPause", 
     [
         ("iAg", int),
diff --git a/notebooks/Agent-Close-Following.ipynb b/notebooks/Agent-Close-Following.ipynb
index e0769089c2b2cfb774fce6d09524d2e7295c22ba..3fb4ed51644e01c7f971e13b88839625de25a46a 100644
--- a/notebooks/Agent-Close-Following.ipynb
+++ b/notebooks/Agent-Close-Following.ipynb
@@ -73,7 +73,9 @@
     "import networkx as nx\n",
     "import PIL\n",
     "from IPython import display\n",
-    "import time"
+    "import time\n",
+    "from matplotlib import pyplot as plt\n",
+    "import numpy as np"
    ]
   },
   {
@@ -147,6 +149,16 @@
     "gvDot"
    ]
   },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "#for v, dPred in omc.G.pred.items():\n",
+    "#    print (v, dPred)"
+   ]
+  },
   {
    "cell_type": "code",
    "execution_count": null,
@@ -257,7 +269,87 @@
    "execution_count": null,
    "metadata": {},
    "outputs": [],
-   "source": []
+   "source": [
+    "env.motionCheck.svDeadlocked"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Deadlocking agents"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "env, envModel = eeu.makeTestEnv(\"loop_with_loops\", nAg=10, bUCF=True)\n",
+    "oEC = ju.EnvCanvas(env, behaviour=ju.DelayedStartForward(env, nStartDelay=1))\n",
+    "oEC.show()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "for i in range(25):\n",
+    "    oEC.step()\n",
+    "    oEC.render()\n",
+    "    \n",
+    "    #display.display_html(f\"<br>Step: {i}\\n\", raw=True)\n",
+    "    #display.display_svg(ac.render(env.motionCheck, horizontal=(i>=3)))\n",
+    "    time.sleep(0.1)   "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "env.motionCheck.svDeadlocked"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "g3Ep = np.array(env.cur_episode)\n",
+    "g3Ep.shape"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "nSteps = g3Ep.shape[0]\n",
+    "plt.step(range(nSteps), np.sum(g3Ep[:,:,5], axis=1))\n",
+    "plt.title(\"Deadlocked agents\")\n",
+    "plt.xticks(range(g3Ep.shape[0]))\n",
+    "plt.yticks(range(11))\n",
+    "plt.grid()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "gnDeadlockExpected = np.array([ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0, 3,  5,  7,  9, 10, 10, 10, 10])\n",
+    "gnDeadlock = np.sum(g3Ep[:,:,5], axis=1)\n",
+    "\n",
+    "assert np.all(gnDeadlock == gnDeadlockExpected), \"Deadlocks by step do not match expected values!\""
+   ]
   }
  ],
  "metadata": {