From a5d4ec9fc091f9254ca14c40720b13eae7abca27 Mon Sep 17 00:00:00 2001 From: hagrid67 <jdhwatson@gmail.com> Date: Fri, 26 Apr 2019 19:58:26 +0100 Subject: [PATCH] add various buttons to editor and update notebook --- flatland/utils/editor.py | 196 +++++++++++++++++++++++++--------- notebooks/CanvasEditor.ipynb | 199 ++++++++++++++++++++++++----------- 2 files changed, 282 insertions(+), 113 deletions(-) diff --git a/flatland/utils/editor.py b/flatland/utils/editor.py index c62ad0e..567c894 100644 --- a/flatland/utils/editor.py +++ b/flatland/utils/editor.py @@ -5,31 +5,56 @@ from collections import deque from matplotlib import pyplot as plt from contextlib import redirect_stdout import os +import sys # import io # from PIL import Image # from ipywidgets import IntSlider, link, VBox -# from flatland.envs.rail_env import RailEnv, random_rail_generator +from flatland.envs.rail_env import RailEnv, random_rail_generator # from flatland.core.transitions import RailEnvTransitions -# from flatland.core.env_observation_builder import TreeObsForRailEnv +from flatland.core.env_observation_builder import TreeObsForRailEnv import flatland.utils.rendertools as rt class JupEditor(object): - def __init__(self, env): + def __init__(self, env, wid_img): self.env = env + self.wid_img = wid_img + self.qEvents = deque() + self.regen_size = 10 + # TODO: These are currently estimated values self.yxBase = array([6, 21]) # pixel offset - self.nPixCell = 35 + self.nPixCell = 700 / self.env.rail.width # 35 self.rcHistory = [] self.iTransLast = -1 self.gRCTrans = array([[-1, 0], [0, 1], [1, 0], [0, -1]]) # NESW in RC + + self.debug = False + self.wid_output = None + self.drawMode = "Draw" + self.env_filename = "temp.npy" + + def set_env(self, env): + self.env = env + self.yxBase = array([6, 21]) # pixel offset + self.nPixCell = 700 / self.env.rail.width # 35 self.oRT = rt.RenderTool(env) + def setDebug(self, dEvent): + self.debug = dEvent["new"] + self.log("Debug:", self.debug) + + def setOutput(self, wid_output): + self.wid_output = wid_output + + def setDrawMode(self, dEvent): + self.drawMode = dEvent["new"] + def event_handler(self, wid, event): """Mouse motion event handler """ @@ -41,6 +66,11 @@ class JupEditor(object): bRedrawn = False writableData = None + if self.debug: + self.log("debug:", len(qEvents), len(rcHistory), event) + + assert wid == self.wid_img, "wid not same as wid_img" + # If the mouse is held down, enqueue an event in our own queue if event["buttons"] > 0: qEvents.append((time.time(), x, y)) @@ -49,9 +79,9 @@ class JupEditor(object): tNow = time.time() if tNow - qEvents[0][0] > 0.1: # wait before trying to draw height, width = wid.data.shape[:2] - writableData = np.copy(wid.data) # writable copy of image - wid.data is somehow readonly + writableData = np.copy(self.wid_img.data) # writable copy of image - wid_img.data is somehow readonly - with wid.hold_sync(): + with self.wid_img.hold_sync(): while len(qEvents) > 0: t, x, y = qEvents.popleft() # get events from our queue @@ -70,53 +100,119 @@ class JupEditor(object): else: rcHistory.append(rcCell) - # If we have already touched 3 cells - # We have a transition into a cell, and out of it. - if len(rcHistory) >= 3: - rc3Cells = array(rcHistory[:3]) # the 3 cells - rcMiddle = rc3Cells[1] # the middle cell which we will update - # get the 2 row, col deltas between the 3 cells, eg [-1,0] = North - rc2Trans = np.diff(rc3Cells, axis=0) + elif len(rcHistory) >= 3: + # If we have already touched 3 cells + # We have a transition into a cell, and out of it. - # get the direction index for the 2 transitions - liTrans = [] - for rcTrans in rc2Trans: - iTrans = np.argwhere(np.all(self.gRCTrans - rcTrans == 0, axis=1)) - if len(iTrans) > 0: - iTrans = iTrans[0][0] - liTrans.append(iTrans) - - if len(liTrans) == 2: - # Set the transition - # oEnv.rail.set_transition((*rcLast, iTransLast), iTrans, True) # does nothing - iValCell = env.rail.transitions.set_transition( - env.rail.grid[tuple(rcMiddle)], liTrans[0], liTrans[1], True) - - # Also set the reverse transition - iValCell = env.rail.transitions.set_transition( - iValCell, - (liTrans[1] + 2) % 4, - (liTrans[0] + 2) % 4, - True) - - # Write the cell transition value back into the grid - env.rail.grid[tuple(rcMiddle)] = iValCell + while len(rcHistory) >= 3: + rc3Cells = array(rcHistory[:3]) # the 3 cells + rcMiddle = rc3Cells[1] # the middle cell which we will update + # get the 2 row, col deltas between the 3 cells, eg [-1,0] = North + rc2Trans = np.diff(rc3Cells, axis=0) - # TODO: bit of a hack - can we suppress the console messages from MPL at source? - with redirect_stdout(os.devnull): - plt.figure(figsize=(10, 10)) - self.oRT.renderEnv(spacing=False, arrows=False, sRailColor="gray", show=False) - img = self.oRT.getImage() - plt.clf() - plt.close() - - # This updates the image in the browser with the new rendered image - wid.data = img - bRedrawn = True - - rcHistory.pop(0) # remove the last-but-one + # get the direction index for the 2 transitions + liTrans = [] + for rcTrans in rc2Trans: + iTrans = np.argwhere(np.all(self.gRCTrans - rcTrans == 0, axis=1)) + if len(iTrans) > 0: + iTrans = iTrans[0][0] + liTrans.append(iTrans) + + if len(liTrans) == 2: + # Set the transition + # oEnv.rail.set_transition((*rcLast, iTransLast), iTrans, True) # does nothing + iValCell = env.rail.transitions.set_transition( + env.rail.grid[tuple(rcMiddle)], liTrans[0], liTrans[1], True) + + # Also set the reverse transition + iValCell = env.rail.transitions.set_transition( + iValCell, + (liTrans[1] + 2) % 4, + (liTrans[0] + 2) % 4, + True) + + # Write the cell transition value back into the grid + env.rail.grid[tuple(rcMiddle)] = iValCell + rcHistory.pop(0) # remove the last-but-one + + self.redraw() + bRedrawn = True + + # only redraw with the dots/squares if necessary if not bRedrawn and writableData is not None: # This updates the image in the browser to be the new edited version - wid.data = writableData + self.wid_img.data = writableData + + def on_click(self, event): + pass + + def redraw(self, hide_stdout=True, update=True): + + if hide_stdout: + stdout_dest = os.devnull + else: + stdout_dest = sys.stdout + + # TODO: bit of a hack - can we suppress the console messages from MPL at source? + with redirect_stdout(stdout_dest): + plt.figure(figsize=(10, 10)) + self.oRT.renderEnv(spacing=False, arrows=False, sRailColor="gray", show=False) + img = self.oRT.getImage() + plt.clf() + plt.close() + + if update: + self.wid_img.data = img + return img + + def redraw_event(self, event): + img = self.redraw() + self.wid_img.data = img + + def clear(self, event): + self.env.rail.grid[:, :] = 0 + self.redraw_event(event) + + def setFilename(self, filename): + self.log("filename = ", filename, type(filename)) + self.env_filename = filename + + def setFilename_event(self, event): + self.setFilename(event["new"]) + + def load(self, event): + self.env.rail.load_transition_map(self.env_filename, override_gridsize=True) + self.fix_env() + self.set_env(self.env) + self.wid_img.data = self.redraw() + + def save(self, event): + self.log("save to ", self.env_filename) + self.env.rail.save_transition_map(self.env_filename) + + def regenerate_event(self, event): + self.env = RailEnv(width=self.regen_size, + height=self.regen_size, + rail_generator=random_rail_generator(cell_type_relative_proportion=[1, 1] + [0.5] * 6), + number_of_agents=0, + obs_builder_object=TreeObsForRailEnv(max_depth=2)) + self.env.reset() + self.set_env(self.env) + self.redraw() + + def setRegenSize_event(self, event): + self.regen_size = event["new"] + + def fix_env(self): + self.env.width = self.env.rail.width + self.env.height = self.env.rail.height + + def log(self, *args, **kwargs): + + if self.wid_output: + with self.wid_output: + print(*args, **kwargs) + else: + print(*args, **kwargs) diff --git a/notebooks/CanvasEditor.ipynb b/notebooks/CanvasEditor.ipynb index 6013c92..faa57ce 100644 --- a/notebooks/CanvasEditor.ipynb +++ b/notebooks/CanvasEditor.ipynb @@ -48,11 +48,11 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 52, "metadata": {}, "outputs": [], "source": [ - "from ipywidgets import IntSlider, link, VBox, RadioButtons, HBox" + "from ipywidgets import IntSlider, link, VBox, RadioButtons, HBox, interact" ] }, { @@ -103,9 +103,7 @@ " rail_generator=random_rail_generator(cell_type_relative_proportion=[1,1] + [0.5] * 6),\n", " number_of_agents=0,\n", " obs_builder_object=TreeObsForRailEnv(max_depth=2))\n", - "obs = oEnv.reset()\n", - "\n", - "oRT = rt.RenderTool(oEnv)" + "obs = oEnv.reset()" ] }, { @@ -115,8 +113,39 @@ "outputs": [], "source": [ "sfEnv = \"../flatland/env-data/tests/test1.npy\"\n", - "if False:\n", - " oEnv.rail.load_transition_map(sfEnv)" + "if True:\n", + " oEnv.rail.load_transition_map(sfEnv)\n", + " oEnv.width = oEnv.rail.width\n", + " oEnv.height = oEnv.rail.height" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "oRT = rt.RenderTool(oEnv)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "10" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "oEnv.width" ] }, { @@ -128,7 +157,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -145,7 +174,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 144, "metadata": {}, "outputs": [ { @@ -185,25 +214,65 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 156, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "([], deque([]))" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "wid_img.unregister_all()\n", - "oEditor = JupEditor(oEnv)\n", - "wid_img.register_move(oEditor.event_handler)\n", - "oEditor.rcHistory, oEditor.qEvents" + "oEditor = JupEditor(oEnv, wid_img)\n", + "wid_img.register_move(oEditor.event_handler)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Some more widgets" + ] + }, + { + "cell_type": "code", + "execution_count": 157, + "metadata": {}, + "outputs": [], + "source": [ + "wid_drawMode = ipywidgets.RadioButtons(options=[\"Draw\", \"Erase\", \"Origin\", \"Destination\"])\n", + "wid_drawMode.observe(oEditor.setDrawMode, names=\"value\")\n", + "wid_refresh = ipywidgets.Button(description=\"Refresh\")\n", + "wid_refresh.on_click(oEditor.redraw_event)\n", + "wid_clear = ipywidgets.Button(description = \"Clear\")\n", + "wid_clear.on_click(oEditor.clear)\n", + "wid_debug = ipywidgets.Checkbox(description = \"Debug\")\n", + "wid_debug.observe(oEditor.setDebug, names=\"value\")\n", + "wid_output = ipywidgets.Output()\n", + "oEditor.setOutput(wid_output)\n", + "wid_regen = ipywidgets.Button(description = \"Regenerate\")\n", + "wid_filename = ipywidgets.Text(description = \"Filename\")\n", + "wid_filename.value = sfEnv\n", + "oEditor.setFilename(sfEnv)\n", + "wid_filename.observe(oEditor.setFilename_event, names=\"value\")\n", + "\n", + "wid_size = ipywidgets.IntSlider(min=5, max=30, step=5, description=\"Regen Size\")\n", + "wid_size.observe(oEditor.setRegenSize_event, names=\"value\")\n", + "\n", + "\n", + "ldButtons = [\n", + " dict(name = \"Refresh\", method = oEditor.redraw_event),\n", + " dict(name = \"Clear\", method = oEditor.clear),\n", + " dict(name = \"Regenerate\", method = oEditor.regenerate_event),\n", + " dict(name = \"Load\", method = oEditor.load),\n", + " dict(name = \"Save\", method = oEditor.save)\n", + "]\n", + "\n", + "lwid_buttons = []\n", + "for dButton in ldButtons:\n", + " wid_button = ipywidgets.Button(description = dButton[\"name\"])\n", + " wid_button.on_click(dButton[\"method\"])\n", + " lwid_buttons.append(wid_button)\n", + " \n", + "\n", + "#wid_debug = interact(oEditor.setDebug, debug=False)\n", + "vbox_controls = VBox([wid_filename, wid_drawMode, *lwid_buttons, wid_size, wid_debug])\n" ] }, { @@ -216,7 +285,36 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 158, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "2fceb907aab945788d32e2c4555d5071", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(Canvas(), VBox(children=(Text(value='../flatland/env-data/tests/test1.npy', description='Filena…" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# wid_box\n", + "wid_main = HBox([wid_img, vbox_controls])\n", + "wid_output.clear_output()\n", + "wid_main" + ] + }, + { + "cell_type": "code", + "execution_count": 138, "metadata": { "scrolled": false }, @@ -224,12 +322,12 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "e36f66779f454856882018ee3fa8e8b3", + "model_id": "b9c28e5dab4e46b49ab1fb7dd9f3834b", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "Canvas()" + "Output(outputs=({'output_type': 'stream', 'text': 'Debug: True\\n', 'name': 'stdout'},))" ] }, "metadata": {}, @@ -237,9 +335,7 @@ } ], "source": [ - "#wid_box\n", - "#HBox([wid_img, wid_buttons])\n", - "wid_img" + "wid_output" ] }, { @@ -251,7 +347,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ @@ -267,32 +363,7 @@ }, { "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "52bb87bcae69447fb1ecbf06fff971bc", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "RadioButtons(options=('Draw', 'Erase'), value='Draw')" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "wid_buttons = ipywidgets.RadioButtons(options=[\"Draw\", \"Erase\"])\n", - "wid_buttons" - ] - }, - { - "cell_type": "code", - "execution_count": 14, + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -301,7 +372,7 @@ "'Draw'" ] }, - "execution_count": 14, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -312,7 +383,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ @@ -322,21 +393,23 @@ " yxBase = array([6, 21])\n", " nPixCell = 35\n", " rcCell = ((array([y, x]) - yxBase) / nPixCell).astype(int)\n", + " print(ev)\n", " print(x, y, (x-21) / 35, (y-6) / 35, rcCell)" ] }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 18, "metadata": {}, "outputs": [], "source": [ - "#wid_img.register_click(evListen)" + "# wid_img.register_click(evListen)\n", + "#wid_img.register(evListen)" ] }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 19, "metadata": {}, "outputs": [], "source": [ -- GitLab