From 8c982cbac3aa304aad855a300fdc57960fdb6321 Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Wed, 4 Sep 2019 16:52:54 +0200 Subject: [PATCH] #167 bugfix action_on_cellexit --- flatland/envs/rail_env.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 6f4b037e..e6c41882 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -291,6 +291,9 @@ class RailEnv(Environment): self.max_number_of_steps_broken + 1) + 1 agent.malfunction_data['malfunction'] = num_broken_steps + + + def step(self, action_dict_): self._elapsed_steps += 1 @@ -325,8 +328,6 @@ class RailEnv(Environment): agent.old_direction = agent.direction agent.old_position = agent.position - # Check if agent breaks at this step - self._agent_malfunction(agent) if self.dones[i_agent]: # this agent has already completed... continue @@ -335,6 +336,15 @@ class RailEnv(Environment): if i_agent not in action_dict: action_dict[i_agent] = RailEnvActions.DO_NOTHING + + if action_dict[i_agent] < 0 or action_dict[i_agent] > len(RailEnvActions): + print('ERROR: illegal action=', action_dict[i_agent], + 'for agent with index=', i_agent, + '"DO NOTHING" will be executed instead') + action_dict[i_agent] = RailEnvActions.DO_NOTHING + + action = action_dict[i_agent] + # The train is broken if agent.malfunction_data['malfunction'] > 0: @@ -350,17 +360,14 @@ class RailEnv(Environment): # Broken agents are stopped self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed'] self.agents[i_agent].moving = False + action_dict[i_agent] = RailEnvActions.DO_NOTHING # Nothing left to do with broken agent continue - if action_dict[i_agent] < 0 or action_dict[i_agent] > len(RailEnvActions): - print('ERROR: illegal action=', action_dict[i_agent], - 'for agent with index=', i_agent, - '"DO NOTHING" will be executed instead') - action_dict[i_agent] = RailEnvActions.DO_NOTHING + # Check if agent breaks at this step + self._agent_malfunction(agent) - action = action_dict[i_agent] if action == RailEnvActions.DO_NOTHING and agent.moving: # Keep moving @@ -458,7 +465,7 @@ class RailEnv(Environment): self.dones[k] = True action_required_agents = { - i: self.agents[i].speed_data['position_fraction'] <= epsilon for i in range(self.get_num_agents()) + i: self.agents[i].speed_data['position_fraction'] == 0.0 for i in range(self.get_num_agents()) } malfunction_agents = { i: self.agents[i].malfunction_data['malfunction'] for i in range(self.get_num_agents()) -- GitLab