diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index 535f44b4d422ef54a012c838c41d1ff432527e7f..dd49fb54f6cef27748f8936972bbbd7a8136d004 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -30,7 +30,7 @@ class RenderTool(object): def __init__(self, env): self.env = env - def plotTreeOnRail(self, rcPos, iDir, nDepth=10): + def plotTreeOnRail(self, lVisits, color="r"): """ Derives and plots a tree of transitions starting at position rcPos in direction iDir. @@ -50,7 +50,7 @@ class RenderTool(object): self.plotAgent(rcPos, iDir, sColor) gTransRCAg = self.getTransRC(rcPos, iDir) - self.plotTrans(rcPos, gTransRCAg) + self.plotTrans(rcPos, gTransRCAg, color=color) if False: # TODO: this was `rcDir' but it was undefined @@ -62,13 +62,15 @@ class RenderTool(object): giTrans = np.where(tbTrans)[0] # RC list of transitions gTransRCAg = self.__class__.gTransRC[giTrans] - # rcPos=(6,4) - # iDir=2 - gTransRCAg = self.getTransRC(rcPos, iDir) - self.plotTrans(rcPos, gTransRCAg) - - lVisits = self.getTreeFromRail(rcPos, iDir, nDepth=nDepth) - return lVisits + for visit in lVisits: + # transition for next cell + oTrans = self.env.rail[visit.rc] + tbTrans = rt.RETrans.get_transitions(oTrans, visit.iDir) + giTrans = np.where(tbTrans)[0] # RC list of transitions + gTransRCAg = rt.gTransRC[giTrans] + self.plotTrans( + visit.rc, gTransRCAg, + depth=str(visit.iDepth), color=color) def plotAgents(self): rt = self.__class__ @@ -148,13 +150,12 @@ class RenderTool(object): rt = self.__class__ xyPos = np.matmul(rcPos, rt.grc2xy) + rt.xyHalf gxyTrans = xyPos + np.matmul(gTransRCAg, rt.grc2xy/2.4) - # print(gxyTrans) plt.scatter(*gxyTrans.T, color=color, marker="o", s=50, alpha=0.2) if depth is not None: for x, y in gxyTrans: plt.text(x, y, depth) - def getTreeFromRail(self, rcPos, iDir, nDepth=10, bBFS=True): + def getTreeFromRail(self, rcPos, iDir, nDepth=10, bBFS=True, bPlot=False): """ Generate a tree from the env starting at rcPos, iDir. """ @@ -194,7 +195,10 @@ class RenderTool(object): stack.append(visitNext) # plot the available transitions from this node - self.plotTrans(visit.rc, gTransRCAg, depth=str(visit.iDepth)) + if bPlot: + self.plotTrans( + visit.rc, gTransRCAg, + depth=str(visit.iDepth)) return lVisits diff --git a/images/env-tree-spatial.png b/images/env-tree-spatial.png index 255feb209f35a315179dba8121ba81abe8d93c5c..b04b617f423b726fc0708bf10a78ed09e44c1a02 100644 Binary files a/images/env-tree-spatial.png and b/images/env-tree-spatial.png differ diff --git a/tests/test_rendertools.py b/tests/test_rendertools.py index b73b892a474180486d2a6d1d7143e6a5bee69e60..c68f84db56e0ab95ab2b384c0ffab683945424c2 100644 --- a/tests/test_rendertools.py +++ b/tests/test_rendertools.py @@ -61,10 +61,11 @@ def test_render_env(): plt.figure(figsize=(10,10)) oRT.renderEnv() - lVisits = oRT.plotTreeOnRail( + + lVisits = oRT.getTreeFromRail( oEnv.agents_position[0], oEnv.agents_direction[0], - nDepth=17) + nDepth=17, bPlot=True) checkFrozenImage("env-tree-spatial.png")