Skip to content
Snippets Groups Projects
Commit 8a3a043c authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

Merge branch 'env-step-facelift' of gitlab.aicrowd.com:flatland/flatland into env-step-facelift

parents 3f501522 64a36242
No related branches found
No related tags found
No related merge requests found
...@@ -119,3 +119,6 @@ test_save.dat ...@@ -119,3 +119,6 @@ test_save.dat
.visualizations .visualizations
playground/ playground/
*.pkl
**/tmp
\ No newline at end of file
...@@ -425,7 +425,7 @@ class RailEnv(Environment): ...@@ -425,7 +425,7 @@ class RailEnv(Environment):
''' '''
reward = None reward = None
# agent done? (arrival_time is not None) # agent done? (arrival_time is not None)
if agent.status == RailAgentStatus.DONE or agent.status == RailAgentStatus.DONE_REMOVED: if agent.state == TrainState.DONE:
# if agent arrived earlier or on time = 0 # if agent arrived earlier or on time = 0
# if agent arrived later = -ve reward based on how late # if agent arrived later = -ve reward based on how late
reward = min(agent.latest_arrival - agent.arrival_time, 0) reward = min(agent.latest_arrival - agent.arrival_time, 0)
...@@ -433,12 +433,12 @@ class RailEnv(Environment): ...@@ -433,12 +433,12 @@ class RailEnv(Environment):
# Agents not done (arrival_time is None) # Agents not done (arrival_time is None)
else: else:
# CANCELLED check (never departed) # CANCELLED check (never departed)
if (agent.status == RailAgentStatus.READY_TO_DEPART): if (agent.state == TrainState.READY_TO_DEPART):
reward = -1 * self.cancellation_factor * \ reward = -1 * self.cancellation_factor * \
(agent.get_travel_time_on_shortest_path(self.distance_map) + self.cancellation_time_buffer) (agent.get_travel_time_on_shortest_path(self.distance_map) + self.cancellation_time_buffer)
# Departed but never reached # Departed but never reached
if (agent.status == RailAgentStatus.ACTIVE): if (agent.state.is_on_map_state()):
reward = agent.get_current_delay(self._elapsed_steps, self.distance_map) reward = agent.get_current_delay(self._elapsed_steps, self.distance_map)
return reward return reward
...@@ -488,6 +488,8 @@ class RailEnv(Environment): ...@@ -488,6 +488,8 @@ class RailEnv(Environment):
self.clear_rewards_dict() self.clear_rewards_dict()
have_all_agents_ended = True # Boolean flag to check if all agents are done
self.motionCheck = ac.MotionCheck() # reset the motion check self.motionCheck = ac.MotionCheck() # reset the motion check
temp_transition_data = {} temp_transition_data = {}
...@@ -557,9 +559,11 @@ class RailEnv(Environment): ...@@ -557,9 +559,11 @@ class RailEnv(Environment):
# Remove agent is required # Remove agent is required
if self.remove_agents_at_target and agent.state == TrainState.DONE: if self.remove_agents_at_target and agent.state == TrainState.DONE:
agent.position = None agent.position = None
have_all_agents_ended &= (agent.state == TrainState.DONE)
## Update rewards ## Update rewards
# self.update_rewards(i_agent, agent, rail) # TODO : Rewards - Fix this # self.update_rewards(i_agent, agent, rail) # TODO : Step Rewards
## Update counters (malfunction and speed) ## Update counters (malfunction and speed)
agent.speed_counter.update_counter(agent.state) agent.speed_counter.update_counter(agent.state)
...@@ -568,8 +572,22 @@ class RailEnv(Environment): ...@@ -568,8 +572,22 @@ class RailEnv(Environment):
# Clear old action when starting in new cell # Clear old action when starting in new cell
if agent.speed_counter.is_cell_entry: if agent.speed_counter.is_cell_entry:
agent.action_saver.clear_saved_action() agent.action_saver.clear_saved_action()
self.rewards_dict = {i_agent: 0 for i_agent in range(len(self.agents))}
self.rewards_dict = {i_agent: 0 for i_agent in range(len(self.agents))} # TODO : Rewards - Remove this if ((self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps)) \
or have_all_agents_ended :
for i_agent, agent in enumerate(self.agents):
reward = self._handle_end_reward(agent)
self.rewards_dict[i_agent] += reward
self.dones[i_agent] = True
self.dones["__all__"] = True
return self._get_observations(), self.rewards_dict, self.dones, self.get_info_dict() return self._get_observations(), self.rewards_dict, self.dones, self.get_info_dict()
def record_timestep(self, dActions): def record_timestep(self, dActions):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment