From 45fdd081feb848f5da8940ecfe352f3d99de9cf6 Mon Sep 17 00:00:00 2001 From: "S.P. Mohanty" <spmohanty91@gmail.com> Date: Thu, 24 Oct 2019 18:07:09 +0200 Subject: [PATCH] Merge dict compresisions into a single loop. Remove manipuation of agent status in case of max episode step violation --- flatland/envs/rail_env.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 70312ecb..3bbc1ab0 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -443,13 +443,20 @@ class RailEnv(Environment): # If we're done, set reward and info_dict and step() is done. if self.dones["__all__"]: - self.rewards_dict = {i: self.global_reward for i in range(self.get_num_agents())} + self.rewards_dict = {} info_dict = { - 'action_required': {i: False for i in range(self.get_num_agents())}, - 'malfunction': {i: 0 for i in range(self.get_num_agents())}, - 'speed': {i: 0 for i in range(self.get_num_agents())}, - 'status': {i: agent.status for i, agent in enumerate(self.agents)} + "action_required" : {}, + "malfunction" : {}, + "speed" : {}, + "status" : {}, } + for i_agent in range(self.get_num_agents()): + self.rewards_dict[i_agent] = self.global_reward + info_dict["action_required"][i_agent] = False + info_dict["malfunction"][i_agent] = 0 + info_dict["speed"][i_agent] = 0 + info_dict["status"][i_agent] = self.agents[i_agent].status + return self._get_observations(), self.rewards_dict, self.dones, info_dict # Perform step on all agents @@ -462,9 +469,6 @@ class RailEnv(Environment): self.rewards_dict = {i: self.global_reward for i in range(self.get_num_agents())} if (self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps): self.dones["__all__"] = True - for i in range(self.get_num_agents()): - self.agents[i].status = RailAgentStatus.DONE - self.dones[i] = True info_dict = { 'action_required': { -- GitLab