From add024df7dd2b12187b90f7ca7ba02b3b447474d Mon Sep 17 00:00:00 2001
From: hagrid67 <jdhwatson@gmail.com>
Date: Thu, 13 Aug 2020 22:44:24 +0100
Subject: [PATCH] added close_following option to RailEnv, add env_edit_utils,
 test-collisions.ipynb

---
 flatland/envs/observations.py    |   2 +-
 flatland/envs/rail_env.py        |   9 +-
 flatland/utils/editor.py         |  20 ++-
 flatland/utils/env_edit_utils.py | 126 ++++++++++++++++++
 flatland/utils/rendertools.py    |  31 +++--
 notebooks/test-collision.ipynb   | 211 +++++++++++++++++++++++++++++++
 6 files changed, 382 insertions(+), 17 deletions(-)
 create mode 100644 flatland/utils/env_edit_utils.py
 create mode 100644 notebooks/test-collision.ipynb

diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 3738edd3..beabe4a4 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -220,7 +220,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                                                        speed_min_fractional=agent.speed_data['speed'],
                                                        num_agents_ready_to_depart=0,
                                                        childs={})
-        print("root node type:", type(root_node_observation))
+        #print("root node type:", type(root_node_observation))
 
         visited = OrderedSet()
 
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 91ab1a00..11c471ff 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -139,7 +139,8 @@ class RailEnv(Environment):
                  malfunction_generator=None,
                  remove_agents_at_target=True,
                  random_seed=1,
-                 record_steps=False
+                 record_steps=False,
+                 close_following=True
                  ):
         """
         Environment init.
@@ -245,7 +246,7 @@ class RailEnv(Environment):
         self.cur_episode = []  
         self.list_actions = [] # save actions in here
 
-        self.close_following = True  # use close following logic
+        self.close_following = close_following  # use close following logic
         self.motionCheck = ac.MotionCheck()
 
 
@@ -607,6 +608,7 @@ class RailEnv(Environment):
                 agent.status = RailAgentStatus.ACTIVE
                 self._set_agent_to_initial_position(agent, agent.initial_position)
                 self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
+                return
             else:
                 # TODO: Here we need to check for the departure time in future releases with full schedules
                 self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
@@ -653,7 +655,6 @@ class RailEnv(Environment):
 
             # Store the action if action is moving
             # If not moving, the action will be stored when the agent starts moving again.
-            new_position = None
             if agent.moving:
                 _action_stored = False
                 _, new_cell_valid, new_direction, new_position, transition_valid = \
@@ -850,7 +851,7 @@ class RailEnv(Environment):
         
         if move:
             if agent.position is None:  # agent is entering the env
-                print(i_agent, "writing new pos ", rc_next, " into agent position (None)")
+                #print(i_agent, "writing new pos ", rc_next, " into agent position (None)")
                 agent.position = rc_next
                 agent.status = RailAgentStatus.ACTIVE
                 agent.speed_data['position_fraction'] = 0.0
diff --git a/flatland/utils/editor.py b/flatland/utils/editor.py
index ff99ec08..f26cbaa9 100644
--- a/flatland/utils/editor.py
+++ b/flatland/utils/editor.py
@@ -420,7 +420,7 @@ class EditorModel(object):
     def set_draw_mode(self, draw_mode):
         self.draw_mode = draw_mode
 
-    def interpolate_path(self, rcLast, rc_cell):
+    def interpolate_pair(self, rcLast, rc_cell):
         if np.array_equal(rcLast, rc_cell):
             return []
         rcLast = array(rcLast)
@@ -456,6 +456,15 @@ class EditorModel(object):
             # Convert the array to a list of tuples
             lrcInterp = list(map(tuple, g2Interp))
         return lrcInterp
+    
+    def interpolate_path(self, lrcPath):
+        lrcPath2 = []  # interpolated version of the path
+        rcLast = None
+        for rcCell in lrcPath:
+            if rcLast is not None:
+                lrcPath2.extend(self.interpolate_pair(rcLast, rcCell))
+            rcLast = rcCell
+        return lrcPath2
 
     def drag_path_element(self, rc_cell):
         """Mouse motion event handler for drawing.
@@ -466,7 +475,7 @@ class EditorModel(object):
         if len(lrcStroke) > 0:
             rcLast = lrcStroke[-1]
             if not np.array_equal(rcLast, rc_cell):  # only save at transition
-                lrcInterp = self.interpolate_path(rcLast, rc_cell)
+                lrcInterp = self.interpolate_pair(rcLast, rc_cell)
                 lrcStroke.extend(lrcInterp)
                 self.debug("lrcStroke ", len(lrcStroke), rc_cell, "interp:", lrcInterp)
 
@@ -492,6 +501,8 @@ class EditorModel(object):
         # If we have already touched 3 cells
         # We have a transition into a cell, and out of it.
 
+        #print(lrcStroke)
+
         if len(lrcStroke) >= 2:
             # If the first cell in a stroke is empty, add a deadend to cell 0
             if self.env.rail.get_full_transitions(*lrcStroke[0]) == 0:
@@ -500,6 +511,7 @@ class EditorModel(object):
         # Add transitions for groups of 3 cells
         # hence inbound and outbound transitions for middle cell
         while len(lrcStroke) >= 3:
+            #print(lrcStroke)
             self.mod_rail_3cells(lrcStroke, bAddRemove=bAddRemove)
 
         # If final cell empty, insert deadend:
@@ -507,6 +519,8 @@ class EditorModel(object):
             if self.env.rail.get_full_transitions(*lrcStroke[1]) == 0:
                 self.mod_rail_2cells(lrcStroke, bAddRemove, iCellToMod=1)
 
+        #print("final:", lrcStroke)
+
         # now empty out the final two cells from the queue
         lrcStroke.clear()
 
@@ -582,6 +596,8 @@ class EditorModel(object):
                 iTrans = iTrans[0][0]
                 liTrans.append(iTrans)
 
+        #self.log("liTrans:", liTrans)
+
         # check that we have one transition
         if len(liTrans) == 1:
             # Set the transition as a deadend
diff --git a/flatland/utils/env_edit_utils.py b/flatland/utils/env_edit_utils.py
new file mode 100644
index 00000000..ac748469
--- /dev/null
+++ b/flatland/utils/env_edit_utils.py
@@ -0,0 +1,126 @@
+
+
+import flatland.envs.schedule_generators as sg
+import flatland.envs.rail_generators as rg
+import flatland.envs.observations as obs
+from flatland.utils import editor
+from flatland.envs.rail_env import RailEnv
+
+
+# Start and end all agents at the same place
+class SchedGen2(sg.BaseSchedGen):
+    def __init__(self, rcStart, rcEnd, iDir):
+        self.rcStart = rcStart
+        self.rcEnd = rcEnd
+        self.iDir = iDir
+        
+    def generate(self, rail, num_agents, hints=None, num_resets=None, np_random=None) -> sg.Schedule:
+        return sg.Schedule(agent_positions = [self.rcStart] * num_agents, 
+                           agent_directions= [self.iDir] * num_agents,
+                           agent_targets = [self.rcEnd] * num_agents,
+                           agent_speeds = [1.0] * num_agents,
+                           agent_malfunction_rates = None,
+                           max_episode_steps=100)
+
+
+
+# cycle through lists of start, end and direction
+class SchedGen3(sg.BaseSchedGen):
+    def __init__(self, lrcStarts, lrcTargs, liDirs):
+        self.lrcStarts = lrcStarts
+        self.lrcTargs = lrcTargs
+        self.liDirs = liDirs
+        
+    def generate(self, rail, num_agents, hints=None, num_resets=None, np_random=None) -> sg.Schedule:
+        return sg.Schedule(agent_positions = [ self.lrcStarts[i % len(self.lrcStarts)] for i in range(num_agents) ],
+                           agent_directions= [ self.liDirs[i % len(self.liDirs)] for i in range(num_agents) ],
+                           agent_targets = [ self.lrcTargs[i % len(self.lrcTargs)] for i in range(num_agents) ],
+                           agent_speeds = [1.0] * num_agents,
+                           agent_malfunction_rates = None,
+                           max_episode_steps=100)
+
+
+def makeEnv(nAg=2, width=20, height=10, oSG=None):
+    env = RailEnv(width=width, height=height, rail_generator=rg.empty_rail_generator(),
+                number_of_agents=nAg,
+                schedule_generator=oSG,
+                obs_builder_object=obs.TreeObsForRailEnv(max_depth=1))
+
+    envModel = editor.EditorModel(env)
+    env.reset()
+    return env, envModel
+
+
+def makeEnv2(nAg=2, shape=(20,10), llrcPaths=[], lrcStarts=[], lrcTargs=[], liDirs=[], bUCF=True):
+    oSG = SchedGen3(lrcStarts, lrcTargs, liDirs)
+
+    env = RailEnv(width=shape[0], height=shape[1], 
+                rail_generator=rg.empty_rail_generator(),
+                number_of_agents=nAg,
+                schedule_generator=oSG,
+                obs_builder_object=obs.TreeObsForRailEnv(max_depth=1),
+                close_following=bUCF)
+
+    envModel = editor.EditorModel(env)
+    env.reset()
+
+    for lrcPath in llrcPaths:
+        envModel.mod_rail_cell_seq(envModel.interpolate_path(lrcPath))
+
+    return env, envModel
+
+
+def makeTestEnv(sName="single_alternative", nAg=2, bUCF=True):
+
+    ddEnvSpecs = {
+        # opposing stations with single alternative path
+        "single_alternative":{
+            "llrcPaths":  [
+                [(1,0), (1,15)],  # across the top
+                [(1,4), (1,6), (3,6), (3, 12), (1,12), (1,14)], # alternative loop below
+                ],
+            "lrcStarts": [ (1,3), (1,14) ],
+            "lrcTargs" : [(1,14), (1,3)],
+            "liDirs" : [1,3]
+            },
+
+        # single spur so one agent needs to wait
+        "single_spur": {
+            "llrcPaths" : [
+                [(1,0), (1,15)],
+                [(4,0), (4,6), (1,6), (1, 8)]],
+            "lrcStarts": [(1,3), (1,14) ],
+            "lrcTargs" : [(1,14), (4,2)],
+            "liDirs" : [1,3]
+            },
+        
+        # single spur so one agent needs to wait
+        "merging_spurs": {
+            "llrcPaths" : [
+                [(1,0), (1,15), (7, 15), (7,0)],
+                [(4,0), (4,6), (1,6), (1, 8)],
+                #[((1,14), (1,16), (7,16), )]
+                ],
+            "lrcStarts": [(1,2), (4,2) ],
+            "lrcTargs" : [(7,3)],
+            "liDirs" : [1]
+            },
+
+        # Concentric Loops
+        "concentric_loops": {
+            "llrcPaths": [
+                [(1,1), (1,5), (8, 5), (8,1), (1,1), (1,3)],
+                [(1,3), (1,10), (8,10), (8,3)]
+                ],
+            
+            "lrcStarts": [(1,3)],
+            "lrcTargs": [(2,1)],
+            "liDirs":  [1]
+            }
+
+        }
+    
+    dSpec = ddEnvSpecs[sName]
+
+
+    return makeEnv2(nAg=nAg, bUCF=bUCF, **dSpec)
diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py
index 303eaefd..910dec32 100644
--- a/flatland/utils/rendertools.py
+++ b/flatland/utils/rendertools.py
@@ -7,6 +7,8 @@ import numpy as np
 from numpy import array
 from recordtype import recordtype
 
+from flatland.envs.agent_utils import RailAgentStatus
+
 from flatland.utils.graphics_pil import PILGL, PILSVG
 from flatland.utils.graphics_pgl import PGLGL
 
@@ -675,14 +677,15 @@ class RenderLocal(RenderBase):
                     continue
 
                 # Show an agent even if it hasn't already started
-                if show_inactive_agents and (agent.position is None):
-                    # print("agent ", agent_idx, agent.position, agent.old_position, agent.initial_position)
-                    self.gl.set_agent_at(agent_idx, *(agent.initial_position), 
-                        agent.initial_direction, agent.initial_direction,
-                        is_selected=(selected_agent == agent_idx),
-                        rail_grid=env.rail.grid,
-                        show_debug=self.show_debug, clear_debug_text=self.clear_debug_text,
-                        malfunction=False)
+                if agent.position is None:
+                    if show_inactive_agents:
+                        # print("agent ", agent_idx, agent.position, agent.old_position, agent.initial_position)
+                        self.gl.set_agent_at(agent_idx, *(agent.initial_position), 
+                            agent.initial_direction, agent.initial_direction,
+                            is_selected=(selected_agent == agent_idx),
+                            rail_grid=env.rail.grid,
+                            show_debug=self.show_debug, clear_debug_text=self.clear_debug_text,
+                            malfunction=False)
                     continue
 
                 is_malfunction = agent.malfunction_data["malfunction"] > 0
@@ -736,8 +739,16 @@ class RenderLocal(RenderBase):
                     # set_agent_at uses the agent index for the color
                     if self.agent_render_variant == AgentRenderVariant.AGENT_SHOWS_OPTIONS_AND_BOX:
                         self.gl.set_cell_occupied(agent_idx, *(agent.position))
-                    self.gl.set_agent_at(agent_idx, *position, agent.direction, direction, selected_agent == agent_idx,
-                                         rail_grid=env.rail.grid, malfunction=is_malfunction)
+                    
+                    if show_inactive_agents:
+                        show_this_agent=True
+                    else:
+                        show_this_agent = agent.status == RailAgentStatus.ACTIVE
+
+                    if show_this_agent:
+                        self.gl.set_agent_at(agent_idx, *position, agent.direction, direction, 
+                                        selected_agent == agent_idx,
+                                        rail_grid=env.rail.grid, malfunction=is_malfunction)
 
         if show_observations:
             self.render_observation(range(env.get_num_agents()), env.dev_obs_dict)
diff --git a/notebooks/test-collision.ipynb b/notebooks/test-collision.ipynb
new file mode 100644
index 00000000..22b2ae88
--- /dev/null
+++ b/notebooks/test-collision.ipynb
@@ -0,0 +1,211 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "%load_ext autoreload\n",
+    "%autoreload 2\n",
+    "from IPython.core import display \n",
+    "display.display(display.HTML(\"<style>.container { width:95% !important; }</style>\"))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import PIL\n",
+    "from IPython import display\n",
+    "from ipycanvas import canvas\n",
+    "import time"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from flatland.envs import malfunction_generators as malgen\n",
+    "from flatland.envs.agent_utils import EnvAgent\n",
+    "#from flatland.envs import sparse_rail_gen as spgen\n",
+    "from flatland.envs import rail_generators as rail_gen\n",
+    "from flatland.envs import agent_chains as ac\n",
+    "from flatland.envs.rail_env import RailEnv, RailEnvActions\n",
+    "from flatland.envs.persistence import RailEnvPersister\n",
+    "from flatland.utils.rendertools import RenderTool\n",
+    "from flatland.utils import env_edit_utils as eeu"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "UCF: False\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "eee80deb8b17429d88bd0913cb55790f",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Canvas(layout=Layout(height='300px', width='600px'), size=(600, 300))"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "nAg=2\n",
+    "bUCF=False\n",
+    "bReverseStart=False\n",
+    "env, envModel = eeu.makeTestEnv(sName=\"merging_spurs\", nAg=nAg, bUCF=bUCF)\n",
+    "oRT = RenderTool(env, show_debug=True)\n",
+    "oRT.render_env(show_rowcols=True,  show_inactive_agents=True, show_observations=False)\n",
+    "oCan = canvas.Canvas(size=(600,300))\n",
+    "oCan.put_image_data(oRT.get_image())\n",
+    "print(f\"UCF: {bUCF}\")\n",
+    "oCan\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "iPauseStep=10\n",
+    "iPauseLen = 5\n",
+    "\n",
+    "for iStep in range(25):\n",
+    "    \n",
+    "    if bReverseStart:\n",
+    "        iAgentStart = max((nAg - 2 - iStep*2), 0)\n",
+    "    else:\n",
+    "        iAgentStart = 0\n",
+    "    dActions = { i:int(RailEnvActions.MOVE_FORWARD) for i in range(iAgentStart, len(env.agents)) }\n",
+    "    \n",
+    "    if (iStep >= iPauseStep) and (iStep < iPauseStep + iPauseLen):\n",
+    "        dActions[0] = RailEnvActions.STOP_MOVING\n",
+    "        \n",
+    "    #print(dActions)\n",
+    "    \n",
+    "    env.step(dActions)\n",
+    "    \n",
+    "    oRT.render_env(show_rowcols=True,  show_inactive_agents=True, show_observations=False)\n",
+    "    \n",
+    "    aImg = oRT.get_image()\n",
+    "    pilImg = PIL.Image.fromarray(aImg)\n",
+    "    oCan.put_image_data(aImg)\n",
+    "    with open(f\"tmp-images/img-{iStep:03d}.png\", \"wb\") as fImg:\n",
+    "        pilImg.save(fImg)\n",
+    "    \n",
+    "    time.sleep(0.5)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "#!ffmpeg -i ./tmp-images/img-%03d.png -vcodec mpeg4 -filter:v \"setpts=8.0*PTS\" -y movie_nucf_stop.mp4"
+   ]
+  }
+ ],
+ "metadata": {
+  "hide_input": false,
+  "kernelspec": {
+   "display_name": "ve367",
+   "language": "python",
+   "name": "ve367"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.6.7"
+  },
+  "latex_envs": {
+   "LaTeX_envs_menu_present": true,
+   "autoclose": false,
+   "autocomplete": true,
+   "bibliofile": "biblio.bib",
+   "cite_by": "apalike",
+   "current_citInitial": 1,
+   "eqLabelWithNumbers": true,
+   "eqNumInitial": 1,
+   "hotkeys": {
+    "equation": "Ctrl-E",
+    "itemize": "Ctrl-I"
+   },
+   "labels_anchors": false,
+   "latex_user_defs": false,
+   "report_style_numbering": false,
+   "user_envs_cfg": false
+  },
+  "toc": {
+   "base_numbering": 1,
+   "nav_menu": {},
+   "number_sections": true,
+   "sideBar": true,
+   "skip_h1_title": false,
+   "title_cell": "Table of Contents",
+   "title_sidebar": "Contents",
+   "toc_cell": false,
+   "toc_position": {},
+   "toc_section_display": true,
+   "toc_window_display": false
+  },
+  "varInspector": {
+   "cols": {
+    "lenName": 16,
+    "lenType": 16,
+    "lenVar": 40
+   },
+   "kernels_config": {
+    "python": {
+     "delete_cmd_postfix": "",
+     "delete_cmd_prefix": "del ",
+     "library": "var_list.py",
+     "varRefreshCmd": "print(var_dic_list())"
+    },
+    "r": {
+     "delete_cmd_postfix": ") ",
+     "delete_cmd_prefix": "rm(",
+     "library": "var_list.r",
+     "varRefreshCmd": "cat(var_dic_list()) "
+    }
+   },
+   "types_to_exclude": [
+    "module",
+    "function",
+    "builtin_function_or_method",
+    "instance",
+    "_Feature"
+   ],
+   "window_display": false
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
-- 
GitLab