rendertools.py 31.7 KB
Newer Older
u214892's avatar
u214892 committed
1
import time
2
import warnings
u214892's avatar
u214892 committed
3
from collections import deque
4
from enum import IntEnum
5

hagrid67's avatar
hagrid67 committed
6
import numpy as np
u214892's avatar
u214892 committed
7
8
from numpy import array
from recordtype import recordtype
Erik Nygren's avatar
Erik Nygren committed
9

10
from flatland.envs.step_utils.states import TrainState
11

u214892's avatar
u214892 committed
12
from flatland.utils.graphics_pil import PILGL, PILSVG
13
from flatland.utils.graphics_pgl import PGLGL
14

u214892's avatar
u214892 committed
15
16

# TODO: suggested renaming to RailEnvRenderTool, as it will only work with RailEnv!
hagrid67's avatar
hagrid67 committed
17

18
19
20
21
22
23
24
class AgentRenderVariant(IntEnum):
    BOX_ONLY = 0
    ONE_STEP_BEHIND = 1
    AGENT_SHOWS_OPTIONS = 2
    ONE_STEP_BEHIND_AND_BOX = 3
    AGENT_SHOWS_OPTIONS_AND_BOX = 4

25

26
class RenderTool(object):
27
28
    """ RenderTool is a facade to a renderer.
        (This was introduced for the Browser / JS renderer which has now been removed.)
29
    """
30
    def __init__(self, env, gl="PGL", jupyter=False,
31
32
33
34
35
36
37
38
39
40
41
                 agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
                 show_debug=False, clear_debug_text=True, screen_width=800, screen_height=600,
                 host="localhost", port=None):

        self.env = env
        self.frame_nr = 0
        self.start_time = time.time()
        self.times_list = deque()

        self.agent_render_variant = agent_render_variant

42
        if gl in ["PIL", "PILSVG", "PGL"]:
43
44
45
46
47
            self.renderer = RenderLocal(env, gl, jupyter,
                 agent_render_variant,
                 show_debug, clear_debug_text, screen_width, screen_height)
            self.gl = self.renderer.gl
        else:
48
            print("[", gl, "] not found, switch to PGL")
49
50
51
52
53
54
55

    def render_env(self,
                   show=False,  # whether to call matplotlib show() or equivalent after completion
                   show_agents=True,  # whether to include agents
                   show_inactive_agents=False,  # whether to show agents before they start
                   show_observations=True,  # whether to include observations
                   show_predictions=False,  # whether to include predictions
hagrid67's avatar
hagrid67 committed
56
                   show_rowcols=False, # label the rows and columns
57
58
59
60
61
62
                   frames=False,  # frame counter to show (intended since invocation)
                   episode=None,  # int episode number to show
                   step=None,  # int step number to show in image
                   selected_agent=None,  # indicate which agent is "selected" in the editor):
                   return_image=False): # indicate if image is returned for use in monitor:
        return self.renderer.render_env(show, show_agents, show_inactive_agents, show_observations,
hagrid67's avatar
hagrid67 committed
63
                    show_predictions, show_rowcols, frames, episode, step, selected_agent, return_image)
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125

    def close_window(self):
        self.renderer.close_window()

    def reset(self):
        self.renderer.reset()
    
    def set_new_rail(self):
        self.renderer.set_new_rail()
        self.renderer.env = self.env  # bit of a hack - copy our env to the delegate

    def update_background(self):
        self.renderer.update_background()
    
    def get_endpoint_URL(self):
        """ Returns a string URL for the root of the HTTP server
            TODO: Need to update this work work on a remote server!  May be tricky...
        """
        #return "http://localhost:{}".format(self.renderer.get_port())
        if hasattr(self.renderer, "get_endpoint_url"):
            return self.renderer.get_endpoint_url()
        else:
            print("Attempt to get_endpoint_url from RenderTool - only supported with BROWSER")
            return None

    def get_image(self):
        """ 
        """
        if hasattr(self.renderer, "gl"):
            return self.renderer.gl.get_image()
        else:
            print("Attempt to retrieve image from RenderTool - not supported with BROWSER")
            return None


class RenderBase(object):
    def __init__(self, env):
        pass

    def render_env(self):
        pass

    def close_window(self):
        pass

    def reset(self):
        pass

    def set_new_rail(self):
        """ Signal to the renderer that the env has changed and will need re-rendering.
        """
        pass

    def update_background(self):
        """ A lesser version of set_new_rail?  
            TODO: can update_background be pruned for simplicity?
        """
        pass



class RenderLocal(RenderBase):
126
127
128
129
130
    """ Class to render the RailEnv and agents.
        Uses two layers, layer 0 for rails (mostly static), layer 1 for agents etc (dynamic)
        The lower / rail layer 0 is only redrawn after set_new_rail() has been called.
        Created with a "GraphicsLayer" or gl - now either PIL or PILSVG
    """
Erik Nygren's avatar
Erik Nygren committed
131
    visit = recordtype("visit", ["rc", "iDir", "iDepth", "prev"])
132

Erik Nygren's avatar
Erik Nygren committed
133
    color_list = list("brgcmyk")
134
    # \delta RC for NESW
Erik Nygren's avatar
Erik Nygren committed
135
136
137
138
139
140
141
142
143
    transitions_row_col = np.array([[-1, 0], [0, 1], [1, 0], [0, -1]])
    pix_per_cell = 1  # misnomer...
    half_pix_per_cell = pix_per_cell / 2
    x_y_half = array([half_pix_per_cell, -half_pix_per_cell])
    row_col_to_xy = array([[0, -pix_per_cell], [pix_per_cell, 0]])
    grid = array(np.meshgrid(np.arange(10), -np.arange(10))) * array([[[pix_per_cell]], [[pix_per_cell]]])
    theta = np.linspace(0, np.pi / 2, 5)
    arc = array([np.cos(theta), np.sin(theta)]).T  # from [1,0] to [0,1]

144
    def __init__(self, env, gl="PILSVG", jupyter=False,
Egli Adrian (IT-SCI-API-PFI)'s avatar
fix    
Egli Adrian (IT-SCI-API-PFI) committed
145
                 agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
Erik Nygren's avatar
Erik Nygren committed
146
                 show_debug=False, clear_debug_text=True, screen_width=800, screen_height=600):
147

148
        self.env = env
Erik Nygren's avatar
Erik Nygren committed
149
150
151
        self.frame_nr = 0
        self.start_time = time.time()
        self.times_list = deque()
152

Erik Nygren's avatar
Erik Nygren committed
153
        self.agent_render_variant = agent_render_variant
154

155
156
        self.gl_str = gl

157
        if gl == "PIL":
158
            self.gl = PILGL(env.width, env.height, jupyter, screen_width=screen_width, screen_height=screen_height)
hagrid67's avatar
hagrid67 committed
159
        elif gl == "PILSVG":
160
            self.gl = PILSVG(env.width, env.height, jupyter, screen_width=screen_width, screen_height=screen_height)
161
        else:
162
163
164
165
            if gl != "PGL":
                print("[", gl, "] not found, switch to PGL, PILSVG")
                print("Using PGL")
            self.gl = PGLGL(env.width, env.height, jupyter, screen_width=screen_width, screen_height=screen_height)
Erik Nygren's avatar
Erik Nygren committed
166

167
        self.new_rail = True
168
        self.show_debug = show_debug
169
        self.clear_debug_text = clear_debug_text
170
171
        self.update_background()

172
173
174
175
176
177
178
179
180
181
182
    def reset(self):
        """
        Resets the environment
        :return:
        """
        self.set_new_rail()
        self.frame_nr = 0
        self.start_time = time.time()
        self.times_list = deque()
        return

183
184
    def update_background(self):
        # create background map
Erik Nygren's avatar
Erik Nygren committed
185
        targets = {}
u229589's avatar
u229589 committed
186
        for agent_idx, agent in enumerate(self.env.agents):
187
188
            if agent is None:
                continue
189
            #print(f"updatebg: {agent_idx} {agent.target}")
Erik Nygren's avatar
Erik Nygren committed
190
191
            targets[tuple(agent.target)] = agent_idx
        self.gl.build_background_map(targets)
192

193
194
195
    def resize(self):
        self.gl.resize(self.env)

196
    def set_new_rail(self):
197
198
199
        """ Tell the renderer that the rail has changed.
            eg when the rail has been regenerated, or updated in the editor.
        """
200
        self.new_rail = True
201

Erik Nygren's avatar
Erik Nygren committed
202
    def plot_agents(self, targets=True, selected_agent=None):
u229589's avatar
u229589 committed
203
        color_map = self.gl.get_cmap('hsv', lut=(len(self.env.agents) + 1))
204

u229589's avatar
u229589 committed
205
        for agent_idx, agent in enumerate(self.env.agents):
206
207
            if agent is None:
                continue
Erik Nygren's avatar
Erik Nygren committed
208
209
210
            color = color_map(agent_idx)
            self.plot_single_agent(agent.position, agent.direction, color, target=agent.target if targets else None,
                                   static=True, selected=agent_idx == selected_agent)
hagrid67's avatar
hagrid67 committed
211

Erik Nygren's avatar
Erik Nygren committed
212
        for agent_idx, agent in enumerate(self.env.agents):
213
214
            if agent is None:
                continue
Erik Nygren's avatar
Erik Nygren committed
215
216
            color = color_map(agent_idx)
            self.plot_single_agent(agent.position, agent.direction, color, target=agent.target if targets else None)
217

Erik Nygren's avatar
Erik Nygren committed
218
    def get_transition_row_col(self, row_col_pos, direction, bgiTrans=False):
219
        """
Erik Nygren's avatar
Erik Nygren committed
220
        Get the available transitions for row_col_pos in direction direction,
221
222
223
224
        as row & col deltas.

        If bgiTrans is True, return a grid of indices of available transitions.

Erik Nygren's avatar
Erik Nygren committed
225
        eg for a cell row_col_pos = (4,5), in direction direction = 0 (N),
226
227
228
229
230
231
232
        where the available transitions are N and E, returns:
        [[-1,0], [0,1]] ie N=up one row, and E=right one col.
        and if bgiTrans is True, returns a tuple:
        (
            [[-1,0], [0,1]], # deltas as before
            [0, 1] #  available transition indices, ie N, E
        )
233
234
        """

Erik Nygren's avatar
Erik Nygren committed
235
236
        transitions = self.env.rail.get_transitions(*row_col_pos, direction)
        transition_list = np.where(transitions)[0]  # RC list of transitions
237

238
        # HACK: workaround dead-end transitions
Erik Nygren's avatar
Erik Nygren committed
239
240
241
242
        if len(transition_list) == 0:
            reverse_direciton = (direction + 2) % 4
            transitions = tuple(int(tmp_dir == reverse_direciton) for tmp_dir in range(4))
            transition_list = np.where(transitions)[0]  # RC list of transitions
243

Erik Nygren's avatar
Erik Nygren committed
244
        transition_grid = self.__class__.transitions_row_col[transition_list]
245
246

        if bgiTrans:
Erik Nygren's avatar
Erik Nygren committed
247
            return transition_grid, transition_list
248
        else:
Erik Nygren's avatar
Erik Nygren committed
249
            return transition_grid
250

Erik Nygren's avatar
Erik Nygren committed
251
    def plot_single_agent(self, position_row_col, direction, color="r", target=None, static=False, selected=False):
252
253
        """
        Plot a simple agent.
254
        Assumes a working graphics layer context (cf a MPL figure).
255
        """
u214892's avatar
u214892 committed
256
257
258
        if position_row_col is None:
            return

259
        rt = self.__class__
260

Erik Nygren's avatar
Erik Nygren committed
261
262
        direction_row_col = rt.transitions_row_col[direction]  # agent direction in RC
        direction_xy = np.matmul(direction_row_col, rt.row_col_to_xy)  # agent direction in xy
263

Erik Nygren's avatar
Erik Nygren committed
264
        xyPos = np.matmul(position_row_col - direction_row_col / 2, rt.row_col_to_xy) + rt.x_y_half
hagrid67's avatar
hagrid67 committed
265
266

        if static:
Erik Nygren's avatar
Erik Nygren committed
267
            color = self.gl.adapt_color(color, lighten=True)
hagrid67's avatar
hagrid67 committed
268

269
270
271
        color = color

        self.gl.scatter(*xyPos, color=color, layer=1, marker="o", s=100)  # agent location
Erik Nygren's avatar
Erik Nygren committed
272
273
        xy_dir_line = array([xyPos, xyPos + direction_xy / 2]).T  # line for agent orient.
        self.gl.plot(*xy_dir_line, color=color, layer=1, lw=5, ms=0, alpha=0.6)
hagrid67's avatar
hagrid67 committed
274
275
        if selected:
            self._draw_square(xyPos, 1, color)
276
277

        if target is not None:
Erik Nygren's avatar
Erik Nygren committed
278
279
280
            target_row_col = array(target)
            target_xy = np.matmul(target_row_col, rt.row_col_to_xy) + rt.x_y_half
            self._draw_square(target_xy, 1 / 3, color, layer=1)
281

Erik Nygren's avatar
Erik Nygren committed
282
    def plot_transition(self, position_row_col, transition_row_col, color="r", depth=None):
hagrid67's avatar
hagrid67 committed
283
        """
Erik Nygren's avatar
Erik Nygren committed
284
285
        plot the transitions in transition_row_col at position position_row_col.
        transition_row_col is a 2d numpy array containing a list of RC transitions,
hagrid67's avatar
hagrid67 committed
286
287
288
289
        eg [[-1,0], [0,1]] means N, E.

        """

290
        rt = self.__class__
Erik Nygren's avatar
Erik Nygren committed
291
292
293
        position_xy = np.matmul(position_row_col, rt.row_col_to_xy) + rt.x_y_half
        transition_xy = position_xy + np.matmul(transition_row_col, rt.row_col_to_xy / 2.4)
        self.gl.scatter(*transition_xy.T, color=color, marker="o", s=50, alpha=0.2)
294
        if depth is not None:
Erik Nygren's avatar
Erik Nygren committed
295
            for x, y in transition_xy:
296
                self.gl.text(x, y, depth)
297

298
299
300
301
302
303
304
305
306
    def draw_transition(self,
                        line,
                        center,
                        rotation,
                        dead_end=False,
                        curves=False,
                        color="gray",
                        arrow=True,
                        spacing=0.1):
307
308
309
310
311
312
313
314
        """
        gLine is a numpy 2d array of points,
        in the plotting space / coords.
        eg:
        [[0,.5],[1,0.2]] means a line
        from x=0, y=0.5
        to   x=1, y=0.2
        """
315
316

        if not curves and not dead_end:
317
            # just a straigt line, no curve nor dead_end included in this basic rail element
318
319
320
321
322
            self.gl.plot(
                [line[0][0], line[1][0]],  # x
                [line[0][1], line[1][1]],  # y
                color=color
            )
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
        else:
            # it was not a simple line to draw: the rail has a curve or dead_end included.
            rt = self.__class__
            straight = rotation in [0, 2]
            dx, dy = np.squeeze(np.diff(line, axis=0)) * spacing / 2

            if straight:

                if color == "auto":
                    if dx > 0 or dy > 0:
                        color = "C1"  # N or E
                    else:
                        color = "C2"  # S or W

                if dead_end:
                    line_xy = array([
                        line[1] + [dy, dx],
                        center,
                        line[1] - [dy, dx],
                    ])
                    self.gl.plot(*line_xy.T, color=color)
                else:
                    line_xy = line + [-dy, dx]
                    self.gl.plot(*line_xy.T, color=color)
347

348
349
                    if arrow:
                        middle_xy = np.sum(line_xy * [[1 / 4], [3 / 4]], axis=0)
350

351
352
353
354
355
                        arrow_xy = array([
                            middle_xy + [-dx - dy, +dx - dy],
                            middle_xy,
                            middle_xy + [-dx + dy, -dx - dy]])
                        self.gl.plot(*arrow_xy.T, color=color)
356
357
358

            else:

359
360
361
362
363
364
365
366
367
368
                middle_xy = np.mean(line, axis=0)
                dxy = middle_xy - center
                corner = middle_xy + dxy
                if rotation == 1:
                    arc_factor = 1 - spacing
                    color_auto = "C1"
                else:
                    arc_factor = 1 + spacing
                    color_auto = "C2"
                dxy2 = (center - corner) * arc_factor  # for scaling the arc
369

370
371
372
373
374
375
376
377
378
                if color == "auto":
                    color = color_auto

                self.gl.plot(*(rt.arc * dxy2 + corner).T, color=color)

                if arrow:
                    dx, dy = np.squeeze(np.diff(line, axis=0)) / 20
                    iArc = int(len(rt.arc) / 2)
                    middle_xy = corner + rt.arc[iArc] * dxy2
Erik Nygren's avatar
Erik Nygren committed
379
380
381
382
383
                    arrow_xy = array([
                        middle_xy + [-dx - dy, +dx - dy],
                        middle_xy,
                        middle_xy + [-dx + dy, -dx - dy]])
                    self.gl.plot(*arrow_xy.T, color=color)
384

Erik Nygren's avatar
Erik Nygren committed
385
    def render_observation(self, agent_handles, observation_dict):
Erik Nygren's avatar
Erik Nygren committed
386
        """
387
388
        Render the extent of the observation of each agent. All cells that appear in the agent
        observation will be highlighted.
389
390
        :param agent_handles: List of agent indices to adapt color and get correct observation
        :param observation_dict: dictionary containing sets of cells of the agent observation
Erik Nygren's avatar
Erik Nygren committed
391
392
393

        """
        rt = self.__class__
394
395
396

        # Check if the observation builder provides an observation
        if len(observation_dict) < 1:
397
398
            warnings.warn(
                "Predictor did not provide any predicted cells to render. \
399
                Observation builder needs to populate: env.dev_obs_dict")
400
401
402
403
404
405
406
        else:
            for agent in agent_handles:
                color = self.gl.get_agent_color(agent)
                for visited_cell in observation_dict[agent]:
                    cell_coord = array(visited_cell[:2])
                    cell_coord_trans = np.matmul(cell_coord, rt.row_col_to_xy) + rt.x_y_half
                    self._draw_square(cell_coord_trans, 1 / (agent + 1.1), color, layer=1, opacity=100)
Erik Nygren's avatar
Erik Nygren committed
407
408
409
410
411
412
413
414
415
416

    def render_prediction(self, agent_handles, prediction_dict):
        """
        Render the extent of the observation of each agent. All cells that appear in the agent
        observation will be highlighted.
        :param agent_handles: List of agent indices to adapt color and get correct observation
        :param observation_dict: dictionary containing sets of cells of the agent observation

        """
        rt = self.__class__
417
        if len(prediction_dict) < 1:
418
419
420
            warnings.warn(
                "Predictor did not provide any predicted cells to render. \
                Predictors builder needs to populate: env.dev_pred_dict")
421
422
423
424
425
426
427
428
429
430
431
432
433
434
        else:
            for agent in agent_handles:
                color = self.gl.get_agent_color(agent)
                for visited_cell in prediction_dict[agent]:
                    cell_coord = array(visited_cell[:2])
                    if type(self.gl) is PILSVG:
                        # TODO : Track highlighting (Adrian)
                        r = cell_coord[0]
                        c = cell_coord[1]
                        transitions = self.env.rail.grid[r, c]
                        self.gl.set_predicion_path_at(r, c, transitions, agent_rail_color=color)
                    else:
                        cell_coord_trans = np.matmul(cell_coord, rt.row_col_to_xy) + rt.x_y_half
                        self._draw_square(cell_coord_trans, 1 / (agent + 1.1), color, layer=1, opacity=100)
Erik Nygren's avatar
Erik Nygren committed
435

Erik Nygren's avatar
Erik Nygren committed
436
    def render_rail(self, spacing=False, rail_color="gray", curves=True, arrows=False):
437

438
        cell_size = 1  # TODO: remove cell_size
439
440
441
        env = self.env

        # Draw cells grid
442
        grid_color = [0.95, 0.95, 0.95]
Erik Nygren's avatar
Erik Nygren committed
443
        for row in range(env.height + 1):
spiglerg's avatar
spiglerg committed
444
            self.gl.plot([0, (env.width + 1) * cell_size],
Erik Nygren's avatar
Erik Nygren committed
445
                         [-row * cell_size, -row * cell_size],
hagrid67's avatar
hagrid67 committed
446
                         color=grid_color, linewidth=2)
Erik Nygren's avatar
Erik Nygren committed
447
448
        for col in range(env.width + 1):
            self.gl.plot([col * cell_size, col * cell_size],
spiglerg's avatar
spiglerg committed
449
                         [0, -(env.height + 1) * cell_size],
hagrid67's avatar
hagrid67 committed
450
                         color=grid_color, linewidth=2)
451
452

        # Draw each cell independently
Erik Nygren's avatar
Erik Nygren committed
453
454
        for row in range(env.height):
            for col in range(env.width):
455

hagrid67's avatar
hagrid67 committed
456
                # bounding box of the grid cell
Erik Nygren's avatar
Erik Nygren committed
457
458
459
460
                x0 = cell_size * col  # left
                x1 = cell_size * (col + 1)  # right
                y0 = cell_size * -row  # top
                y1 = cell_size * -(row + 1)  # bottom
461

hagrid67's avatar
hagrid67 committed
462
                # centres of cell edges
463
                coords = [
spiglerg's avatar
spiglerg committed
464
465
466
                    ((x0 + x1) / 2.0, y0),  # N middle top
                    (x1, (y0 + y1) / 2.0),  # E middle right
                    ((x0 + x1) / 2.0, y1),  # S middle bottom
Erik Nygren's avatar
Erik Nygren committed
467
                    (x0, (y0 + y1) / 2.0)  # W middle left
468
469
                ]

hagrid67's avatar
hagrid67 committed
470
                # cell centre
Erik Nygren's avatar
Erik Nygren committed
471
                center_xy = array([x0, y1]) + cell_size / 2
hagrid67's avatar
hagrid67 committed
472

hagrid67's avatar
hagrid67 committed
473
                # cell transition values
Erik Nygren's avatar
Erik Nygren committed
474
                cell = env.rail.get_full_transitions(row, col)
475

Erik Nygren's avatar
Erik Nygren committed
476
                cell_valid = env.rail.cell_neighbours_valid((row, col), check_this_cell=True)
477

hagrid67's avatar
hagrid67 committed
478
479
                # Special Case 7, with a single bit; terminate at center
                nbits = 0
Erik Nygren's avatar
Erik Nygren committed
480
                tmp = cell
hagrid67's avatar
hagrid67 committed
481
482
483
484
485
486
487

                while tmp > 0:
                    nbits += (tmp & 1)
                    tmp = tmp >> 1

                # as above - move the from coord to the centre
                # it's a dead env.
Erik Nygren's avatar
Erik Nygren committed
488
                is_dead_end = nbits == 1
hagrid67's avatar
hagrid67 committed
489

Erik Nygren's avatar
Erik Nygren committed
490
491
                if not cell_valid:
                    self.gl.scatter(*center_xy, color="r", s=30)
492

493
494
                for orientation in range(4):  # ori is where we're heading
                    from_ori = (orientation + 2) % 4  # 0123=NESW -> 2301=SWNE
495
496
                    from_xy = coords[from_ori]

Erik Nygren's avatar
Erik Nygren committed
497
                    moves = env.rail.get_transitions(row, col, orientation)
498

hagrid67's avatar
hagrid67 committed
499
500
501
                    for to_ori in range(4):
                        to_xy = coords[to_ori]
                        rotation = (to_ori - from_ori) % 4
Erik Nygren's avatar
Erik Nygren committed
502
                        if (moves[to_ori]):  # if we have this transition
503
504
505
506
                            self.draw_transition(
                                array([from_xy, to_xy]), center_xy,
                                rotation, dead_end=is_dead_end, curves=curves and not is_dead_end, spacing=spacing,
                                color=rail_color)
507

Erik Nygren's avatar
Erik Nygren committed
508
509
    def render_env(self,
                   show=False,  # whether to call matplotlib show() or equivalent after completion
510
                   show_agents=True,  # whether to include agents
511
                   show_inactive_agents=False,
Erik Nygren's avatar
Erik Nygren committed
512
                   show_observations=True,  # whether to include observations
513
                   show_predictions=False,  # whether to include predictions
hagrid67's avatar
hagrid67 committed
514
                   show_rowcols=False,  # label the rows and columns
Erik Nygren's avatar
Erik Nygren committed
515
516
517
                   frames=False,  # frame counter to show (intended since invocation)
                   episode=None,  # int episode number to show
                   step=None,  # int step number to show in image
518
519
                   selected_agent=None,  # indicate which agent is "selected" in the editor
                   return_image=False): # indicate if image is returned for use in monitor:
520
521
        """ Draw the environment using the GraphicsLayer this RenderTool was created with.
            (Use show=False from a Jupyter notebook with %matplotlib inline)
522
        """
523
524

        # if type(self.gl) is PILSVG:
525
        if self.gl_str in ["PILSVG", "PGL"]:
526
            return self.render_env_svg(show=show,
527
528
                                show_observations=show_observations,
                                show_predictions=show_predictions,
529
                                selected_agent=selected_agent,
530
531
                                show_agents=show_agents,
                                show_inactive_agents=show_inactive_agents,
hagrid67's avatar
hagrid67 committed
532
                                show_rowcols=show_rowcols,
533
                                return_image=return_image
534
535
                                )
        else:
536
            return self.render_env_pil(show=show,
537
                                show_agents=show_agents,
538
                                show_inactive_agents=show_inactive_agents,
539
540
                                show_observations=show_observations,
                                show_predictions=show_predictions,
hagrid67's avatar
hagrid67 committed
541
                                show_rowcols=show_rowcols,
542
543
544
                                frames=frames,
                                episode=episode,
                                step=step,
545
546
                                selected_agent=selected_agent,
                                return_image=return_image
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
                                )

    def _draw_square(self, center, size, color, opacity=255, layer=0):
        x0 = center[0] - size / 2
        x1 = center[0] + size / 2
        y0 = center[1] - size / 2
        y1 = center[1] + size / 2
        self.gl.plot([x0, x1, x1, x0, x0], [y0, y0, y1, y1, y0], color=color, layer=layer, opacity=opacity)

    def get_image(self):
        return self.gl.get_image()

    def render_env_pil(self,
                       show=False,  # whether to call matplotlib show() or equivalent after completion
                       # use false when calling from Jupyter.  (and matplotlib no longer supported!)
562
                       show_agents=True,  # whether to include agents
563
                       show_inactive_agents=False, 
564
565
                       show_observations=True,  # whether to include observations
                       show_predictions=False,  # whether to include predictions
hagrid67's avatar
hagrid67 committed
566
                       show_rowcols=False, # label the rows and columns
567
568
569
                       frames=False,  # frame counter to show (intended since invocation)
                       episode=None,  # int episode number to show
                       step=None,  # int step number to show in image
570
571
                       selected_agent=None,  # indicate which agent is "selected" in the editor
                       return_image=False # indicate if image is returned for use in monitor:
572
                       ):
573

574
        if type(self.gl) is PILGL:
Erik Nygren's avatar
Erik Nygren committed
575
            self.gl.begin_frame()
576
577
578

        env = self.env

Erik Nygren's avatar
Erik Nygren committed
579
        self.render_rail()
580

581
        # Draw each agent + its orientation + its target
582
        if show_agents:
Erik Nygren's avatar
Erik Nygren committed
583
            self.plot_agents(targets=True, selected_agent=selected_agent)
584
        if show_observations:
Erik Nygren's avatar
Erik Nygren committed
585
            self.render_observation(range(env.get_num_agents()), env.dev_obs_dict)
586
        if show_predictions and len(env.dev_pred_dict) > 0:
Erik Nygren's avatar
Erik Nygren committed
587
            self.render_prediction(range(env.get_num_agents()), env.dev_pred_dict)
588
        # Draw some textual information like fps
Erik Nygren's avatar
Erik Nygren committed
589
        text_y = [-0.3, -0.6, -0.9]
590
        if frames:
Erik Nygren's avatar
Erik Nygren committed
591
592
            self.gl.text(0.1, text_y[2], "Frame:{:}".format(self.frame_nr))
        self.frame_nr += 1
spiglerg's avatar
spiglerg committed
593

Erik Nygren's avatar
Erik Nygren committed
594
595
        if episode is not None:
            self.gl.text(0.1, text_y[1], "Ep:{}".format(episode))
596

Erik Nygren's avatar
Erik Nygren committed
597
598
        if step is not None:
            self.gl.text(0.1, text_y[0], "Step:{}".format(step))
599

Erik Nygren's avatar
Erik Nygren committed
600
601
602
603
604
605
606
607
        time_now = time.time()
        self.gl.text(2, text_y[2], "elapsed:{:.2f}s".format(time_now - self.start_time))
        self.times_list.append(time_now)
        if len(self.times_list) > 20:
            self.times_list.popleft()
        if len(self.times_list) > 1:
            rFps = (len(self.times_list) - 1) / (self.times_list[-1] - self.times_list[0])
            self.gl.text(2, text_y[1], "fps:{:.2f}".format(rFps))
608

Erik Nygren's avatar
Erik Nygren committed
609
        self.gl.prettify2(env.width, env.height, self.pix_per_cell)
610

611
612
        # TODO: for MPL, we don't want to call clf (called by endframe)
        # if not show:
613

614
615
616
        if show and type(self.gl) is PILGL:
            self.gl.show()

617
        self.gl.pause(0.00001)
618

619
620
        if return_image:
            return self.get_image()
621
        return
622

623
    def render_env_svg(
624
        self, show=False, show_observations=True, show_predictions=False, selected_agent=None,
hagrid67's avatar
hagrid67 committed
625
        show_agents=True, show_inactive_agents=False, show_rowcols=False, return_image=False
u214892's avatar
u214892 committed
626
    ):
627
        """
628
        Renders the environment with SVG support (nice image)
629
630
631
632
        """

        env = self.env

Erik Nygren's avatar
Erik Nygren committed
633
        self.gl.begin_frame()
hagrid67's avatar
hagrid67 committed
634

635
636
637
        if self.new_rail:
            self.new_rail = False
            self.gl.clear_rails()
hagrid67's avatar
hagrid67 committed
638
639

            # store the targets
Erik Nygren's avatar
Erik Nygren committed
640
641
            targets = {}
            selected = {}
u229589's avatar
u229589 committed
642
            for agent_idx, agent in enumerate(self.env.agents):
hagrid67's avatar
hagrid67 committed
643
644
                if agent is None:
                    continue
Erik Nygren's avatar
Erik Nygren committed
645
646
                targets[tuple(agent.target)] = agent_idx
                selected[tuple(agent.target)] = (agent_idx == selected_agent)
hagrid67's avatar
hagrid67 committed
647

648
649
650
            # Draw each cell independently
            for r in range(env.height):
                for c in range(env.width):
Erik Nygren's avatar
Erik Nygren committed
651
652
653
654
                    transitions = env.rail.grid[r, c]
                    if (r, c) in targets:
                        target = targets[(r, c)]
                        is_selected = selected[(r, c)]
hagrid67's avatar
hagrid67 committed
655
656
                    else:
                        target = None
Erik Nygren's avatar
Erik Nygren committed
657
                        is_selected = False
658

Erik Nygren's avatar
Erik Nygren committed
659
                    self.gl.set_rail_at(r, c, transitions, target=target, is_selected=is_selected,
Egli Adrian (IT-SCI-API-PFI)'s avatar
Egli Adrian (IT-SCI-API-PFI) committed
660
661
                                        rail_grid=env.rail.grid, num_agents=env.get_num_agents(),
                                        show_debug=self.show_debug)
662

Erik Nygren's avatar
Erik Nygren committed
663
            self.gl.build_background_map(targets)
664

hagrid67's avatar
hagrid67 committed
665
666
667
668
669
670
            if show_rowcols:
                # label rows, cols
                for iRow in range(env.height):
                    self.gl.text_rowcol((iRow, 0), str(iRow), layer=self.gl.RAIL_LAYER)
                for iCol in range(env.width):
                    self.gl.text_rowcol((0, iCol), str(iCol), layer=self.gl.RAIL_LAYER)
671
672


673
674
        if show_agents:
            for agent_idx, agent in enumerate(self.env.agents):
675

676
677
678
679
                if agent is None:
                    continue

                # Show an agent even if it hasn't already started
680
681
682
683
684
685
686
687
688
                if agent.position is None:
                    if show_inactive_agents:
                        # print("agent ", agent_idx, agent.position, agent.old_position, agent.initial_position)
                        self.gl.set_agent_at(agent_idx, *(agent.initial_position), 
                            agent.initial_direction, agent.initial_direction,
                            is_selected=(selected_agent == agent_idx),
                            rail_grid=env.rail.grid,
                            show_debug=self.show_debug, clear_debug_text=self.clear_debug_text,
                            malfunction=False)
689
                    continue
690

691
692
693
694
                is_malfunction = agent.malfunction_data["malfunction"] > 0

                if self.agent_render_variant == AgentRenderVariant.BOX_ONLY:
                    self.gl.set_cell_occupied(agent_idx, *(agent.position))
695

696
697
                elif self.agent_render_variant == AgentRenderVariant.ONE_STEP_BEHIND or \
                    self.agent_render_variant == AgentRenderVariant.ONE_STEP_BEHIND_AND_BOX:  # noqa: E125
698
699

                    # Most common case - the agent has been running for >1 steps
700
701
702
703
                    if agent.old_position is not None:
                        position = agent.old_position
                        direction = agent.direction
                        old_direction = agent.old_direction
704
705
706

                    # the agent's first step - it doesn't have an old position yet
                    elif agent.position is not None:
707
708
709
                        position = agent.position
                        direction = agent.direction
                        old_direction = agent.direction
710
711
712
713
714
715
                        
                    # When the editor has just added an agent
                    elif agent.initial_position is not None:
                        position = agent.initial_position
                        direction = agent.initial_direction
                        old_direction = agent.initial_direction
716
717
718
719
720

                    # set_agent_at uses the agent index for the color
                    if self.agent_render_variant == AgentRenderVariant.ONE_STEP_BEHIND_AND_BOX:
                        self.gl.set_cell_occupied(agent_idx, *(agent.position))
                    self.gl.set_agent_at(agent_idx, *position, old_direction, direction,
721
722
723
                                         selected_agent == agent_idx, rail_grid=env.rail.grid,
                                         show_debug=self.show_debug, clear_debug_text=self.clear_debug_text,
                                         malfunction=is_malfunction)
724
725
726
                else:
                    position = agent.position
                    direction = agent.direction
727
728
729
730
731
732
733
734
                    for possible_direction in range(4):
                        # Is a transition along movement `desired_movement_from_new_cell` to the current cell possible?
                        isValid = env.rail.get_transition((*agent.position, agent.direction), possible_direction)
                        if isValid:
                            direction = possible_direction

                            # set_agent_at uses the agent index for the color
                            self.gl.set_agent_at(agent_idx, *position, agent.direction, direction,
735
736
737
                                                 selected_agent == agent_idx, rail_grid=env.rail.grid,
                                                 show_debug=self.show_debug, clear_debug_text=self.clear_debug_text,
                                                 malfunction=is_malfunction)
738
739
740
741

                    # set_agent_at uses the agent index for the color
                    if self.agent_render_variant == AgentRenderVariant.AGENT_SHOWS_OPTIONS_AND_BOX:
                        self.gl.set_cell_occupied(agent_idx, *(agent.position))
742
743
                    
                    if show_inactive_agents:
744
                        show_this_agent = True
745
                    else:
746
                        show_this_agent = TrainState.on_map_state(agent.state)
747
748
749
750
751

                    if show_this_agent:
                        self.gl.set_agent_at(agent_idx, *position, agent.direction, direction, 
                                        selected_agent == agent_idx,
                                        rail_grid=env.rail.grid, malfunction=is_malfunction)
Erik Nygren's avatar
Erik Nygren committed
752

753
        if show_observations:
Erik Nygren's avatar
Erik Nygren committed
754
            self.render_observation(range(env.get_num_agents()), env.dev_obs_dict)
755
        if show_predictions:
Erik Nygren's avatar
Erik Nygren committed
756
            self.render_prediction(range(env.get_num_agents()), env.dev_pred_dict)
757
758
759
        


hagrid67's avatar
hagrid67 committed
760
761
        if show:
            self.gl.show()
hagrid67's avatar
hagrid67 committed
762
        for i in range(3):
Erik Nygren's avatar
Erik Nygren committed
763
            self.gl.process_events()
764

Erik Nygren's avatar
Erik Nygren committed
765
        self.frame_nr += 1
766
767
        if return_image:
            return self.get_image()
768
        return
hagrid67's avatar
hagrid67 committed
769
770
771

    def close_window(self):
        self.gl.close_window()