diff --git a/.gitignore b/.gitignore index 2f1f81d1ba05de2544aeb53d61d2a222b59de31f..9477ceed17b5028f29fa2eff9018eb4a2aef212d 100644 --- a/.gitignore +++ b/.gitignore @@ -119,3 +119,6 @@ test_save.dat .visualizations playground/ + +*.pkl +**/tmp \ No newline at end of file diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 6c4a4083a8f96e55d55bb16caa3a4dad09608466..46876ac953535f4c49b57036045b405c6b986cc3 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -425,7 +425,7 @@ class RailEnv(Environment): ''' reward = None # agent done? (arrival_time is not None) - if agent.status == RailAgentStatus.DONE or agent.status == RailAgentStatus.DONE_REMOVED: + if agent.state == TrainState.DONE: # if agent arrived earlier or on time = 0 # if agent arrived later = -ve reward based on how late reward = min(agent.latest_arrival - agent.arrival_time, 0) @@ -433,12 +433,12 @@ class RailEnv(Environment): # Agents not done (arrival_time is None) else: # CANCELLED check (never departed) - if (agent.status == RailAgentStatus.READY_TO_DEPART): + if (agent.state == TrainState.READY_TO_DEPART): reward = -1 * self.cancellation_factor * \ (agent.get_travel_time_on_shortest_path(self.distance_map) + self.cancellation_time_buffer) # Departed but never reached - if (agent.status == RailAgentStatus.ACTIVE): + if (agent.state.is_on_map_state()): reward = agent.get_current_delay(self._elapsed_steps, self.distance_map) return reward @@ -488,6 +488,8 @@ class RailEnv(Environment): self.clear_rewards_dict() + have_all_agents_ended = True # Boolean flag to check if all agents are done + self.motionCheck = ac.MotionCheck() # reset the motion check temp_transition_data = {} @@ -557,9 +559,11 @@ class RailEnv(Environment): # Remove agent is required if self.remove_agents_at_target and agent.state == TrainState.DONE: agent.position = None + + have_all_agents_ended &= (agent.state == TrainState.DONE) ## Update rewards - # self.update_rewards(i_agent, agent, rail) # TODO : Rewards - Fix this + # self.update_rewards(i_agent, agent, rail) # TODO : Step Rewards ## Update counters (malfunction and speed) agent.speed_counter.update_counter(agent.state) @@ -568,8 +572,22 @@ class RailEnv(Environment): # Clear old action when starting in new cell if agent.speed_counter.is_cell_entry: agent.action_saver.clear_saved_action() + + + self.rewards_dict = {i_agent: 0 for i_agent in range(len(self.agents))} - self.rewards_dict = {i_agent: 0 for i_agent in range(len(self.agents))} # TODO : Rewards - Remove this + if ((self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps)) \ + or have_all_agents_ended : + + for i_agent, agent in enumerate(self.agents): + + reward = self._handle_end_reward(agent) + self.rewards_dict[i_agent] += reward + + self.dones[i_agent] = True + + self.dones["__all__"] = True + return self._get_observations(), self.rewards_dict, self.dones, self.get_info_dict() def record_timestep(self, dActions):