Commit 64a36242 authored by nimishsantosh107's avatar nimishsantosh107
Browse files

handle_end_rewards used in new env-step()

parent 8c357c5b
Pipeline #8449 failed with stages
in 5 minutes and 23 seconds
......@@ -119,3 +119,6 @@ test_save.dat
.visualizations
playground/
*.pkl
**/tmp
\ No newline at end of file
......@@ -425,7 +425,7 @@ class RailEnv(Environment):
'''
reward = None
# agent done? (arrival_time is not None)
if agent.status == RailAgentStatus.DONE or agent.status == RailAgentStatus.DONE_REMOVED:
if agent.state == TrainState.DONE:
# if agent arrived earlier or on time = 0
# if agent arrived later = -ve reward based on how late
reward = min(agent.latest_arrival - agent.arrival_time, 0)
......@@ -433,12 +433,12 @@ class RailEnv(Environment):
# Agents not done (arrival_time is None)
else:
# CANCELLED check (never departed)
if (agent.status == RailAgentStatus.READY_TO_DEPART):
if (agent.state == TrainState.READY_TO_DEPART):
reward = -1 * self.cancellation_factor * \
(agent.get_travel_time_on_shortest_path(self.distance_map) + self.cancellation_time_buffer)
# Departed but never reached
if (agent.status == RailAgentStatus.ACTIVE):
if (agent.state.is_on_map_state()):
reward = agent.get_current_delay(self._elapsed_steps, self.distance_map)
return reward
......@@ -488,6 +488,8 @@ class RailEnv(Environment):
self.clear_rewards_dict()
have_all_agents_ended = True # Boolean flag to check if all agents are done
self.motionCheck = ac.MotionCheck() # reset the motion check
temp_transition_data = {}
......@@ -558,9 +560,11 @@ class RailEnv(Environment):
# Remove agent is required
if self.remove_agents_at_target and agent.state == TrainState.DONE:
agent.position = None
have_all_agents_ended &= (agent.state == TrainState.DONE)
## Update rewards
# self.update_rewards(i_agent, agent, rail) # TODO : Rewards - Fix this
# self.update_rewards(i_agent, agent, rail) # TODO : Step Rewards
## Update counters (malfunction and speed)
agent.speed_counter.update_counter(agent.state)
......@@ -569,8 +573,22 @@ class RailEnv(Environment):
# Clear old action when starting in new cell
if agent.speed_counter.is_cell_entry:
agent.action_saver.clear_saved_action()
self.rewards_dict = {i_agent: 0 for i_agent in range(len(self.agents))}
self.rewards_dict = {i_agent: 0 for i_agent in range(len(self.agents))} # TODO : Rewards - Remove this
if ((self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps)) \
or have_all_agents_ended :
for i_agent, agent in enumerate(self.agents):
reward = self._handle_end_reward(agent)
self.rewards_dict[i_agent] += reward
self.dones[i_agent] = True
self.dones["__all__"] = True
return self._get_observations(), self.rewards_dict, self.dones, self.get_info_dict()
def record_timestep(self, dActions):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment