diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index ab0e14879a354993331519315d7f1fbc423f2af2..4cd1d2d4683e9e945318c914c1e90ea7f2abeb52 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -15,7 +15,7 @@ from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions from flatland.core.grid.grid4_utils import get_new_position from flatland.core.grid.grid_utils import IntVector2D from flatland.core.transition_map import GridTransitionMap -from flatland.envs.agent_utils import EnvAgent, RailAgentStatus +from flatland.envs.agent_utils import Agent, EnvAgent, RailAgentStatus from flatland.envs.distance_map import DistanceMap from flatland.envs.rail_env_action import RailEnvActions @@ -475,6 +475,34 @@ class RailEnv(Environment): return + def _handle_end_reward(self, agent: EnvAgent) -> int: + ''' + Handles end-of-episode reward for a particular agent. + + Parameters + ---------- + agent : EnvAgent + ''' + reward = None + # agent done? (arrival_time is not None) + if agent.status == RailAgentStatus.DONE or agent.status == RailAgentStatus.DONE_REMOVED: + # if agent arrived earlier or on time = 0 + # if agent arrived later = -ve reward based on how late + reward = min(agent.latest_arrival - agent.arrival_time, 0) + + # Agents not done (arrival_time is None) + else: + # CANCELLED check (never departed) + if (agent.status == RailAgentStatus.READY_TO_DEPART): + reward = -1 * self.cancellation_factor * \ + (agent.get_travel_time_on_shortest_path(self.distance_map) + 0) # 0 replaced with buffer + + # Departed but never reached + if (agent.status == RailAgentStatus.ACTIVE): + reward = agent.get_current_delay(self._elapsed_steps, self.distance_map) + + return reward + def step(self, action_dict_: Dict[int, RailEnvActions]): """ Updates rewards for the agents at a step. @@ -564,26 +592,8 @@ class RailEnv(Environment): for i_agent, agent in enumerate(self.agents): # agent done? (arrival_time is not None) - if agent.status == RailAgentStatus.DONE or agent.status == RailAgentStatus.DONE_REMOVED: - - # if agent arrived earlier or on time = 0 - # if agent arrived later = -ve reward based on how late - reward = min(agent.latest_arrival - agent.arrival_time, 0) - self.rewards_dict[i_agent] += reward - - # Agents not done (arrival_time is None) - else: - - # CANCELLED check (never departed) - if (agent.status == RailAgentStatus.READY_TO_DEPART): - reward = -1 * self.cancellation_factor * \ - (agent.get_travel_time_on_shortest_path(self.distance_map) + 0) # 0 replaced with buffer - self.rewards_dict[i_agent] += reward - - # Departed but never reached - if (agent.status == RailAgentStatus.ACTIVE): - reward = agent.get_current_delay(self._elapsed_steps, self.distance_map) - self.rewards_dict[i_agent] += reward + reward = self._handle_end_reward(agent) + self.rewards_dict[i_agent] += reward self.dones[i_agent] = True