Commit b1b2c42e authored by Erik Nygren's avatar Erik Nygren 🚅
Browse files

updated malfunction behavior. Now a train can break at any point during its...

updated malfunction behavior. Now a train can break at any point during its travel. We don't count number of stops anymore. This is to simulate more complex stochastic events.
parent cc6b05f0
Pipeline #1723 passed with stages
in 52 minutes and 13 seconds
......@@ -238,7 +238,7 @@ class RailEnv(Environment):
agent.speed_data['position_fraction'] = 0.0
agent.malfunction_data['malfunction'] = 0
self._agent_stopped(i_agent)
self._agent_malfunction(agent)
self.num_resets += 1
self._elapsed_steps = 0
......@@ -253,29 +253,29 @@ class RailEnv(Environment):
# Return the new observation vectors for each agent
return self._get_observations()
def _agent_stopped(self, i_agent):
def _agent_malfunction(self, agent):
# Decrease counter for next event
self.agents[i_agent].malfunction_data['next_malfunction'] -= 1
agent.malfunction_data['next_malfunction'] -= 1
# Only agents that have a positive rate for malfunctions are considered
if self.agents[i_agent].malfunction_data['malfunction_rate'] > 0 >= self.agents[i_agent].malfunction_data[
# Only agents that have a positive rate for malfunctions and are not currently broken are considered
if agent.malfunction_data['malfunction_rate'] > 0 >= agent.malfunction_data[
'malfunction']:
# If counter has come to zero --> Agent has malfunction
# set next malfunction time and duration of current malfunction
if self.agents[i_agent].malfunction_data['next_malfunction'] <= 0:
if agent.malfunction_data['next_malfunction'] <= 0:
# Increase number of malfunctions
self.agents[i_agent].malfunction_data['nr_malfunctions'] += 1
agent.malfunction_data['nr_malfunctions'] += 1
# Next malfunction in number of stops
next_breakdown = int(
np.random.exponential(scale=self.agents[i_agent].malfunction_data['malfunction_rate']))
self.agents[i_agent].malfunction_data['next_malfunction'] = next_breakdown
np.random.exponential(scale=agent.malfunction_data['malfunction_rate']))
agent.malfunction_data['next_malfunction'] = next_breakdown
# Duration of current malfunction
num_broken_steps = np.random.randint(self.min_number_of_steps_broken,
self.max_number_of_steps_broken + 1) + 1
self.agents[i_agent].malfunction_data['malfunction'] = num_broken_steps
agent.malfunction_data['malfunction'] = num_broken_steps
def step(self, action_dict_):
self._elapsed_steps += 1
......@@ -306,6 +306,9 @@ class RailEnv(Environment):
agent.old_direction = agent.direction
agent.old_position = agent.position
# Check if agent breaks at this step
self._agent_malfunction(agent)
if self.dones[i_agent]: # this agent has already completed...
continue
......@@ -341,7 +344,6 @@ class RailEnv(Environment):
# Only allow halting an agent on entering new cells.
agent.moving = False
self.rewards_dict[i_agent] += stop_penalty
self._agent_stopped(i_agent)
if not agent.moving and not (action == RailEnvActions.DO_NOTHING or action == RailEnvActions.STOP_MOVING):
# Allow agent to start with any forward or direction action
......@@ -385,8 +387,6 @@ class RailEnv(Environment):
self.rewards_dict[i_agent] += invalid_action_penalty
self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed']
self.rewards_dict[i_agent] += stop_penalty
if agent.moving:
self._agent_stopped(i_agent)
agent.moving = False
continue
else:
......@@ -394,8 +394,6 @@ class RailEnv(Environment):
self.rewards_dict[i_agent] += invalid_action_penalty
self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed']
self.rewards_dict[i_agent] += stop_penalty
if agent.moving:
self._agent_stopped(i_agent)
agent.moving = False
continue
......@@ -416,14 +414,11 @@ class RailEnv(Environment):
agent.speed_data['position_fraction'] = 0.0
else:
# If the agent cannot move due to any reason, we set its state to not moving
if agent.moving:
self._agent_stopped(i_agent)
agent.moving = False
if np.equal(agent.position, agent.target).all():
self.dones[i_agent] = True
agent.moving = False
# Do not call self._agent_stopped, as the agent has terminated its task
else:
self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed']
......
......@@ -53,13 +53,13 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
def test_malfunction_process():
# Set fixed malfunction duration for this test
stochastic_data = {'prop_malfunction': 1.,
'malfunction_rate': 5,
'malfunction_rate': 1000,
'min_duration': 3,
'max_duration': 3}
np.random.seed(5)
env = RailEnv(width=14,
height=14,
env = RailEnv(width=20,
height=20,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999,
seed=0),
number_of_agents=2,
......@@ -82,17 +82,17 @@ def test_malfunction_process():
if step % 5 == 0:
# Stop the agent and set it to be malfunctioning
actions[0] = 4
env.agents[0].malfunction_data['malfunction'] = -1
env.agents[0].malfunction_data['next_malfunction'] = 0
agent_halts += 1
obs, all_rewards, done, _ = env.step(actions)
if env.agents[0].malfunction_data['malfunction'] > 0:
agent_malfunctioning = True
else:
agent_malfunctioning = False
obs, all_rewards, done, _ = env.step(actions)
if agent_malfunctioning:
# Check that agent is not moving while malfunctioning
assert agent_old_position == env.agents[0].position
......@@ -101,7 +101,7 @@ def test_malfunction_process():
total_down_time += env.agents[0].malfunction_data['malfunction']
# Check that the appropriate number of malfunctions is achieved
assert env.agents[0].malfunction_data['nr_malfunctions'] == 5
assert env.agents[0].malfunction_data['nr_malfunctions'] == 21
# Check that 20 stops where performed
assert agent_halts == 20
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment