Skip to content
Snippets Groups Projects
Commit ed12d9c1 authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

Refactoring: Intdocued fix_agents function.

At each step we first check for possible malfunctions.
THen we apply them and do all the agent steps
At the end we look at what agents should be fixed. This lead so to:

Current malfunction value says how many step calls are necessary till agent can move again

e.g. malfunction = 2, we have to call step() twice untill the agent can move again.
parent 7c974837
No related branches found
No related tags found
No related merge requests found
......@@ -367,15 +367,16 @@ class RailEnv(Environment):
for i_agent in range(self.get_num_agents()):
self.set_agent_active(i_agent)
# See if agents are already broken
# Induce malfunctions
self._malfunction(self.mean_malfunction_rate)
for i_agent, agent in enumerate(self.agents):
initial_malfunction = self._agent_malfunction(i_agent)
if initial_malfunction:
for agent in self.agents:
if agent.malfunction_data["malfunction"] > 0:
agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.DO_NOTHING
# Fix agents that finished their malfunciton
self._fix_agents()
self.num_resets += 1
self._elapsed_steps = 0
......@@ -398,26 +399,27 @@ class RailEnv(Environment):
observation_dict: Dict = self._get_observations()
return observation_dict, info_dict
def _agent_malfunction(self, i_agent) -> bool:
def _fix_agents(self):
"""
Returns true if the agent enters into malfunction. (False, if not broken down or already broken down before).
Updates agent malfunction variables and fixes broken agents
"""
agent = self.agents[i_agent]
for agent in self.agents:
# Reduce number of malfunction steps left
if agent.malfunction_data['malfunction'] > 0:
agent.malfunction_data['malfunction'] -= 1
return True
# Ignore agents that OK
if agent.malfunction_data['fixed']:
continue
# Ignore agents that OK
if agent.malfunction_data['fixed']:
return False
# Reduce number of malfunction steps left
if agent.malfunction_data['malfunction'] > 1:
agent.malfunction_data['malfunction'] -= 1
continue
# Restart agents at the end of their malfunction
agent.malfunction_data['fixed'] = True
if 'moving_before_malfunction' in agent.malfunction_data:
self.agents[i_agent].moving = agent.malfunction_data['moving_before_malfunction']
return False
# Restart agents at the end of their malfunction
agent.malfunction_data['malfunction'] -= 1
agent.malfunction_data['fixed'] = True
if 'moving_before_malfunction' in agent.malfunction_data:
agent.moving = agent.malfunction_data['moving_before_malfunction']
continue
def _malfunction(self, rate):
"""
......@@ -434,7 +436,7 @@ class RailEnv(Environment):
# TODO: Do we want to guarantee that we have the desired rate or are we happy with lower rates?
if breaking_agent.malfunction_data['malfunction'] < 1:
num_broken_steps = self.np_random.randint(self.min_number_of_steps_broken,
self.max_number_of_steps_broken + 1)
self.max_number_of_steps_broken + 1) + 1
breaking_agent.malfunction_data['malfunction'] = num_broken_steps
breaking_agent.malfunction_data['moving_before_malfunction'] = breaking_agent.moving
breaking_agent.malfunction_data['fixed'] = False
......@@ -479,7 +481,7 @@ class RailEnv(Environment):
}
have_all_agents_ended = True # boolean flag to check if all agents are done
# Evoke the malfunction generator
# Induce malfunctions
self._malfunction(self.mean_malfunction_rate)
for i_agent, agent in enumerate(self.agents):
......@@ -498,6 +500,9 @@ class RailEnv(Environment):
info_dict["speed"][i_agent] = agent.speed_data['speed']
info_dict["status"][i_agent] = agent.status
# Fix agents that finished their malfunction
self._fix_agents()
# Check for end of episode + set global reward to all rewards!
if have_all_agents_ended:
self.dones["__all__"] = True
......@@ -542,12 +547,9 @@ class RailEnv(Environment):
agent.old_direction = agent.direction
agent.old_position = agent.position
# is the agent malfunctioning?
malfunction = self._agent_malfunction(i_agent)
# if agent is broken, actions are ignored and agent does not move.
# full step penalty in this case
if malfunction:
if agent.malfunction_data['malfunction'] > 0:
self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
return
......
......@@ -110,7 +110,7 @@ def test_malfunction_process():
total_down_time += env.agents[0].malfunction_data['malfunction']
# Check that the appropriate number of malfunctions is achieved
assert env.agents[0].malfunction_data['nr_malfunctions'] == 28, "Actual {}".format(
assert env.agents[0].malfunction_data['nr_malfunctions'] == 22, "Actual {}".format(
env.agents[0].malfunction_data['nr_malfunctions'])
# Check that malfunctioning data was standing around
......@@ -140,17 +140,17 @@ def test_malfunction_process_statistically():
env.agents[0].target = (0, 0)
# Next line only for test generation
#agent_malfunction_list = [[] for i in range(20)]
agent_malfunction_list = [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 3, 2, 1, 0],
[0, 0, 0, 0, 0, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 5, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [], [], [], [], [], [], [], [], [], []]
#agent_malfunction_list = [[] for i in range(10)]
agent_malfunction_list = [[0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1],
[0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5],
[0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
for step in range(20):
action_dict: Dict[int, RailEnvActions] = {}
......@@ -188,17 +188,25 @@ def test_malfunction_before_entry():
# Test initial malfunction values for all agents
# we want some agents to be malfuncitoning already and some to be working
# we want different next_malfunction values for the agents
for a in range(10):
print("assert env.agents[{}].malfunction_data['malfunction'] == {}".format(a,env.agents[a].malfunction_data['malfunction']))
assert env.agents[0].malfunction_data['malfunction'] == 0
assert env.agents[1].malfunction_data['malfunction'] == 0
assert env.agents[2].malfunction_data['malfunction'] == 0
assert env.agents[3].malfunction_data['malfunction'] == 0
assert env.agents[4].malfunction_data['malfunction'] == 0
assert env.agents[5].malfunction_data['malfunction'] == 10
assert env.agents[6].malfunction_data['malfunction'] == 0
assert env.agents[7].malfunction_data['malfunction'] == 0
assert env.agents[8].malfunction_data['malfunction'] == 0
assert env.agents[9].malfunction_data['malfunction'] == 0
#for a in range(10):
# print("assert env.agents[{}].malfunction_data['malfunction'] == {}".format(a,env.agents[a].malfunction_data['malfunction']))
def test_malfunction_values_and_behavior():
"""
Test that the next malfunction occurs when desired.
Test the malfunction counts down as desired
Returns
-------
......@@ -207,7 +215,7 @@ def test_malfunction_values_and_behavior():
rail, rail_map = make_simple_rail2()
action_dict: Dict[int, RailEnvActions] = {}
stochastic_data = {'malfunction_rate': 0.01,
stochastic_data = {'malfunction_rate': 0.001,
'min_duration': 10,
'max_duration': 10}
env = RailEnv(width=25,
......@@ -223,7 +231,7 @@ def test_malfunction_values_and_behavior():
env.reset(False, False, activate_agents=True, random_seed=10)
# Assertions
assert_list = [8, 7, 6, 5, 4, 3, 2, 1, 0, 9, 8, 7, 6, 5, 4]
assert_list = [9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 10, 9, 8, 7, 6, 5]
print("[")
for time_step in range(15):
# Move in the env
......@@ -233,8 +241,7 @@ def test_malfunction_values_and_behavior():
def test_initial_malfunction():
stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents
'malfunction_rate': 100, # Rate of malfunction occurence
stochastic_data = {'malfunction_rate': 1000, # Rate of malfunction occurence
'min_duration': 2, # Minimal duration of malfunction
'max_duration': 5 # Max duration of malfunction
}
......@@ -278,7 +285,7 @@ def test_initial_malfunction():
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=1,
reward=env.step_penalty * 1.0
reward=env.step_penalty
), # malfunctioning ends: starting and running at speed 1.0
Replay(
......@@ -293,7 +300,7 @@ def test_initial_malfunction():
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=0,
reward=env.step_penalty * 1.0 # running at speed 1.0
reward=env.step_penalty # running at speed 1.0
)
],
speed=env.agents[0].speed_data['speed'],
......@@ -341,7 +348,7 @@ def test_initial_malfunction_stop_moving():
position=(3, 2),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING,
malfunction=3,
malfunction=2,
reward=env.step_penalty, # full step penalty when stopped
status=RailAgentStatus.ACTIVE
),
......@@ -352,7 +359,7 @@ def test_initial_malfunction_stop_moving():
position=(3, 2),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.STOP_MOVING,
malfunction=2,
malfunction=1,
reward=env.step_penalty, # full step penalty while stopped
status=RailAgentStatus.ACTIVE
),
......@@ -361,7 +368,7 @@ def test_initial_malfunction_stop_moving():
position=(3, 2),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING,
malfunction=1,
malfunction=0,
reward=env.step_penalty, # full step penalty while stopped
status=RailAgentStatus.ACTIVE
),
......@@ -429,7 +436,7 @@ def test_initial_malfunction_do_nothing():
position=(3, 2),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING,
malfunction=3,
malfunction=2,
reward=env.step_penalty, # full step penalty while malfunctioning
status=RailAgentStatus.ACTIVE
),
......@@ -440,7 +447,7 @@ def test_initial_malfunction_do_nothing():
position=(3, 2),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING,
malfunction=2,
malfunction=1,
reward=env.step_penalty, # full step penalty while stopped
status=RailAgentStatus.ACTIVE
),
......@@ -449,7 +456,7 @@ def test_initial_malfunction_do_nothing():
position=(3, 2),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING,
malfunction=1,
malfunction=0,
reward=env.step_penalty, # full step penalty while stopped
status=RailAgentStatus.ACTIVE
),
......
......@@ -119,8 +119,9 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering:
# We also set next malfunction to infitiy to avoid interference with our tests
agent.malfunction_data['malfunction'] = replay.set_malfunction
agent.malfunction_data['moving_before_malfunction'] = agent.moving
agent.malfunction_data['fixed'] = False
_assert(a, agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction')
print(step)
print(step, agent.moving, agent.malfunction_data['fixed'], agent.malfunction_data['malfunction'])
_, rewards_dict, _, info_dict = env.step(action_dict)
if rendering:
renderer.render_env(show=True, show_observations=True)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment