diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 75b10e8f71b45b078017f0e1ff9061249c594c75..1410556e9701b56c6b2f86e545b3baa7c240022f 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -461,39 +461,32 @@ class RailEnv(Environment): action = action_preprocessing.preprocess_moving_action(action, self.rail, current_position, current_direction) return action + + def clear_rewards_dict(self): + """ Reset the step rewards """ - def step(self, action_dict_: Dict[int, RailEnvActions]): - """ - Updates rewards for the agents at a step. - """ - self._elapsed_steps += 1 - - # If we're done, set reward and info_dict and step() is done. - if self.dones["__all__"]: # TODO: Move boilerplate to different function - self.rewards_dict = {} - info_dict = { - "action_required": {}, - "malfunction": {}, - "speed": {}, - "status": {}, - } - for i_agent, agent in enumerate(self.agents): - self.rewards_dict[i_agent] = self.global_reward - info_dict["action_required"][i_agent] = False - info_dict["malfunction"][i_agent] = 0 - info_dict["speed"][i_agent] = 0 - info_dict["status"][i_agent] = agent.status - - return self._get_observations(), self.rewards_dict, self.dones, info_dict - - # Reset the step rewards self.rewards_dict = dict() + + def get_info_dict(self): # TODO Important : Update this info_dict = { "action_required": {}, "malfunction": {}, "speed": {}, "status": {}, } + return info_dict + + def step(self, action_dict_: Dict[int, RailEnvActions]): + """ + Updates rewards for the agents at a step. + """ + self._elapsed_steps += 1 + + # Not allowed to step further once done + if self.dones["__all__"]: + raise Exception("Episode is done, cannot call step()") + + self.clear_rewards_dict() self.motionCheck = ac.MotionCheck() # reset the motion check @@ -578,7 +571,7 @@ class RailEnv(Environment): agent.action_saver.clear_saved_action() self.rewards_dict = {i_agent: 0 for i_agent in range(len(self.agents))} # TODO : Rewards - Remove this - return self._get_observations(), self.rewards_dict, self.dones, info_dict # TODO : Will need changes? + return self._get_observations(), self.rewards_dict, self.dones, self.get_info_dict() def record_timestep(self, dActions): ''' Record the positions and orientations of all agents in memory, in the cur_episode