Commit f3e8257b authored by hagrid67's avatar hagrid67

count deadlocks; store deadlock status per agent in saved episode

parent b4ddce5d
Pipeline #5598 passed with stages
in 94 minutes and 31 seconds
......@@ -12,6 +12,9 @@ class MotionCheck(object):
"""
def __init__(self):
self.G = nx.DiGraph()
self.nDeadlocks = 0
self.svDeadlocked = set()
def addAgent(self, iAg, rc1, rc2, xlabel=None):
""" add an agent and its motion as row,col tuples of current and next position.
......@@ -60,6 +63,10 @@ class MotionCheck(object):
return svStops
def find_stop_preds(self, svStops=None):
""" Find the predecessors to a list of stopped agents (ie the nodes / vertices)
Returns the set of predecessors.
Includes "chained" predecessors.
"""
if svStops is None:
svStops = self.find_stops2()
......@@ -73,9 +80,10 @@ class MotionCheck(object):
for oWCC in lWCC:
#print("Component:", oWCC)
# Get the node details for this WCC in a subgraph
Gwcc = self.G.subgraph(oWCC)
# Find all the stops in this chain
# Find all the stops in this chain or tree
svCompStops = svStops.intersection(Gwcc)
#print(svCompStops)
......@@ -91,11 +99,14 @@ class MotionCheck(object):
lStops = list(iter_stops)
svBlocked.update(lStops)
# the set of all the nodes/agents blocked by this set of stopped nodes
return svBlocked
def find_swaps(self):
""" find all the swap conflicts where two agents are trying to exchange places.
These appear as simple cycles of length 2.
These agents are necessarily deadlocked (since they can't change direction in flatland) -
meaning they will now be stuck for the rest of the episode.
"""
#svStops = self.find_stops2()
llvLoops = list(nx.algorithms.cycles.simple_cycles(self.G))
......@@ -109,30 +120,73 @@ class MotionCheck(object):
"""
pass
def block_preds(self, svStops, color="red"):
""" Take a list of stopped agents, and apply a stop color to any chains/trees
of agents trying to head toward those cells.
Count the number of agents blocked, ignoring those which are already marked.
(Otherwise it can double count swaps)
"""
iCount = 0
svBlocked = set()
# The reversed graph allows us to follow directed edges to find affected agents.
Grev = self.G.reverse()
for v in svStops:
# Use depth-first-search to find a tree of agents heading toward the blocked cell.
lvPred = list(nx.traversal.dfs_postorder_nodes(Grev, source=v))
svBlocked |= set(lvPred)
svBlocked.add(v)
#print("node:", v, "set", svBlocked)
# only count those not already marked
for v2 in [v]+lvPred:
if self.G.nodes[v2].get("color") != color:
self.G.nodes[v2]["color"] = color
iCount += 1
return svBlocked
def find_conflicts(self):
svStops = self.find_stops2() # { u for u,v in nx.classes.function.selfloop_edges(self.G) }
#llvLoops = list(nx.algorithms.cycles.simple_cycles(self.G))
#llvSwaps = [lvLoop for lvLoop in llvLoops if len(lvLoop) == 2 ]
#svSwaps = { v for lvSwap in llvSwaps for v in lvSwap }
svSwaps = self.find_swaps()
svBlocked = self.find_stop_preds(svStops.union(svSwaps))
svStops = self.find_stops2() # voluntarily stopped agents - have self-loops
svSwaps = self.find_swaps() # deadlocks - adjacent head-on collisions
# Block all swaps and their tree of predessors
self.svDeadlocked = self.block_preds(svSwaps, color="purple")
# Take the union of the above, and find all the predecessors
#svBlocked = self.find_stop_preds(svStops.union(svSwaps))
# Just look for the the tree of preds for each voluntarily stopped agent
svBlocked = self.find_stop_preds(svStops)
# iterate the nodes v with their predecessors dPred (dict of nodes->{})
for (v, dPred) in self.G.pred.items():
if v in svSwaps:
self.G.nodes[v]["color"] = "purple"
elif v in svBlocked:
# mark any swaps with purple - these are directly deadlocked
#if v in svSwaps:
# self.G.nodes[v]["color"] = "purple"
# If they are not directly deadlocked, but are in the union of stopped + deadlocked
#elif v in svBlocked:
# if in blocked, it will not also be in a swap pred tree, so no need to worry about overwriting
if v in svBlocked:
self.G.nodes[v]["color"] = "red"
# not blocked but has two or more predecessors, ie >=2 agents waiting to enter this node
elif len(dPred)>1:
if self.G.nodes[v].get("color") == "red":
# if this agent is already red/blocked, ignore. CHECK: why?
# certainly we want to ignore purple so we don't overwrite with red.
if self.G.nodes[v].get("color") in ("red", "purple"):
continue
# if this node has no agent, and >=2 want to enter it.
if self.G.nodes[v].get("agent") is None:
self.G.nodes[v]["color"] = "blue"
# this node has an agent and >=2 want to enter
else:
self.G.nodes[v]["color"] = "magenta"
# predecessors of a contended cell
# predecessors of a contended cell: {agent index -> node}
diAgCell = {self.G.nodes[vPred].get("agent"): vPred for vPred in dPred}
# remove the agent with the lowest index, who wins
......@@ -140,13 +194,15 @@ class MotionCheck(object):
diAgCell.pop(iAgWinner)
# Block all the remaining predessors, and their tree of preds
for iAg, v in diAgCell.items():
self.G.nodes[v]["color"] = "red"
for vPred in nx.traversal.dfs_postorder_nodes(self.G.reverse(), source=v):
self.G.nodes[vPred]["color"] = "red"
#for iAg, v in diAgCell.items():
# self.G.nodes[v]["color"] = "red"
# for vPred in nx.traversal.dfs_postorder_nodes(self.G.reverse(), source=v):
# self.G.nodes[vPred]["color"] = "red"
self.block_preds(diAgCell.values(), "red")
def check_motion(self, iAgent, rcPos):
""" If agent position is None, we use a dummy position of (-1, iAgent)
""" Returns tuple of boolean can the agent move, and the cell it will move into.
If agent position is None, we use a dummy position of (-1, iAgent)
"""
if rcPos is None:
......@@ -168,7 +224,7 @@ class MotionCheck(object):
# This should never happen - only the next cell of an agent has no successor
if len(dSucc)==0:
print(f"error condition - agent {iAg} node {rcPos} has no successor")
print(f"error condition - agent {iAgent} node {rcPos} has no successor")
return (False, rcPos)
# This agent has a successor
......@@ -181,6 +237,7 @@ class MotionCheck(object):
def render(omc:MotionCheck, horizontal=True):
try:
oAG = nx.drawing.nx_agraph.to_agraph(omc.G)
......
......@@ -728,6 +728,8 @@ class RailEnv(Environment):
self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
def _step_agent_cf(self, i_agent, action: Optional[RailEnvActions] = None):
""" "close following" version of step_agent.
"""
agent = self.agents[i_agent]
if agent.status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED]: # this agent has already completed...
return
......@@ -748,6 +750,8 @@ class RailEnv(Environment):
# if agent is broken, actions are ignored and agent does not move.
# full step penalty in this case
# TODO: this means that deadlocked agents which suffer a malfunction are marked as
# stopped rather than deadlocked.
if agent.malfunction_data['malfunction'] > 0:
self.motionCheck.addAgent(i_agent, agent.position, agent.position)
# agent will get penalty in step_agent2_cf
......@@ -999,7 +1003,8 @@ class RailEnv(Environment):
list_agents_state.append([
*pos, int(agent.direction),
agent.malfunction_data["malfunction"],
int(agent.status)
int(agent.status),
int(agent.position in self.motionCheck.svDeadlocked)
])
self.cur_episode.append(list_agents_state)
......
......@@ -59,7 +59,8 @@ def makeEnv2(nAg=2, shape=(20,10), llrcPaths=[], lrcStarts=[], lrcTargs=[], liDi
number_of_agents=nAg,
schedule_generator=oSG,
obs_builder_object=obs.TreeObsForRailEnv(max_depth=1),
close_following=bUCF)
close_following=bUCF,
record_steps=True)
envModel = editor.EditorModel(env)
env.reset()
......
......@@ -29,6 +29,16 @@ class AlwaysForward(Behaviour):
def getActions(self):
return { i:RailEnvActions.MOVE_FORWARD for i in range(self.nAg) }
class DelayedStartForward(AlwaysForward):
def __init__(self, env, nStartDelay=2):
self.nStartDelay = nStartDelay
super().__init__(env)
def getActions(self):
iStep = self.env._elapsed_steps + 1
nAgentsMoving = min(self.nAg, iStep // self.nStartDelay)
return { i:RailEnvActions.MOVE_FORWARD for i in range(nAgentsMoving) }
AgentPause = NamedTuple("AgentPause",
[
("iAg", int),
......
......@@ -73,7 +73,9 @@
"import networkx as nx\n",
"import PIL\n",
"from IPython import display\n",
"import time"
"import time\n",
"from matplotlib import pyplot as plt\n",
"import numpy as np"
]
},
{
......@@ -147,6 +149,16 @@
"gvDot"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#for v, dPred in omc.G.pred.items():\n",
"# print (v, dPred)"
]
},
{
"cell_type": "code",
"execution_count": null,
......@@ -257,7 +269,87 @@
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
"source": [
"env.motionCheck.svDeadlocked"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Deadlocking agents"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"env, envModel = eeu.makeTestEnv(\"loop_with_loops\", nAg=10, bUCF=True)\n",
"oEC = ju.EnvCanvas(env, behaviour=ju.DelayedStartForward(env, nStartDelay=1))\n",
"oEC.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for i in range(25):\n",
" oEC.step()\n",
" oEC.render()\n",
" \n",
" #display.display_html(f\"<br>Step: {i}\\n\", raw=True)\n",
" #display.display_svg(ac.render(env.motionCheck, horizontal=(i>=3)))\n",
" time.sleep(0.1) "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"env.motionCheck.svDeadlocked"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"g3Ep = np.array(env.cur_episode)\n",
"g3Ep.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"nSteps = g3Ep.shape[0]\n",
"plt.step(range(nSteps), np.sum(g3Ep[:,:,5], axis=1))\n",
"plt.title(\"Deadlocked agents\")\n",
"plt.xticks(range(g3Ep.shape[0]))\n",
"plt.yticks(range(11))\n",
"plt.grid()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"gnDeadlockExpected = np.array([ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 5, 7, 9, 10, 10, 10, 10])\n",
"gnDeadlock = np.sum(g3Ep[:,:,5], axis=1)\n",
"\n",
"assert np.all(gnDeadlock == gnDeadlockExpected), \"Deadlocks by step do not match expected values!\""
]
}
],
"metadata": {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment