rendertools.py 26.9 KB
Newer Older
u214892's avatar
u214892 committed
1
2
import time
from collections import deque
3
from enum import IntEnum
4

hagrid67's avatar
hagrid67 committed
5
import numpy as np
u214892's avatar
u214892 committed
6
7
from numpy import array
from recordtype import recordtype
Erik Nygren's avatar
Erik Nygren committed
8

u214892's avatar
u214892 committed
9
from flatland.utils.graphics_pil import PILGL, PILSVG
10

u214892's avatar
u214892 committed
11
12

# TODO: suggested renaming to RailEnvRenderTool, as it will only work with RailEnv!
hagrid67's avatar
hagrid67 committed
13

14
15
16
17
18
19
20
class AgentRenderVariant(IntEnum):
    BOX_ONLY = 0
    ONE_STEP_BEHIND = 1
    AGENT_SHOWS_OPTIONS = 2
    ONE_STEP_BEHIND_AND_BOX = 3
    AGENT_SHOWS_OPTIONS_AND_BOX = 4

21

22
class RenderTool(object):
23
24
25
26
27
    """ Class to render the RailEnv and agents.
        Uses two layers, layer 0 for rails (mostly static), layer 1 for agents etc (dynamic)
        The lower / rail layer 0 is only redrawn after set_new_rail() has been called.
        Created with a "GraphicsLayer" or gl - now either PIL or PILSVG
    """
28
29
30
    Visit = recordtype("Visit", ["rc", "iDir", "iDepth", "prev"])

    lColors = list("brgcmyk")
31
32
    # \delta RC for NESW
    gTransRC = np.array([[-1, 0], [0, 1], [1, 0], [0, -1]])
Erik Nygren's avatar
Erik Nygren committed
33
    nPixCell = 1  # misnomer...
34
35
    nPixHalf = nPixCell / 2
    xyHalf = array([nPixHalf, -nPixHalf])
36
    grc2xy = array([[0, -nPixCell], [nPixCell, 0]])
Erik Nygren's avatar
Erik Nygren committed
37
    gGrid = array(np.meshgrid(np.arange(10), -np.arange(10))) * array([[[nPixCell]], [[nPixCell]]])
hagrid67's avatar
hagrid67 committed
38
    gTheta = np.linspace(0, np.pi / 2, 5)
39
    gArc = array([np.cos(gTheta), np.sin(gTheta)]).T  # from [1,0] to [0,1]
40

41
    def __init__(self, env, gl="PILSVG", jupyter=False, agentRenderVariant=AgentRenderVariant.ONE_STEP_BEHIND):
42
        self.env = env
43
44
45
        self.iFrame = 0
        self.time1 = time.time()
        self.lTimes = deque()
46

47
48
        self.agentRenderVariant = agentRenderVariant

49
        if gl == "PIL":
u214892's avatar
u214892 committed
50
            self.gl = PILGL(env.width, env.height, jupyter)
hagrid67's avatar
hagrid67 committed
51
        elif gl == "PILSVG":
u214892's avatar
u214892 committed
52
            self.gl = PILSVG(env.width, env.height, jupyter)
53
54
55
        else:
            print("[", gl, "] not found, switch to PILSVG")
            self.gl = PILSVG(env.width, env.height, jupyter)
Erik Nygren's avatar
Erik Nygren committed
56

57
        self.new_rail = True
58
59
60
61
62
63
64
65
66
67
        self.update_background()

    def update_background(self):
        # create background map
        dTargets = {}
        for iAgent, agent in enumerate(self.env.agents_static):
            if agent is None:
                continue
            dTargets[tuple(agent.target)] = iAgent
        self.gl.build_background_map(dTargets)
68

69
70
71
    def resize(self):
        self.gl.resize(self.env)

72
    def set_new_rail(self):
73
74
75
        """ Tell the renderer that the rail has changed.
            eg when the rail has been regenerated, or updated in the editor.
        """
76
        self.new_rail = True
77

78
    def plotTreeOnRail(self, lVisits, color="r"):
79
        """
80
        DEFUNCT
81
82
        Derives and plots a tree of transitions starting at position rcPos
        in direction iDir.
83
84
        Returns a list of Visits which are the nodes / vertices in the tree.
        """
85
        rt = self.__class__
86

87
88
        for visit in lVisits:
            # transition for next cell
89
            tbTrans = self.env.rail.get_transitions((*visit.rc, visit.iDir))
90
91
            giTrans = np.where(tbTrans)[0]  # RC list of transitions
            gTransRCAg = rt.gTransRC[giTrans]
92
            self.plotTrans(visit.rc, gTransRCAg, depth=str(visit.iDepth), color=color)
93

hagrid67's avatar
hagrid67 committed
94
95
    def plotAgents(self, targets=True, iSelectedAgent=None):
        cmap = self.gl.get_cmap('hsv',
Erik Nygren's avatar
Erik Nygren committed
96
                                lut=max(len(self.env.agents), len(self.env.agents_static) + 1))
hagrid67's avatar
hagrid67 committed
97
98

        for iAgent, agent in enumerate(self.env.agents_static):
99
100
            if agent is None:
                continue
hagrid67's avatar
hagrid67 committed
101
102
            oColor = cmap(iAgent)
            self.plotAgent(agent.position, agent.direction, oColor, target=agent.target if targets else None,
Erik Nygren's avatar
Erik Nygren committed
103
                           static=True, selected=iAgent == iSelectedAgent)
hagrid67's avatar
hagrid67 committed
104

105
        for iAgent, agent in enumerate(self.env.agents):
106
107
            if agent is None:
                continue
108
            oColor = cmap(iAgent)
109
            self.plotAgent(agent.position, agent.direction, oColor, target=agent.target if targets else None)
110
111

    def getTransRC(self, rcPos, iDir, bgiTrans=False):
112
113
114
115
116
117
118
119
120
121
122
123
124
125
        """
        Get the available transitions for rcPos in direction iDir,
        as row & col deltas.

        If bgiTrans is True, return a grid of indices of available transitions.

        eg for a cell rcPos = (4,5), in direction iDir = 0 (N),
        where the available transitions are N and E, returns:
        [[-1,0], [0,1]] ie N=up one row, and E=right one col.
        and if bgiTrans is True, returns a tuple:
        (
            [[-1,0], [0,1]], # deltas as before
            [0, 1] #  available transition indices, ie N, E
        )
126
127
        """

128
        tbTrans = self.env.rail.get_transitions((*rcPos, iDir))
129
        giTrans = np.where(tbTrans)[0]  # RC list of transitions
130

131
132
133
        # HACK: workaround dead-end transitions
        if len(giTrans) == 0:
            iDirReverse = (iDir + 2) % 4
134
            tbTrans = tuple(int(iDir2 == iDirReverse) for iDir2 in range(4))
135
            giTrans = np.where(tbTrans)[0]  # RC list of transitions
136
137

        gTransRCAg = self.__class__.gTransRC[giTrans]
138
139
140
141
142
143

        if bgiTrans:
            return gTransRCAg, giTrans
        else:
            return gTransRCAg

hagrid67's avatar
hagrid67 committed
144
    def plotAgent(self, rcPos, iDir, color="r", target=None, static=False, selected=False):
145
146
        """
        Plot a simple agent.
147
        Assumes a working graphics layer context (cf a MPL figure).
148
149
        """
        rt = self.__class__
150

Erik Nygren's avatar
Erik Nygren committed
151
152
        rcDir = rt.gTransRC[iDir]  # agent direction in RC
        xyDir = np.matmul(rcDir, rt.grc2xy)  # agent direction in xy
153
154

        xyPos = np.matmul(rcPos - rcDir / 2, rt.grc2xy) + rt.xyHalf
hagrid67's avatar
hagrid67 committed
155
156
157
158

        if static:
            color = self.gl.adaptColor(color, lighten=True)

159
160
161
        color = color

        self.gl.scatter(*xyPos, color=color, layer=1, marker="o", s=100)  # agent location
162
        xyDirLine = array([xyPos, xyPos + xyDir / 2]).T  # line for agent orient.
163
        self.gl.plot(*xyDirLine, color=color, layer=1, lw=5, ms=0, alpha=0.6)
hagrid67's avatar
hagrid67 committed
164
165
        if selected:
            self._draw_square(xyPos, 1, color)
166
167
168
169

        if target is not None:
            rcTarget = array(target)
            xyTarget = np.matmul(rcTarget, rt.grc2xy) + rt.xyHalf
170
            self._draw_square(xyTarget, 1 / 3, color, layer=1)
171

172
    def plotTrans(self, rcPos, gTransRCAg, color="r", depth=None):
hagrid67's avatar
hagrid67 committed
173
174
175
176
177
178
179
        """
        plot the transitions in gTransRCAg at position rcPos.
        gTransRCAg is a 2d numpy array containing a list of RC transitions,
        eg [[-1,0], [0,1]] means N, E.

        """

180
181
        rt = self.__class__
        xyPos = np.matmul(rcPos, rt.grc2xy) + rt.xyHalf
spiglerg's avatar
spiglerg committed
182
        gxyTrans = xyPos + np.matmul(gTransRCAg, rt.grc2xy / 2.4)
183
        self.gl.scatter(*gxyTrans.T, color=color, marker="o", s=50, alpha=0.2)
184
        if depth is not None:
185
            for x, y in gxyTrans:
186
                self.gl.text(x, y, depth)
187

188
    def getTreeFromRail(self, rcPos, iDir, nDepth=10, bBFS=True, bPlot=False):
189
        """
190
        DEFUNCT
191
192
193
194
        Generate a tree from the env starting at rcPos, iDir.
        """
        rt = self.__class__
        print(rcPos, iDir)
195
196
        iPos = 0 if bBFS else -1  # BF / DF Search

197
198
199
        iDepth = 0
        visited = set()
        lVisits = []
200
        stack = [rt.Visit(rcPos, iDir, iDepth, None)]
201
202
203
204
205
206
        while stack:
            visit = stack.pop(iPos)
            rcd = (visit.rc, visit.iDir)
            if visit.iDepth > nDepth:
                continue
            lVisits.append(visit)
207

208
209
            if rcd not in visited:
                visited.add(rcd)
210
211
212
213

                gTransRCAg, giTrans = self.getTransRC(visit.rc,
                                                      visit.iDir,
                                                      bgiTrans=True)
214
215
                # enqueue the next nodes (ie transitions from this node)
                for gTransRC2, iTrans in zip(gTransRCAg, giTrans):
216
217
                    visitNext = rt.Visit(tuple(visit.rc + gTransRC2),
                                         iTrans,
spiglerg's avatar
spiglerg committed
218
                                         visit.iDepth + 1,
219
220
221
                                         visit)
                    stack.append(visitNext)

222
                # plot the available transitions from this node
223
                if bPlot:
hagrid67's avatar
hagrid67 committed
224
225
226
                    self.plotTrans(
                        visit.rc, gTransRCAg,
                        depth=str(visit.iDepth))
227

228
229
230
231
232
        return lVisits

    def plotTree(self, lVisits, xyTarg):
        '''
        Plot a vertical tree of transitions.
233
234
        Returns the "visit" to the destination
        (ie where euclidean distance is near zero) or None if absent.
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
        '''

        dPos = {}
        iPos = 0

        visitDest = None

        for iVisit, visit in enumerate(lVisits):

            if visit.rc in dPos:
                xLoc = dPos[visit.rc]
            else:
                xLoc = dPos[visit.rc] = iPos
                iPos += 1

250
251
            rDist = np.linalg.norm(array(visit.rc) - array(xyTarg))

252
            xLoc = rDist + visit.iDir / 4
253

254
            # point labelled with distance
spiglerg's avatar
spiglerg committed
255
            self.gl.scatter(xLoc, visit.iDepth, color="k", s=2)
256
            self.gl.text(xLoc, visit.iDepth, visit.rc, color="k", rotation=45)
257
258

            # if len(dPos)>1:
259
260
            if visit.prev:
                xLocPrev = dPos[visit.prev.rc]
261
262
263
264

                rDistPrev = np.linalg.norm(array(visit.prev.rc) -
                                           array(xyTarg))

265
                xLocPrev = rDistPrev + visit.prev.iDir / 4
266

267
                # line from prev node
268
                self.gl.plot([xLocPrev, xLoc],
spiglerg's avatar
spiglerg committed
269
270
                             [visit.iDepth - 1, visit.iDepth],
                             color="k", alpha=0.5, lw=1)
271

272
273
274
275
276
277
278
279
            if rDist < 0.1:
                visitDest = visit

        # Walk backwards from destination to origin, plotting in red
        if visitDest is not None:
            visit = visitDest
            xLocPrev = None
            while visit is not None:
280
                rDist = np.linalg.norm(array(visit.rc) - array(xyTarg))
281
282
                xLoc = rDist + visit.iDir / 4
                if xLocPrev is not None:
spiglerg's avatar
spiglerg committed
283
284
                    self.gl.plot([xLoc, xLocPrev], [visit.iDepth, visit.iDepth + 1],
                                 color="r", alpha=0.5, lw=2)
285
286
287
                xLocPrev = xLoc
                visit = visit.prev

288
        self.gl.prettify()
289
290
291
        return visitDest

    def plotPath(self, visitDest):
hagrid67's avatar
hagrid67 committed
292
293
294
295
296
297
298
299
300
301
        """
        Given a "final" visit visitDest, plotPath recurses back through the path
        using the visit.prev field (previous) to get back to the start of the path.
        The path of transitions is plotted with arrows at 3/4 along the line.
        The transition is plotted slightly to one side of the rail, so that
        transitions in opposite directions are separate.
        Currently, no attempt is made to make the transition arrows coincide
        at corners, and they are straight only.
        """

302
303
304
305
306
        rt = self.__class__
        # Walk backwards from destination to origin
        if visitDest is not None:
            visit = visitDest
            xyPrev = None
307
            while visit is not None:
308
309
                xy = np.matmul(visit.rc, rt.grc2xy) + rt.xyHalf
                if xyPrev is not None:
hagrid67's avatar
hagrid67 committed
310
311
312
                    dx, dy = (xyPrev - xy) / 20
                    xyLine = array([xy, xyPrev]) + array([dy, dx])

313
                    self.gl.plot(*xyLine.T, color="r", alpha=0.5, lw=1)
hagrid67's avatar
hagrid67 committed
314

spiglerg's avatar
spiglerg committed
315
                    xyMid = np.sum(xyLine * [[1 / 4], [3 / 4]], axis=0)
hagrid67's avatar
hagrid67 committed
316
317

                    xyArrow = array([
spiglerg's avatar
spiglerg committed
318
                        xyMid + [-dx - dy, +dx - dy],
hagrid67's avatar
hagrid67 committed
319
                        xyMid,
spiglerg's avatar
spiglerg committed
320
                        xyMid + [-dx + dy, -dx - dy]])
321
                    self.gl.plot(*xyArrow.T, color="r")
hagrid67's avatar
hagrid67 committed
322

323
324
325
                visit = visit.prev
                xyPrev = xy

326
327
328
329
330
331
332
    def drawTrans(self, oFrom, oTo, sColor="gray"):
        self.gl.plot(
            [oFrom[0], oTo[0]],  # x
            [oFrom[1], oTo[1]],  # y
            color=sColor
        )

u214892's avatar
u214892 committed
333
334
335
336
337
338
    def drawTrans2(self,
                   xyLine, xyCentre,
                   rotation, bDeadEnd=False,
                   sColor="gray",
                   bArrow=True,
                   spacing=0.1):
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
        """
        gLine is a numpy 2d array of points,
        in the plotting space / coords.
        eg:
        [[0,.5],[1,0.2]] means a line
        from x=0, y=0.5
        to   x=1, y=0.2
        """
        rt = self.__class__
        bStraight = rotation in [0, 2]
        dx, dy = np.squeeze(np.diff(xyLine, axis=0)) * spacing / 2

        if bStraight:

            if sColor == "auto":
                if dx > 0 or dy > 0:
Erik Nygren's avatar
Erik Nygren committed
355
                    sColor = "C1"  # N or E
356
                else:
Erik Nygren's avatar
Erik Nygren committed
357
                    sColor = "C2"  # S or W
358
359
360
361
362
363
364

            if bDeadEnd:
                xyLine2 = array([
                    xyLine[1] + [dy, dx],
                    xyCentre,
                    xyLine[1] - [dy, dx],
                ])
365
                self.gl.plot(*xyLine2.T, color=sColor)
366
367
            else:
                xyLine2 = xyLine + [-dy, dx]
368
                self.gl.plot(*xyLine2.T, color=sColor)
369
370

                if bArrow:
spiglerg's avatar
spiglerg committed
371
                    xyMid = np.sum(xyLine2 * [[1 / 4], [3 / 4]], axis=0)
372
373

                    xyArrow = array([
spiglerg's avatar
spiglerg committed
374
                        xyMid + [-dx - dy, +dx - dy],
375
                        xyMid,
spiglerg's avatar
spiglerg committed
376
                        xyMid + [-dx + dy, -dx - dy]])
377
                    self.gl.plot(*xyArrow.T, color=sColor)
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394

        else:

            xyMid = np.mean(xyLine, axis=0)
            dxy = xyMid - xyCentre
            xyCorner = xyMid + dxy
            if rotation == 1:
                rArcFactor = 1 - spacing
                sColorAuto = "C1"
            else:
                rArcFactor = 1 + spacing
                sColorAuto = "C2"
            dxy2 = (xyCentre - xyCorner) * rArcFactor  # for scaling the arc

            if sColor == "auto":
                sColor = sColorAuto

395
            self.gl.plot(*(rt.gArc * dxy2 + xyCorner).T, color=sColor)
396
397
398
399
400
401

            if bArrow:
                dx, dy = np.squeeze(np.diff(xyLine, axis=0)) / 20
                iArc = int(len(rt.gArc) / 2)
                xyMid = xyCorner + rt.gArc[iArc] * dxy2
                xyArrow = array([
spiglerg's avatar
spiglerg committed
402
                    xyMid + [-dx - dy, +dx - dy],
403
                    xyMid,
spiglerg's avatar
spiglerg committed
404
                    xyMid + [-dx + dy, -dx - dy]])
405
                self.gl.plot(*xyArrow.T, color=sColor)
Erik Nygren's avatar
Erik Nygren committed
406

407
    def renderObs(self, agent_handles, observation_dict):
Erik Nygren's avatar
Erik Nygren committed
408
        """
409
410
        Render the extent of the observation of each agent. All cells that appear in the agent
        observation will be highlighted.
411
412
        :param agent_handles: List of agent indices to adapt color and get correct observation
        :param observation_dict: dictionary containing sets of cells of the agent observation
Erik Nygren's avatar
Erik Nygren committed
413
414
415
416
417

        """
        rt = self.__class__

        for agent in agent_handles:
418
            color = self.gl.getAgentColor(agent)
419
            for visited_cell in observation_dict[agent]:
Erik Nygren's avatar
Erik Nygren committed
420
                cell_coord = array(visited_cell[:2])
Erik Nygren's avatar
Erik Nygren committed
421
                cell_coord_trans = np.matmul(cell_coord, rt.grc2xy) + rt.xyHalf
u214892's avatar
u214892 committed
422
                self._draw_square(cell_coord_trans, 1 / (agent + 1.1), color, layer=1, opacity=100)
Erik Nygren's avatar
Erik Nygren committed
423

424
    def renderRail(self, spacing=False, sRailColor="gray", curves=True, arrows=False):
425

426
        cell_size = 1  # TODO: remove cell_size
427
428
429
        env = self.env

        # Draw cells grid
430
        grid_color = [0.95, 0.95, 0.95]
spiglerg's avatar
spiglerg committed
431
432
433
        for r in range(env.height + 1):
            self.gl.plot([0, (env.width + 1) * cell_size],
                         [-r * cell_size, -r * cell_size],
hagrid67's avatar
hagrid67 committed
434
                         color=grid_color, linewidth=2)
spiglerg's avatar
spiglerg committed
435
436
437
        for c in range(env.width + 1):
            self.gl.plot([c * cell_size, c * cell_size],
                         [0, -(env.height + 1) * cell_size],
hagrid67's avatar
hagrid67 committed
438
                         color=grid_color, linewidth=2)
439
440
441
442
443

        # Draw each cell independently
        for r in range(env.height):
            for c in range(env.width):

hagrid67's avatar
hagrid67 committed
444
                # bounding box of the grid cell
Erik Nygren's avatar
Erik Nygren committed
445
446
447
                x0 = cell_size * c  # left
                x1 = cell_size * (c + 1)  # right
                y0 = cell_size * -r  # top
spiglerg's avatar
spiglerg committed
448
                y1 = cell_size * -(r + 1)  # bottom
449

hagrid67's avatar
hagrid67 committed
450
                # centres of cell edges
451
                coords = [
spiglerg's avatar
spiglerg committed
452
453
454
                    ((x0 + x1) / 2.0, y0),  # N middle top
                    (x1, (y0 + y1) / 2.0),  # E middle right
                    ((x0 + x1) / 2.0, y1),  # S middle bottom
Erik Nygren's avatar
Erik Nygren committed
455
                    (x0, (y0 + y1) / 2.0)  # W middle left
456
457
                ]

hagrid67's avatar
hagrid67 committed
458
                # cell centre
hagrid67's avatar
hagrid67 committed
459
460
                xyCentre = array([x0, y1]) + cell_size / 2

hagrid67's avatar
hagrid67 committed
461
                # cell transition values
462
                oCell = env.rail.get_transitions((r, c))
463

464
                bCellValid = env.rail.cell_neighbours_valid((r, c), check_this_cell=True)
465

hagrid67's avatar
hagrid67 committed
466
467
468
469
470
471
472
473
474
475
476
477
                # Special Case 7, with a single bit; terminate at center
                nbits = 0
                tmp = oCell

                while tmp > 0:
                    nbits += (tmp & 1)
                    tmp = tmp >> 1

                # as above - move the from coord to the centre
                # it's a dead env.
                bDeadEnd = nbits == 1

478
                if not bCellValid:
479
                    self.gl.scatter(*xyCentre, color="r", s=30)
480

481
482
                for orientation in range(4):  # ori is where we're heading
                    from_ori = (orientation + 2) % 4  # 0123=NESW -> 2301=SWNE
483
484
                    from_xy = coords[from_ori]

485
                    tMoves = env.rail.get_transitions((r, c, orientation))
486

hagrid67's avatar
hagrid67 committed
487
488
489
                    for to_ori in range(4):
                        to_xy = coords[to_ori]
                        rotation = (to_ori - from_ori) % 4
490

hagrid67's avatar
hagrid67 committed
491
                        if (tMoves[to_ori]):  # if we have this transition
492

hagrid67's avatar
hagrid67 committed
493
                            if bDeadEnd:
494
                                self.drawTrans2(
hagrid67's avatar
hagrid67 committed
495
                                    array([from_xy, to_xy]), xyCentre,
496
497
                                    rotation, bDeadEnd=True, spacing=spacing,
                                    sColor=sRailColor)
498

hagrid67's avatar
hagrid67 committed
499
                            else:
500

hagrid67's avatar
hagrid67 committed
501
                                if curves:
502
                                    self.drawTrans2(
hagrid67's avatar
hagrid67 committed
503
                                        array([from_xy, to_xy]), xyCentre,
504
505
                                        rotation, spacing=spacing, bArrow=arrows,
                                        sColor=sRailColor)
hagrid67's avatar
hagrid67 committed
506
                                else:
507
                                    self.drawTrans(self, from_xy, to_xy, sRailColor)
hagrid67's avatar
hagrid67 committed
508
509
510
511
512
513
514
515
516
517
518

                            if False:
                                print(
                                    "r,c,ori: ", r, c, orientation,
                                    "cell:", "{0:b}".format(oCell),
                                    "moves:", tMoves,
                                    "from:", from_ori, from_xy,
                                    "to: ", to_ori, to_xy,
                                    "cen:", *xyCentre,
                                    "rot:", rotation,
                                )
519

520
    def renderEnv(self,
u214892's avatar
u214892 committed
521
522
523
524
525
526
527
528
529
530
531
532
533
                  show=False,  # whether to call matplotlib show() or equivalent after completion
                  # use false when calling from Jupyter.  (and matplotlib no longer supported!)
                  curves=True,  # draw turns as curves instead of straight diagonal lines
                  spacing=False,  # defunct - size of spacing between rails
                  arrows=False,  # defunct - draw arrows on rail lines
                  agents=True,  # whether to include agents
                  show_observations=True,  # whether to include observations
                  sRailColor="gray",  # color to use in drawing rails (not used with SVG)
                  frames=False,  # frame counter to show (intended since invocation)
                  iEpisode=None,  # int episode number to show
                  iStep=None,  # int step number to show in image
                  iSelectedAgent=None,  # indicate which agent is "selected" in the editor
                  action_dict=None):  # defunct - was used to indicate agent intention to turn
534
535
        """ Draw the environment using the GraphicsLayer this RenderTool was created with.
            (Use show=False from a Jupyter notebook with %matplotlib inline)
536
537
538
        """

        if not self.gl.is_raster():
hagrid67's avatar
hagrid67 committed
539
540
541
542
543
            self.renderEnv2(show=show, curves=curves, spacing=spacing,
                            arrows=arrows, agents=agents, show_observations=show_observations,
                            sRailColor=sRailColor,
                            frames=frames, iEpisode=iEpisode, iStep=iStep,
                            iSelectedAgent=iSelectedAgent, action_dict=action_dict)
544
545
            return

546
        if type(self.gl) is PILGL:
547
548
549
550
551
552
            self.gl.beginFrame()

        env = self.env

        self.renderRail()

553
        # Draw each agent + its orientation + its target
hagrid67's avatar
hagrid67 committed
554
        if agents:
hagrid67's avatar
hagrid67 committed
555
            self.plotAgents(targets=True, iSelectedAgent=iSelectedAgent)
556
        if show_observations:
Erik Nygren's avatar
Erik Nygren committed
557
            self.renderObs(range(env.get_num_agents()), env.dev_obs_dict)
558
        # Draw some textual information like fps
559
        yText = [-0.3, -0.6, -0.9]
560
        if frames:
561
            self.gl.text(0.1, yText[2], "Frame:{:}".format(self.iFrame))
562
        self.iFrame += 1
spiglerg's avatar
spiglerg committed
563

564
        if iEpisode is not None:
565
            self.gl.text(0.1, yText[1], "Ep:{}".format(iEpisode))
566
567

        if iStep is not None:
568
            self.gl.text(0.1, yText[0], "Step:{}".format(iStep))
569
570

        tNow = time.time()
571
        self.gl.text(2, yText[2], "elapsed:{:.2f}s".format(tNow - self.time1))
572
573
574
575
576
        self.lTimes.append(tNow)
        if len(self.lTimes) > 20:
            self.lTimes.popleft()
        if len(self.lTimes) > 1:
            rFps = (len(self.lTimes) - 1) / (self.lTimes[-1] - self.lTimes[0])
577
            self.gl.text(2, yText[1], "fps:{:.2f}".format(rFps))
578

579
        self.gl.prettify2(env.width, env.height, self.nPixCell)
580

581
582
        # TODO: for MPL, we don't want to call clf (called by endframe)
        # if not show:
583

584
585
586
        if show and type(self.gl) is PILGL:
            self.gl.show()

587
        self.gl.pause(0.00001)
588
589

        return
590

591
    def _draw_square(self, center, size, color, opacity=255, layer=0):
spiglerg's avatar
spiglerg committed
592
593
594
595
        x0 = center[0] - size / 2
        x1 = center[0] + size / 2
        y0 = center[1] - size / 2
        y1 = center[1] + size / 2
596
        self.gl.plot([x0, x1, x1, x0, x0], [y0, y0, y1, y1, y0], color=color, layer=layer, opacity=opacity)
hagrid67's avatar
hagrid67 committed
597
598
599

    def getImage(self):
        return self.gl.getImage()
hagrid67's avatar
hagrid67 committed
600
601
602
603
604
605
606

    def plotTreeObs(self, gObs):
        nBranchFactor = 4

        gP0 = array([[0, 0, 0]]).T
        nDepth = 2
        for i in range(nDepth):
Erik Nygren's avatar
Erik Nygren committed
607
608
609
610
            nDepthNodes = nBranchFactor ** i
            rShrinkDepth = 1 / (i + 1)

            gX1 = np.linspace(-(nDepthNodes - 1), (nDepthNodes - 1), nDepthNodes) * rShrinkDepth
hagrid67's avatar
hagrid67 committed
611
612
            gY1 = np.ones((nDepthNodes)) * i
            gZ1 = np.zeros((nDepthNodes))
Erik Nygren's avatar
Erik Nygren committed
613

hagrid67's avatar
hagrid67 committed
614
615
            gP1 = array([gX1, gY1, gZ1])
            gP01 = np.append(gP0, gP1, axis=1)
Erik Nygren's avatar
Erik Nygren committed
616

hagrid67's avatar
hagrid67 committed
617
618
619
620
621
622
623
624
            if nDepthNodes > 1:
                nDepthNodesPrev = nDepthNodes / nBranchFactor
                giP0 = np.repeat(np.arange(nDepthNodesPrev), nBranchFactor)
                giP1 = np.arange(0, nDepthNodes) + nDepthNodesPrev
                giLinePoints = np.stack([giP0, giP1]).ravel("F")
                self.gl.plot(gP01[0], -gP01[1], lines=giLinePoints, color="gray")

            gP0 = array([gX1, gY1, gZ1])
Erik Nygren's avatar
Erik Nygren committed
625

