Skip to content
Snippets Groups Projects
Commit b1b2c42e authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

updated malfunction behavior. Now a train can break at any point during its...

updated malfunction behavior. Now a train can break at any point during its travel. We don't count number of stops anymore. This is to simulate more complex stochastic events.
parent cc6b05f0
No related branches found
No related tags found
No related merge requests found
......@@ -238,7 +238,7 @@ class RailEnv(Environment):
agent.speed_data['position_fraction'] = 0.0
agent.malfunction_data['malfunction'] = 0
self._agent_stopped(i_agent)
self._agent_malfunction(agent)
self.num_resets += 1
self._elapsed_steps = 0
......@@ -253,29 +253,29 @@ class RailEnv(Environment):
# Return the new observation vectors for each agent
return self._get_observations()
def _agent_stopped(self, i_agent):
def _agent_malfunction(self, agent):
# Decrease counter for next event
self.agents[i_agent].malfunction_data['next_malfunction'] -= 1
agent.malfunction_data['next_malfunction'] -= 1
# Only agents that have a positive rate for malfunctions are considered
if self.agents[i_agent].malfunction_data['malfunction_rate'] > 0 >= self.agents[i_agent].malfunction_data[
# Only agents that have a positive rate for malfunctions and are not currently broken are considered
if agent.malfunction_data['malfunction_rate'] > 0 >= agent.malfunction_data[
'malfunction']:
# If counter has come to zero --> Agent has malfunction
# set next malfunction time and duration of current malfunction
if self.agents[i_agent].malfunction_data['next_malfunction'] <= 0:
if agent.malfunction_data['next_malfunction'] <= 0:
# Increase number of malfunctions
self.agents[i_agent].malfunction_data['nr_malfunctions'] += 1
agent.malfunction_data['nr_malfunctions'] += 1
# Next malfunction in number of stops
next_breakdown = int(
np.random.exponential(scale=self.agents[i_agent].malfunction_data['malfunction_rate']))
self.agents[i_agent].malfunction_data['next_malfunction'] = next_breakdown
np.random.exponential(scale=agent.malfunction_data['malfunction_rate']))
agent.malfunction_data['next_malfunction'] = next_breakdown
# Duration of current malfunction
num_broken_steps = np.random.randint(self.min_number_of_steps_broken,
self.max_number_of_steps_broken + 1) + 1
self.agents[i_agent].malfunction_data['malfunction'] = num_broken_steps
agent.malfunction_data['malfunction'] = num_broken_steps
def step(self, action_dict_):
self._elapsed_steps += 1
......@@ -306,6 +306,9 @@ class RailEnv(Environment):
agent.old_direction = agent.direction
agent.old_position = agent.position
# Check if agent breaks at this step
self._agent_malfunction(agent)
if self.dones[i_agent]: # this agent has already completed...
continue
......@@ -341,7 +344,6 @@ class RailEnv(Environment):
# Only allow halting an agent on entering new cells.
agent.moving = False
self.rewards_dict[i_agent] += stop_penalty
self._agent_stopped(i_agent)
if not agent.moving and not (action == RailEnvActions.DO_NOTHING or action == RailEnvActions.STOP_MOVING):
# Allow agent to start with any forward or direction action
......@@ -385,8 +387,6 @@ class RailEnv(Environment):
self.rewards_dict[i_agent] += invalid_action_penalty
self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed']
self.rewards_dict[i_agent] += stop_penalty
if agent.moving:
self._agent_stopped(i_agent)
agent.moving = False
continue
else:
......@@ -394,8 +394,6 @@ class RailEnv(Environment):
self.rewards_dict[i_agent] += invalid_action_penalty
self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed']
self.rewards_dict[i_agent] += stop_penalty
if agent.moving:
self._agent_stopped(i_agent)
agent.moving = False
continue
......@@ -416,14 +414,11 @@ class RailEnv(Environment):
agent.speed_data['position_fraction'] = 0.0
else:
# If the agent cannot move due to any reason, we set its state to not moving
if agent.moving:
self._agent_stopped(i_agent)
agent.moving = False
if np.equal(agent.position, agent.target).all():
self.dones[i_agent] = True
agent.moving = False
# Do not call self._agent_stopped, as the agent has terminated its task
else:
self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed']
......
......@@ -53,13 +53,13 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
def test_malfunction_process():
# Set fixed malfunction duration for this test
stochastic_data = {'prop_malfunction': 1.,
'malfunction_rate': 5,
'malfunction_rate': 1000,
'min_duration': 3,
'max_duration': 3}
np.random.seed(5)
env = RailEnv(width=14,
height=14,
env = RailEnv(width=20,
height=20,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999,
seed=0),
number_of_agents=2,
......@@ -82,17 +82,17 @@ def test_malfunction_process():
if step % 5 == 0:
# Stop the agent and set it to be malfunctioning
actions[0] = 4
env.agents[0].malfunction_data['malfunction'] = -1
env.agents[0].malfunction_data['next_malfunction'] = 0
agent_halts += 1
obs, all_rewards, done, _ = env.step(actions)
if env.agents[0].malfunction_data['malfunction'] > 0:
agent_malfunctioning = True
else:
agent_malfunctioning = False
obs, all_rewards, done, _ = env.step(actions)
if agent_malfunctioning:
# Check that agent is not moving while malfunctioning
assert agent_old_position == env.agents[0].position
......@@ -101,7 +101,7 @@ def test_malfunction_process():
total_down_time += env.agents[0].malfunction_data['malfunction']
# Check that the appropriate number of malfunctions is achieved
assert env.agents[0].malfunction_data['nr_malfunctions'] == 5
assert env.agents[0].malfunction_data['nr_malfunctions'] == 21
# Check that 20 stops where performed
assert agent_halts == 20
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment