Skip to content
Snippets Groups Projects
Commit 8c982cba authored by u214892's avatar u214892
Browse files

#167 bugfix action_on_cellexit

parent 4ad3eae7
No related branches found
No related tags found
No related merge requests found
...@@ -291,6 +291,9 @@ class RailEnv(Environment): ...@@ -291,6 +291,9 @@ class RailEnv(Environment):
self.max_number_of_steps_broken + 1) + 1 self.max_number_of_steps_broken + 1) + 1
agent.malfunction_data['malfunction'] = num_broken_steps agent.malfunction_data['malfunction'] = num_broken_steps
def step(self, action_dict_): def step(self, action_dict_):
self._elapsed_steps += 1 self._elapsed_steps += 1
...@@ -325,8 +328,6 @@ class RailEnv(Environment): ...@@ -325,8 +328,6 @@ class RailEnv(Environment):
agent.old_direction = agent.direction agent.old_direction = agent.direction
agent.old_position = agent.position agent.old_position = agent.position
# Check if agent breaks at this step
self._agent_malfunction(agent)
if self.dones[i_agent]: # this agent has already completed... if self.dones[i_agent]: # this agent has already completed...
continue continue
...@@ -335,6 +336,15 @@ class RailEnv(Environment): ...@@ -335,6 +336,15 @@ class RailEnv(Environment):
if i_agent not in action_dict: if i_agent not in action_dict:
action_dict[i_agent] = RailEnvActions.DO_NOTHING action_dict[i_agent] = RailEnvActions.DO_NOTHING
if action_dict[i_agent] < 0 or action_dict[i_agent] > len(RailEnvActions):
print('ERROR: illegal action=', action_dict[i_agent],
'for agent with index=', i_agent,
'"DO NOTHING" will be executed instead')
action_dict[i_agent] = RailEnvActions.DO_NOTHING
action = action_dict[i_agent]
# The train is broken # The train is broken
if agent.malfunction_data['malfunction'] > 0: if agent.malfunction_data['malfunction'] > 0:
...@@ -350,17 +360,14 @@ class RailEnv(Environment): ...@@ -350,17 +360,14 @@ class RailEnv(Environment):
# Broken agents are stopped # Broken agents are stopped
self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed'] self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed']
self.agents[i_agent].moving = False self.agents[i_agent].moving = False
action_dict[i_agent] = RailEnvActions.DO_NOTHING
# Nothing left to do with broken agent # Nothing left to do with broken agent
continue continue
if action_dict[i_agent] < 0 or action_dict[i_agent] > len(RailEnvActions): # Check if agent breaks at this step
print('ERROR: illegal action=', action_dict[i_agent], self._agent_malfunction(agent)
'for agent with index=', i_agent,
'"DO NOTHING" will be executed instead')
action_dict[i_agent] = RailEnvActions.DO_NOTHING
action = action_dict[i_agent]
if action == RailEnvActions.DO_NOTHING and agent.moving: if action == RailEnvActions.DO_NOTHING and agent.moving:
# Keep moving # Keep moving
...@@ -458,7 +465,7 @@ class RailEnv(Environment): ...@@ -458,7 +465,7 @@ class RailEnv(Environment):
self.dones[k] = True self.dones[k] = True
action_required_agents = { action_required_agents = {
i: self.agents[i].speed_data['position_fraction'] <= epsilon for i in range(self.get_num_agents()) i: self.agents[i].speed_data['position_fraction'] == 0.0 for i in range(self.get_num_agents())
} }
malfunction_agents = { malfunction_agents = {
i: self.agents[i].malfunction_data['malfunction'] for i in range(self.get_num_agents()) i: self.agents[i].malfunction_data['malfunction'] for i in range(self.get_num_agents())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment