Skip to content
Snippets Groups Projects
Commit 00431259 authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

updated tests to new malfunction behavior

parent 324a397e
No related branches found
No related tags found
No related merge requests found
...@@ -429,7 +429,7 @@ class RailEnv(Environment): ...@@ -429,7 +429,7 @@ class RailEnv(Environment):
agent.malfunction_data['next_malfunction'] = max(next_breakdown, 1) agent.malfunction_data['next_malfunction'] = max(next_breakdown, 1)
# Duration of current malfunction # Duration of current malfunction
num_broken_steps = self.np_random.randint(self.min_number_of_steps_broken, num_broken_steps = self.np_random.randint(self.min_number_of_steps_broken,
self.max_number_of_steps_broken + 1) + 1 self.max_number_of_steps_broken + 1)
agent.malfunction_data['malfunction'] = num_broken_steps agent.malfunction_data['malfunction'] = num_broken_steps
agent.malfunction_data['moving_before_malfunction'] = agent.moving agent.malfunction_data['moving_before_malfunction'] = agent.moving
......
...@@ -156,14 +156,14 @@ def test_malfunction_process_statistically(): ...@@ -156,14 +156,14 @@ def test_malfunction_process_statistically():
env.agents[0].target = (0, 0) env.agents[0].target = (0, 0)
# Next line only for test generation # Next line only for test generation
# agent_malfunction_list = [[] for i in range(20)] #agent_malfunction_list = [[] for i in range(20)]
agent_malfunction_list = [[0, 0, 0, 0, 5, 5, 0, 0, 0, 0], [0, 0, 0, 0, 5, 5, 0, 0, 0, 0], [0, 0, 0, 0, 4, 4, 0, 0, 0, 0], agent_malfunction_list = [[0, 0, 0, 0, 5, 5, 0, 0, 0, 0], [0, 0, 0, 0, 5, 5, 0, 0, 0, 0], [0, 0, 0, 0, 4, 4, 0, 0, 0, 0],
[0, 0, 0, 0, 3, 3, 0, 0, 0, 0], [0, 0, 0, 0, 2, 2, 0, 0, 0, 5], [0, 0, 0, 0, 1, 1, 5, 0, 0, 4], [0, 0, 0, 0, 3, 3, 0, 0, 0, 0], [0, 0, 0, 0, 2, 2, 0, 0, 0, 5], [0, 0, 0, 0, 1, 1, 5, 0, 0, 4],
[0, 0, 0, 5, 0, 0, 4, 5, 0, 3], [5, 0, 0, 4, 5, 5, 3, 4, 0, 2], [4, 5, 0, 3, 4, 4, 2, 3, 5, 1], [0, 0, 0, 5, 0, 0, 4, 5, 0, 3], [5, 0, 0, 4, 0, 0, 3, 4, 0, 2], [4, 5, 0, 3, 5, 5, 2, 3, 5, 1],
[3, 4, 0, 2, 3, 3, 1, 2, 4, 0], [2, 3, 5, 1, 2, 2, 0, 1, 3, 0], [1, 2, 4, 0, 1, 1, 5, 0, 2, 0], [3, 4, 0, 2, 4, 4, 1, 2, 4, 0], [2, 3, 5, 1, 3, 3, 0, 1, 3, 0], [1, 2, 4, 0, 2, 2, 0, 0, 2, 0],
[0, 1, 3, 0, 0, 0, 4, 0, 1, 0], [5, 0, 2, 0, 0, 5, 3, 5, 0, 5], [4, 0, 1, 0, 0, 4, 2, 4, 0, 4], [0, 1, 3, 0, 1, 1, 5, 0, 1, 0], [0, 0, 2, 0, 0, 0, 4, 0, 0, 0], [5, 0, 1, 0, 0, 0, 3, 5, 0, 5],
[3, 0, 0, 0, 0, 3, 1, 3, 5, 3], [2, 0, 0, 0, 0, 2, 0, 2, 4, 2], [1, 0, 5, 5, 5, 1, 5, 1, 3, 1], [4, 0, 0, 0, 5, 0, 2, 4, 0, 4], [3, 0, 0, 0, 4, 0, 1, 3, 5, 3], [2, 0, 0, 0, 3, 0, 0, 2, 4, 2],
[0, 0, 4, 4, 4, 0, 4, 0, 2, 0], [5, 0, 3, 3, 3, 5, 3, 5, 1, 5]] [1, 0, 5, 5, 2, 0, 0, 1, 3, 1], [0, 5, 4, 4, 1, 0, 5, 0, 2, 0]]
for step in range(20): for step in range(20):
action_dict: Dict[int, RailEnvActions] = {} action_dict: Dict[int, RailEnvActions] = {}
...@@ -223,6 +223,11 @@ def test_malfunction_before_entry(): ...@@ -223,6 +223,11 @@ def test_malfunction_before_entry():
assert env.agents[8].malfunction_data['malfunction'] == 0 assert env.agents[8].malfunction_data['malfunction'] == 0
assert env.agents[9].malfunction_data['malfunction'] == 0 assert env.agents[9].malfunction_data['malfunction'] == 0
for a in range(env.get_num_agents()):
print("assert env.agents[{}].malfunction_data['next_malfunction'] == {}".format(a, env.agents[a].malfunction_data['next_malfunction']))
for a in range(env.get_num_agents()):
print("assert env.agents[{}].malfunction_data['malfunction'] == {}".format(a, env.agents[a].malfunction_data[
'malfunction']))
for step in range(20): for step in range(20):
action_dict: Dict[int, RailEnvActions] = {} action_dict: Dict[int, RailEnvActions] = {}
for agent in env.agents: for agent in env.agents:
...@@ -309,18 +314,18 @@ def test_initial_malfunction(): ...@@ -309,18 +314,18 @@ def test_initial_malfunction():
direction=Grid4TransitionsEnum.EAST, direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD, action=RailEnvActions.MOVE_FORWARD,
malfunction=1, malfunction=1,
reward=env.start_penalty + env.step_penalty * 1.0 reward= env.step_penalty * 1.0
# malfunctioning ends: starting and running at speed 1.0
), ),# malfunctioning ends: starting and running at speed 1.0
Replay( Replay(
position=(3, 3), position=(3, 2),
direction=Grid4TransitionsEnum.EAST, direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD, action=RailEnvActions.MOVE_FORWARD,
malfunction=0, malfunction=0,
reward=env.step_penalty * 1.0 # running at speed 1.0 reward=env.start_penalty +env.step_penalty * 1.0 # running at speed 1.0
), ),
Replay( Replay(
position=(3, 4), position=(3, 3),
direction=Grid4TransitionsEnum.EAST, direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD, action=RailEnvActions.MOVE_FORWARD,
malfunction=0, malfunction=0,
...@@ -654,7 +659,7 @@ def test_last_malfunction_step(): ...@@ -654,7 +659,7 @@ def test_last_malfunction_step():
action_dict[agent.handle] = RailEnvActions(2) action_dict[agent.handle] = RailEnvActions(2)
# Check if the agent is still allowed to move in this step # Check if the agent is still allowed to move in this step
if env.agents[0].malfunction_data['malfunction'] > 1 or env.agents[0].malfunction_data['next_malfunction'] < 1: if env.agents[0].malfunction_data['malfunction'] > 0 or env.agents[0].malfunction_data['next_malfunction'] < 1:
agent_can_move = False agent_can_move = False
else: else:
agent_can_move = True agent_can_move = True
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment