From 8c357c5bfdba375d29293ac95191e50630a760f6 Mon Sep 17 00:00:00 2001 From: Dipam Chakraborty <dipam@aicrowd.com> Date: Thu, 9 Sep 2021 20:47:10 +0530 Subject: [PATCH] End episode once dones is completed --- flatland/envs/rail_env.py | 45 +++++++++++++++++---------------------- 1 file changed, 19 insertions(+), 26 deletions(-) diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 75b10e8f..1410556e 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -461,39 +461,32 @@ class RailEnv(Environment): action = action_preprocessing.preprocess_moving_action(action, self.rail, current_position, current_direction) return action + + def clear_rewards_dict(self): + """ Reset the step rewards """ - def step(self, action_dict_: Dict[int, RailEnvActions]): - """ - Updates rewards for the agents at a step. - """ - self._elapsed_steps += 1 - - # If we're done, set reward and info_dict and step() is done. - if self.dones["__all__"]: # TODO: Move boilerplate to different function - self.rewards_dict = {} - info_dict = { - "action_required": {}, - "malfunction": {}, - "speed": {}, - "status": {}, - } - for i_agent, agent in enumerate(self.agents): - self.rewards_dict[i_agent] = self.global_reward - info_dict["action_required"][i_agent] = False - info_dict["malfunction"][i_agent] = 0 - info_dict["speed"][i_agent] = 0 - info_dict["status"][i_agent] = agent.status - - return self._get_observations(), self.rewards_dict, self.dones, info_dict - - # Reset the step rewards self.rewards_dict = dict() + + def get_info_dict(self): # TODO Important : Update this info_dict = { "action_required": {}, "malfunction": {}, "speed": {}, "status": {}, } + return info_dict + + def step(self, action_dict_: Dict[int, RailEnvActions]): + """ + Updates rewards for the agents at a step. + """ + self._elapsed_steps += 1 + + # Not allowed to step further once done + if self.dones["__all__"]: + raise Exception("Episode is done, cannot call step()") + + self.clear_rewards_dict() self.motionCheck = ac.MotionCheck() # reset the motion check @@ -578,7 +571,7 @@ class RailEnv(Environment): agent.action_saver.clear_saved_action() self.rewards_dict = {i_agent: 0 for i_agent in range(len(self.agents))} # TODO : Rewards - Remove this - return self._get_observations(), self.rewards_dict, self.dones, info_dict # TODO : Will need changes? + return self._get_observations(), self.rewards_dict, self.dones, self.get_info_dict() def record_timestep(self, dActions): ''' Record the positions and orientations of all agents in memory, in the cur_episode -- GitLab