hagrid67's avatar
hagrid67 committed
626
    def renderEnv2(
u214892's avatar
u214892 committed
627
628
629
630
631
        self, show=False, curves=True, spacing=False, arrows=False, agents=True,
        show_observations=True, sRailColor="gray",
        frames=False, iEpisode=None, iStep=None, iSelectedAgent=None,
        action_dict=dict()
    ):
632
633
634
635
636
637
638
639
640
641
        """
        Draw the environment using matplotlib.
        Draw into the figure if provided.

        Call pyplot.show() if show==True.
        (Use show=False from a Jupyter notebook with %matplotlib inline)
        """

        env = self.env

hagrid67's avatar
hagrid67 committed
642
643
        self.gl.beginFrame()

644
645
646
        if self.new_rail:
            self.new_rail = False
            self.gl.clear_rails()
hagrid67's avatar
hagrid67 committed
647
648
649

            # store the targets
            dTargets = {}
650
            dSelected = {}
hagrid67's avatar
hagrid67 committed
651
652
653
            for iAgent, agent in enumerate(self.env.agents_static):
                if agent is None:
                    continue
654
                dTargets[tuple(agent.target)] = iAgent
u214892's avatar
u214892 committed
655
                dSelected[tuple(agent.target)] = (iAgent == iSelectedAgent)
hagrid67's avatar
hagrid67 committed
656

657
658
659
660
            # Draw each cell independently
            for r in range(env.height):
                for c in range(env.width):
                    binTrans = env.rail.grid[r, c]
hagrid67's avatar
hagrid67 committed
661
662
                    if (r, c) in dTargets:
                        target = dTargets[(r, c)]
663
                        isSelected = dSelected[(r, c)]
hagrid67's avatar
hagrid67 committed
664
665
                    else:
                        target = None
666
667
                        isSelected = False

668
                    self.gl.setRailAt(r, c, binTrans, iTarget=target, isSelected=isSelected, rail_grid=env.rail.grid)
669

670
671
            self.gl.build_background_map(dTargets)

672
        for iAgent, agent in enumerate(self.env.agents):
673

674
675
676
            if agent is None:
                continue

677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
            if self.agentRenderVariant == AgentRenderVariant.BOX_ONLY:
                self.gl.setCellOccupied(iAgent, *(agent.position))
            elif self.agentRenderVariant == AgentRenderVariant.ONE_STEP_BEHIND or \
                    self.agentRenderVariant == AgentRenderVariant.ONE_STEP_BEHIND_AND_BOX:
                if agent.old_position is not None:
                    position = agent.old_position
                    direction = agent.direction
                    old_direction = agent.old_direction
                else:
                    position = agent.position
                    direction = agent.direction
                    old_direction = agent.direction

                # setAgentAt uses the agent index for the color
                if self.agentRenderVariant == AgentRenderVariant.ONE_STEP_BEHIND_AND_BOX:
                    self.gl.setCellOccupied(iAgent, *(agent.position))
                self.gl.setAgentAt(iAgent, *position, old_direction, direction, iSelectedAgent == iAgent)
694
            else:
695
696
                position = agent.position
                direction = agent.direction
697
698
699
700
701
702
703
704
705
706
707
708
709
                for possible_directions in range(4):
                    # Is a transition along movement `desired_movement_from_new_cell' to the current cell possible?
                    isValid = env.rail.get_transition((*agent.position, agent.direction), possible_directions)
                    if isValid:
                        direction = possible_directions

                        # setAgentAt uses the agent index for the color
                        self.gl.setAgentAt(iAgent, *position, agent.direction, direction, iSelectedAgent == iAgent)

                # setAgentAt uses the agent index for the color
                if self.agentRenderVariant == AgentRenderVariant.AGENT_SHOWS_OPTIONS_AND_BOX:
                    self.gl.setCellOccupied(iAgent, *(agent.position))
                self.gl.setAgentAt(iAgent, *position, agent.direction, direction, iSelectedAgent == iAgent)
Erik Nygren's avatar
Erik Nygren committed
710

711
712
713
        if show_observations:
            self.renderObs(range(env.get_num_agents()), env.dev_obs_dict)

hagrid67's avatar
hagrid67 committed
714
715
        if show:
            self.gl.show()
hagrid67's avatar
hagrid67 committed
716
717
        for i in range(3):
            self.gl.processEvents()
718
719
720

        self.iFrame += 1
        return
hagrid67's avatar
hagrid67 committed
721
722
723

    def close_window(self):
        self.gl.close_window()