From a5d4ec9fc091f9254ca14c40720b13eae7abca27 Mon Sep 17 00:00:00 2001
From: hagrid67 <jdhwatson@gmail.com>
Date: Fri, 26 Apr 2019 19:58:26 +0100
Subject: [PATCH] add various buttons to editor and update notebook

---
 flatland/utils/editor.py     | 196 +++++++++++++++++++++++++---------
 notebooks/CanvasEditor.ipynb | 199 ++++++++++++++++++++++++-----------
 2 files changed, 282 insertions(+), 113 deletions(-)

diff --git a/flatland/utils/editor.py b/flatland/utils/editor.py
index c62ad0e..567c894 100644
--- a/flatland/utils/editor.py
+++ b/flatland/utils/editor.py
@@ -5,31 +5,56 @@ from collections import deque
 from matplotlib import pyplot as plt
 from contextlib import redirect_stdout
 import os
+import sys
 
 # import io
 # from PIL import Image
 # from ipywidgets import IntSlider, link, VBox
 
-# from flatland.envs.rail_env import RailEnv, random_rail_generator
+from flatland.envs.rail_env import RailEnv, random_rail_generator
 # from flatland.core.transitions import RailEnvTransitions
-# from flatland.core.env_observation_builder import TreeObsForRailEnv
+from flatland.core.env_observation_builder import TreeObsForRailEnv
 import flatland.utils.rendertools as rt
 
 
 class JupEditor(object):
-    def __init__(self, env):
+    def __init__(self, env, wid_img):
         self.env = env
+        self.wid_img = wid_img
+
         self.qEvents = deque()
 
+        self.regen_size = 10
+
         # TODO: These are currently estimated values
         self.yxBase = array([6, 21])  # pixel offset
-        self.nPixCell = 35
+        self.nPixCell = 700 / self.env.rail.width  # 35
 
         self.rcHistory = []
         self.iTransLast = -1
         self.gRCTrans = array([[-1, 0], [0, 1], [1, 0], [0, -1]])  # NESW in RC
+
+        self.debug = False
+        self.wid_output = None
+        self.drawMode = "Draw"
+        self.env_filename = "temp.npy"
+
+    def set_env(self, env):
+        self.env = env
+        self.yxBase = array([6, 21])  # pixel offset
+        self.nPixCell = 700 / self.env.rail.width  # 35
         self.oRT = rt.RenderTool(env)
 
+    def setDebug(self, dEvent):
+        self.debug = dEvent["new"]
+        self.log("Debug:", self.debug)
+
+    def setOutput(self, wid_output):
+        self.wid_output = wid_output
+
+    def setDrawMode(self, dEvent):
+        self.drawMode = dEvent["new"]
+
     def event_handler(self, wid, event):
         """Mouse motion event handler
         """
@@ -41,6 +66,11 @@ class JupEditor(object):
         bRedrawn = False
         writableData = None
 
+        if self.debug:
+            self.log("debug:", len(qEvents), len(rcHistory), event)
+
+        assert wid == self.wid_img, "wid not same as wid_img"
+
         # If the mouse is held down, enqueue an event in our own queue
         if event["buttons"] > 0:
             qEvents.append((time.time(), x, y))
@@ -49,9 +79,9 @@ class JupEditor(object):
             tNow = time.time()
             if tNow - qEvents[0][0] > 0.1:   # wait before trying to draw
                 height, width = wid.data.shape[:2]
-                writableData = np.copy(wid.data)  # writable copy of image - wid.data is somehow readonly
+                writableData = np.copy(self.wid_img.data)  # writable copy of image - wid_img.data is somehow readonly
                 
-                with wid.hold_sync():
+                with self.wid_img.hold_sync():
                     while len(qEvents) > 0:
                         t, x, y = qEvents.popleft()  # get events from our queue
 
@@ -70,53 +100,119 @@ class JupEditor(object):
                         else:
                             rcHistory.append(rcCell)
 
-        # If we have already touched 3 cells
-        # We have a transition into a cell, and out of it.
-        if len(rcHistory) >= 3:
-            rc3Cells = array(rcHistory[:3])  # the 3 cells
-            rcMiddle = rc3Cells[1]  # the middle cell which we will update
-            # get the 2 row, col deltas between the 3 cells, eg [-1,0] = North
-            rc2Trans = np.diff(rc3Cells, axis=0)
+        elif len(rcHistory) >= 3:
+            # If we have already touched 3 cells
+            # We have a transition into a cell, and out of it.
             
-            # get the direction index for the 2 transitions
-            liTrans = []
-            for rcTrans in rc2Trans:
-                iTrans = np.argwhere(np.all(self.gRCTrans - rcTrans == 0, axis=1))
-                if len(iTrans) > 0:
-                    iTrans = iTrans[0][0]
-                    liTrans.append(iTrans)
-
-            if len(liTrans) == 2:
-                # Set the transition
-                # oEnv.rail.set_transition((*rcLast, iTransLast), iTrans, True) # does nothing
-                iValCell = env.rail.transitions.set_transition(
-                    env.rail.grid[tuple(rcMiddle)], liTrans[0], liTrans[1], True)
-
-                # Also set the reverse transition
-                iValCell = env.rail.transitions.set_transition(
-                    iValCell,
-                    (liTrans[1] + 2) % 4,
-                    (liTrans[0] + 2) % 4,
-                    True)
-
-                # Write the cell transition value back into the grid
-                env.rail.grid[tuple(rcMiddle)] = iValCell
+            while len(rcHistory) >= 3:
+                rc3Cells = array(rcHistory[:3])  # the 3 cells
+                rcMiddle = rc3Cells[1]  # the middle cell which we will update
+                # get the 2 row, col deltas between the 3 cells, eg [-1,0] = North
+                rc2Trans = np.diff(rc3Cells, axis=0)
                 
-                # TODO: bit of a hack - can we suppress the console messages from MPL at source?
-                with redirect_stdout(os.devnull):
-                    plt.figure(figsize=(10, 10))
-                    self.oRT.renderEnv(spacing=False, arrows=False, sRailColor="gray", show=False)
-                    img = self.oRT.getImage()
-                    plt.clf()
-                    plt.close()
-
-                # This updates the image in the browser with the new rendered image
-                wid.data = img
-                bRedrawn = True
-        
-            rcHistory.pop(0)  # remove the last-but-one
+                # get the direction index for the 2 transitions
+                liTrans = []
+                for rcTrans in rc2Trans:
+                    iTrans = np.argwhere(np.all(self.gRCTrans - rcTrans == 0, axis=1))
+                    if len(iTrans) > 0:
+                        iTrans = iTrans[0][0]
+                        liTrans.append(iTrans)
+
+                if len(liTrans) == 2:
+                    # Set the transition
+                    # oEnv.rail.set_transition((*rcLast, iTransLast), iTrans, True) # does nothing
+                    iValCell = env.rail.transitions.set_transition(
+                        env.rail.grid[tuple(rcMiddle)], liTrans[0], liTrans[1], True)
+
+                    # Also set the reverse transition
+                    iValCell = env.rail.transitions.set_transition(
+                        iValCell,
+                        (liTrans[1] + 2) % 4,
+                        (liTrans[0] + 2) % 4,
+                        True)
+
+                    # Write the cell transition value back into the grid
+                    env.rail.grid[tuple(rcMiddle)] = iValCell
             
+                rcHistory.pop(0)  # remove the last-but-one
+            
+            self.redraw()
+            bRedrawn = True
+
+        # only redraw with the dots/squares if necessary
         if not bRedrawn and writableData is not None:
             # This updates the image in the browser to be the new edited version
-            wid.data = writableData
+            self.wid_img.data = writableData
+    
+    def on_click(self, event):
+        pass
+
+    def redraw(self, hide_stdout=True, update=True):
+
+        if hide_stdout:
+            stdout_dest = os.devnull
+        else:
+            stdout_dest = sys.stdout
+
+        # TODO: bit of a hack - can we suppress the console messages from MPL at source?
+        with redirect_stdout(stdout_dest):
+            plt.figure(figsize=(10, 10))
+            self.oRT.renderEnv(spacing=False, arrows=False, sRailColor="gray", show=False)
+            img = self.oRT.getImage()
+            plt.clf()
+            plt.close()
+        
+        if update:
+            self.wid_img.data = img
+        return img
+
+    def redraw_event(self, event):
+        img = self.redraw()
+        self.wid_img.data = img
+    
+    def clear(self, event):
+        self.env.rail.grid[:, :] = 0
+        self.redraw_event(event)
+
+    def setFilename(self, filename):
+        self.log("filename = ", filename, type(filename))
+        self.env_filename = filename
+
+    def setFilename_event(self, event):
+        self.setFilename(event["new"])
+
+    def load(self, event):
+        self.env.rail.load_transition_map(self.env_filename, override_gridsize=True)
+        self.fix_env()
+        self.set_env(self.env)
+        self.wid_img.data = self.redraw()
+    
+    def save(self, event):
+        self.log("save to ", self.env_filename)
+        self.env.rail.save_transition_map(self.env_filename)
+
+    def regenerate_event(self, event):
+        self.env = RailEnv(width=self.regen_size,
+              height=self.regen_size,
+              rail_generator=random_rail_generator(cell_type_relative_proportion=[1, 1] + [0.5] * 6),
+              number_of_agents=0,
+              obs_builder_object=TreeObsForRailEnv(max_depth=2))
+        self.env.reset()
+        self.set_env(self.env)
+        self.redraw()
+        
+    def setRegenSize_event(self, event):
+        self.regen_size = event["new"]
+    
+    def fix_env(self):
+        self.env.width = self.env.rail.width
+        self.env.height = self.env.rail.height
+
+    def log(self, *args, **kwargs):
+
+        if self.wid_output:
+            with self.wid_output:
+                print(*args, **kwargs)
+        else:
+            print(*args, **kwargs)
 
diff --git a/notebooks/CanvasEditor.ipynb b/notebooks/CanvasEditor.ipynb
index 6013c92..faa57ce 100644
--- a/notebooks/CanvasEditor.ipynb
+++ b/notebooks/CanvasEditor.ipynb
@@ -48,11 +48,11 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 3,
+   "execution_count": 52,
    "metadata": {},
    "outputs": [],
    "source": [
-    "from ipywidgets import IntSlider, link, VBox, RadioButtons, HBox"
+    "from ipywidgets import IntSlider, link, VBox, RadioButtons, HBox, interact"
    ]
   },
   {
@@ -103,9 +103,7 @@
     "              rail_generator=random_rail_generator(cell_type_relative_proportion=[1,1] + [0.5] * 6),\n",
     "              number_of_agents=0,\n",
     "              obs_builder_object=TreeObsForRailEnv(max_depth=2))\n",
-    "obs = oEnv.reset()\n",
-    "\n",
-    "oRT = rt.RenderTool(oEnv)"
+    "obs = oEnv.reset()"
    ]
   },
   {
@@ -115,8 +113,39 @@
    "outputs": [],
    "source": [
     "sfEnv = \"../flatland/env-data/tests/test1.npy\"\n",
-    "if False:\n",
-    "    oEnv.rail.load_transition_map(sfEnv)"
+    "if True:\n",
+    "    oEnv.rail.load_transition_map(sfEnv)\n",
+    "    oEnv.width = oEnv.rail.width\n",
+    "    oEnv.height = oEnv.rail.height"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "oRT = rt.RenderTool(oEnv)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "10"
+      ]
+     },
+     "execution_count": 9,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "oEnv.width"
    ]
   },
   {
@@ -128,7 +157,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 8,
+   "execution_count": 10,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -145,7 +174,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 9,
+   "execution_count": 144,
    "metadata": {},
    "outputs": [
     {
@@ -185,25 +214,65 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 10,
+   "execution_count": 156,
    "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "([], deque([]))"
-      ]
-     },
-     "execution_count": 10,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
+   "outputs": [],
    "source": [
     "wid_img.unregister_all()\n",
-    "oEditor = JupEditor(oEnv)\n",
-    "wid_img.register_move(oEditor.event_handler)\n",
-    "oEditor.rcHistory, oEditor.qEvents"
+    "oEditor = JupEditor(oEnv, wid_img)\n",
+    "wid_img.register_move(oEditor.event_handler)\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Some more widgets"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 157,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "wid_drawMode = ipywidgets.RadioButtons(options=[\"Draw\", \"Erase\", \"Origin\", \"Destination\"])\n",
+    "wid_drawMode.observe(oEditor.setDrawMode, names=\"value\")\n",
+    "wid_refresh = ipywidgets.Button(description=\"Refresh\")\n",
+    "wid_refresh.on_click(oEditor.redraw_event)\n",
+    "wid_clear = ipywidgets.Button(description = \"Clear\")\n",
+    "wid_clear.on_click(oEditor.clear)\n",
+    "wid_debug = ipywidgets.Checkbox(description = \"Debug\")\n",
+    "wid_debug.observe(oEditor.setDebug, names=\"value\")\n",
+    "wid_output = ipywidgets.Output()\n",
+    "oEditor.setOutput(wid_output)\n",
+    "wid_regen = ipywidgets.Button(description = \"Regenerate\")\n",
+    "wid_filename = ipywidgets.Text(description = \"Filename\")\n",
+    "wid_filename.value = sfEnv\n",
+    "oEditor.setFilename(sfEnv)\n",
+    "wid_filename.observe(oEditor.setFilename_event, names=\"value\")\n",
+    "\n",
+    "wid_size = ipywidgets.IntSlider(min=5, max=30, step=5, description=\"Regen Size\")\n",
+    "wid_size.observe(oEditor.setRegenSize_event, names=\"value\")\n",
+    "\n",
+    "\n",
+    "ldButtons = [\n",
+    "    dict(name = \"Refresh\", method = oEditor.redraw_event),\n",
+    "    dict(name = \"Clear\", method = oEditor.clear),\n",
+    "    dict(name = \"Regenerate\", method = oEditor.regenerate_event),\n",
+    "    dict(name = \"Load\", method = oEditor.load),\n",
+    "    dict(name = \"Save\", method = oEditor.save)\n",
+    "]\n",
+    "\n",
+    "lwid_buttons = []\n",
+    "for dButton in ldButtons:\n",
+    "    wid_button = ipywidgets.Button(description = dButton[\"name\"])\n",
+    "    wid_button.on_click(dButton[\"method\"])\n",
+    "    lwid_buttons.append(wid_button)\n",
+    "    \n",
+    "\n",
+    "#wid_debug = interact(oEditor.setDebug, debug=False)\n",
+    "vbox_controls = VBox([wid_filename, wid_drawMode, *lwid_buttons, wid_size, wid_debug])\n"
    ]
   },
   {
@@ -216,7 +285,36 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 11,
+   "execution_count": 158,
+   "metadata": {
+    "scrolled": false
+   },
+   "outputs": [
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "2fceb907aab945788d32e2c4555d5071",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "HBox(children=(Canvas(), VBox(children=(Text(value='../flatland/env-data/tests/test1.npy', description='Filena…"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "# wid_box\n",
+    "wid_main = HBox([wid_img, vbox_controls])\n",
+    "wid_output.clear_output()\n",
+    "wid_main"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 138,
    "metadata": {
     "scrolled": false
    },
@@ -224,12 +322,12 @@
     {
      "data": {
       "application/vnd.jupyter.widget-view+json": {
-       "model_id": "e36f66779f454856882018ee3fa8e8b3",
+       "model_id": "b9c28e5dab4e46b49ab1fb7dd9f3834b",
        "version_major": 2,
        "version_minor": 0
       },
       "text/plain": [
-       "Canvas()"
+       "Output(outputs=({'output_type': 'stream', 'text': 'Debug: True\\n', 'name': 'stdout'},))"
       ]
      },
      "metadata": {},
@@ -237,9 +335,7 @@
     }
    ],
    "source": [
-    "#wid_box\n",
-    "#HBox([wid_img, wid_buttons])\n",
-    "wid_img"
+    "wid_output"
    ]
   },
   {
@@ -251,7 +347,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 12,
+   "execution_count": 15,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -267,32 +363,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 13,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "52bb87bcae69447fb1ecbf06fff971bc",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "RadioButtons(options=('Draw', 'Erase'), value='Draw')"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    }
-   ],
-   "source": [
-    "wid_buttons = ipywidgets.RadioButtons(options=[\"Draw\", \"Erase\"])\n",
-    "wid_buttons"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 14,
+   "execution_count": 16,
    "metadata": {},
    "outputs": [
     {
@@ -301,7 +372,7 @@
        "'Draw'"
       ]
      },
-     "execution_count": 14,
+     "execution_count": 16,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -312,7 +383,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 15,
+   "execution_count": 17,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -322,21 +393,23 @@
     "    yxBase = array([6, 21])\n",
     "    nPixCell = 35\n",
     "    rcCell = ((array([y, x]) - yxBase) / nPixCell).astype(int)\n",
+    "    print(ev)\n",
     "    print(x, y, (x-21) / 35, (y-6) / 35, rcCell)"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 16,
+   "execution_count": 18,
    "metadata": {},
    "outputs": [],
    "source": [
-    "#wid_img.register_click(evListen)"
+    "# wid_img.register_click(evListen)\n",
+    "#wid_img.register(evListen)"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 17,
+   "execution_count": 19,
    "metadata": {},
    "outputs": [],
    "source": [
-- 
GitLab