Skip to content
Snippets Groups Projects
Commit b1b2c42e authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

updated malfunction behavior. Now a train can break at any point during its...

updated malfunction behavior. Now a train can break at any point during its travel. We don't count number of stops anymore. This is to simulate more complex stochastic events.
parent cc6b05f0
No related branches found
No related tags found
No related merge requests found
...@@ -238,7 +238,7 @@ class RailEnv(Environment): ...@@ -238,7 +238,7 @@ class RailEnv(Environment):
agent.speed_data['position_fraction'] = 0.0 agent.speed_data['position_fraction'] = 0.0
agent.malfunction_data['malfunction'] = 0 agent.malfunction_data['malfunction'] = 0
self._agent_stopped(i_agent) self._agent_malfunction(agent)
self.num_resets += 1 self.num_resets += 1
self._elapsed_steps = 0 self._elapsed_steps = 0
...@@ -253,29 +253,29 @@ class RailEnv(Environment): ...@@ -253,29 +253,29 @@ class RailEnv(Environment):
# Return the new observation vectors for each agent # Return the new observation vectors for each agent
return self._get_observations() return self._get_observations()
def _agent_stopped(self, i_agent): def _agent_malfunction(self, agent):
# Decrease counter for next event # Decrease counter for next event
self.agents[i_agent].malfunction_data['next_malfunction'] -= 1 agent.malfunction_data['next_malfunction'] -= 1
# Only agents that have a positive rate for malfunctions are considered # Only agents that have a positive rate for malfunctions and are not currently broken are considered
if self.agents[i_agent].malfunction_data['malfunction_rate'] > 0 >= self.agents[i_agent].malfunction_data[ if agent.malfunction_data['malfunction_rate'] > 0 >= agent.malfunction_data[
'malfunction']: 'malfunction']:
# If counter has come to zero --> Agent has malfunction # If counter has come to zero --> Agent has malfunction
# set next malfunction time and duration of current malfunction # set next malfunction time and duration of current malfunction
if self.agents[i_agent].malfunction_data['next_malfunction'] <= 0: if agent.malfunction_data['next_malfunction'] <= 0:
# Increase number of malfunctions # Increase number of malfunctions
self.agents[i_agent].malfunction_data['nr_malfunctions'] += 1 agent.malfunction_data['nr_malfunctions'] += 1
# Next malfunction in number of stops # Next malfunction in number of stops
next_breakdown = int( next_breakdown = int(
np.random.exponential(scale=self.agents[i_agent].malfunction_data['malfunction_rate'])) np.random.exponential(scale=agent.malfunction_data['malfunction_rate']))
self.agents[i_agent].malfunction_data['next_malfunction'] = next_breakdown agent.malfunction_data['next_malfunction'] = next_breakdown
# Duration of current malfunction # Duration of current malfunction
num_broken_steps = np.random.randint(self.min_number_of_steps_broken, num_broken_steps = np.random.randint(self.min_number_of_steps_broken,
self.max_number_of_steps_broken + 1) + 1 self.max_number_of_steps_broken + 1) + 1
self.agents[i_agent].malfunction_data['malfunction'] = num_broken_steps agent.malfunction_data['malfunction'] = num_broken_steps
def step(self, action_dict_): def step(self, action_dict_):
self._elapsed_steps += 1 self._elapsed_steps += 1
...@@ -306,6 +306,9 @@ class RailEnv(Environment): ...@@ -306,6 +306,9 @@ class RailEnv(Environment):
agent.old_direction = agent.direction agent.old_direction = agent.direction
agent.old_position = agent.position agent.old_position = agent.position
# Check if agent breaks at this step
self._agent_malfunction(agent)
if self.dones[i_agent]: # this agent has already completed... if self.dones[i_agent]: # this agent has already completed...
continue continue
...@@ -341,7 +344,6 @@ class RailEnv(Environment): ...@@ -341,7 +344,6 @@ class RailEnv(Environment):
# Only allow halting an agent on entering new cells. # Only allow halting an agent on entering new cells.
agent.moving = False agent.moving = False
self.rewards_dict[i_agent] += stop_penalty self.rewards_dict[i_agent] += stop_penalty
self._agent_stopped(i_agent)
if not agent.moving and not (action == RailEnvActions.DO_NOTHING or action == RailEnvActions.STOP_MOVING): if not agent.moving and not (action == RailEnvActions.DO_NOTHING or action == RailEnvActions.STOP_MOVING):
# Allow agent to start with any forward or direction action # Allow agent to start with any forward or direction action
...@@ -385,8 +387,6 @@ class RailEnv(Environment): ...@@ -385,8 +387,6 @@ class RailEnv(Environment):
self.rewards_dict[i_agent] += invalid_action_penalty self.rewards_dict[i_agent] += invalid_action_penalty
self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed'] self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed']
self.rewards_dict[i_agent] += stop_penalty self.rewards_dict[i_agent] += stop_penalty
if agent.moving:
self._agent_stopped(i_agent)
agent.moving = False agent.moving = False
continue continue
else: else:
...@@ -394,8 +394,6 @@ class RailEnv(Environment): ...@@ -394,8 +394,6 @@ class RailEnv(Environment):
self.rewards_dict[i_agent] += invalid_action_penalty self.rewards_dict[i_agent] += invalid_action_penalty
self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed'] self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed']
self.rewards_dict[i_agent] += stop_penalty self.rewards_dict[i_agent] += stop_penalty
if agent.moving:
self._agent_stopped(i_agent)
agent.moving = False agent.moving = False
continue continue
...@@ -416,14 +414,11 @@ class RailEnv(Environment): ...@@ -416,14 +414,11 @@ class RailEnv(Environment):
agent.speed_data['position_fraction'] = 0.0 agent.speed_data['position_fraction'] = 0.0
else: else:
# If the agent cannot move due to any reason, we set its state to not moving # If the agent cannot move due to any reason, we set its state to not moving
if agent.moving:
self._agent_stopped(i_agent)
agent.moving = False agent.moving = False
if np.equal(agent.position, agent.target).all(): if np.equal(agent.position, agent.target).all():
self.dones[i_agent] = True self.dones[i_agent] = True
agent.moving = False agent.moving = False
# Do not call self._agent_stopped, as the agent has terminated its task
else: else:
self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed'] self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed']
......
...@@ -53,13 +53,13 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): ...@@ -53,13 +53,13 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
def test_malfunction_process(): def test_malfunction_process():
# Set fixed malfunction duration for this test # Set fixed malfunction duration for this test
stochastic_data = {'prop_malfunction': 1., stochastic_data = {'prop_malfunction': 1.,
'malfunction_rate': 5, 'malfunction_rate': 1000,
'min_duration': 3, 'min_duration': 3,
'max_duration': 3} 'max_duration': 3}
np.random.seed(5) np.random.seed(5)
env = RailEnv(width=14, env = RailEnv(width=20,
height=14, height=20,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999,
seed=0), seed=0),
number_of_agents=2, number_of_agents=2,
...@@ -82,17 +82,17 @@ def test_malfunction_process(): ...@@ -82,17 +82,17 @@ def test_malfunction_process():
if step % 5 == 0: if step % 5 == 0:
# Stop the agent and set it to be malfunctioning # Stop the agent and set it to be malfunctioning
actions[0] = 4 env.agents[0].malfunction_data['malfunction'] = -1
env.agents[0].malfunction_data['next_malfunction'] = 0 env.agents[0].malfunction_data['next_malfunction'] = 0
agent_halts += 1 agent_halts += 1
obs, all_rewards, done, _ = env.step(actions)
if env.agents[0].malfunction_data['malfunction'] > 0: if env.agents[0].malfunction_data['malfunction'] > 0:
agent_malfunctioning = True agent_malfunctioning = True
else: else:
agent_malfunctioning = False agent_malfunctioning = False
obs, all_rewards, done, _ = env.step(actions)
if agent_malfunctioning: if agent_malfunctioning:
# Check that agent is not moving while malfunctioning # Check that agent is not moving while malfunctioning
assert agent_old_position == env.agents[0].position assert agent_old_position == env.agents[0].position
...@@ -101,7 +101,7 @@ def test_malfunction_process(): ...@@ -101,7 +101,7 @@ def test_malfunction_process():
total_down_time += env.agents[0].malfunction_data['malfunction'] total_down_time += env.agents[0].malfunction_data['malfunction']
# Check that the appropriate number of malfunctions is achieved # Check that the appropriate number of malfunctions is achieved
assert env.agents[0].malfunction_data['nr_malfunctions'] == 5 assert env.agents[0].malfunction_data['nr_malfunctions'] == 21
# Check that 20 stops where performed # Check that 20 stops where performed
assert agent_halts == 20 assert agent_halts == 20
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment