From 4fae3ccb8ed4aa54e1a3bf4cf1ef697672263734 Mon Sep 17 00:00:00 2001
From: u214892 <u214892@sbb.ch>
Date: Thu, 5 Sep 2019 00:47:29 +0200
Subject: [PATCH] #167 bugfix action_on_cellexit

---
 flatland/envs/rail_env.py        | 77 ++++++++++++++++----------------
 flatland/envs/rail_generators.py |  4 +-
 2 files changed, 41 insertions(+), 40 deletions(-)

diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 8b4f43fe..cc85604e 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -4,6 +4,7 @@ Definition of the RailEnv environment.
 # TODO:  _ this is a global method --> utils or remove later
 import warnings
 from enum import IntEnum
+from typing import List
 
 import msgpack
 import msgpack_numpy as m
@@ -165,8 +166,8 @@ class RailEnv(Environment):
         self.dev_obs_dict = {}
         self.dev_pred_dict = {}
 
-        self.agents = [None] * number_of_agents  # live agents
-        self.agents_static = [None] * number_of_agents  # static agent information
+        self.agents: List[EnvAgent] = [None] * number_of_agents  # live agents
+        self.agents_static: List[EnvAgentStatic] = [None] * number_of_agents  # static agent information
         self.num_resets = 0
 
         self.action_space = [1]
@@ -239,17 +240,17 @@ class RailEnv(Environment):
             self.height, self.width = self.rail.grid.shape
             for r in range(self.height):
                 for c in range(self.width):
-                    rcPos = (r, c)
-                    check = self.rail.cell_neighbours_valid(rcPos, True)
+                    rc_pos = (r, c)
+                    check = self.rail.cell_neighbours_valid(rc_pos, True)
                     if not check:
-                        warnings.warn("Invalid grid at {} -> {}".format(rcPos, check))
+                        warnings.warn("Invalid grid at {} -> {}".format(rc_pos, check))
 
         if replace_agents:
             agents_hints = None
             if optionals and 'agents_hints' in optionals:
                 agents_hints = optionals['agents_hints']
             self.agents_static = EnvAgentStatic.from_lists(
-                *self.schedule_generator(self.rail, self.get_num_agents(), hints=agents_hints))
+                *self.schedule_generator(self.rail, self.get_num_agents(), agents_hints))
         self.restart_agents()
 
         for i_agent in range(self.get_num_agents()):
@@ -284,25 +285,24 @@ class RailEnv(Environment):
             agent.malfunction_data['next_malfunction'] -= 1
 
         # Only agents that have a positive rate for malfunctions and are not currently broken are considered
-        if agent.malfunction_data['malfunction_rate'] > 0 >= agent.malfunction_data['malfunction']:
-
-            # If counter has come to zero --> Agent has malfunction
-            # set next malfunction time and duration of current malfunction
-            if agent.malfunction_data['next_malfunction'] <= 0:
-                # Increase number of malfunctions
-                agent.malfunction_data['nr_malfunctions'] += 1
-
-                # Next malfunction in number of stops
-                next_breakdown = int(
-                    np.random.exponential(scale=agent.malfunction_data['malfunction_rate']))
-                agent.malfunction_data['next_malfunction'] = next_breakdown
-
-                # Duration of current malfunction
-                num_broken_steps = np.random.randint(self.min_number_of_steps_broken,
-                                                     self.max_number_of_steps_broken + 1) + 1
-                agent.malfunction_data['malfunction'] = num_broken_steps
-
-                return True
+        # If counter has come to zero --> Agent has malfunction
+        # set next malfunction time and duration of current malfunction
+        if agent.malfunction_data['malfunction_rate'] > 0 >= agent.malfunction_data['malfunction'] and \
+            agent.malfunction_data['next_malfunction'] <= 0:
+            # Increase number of malfunctions
+            agent.malfunction_data['nr_malfunctions'] += 1
+
+            # Next malfunction in number of stops
+            next_breakdown = int(
+                np.random.exponential(scale=agent.malfunction_data['malfunction_rate']))
+            agent.malfunction_data['next_malfunction'] = next_breakdown
+
+            # Duration of current malfunction
+            num_broken_steps = np.random.randint(self.min_number_of_steps_broken,
+                                                 self.max_number_of_steps_broken + 1) + 1
+            agent.malfunction_data['malfunction'] = num_broken_steps
+
+            return True
         return False
 
     def step(self, action_dict_):
@@ -353,6 +353,20 @@ class RailEnv(Environment):
             # TODO refactor!!!
             # If the agent can make an action
             if agent.speed_data['position_fraction'] == 0.0:
+                if action == RailEnvActions.DO_NOTHING and agent.moving:
+                    # Keep moving
+                    action = RailEnvActions.MOVE_FORWARD
+
+                if action == RailEnvActions.STOP_MOVING and agent.moving and agent.speed_data['position_fraction'] == 0.0:
+                    # Only allow halting an agent on entering new cells.
+                    agent.moving = False
+                    self.rewards_dict[i_agent] += self.stop_penalty
+
+                if not agent.moving and not (action == RailEnvActions.DO_NOTHING or action == RailEnvActions.STOP_MOVING):
+                    # Allow agent to start with any forward or direction action
+                    agent.moving = True
+                    self.rewards_dict[i_agent] += self.start_penalty
+
                 if action != RailEnvActions.DO_NOTHING and action != RailEnvActions.STOP_MOVING:
                     cell_free, new_cell_valid, new_direction, new_position, transition_valid = \
                         self._check_action_on_agent(action, agent)
@@ -408,19 +422,6 @@ class RailEnv(Environment):
                     # Nothing left to do with broken agent
                     continue
 
-            if action == RailEnvActions.DO_NOTHING and agent.moving:
-                # Keep moving
-                action = RailEnvActions.MOVE_FORWARD
-
-            if action == RailEnvActions.STOP_MOVING and agent.moving and agent.speed_data['position_fraction'] == 0.0:
-                # Only allow halting an agent on entering new cells.
-                agent.moving = False
-                self.rewards_dict[i_agent] += self.stop_penalty
-
-            if not agent.moving and not (action == RailEnvActions.DO_NOTHING or action == RailEnvActions.STOP_MOVING):
-                # Allow agent to start with any forward or direction action
-                agent.moving = True
-                self.rewards_dict[i_agent] += self.start_penalty
 
             # Now perform a movement.
             # If the agent is in an initial position within a new cell (agent.speed_data['position_fraction']<eps)
diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index 8573c25c..9d55198b 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -1,6 +1,6 @@
 """Rail generators (infrastructure manager, "Infrastrukturbetreiber")."""
 import warnings
-from typing import Callable, Tuple, Any, Optional
+from typing import Callable, Tuple, Optional, Dict
 
 import msgpack
 import numpy as np
@@ -11,7 +11,7 @@ from flatland.core.grid.rail_env_grid import RailEnvTransitions
 from flatland.core.transition_map import GridTransitionMap
 from flatland.envs.grid4_generators_utils import connect_rail, connect_nodes, connect_from_nodes
 
-RailGeneratorProduct = Tuple[GridTransitionMap, Optional[Any]]
+RailGeneratorProduct = Tuple[GridTransitionMap, Optional[Dict]]
 RailGenerator = Callable[[int, int, int, int], RailGeneratorProduct]
 
 
-- 
GitLab