From be854ec7ccb81ab73114dbd52936a00f865f7fb1 Mon Sep 17 00:00:00 2001
From: nimishsantosh107 <nimishsantosh107@icloud.com>
Date: Sun, 1 Aug 2021 22:38:12 +0530
Subject: [PATCH] reward scheme implemented, treeobs fixed, untested changes

---
 flatland/envs/agent_utils.py  |  37 ++++++++--
 flatland/envs/observations.py |  21 ++++--
 flatland/envs/predictions.py  |   5 +-
 flatland/envs/rail_env.py     | 124 ++++++++++++++++------------------
 4 files changed, 109 insertions(+), 78 deletions(-)

diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py
index 9831f603..81e8eedf 100644
--- a/flatland/envs/agent_utils.py
+++ b/flatland/envs/agent_utils.py
@@ -1,3 +1,5 @@
+import numpy as np
+
 from enum import IntEnum
 from itertools import starmap
 from typing import Tuple, Optional, NamedTuple
@@ -7,14 +9,12 @@ from attr import attr, attrs, attrib, Factory
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
 from flatland.envs.schedule_utils import Schedule
 
-
 class RailAgentStatus(IntEnum):
     WAITING = 0
     READY_TO_DEPART = 1  # not in grid yet (position is None) -> prediction as if it were at initial position
     ACTIVE = 2  # in grid (position is not None), not done -> prediction is remaining path
     DONE = 3  # in grid (position is not None), but done -> prediction is stay at target forever
     DONE_REMOVED = 4  # removed from grid (position is None) -> prediction is None
-    CANCELLED = 5
 
 
 Agent = NamedTuple('Agent', [('initial_position', Tuple[int, int]),
@@ -29,6 +29,7 @@ Agent = NamedTuple('Agent', [('initial_position', Tuple[int, int]),
                              ('handle', int),
                              ('status', RailAgentStatus),
                              ('position', Tuple[int, int]),
+                             ('arrival_time', int),
                              ('old_direction', Grid4TransitionsEnum),
                              ('old_position', Tuple[int, int])])
 
@@ -66,6 +67,9 @@ class EnvAgent:
     status = attrib(default=RailAgentStatus.WAITING, type=RailAgentStatus)
     position = attrib(default=None, type=Optional[Tuple[int, int]])
 
+    # NEW : EnvAgent Reward Handling
+    arrival_time = attrib(default=None, type=int)
+
     # used in rendering
     old_direction = attrib(default=None)
     old_position = attrib(default=None)
@@ -83,6 +87,8 @@ class EnvAgent:
         else:
             self.status = RailAgentStatus.WAITING
             
+        self.arrival_time = None
+
         self.old_position = None
         self.old_direction = None
         self.moving = False
@@ -96,12 +102,33 @@ class EnvAgent:
         self.malfunction_data['nr_malfunctions'] = 0
         self.malfunction_data['moving_before_malfunction'] = False
 
+    # NEW : Callables
+    def get_shortest_path(self, distance_map):
+        from flatland.envs.rail_env_shortest_paths import get_shortest_paths # Circular dep fix
+        return get_shortest_paths(distance_map=distance_map, agent_handle=self.handle)[self.handle]
+        
+    def get_travel_time_on_shortest_path(self, distance_map):
+        distance = len(self.get_shortest_path(distance_map))
+        speed = self.speed_data['speed']
+        return int(np.ceil(distance / speed))
+
+    def get_time_remaining_until_latest_arrival(self, elapsed_steps):
+        return self.latest_arrival - elapsed_steps
+
+    def get_current_delay(self, elapsed_steps, distance_map):
+        '''
+        +ve if arrival time is projected before latest arrival
+        -ve if arrival time is projected after latest arrival
+        '''
+        return self.get_time_remaining_until_latest_arrival(elapsed_steps) - \
+               self.get_travel_time_on_shortest_path(distance_map)
+
     def to_agent(self) -> Agent:
         return Agent(initial_position=self.initial_position, initial_direction=self.initial_direction, 
                      direction=self.direction, target=self.target, moving=self.moving, earliest_departure=self.earliest_departure, 
-                     latest_arrival=self.latest_arrival, speed_data=self.speed_data,
-                     malfunction_data=self.malfunction_data, handle=self.handle, status=self.status,
-                     position=self.position, old_direction=self.old_direction, old_position=self.old_position)
+                     latest_arrival=self.latest_arrival, speed_data=self.speed_data, malfunction_data=self.malfunction_data, 
+                     handle=self.handle, status=self.status, position=self.position, arrival_time=self.arrival_time, 
+                     old_direction=self.old_direction, old_position=self.old_position)
 
     @classmethod
     def from_schedule(cls, schedule: Schedule):
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index beabe4a4..4de36060 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -101,10 +101,13 @@ class TreeObsForRailEnv(ObservationBuilder):
                 self.location_has_agent_malfunction[tuple(_agent.position)] = _agent.malfunction_data[
                     'malfunction']
 
-            if _agent.status in [RailAgentStatus.READY_TO_DEPART] and \
+            # [NIMISH] WHAT IS THIS
+            if _agent.status in [RailAgentStatus.READY_TO_DEPART, RailAgentStatus.WAITING] and \
                 _agent.initial_position:
-                self.location_has_agent_ready_to_depart[tuple(_agent.initial_position)] = \
-                    self.location_has_agent_ready_to_depart.get(tuple(_agent.initial_position), 0) + 1
+                    self.location_has_agent_ready_to_depart.setdefault(tuple(_agent.initial_position), 0)
+                    self.location_has_agent_ready_to_depart[tuple(_agent.initial_position)] += 1
+                # self.location_has_agent_ready_to_depart[tuple(_agent.initial_position)] = \
+                #     self.location_has_agent_ready_to_depart.get(tuple(_agent.initial_position), 0) + 1
 
         observations = super().get_many(handles)
 
@@ -192,8 +195,10 @@ class TreeObsForRailEnv(ObservationBuilder):
         if handle > len(self.env.agents):
             print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents))
         agent = self.env.agents[handle]  # TODO: handle being treated as index
