Commit 8c357c5b authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

End episode once dones is completed

parent fc7e752c
Pipeline #8447 canceled with stages
......@@ -461,39 +461,32 @@ class RailEnv(Environment):
action = action_preprocessing.preprocess_moving_action(action, self.rail, current_position, current_direction)
return action
def clear_rewards_dict(self):
""" Reset the step rewards """
def step(self, action_dict_: Dict[int, RailEnvActions]):
"""
Updates rewards for the agents at a step.
"""
self._elapsed_steps += 1
# If we're done, set reward and info_dict and step() is done.
if self.dones["__all__"]: # TODO: Move boilerplate to different function
self.rewards_dict = {}
info_dict = {
"action_required": {},
"malfunction": {},
"speed": {},
"status": {},
}
for i_agent, agent in enumerate(self.agents):
self.rewards_dict[i_agent] = self.global_reward
info_dict["action_required"][i_agent] = False
info_dict["malfunction"][i_agent] = 0
info_dict["speed"][i_agent] = 0
info_dict["status"][i_agent] = agent.status
return self._get_observations(), self.rewards_dict, self.dones, info_dict
# Reset the step rewards
self.rewards_dict = dict()
def get_info_dict(self): # TODO Important : Update this
info_dict = {
"action_required": {},
"malfunction": {},
"speed": {},
"status": {},
}
return info_dict
def step(self, action_dict_: Dict[int, RailEnvActions]):
"""
Updates rewards for the agents at a step.
"""
self._elapsed_steps += 1
# Not allowed to step further once done
if self.dones["__all__"]:
raise Exception("Episode is done, cannot call step()")
self.clear_rewards_dict()
self.motionCheck = ac.MotionCheck() # reset the motion check
......@@ -578,7 +571,7 @@ class RailEnv(Environment):
agent.action_saver.clear_saved_action()
self.rewards_dict = {i_agent: 0 for i_agent in range(len(self.agents))} # TODO : Rewards - Remove this
return self._get_observations(), self.rewards_dict, self.dones, info_dict # TODO : Will need changes?
return self._get_observations(), self.rewards_dict, self.dones, self.get_info_dict()
def record_timestep(self, dActions):
''' Record the positions and orientations of all agents in memory, in the cur_episode
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment