Skip to content
Snippets Groups Projects
Forked from Flatland / Flatland
2311 commits behind the upstream repository.
rendertools.py 28.28 KiB
import time
from collections import deque

# import xarray as xr
import matplotlib.pyplot as plt
from flatland.utils.render_qt import QTGL, QTSVG
from flatland.utils.graphics_pil import PILGL, PILSVG
from flatland.utils.graphics_layer import GraphicsLayer
from recordtype import recordtype
from numpy import array
import numpy as np

# TODO: suggested renaming to RailEnvRenderTool, as it will only work with RailEnv!


class MPLGL(GraphicsLayer):
    def __init__(self, width, height):
        self.width = width
        self.height = height
        self.yxBase = array([6, 21])  # pixel offset
        self.nPixCell = 700 / width
        self.img = None

    def open_window(self):
        plt.figure(figsize=(10, 10))

    def plot(self, *args, **kwargs):
        plt.plot(*args, **kwargs)

    def scatter(self, *args, **kwargs):
        plt.scatter(*args, **kwargs)

    def text(self, *args, **kwargs):
        plt.text(*args, **kwargs)

    def prettify(self, *args, **kwargs):
        ax = plt.gca()
        plt.xticks(range(int(ax.get_xlim()[1]) + 1))
        plt.yticks(range(int(ax.get_ylim()[1]) + 1))
        plt.grid()
        plt.xlabel("Euclidean distance")
        plt.ylabel("Tree / Transition Depth")

    def prettify2(self, width, height, cell_size):
        plt.xlim([0, width * cell_size])
        plt.ylim([-height * cell_size, 0])

        gTicks = (np.arange(0, height) + 0.5) * cell_size
        gLabels = np.arange(0, height)
        plt.xticks(gTicks, gLabels)

        gTicks = np.arange(-height * cell_size, 0) + cell_size / 2
        gLabels = np.arange(height - 1, -1, -1)
        plt.yticks(gTicks, gLabels)

        plt.xlim([0, width * cell_size])
        plt.ylim([-height * cell_size, 0])

    def show(self, block=False):
        plt.show(block=block)

    def pause(self, seconds=0.00001):
        plt.pause(seconds)

    def clf(self):
        plt.clf()
        plt.close()

    def get_cmap(self, *args, **kwargs):
        return plt.get_cmap(*args, **kwargs)

    def beginFrame(self):
        self.img = None
        plt.figure(figsize=(10, 10))
        plt.clf()
        pass

    def endFrame(self):
        self.img = self.getImage(force=True)
        plt.clf()
        plt.close()

    def getImage(self, force=False):
        if self.img is None or force:
            ax = plt.gca()
            fig = ax.get_figure()
            fig.tight_layout(pad=0)
            fig.canvas.draw()
            data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
            data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
            self.img = data
        return self.img

    def adaptColor(self, color, lighten=False):
        color = super(self.__class__, self).adaptColor(color, lighten)
        # MPL has RGBA in [0,1]^4 not \mathbb{N} \cap [0,255]^4
        color = tuple([iRGBA / 255 for iRGBA in color])
        return color


class RenderTool(object):
    Visit = recordtype("Visit", ["rc", "iDir", "iDepth", "prev"])

    lColors = list("brgcmyk")
    # \delta RC for NESW
    gTransRC = np.array([[-1, 0], [0, 1], [1, 0], [0, -1]])
    nPixCell = 1  # misnomer...
    nPixHalf = nPixCell / 2
    xyHalf = array([nPixHalf, -nPixHalf])
    grc2xy = array([[0, -nPixCell], [nPixCell, 0]])
    gGrid = array(np.meshgrid(np.arange(10), -np.arange(10))) * array([[[nPixCell]], [[nPixCell]]])
    # xyPixHalf = xr.DataArray([nPixHalf, -nPixHalf],
    #                         dims="xy",
    #                         coords={"xy": ["x", "y"]})
    # gCentres = xr.DataArray(gGrid,
    #                        dims=["xy", "p1", "p2"],
    #                        coords={"xy": ["x", "y"]}) + xyPixHalf
    gTheta = np.linspace(0, np.pi / 2, 5)
    gArc = array([np.cos(gTheta), np.sin(gTheta)]).T  # from [1,0] to [0,1]

    def __init__(self, env, gl="MPL", show=False):
        self.env = env
        self.iFrame = 0
        self.time1 = time.time()
        self.lTimes = deque()
        # self.gl = MPLGL()

        if gl == "MPL":
            self.gl = MPLGL(env.width, env.height)
        elif gl == "QT":
            self.gl = QTGL(env.width, env.height)
        elif gl == "PIL":
            self.gl = PILGL(env.width, env.height)
        elif gl == "PILSVG":
            self.gl = PILSVG(env.width, env.height)
        elif gl == "QTSVG":
            self.gl = QTSVG(env.width, env.height)

        self.new_rail = True

    def resize(self):
        self.gl.resize(self.env)

    def set_new_rail(self):
        self.new_rail = True

    def plotTreeOnRail(self, lVisits, color="r"):
        """
        Derives and plots a tree of transitions starting at position rcPos
        in direction iDir.
        Returns a list of Visits which are the nodes / vertices in the tree.
        """
        rt = self.__class__

        for visit in lVisits:
            # transition for next cell
            tbTrans = self.env.rail.get_transitions((*visit.rc, visit.iDir))
            giTrans = np.where(tbTrans)[0]  # RC list of transitions
            gTransRCAg = rt.gTransRC[giTrans]
            self.plotTrans(visit.rc, gTransRCAg, depth=str(visit.iDepth), color=color)

    def plotAgents(self, targets=True, iSelectedAgent=None):
        cmap = self.gl.get_cmap('hsv',
                                lut=max(len(self.env.agents), len(self.env.agents_static) + 1))

        for iAgent, agent in enumerate(self.env.agents_static):
            if agent is None:
                continue
            oColor = cmap(iAgent)
            self.plotAgent(agent.position, agent.direction, oColor, target=agent.target if targets else None,
                           static=True, selected=iAgent == iSelectedAgent)

        for iAgent, agent in enumerate(self.env.agents):
            if agent is None:
                continue
            oColor = cmap(iAgent)
            self.plotAgent(agent.position, agent.direction, oColor, target=agent.target if targets else None)

    def getTransRC(self, rcPos, iDir, bgiTrans=False):
        """
        Get the available transitions for rcPos in direction iDir,
        as row & col deltas.

        If bgiTrans is True, return a grid of indices of available transitions.

        eg for a cell rcPos = (4,5), in direction iDir = 0 (N),
        where the available transitions are N and E, returns:
        [[-1,0], [0,1]] ie N=up one row, and E=right one col.
        and if bgiTrans is True, returns a tuple:
        (
            [[-1,0], [0,1]], # deltas as before
            [0, 1] #  available transition indices, ie N, E
        )
        """

        tbTrans = self.env.rail.get_transitions((*rcPos, iDir))
        giTrans = np.where(tbTrans)[0]  # RC list of transitions

        # HACK: workaround dead-end transitions
        if len(giTrans) == 0:
            # print("Dead End", rcPos, iDir, tbTrans, giTrans)
            iDirReverse = (iDir + 2) % 4
            tbTrans = tuple(int(iDir2 == iDirReverse) for iDir2 in range(4))
            giTrans = np.where(tbTrans)[0]  # RC list of transitions
            # print("Dead End2", rcPos, iDirReverse, tbTrans, giTrans)

        # print("agent", array(list("NESW"))[giTrans], self.gTransRC[giTrans])
        gTransRCAg = self.__class__.gTransRC[giTrans]

        if bgiTrans:
            return gTransRCAg, giTrans
        else:
            return gTransRCAg

    def plotAgent(self, rcPos, iDir, color="r", target=None, static=False, selected=False):
        """
        Plot a simple agent.
        Assumes a working graphics layer context (cf a MPL figure).
        """
        rt = self.__class__

        rcDir = rt.gTransRC[iDir]  # agent direction in RC
        xyDir = np.matmul(rcDir, rt.grc2xy)  # agent direction in xy

        xyPos = np.matmul(rcPos - rcDir / 2, rt.grc2xy) + rt.xyHalf

        if static:
            color = self.gl.adaptColor(color, lighten=True)

        color = color

        # print("Agent:", rcPos, iDir, rcDir, xyDir, xyPos)
        self.gl.scatter(*xyPos, color=color, layer=1, marker="o", s=100)  # agent location
        xyDirLine = array([xyPos, xyPos + xyDir / 2]).T  # line for agent orient.
        self.gl.plot(*xyDirLine, color=color, layer=1, lw=5, ms=0, alpha=0.6)
        if selected:
            self._draw_square(xyPos, 1, color)

        if target is not None:
            rcTarget = array(target)
            xyTarget = np.matmul(rcTarget, rt.grc2xy) + rt.xyHalf
            self._draw_square(xyTarget, 1 / 3, color, layer=1)

    def plotTrans(self, rcPos, gTransRCAg, color="r", depth=None):
        """
        plot the transitions in gTransRCAg at position rcPos.
        gTransRCAg is a 2d numpy array containing a list of RC transitions,
        eg [[-1,0], [0,1]] means N, E.

        """

        rt = self.__class__
        xyPos = np.matmul(rcPos, rt.grc2xy) + rt.xyHalf
        gxyTrans = xyPos + np.matmul(gTransRCAg, rt.grc2xy / 2.4)
        self.gl.scatter(*gxyTrans.T, color=color, marker="o", s=50, alpha=0.2)
        if depth is not None:
            for x, y in gxyTrans:
                self.gl.text(x, y, depth)

    def getTreeFromRail(self, rcPos, iDir, nDepth=10, bBFS=True, bPlot=False):
        """
        Generate a tree from the env starting at rcPos, iDir.
        """
        rt = self.__class__
        print(rcPos, iDir)
        iPos = 0 if bBFS else -1  # BF / DF Search

        iDepth = 0
        visited = set()
        lVisits = []
        # stack = [ (rcPos,iDir,nDepth) ]
        stack = [rt.Visit(rcPos, iDir, iDepth, None)]
        while stack:
            visit = stack.pop(iPos)
            rcd = (visit.rc, visit.iDir)
            if visit.iDepth > nDepth:
                continue
            lVisits.append(visit)

            if rcd not in visited:
                visited.add(rcd)

                # moves = self._get_valid_transitions( node[0], node[1] )
                gTransRCAg, giTrans = self.getTransRC(visit.rc,
                                                      visit.iDir,
                                                      bgiTrans=True)
                # nodePos = node[0]

                # enqueue the next nodes (ie transitions from this node)
                for gTransRC2, iTrans in zip(gTransRCAg, giTrans):
                    # print("Trans:", gTransRC2)
                    visitNext = rt.Visit(tuple(visit.rc + gTransRC2),
                                         iTrans,
                                         visit.iDepth + 1,
                                         visit)
                    # print("node2: ", node2)
                    stack.append(visitNext)

                # plot the available transitions from this node
                if bPlot:
                    self.plotTrans(
                        visit.rc, gTransRCAg,
                        depth=str(visit.iDepth))

        return lVisits

    def plotTree(self, lVisits, xyTarg):
        '''
        Plot a vertical tree of transitions.
        Returns the "visit" to the destination
        (ie where euclidean distance is near zero) or None if absent.
        '''

        dPos = {}
        iPos = 0

        visitDest = None

        for iVisit, visit in enumerate(lVisits):

            if visit.rc in dPos:
                xLoc = dPos[visit.rc]
            else:
                xLoc = dPos[visit.rc] = iPos
                iPos += 1

            rDist = np.linalg.norm(array(visit.rc) - array(xyTarg))
            # sDist = "%.1f" % rDist

            xLoc = rDist + visit.iDir / 4

            # point labelled with distance
            self.gl.scatter(xLoc, visit.iDepth, color="k", s=2)
            # plt.text(xLoc, visit.iDepth, sDist, color="k", rotation=45)
            self.gl.text(xLoc, visit.iDepth, visit.rc, color="k", rotation=45)

            # if len(dPos)>1:
            if visit.prev:
                # print(dPos)
                # print(tNodeDepth)
                xLocPrev = dPos[visit.prev.rc]

                rDistPrev = np.linalg.norm(array(visit.prev.rc) -
                                           array(xyTarg))
                # sDist = "%.1f" % rDistPrev

                xLocPrev = rDistPrev + visit.prev.iDir / 4

                # line from prev node
                self.gl.plot([xLocPrev, xLoc],
                             [visit.iDepth - 1, visit.iDepth],
                             color="k", alpha=0.5, lw=1)

            if rDist < 0.1:
                visitDest = visit

        # Walk backwards from destination to origin, plotting in red
        if visitDest is not None:
            visit = visitDest
            xLocPrev = None
            while visit is not None:
                rDist = np.linalg.norm(array(visit.rc) - array(xyTarg))
                xLoc = rDist + visit.iDir / 4
                if xLocPrev is not None:
                    self.gl.plot([xLoc, xLocPrev], [visit.iDepth, visit.iDepth + 1],
                                 color="r", alpha=0.5, lw=2)
                xLocPrev = xLoc
                visit = visit.prev
            # prev = prev.prev

        # self.gl.xticks(range(7)); self.gl.yticks(range(11))
        self.gl.prettify()
        return visitDest

    def plotPath(self, visitDest):
        """
        Given a "final" visit visitDest, plotPath recurses back through the path
        using the visit.prev field (previous) to get back to the start of the path.
        The path of transitions is plotted with arrows at 3/4 along the line.
        The transition is plotted slightly to one side of the rail, so that
        transitions in opposite directions are separate.
        Currently, no attempt is made to make the transition arrows coincide
        at corners, and they are straight only.
        """

        rt = self.__class__
        # Walk backwards from destination to origin
        if visitDest is not None:
            visit = visitDest
            xyPrev = None
            while visit is not None:
                xy = np.matmul(visit.rc, rt.grc2xy) + rt.xyHalf
                if xyPrev is not None:
                    dx, dy = (xyPrev - xy) / 20
                    xyLine = array([xy, xyPrev]) + array([dy, dx])

                    self.gl.plot(*xyLine.T, color="r", alpha=0.5, lw=1)

                    xyMid = np.sum(xyLine * [[1 / 4], [3 / 4]], axis=0)

                    xyArrow = array([
                        xyMid + [-dx - dy, +dx - dy],
                        xyMid,
                        xyMid + [-dx + dy, -dx - dy]])
                    self.gl.plot(*xyArrow.T, color="r")

                visit = visit.prev
                xyPrev = xy

    def drawTrans(self, oFrom, oTo, sColor="gray"):
        self.gl.plot(
            [oFrom[0], oTo[0]],  # x
            [oFrom[1], oTo[1]],  # y
            color=sColor
        )

    def drawTrans2(self,
                   xyLine, xyCentre,
                   rotation, bDeadEnd=False,
                   sColor="gray",
                   bArrow=True,
                   spacing=0.1):
        """
        gLine is a numpy 2d array of points,
        in the plotting space / coords.
        eg:
        [[0,.5],[1,0.2]] means a line
        from x=0, y=0.5
        to   x=1, y=0.2
        """
        rt = self.__class__
        bStraight = rotation in [0, 2]
        dx, dy = np.squeeze(np.diff(xyLine, axis=0)) * spacing / 2

        if bStraight:

            if sColor == "auto":
                if dx > 0 or dy > 0:
                    sColor = "C1"  # N or E
                else:
                    sColor = "C2"  # S or W

            if bDeadEnd:
                xyLine2 = array([
                    xyLine[1] + [dy, dx],
                    xyCentre,
                    xyLine[1] - [dy, dx],
                ])
                self.gl.plot(*xyLine2.T, color=sColor)
            else:
                xyLine2 = xyLine + [-dy, dx]
                self.gl.plot(*xyLine2.T, color=sColor)

                if bArrow:
                    xyMid = np.sum(xyLine2 * [[1 / 4], [3 / 4]], axis=0)

                    xyArrow = array([
                        xyMid + [-dx - dy, +dx - dy],
                        xyMid,
                        xyMid + [-dx + dy, -dx - dy]])
                    self.gl.plot(*xyArrow.T, color=sColor)

        else:

            xyMid = np.mean(xyLine, axis=0)
            dxy = xyMid - xyCentre
            xyCorner = xyMid + dxy
            if rotation == 1:
                rArcFactor = 1 - spacing
                sColorAuto = "C1"
            else:
                rArcFactor = 1 + spacing
                sColorAuto = "C2"
            dxy2 = (xyCentre - xyCorner) * rArcFactor  # for scaling the arc

            if sColor == "auto":
                sColor = sColorAuto

            self.gl.plot(*(rt.gArc * dxy2 + xyCorner).T, color=sColor)

            if bArrow:
                dx, dy = np.squeeze(np.diff(xyLine, axis=0)) / 20
                iArc = int(len(rt.gArc) / 2)
                xyMid = xyCorner + rt.gArc[iArc] * dxy2
                xyArrow = array([
                    xyMid + [-dx - dy, +dx - dy],
                    xyMid,
                    xyMid + [-dx + dy, -dx - dy]])
                self.gl.plot(*xyArrow.T, color=sColor)

    def renderObs(self, agent_handles, observation_dict):
        """
        Render the extent of the observation of each agent. All cells that appear in the agent
        observation will be highlighted.
        :param agent_handles: List of agent indices to adapt color and get correct observation
        :param observation_dict: dictionary containing sets of cells of the agent observation

        """
        rt = self.__class__

        # cmap = self.gl.get_cmap('hsv', lut=max(len(self.env.agents), len(self.env.agents_static) + 1))

        for agent in agent_handles:
            # color = cmap(agent)
            color = self.gl.getAgentColor(agent)
            for visited_cell in observation_dict[agent]:
                cell_coord = array(visited_cell[:2])
                cell_coord_trans = np.matmul(cell_coord, rt.grc2xy) + rt.xyHalf
                self._draw_square(cell_coord_trans, 1 / (agent + 1.1), color, layer=1, opacity=100)

    def renderRail(self, spacing=False, sRailColor="gray", curves=True, arrows=False):

        cell_size = 1  # TODO: remove cell_size
        env = self.env

        # Draw cells grid
        grid_color = [0.95, 0.95, 0.95]
        for r in range(env.height + 1):
            self.gl.plot([0, (env.width + 1) * cell_size],
                         [-r * cell_size, -r * cell_size],
                         color=grid_color, linewidth=2)
        for c in range(env.width + 1):
            self.gl.plot([c * cell_size, c * cell_size],
                         [0, -(env.height + 1) * cell_size],
                         color=grid_color, linewidth=2)

        # Draw each cell independently
        for r in range(env.height):
            for c in range(env.width):

                # bounding box of the grid cell
                x0 = cell_size * c  # left
                x1 = cell_size * (c + 1)  # right
                y0 = cell_size * -r  # top
                y1 = cell_size * -(r + 1)  # bottom

                # centres of cell edges
                coords = [
                    ((x0 + x1) / 2.0, y0),  # N middle top
                    (x1, (y0 + y1) / 2.0),  # E middle right
                    ((x0 + x1) / 2.0, y1),  # S middle bottom
                    (x0, (y0 + y1) / 2.0)  # W middle left
                ]

                # cell centre
                xyCentre = array([x0, y1]) + cell_size / 2

                # cell transition values
                oCell = env.rail.get_transitions((r, c))

                bCellValid = env.rail.cell_neighbours_valid((r, c), check_this_cell=True)

                # Special Case 7, with a single bit; terminate at center
                nbits = 0
                tmp = oCell

                while tmp > 0:
                    nbits += (tmp & 1)
                    tmp = tmp >> 1

                # as above - move the from coord to the centre
                # it's a dead env.
                bDeadEnd = nbits == 1

                if not bCellValid:
                    # print("invalid:", r, c)
                    self.gl.scatter(*xyCentre, color="r", s=30)

                for orientation in range(4):  # ori is where we're heading
                    from_ori = (orientation + 2) % 4  # 0123=NESW -> 2301=SWNE
                    from_xy = coords[from_ori]

                    # renderer.push()
                    # renderer.translate(c * CELL_PIXELS, r * CELL_PIXELS)

                    tMoves = env.rail.get_transitions((r, c, orientation))

                    # to_ori = (orientation + 2) % 4
                    for to_ori in range(4):
                        to_xy = coords[to_ori]
                        rotation = (to_ori - from_ori) % 4

                        if (tMoves[to_ori]):  # if we have this transition

                            if bDeadEnd:
                                self.drawTrans2(
                                    array([from_xy, to_xy]), xyCentre,
                                    rotation, bDeadEnd=True, spacing=spacing,
                                    sColor=sRailColor)

                            else:

                                if curves:
                                    self.drawTrans2(
                                        array([from_xy, to_xy]), xyCentre,
                                        rotation, spacing=spacing, bArrow=arrows,
                                        sColor=sRailColor)
                                else:
                                    self.drawTrans(self, from_xy, to_xy, sRailColor)

                            if False:
                                print(
                                    "r,c,ori: ", r, c, orientation,
                                    "cell:", "{0:b}".format(oCell),
                                    "moves:", tMoves,
                                    "from:", from_ori, from_xy,
                                    "to: ", to_ori, to_xy,
                                    "cen:", *xyCentre,
                                    "rot:", rotation,
                                )

    def renderEnv(self, show=False, curves=True, spacing=False,
                  arrows=False, agents=True, renderobs=True, show_observations=True, sRailColor="gray", frames=False,
                  iEpisode=None, iStep=None,
                  iSelectedAgent=None, action_dict=None):
        """
        Draw the environment using matplotlib.
        Draw into the figure if provided.

        Call pyplot.show() if show==True.
        (Use show=False from a Jupyter notebook with %matplotlib inline)
        """

        if not self.gl.is_raster():
            self.renderEnv2(show=show, curves=curves, spacing=spacing,
                            arrows=arrows, agents=agents, show_observations=show_observations,
                            sRailColor=sRailColor,
                            frames=frames, iEpisode=iEpisode, iStep=iStep,
                            iSelectedAgent=iSelectedAgent, action_dict=action_dict)
            return

        if type(self.gl) in (QTGL, PILGL):
            self.gl.beginFrame()

        if type(self.gl) is MPLGL:
            # self.gl.clf()
            self.gl.beginFrame()
            pass

        # self.gl.clf()
        # if oFigure is None:
        #    oFigure = self.gl.figure()

        env = self.env

        self.renderRail()

        # Draw each agent + its orientation + its target
        if agents:
            self.plotAgents(targets=True, iSelectedAgent=iSelectedAgent)
        if show_observations:
            self.renderObs(range(env.get_num_agents()), env.dev_obs_dict)
        # Draw some textual information like fps
        yText = [-0.3, -0.6, -0.9]
        if frames:
            self.gl.text(0.1, yText[2], "Frame:{:}".format(self.iFrame))
        self.iFrame += 1

        if iEpisode is not None:
            self.gl.text(0.1, yText[1], "Ep:{}".format(iEpisode))

        if iStep is not None:
            self.gl.text(0.1, yText[0], "Step:{}".format(iStep))

        tNow = time.time()
        self.gl.text(2, yText[2], "elapsed:{:.2f}s".format(tNow - self.time1))
        self.lTimes.append(tNow)
        if len(self.lTimes) > 20:
            self.lTimes.popleft()
        if len(self.lTimes) > 1:
            rFps = (len(self.lTimes) - 1) / (self.lTimes[-1] - self.lTimes[0])
            self.gl.text(2, yText[1], "fps:{:.2f}".format(rFps))

        self.gl.prettify2(env.width, env.height, self.nPixCell)

        # TODO: for MPL, we don't want to call clf (called by endframe)
        # for QT, we need to call endFrame()
        # if not show:
        if type(self.gl) is QTGL:
            self.gl.endFrame()
            if show:
                self.gl.show(block=False)

        if type(self.gl) is MPLGL:
            if show:
                self.gl.show(block=False)
            # self.gl.endFrame()

        if show and type(self.gl) is PILGL:
            self.gl.show()

        self.gl.pause(0.00001)

        return

    def _draw_square(self, center, size, color, opacity=255, layer=0):
        x0 = center[0] - size / 2
        x1 = center[0] + size / 2
        y0 = center[1] - size / 2
        y1 = center[1] + size / 2
        self.gl.plot([x0, x1, x1, x0, x0], [y0, y0, y1, y1, y0], color=color, layer=layer, opacity=opacity)

    def getImage(self):
        return self.gl.getImage()

    def plotTreeObs(self, gObs):
        nBranchFactor = 4

        gP0 = array([[0, 0, 0]]).T
        nDepth = 2
        for i in range(nDepth):
            nDepthNodes = nBranchFactor ** i
            # rScale = nBranchFactor ** (nDepth - i)
            rShrinkDepth = 1 / (i + 1)
            # gX1 = np.linspace(-nDepthNodes / 2, nDepthNodes / 2, nDepthNodes) * rShrinkDepth

            gX1 = np.linspace(-(nDepthNodes - 1), (nDepthNodes - 1), nDepthNodes) * rShrinkDepth
            gY1 = np.ones((nDepthNodes)) * i
            gZ1 = np.zeros((nDepthNodes))

            gP1 = array([gX1, gY1, gZ1])
            gP01 = np.append(gP0, gP1, axis=1)

            if nDepthNodes > 1:
                nDepthNodesPrev = nDepthNodes / nBranchFactor
                giP0 = np.repeat(np.arange(nDepthNodesPrev), nBranchFactor)
                giP1 = np.arange(0, nDepthNodes) + nDepthNodesPrev
                giLinePoints = np.stack([giP0, giP1]).ravel("F")
                # print(gP01[:,:10])
                print(giLinePoints)
                self.gl.plot(gP01[0], -gP01[1], lines=giLinePoints, color="gray")

            gP0 = array([gX1, gY1, gZ1])

    def renderEnv2(
        self, show=False, curves=True, spacing=False, arrows=False, agents=True, 
            show_observations=True, sRailColor="gray",
            frames=False, iEpisode=None, iStep=None, iSelectedAgent=None,
            action_dict=dict()):
        """
        Draw the environment using matplotlib.
        Draw into the figure if provided.

        Call pyplot.show() if show==True.
        (Use show=False from a Jupyter notebook with %matplotlib inline)
        """

        env = self.env

        self.gl.beginFrame()

        if self.new_rail:
            self.new_rail = False
            self.gl.clear_rails()

            # store the targets
            dTargets = {}
            for iAgent, agent in enumerate(self.env.agents_static):
                if agent is None:
                    continue
                dTargets[tuple(agent.target)] = iAgent

            # Draw each cell independently
            for r in range(env.height):
                for c in range(env.width):
                    binTrans = env.rail.grid[r, c]
                    if (r, c) in dTargets:
                        target = dTargets[(r, c)]
                    else:
                        target = None
                    self.gl.setRailAt(r, c, binTrans, iTarget=target)

        for iAgent, agent in enumerate(self.env.agents):
            if agent is None:
                continue

            if agent.old_position is not None:
                position = agent.old_position
                direction = agent.direction
                old_direction = agent.old_direction
            else:
                position = agent.position
                direction = agent.direction
                old_direction = agent.direction

            # setAgentAt uses the agent index for the color
            # cmap = self.gl.get_cmap('hsv', lut=max(len(self.env.agents), len(self.env.agents_static) + 1))
            self.gl.setAgentAt(iAgent, *position, old_direction, direction)  # ,color=cmap(iAgent))

        if show_observations:
            self.renderObs(range(env.get_num_agents()), env.dev_obs_dict)

        if show:
            self.gl.show()
        for i in range(3):
            self.gl.processEvents()

        self.iFrame += 1
        return

    def close_window(self):
        self.gl.close_window()