From 1b4de9444f3caa4415b77ea0a93da0a31287b511 Mon Sep 17 00:00:00 2001
From: u229589 <christian.baumberger@sbb.ch>
Date: Tue, 5 Nov 2019 09:28:35 +0100
Subject: [PATCH] include improvements suggested by Ch. Eichenberger

---
 changelog.md                                  |  3 +
 flatland/envs/agent_utils.py                  | 32 +++++++++--
 flatland/envs/rail_env.py                     | 10 ++--
 flatland/envs/schedule_generators.py          |  2 +-
 flatland/utils/editor.py                      | 57 +++++++++----------
 ...t_flatland_envs_rail_env_shortest_paths.py |  3 +
 6 files changed, 67 insertions(+), 40 deletions(-)

diff --git a/changelog.md b/changelog.md
index 5eba35db..543fc87a 100644
--- a/changelog.md
+++ b/changelog.md
@@ -3,6 +3,9 @@ Changelog
 
 Changes since Flatland 2.0.0
 --------------------------
+### Changes in `EnvAgent`
+- class `EnvAgentStatic` was removed, so there is only class `EnvAgent` left which should simplify the handling of agents. The member `self.agents_static` of `RailEnv` was therefore also removed. Old Scence saved as pickle files cannot be loaded anymore.
+
 ### Changes in malfunction behavior
 - agent attribute `next_malfunction`is not used anymore, it will be removed fully in future versions.
 - `break_agent()` function is introduced which induces malfunctions in agent according to poisson process
diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py
index 9f86f41b..bcc0fcc6 100644
--- a/flatland/envs/agent_utils.py
+++ b/flatland/envs/agent_utils.py
@@ -1,6 +1,6 @@
 from enum import IntEnum
 from itertools import starmap
-from typing import Tuple, Optional
+from typing import Tuple, Optional, NamedTuple
 
 from attr import attrs, attrib, Factory
 
@@ -15,6 +15,20 @@ class RailAgentStatus(IntEnum):
     DONE_REMOVED = 3  # removed from grid (position is None) -> prediction is None
 
 
+Agent = NamedTuple('Agent', [('initial_position', Tuple[int, int]),
+                             ('initial_direction', Grid4TransitionsEnum),
+                             ('direction', Grid4TransitionsEnum),
+                             ('target', Tuple[int, int]),
+                             ('moving', bool),
+                             ('speed_data', dict),
+                             ('malfunction_data', dict),
+                             ('handle', int),
+                             ('status', RailAgentStatus),
+                             ('position', Tuple[int, int]),
+                             ('old_direction', Grid4TransitionsEnum),
+                             ('old_position', Tuple[int, int])])
+
+
 @attrs
 class EnvAgent:
 
@@ -55,10 +69,18 @@ class EnvAgent:
         self.old_direction = None
         self.moving = False
 
-    def to_list(self):
-        return [self.initial_position, self.initial_direction, int(self.direction), self.target, int(self.moving),
-                self.speed_data, self.malfunction_data, self.handle, self.status, self.position, self.old_direction,
-                self.old_position]
+    def move(self, new_pos: Tuple[int, int], new_dir: Tuple[int, int] = None):
+        self.old_position = self.position
+        self.position = new_pos
+        if new_dir is not None:
+            self.old_direction = self.direction
+            self.direction = new_dir
+
+    def to_agent(self) -> Agent:
+        return Agent(initial_position=self.initial_position, initial_direction=self.initial_direction,
+                     direction=self.direction, target=self.target, moving=self.moving, speed_data=self.speed_data,
+                     malfunction_data=self.malfunction_data, handle=self.handle, status=self.status,
+                     position=self.position, old_direction=self.old_direction, old_position=self.old_position)
 
     @classmethod
     def from_schedule(cls, schedule: Schedule):
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 7eacb528..f6537274 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -813,7 +813,7 @@ class RailEnv(Environment):
         Returns state of environment in msgpack object
         """
         grid_data = self.rail.grid.tolist()
-        agent_data = [agent.to_list() for agent in self.agents]
+        agent_data = [agent.to_agent() for agent in self.agents]
         msgpack.packb(grid_data, use_bin_type=True)
         msgpack.packb(agent_data, use_bin_type=True)
         msg_data = {
@@ -825,7 +825,7 @@ class RailEnv(Environment):
         """
         Returns agents information in msgpack object
         """
-        agent_data = [agent.to_list() for agent in self.agents]
+        agent_data = [agent.to_agent() for agent in self.agents]
         msg_data = {
             "agents": agent_data}
         return msgpack.packb(msg_data, use_bin_type=True)
@@ -841,7 +841,7 @@ class RailEnv(Environment):
         data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8')
         self.rail.grid = np.array(data["grid"])
         # agents are always reset as not moving
-        self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8], d[9], d[10], d[11]) for d in data["agents"]]
+        self.agents = [EnvAgent(*d[0:12]) for d in data["agents"]]
         # setup with loaded data
         self.height, self.width = self.rail.grid.shape
         self.rail.height = self.height
@@ -859,7 +859,7 @@ class RailEnv(Environment):
         data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8')
         self.rail.grid = np.array(data["grid"])
         # agents are always reset as not moving
-        self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8], d[9], d[10], d[11]) for d in data["agents"]]
+        self.agents = [EnvAgent(*d[0:12]) for d in data["agents"]]
         if "distance_map" in data.keys():
             self.distance_map.set(data["distance_map"])
         # setup with loaded data
@@ -873,7 +873,7 @@ class RailEnv(Environment):
         Returns environment information with distance map information as msgpack object
         """
         grid_data = self.rail.grid.tolist()
-        agent_data = [agent.to_list() for agent in self.agents]
+        agent_data = [agent.to_agent() for agent in self.agents]
         msgpack.packb(grid_data, use_bin_type=True)
         msgpack.packb(agent_data, use_bin_type=True)
         distance_map_data = self.distance_map.get()
diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py
index be19fda5..656bf70c 100644
--- a/flatland/envs/schedule_generators.py
+++ b/flatland/envs/schedule_generators.py
@@ -291,7 +291,7 @@ def schedule_from_file(filename, load_from_package=None) -> ScheduleGenerator:
             with open(filename, "rb") as file_in:
                 load_data = file_in.read()
         data = msgpack.unpackb(load_data, use_list=False, encoding='utf-8')
-        agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8], d[9], d[10], d[11]) for d in data["agents"]]
+        agents = [EnvAgent(*d[0:12]) for d in data["agents"]]
 
         # setup with loaded data
         agents_position = [a.initial_position for a in agents]
diff --git a/flatland/utils/editor.py b/flatland/utils/editor.py
index ffc0e30f..eadea813 100644
--- a/flatland/utils/editor.py
+++ b/flatland/utils/editor.py
@@ -178,12 +178,12 @@ class View(object):
             self.writableData[(y - 2):(y + 2), (x - 2):(x + 2), :3] = 0
 
     def xy_to_rc(self, x, y):
-        rcCell = ((array([y, x]) - self.yxBase))
+        rc_cell = ((array([y, x]) - self.yxBase))
         nX = np.floor((self.yxSize[0] - self.yxBase[0]) / self.model.env.height)
         nY = np.floor((self.yxSize[1] - self.yxBase[1]) / self.model.env.width)
-        rcCell[0] = max(0, min(np.floor(rcCell[0] / nY), self.model.env.height - 1))
-        rcCell[1] = max(0, min(np.floor(rcCell[1] / nX), self.model.env.width - 1))
-        return rcCell
+        rc_cell[0] = max(0, min(np.floor(rc_cell[0] / nY), self.model.env.height - 1))
+        rc_cell[1] = max(0, min(np.floor(rc_cell[1] / nX), self.model.env.width - 1))
+        return rc_cell
 
     def log(self, *args, **kwargs):
         if self.output_generator:
@@ -215,23 +215,23 @@ class Controller(object):
         y = event['canvasY']
         self.debug("debug:", x, y)
 
-        rcCell = self.view.xy_to_rc(x, y)
+        rc_cell = self.view.xy_to_rc(x, y)
 
         bShift = event["shiftKey"]
         bCtrl = event["ctrlKey"]
         bAlt = event["altKey"]
         if bCtrl and not bShift and not bAlt:
-            self.model.click_agent(rcCell)
+            self.model.click_agent(rc_cell)
             self.lrcStroke = []
         elif bShift and bCtrl:
-            self.model.add_target(rcCell)
+            self.model.add_target(rc_cell)
             self.lrcStroke = []
         elif bAlt and not bShift and not bCtrl:
-            self.model.clear_cell(rcCell)
+            self.model.clear_cell(rc_cell)
             self.lrcStroke = []
 
-        self.debug("click in cell", rcCell)
-        self.model.debug_cell(rcCell)
+        self.debug("click in cell", rc_cell)
+        self.model.debug_cell(rc_cell)
 
         if self.model.selected_agent is not None:
             self.lrcStroke = []
@@ -304,8 +304,8 @@ class Controller(object):
                     self.view.drag_path_element(x, y)
 
                     # Translate and scale from x,y to integer row,col (note order change)
-                    rcCell = self.view.xy_to_rc(x, y)
-                    self.editor.drag_path_element(rcCell)
+                    rc_cell = self.view.xy_to_rc(x, y)
+                    self.editor.drag_path_element(rc_cell)
 
                 self.view.redisplay_image()
 
@@ -413,12 +413,12 @@ class EditorModel(object):
     def set_draw_mode(self, draw_mode):
         self.draw_mode = draw_mode
 
-    def interpolate_path(self, rcLast, rcCell):
-        if np.array_equal(rcLast, rcCell):
+    def interpolate_path(self, rcLast, rc_cell):
+        if np.array_equal(rcLast, rc_cell):
             return []
         rcLast = array(rcLast)
-        rcCell = array(rcCell)
-        rcDelta = rcCell - rcLast
+        rc_cell = array(rc_cell)
+        rcDelta = rc_cell - rcLast
 
         lrcInterp = []  # extra row,col points
 
@@ -450,7 +450,7 @@ class EditorModel(object):
             lrcInterp = list(map(tuple, g2Interp))
         return lrcInterp
 
-    def drag_path_element(self, rcCell):
+    def drag_path_element(self, rc_cell):
         """Mouse motion event handler for drawing.
         """
         lrcStroke = self.lrcStroke
@@ -458,15 +458,15 @@ class EditorModel(object):
         # Store the row,col location of the click, if we have entered a new cell
         if len(lrcStroke) > 0:
             rcLast = lrcStroke[-1]
-            if not np.array_equal(rcLast, rcCell):  # only save at transition
-                lrcInterp = self.interpolate_path(rcLast, rcCell)
+            if not np.array_equal(rcLast, rc_cell):  # only save at transition
+                lrcInterp = self.interpolate_path(rcLast, rc_cell)
                 lrcStroke.extend(lrcInterp)
-                self.debug("lrcStroke ", len(lrcStroke), rcCell, "interp:", lrcInterp)
+                self.debug("lrcStroke ", len(lrcStroke), rc_cell, "interp:", lrcInterp)
 
         else:
             # This is the first cell in a mouse stroke
-            lrcStroke.append(rcCell)
-            self.debug("lrcStroke ", len(lrcStroke), rcCell)
+            lrcStroke.append(rc_cell)
+            self.debug("lrcStroke ", len(lrcStroke), rc_cell)
 
     def mod_path(self, bAddRemove):
         # disabled functionality (no longer required)
@@ -701,8 +701,7 @@ class EditorModel(object):
             else:
                 # Move the selected agent to this cell
                 agent = self.env.agents[self.selected_agent]
-                agent.position = cell_row_col
-                agent.old_position = cell_row_col
+                agent.move(cell_row_col)
         else:
             # Yes
             # Have they clicked on the agent already selected?
@@ -715,9 +714,9 @@ class EditorModel(object):
 
         self.redraw()
 
-    def add_target(self, rcCell):
+    def add_target(self, rc_cell):
         if self.selected_agent is not None:
-            self.env.agents[self.selected_agent].target = rcCell
+            self.env.agents[self.selected_agent].target = rc_cell
             self.view.oRT.update_background()
             self.redraw()
 
@@ -735,11 +734,11 @@ class EditorModel(object):
         if self.debug_bool:
             self.log(*args, **kwargs)
 
-    def debug_cell(self, rcCell):
-        binTrans = self.env.rail.get_full_transitions(*rcCell)
+    def debug_cell(self, rc_cell):
+        binTrans = self.env.rail.get_full_transitions(*rc_cell)
         sbinTrans = format(binTrans, "#018b")[2:]
         self.debug("cell ",
-                   rcCell,
+                   rc_cell,
                    "Transitions: ",
                    binTrans,
                    sbinTrans,
diff --git a/tests/test_flatland_envs_rail_env_shortest_paths.py b/tests/test_flatland_envs_rail_env_shortest_paths.py
index 5a0c35df..8b066028 100644
--- a/tests/test_flatland_envs_rail_env_shortest_paths.py
+++ b/tests/test_flatland_envs_rail_env_shortest_paths.py
@@ -43,6 +43,7 @@ def test_get_shortest_paths_unreachable():
 
 
 # todo file test_002.pkl has to be generated automatically
+# see https://gitlab.aicrowd.com/flatland/flatland/issues/279
 @pytest.mark.skip
 def test_get_shortest_paths():
     env = load_flatland_environment_from_file('test_002.pkl', 'env_data.tests')
@@ -174,6 +175,7 @@ def test_get_shortest_paths():
 
 
 # todo file test_002.pkl has to be generated automatically
+# see https://gitlab.aicrowd.com/flatland/flatland/issues/279
 @pytest.mark.skip
 def test_get_shortest_paths_max_depth():
     env = load_flatland_environment_from_file('test_002.pkl', 'env_data.tests')
@@ -205,6 +207,7 @@ def test_get_shortest_paths_max_depth():
 
 
 # todo file Level_distance_map_shortest_path.pkl has to be generated automatically
+# see https://gitlab.aicrowd.com/flatland/flatland/issues/279
 @pytest.mark.skip
 def test_get_shortest_paths_agent_handle():
     env = load_flatland_environment_from_file('Level_distance_map_shortest_path.pkl', 'env_data.tests')
-- 
GitLab