Skip to content
Snippets Groups Projects
Commit 45fdd081 authored by spmohanty's avatar spmohanty
Browse files

Merge dict compresisions into a single loop. Remove manipuation of agent...

Merge dict compresisions into a single loop. Remove manipuation of agent status in case of max episode step violation
parent bcbb734a
No related branches found
No related tags found
No related merge requests found
...@@ -443,13 +443,20 @@ class RailEnv(Environment): ...@@ -443,13 +443,20 @@ class RailEnv(Environment):
# If we're done, set reward and info_dict and step() is done. # If we're done, set reward and info_dict and step() is done.
if self.dones["__all__"]: if self.dones["__all__"]:
self.rewards_dict = {i: self.global_reward for i in range(self.get_num_agents())} self.rewards_dict = {}
info_dict = { info_dict = {
'action_required': {i: False for i in range(self.get_num_agents())}, "action_required" : {},
'malfunction': {i: 0 for i in range(self.get_num_agents())}, "malfunction" : {},
'speed': {i: 0 for i in range(self.get_num_agents())}, "speed" : {},
'status': {i: agent.status for i, agent in enumerate(self.agents)} "status" : {},
} }
for i_agent in range(self.get_num_agents()):
self.rewards_dict[i_agent] = self.global_reward
info_dict["action_required"][i_agent] = False
info_dict["malfunction"][i_agent] = 0
info_dict["speed"][i_agent] = 0
info_dict["status"][i_agent] = self.agents[i_agent].status
return self._get_observations(), self.rewards_dict, self.dones, info_dict return self._get_observations(), self.rewards_dict, self.dones, info_dict
# Perform step on all agents # Perform step on all agents
...@@ -462,9 +469,6 @@ class RailEnv(Environment): ...@@ -462,9 +469,6 @@ class RailEnv(Environment):
self.rewards_dict = {i: self.global_reward for i in range(self.get_num_agents())} self.rewards_dict = {i: self.global_reward for i in range(self.get_num_agents())}
if (self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps): if (self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps):
self.dones["__all__"] = True self.dones["__all__"] = True
for i in range(self.get_num_agents()):
self.agents[i].status = RailAgentStatus.DONE
self.dones[i] = True
info_dict = { info_dict = {
'action_required': { 'action_required': {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment