diff --git a/flatland/envs/agent_chains.py b/flatland/envs/agent_chains.py index a2792479c0877aef2b41a8b62af7d55f5ab46bb0..3e566ad0617a3c49ec69e1049be36231b705916f 100644 --- a/flatland/envs/agent_chains.py +++ b/flatland/envs/agent_chains.py @@ -12,6 +12,9 @@ class MotionCheck(object): """ def __init__(self): self.G = nx.DiGraph() + self.nDeadlocks = 0 + self.svDeadlocked = set() + def addAgent(self, iAg, rc1, rc2, xlabel=None): """ add an agent and its motion as row,col tuples of current and next position. @@ -60,6 +63,10 @@ class MotionCheck(object): return svStops def find_stop_preds(self, svStops=None): + """ Find the predecessors to a list of stopped agents (ie the nodes / vertices) + Returns the set of predecessors. + Includes "chained" predecessors. + """ if svStops is None: svStops = self.find_stops2() @@ -73,9 +80,10 @@ class MotionCheck(object): for oWCC in lWCC: #print("Component:", oWCC) + # Get the node details for this WCC in a subgraph Gwcc = self.G.subgraph(oWCC) - # Find all the stops in this chain + # Find all the stops in this chain or tree svCompStops = svStops.intersection(Gwcc) #print(svCompStops) @@ -91,11 +99,14 @@ class MotionCheck(object): lStops = list(iter_stops) svBlocked.update(lStops) + # the set of all the nodes/agents blocked by this set of stopped nodes return svBlocked def find_swaps(self): """ find all the swap conflicts where two agents are trying to exchange places. These appear as simple cycles of length 2. + These agents are necessarily deadlocked (since they can't change direction in flatland) - + meaning they will now be stuck for the rest of the episode. """ #svStops = self.find_stops2() llvLoops = list(nx.algorithms.cycles.simple_cycles(self.G)) @@ -109,30 +120,73 @@ class MotionCheck(object): """ pass + def block_preds(self, svStops, color="red"): + """ Take a list of stopped agents, and apply a stop color to any chains/trees + of agents trying to head toward those cells. + Count the number of agents blocked, ignoring those which are already marked. + (Otherwise it can double count swaps) + + """ + iCount = 0 + svBlocked = set() + # The reversed graph allows us to follow directed edges to find affected agents. + Grev = self.G.reverse() + for v in svStops: + + # Use depth-first-search to find a tree of agents heading toward the blocked cell. + lvPred = list(nx.traversal.dfs_postorder_nodes(Grev, source=v)) + svBlocked |= set(lvPred) + svBlocked.add(v) + #print("node:", v, "set", svBlocked) + # only count those not already marked + for v2 in [v]+lvPred: + if self.G.nodes[v2].get("color") != color: + self.G.nodes[v2]["color"] = color + iCount += 1 + + return svBlocked + + def find_conflicts(self): - svStops = self.find_stops2() # { u for u,v in nx.classes.function.selfloop_edges(self.G) } - #llvLoops = list(nx.algorithms.cycles.simple_cycles(self.G)) - #llvSwaps = [lvLoop for lvLoop in llvLoops if len(lvLoop) == 2 ] - #svSwaps = { v for lvSwap in llvSwaps for v in lvSwap } - svSwaps = self.find_swaps() - svBlocked = self.find_stop_preds(svStops.union(svSwaps)) + svStops = self.find_stops2() # voluntarily stopped agents - have self-loops + svSwaps = self.find_swaps() # deadlocks - adjacent head-on collisions + + # Block all swaps and their tree of predessors + self.svDeadlocked = self.block_preds(svSwaps, color="purple") + # Take the union of the above, and find all the predecessors + #svBlocked = self.find_stop_preds(svStops.union(svSwaps)) + + # Just look for the the tree of preds for each voluntarily stopped agent + svBlocked = self.find_stop_preds(svStops) + + # iterate the nodes v with their predecessors dPred (dict of nodes->{}) for (v, dPred) in self.G.pred.items(): - if v in svSwaps: - self.G.nodes[v]["color"] = "purple" - elif v in svBlocked: + # mark any swaps with purple - these are directly deadlocked + #if v in svSwaps: + # self.G.nodes[v]["color"] = "purple" + # If they are not directly deadlocked, but are in the union of stopped + deadlocked + #elif v in svBlocked: + + # if in blocked, it will not also be in a swap pred tree, so no need to worry about overwriting + if v in svBlocked: self.G.nodes[v]["color"] = "red" + # not blocked but has two or more predecessors, ie >=2 agents waiting to enter this node elif len(dPred)>1: - if self.G.nodes[v].get("color") == "red": + # if this agent is already red/blocked, ignore. CHECK: why? + # certainly we want to ignore purple so we don't overwrite with red. + if self.G.nodes[v].get("color") in ("red", "purple"): continue + # if this node has no agent, and >=2 want to enter it. if self.G.nodes[v].get("agent") is None: self.G.nodes[v]["color"] = "blue" + # this node has an agent and >=2 want to enter else: self.G.nodes[v]["color"] = "magenta" - # predecessors of a contended cell + # predecessors of a contended cell: {agent index -> node} diAgCell = {self.G.nodes[vPred].get("agent"): vPred for vPred in dPred} # remove the agent with the lowest index, who wins @@ -140,13 +194,15 @@ class MotionCheck(object): diAgCell.pop(iAgWinner) # Block all the remaining predessors, and their tree of preds - for iAg, v in diAgCell.items(): - self.G.nodes[v]["color"] = "red" - for vPred in nx.traversal.dfs_postorder_nodes(self.G.reverse(), source=v): - self.G.nodes[vPred]["color"] = "red" + #for iAg, v in diAgCell.items(): + # self.G.nodes[v]["color"] = "red" + # for vPred in nx.traversal.dfs_postorder_nodes(self.G.reverse(), source=v): + # self.G.nodes[vPred]["color"] = "red" + self.block_preds(diAgCell.values(), "red") def check_motion(self, iAgent, rcPos): - """ If agent position is None, we use a dummy position of (-1, iAgent) + """ Returns tuple of boolean can the agent move, and the cell it will move into. + If agent position is None, we use a dummy position of (-1, iAgent) """ if rcPos is None: @@ -168,7 +224,7 @@ class MotionCheck(object): # This should never happen - only the next cell of an agent has no successor if len(dSucc)==0: - print(f"error condition - agent {iAg} node {rcPos} has no successor") + print(f"error condition - agent {iAgent} node {rcPos} has no successor") return (False, rcPos) # This agent has a successor @@ -181,6 +237,7 @@ class MotionCheck(object): + def render(omc:MotionCheck, horizontal=True): try: oAG = nx.drawing.nx_agraph.to_agraph(omc.G) diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index d9ff7120789d5ad58ba4acc15d16586f0e34b827..94d911eaab8dfe545d6960f9cc7068e53fee8e16 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -728,6 +728,8 @@ class RailEnv(Environment): self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] def _step_agent_cf(self, i_agent, action: Optional[RailEnvActions] = None): + """ "close following" version of step_agent. + """ agent = self.agents[i_agent] if agent.status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED]: # this agent has already completed... return @@ -748,6 +750,8 @@ class RailEnv(Environment): # if agent is broken, actions are ignored and agent does not move. # full step penalty in this case + # TODO: this means that deadlocked agents which suffer a malfunction are marked as + # stopped rather than deadlocked. if agent.malfunction_data['malfunction'] > 0: self.motionCheck.addAgent(i_agent, agent.position, agent.position) # agent will get penalty in step_agent2_cf @@ -999,7 +1003,8 @@ class RailEnv(Environment): list_agents_state.append([ *pos, int(agent.direction), agent.malfunction_data["malfunction"], - int(agent.status) + int(agent.status), + int(agent.position in self.motionCheck.svDeadlocked) ]) self.cur_episode.append(list_agents_state) diff --git a/flatland/utils/env_edit_utils.py b/flatland/utils/env_edit_utils.py index bf6aa32a6c0d6050acd2c28de7177b4d0bade6bf..98a22b809d668a9062646c7ca3644b0c35c2092f 100644 --- a/flatland/utils/env_edit_utils.py +++ b/flatland/utils/env_edit_utils.py @@ -59,7 +59,8 @@ def makeEnv2(nAg=2, shape=(20,10), llrcPaths=[], lrcStarts=[], lrcTargs=[], liDi number_of_agents=nAg, schedule_generator=oSG, obs_builder_object=obs.TreeObsForRailEnv(max_depth=1), - close_following=bUCF) + close_following=bUCF, + record_steps=True) envModel = editor.EditorModel(env) env.reset() diff --git a/flatland/utils/jupyter_utils.py b/flatland/utils/jupyter_utils.py index 3b7bc3e0c69310bf2696aa1831219f178cd5b500..f28f07af1142d8e7029aa7553d6c7a5c667f46b3 100644 --- a/flatland/utils/jupyter_utils.py +++ b/flatland/utils/jupyter_utils.py @@ -29,6 +29,16 @@ class AlwaysForward(Behaviour): def getActions(self): return { i:RailEnvActions.MOVE_FORWARD for i in range(self.nAg) } +class DelayedStartForward(AlwaysForward): + def __init__(self, env, nStartDelay=2): + self.nStartDelay = nStartDelay + super().__init__(env) + + def getActions(self): + iStep = self.env._elapsed_steps + 1 + nAgentsMoving = min(self.nAg, iStep // self.nStartDelay) + return { i:RailEnvActions.MOVE_FORWARD for i in range(nAgentsMoving) } + AgentPause = NamedTuple("AgentPause", [ ("iAg", int), diff --git a/notebooks/Agent-Close-Following.ipynb b/notebooks/Agent-Close-Following.ipynb index e0769089c2b2cfb774fce6d09524d2e7295c22ba..3fb4ed51644e01c7f971e13b88839625de25a46a 100644 --- a/notebooks/Agent-Close-Following.ipynb +++ b/notebooks/Agent-Close-Following.ipynb @@ -73,7 +73,9 @@ "import networkx as nx\n", "import PIL\n", "from IPython import display\n", - "import time" + "import time\n", + "from matplotlib import pyplot as plt\n", + "import numpy as np" ] }, { @@ -147,6 +149,16 @@ "gvDot" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#for v, dPred in omc.G.pred.items():\n", + "# print (v, dPred)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -257,7 +269,87 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "env.motionCheck.svDeadlocked" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Deadlocking agents" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "env, envModel = eeu.makeTestEnv(\"loop_with_loops\", nAg=10, bUCF=True)\n", + "oEC = ju.EnvCanvas(env, behaviour=ju.DelayedStartForward(env, nStartDelay=1))\n", + "oEC.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for i in range(25):\n", + " oEC.step()\n", + " oEC.render()\n", + " \n", + " #display.display_html(f\"<br>Step: {i}\\n\", raw=True)\n", + " #display.display_svg(ac.render(env.motionCheck, horizontal=(i>=3)))\n", + " time.sleep(0.1) " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "env.motionCheck.svDeadlocked" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "g3Ep = np.array(env.cur_episode)\n", + "g3Ep.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "nSteps = g3Ep.shape[0]\n", + "plt.step(range(nSteps), np.sum(g3Ep[:,:,5], axis=1))\n", + "plt.title(\"Deadlocked agents\")\n", + "plt.xticks(range(g3Ep.shape[0]))\n", + "plt.yticks(range(11))\n", + "plt.grid()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gnDeadlockExpected = np.array([ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 5, 7, 9, 10, 10, 10, 10])\n", + "gnDeadlock = np.sum(g3Ep[:,:,5], axis=1)\n", + "\n", + "assert np.all(gnDeadlock == gnDeadlockExpected), \"Deadlocks by step do not match expected values!\"" + ] } ], "metadata": {