From 20d0f8fb7f2c81535161f68f5d3a5bd70faa2d85 Mon Sep 17 00:00:00 2001
From: Giacomo Spigler <spiglerg@gmail.com>
Date: Wed, 5 Jun 2019 13:38:48 +0200
Subject: [PATCH] tmp mods

---
 examples/simple_example_3.py |  8 +++---
 flatland/envs/agent_utils.py |  8 +++---
 flatland/envs/rail_env.py    | 51 ++++++++++++++++++++++++++++--------
 3 files changed, 49 insertions(+), 18 deletions(-)

diff --git a/examples/simple_example_3.py b/examples/simple_example_3.py
index 1978e27c..8c98eea7 100644
--- a/examples/simple_example_3.py
+++ b/examples/simple_example_3.py
@@ -6,8 +6,8 @@ from flatland.utils.rendertools import RenderTool
 from flatland.envs.observations import TreeObsForRailEnv
 import numpy as np
 
-random.seed(100)
-np.random.seed(100)
+random.seed(10)
+np.random.seed(10)
 
 env = RailEnv(width=7,
               height=7,
@@ -24,7 +24,7 @@ obs, all_rewards, done, _ = env.step({0: 0})
 for i in range(env.get_num_agents()):
     env.obs_builder.util_print_obs_subtree(tree=obs[i], num_features_per_node=5)
 
-env_renderer = RenderTool(env, gl="QT")
+env_renderer = RenderTool(env, gl="PIL")
 env_renderer.renderEnv(show=True)
 
 print("Manual control: s=perform step, q=quit, [agent id] [1-2-3 action] \
@@ -52,4 +52,4 @@ for step in range(100):
             i = i + 1
         i += 1
 
-    env_renderer.renderEnv(show=True)
+    env_renderer.renderEnv(show=True, frames=True)
diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py
index a4dc6962..5105ec5d 100644
--- a/flatland/envs/agent_utils.py
+++ b/flatland/envs/agent_utils.py
@@ -29,17 +29,19 @@ class EnvAgentStatic(object):
     position = attrib()
     direction = attrib()
     target = attrib()
+    moving = attrib()
 
     def __init__(self, position, direction, target):
         self.position = position
         self.direction = direction
         self.target = target
+        self.moving = False
 
     @classmethod
     def from_lists(cls, positions, directions, targets):
         """ Create a list of EnvAgentStatics from lists of positions, directions and targets
         """
-        return list(starmap(EnvAgentStatic, zip(positions, directions, targets)))
+        return list(starmap(EnvAgentStatic, zip(positions, directions, targets, [False]*len(positions))))
 
     def to_list(self):
 
@@ -53,7 +55,7 @@ class EnvAgentStatic(object):
         if type(lTarget) is np.ndarray:
             lTarget = lTarget.tolist()
 
-        return [lPos, int(self.direction), lTarget]
+        return [lPos, int(self.direction), lTarget, int(self.moving)]
 
 
 @attrs
@@ -77,7 +79,7 @@ class EnvAgent(EnvAgentStatic):
     def to_list(self):
         return [
             self.position, self.direction, self.target, self.handle, 
-            self.old_direction, self.old_position]
+            self.old_direction, self.old_position, self.moving]
 
     @classmethod
     def from_static(cls, oStatic):
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index a23b6ac7..ba4e9e46 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -9,6 +9,7 @@ a GridTransitionMap object.
 
 import msgpack
 import numpy as np
+from enum import IntEnum
 
 from flatland.core.env import Environment
 from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent
@@ -21,6 +22,14 @@ from flatland.envs.observations import TreeObsForRailEnv
 # from flatland.core.transition_map import GridTransitionMap
 
 
+class RailEnvActions(IntEnum):
+    DO_NOTHING = 0
+    MOVE_LEFT = 1
+    MOVE_FORWARD = 2
+    MOVE_RIGHT = 3
+    STOP_MOVING = 4
+
+
 class RailEnv(Environment):
     """
     RailEnv environment class.
@@ -32,9 +41,10 @@ class RailEnv(Environment):
 
     The valid actions in the environment are:
         0: do nothing
-        1: turn left and move to the next cell
-        2: move to the next cell in front of the agent
-        3: turn right and move to the next cell
+        1: turn left and move to the next cell; if the agent was not moving, movement is started
+        2: move to the next cell in front of the agent; if the agent was not moving, movement is started
+        3: turn right and move to the next cell; if the agent was not moving, movement is started
+        4: stop moving
 
     Moving forward in a dead-end cell makes the agent turn 180 degrees and step
     to the cell it came from.
@@ -176,7 +186,7 @@ class RailEnv(Environment):
         alpha = 1.0
         beta = 1.0
 
-        invalid_action_penalty = -2
+        invalid_action_penalty = 0 # -2 GIACOMO: we decided that invalid actions will carry no penalty
         step_penalty = -1 * alpha
         global_reward = 1 * beta
 
@@ -198,7 +208,11 @@ class RailEnv(Environment):
             agent = self.agents[iAgent]
 
             if iAgent not in action_dict:  # no action has been supplied for this agent
-                continue
+                if agent.moving:
+                    # Keep moving
+                    action_dict[iAgent] = RailEnvActions.MOVE_FORWARD
+                else:
+                    action_dict[iAgent] = RailEnvActions.DO_NOTHING
 
             if self.dones[iAgent]:  # this agent has already completed...
                 # print("rail_env.py @", currentframe().f_back.f_lineno, " agent ", iAgent,
@@ -206,12 +220,27 @@ class RailEnv(Environment):
                 continue
             action = action_dict[iAgent]
 
-            if action < 0 or action > 3:
+            if action < 0 or action > len(RailEnvActions):
                 print('ERROR: illegal action=', action,
                       'for agent with index=', iAgent)
                 return
 
-            if action > 0:
+            if action == RailEnvActions.DO_NOTHING and agent.moving:
+                # Keep moving
+                action_dict[iAgent] = RailEnvActions.MOVE_FORWARD
+                action = RailEnvActions.MOVE_FORWARD
+
+            if action == RailEnvActions.STOP_MOVING and agent.moving:
+                action_dict[iAgent] = RailEnvActions.DO_NOTHING
+                action = RailEnvActions.DO_NOTHING
+                agent.moving = False
+                # TODO: possibly, penalty for stopping!
+
+            if not agent.moving and (action == RailEnvActions.MOVE_LEFT or action == RailEnvActions.MOVE_FORWARD or action == RailEnvActions.MOVE_RIGHT):
+                agent.moving = True
+                # TODO: possibly, may add a penalty for starting, but the best is only for stopping (GIACOMO's opinion)
+
+            if action != RailEnvActions.DO_NOTHING and action != RailEnvActions.STOP_MOVING:
                 # pos = agent.position #  self.agents_position[i]
                 # direction = agent.direction # self.agents_direction[i]
 
@@ -293,7 +322,7 @@ class RailEnv(Environment):
 
         # Reset the step actions (in case some agent doesn't 'register_action'
         # on the next step)
-        self.actions = [0] * self.get_num_agents()
+        self.actions = [RailEnvActions.DO_NOTHING] * self.get_num_agents()
         return self._get_observations(), self.rewards_dict, self.dones, {}
 
     def check_action(self, agent, action):
@@ -303,19 +332,19 @@ class RailEnv(Environment):
 
         new_direction = agent.direction
         # print(nbits,np.sum(possible_transitions))
-        if action == 1:
+        if action == RailEnvActions.MOVE_LEFT:
             new_direction = agent.direction - 1
             if num_transitions <= 1:
                 transition_isValid = False
 
-        elif action == 3:
+        elif action == RailEnvActions.MOVE_RIGHT:
             new_direction = agent.direction + 1
             if num_transitions <= 1:
                 transition_isValid = False
 
         new_direction %= 4
 
-        if action == 2:
+        if action == RailEnvActions.MOVE_FORWARD:
             if num_transitions == 1:
                 # - dead-end, straight line or curved line;
                 # new_direction will be the only valid transition
-- 
GitLab