Commit 00431259 authored by Erik Nygren's avatar Erik Nygren 🚅
Browse files

updated tests to new malfunction behavior

parent 324a397e
......@@ -429,7 +429,7 @@ class RailEnv(Environment):
agent.malfunction_data['next_malfunction'] = max(next_breakdown, 1)
# Duration of current malfunction
num_broken_steps = self.np_random.randint(self.min_number_of_steps_broken,
self.max_number_of_steps_broken + 1) + 1
self.max_number_of_steps_broken + 1)
agent.malfunction_data['malfunction'] = num_broken_steps
agent.malfunction_data['moving_before_malfunction'] = agent.moving
......
......@@ -156,14 +156,14 @@ def test_malfunction_process_statistically():
env.agents[0].target = (0, 0)
# Next line only for test generation
# agent_malfunction_list = [[] for i in range(20)]
#agent_malfunction_list = [[] for i in range(20)]
agent_malfunction_list = [[0, 0, 0, 0, 5, 5, 0, 0, 0, 0], [0, 0, 0, 0, 5, 5, 0, 0, 0, 0], [0, 0, 0, 0, 4, 4, 0, 0, 0, 0],
[0, 0, 0, 0, 3, 3, 0, 0, 0, 0], [0, 0, 0, 0, 2, 2, 0, 0, 0, 5], [0, 0, 0, 0, 1, 1, 5, 0, 0, 4],
[0, 0, 0, 5, 0, 0, 4, 5, 0, 3], [5, 0, 0, 4, 5, 5, 3, 4, 0, 2], [4, 5, 0, 3, 4, 4, 2, 3, 5, 1],
[3, 4, 0, 2, 3, 3, 1, 2, 4, 0], [2, 3, 5, 1, 2, 2, 0, 1, 3, 0], [1, 2, 4, 0, 1, 1, 5, 0, 2, 0],
[0, 1, 3, 0, 0, 0, 4, 0, 1, 0], [5, 0, 2, 0, 0, 5, 3, 5, 0, 5], [4, 0, 1, 0, 0, 4, 2, 4, 0, 4],
[3, 0, 0, 0, 0, 3, 1, 3, 5, 3], [2, 0, 0, 0, 0, 2, 0, 2, 4, 2], [1, 0, 5, 5, 5, 1, 5, 1, 3, 1],
[0, 0, 4, 4, 4, 0, 4, 0, 2, 0], [5, 0, 3, 3, 3, 5, 3, 5, 1, 5]]
[0, 0, 0, 5, 0, 0, 4, 5, 0, 3], [5, 0, 0, 4, 0, 0, 3, 4, 0, 2], [4, 5, 0, 3, 5, 5, 2, 3, 5, 1],
[3, 4, 0, 2, 4, 4, 1, 2, 4, 0], [2, 3, 5, 1, 3, 3, 0, 1, 3, 0], [1, 2, 4, 0, 2, 2, 0, 0, 2, 0],
[0, 1, 3, 0, 1, 1, 5, 0, 1, 0], [0, 0, 2, 0, 0, 0, 4, 0, 0, 0], [5, 0, 1, 0, 0, 0, 3, 5, 0, 5],
[4, 0, 0, 0, 5, 0, 2, 4, 0, 4], [3, 0, 0, 0, 4, 0, 1, 3, 5, 3], [2, 0, 0, 0, 3, 0, 0, 2, 4, 2],
[1, 0, 5, 5, 2, 0, 0, 1, 3, 1], [0, 5, 4, 4, 1, 0, 5, 0, 2, 0]]
for step in range(20):
action_dict: Dict[int, RailEnvActions] = {}
......@@ -223,6 +223,11 @@ def test_malfunction_before_entry():
assert env.agents[8].malfunction_data['malfunction'] == 0
assert env.agents[9].malfunction_data['malfunction'] == 0
for a in range(env.get_num_agents()):
print("assert env.agents[{}].malfunction_data['next_malfunction'] == {}".format(a, env.agents[a].malfunction_data['next_malfunction']))
for a in range(env.get_num_agents()):
print("assert env.agents[{}].malfunction_data['malfunction'] == {}".format(a, env.agents[a].malfunction_data[
'malfunction']))
for step in range(20):
action_dict: Dict[int, RailEnvActions] = {}
for agent in env.agents:
......@@ -309,18 +314,18 @@ def test_initial_malfunction():
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=1,
reward=env.start_penalty + env.step_penalty * 1.0
# malfunctioning ends: starting and running at speed 1.0
),
reward= env.step_penalty * 1.0
),# malfunctioning ends: starting and running at speed 1.0
Replay(
position=(3, 3),
position=(3, 2),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=0,
reward=env.step_penalty * 1.0 # running at speed 1.0
reward=env.start_penalty +env.step_penalty * 1.0 # running at speed 1.0
),
Replay(
position=(3, 4),
position=(3, 3),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=0,
......@@ -654,7 +659,7 @@ def test_last_malfunction_step():
action_dict[agent.handle] = RailEnvActions(2)
# Check if the agent is still allowed to move in this step
if env.agents[0].malfunction_data['malfunction'] > 1 or env.agents[0].malfunction_data['next_malfunction'] < 1:
if env.agents[0].malfunction_data['malfunction'] > 0 or env.agents[0].malfunction_data['next_malfunction'] < 1:
agent_can_move = False
else:
agent_can_move = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment