From 5bdf72e03457ce69f0b36d364802580852037746 Mon Sep 17 00:00:00 2001
From: hagrid67 <jdhwatson@gmail.com>
Date: Tue, 14 May 2019 21:39:07 +0100
Subject: [PATCH] SVG trains bend around corners

---
 examples/play_model.py        | 20 +++++------
 flatland/envs/agent_utils.py  |  3 +-
 flatland/envs/rail_env.py     | 68 +++++++++++++++++++----------------
 flatland/utils/render_qt.py   | 26 +++++++++-----
 flatland/utils/rendertools.py | 26 +++++++++++---
 flatland/utils/svg.py         | 10 +++---
 6 files changed, 91 insertions(+), 62 deletions(-)

diff --git a/examples/play_model.py b/examples/play_model.py
index 4f6ba79..f80d9a1 100644
--- a/examples/play_model.py
+++ b/examples/play_model.py
@@ -106,10 +106,7 @@ def main(render=True, delay=0.0):
 
     if render:
         env_renderer = RenderTool(env, gl="QTSVG")
-    # plt.figure(figsize=(5,5))
-    # fRedis = redis.Redis()
-
-    # handle = env.get_agent_handles()
+        # env_renderer = RenderTool(env, gl="QT")
 
     state_size = 105
     action_size = 4
@@ -167,6 +164,14 @@ def main(render=True, delay=0.0):
                 action_prob[action] += 1
                 action_dict.update({a: action})
 
+            if render:
+                env_renderer.renderEnv(show=True, frames=True, iEpisode=trials, iStep=step, action_dict=action_dict)
+                #time.sleep(10)
+                if delay > 0:
+                    time.sleep(delay)
+
+            iFrame += 1
+
             # Environment step
             next_obs, all_rewards, done, _ = env.step(action_dict)
             for a in range(env.get_num_agents()):
@@ -177,13 +182,6 @@ def main(render=True, delay=0.0):
                 agent.step(obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a])
                 score += all_rewards[a]
 
-            if render:
-                env_renderer.renderEnv(show=True, frames=True, iEpisode=trials, iStep=step)
-                #time.sleep(10)
-                if delay > 0:
-                    time.sleep(delay)
-
-            iFrame += 1
 
             obs = next_obs.copy()
             if done['__all__']:
diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py
index 6ef8aa4..4920946 100644
--- a/flatland/envs/agent_utils.py
+++ b/flatland/envs/agent_utils.py
@@ -23,8 +23,7 @@ class EnvAgentStatic(object):
     position = attrib()
     direction = attrib()
     target = attrib()
-
-    next_handle = 0  # this is not properly implemented
+    old_direction = attrib(default=None)
 
     @classmethod
     def from_lists(cls, positions, directions, targets):
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 33b9f96..adb40f0 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -208,34 +208,11 @@ class RailEnv(Environment):
                 # compute number of possible transitions in the current
                 # cell used to check for invalid actions
 
-                possible_transitions = self.rail.get_transitions((*agent.position, agent.direction))
-                num_transitions = np.count_nonzero(possible_transitions)
-
-                movement = agent.direction
-                # print(nbits,np.sum(possible_transitions))
-                if action == 1:
-                    movement = agent.direction - 1
-                    if num_transitions <= 1:
-                        transition_isValid = False
-
-                elif action == 3:
-                    movement = agent.direction + 1
-                    if num_transitions <= 1:
-                        transition_isValid = False
-
-                movement %= 4
-
-                if action == 2:
-                    if num_transitions == 1:
-                        # - dead-end, straight line or curved line;
-                        # movement will be the only valid transition
-                        # - take only available transition
-                        movement = np.argmax(possible_transitions)
-                        transition_isValid = True
-
-                new_position = get_new_position(agent.position, movement)
+                new_direction, transition_isValid = self.check_action(agent, action)
+
+                new_position = get_new_position(agent.position, new_direction)
                 # Is it a legal move?
-                # 1) transition allows the movement in the cell,
+                # 1) transition allows the new_direction in the cell,
                 # 2) the new cell is not empty (case 0),
                 # 3) the cell is free, i.e., no agent is currently in that cell
                 
@@ -259,7 +236,7 @@ class RailEnv(Environment):
                 if transition_isValid is None:
                     transition_isValid = self.rail.get_transition(
                         (*agent.position, agent.direction),
-                        movement)
+                        new_direction)
 
                 # cell_isFree = True
                 # for j in range(self.number_of_agents):
@@ -272,12 +249,13 @@ class RailEnv(Environment):
                         np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1))
 
                 if all([new_cell_isValid, transition_isValid, cell_isFree]):
-                    # move and change direction to face the movement that was
+                    # move and change direction to face the new_direction that was
                     # performed
                     # self.agents_position[i] = new_position
-                    # self.agents_direction[i] = movement
+                    # self.agents_direction[i] = new_direction
                     agent.position = new_position
-                    agent.direction = movement
+                    agent.old_direction = agent.direction
+                    agent.direction = new_direction
                 else:
                     # the action was not valid, add penalty
                     self.rewards_dict[iAgent] += invalid_action_penalty
@@ -307,6 +285,34 @@ class RailEnv(Environment):
         self.actions = [0] * self.get_num_agents()
         return self._get_observations(), self.rewards_dict, self.dones, {}
 
+    def check_action(self, agent, action):
+        transition_isValid = None
+        possible_transitions = self.rail.get_transitions((*agent.position, agent.direction))
+        num_transitions = np.count_nonzero(possible_transitions)
+
+        new_direction = agent.direction
+        # print(nbits,np.sum(possible_transitions))
+        if action == 1:
+            new_direction = agent.direction - 1
+            if num_transitions <= 1:
+                transition_isValid = False
+
+        elif action == 3:
+            new_direction = agent.direction + 1
+            if num_transitions <= 1:
+                transition_isValid = False
+
+        new_direction %= 4
+
+        if action == 2:
+            if num_transitions == 1:
+                # - dead-end, straight line or curved line;
+                # new_direction will be the only valid transition
+                # - take only available transition
+                new_direction = np.argmax(possible_transitions)
+                transition_isValid = True
+        return new_direction, transition_isValid
+
     def _get_observations(self):
         self.obs_dict = {}
         # for handle in self.agents_handles:
diff --git a/flatland/utils/render_qt.py b/flatland/utils/render_qt.py
index 5608806..f70afae 100644
--- a/flatland/utils/render_qt.py
+++ b/flatland/utils/render_qt.py
@@ -175,38 +175,46 @@ class QTSVG(GraphicsLayer):
             svgWidget.renderer().load(bySVG)
             self.layout.addWidget(svgWidget, row, col)
             self.lwTrack.append(svgWidget)
+        else:
+            print("Illegal rail:", row, col, format(binTrans, "#018b")[2:])
 
     def setAgentAt(self, iAgent, row, col, iDirIn, iDirOut):
         if iAgent < len(self.lwAgents):
             wAgent = self.lwAgents[iAgent]
             agentPrev = self.agents_prev[iAgent]
+
+            # If we have an existing agent widget, we can just move it
             if wAgent is not None:
                 self.layout.removeWidget(wAgent)
                 self.layout.addWidget(wAgent, row, col)
 
-                if agentPrev.direction == iDirIn:
-                    # print("moved ", iAgent, row, col, iDirIn)
+                # We can only reuse the image if noth new and old are straight and the same:
+                if iDirIn == iDirOut and \
+                        agentPrev.direction == iDirIn and \
+                        agentPrev.old_direction == agentPrev.direction:
                     return
                 else:
+                    # need to load new image
                     # print("new dir:", iAgent, row, col, agentPrev.direction, iDirIn)
-                    agentPrev.direction = iDirIn
+                    agentPrev.direction = iDirOut
+                    agentPrev.old_direction = iDirIn
                     sSVG = self.zug.getSvg(iAgent, iDirIn, iDirOut).to_string()
                     bySVG = bytearray(sSVG, encoding='utf-8')
                     wAgent.renderer().load(bySVG)
                     return
 
-        else:
-            # Ensure we have adequate slots in the list lwAgents
-            for i in range(len(self.lwAgents), iAgent+1):
-                self.lwAgents.append(None)
-                self.agents_prev.append(None)
+        # Ensure we have adequate slots in the list lwAgents
+        for i in range(len(self.lwAgents), iAgent+1):
+            self.lwAgents.append(None)
+            self.agents_prev.append(None)
 
+        # Create a new widget for the agent
         sSVG = self.zug.getSvg(iAgent, iDirIn, iDirOut).to_string()
         bySVG = bytearray(sSVG, encoding='utf-8')
         svgWidget = QtSvg.QSvgWidget()
         svgWidget.renderer().load(bySVG)
         self.lwAgents[iAgent] = svgWidget
-        self.agents_prev[iAgent] = EnvAgent((row, col), iDirIn, (0, 0))
+        self.agents_prev[iAgent] = EnvAgent((row, col), iDirOut, (0, 0), old_direction=iDirIn)
         self.layout.addWidget(svgWidget, row, col)
         # print("Created ", iAgent, row, col)
 
diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py
index 16fcd05..18bc3c8 100644
--- a/flatland/utils/rendertools.py
+++ b/flatland/utils/rendertools.py
@@ -476,7 +476,8 @@ class RenderTool(object):
             self, show=False, curves=True, spacing=False,
             arrows=False, agents=True, sRailColor="gray",
             frames=False, iEpisode=None, iStep=None,
-            iSelectedAgent=None):
+            iSelectedAgent=None,
+            action_dict=None):
         """
         Draw the environment using matplotlib.
         Draw into the figure if provided.
@@ -489,7 +490,7 @@ class RenderTool(object):
             self.renderEnv2(show, curves, spacing,
             arrows, agents, sRailColor,
             frames, iEpisode, iStep,
-            iSelectedAgent)
+            iSelectedAgent, action_dict)
             return
 
         # cell_size is a bit pointless with matplotlib - it does not relate to pixels,
@@ -691,7 +692,8 @@ class RenderTool(object):
             self, show=False, curves=True, spacing=False,
             arrows=False, agents=True, sRailColor="gray",
             frames=False, iEpisode=None, iStep=None,
-            iSelectedAgent=None):
+            iSelectedAgent=None,
+            action_dict=dict()):
         """
         Draw the environment using matplotlib.
         Draw into the figure if provided.
@@ -715,10 +717,24 @@ class RenderTool(object):
         for iAgent, agent in enumerate(self.env.agents):
             if agent is None:
                 continue
-            self.gl.setAgentAt(iAgent, *agent.position, agent.direction, agent.direction)
 
+            new_direction = agent.direction
+            action_isValid = False
+
+            if iAgent in action_dict:
+                iAction = action_dict[iAgent]
+                new_direction, action_isValid = self.env.check_action(agent, iAction)
+            
+            if action_isValid:
+                self.gl.setAgentAt(iAgent, *agent.position, agent.direction, new_direction)
+            else:
+                pass
+                # print("invalid action - agent ", iAgent, " bend ", agent.direction, new_direction)
+                # self.gl.setAgentAt(iAgent, *agent.position, agent.direction, new_direction)
+                
         self.gl.show()
-        self.gl.processEvents()
+        for i in range(3):
+            self.gl.processEvents()
 
         self.iFrame += 1
         return
diff --git a/flatland/utils/svg.py b/flatland/utils/svg.py
index f9daf19..bd6ab4e 100644
--- a/flatland/utils/svg.py
+++ b/flatland/utils/svg.py
@@ -65,19 +65,21 @@ class Zug(object):
 
     def getSvg(self, iAgent, iDirIn, iDirOut):
         delta_dir = (iDirOut - iDirIn) % 4
+        # if delta_dir != 0:
+        #    print("Bend:", iAgent, iDirIn, iDirOut)
 
         if delta_dir in (0, 2):
             svg = self.svg_straight.copy()
             svg.set_rotate(iDirIn * 90)
             return svg
         
-        if delta_dir == 1:
+        if delta_dir == 1:  # bend to right, eg N->E, E->S
             svg = self.svg_curve1.copy()
             svg.set_rotate((iDirIn - 1) * 90)
             return svg
 
-        elif delta_dir == 3:
-            svg = self.svg_curve1.copy()
+        elif delta_dir == 3:  # bend to left, eg N->W
+            svg = self.svg_curve2.copy()
             svg.set_rotate(iDirIn * 90)
             return svg
 
@@ -94,7 +96,7 @@ class Track(object):
             "ES NW": "Gleis_Kurve_unten_links.svg",
             "NE WS": "Gleis_Kurve_unten_rechts.svg",
             "NN SS": "Gleis_vertikal.svg",
-            "NN SS ES NW SE WN": "Weiche_Double_Slip.svg",
+            "NN SS EE WW ES NW SE WN": "Weiche_Double_Slip.svg",
             "EE WW EN SW": "Weiche_horizontal_oben_links.svg",
             "EE WW SE WN": "Weiche_horizontal_oben_rechts.svg",
             "EE WW ES NW": "Weiche_horizontal_unten_links.svg",
-- 
GitLab