From 3b1ffed4e6af0e70c34fe0bc124548dd16e09bf7 Mon Sep 17 00:00:00 2001
From: hagrid67 <jdhwatson@gmail.com>
Date: Wed, 17 Apr 2019 20:22:48 +0100
Subject: [PATCH] merged a bit of refactoring in rendertools.py with
 TransitionMap changes

---
 flatland/utils/rendertools.py | 172 ++++++++++++++++++----------------
 1 file changed, 93 insertions(+), 79 deletions(-)

diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py
index 771f228f..d0a78917 100644
--- a/flatland/utils/rendertools.py
+++ b/flatland/utils/rendertools.py
@@ -24,6 +24,8 @@ class RenderTool(object):
     gCentres = xr.DataArray(gGrid,
                             dims=["xy", "p1", "p2"],
                             coords={"xy": ["x", "y"]}) + xyPixHalf
+    gTheta = np.linspace(0, np.pi/2, 10)
+    gArc = array([np.cos(gTheta), np.sin(gTheta)]).T  # from [1,0] to [0,1]
 
     def __init__(self, env):
         self.env = env
@@ -54,14 +56,13 @@ class RenderTool(object):
                     # TODO: this was `rcDir' but it was undefined
                     rcNext = rcPos + iDir
                     # transition for next cell
-                    tbTrans = self.env.rail. \
-                        get_transitions((rcNext[0], rcNext[1], iDir))
+                    tbTrans = self.env.rail.get_transitions((*rcNext, iDir))
                     giTrans = np.where(tbTrans)[0]  # RC list of transitions
-                    gTransRCAg = self.__class__.gTransRC[giTrans]
+                    gTransRCAg = rt.gTransRC[giTrans]
 
         for visit in lVisits:
             # transition for next cell
-            tbTrans = self.env.rail.get_transitions((visit.rc[0], visit.rc[1], visit.iDir))
+            tbTrans = self.env.rail.get_transitions((*visit.rc, visit.iDir))
             giTrans = np.where(tbTrans)[0]  # RC list of transitions
             gTransRCAg = rt.gTransRC[giTrans]
             self.plotTrans(visit.rc, gTransRCAg, depth=str(visit.iDepth), color=color)
@@ -99,8 +100,7 @@ class RenderTool(object):
         )
         """
 
-        # TODO: suggest we provide an accessor in RailEnv
-        tbTrans = self.env.rail.get_transitions((rcPos[0], rcPos[1], iDir))
+        tbTrans = self.env.rail.get_transitions((*rcPos, iDir))
         giTrans = np.where(tbTrans)[0]  # RC list of transitions
 
         # HACK: workaround dead-end transitions
@@ -310,7 +310,86 @@ class RenderTool(object):
                 visit = visit.prev
                 xyPrev = xy
 
-    def renderEnv(self, show=False, curves=True, spacing=False, arrows=False, agents=True):
+    def drawTrans2(
+            self,
+            xyLine, xyCentre,
+            rotation, bDeadEnd=False,
+            sColor="gray",
+            bArrow=True,
+            spacing=0.1):
+        """
+        gLine is a numpy 2d array of points,
+        in the plotting space / coords.
+        eg:
+        [[0,.5],[1,0.2]] means a line
+        from x=0, y=0.5
+        to   x=1, y=0.2
+        """
+        rt = self.__class__
+        bStraight = rotation in [0, 2]
+        dx, dy = np.squeeze(np.diff(xyLine, axis=0)) * spacing / 2
+
+        if bStraight:
+
+            if sColor == "auto":
+                if dx > 0 or dy > 0:
+                    sColor = "C1"   # N or E
+                else:
+                    sColor = "C2"   # S or W
+
+            if bDeadEnd:
+                xyLine2 = array([
+                    xyLine[1] + [dy, dx],
+                    xyCentre,
+                    xyLine[1] - [dy, dx],
+                ])
+                plt.plot(*xyLine2.T, color=sColor)
+            else:
+                xyLine2 = xyLine + [-dy, dx]
+                plt.plot(*xyLine2.T, color=sColor)
+
+                if bArrow:
+                    xyMid = np.sum(xyLine2 * [[1/4], [3/4]], axis=0)
+
+                    xyArrow = array([
+                        xyMid + [-dx-dy, +dx-dy],
+                        xyMid,
+                        xyMid + [-dx+dy, -dx-dy]
+                        ])
+                    plt.plot(*xyArrow.T, color=sColor)
+
+        else:
+
+            xyMid = np.mean(xyLine, axis=0)
+            dxy = xyMid - xyCentre
+            xyCorner = xyMid + dxy
+            if rotation == 1:
+                rArcFactor = 1 - spacing
+                sColorAuto = "C1"
+            else:
+                rArcFactor = 1 + spacing
+                sColorAuto = "C2"
+            dxy2 = (xyCentre - xyCorner) * rArcFactor  # for scaling the arc
+
+            if sColor == "auto":
+                sColor = sColorAuto
+
+            plt.plot(*(rt.gArc * dxy2 + xyCorner).T, color=sColor)
+
+            if bArrow:
+                dx, dy = np.squeeze(np.diff(xyLine, axis=0)) / 20
+                iArc = int(len(rt.gArc) / 2)
+                xyMid = xyCorner + rt.gArc[iArc] * dxy2
+                xyArrow = array([
+                    xyMid + [-dx-dy, +dx-dy],
+                    xyMid,
+                    xyMid + [-dx+dy, -dx-dy]
+                    ])
+                plt.plot(*xyArrow.T, color=sColor)
+
+    def renderEnv(
+            self, show=False, curves=True, spacing=False,
+            arrows=False, agents=True, sRailColor="gray"):
         """
         Draw the environment using matplotlib.
         Draw into the figure if provided.
