Commit 57b15b9f authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

tests fixed

parent 94feb039
Pipeline #8495 failed with stages
in 5 minutes and 45 seconds
......@@ -106,7 +106,7 @@ def test_malfunction_process():
if done["__all__"]:
break
if env.agents[0].malfunction_data['malfunction'] > 0:
if env.agents[0].malfunction_handler.malfunction_down_counter > 0:
agent_malfunctioning = True
else:
agent_malfunctioning = False
......@@ -116,11 +116,11 @@ def test_malfunction_process():
assert agent_old_position == env.agents[0].position
agent_old_position = env.agents[0].position
total_down_time += env.agents[0].malfunction_data['malfunction']
total_down_time += env.agents[0].malfunction_handler.malfunction_down_counter
# Check that the appropriate number of malfunctions is achieved
# Dipam: The number of malfunctions varies by seed
assert env.agents[0].malfunction_data['nr_malfunctions'] == 21, "Actual {}".format(
env.agents[0].malfunction_data['nr_malfunctions'])
assert env.agents[0].malfunction_handler.num_malfunctions == 46, "Actual {}".format(
env.agents[0].malfunction_handler.num_malfunctions)
# Check that malfunctioning data was standing around
assert total_down_time > 0
......@@ -150,9 +150,9 @@ def test_malfunction_process_statistically():
env.agents[0].target = (0, 0)
# Next line only for test generation
# agent_malfunction_list = [[] for i in range(2)]
agent_malfunction_list = [[0, 5, 4, 3, 2, 1, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0],
[5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 5]]
agent_malfunction_list = [[] for i in range(2)]
agent_malfunction_list = [[0, 0, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 3, 2, 1, 0],
[0, 4, 3, 2, 1, 0, 0, 0, 0, 0, 4, 3, 2, 1, 0, 4, 3, 2, 1, 0]]
for step in range(20):
action_dict: Dict[int, RailEnvActions] = {}
......@@ -160,10 +160,11 @@ def test_malfunction_process_statistically():
# We randomly select an action
action_dict[agent_idx] = RailEnvActions(np.random.randint(4))
# For generating tests only:
# agent_malfunction_list[agent_idx].append(env.agents[agent_idx].malfunction_data['malfunction'])
assert env.agents[agent_idx].malfunction_data['malfunction'] == agent_malfunction_list[agent_idx][step]
# agent_malfunction_list[agent_idx].append(
# env.agents[agent_idx].malfunction_handler.malfunction_down_counter)
assert env.agents[agent_idx].malfunction_handler.malfunction_down_counter == \
agent_malfunction_list[agent_idx][step]
env.step(action_dict)
# print(agent_malfunction_list)
def test_malfunction_before_entry():
......@@ -221,18 +222,18 @@ def test_malfunction_values_and_behavior():
env.reset(False, False, random_seed=10)
env._max_episode_steps = 20
# Assertions
assert_list = [9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 10, 9, 8, 7, 6, 5]
print("[")
assert_list = [9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 9, 8, 7, 6, 5]
for time_step in range(15):
# Move in the env
_, _, dones,_ = env.step(action_dict)
# Check that next_step decreases as expected
assert env.agents[0].malfunction_data['malfunction'] == assert_list[time_step]
assert env.agents[0].malfunction_handler.malfunction_down_counter == assert_list[time_step]
if dones['__all__']:
break
def test_initial_malfunction():
stochastic_data = MalfunctionParameters(malfunction_rate=1/1000, # Rate of malfunction occurence
min_duration=2, # Minimal duration of malfunction
......@@ -321,7 +322,7 @@ def test_initial_malfunction_stop_moving():
set_penalties_for_replay(env)
replay_config = ReplayConfig(
replay=[
Replay(
Replay( # 0
position=None,
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
......@@ -330,44 +331,60 @@ def test_initial_malfunction_stop_moving():
reward=env.step_penalty, # full step penalty when stopped
state=TrainState.READY_TO_DEPART
),
Replay(
position=(3, 2),
Replay( # 1
position=None,
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING,
malfunction=2,
reward=env.step_penalty, # full step penalty when stopped
state=TrainState.READY_TO_DEPART
state=TrainState.MALFUNCTION_OFF_MAP
),
# malfunction stops in the next step and we're still at the beginning of the cell
# --> if we take action STOP_MOVING, agent should restart without moving
#
Replay(
position=(3, 2),
Replay( # 2
position=None,
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.STOP_MOVING,
malfunction=1,
reward=env.step_penalty, # full step penalty while stopped
state=TrainState.STOPPED
state=TrainState.MALFUNCTION_OFF_MAP
),
# we have stopped and do nothing --> should stand still
Replay(
position=(3, 2),
Replay( # 3
position=None,
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING,
malfunction=0,
reward=env.step_penalty, # full step penalty while stopped
state=TrainState.STOPPED
state=TrainState.MALFUNCTION_OFF_MAP
),
# we start to move forward --> should go to next cell now
Replay(
Replay( # 4
position=(3, 2),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
action=RailEnvActions.STOP_MOVING,
malfunction=0,
reward=env.start_penalty + env.step_penalty * 1.0, # full step penalty while stopped
state=TrainState.MOVING
),
Replay( # 5
position=(3, 2),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=0,
reward=env.step_penalty * 1.0, # full step penalty while stopped
state=TrainState.STOPPED
),
Replay(
Replay( # 6
position=(3, 3),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.STOP_MOVING,
malfunction=0,
reward=env.step_penalty * 1.0, # full step penalty while stopped
state=TrainState.MOVING
),
Replay( # 6
position=(3, 3),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
......@@ -382,7 +399,8 @@ def test_initial_malfunction_stop_moving():
initial_direction=Grid4TransitionsEnum.EAST,
)
run_replay_config(env, [replay_config], activate_agents=False, skip_reward_check=True)
run_replay_config(env, [replay_config], activate_agents=False,
skip_reward_check=True, set_ready_to_depart=True, skip_action_required_check=True)
def test_initial_malfunction_do_nothing():
......@@ -403,6 +421,7 @@ def test_initial_malfunction_do_nothing():
)
env.reset()
env._max_episode_steps = 1000
set_penalties_for_replay(env)
replay_config = ReplayConfig(
replay=[
......@@ -416,32 +435,32 @@ def test_initial_malfunction_do_nothing():
state=TrainState.READY_TO_DEPART
),
Replay(
position=(3, 2),
position=None,
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING,
action=None,
malfunction=2,
reward=env.step_penalty, # full step penalty while malfunctioning
state=TrainState.MOVING
state=TrainState.MALFUNCTION_OFF_MAP
),
# malfunction stops in the next step and we're still at the beginning of the cell
# --> if we take action DO_NOTHING, agent should restart without moving
#
Replay(
position=(3, 2),
position=None,
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING,
action=None,
malfunction=1,
reward=env.step_penalty, # full step penalty while stopped
state=TrainState.MOVING
state=TrainState.MALFUNCTION_OFF_MAP
),
# we haven't started moving yet --> stay here
Replay(
position=(3, 2),
position=None,
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING,
action=None,
malfunction=0,
reward=env.step_penalty, # full step penalty while stopped
state=TrainState.MOVING
state=TrainState.MALFUNCTION_OFF_MAP
),
Replay(
......@@ -466,7 +485,8 @@ def test_initial_malfunction_do_nothing():
initial_position=(3, 2),
initial_direction=Grid4TransitionsEnum.EAST,
)
run_replay_config(env, [replay_config], activate_agents=False, skip_reward_check=True)
run_replay_config(env, [replay_config], activate_agents=False,
skip_reward_check=True, set_ready_to_depart=True)
def tests_random_interference_from_outside():
......@@ -532,7 +552,6 @@ def test_last_malfunction_step():
# Set fixed malfunction duration for this test
rail, rail_map, optionals = make_simple_rail2()
# import pdb; pdb.set_trace()
env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(seed=2), number_of_agents=1, random_seed=1)
......
......@@ -120,7 +120,8 @@ def test_initial_status():
speed=0.5
)
run_replay_config(env, [test_config], activate_agents=False, skip_reward_check=True)
run_replay_config(env, [test_config], activate_agents=False, skip_reward_check=True,
set_ready_to_depart=True)
assert env.agents[0].state == TrainState.DONE
......@@ -236,5 +237,6 @@ def test_status_done_remove():
speed=0.5
)
run_replay_config(env, [test_config], activate_agents=False, skip_reward_check=True)
run_replay_config(env, [test_config], activate_agents=False, skip_reward_check=True,
set_ready_to_depart=True)
assert env.agents[0].state == TrainState.DONE
......@@ -196,7 +196,7 @@ def test_multispeed_actions_no_malfunction_no_blocking():
initial_direction=Grid4TransitionsEnum.EAST,
)
run_replay_config(env, [test_config], skip_reward_check=True)
run_replay_config(env, [test_config], skip_reward_check=True, skip_action_required_check=True)
def test_multispeed_actions_no_malfunction_blocking():
......@@ -206,11 +206,6 @@ def test_multispeed_actions_no_malfunction_blocking():
line_generator=sparse_line_generator(), number_of_agents=2,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
env.reset()
# Perform DO_NOTHING actions until all trains get to READY_TO_DEPART
for _ in range(max([agent.earliest_departure for agent in env.agents])):
env.step({}) # DO_NOTHING for all agents
set_penalties_for_replay(env)
test_configs = [
......
......@@ -43,7 +43,8 @@ def set_penalties_for_replay(env: RailEnv):
env.invalid_action_penalty = -29
def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: bool = False, activate_agents=True, skip_reward_check=False):
def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: bool = False, activate_agents=True,
skip_reward_check=False, set_ready_to_depart=False, skip_action_required_check=False):
"""
Runs the replay configs and checks assertions.
......@@ -90,7 +91,14 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering:
agent.target = test_config.target
agent.speed_counter = SpeedCounter(speed=test_config.speed)
env.reset(False, False)
if activate_agents:
if set_ready_to_depart:
# Set all agents to ready to depart
for i_agent in range(len(env.agents)):
env.agents[i_agent].earliest_departure = 0
env.agents[i_agent]._set_state(TrainState.READY_TO_DEPART)
elif activate_agents:
for a_idx in range(len(env.agents)):
env.agents[a_idx].position = env.agents[a_idx].initial_position
env.agents[a_idx]._set_state(TrainState.MOVING)
......@@ -113,12 +121,14 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering:
_assert(a, agent.state, replay.state, 'state')
if replay.action is not None:
assert info_dict['action_required'][
if not skip_action_required_check:
assert info_dict['action_required'][
a] == True or agent.state == TrainState.READY_TO_DEPART, "[{}] agent {} expecting action_required={} or agent status READY_TO_DEPART".format(
step, a, True)
action_dict[a] = replay.action
else:
assert info_dict['action_required'][
if not skip_action_required_check:
assert info_dict['action_required'][
a] == False, "[{}] agent {} expecting action_required={}, but found {}".format(
step, a, False, info_dict['action_required'][a])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment