From ca36e40eeb553c6ffecdb837da6e6635f191634d Mon Sep 17 00:00:00 2001
From: Dipam Chakraborty <dipam@aicrowd.com>
Date: Thu, 9 Sep 2021 18:34:21 +0530
Subject: [PATCH] change all RailEnvStatus to TrainState

---
 flatland/envs/agent_utils.py                  | 25 ++++---
 flatland/envs/malfunction_generators.py       |  6 +-
 flatland/envs/observations.py                 | 19 +++--
 flatland/envs/persistence.py                  |  2 +-
 flatland/envs/rail_env.py                     | 74 +++++--------------
 flatland/envs/rail_env_shortest_paths.py      | 10 +--
 .../envs/step_utils/action_preprocessing.py   |  6 +-
 flatland/envs/step_utils/states.py            |  9 +++
 flatland/envs/step_utils/transition_utils.py  |  5 +-
 flatland/utils/rendertools.py                 |  6 +-
 tests/test_env_step_utils.py                  |  6 +-
 11 files changed, 70 insertions(+), 98 deletions(-)

diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py
index 632caeea..9230659e 100644
--- a/flatland/envs/agent_utils.py
+++ b/flatland/envs/agent_utils.py
@@ -2,18 +2,19 @@ from flatland.envs.rail_trainrun_data_structures import Waypoint
 import numpy as np
 
 from enum import IntEnum
-from flatland.envs.step_utils.states import TrainState
+
 from itertools import starmap
 from typing import Tuple, Optional, NamedTuple, List
 
 from attr import attr, attrs, attrib, Factory
 
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
-from flatland.envs.schedule_utils import Schedule
+from flatland.envs.timetable_utils import Line
 
 from flatland.envs.step_utils.action_saver import ActionSaver
 from flatland.envs.step_utils.speed_counter import SpeedCounter
 from flatland.envs.step_utils.state_machine import TrainStateMachine
+from flatland.envs.step_utils.states import TrainState
 from flatland.envs.step_utils.malfunction_handler import MalfunctionHandler
 
 Agent = NamedTuple('Agent', [('initial_position', Tuple[int, int]),
@@ -137,8 +138,8 @@ class EnvAgent:
         """
         speed_datas = []
         speed_counters = []
-        for i in range(len(schedule.agent_positions)):
-            speed = schedule.agent_speeds[i] if schedule.agent_speeds is not None else 1.0
+        for i in range(len(line.agent_positions)):
+            speed = line.agent_speeds[i] if line.agent_speeds is not None else 1.0
             speed_datas.append({'position_fraction': 0.0,
                                 'speed': speed,
                                 'transition_action_on_cellexit': 0})
@@ -152,16 +153,16 @@ class EnvAgent:
                                       'next_malfunction': 0,
                                       'nr_malfunctions': 0})
         
-        return list(starmap(EnvAgent, zip(schedule.agent_positions,  # TODO : Dipam - Really want to change this way of loading agents
-                                          schedule.agent_directions,
-                                          schedule.agent_directions,
-                                          schedule.agent_targets, 
-                                          [False] * len(schedule.agent_positions), 
-                                          [None] * len(schedule.agent_positions), # earliest_departure
-                                          [None] * len(schedule.agent_positions), # latest_arrival
+        return list(starmap(EnvAgent, zip(line.agent_positions,  # TODO : Dipam - Really want to change this way of loading agents
+                                          line.agent_directions,
+                                          line.agent_directions,
+                                          line.agent_targets, 
+                                          [False] * len(line.agent_positions), 
+                                          [None] * len(line.agent_positions), # earliest_departure
+                                          [None] * len(line.agent_positions), # latest_arrival
                                           speed_datas,
                                           malfunction_datas,
-                                          range(len(schedule.agent_positions)),
+                                          range(len(line.agent_positions)),
                                           speed_counters,
                                           )))
 
diff --git a/flatland/envs/malfunction_generators.py b/flatland/envs/malfunction_generators.py
index 2fecddf1..0dfafb36 100644
--- a/flatland/envs/malfunction_generators.py
+++ b/flatland/envs/malfunction_generators.py
@@ -5,7 +5,8 @@ from typing import Callable, NamedTuple, Optional, Tuple
 import numpy as np
 from numpy.random.mtrand import RandomState
 
-from flatland.envs.agent_utils import EnvAgent, RailAgentStatus
+from flatland.envs.agent_utils import EnvAgent
+from flatland.envs.step_utils.states import TrainState
 from flatland.envs import persistence
 
 
@@ -155,7 +156,8 @@ def single_malfunction_generator(earlierst_malfunction: int, malfunction_duratio
             malfunction_calls[agent.handle] = 1
 
         # Break an agent that is active at the time of the malfunction
-        if agent.status == RailAgentStatus.ACTIVE and malfunction_calls[agent.handle] >= earlierst_malfunction: #TODO : Dipam : Is this needed?
+        if (agent.state == TrainState.MOVING or agent.state == TrainState.STOPPED) \
+            and malfunction_calls[agent.handle] >= earlierst_malfunction: #TODO : Dipam : Is this needed?
             global_nr_malfunctions += 1
             return Malfunction(malfunction_duration)
         else:
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 4de36060..a140da06 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -11,7 +11,8 @@ from flatland.core.env_observation_builder import ObservationBuilder
 from flatland.core.env_prediction_builder import PredictionBuilder
 from flatland.core.grid.grid4_utils import get_new_position
 from flatland.core.grid.grid_utils import coordinate_to_position
-from flatland.envs.agent_utils import RailAgentStatus, EnvAgent
+from flatland.envs.agent_utils import EnvAgent
+from flatland.envs.step_utils.states import TrainState
 from flatland.utils.ordered_set import OrderedSet
 
 
@@ -93,7 +94,7 @@ class TreeObsForRailEnv(ObservationBuilder):
         self.location_has_agent_ready_to_depart = {}
 
         for _agent in self.env.agents:
-            if _agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] and \
+            if not TrainState.off_map_state(_agent.state) and \
                 _agent.position:
                 self.location_has_agent[tuple(_agent.position)] = 1
                 self.location_has_agent_direction[tuple(_agent.position)] = _agent.direction
@@ -102,7 +103,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                     'malfunction']
 
             # [NIMISH] WHAT IS THIS
-            if _agent.status in [RailAgentStatus.READY_TO_DEPART, RailAgentStatus.WAITING] and \
+            if TrainState.off_map_state(_agent.state) and \
                 _agent.initial_position:
                     self.location_has_agent_ready_to_depart.setdefault(tuple(_agent.initial_position), 0)
                     self.location_has_agent_ready_to_depart[tuple(_agent.initial_position)] += 1
@@ -569,13 +570,11 @@ class GlobalObsForRailEnv(ObservationBuilder):
     def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray):
 
         agent = self.env.agents[handle]
-        if agent.status == RailAgentStatus.WAITING:
-            agent_virtual_position = agent.initial_position
-        elif agent.status == RailAgentStatus.READY_TO_DEPART:
+        if TrainState.off_map_state(agent.state):
             agent_virtual_position = agent.initial_position
-        elif agent.status == RailAgentStatus.ACTIVE:
+        elif TrainState.on_map_state(agent.state):
             agent_virtual_position = agent.position
-        elif agent.status == RailAgentStatus.DONE:
+        elif agent.state == TrainState.DONE:
             agent_virtual_position = agent.target
         else:
             return None
@@ -596,7 +595,7 @@ class GlobalObsForRailEnv(ObservationBuilder):
             other_agent: EnvAgent = self.env.agents[i]
 
             # ignore other agents not in the grid any more
-            if other_agent.status == RailAgentStatus.DONE_REMOVED:
+            if other_agent.state == TrainState.DONE:
                 continue
 
             obs_targets[other_agent.target][1] = 1
@@ -609,7 +608,7 @@ class GlobalObsForRailEnv(ObservationBuilder):
                 obs_agents_state[other_agent.position][2] = other_agent.malfunction_data['malfunction']
                 obs_agents_state[other_agent.position][3] = other_agent.speed_data['speed']
             # fifth channel: all ready to depart on this position
-            if other_agent.status == RailAgentStatus.READY_TO_DEPART or other_agent.status == RailAgentStatus.WAITING:
+            if TrainState.off_map_state(other_agent.state):
                 obs_agents_state[other_agent.initial_position][4] += 1
         return self.rail_obs, obs_agents_state, obs_targets
 
diff --git a/flatland/envs/persistence.py b/flatland/envs/persistence.py
index 41f352e7..c5ec8f33 100644
--- a/flatland/envs/persistence.py
+++ b/flatland/envs/persistence.py
@@ -13,7 +13,7 @@ from flatland.core.env_observation_builder import DummyObservationBuilder
 #from flatland.core.grid.grid4_utils import get_new_position
 #from flatland.core.grid.grid_utils import IntVector2D
 from flatland.core.transition_map import GridTransitionMap
-from flatland.envs.agent_utils import Agent, EnvAgent, RailAgentStatus
+from flatland.envs.agent_utils import Agent, EnvAgent
 from flatland.envs.distance_map import DistanceMap
 
 #from flatland.envs.observations import GlobalObsForRailEnv
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 1dc332d9..68bb0de2 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -17,7 +17,7 @@ from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions
 from flatland.core.grid.grid4_utils import get_new_position
 from flatland.core.grid.grid_utils import IntVector2D, position_to_coordinate
 from flatland.core.transition_map import GridTransitionMap
-from flatland.envs.agent_utils import Agent, EnvAgent, RailAgentStatus
+from flatland.envs.agent_utils import Agent, EnvAgent
 from flatland.envs.distance_map import DistanceMap
 from flatland.envs.rail_env_action import RailEnvActions
 
@@ -39,8 +39,7 @@ from gym.utils import seeding
 # from flatland.envs.line_generators import random_line_generator, LineGenerator
 
 
-# NEW : Imports 
-from flatland.envs.schedule_time_generators import schedule_time_generator
+from flatland.envs.timetable_generators import timetable_generator
 from flatland.envs.step_utils.states import TrainState
 from flatland.envs.step_utils.transition_utils import check_action
 
@@ -285,9 +284,9 @@ class RailEnv(Environment):
         True: Agent needs to provide an action
         False: Agent cannot provide an action
         """
-        return (agent.status == RailAgentStatus.READY_TO_DEPART or (
-            agent.status == RailAgentStatus.ACTIVE and fast_isclose(agent.speed_data['position_fraction'], 0.0,
-                                                                    rtol=1e-03)))
+        return agent.state == TrainState.READY_TO_DEPART or \
+               (TrainState.on_map_state(agent.state) and \
+                fast_isclose(agent.speed_data['position_fraction'], 0.0, rtol=1e-03) )
 
     def reset(self, regenerate_rail: bool = True, regenerate_schedule: bool = True, *,
               random_seed: bool = None) -> Tuple[Dict, Dict]:
@@ -400,7 +399,7 @@ class RailEnv(Environment):
                 i: agent.malfunction_data['malfunction'] for i, agent in enumerate(self.agents)
             },
             'speed': {i: agent.speed_data['speed'] for i, agent in enumerate(self.agents)},
-            'status': {i: agent.status for i, agent in enumerate(self.agents)}
+            'state': {i: agent.state for i, agent in enumerate(self.agents)}
         }
         # Return the new observation vectors for each agent
         observation_dict: Dict = self._get_observations()
@@ -425,6 +424,8 @@ class RailEnv(Environment):
         st_signals['target_reached'] = fast_position_equal(agent.position, agent.target)
         st_signals['movement_conflict'] = (not movement_allowed) and agent.speed_counter.is_cell_exit # TODO: Modify motion check to provide proper conflict information
 
+        return st_signals
+
     def _handle_end_reward(self, agent: EnvAgent) -> int:
         '''
         Handles end-of-episode reward for a particular agent.
@@ -456,8 +457,7 @@ class RailEnv(Environment):
     def step(self, action_dict_: Dict[int, RailEnvActions]):
         """
         Updates rewards for the agents at a step.
-
-    def step(self, action_dict):
+        """
         self._elapsed_steps += 1
 
         # If we're done, set reward and info_dict and step() is done.
@@ -497,7 +497,7 @@ class RailEnv(Environment):
             agent.malfunction_handler.generate_malfunction(self.malfunction_generator, self.np_random)
 
             # Get action for the agent
-            action = action_dict.get(i_agent, RailEnvActions.DO_NOTHING)
+            action = action_dict_.get(i_agent, RailEnvActions.DO_NOTHING)
             # TODO: Add the bottom stuff to separate function(s)
 
             # Preprocess action
@@ -509,7 +509,7 @@ class RailEnv(Environment):
             agent_not_on_map = current_position is None
             if agent_not_on_map: # Agent not added on map yet
                 current_position, current_direction = agent.initial_position, agent.initial_direction
-            action = preprocess_moving_action(action, agent.state, self.rail, current_position, current_direction)
+            action = preprocess_moving_action(action, self.rail, current_position, current_direction)
 
             # Save moving actions in not already saved
             agent.action_saver.save_action_if_allowed(action, agent.state)
@@ -519,6 +519,7 @@ class RailEnv(Environment):
             if agent_not_on_map and agent.action_saver.is_action_saved:
                 temp_new_position = agent.initial_position
                 temp_new_direction = agent.initial_direction
+                preprocessed_action = action
                 
             # When cell exit occurs apply saved action independent of other agents
             elif agent.speed_counter.is_cell_exit and agent.action_saver.is_action_saved:
@@ -526,11 +527,13 @@ class RailEnv(Environment):
                 # Apply action independent of other agents and get temporary new position and direction
                 temp_pd = self.apply_action_independent(saved_action, self.rail, agent.position, agent.direction)
                 temp_new_position, temp_new_direction = temp_pd
+                preprocessed_action = saved_action
             else:
                 temp_new_position, temp_new_direction = agent.position, agent.direction
+                preprocessed_action = action
 
             # TODO: Saving temporary positon shouldn't be needed if recheck of position is not needed later (see TAG#1)
-            temp_saved_data[i_agent] = temp_new_position, temp_new_direction, action
+            temp_saved_data[i_agent] = temp_new_position, temp_new_direction, preprocessed_action
             self.motionCheck.addAgent(i_agent, agent.position, temp_new_position)
 
         # Find conflicts
@@ -554,6 +557,10 @@ class RailEnv(Environment):
             else:
                 final_new_position = agent.position
                 final_new_direction = agent.direction
+            # if final_new_position and self.rail.grid[final_new_position] == 0:
+                # import pdb; pdb.set_trace()
+            # if final_new_position and not (final_new_position[0] >= 0 and final_new_position[1] >= 0 and final_new_position[0] < self.rail.height and final_new_position[1] < self.rail.width): # TODO: Remove this
+                # import pdb; pdb.set_trace()
             agent.position = final_new_position
             agent.direction = final_new_direction
 
@@ -581,49 +588,6 @@ class RailEnv(Environment):
         self.rewards_dict = {i_agent: 0 for i_agent in range(len(self.agents))} # TODO : Remove this
         return self._get_observations(), self.rewards_dict, self.dones, info_dict # TODO : Will need changes?
 
-    def _set_agent_to_initial_position(self, agent: EnvAgent, new_position: IntVector2D):
-        """
-        Sets the agent to its initial position. Updates the agent object and the position
-        of the agent inside the global agent_position numpy array
-
-        Parameters
-        -------
-        agent: EnvAgent object
-        new_position: IntVector2D
-        """
-        agent.position = new_position
-        self.agent_positions[agent.position] = agent.handle
-
-    def _move_agent_to_new_position(self, agent: EnvAgent, new_position: IntVector2D):
-        """
-        Move the agent to the a new position. Updates the agent object and the position
-        of the agent inside the global agent_position numpy array
-
-        Parameters
-        -------
-        agent: EnvAgent object
-        new_position: IntVector2D
-        """
-        agent.position = new_position
-        self.agent_positions[agent.old_position] = -1
-        self.agent_positions[agent.position] = agent.handle
-
-    def _remove_agent_from_scene(self, agent: EnvAgent):
-        """
-        Remove the agent from the scene. Updates the agent object and the position
-        of the agent inside the global agent_position numpy array
-
-        Parameters
-        -------
-        agent: EnvAgent object
-        """
-        self.agent_positions[agent.position] = -1
-        if self.remove_agents_at_target:
-            agent.position = None
-            # setting old_position to None here stops the DONE agents from appearing in the rendered image
-            agent.old_position = None
-            agent.status = RailAgentStatus.DONE_REMOVED
-
     def record_timestep(self, dActions):
         ''' Record the positions and orientations of all agents in memory, in the cur_episode
         '''
diff --git a/flatland/envs/rail_env_shortest_paths.py b/flatland/envs/rail_env_shortest_paths.py
index 8c981778..7d063282 100644
--- a/flatland/envs/rail_env_shortest_paths.py
+++ b/flatland/envs/rail_env_shortest_paths.py
@@ -7,7 +7,7 @@ import numpy as np
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
 from flatland.core.grid.grid4_utils import get_new_position
 from flatland.core.transition_map import GridTransitionMap
-from flatland.envs.agent_utils import RailAgentStatus
+from flatland.envs.step_utils.states import TrainState
 from flatland.envs.distance_map import DistanceMap
 from flatland.envs.rail_env_action import RailEnvActions, RailEnvNextAction
 from flatland.envs.rail_trainrun_data_structures import Waypoint
@@ -227,13 +227,11 @@ def get_shortest_paths(distance_map: DistanceMap, max_depth: Optional[int] = Non
     shortest_paths = dict()
 
     def _shortest_path_for_agent(agent):
-        if agent.status == RailAgentStatus.WAITING:
+        if TrainState.off_map_state(agent.state):
             position = agent.initial_position
-        elif agent.status == RailAgentStatus.READY_TO_DEPART:
-            position = agent.initial_position
-        elif agent.status == RailAgentStatus.ACTIVE:
+        elif TrainState.on_map_state(agent.state):
             position = agent.position
-        elif agent.status == RailAgentStatus.DONE:
+        elif agent.state == TrainState.DONE:
             position = agent.target
         else:
             shortest_paths[agent.handle] = None
diff --git a/flatland/envs/step_utils/action_preprocessing.py b/flatland/envs/step_utils/action_preprocessing.py
index e8ad1d79..c777342d 100644
--- a/flatland/envs/step_utils/action_preprocessing.py
+++ b/flatland/envs/step_utils/action_preprocessing.py
@@ -20,8 +20,8 @@ def process_do_nothing(state: TrainState):
     return action
 
 
-def process_left_right(action, state, rail, position, direction):
-    if not check_valid_action(action, state, rail, position, direction):
+def process_left_right(action, rail, position, direction):
+    if not check_valid_action(action, rail, position, direction):
         action = RailEnvActions.MOVE_FORWARD
     return action
 
@@ -48,7 +48,7 @@ def preprocess_raw_action(action, state):
 
     return action
 
-def preprocess_moving_action(action, state, rail, position, direction):
+def preprocess_moving_action(action, rail, position, direction):
     """
     LEFT/RIGHT is converted to FORWARD if left/right is not available and train is moving
     FORWARD is converted to STOP_MOVING if leading to dead end?
diff --git a/flatland/envs/step_utils/states.py b/flatland/envs/step_utils/states.py
index 4c991b66..b8040939 100644
--- a/flatland/envs/step_utils/states.py
+++ b/flatland/envs/step_utils/states.py
@@ -16,6 +16,15 @@ class TrainState(IntEnum):
     @staticmethod
     def is_malfunction_state(state):
         return state in [2, 5] # TODO: Can this be done with names instead?
+
+    @staticmethod
+    def off_map_state(state):
+        return state in [0, 1, 2]
+    
+    @staticmethod    
+    def on_map_state(state):
+        return state in [3, 4, 5]
+
     
 
 
diff --git a/flatland/envs/step_utils/transition_utils.py b/flatland/envs/step_utils/transition_utils.py
index 2d58d21e..157db5ac 100644
--- a/flatland/envs/step_utils/transition_utils.py
+++ b/flatland/envs/step_utils/transition_utils.py
@@ -66,9 +66,8 @@ def check_action_on_agent(action, rail, position, direction):
     new_direction, transition_valid = check_action(action, position, direction, rail)
     new_position = get_new_position(position, new_direction)
 
-    cell_inside_grid = check_bounds(new_position, rail.height, rail.width)
-    cell_not_empty = rail.get_full_transitions(*new_position) > 0
-    new_cell_valid = cell_inside_grid and cell_not_empty
+    new_cell_valid = check_bounds(new_position, rail.height, rail.width) and \
+                     rail.get_full_transitions(*new_position) > 0
 
     # If transition validity hasn't been checked yet.
     if transition_valid is None:
diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py
index 910dec32..9499c5c4 100644
--- a/flatland/utils/rendertools.py
+++ b/flatland/utils/rendertools.py
@@ -7,7 +7,7 @@ import numpy as np
 from numpy import array
 from recordtype import recordtype
 
-from flatland.envs.agent_utils import RailAgentStatus
+from flatland.envs.step_utils.states import TrainState
 
 from flatland.utils.graphics_pil import PILGL, PILSVG
 from flatland.utils.graphics_pgl import PGLGL
@@ -741,9 +741,9 @@ class RenderLocal(RenderBase):
                         self.gl.set_cell_occupied(agent_idx, *(agent.position))
                     
                     if show_inactive_agents:
-                        show_this_agent=True
+                        show_this_agent = True
                     else:
-                        show_this_agent = agent.status == RailAgentStatus.ACTIVE
+                        show_this_agent = TrainState.on_map_state(agent.state)
 
                     if show_this_agent:
                         self.gl.set_agent_at(agent_idx, *position, agent.direction, direction, 
diff --git a/tests/test_env_step_utils.py b/tests/test_env_step_utils.py
index 739d3d06..4c249de3 100644
--- a/tests/test_env_step_utils.py
+++ b/tests/test_env_step_utils.py
@@ -10,7 +10,7 @@ from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_env import RailEnvActions
 from flatland.envs.rail_generators import sparse_rail_generator
 #from flatland.envs.sparse_rail_gen import SparseRailGen
-from flatland.envs.schedule_generators import sparse_schedule_generator
+from flatland.envs.line_generators import sparse_line_generator
 
 
 def get_small_two_agent_env():
@@ -35,7 +35,7 @@ def get_small_two_agent_env():
                     1. / 3.: 0.25,  # Slow commuter train
                     1. / 4.: 0.25}  # Slow freight train
 
-    schedule_generator = sparse_schedule_generator(speed_ration_map)
+    line_generator = sparse_line_generator(speed_ration_map)
 
 
     stochastic_data = MalfunctionParameters(malfunction_rate=1/10000,  # Rate of malfunction occurence
@@ -48,7 +48,7 @@ def get_small_two_agent_env():
     env = RailEnv(width=width,
                 height=height,
                 rail_generator=rail_generator,
-                schedule_generator=schedule_generator,
+                line_generator=line_generator,
                 number_of_agents=nr_trains,
                 obs_builder_object=observation_builder,
                 #malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
-- 
GitLab