@@ -326,9 +405,6 @@ class RenderTool(object):
         # if oFigure is None:
         #    oFigure = plt.figure()
 
-        gTheta = np.linspace(0, np.pi/2, 10)
-        gArc = array([np.cos(gTheta), np.sin(gTheta)]).T  # from [1,0] to [0,1]
-
         def drawTrans(oFrom, oTo, sColor="gray"):
             plt.plot(
                 [oFrom[0], oTo[0]],  # x
@@ -336,70 +412,6 @@ class RenderTool(object):
                 color=sColor
             )
 
-        def drawTrans2(
-                xyLine, xyCentre,
-                rotation, bDeadEnd=False,
-                sColor="gray",
-                bArrow=True,
-                spacing=0.1):
-            """
-            gLine is a numpy 2d array of points,
-            in the plotting space / coords.
-            eg:
-            [[0,.5],[1,0.2]] means a line
-            from x=0, y=0.5
-            to   x=1, y=0.2
-            """
-
-            bStraight = rotation in [0, 2]
-            if bStraight:
-
-                dx, dy = np.squeeze(np.diff(xyLine, axis=0)) * spacing / 2
-                if bDeadEnd:
-                    xyLine2 = array([
-                        xyLine[1] + [dy, dx],
-                        xyCentre,
-                        xyLine[1] - [dy, dx],
-                    ])
-                    plt.plot(*xyLine2.T, color=sColor)
-                else:
-                    xyLine2 = xyLine + [dy, dx]
-                    plt.plot(*xyLine2.T, color=sColor)
-
-                    if bArrow:
-                        xyMid = np.sum(xyLine2 * [[1/4], [3/4]], axis=0)
-
-                        xyArrow = array([
-                            xyMid + [-dx-dy, +dx-dy],
-                            xyMid,
-                            xyMid + [-dx+dy, -dx-dy]
-                            ])
-                        plt.plot(*xyArrow.T, color=sColor)
-
-            else:
-
-                xyMid = np.mean(xyLine, axis=0)
-                dxy = xyMid - xyCentre
-                xyCorner = xyMid + dxy
-                if rotation == 1:
-                    rArcFactor = 1 - spacing
-                else:
-                    rArcFactor = 1 + spacing
-                dxy2 = (xyCentre - xyCorner) * rArcFactor  # for scaling the arc
-
-                plt.plot(*(gArc * dxy2 + xyCorner).T, color=sColor)
-
-                if bArrow:
-                    dx, dy = np.squeeze(np.diff(xyLine, axis=0)) / 20
-                    iArc = int(len(gArc) / 2)
-                    xyMid = xyCorner + gArc[iArc] * dxy2
-                    xyArrow = array([
-                        xyMid + [-dx-dy, +dx-dy],
-                        xyMid,
-                        xyMid + [-dx+dy, -dx-dy]
-                        ])
-                    plt.plot(*xyArrow.T, color=sColor)
-
         env = self.env
 
         # Draw cells grid
@@ -466,18 +478,20 @@ class RenderTool(object):
                         if (tMoves[to_ori]):  # if we have this transition
 
                             if bDeadEnd:
-                                drawTrans2(
+                                self.drawTrans2(
                                     array([from_xy, to_xy]), xyCentre,
-                                    rotation, bDeadEnd=True, spacing=spacing)
+                                    rotation, bDeadEnd=True, spacing=spacing,
+                                    sColor=sRailColor)
 
                             else:
 
                                 if curves:
-                                    drawTrans2(
+                                    self.drawTrans2(
                                         array([from_xy, to_xy]), xyCentre,
-                                        rotation, spacing=spacing, bArrow=arrows)
+                                        rotation, spacing=spacing, bArrow=arrows,
+                                        sColor=sRailColor)
                                 else:
-                                    drawTrans(from_xy, to_xy)
+                                    drawTrans(from_xy, to_xy, sRailColor)
 
                             if False:
                                 print(
-- 
GitLab