-
-        if agent.status == RailAgentStatus.READY_TO_DEPART:
+        
+        if agent.status == RailAgentStatus.WAITING:
+            agent_virtual_position = agent.initial_position
+        elif agent.status == RailAgentStatus.READY_TO_DEPART:
             agent_virtual_position = agent.initial_position
         elif agent.status == RailAgentStatus.ACTIVE:
             agent_virtual_position = agent.position
@@ -564,7 +569,9 @@ class GlobalObsForRailEnv(ObservationBuilder):
     def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray):
 
         agent = self.env.agents[handle]
-        if agent.status == RailAgentStatus.READY_TO_DEPART:
+        if agent.status == RailAgentStatus.WAITING:
+            agent_virtual_position = agent.initial_position
+        elif agent.status == RailAgentStatus.READY_TO_DEPART:
             agent_virtual_position = agent.initial_position
         elif agent.status == RailAgentStatus.ACTIVE:
             agent_virtual_position = agent.position
@@ -602,7 +609,7 @@ class GlobalObsForRailEnv(ObservationBuilder):
                 obs_agents_state[other_agent.position][2] = other_agent.malfunction_data['malfunction']
                 obs_agents_state[other_agent.position][3] = other_agent.speed_data['speed']
             # fifth channel: all ready to depart on this position
-            if other_agent.status == RailAgentStatus.READY_TO_DEPART:
+            if other_agent.status == RailAgentStatus.READY_TO_DEPART or other_agent.status == RailAgentStatus.WAITING:
                 obs_agents_state[other_agent.initial_position][4] += 1
         return self.rail_obs, obs_agents_state, obs_targets
 
diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py
index 40646893..3cd3b714 100644
--- a/flatland/envs/predictions.py
+++ b/flatland/envs/predictions.py
@@ -126,8 +126,9 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
 
         prediction_dict = {}
         for agent in agents:
-
-            if agent.status == RailAgentStatus.READY_TO_DEPART:
+            if agent.status == RailAgentStatus.WAITING:
+                agent_virtual_position = agent.initial_position
+            elif agent.status == RailAgentStatus.READY_TO_DEPART:
                 agent_virtual_position = agent.initial_position
             elif agent.status == RailAgentStatus.ACTIVE:
                 agent_virtual_position = agent.position
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 329d8aeb..472a016d 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -121,15 +121,18 @@ class RailEnv(Environment):
     For Round 2, they will be passed to the constructor as arguments, to allow for more flexibility.
 
     """
-    alpha = 1.0
-    beta = 1.0
     # Epsilon to avoid rounding errors
     epsilon = 0.01
-    invalid_action_penalty = 0  # previously -2; GIACOMO: we decided that invalid actions will carry no penalty
+    # NEW : REW: Sparse Reward
+    alpha = 0
+    beta = 0
     step_penalty = -1 * alpha
     global_reward = 1 * beta
+    invalid_action_penalty = 0  # previously -2; GIACOMO: we decided that invalid actions will carry no penalty
     stop_penalty = 0  # penalty for stopping a moving agent
     start_penalty = 0  # penalty for starting a stopped agent
+    cancellation_factor = 1
+    cancellation_time_buffer = 0
 
     def __init__(self,
                  width,
@@ -367,21 +370,24 @@ class RailEnv(Environment):
             # Look at the specific schedule generator used to see where this number comes from
             self._max_episode_steps = schedule.max_episode_steps # NEW UPDATE THIS!
 
-        # Agent Positions Map
-        self.agent_positions = np.zeros((self.height, self.width), dtype=int) - 1
+            # Reset distance map - basically initializing
+            self.distance_map.reset(self.agents, self.rail)
 
-        # Reset distance map - basically initializing
-        self.distance_map.reset(self.agents, self.rail)
+            # NEW : Time Schedule Generation
+            # find agent speeds (needed for max_ep_steps recalculation)
+            if (type(self.schedule_generator.speed_ratio_map) is dict):
+                config_speeds = list(self.schedule_generator.speed_ratio_map.keys())
+            else:
+                config_speeds = [1.0]
 
-        # NEW : Time Schedule Generation
-        # find agent speeds (needed for max_ep_steps recalculation)
-        if (type(self.schedule_generator.speed_ratio_map) is dict):
-            config_speeds = list(self.schedule_generator.speed_ratio_map.keys())
-        else:
-            config_speeds = [1.0]
+            self._max_episode_steps = schedule_time_generator(self.agents, config_speeds, self.distance_map, 
+                                            self._max_episode_steps, self.np_random, temp_info=optionals)
+        
+        # Reset distance map - again (just in case if regen_schedule = False)
+        self.distance_map.reset(self.agents, self.rail)
 
-        self._max_episode_steps = schedule_time_generator(self.agents, config_speeds, self.distance_map, 
-                                        self._max_episode_steps, self.np_random, temp_info=optionals)
+        # Agent Positions Map
+        self.agent_positions = np.zeros((self.height, self.width), dtype=int) - 1
         
         # Reset agents to initial states
         self.reset_agents()
@@ -488,21 +494,7 @@ class RailEnv(Environment):
 
         # If we're done, set reward and info_dict and step() is done.
         if self.dones["__all__"]:
-            self.rewards_dict = {}
-            info_dict = {
-                "action_required": {},
-                "malfunction": {},
-                "speed": {},
-                "status": {},
-            }
-            for i_agent, agent in enumerate(self.agents):
-                self.rewards_dict[i_agent] = self.global_reward
-                info_dict["action_required"][i_agent] = False
-                info_dict["malfunction"][i_agent] = 0
-                info_dict["speed"][i_agent] = 0
-                info_dict["status"][i_agent] = agent.status
-
-            return self._get_observations(), self.rewards_dict, self.dones, info_dict
+            raise Exception("Episode is done, cannot call step()")
 
         # Reset the step rewards
         self.rewards_dict = dict()
@@ -570,28 +562,40 @@ class RailEnv(Environment):
                 # Fix agents that finished their malfunction such that they can perform an action in the next step
                 self._fix_agent_after_malfunction(agent)
 
-        # Check for end of episode + set global reward to all rewards!
-        if have_all_agents_ended:
-            self.dones["__all__"] = True
-            self.rewards_dict = {i: self.global_reward for i in range(self.get_num_agents())}
-        
-        if (self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps):
-            self.dones["__all__"] = True
         
+        # NEW : REW: (END)
+        if ((self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps)) \
+            or have_all_agents_ended :
+            
             for i_agent, agent in enumerate(self.agents):
-                # NEW : STEP:REW: CANCELLED check / reward (never departed)
-                if (agent.status == RailAgentStatus.READY_TO_DEPART):
-                    agent.status = RailAgentStatus.CANCELLED
-                    # NEGATIVE REWARD?
                 
-                # NEW : STEP:REW: Departed but never reached
-                if (agent.status == RailAgentStatus.ACTIVE):
-                    pass
-                    # NEGATIVE REWARD?
-
+                # agent done? (arrival_time is not None)
+                if (self.dones[i_agent]):
+                    
+                    # if agent arrived earlier or on time = 0
+                    # if agent arrived later = -ve reward based on how late
+                    reward = min(agent.latest_arrival - agent.arrival_time, 0)
+                    self.rewards_dict[i_agent] += reward
+                
+                # Agents not done (arrival_time is None)
+                else:
+                    
+                    # CANCELLED check (never departed)
+                    if (agent.status == RailAgentStatus.READY_TO_DEPART):
+                        reward = -1 * self.cancellation_factor * \
+                            (agent.get_travel_time_on_shortest_path(self.distance_map) + 0) # 0 replaced with buffer
+                        self.rewards_dict[i_agent] += reward
+
+                    # Departed but never reached
+                    if (agent.status == RailAgentStatus.ACTIVE):
+                        reward = agent.get_current_delay(self.distance_map)
+                        self.rewards_dict[i_agent] += reward
+                
                 self.dones[i_agent] = True
 
+            self.dones["__all__"] = True
         
+
         if self.record_steps:
             self.record_timestep(action_dict_)
 
@@ -859,7 +863,7 @@ class RailEnv(Environment):
     def _step_agent2_cf(self, i_agent):
         agent = self.agents[i_agent]
 
-        # NEW : REW: no reward during WAITING...
+        # NEW : REW: (WAITING) no reward during WAITING...
         if agent.status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED, RailAgentStatus.WAITING]:
             return
 
@@ -901,23 +905,16 @@ class RailEnv(Environment):
                 agent.direction = new_direction
                 agent.speed_data['position_fraction'] = 0.0
 
-            # NEW : REW: Check DONE  before / after LA & Check if RUNNING before / after LA
+            # NEW : STEP: Check DONE  before / after LA & Check if RUNNING before / after LA
             # has the agent reached its target?
             if np.equal(agent.position, agent.target).all():
-                # arrived before Latest Arrival
-                if (self._elapsed_steps <= agent.latest_arrival):
-                    agent.status = RailAgentStatus.DONE
-                    self.dones[i_agent] = True
-                    self.active_agents.remove(i_agent)
-                    agent.moving = False
-                    self._remove_agent_from_scene(agent)
-                else: # arrived after latest arrival
-                    agent.status = RailAgentStatus.DONE
-                    self.dones[i_agent] = True
-                    self.active_agents.remove(i_agent)
-                    agent.moving = False
-                    self._remove_agent_from_scene(agent)   
-                    # NEGATIVE REWARD?
+                # arrived before or after Latest Arrival
+                agent.status = RailAgentStatus.DONE
+                self.dones[i_agent] = True
+                self.active_agents.remove(i_agent)
+                agent.moving = False
+                agent.arrival_time = self._elapsed_steps
+                self._remove_agent_from_scene(agent)
 
             else: # not reached its target and moving
                 # running before Latest Arrival
@@ -930,8 +927,7 @@ class RailEnv(Environment):
             if (self._elapsed_steps <= agent.latest_arrival):
                 self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
             else:  # stopped (!move) after Latest Arrival
-                self.rewards_dict[i_agent] += self.step_penalty * \
-                    agent.speed_data['speed']  # + # NEGATIVE REWARD? per step?
+                self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']  # + # NEGATIVE REWARD? per step?
 
     def _set_agent_to_initial_position(self, agent: EnvAgent, new_position: IntVector2D):
         """
-- 
GitLab