diff --git a/flatland/utils/graphics_pil.py b/flatland/utils/graphics_pil.py index 97e3e5622ff633a35159f1d1fe1b2249e1803463..b72216a0a85b54714b0a10503e077b06c6a019ac 100644 --- a/flatland/utils/graphics_pil.py +++ b/flatland/utils/graphics_pil.py @@ -37,6 +37,19 @@ class PILGL(GraphicsLayer): # https://stackoverflow.com/questions/26097811/image-pyimage2-doesnt-exist window = tk.Tk() + RAIL_LAYER = 0 + AGENT_LAYER = 1 + PREDICTION_PATH_LAYER = 2 + SELECTED_AGENT_LAYER = 3 + SELECTED_TARGET_LAYER = 4 + + def create_layers(self, clear=True): + self.create_layer(PILGL.RAIL_LAYER, clear=clear) # rail / background (scene) + self.create_layer(PILGL.AGENT_LAYER, clear=clear) # agents + self.create_layer(PILGL.PREDICTION_PATH_LAYER, clear=clear) # drawing layer for agent's prediction path + self.create_layer(PILGL.SELECTED_AGENT_LAYER, clear=clear) # drawing layer for selected agent + self.create_layer(PILGL.SELECTED_TARGET_LAYER, clear=clear) # drawing layer for selected agent's target + def __init__(self, width, height, jupyter=False): self.yxBase = (0, 0) self.linewidth = 4 @@ -129,7 +142,7 @@ class PILGL(GraphicsLayer): def get_agent_color(self, iAgent): return self.agent_colors[iAgent % self.n_agent_colors] - def plot(self, gX, gY, color=None, linewidth=3, layer=0, opacity=255, **kwargs): + def plot(self, gX, gY, color=None, linewidth=3, layer=RAIL_LAYER, opacity=255, **kwargs): color = self.adapt_color(color) if len(color) == 3: color += (opacity,) @@ -139,14 +152,14 @@ class PILGL(GraphicsLayer): gPoints = list(gPoints.ravel()) self.draws[layer].line(gPoints, fill=color, width=self.linewidth) - def scatter(self, gX, gY, color=None, marker="o", s=50, layer=0, opacity=255, *args, **kwargs): + def scatter(self, gX, gY, color=None, marker="o", s=50, layer=RAIL_LAYER, opacity=255, *args, **kwargs): color = self.adapt_color(color) r = np.sqrt(s) gPoints = np.stack([np.atleast_1d(gX), -np.atleast_1d(gY)]).T * self.nPixCell for x, y in gPoints: self.draws[layer].rectangle([(x - r, y - r), (x + r, y + r)], fill=color, outline=color) - def draw_image_xy(self, pil_img, xyPixLeftTop, layer=0): + def draw_image_xy(self, pil_img, xyPixLeftTop, layer=RAIL_LAYER, ): if (pil_img.mode == "RGBA"): pil_mask = pil_img else: @@ -154,7 +167,7 @@ class PILGL(GraphicsLayer): self.layers[layer].paste(pil_img, xyPixLeftTop, pil_mask) - def draw_image_row_col(self, pil_img, rcTopLeft, layer=0): + def draw_image_row_col(self, pil_img, rcTopLeft, layer=RAIL_LAYER, ): xyPixLeftTop = tuple((array(rcTopLeft) * self.nPixCell)[[1, 0]]) self.draw_image_xy(pil_img, xyPixLeftTop, layer=layer) @@ -180,7 +193,7 @@ class PILGL(GraphicsLayer): def begin_frame(self): # Create a new agent layer - self.create_layer(iLayer=1, clear=True) + self.create_layer(iLayer=PILGL.AGENT_LAYER, clear=True) def show(self, block=False): img = self.alpha_composite_layers() @@ -254,11 +267,6 @@ class PILGL(GraphicsLayer): if clear: self.clear_layer(iLayer) - def create_layers(self, clear=True): - self.create_layer(0, clear=clear) # rail / background (scene) - self.create_layer(1, clear=clear) # agents - self.create_layer(2, clear=clear) # drawing layer for selected agent - self.create_layer(3, clear=clear) # drawing layer for selected agent's target class PILSVG(PILGL): @@ -485,23 +493,21 @@ class PILSVG(PILGL): return pil - def set_rail_at(self, row, col, binary_trans, target=None, is_selected=False, rail_grid=None, - agent_rail_color=None, blend_factor=0.5): + + def clear_set_predicion_path_layer(self): + self.clear_layer(PILGL.PREDICTION_PATH_LAYER,0) + + def set_predicion_path_at(self, row, col, binary_trans, agent_rail_color): + colored_rail = self.recolor_image(self.pil_rail_org[binary_trans], + [61, 61, 61], [agent_rail_color], + False)[0] + # pil_track = Image.blend(pil_track,colored_rail,blend_factor) + # pil_track = colored_rail#Image.alpha_composite(pil_track, colored_rail) + self.draw_image_row_col(colored_rail, (row, col), layer=PILGL.PREDICTION_PATH_LAYER) + + def set_rail_at(self, row, col, binary_trans, target=None, is_selected=False, rail_grid=None): if binary_trans in self.pil_rail: pil_track = self.pil_rail[binary_trans] - if agent_rail_color is not None: - colored_rail = self.recolor_image(self.pil_rail_org[binary_trans], - [61, 61, 61], [agent_rail_color], - False)[0] - rcTopLeft1 = (row, col) - rcTopLeft2 = (row + 1, col + 1) - rcTopLeft1 = tuple((array(rcTopLeft1) * self.nPixCell)[[1, 0]]) - rcTopLeft2 = tuple((array(rcTopLeft2) * self.nPixCell)[[1, 0]]) - pil_track = Image.blend( - self.layers[0].crop((rcTopLeft1[0], rcTopLeft1[1], rcTopLeft2[0], rcTopLeft2[1])), - colored_rail, - blend_factor) - if target is not None: pil_track = Image.alpha_composite(pil_track, self.station_colors[target % len(self.station_colors)]) @@ -522,15 +528,15 @@ class PILSVG(PILGL): a = a2 pil_track = self.scenery[a % len(self.scenery)] - self.draw_image_row_col(pil_track, (row, col)) + self.draw_image_row_col(pil_track, (row, col), layer=PILGL.RAIL_LAYER) else: print("Illegal rail:", row, col, format(binary_trans, "#018b")[2:], binary_trans) if target is not None: if is_selected: svgBG = self.pil_from_svg_file("svg", "Selected_Target.svg") - self.clear_layer(3, 0) - self.draw_image_row_col(svgBG, (row, col), layer=3) + self.clear_layer(PILGL.SELECTED_TARGET_LAYER, 0) + self.draw_image_row_col(svgBG, (row, col), layer=PILGL.SELECTED_TARGET_LAYER) def recolor_image(self, pil, a3BaseColor, ltColors, invert=False): rgbaImg = array(pil) @@ -591,12 +597,12 @@ class PILSVG(PILGL): if delta_dir == 2: in_direction = out_direction pil_zug = self.pil_zug[(in_direction % 4, out_direction % 4, color_idx)] - self.draw_image_row_col(pil_zug, (row, col), layer=1) + self.draw_image_row_col(pil_zug, (row, col), layer=PILGL.AGENT_LAYER) if is_selected: bg_svg = self.pil_from_svg_file("svg", "Selected_Agent.svg") - self.clear_layer(2, 0) - self.draw_image_row_col(bg_svg, (row, col), layer=2) + self.clear_layer(PILGL.SELECTED_AGENT_LAYER, 0) + self.draw_image_row_col(bg_svg, (row, col), layer=PILGL.SELECTED_AGENT_LAYER) def set_cell_occupied(self, agent_idx, row, col): occupied_im = self.cell_occupied[agent_idx % len(self.cell_occupied)] diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index 9f1b59e10aa36987e06e89d5b1e8145e55bfead1..14c49add9fe03ba90f2691b9f00de32cb0863330 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -292,19 +292,21 @@ class RenderTool(object): """ rt = self.__class__ + if type(self.gl) is PILSVG: + self.gl.clear_set_predicion_path_layer() + for agent in agent_handles: color = self.gl.get_agent_color(agent) for visited_cell in prediction_dict[agent]: cell_coord = array(visited_cell[:2]) - cell_coord_trans = np.matmul(cell_coord, rt.row_col_to_xy) + rt.x_y_half if type(self.gl) is PILSVG: # TODO : Track highlighting (Adrian) r = cell_coord[0] c = cell_coord[1] transitions = self.env.rail.grid[r, c] - self.gl.set_rail_at(r, c, transitions, target=None, is_selected=False, rail_grid=self.env.rail.grid, - agent_rail_color=color) + self.gl.set_predicion_path_at(r, c, transitions, agent_rail_color=color) else: + cell_coord_trans = np.matmul(cell_coord, rt.row_col_to_xy) + rt.x_y_half self._draw_square(cell_coord_trans, 1 / (agent + 1.1), color, layer=1, opacity=100) def render_rail(self, spacing=False, rail_color="gray", curves=True, arrows=False): diff --git a/notebooks/Simple_Rendering_Demo.ipynb b/notebooks/Simple_Rendering_Demo.ipynb index ad76f054cd3d83a2cf1c239a6656c8d95f9a2720..2084ee46cbb7de2839ec71bce98c7a5bc46c2469 100644 --- a/notebooks/Simple_Rendering_Demo.ipynb +++ b/notebooks/Simple_Rendering_Demo.ipynb @@ -9,18 +9,9 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The autoreload extension is already loaded. To reload it, use:\n", - " %reload_ext autoreload\n" - ] - } - ], + "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2" @@ -28,7 +19,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -40,7 +31,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -50,34 +41,22 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import flatland.core.env\n", "import flatland.utils.rendertools as rt\n", "from flatland.envs.rail_env import RailEnv, random_rail_generator\n", - "from flatland.envs.observations import TreeObsForRailEnv" + "from flatland.envs.observations import TreeObsForRailEnv\n", + "from flatland.envs.predictions import ShortestPathPredictorForRailEnv" ] }, { "cell_type": "code", - "execution_count": 25, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "<style>.container { width:90% !important; }</style>" - ], - "text/plain": [ - "<IPython.core.display.HTML object>" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "from IPython.core.display import display, HTML\n", "display(HTML(\"<style>.container { width:90% !important; }</style>\"))" @@ -92,7 +71,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -102,7 +81,7 @@ " height=10,\n", " rail_generator=fnMethod,\n", " number_of_agents=nAgents,\n", - " obs_builder_object=TreeObsForRailEnv(max_depth=2))" + " obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()))" ] }, { @@ -114,71 +93,40 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "oRT = rt.RenderTool(env,gl=\"PILSVG\")\n", - "env.dev_pred_dict = env.dev_obs_dict\n", "oRT.render_env(show_observations=False,show_predictions=True)\n", "img = oRT.get_image()" ] }, { "cell_type": "code", - "execution_count": 32, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "6a9401eacf31417d97119674cf249bc7", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Canvas()" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "jpy_canvas.Canvas(img)" ] }, { "cell_type": "code", - "execution_count": 29, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "oRT = rt.RenderTool(env,gl=\"PIL\")\n", - "oRT.render_env()\n", + "oRT.render_env(show_observations=False,show_predictions=True)\n", "img = oRT.get_image()" ] }, { "cell_type": "code", - "execution_count": 30, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "cedca754ea334df9b4678c1a29a4788e", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Canvas()" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "jpy_canvas.Canvas(img)" ]