Skip to content
Snippets Groups Projects
Commit 172ac2c0 authored by nimishsantosh107's avatar nimishsantosh107
Browse files

end reward handling - decoupled

parent 0eebf660
No related branches found
No related tags found
No related merge requests found
......@@ -15,7 +15,7 @@ from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions
from flatland.core.grid.grid4_utils import get_new_position
from flatland.core.grid.grid_utils import IntVector2D
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_utils import EnvAgent, RailAgentStatus
from flatland.envs.agent_utils import Agent, EnvAgent, RailAgentStatus
from flatland.envs.distance_map import DistanceMap
from flatland.envs.rail_env_action import RailEnvActions
......@@ -475,6 +475,34 @@ class RailEnv(Environment):
return
def _handle_end_reward(self, agent: EnvAgent) -> int:
'''
Handles end-of-episode reward for a particular agent.
Parameters
----------
agent : EnvAgent
'''
reward = None
# agent done? (arrival_time is not None)
if agent.status == RailAgentStatus.DONE or agent.status == RailAgentStatus.DONE_REMOVED:
# if agent arrived earlier or on time = 0
# if agent arrived later = -ve reward based on how late
reward = min(agent.latest_arrival - agent.arrival_time, 0)
# Agents not done (arrival_time is None)
else:
# CANCELLED check (never departed)
if (agent.status == RailAgentStatus.READY_TO_DEPART):
reward = -1 * self.cancellation_factor * \
(agent.get_travel_time_on_shortest_path(self.distance_map) + 0) # 0 replaced with buffer
# Departed but never reached
if (agent.status == RailAgentStatus.ACTIVE):
reward = agent.get_current_delay(self._elapsed_steps, self.distance_map)
return reward
def step(self, action_dict_: Dict[int, RailEnvActions]):
"""
Updates rewards for the agents at a step.
......@@ -564,26 +592,8 @@ class RailEnv(Environment):
for i_agent, agent in enumerate(self.agents):
# agent done? (arrival_time is not None)
if agent.status == RailAgentStatus.DONE or agent.status == RailAgentStatus.DONE_REMOVED:
# if agent arrived earlier or on time = 0
# if agent arrived later = -ve reward based on how late
reward = min(agent.latest_arrival - agent.arrival_time, 0)
self.rewards_dict[i_agent] += reward
# Agents not done (arrival_time is None)
else:
# CANCELLED check (never departed)
if (agent.status == RailAgentStatus.READY_TO_DEPART):
reward = -1 * self.cancellation_factor * \
(agent.get_travel_time_on_shortest_path(self.distance_map) + 0) # 0 replaced with buffer
self.rewards_dict[i_agent] += reward
# Departed but never reached
if (agent.status == RailAgentStatus.ACTIVE):
reward = agent.get_current_delay(self._elapsed_steps, self.distance_map)
self.rewards_dict[i_agent] += reward
reward = self._handle_end_reward(agent)
self.rewards_dict[i_agent] += reward
self.dones[i_agent] = True
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment