Skip to content
Snippets Groups Projects
Commit 45fdd081 authored by spmohanty's avatar spmohanty
Browse files

Merge dict compresisions into a single loop. Remove manipuation of agent...

Merge dict compresisions into a single loop. Remove manipuation of agent status in case of max episode step violation
parent bcbb734a
No related branches found
No related tags found
1 merge request!239Redis opts
......@@ -443,13 +443,20 @@ class RailEnv(Environment):
# If we're done, set reward and info_dict and step() is done.
if self.dones["__all__"]:
self.rewards_dict = {i: self.global_reward for i in range(self.get_num_agents())}
self.rewards_dict = {}
info_dict = {
'action_required': {i: False for i in range(self.get_num_agents())},
'malfunction': {i: 0 for i in range(self.get_num_agents())},
'speed': {i: 0 for i in range(self.get_num_agents())},
'status': {i: agent.status for i, agent in enumerate(self.agents)}
"action_required" : {},
"malfunction" : {},
"speed" : {},
"status" : {},
}
for i_agent in range(self.get_num_agents()):
self.rewards_dict[i_agent] = self.global_reward
info_dict["action_required"][i_agent] = False
info_dict["malfunction"][i_agent] = 0
info_dict["speed"][i_agent] = 0
info_dict["status"][i_agent] = self.agents[i_agent].status
return self._get_observations(), self.rewards_dict, self.dones, info_dict
# Perform step on all agents
......@@ -462,9 +469,6 @@ class RailEnv(Environment):
self.rewards_dict = {i: self.global_reward for i in range(self.get_num_agents())}
if (self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps):
self.dones["__all__"] = True
for i in range(self.get_num_agents()):
self.agents[i].status = RailAgentStatus.DONE
self.dones[i] = True
info_dict = {
'action_required': {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment