diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index b4632f3e37bd5f879349488d8b24a74c4c5d9759..7ebf73f0c8acc98f9690c219032550a4afead3e3 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -106,7 +106,7 @@ def test_malfunction_process(): if done["__all__"]: break - if env.agents[0].malfunction_data['malfunction'] > 0: + if env.agents[0].malfunction_handler.malfunction_down_counter > 0: agent_malfunctioning = True else: agent_malfunctioning = False @@ -116,11 +116,11 @@ def test_malfunction_process(): assert agent_old_position == env.agents[0].position agent_old_position = env.agents[0].position - total_down_time += env.agents[0].malfunction_data['malfunction'] + total_down_time += env.agents[0].malfunction_handler.malfunction_down_counter # Check that the appropriate number of malfunctions is achieved # Dipam: The number of malfunctions varies by seed - assert env.agents[0].malfunction_data['nr_malfunctions'] == 21, "Actual {}".format( - env.agents[0].malfunction_data['nr_malfunctions']) + assert env.agents[0].malfunction_handler.num_malfunctions == 46, "Actual {}".format( + env.agents[0].malfunction_handler.num_malfunctions) # Check that malfunctioning data was standing around assert total_down_time > 0 @@ -150,9 +150,9 @@ def test_malfunction_process_statistically(): env.agents[0].target = (0, 0) # Next line only for test generation - # agent_malfunction_list = [[] for i in range(2)] - agent_malfunction_list = [[0, 5, 4, 3, 2, 1, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0], - [5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 5]] + agent_malfunction_list = [[] for i in range(2)] + agent_malfunction_list = [[0, 0, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 3, 2, 1, 0], + [0, 4, 3, 2, 1, 0, 0, 0, 0, 0, 4, 3, 2, 1, 0, 4, 3, 2, 1, 0]] for step in range(20): action_dict: Dict[int, RailEnvActions] = {} @@ -160,10 +160,11 @@ def test_malfunction_process_statistically(): # We randomly select an action action_dict[agent_idx] = RailEnvActions(np.random.randint(4)) # For generating tests only: - # agent_malfunction_list[agent_idx].append(env.agents[agent_idx].malfunction_data['malfunction']) - assert env.agents[agent_idx].malfunction_data['malfunction'] == agent_malfunction_list[agent_idx][step] + # agent_malfunction_list[agent_idx].append( + # env.agents[agent_idx].malfunction_handler.malfunction_down_counter) + assert env.agents[agent_idx].malfunction_handler.malfunction_down_counter == \ + agent_malfunction_list[agent_idx][step] env.step(action_dict) - # print(agent_malfunction_list) def test_malfunction_before_entry(): @@ -221,18 +222,18 @@ def test_malfunction_values_and_behavior(): env.reset(False, False, random_seed=10) + env._max_episode_steps = 20 + # Assertions - assert_list = [9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 10, 9, 8, 7, 6, 5] - print("[") + assert_list = [9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 9, 8, 7, 6, 5] for time_step in range(15): # Move in the env _, _, dones,_ = env.step(action_dict) # Check that next_step decreases as expected - assert env.agents[0].malfunction_data['malfunction'] == assert_list[time_step] + assert env.agents[0].malfunction_handler.malfunction_down_counter == assert_list[time_step] if dones['__all__']: break - def test_initial_malfunction(): stochastic_data = MalfunctionParameters(malfunction_rate=1/1000, # Rate of malfunction occurence min_duration=2, # Minimal duration of malfunction @@ -321,7 +322,7 @@ def test_initial_malfunction_stop_moving(): set_penalties_for_replay(env) replay_config = ReplayConfig( replay=[ - Replay( + Replay( # 0 position=None, direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.MOVE_FORWARD, @@ -330,44 +331,60 @@ def test_initial_malfunction_stop_moving(): reward=env.step_penalty, # full step penalty when stopped state=TrainState.READY_TO_DEPART ), - Replay( - position=(3, 2), + Replay( # 1 + position=None, direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.DO_NOTHING, malfunction=2, reward=env.step_penalty, # full step penalty when stopped - state=TrainState.READY_TO_DEPART + state=TrainState.MALFUNCTION_OFF_MAP ), # malfunction stops in the next step and we're still at the beginning of the cell # --> if we take action STOP_MOVING, agent should restart without moving # - Replay( - position=(3, 2), + Replay( # 2 + position=None, direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.STOP_MOVING, malfunction=1, reward=env.step_penalty, # full step penalty while stopped - state=TrainState.STOPPED + state=TrainState.MALFUNCTION_OFF_MAP ), # we have stopped and do nothing --> should stand still - Replay( - position=(3, 2), + Replay( # 3 + position=None, direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.DO_NOTHING, malfunction=0, reward=env.step_penalty, # full step penalty while stopped - state=TrainState.STOPPED + state=TrainState.MALFUNCTION_OFF_MAP ), # we start to move forward --> should go to next cell now - Replay( + Replay( # 4 position=(3, 2), direction=Grid4TransitionsEnum.EAST, - action=RailEnvActions.MOVE_FORWARD, + action=RailEnvActions.STOP_MOVING, malfunction=0, reward=env.start_penalty + env.step_penalty * 1.0, # full step penalty while stopped + state=TrainState.MOVING + ), + Replay( # 5 + position=(3, 2), + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.MOVE_FORWARD, + malfunction=0, + reward=env.step_penalty * 1.0, # full step penalty while stopped state=TrainState.STOPPED ), - Replay( + Replay( # 6 + position=(3, 3), + direction=Grid4TransitionsEnum.EAST, + action=RailEnvActions.STOP_MOVING, + malfunction=0, + reward=env.step_penalty * 1.0, # full step penalty while stopped + state=TrainState.MOVING + ), + Replay( # 6 position=(3, 3), direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.MOVE_FORWARD, @@ -382,7 +399,8 @@ def test_initial_malfunction_stop_moving(): initial_direction=Grid4TransitionsEnum.EAST, ) - run_replay_config(env, [replay_config], activate_agents=False, skip_reward_check=True) + run_replay_config(env, [replay_config], activate_agents=False, + skip_reward_check=True, set_ready_to_depart=True, skip_action_required_check=True) def test_initial_malfunction_do_nothing(): @@ -403,6 +421,7 @@ def test_initial_malfunction_do_nothing(): ) env.reset() env._max_episode_steps = 1000 + set_penalties_for_replay(env) replay_config = ReplayConfig( replay=[ @@ -416,32 +435,32 @@ def test_initial_malfunction_do_nothing(): state=TrainState.READY_TO_DEPART ), Replay( - position=(3, 2), + position=None, direction=Grid4TransitionsEnum.EAST, - action=RailEnvActions.DO_NOTHING, + action=None, malfunction=2, reward=env.step_penalty, # full step penalty while malfunctioning - state=TrainState.MOVING + state=TrainState.MALFUNCTION_OFF_MAP ), # malfunction stops in the next step and we're still at the beginning of the cell # --> if we take action DO_NOTHING, agent should restart without moving # Replay( - position=(3, 2), + position=None, direction=Grid4TransitionsEnum.EAST, - action=RailEnvActions.DO_NOTHING, + action=None, malfunction=1, reward=env.step_penalty, # full step penalty while stopped - state=TrainState.MOVING + state=TrainState.MALFUNCTION_OFF_MAP ), # we haven't started moving yet --> stay here Replay( - position=(3, 2), + position=None, direction=Grid4TransitionsEnum.EAST, - action=RailEnvActions.DO_NOTHING, + action=None, malfunction=0, reward=env.step_penalty, # full step penalty while stopped - state=TrainState.MOVING + state=TrainState.MALFUNCTION_OFF_MAP ), Replay( @@ -466,7 +485,8 @@ def test_initial_malfunction_do_nothing(): initial_position=(3, 2), initial_direction=Grid4TransitionsEnum.EAST, ) - run_replay_config(env, [replay_config], activate_agents=False, skip_reward_check=True) + run_replay_config(env, [replay_config], activate_agents=False, + skip_reward_check=True, set_ready_to_depart=True) def tests_random_interference_from_outside(): @@ -532,7 +552,6 @@ def test_last_malfunction_step(): # Set fixed malfunction duration for this test rail, rail_map, optionals = make_simple_rail2() - # import pdb; pdb.set_trace() env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals), line_generator=sparse_line_generator(seed=2), number_of_agents=1, random_seed=1) diff --git a/tests/test_flatland_rail_agent_status.py b/tests/test_flatland_rail_agent_status.py index 82a2089f17cf1d25eb8bb28bd58e6918537035d2..0c76174ef01afa26a7387a6684240c385ca39775 100644 --- a/tests/test_flatland_rail_agent_status.py +++ b/tests/test_flatland_rail_agent_status.py @@ -120,7 +120,8 @@ def test_initial_status(): speed=0.5 ) - run_replay_config(env, [test_config], activate_agents=False, skip_reward_check=True) + run_replay_config(env, [test_config], activate_agents=False, skip_reward_check=True, + set_ready_to_depart=True) assert env.agents[0].state == TrainState.DONE @@ -236,5 +237,6 @@ def test_status_done_remove(): speed=0.5 ) - run_replay_config(env, [test_config], activate_agents=False, skip_reward_check=True) + run_replay_config(env, [test_config], activate_agents=False, skip_reward_check=True, + set_ready_to_depart=True) assert env.agents[0].state == TrainState.DONE diff --git a/tests/test_multi_speed.py b/tests/test_multi_speed.py index 56a3a33ff69aaf4fdc23dc5b049c477826de3de5..c517c2c58239b28513991f77592f4730c7fa813b 100644 --- a/tests/test_multi_speed.py +++ b/tests/test_multi_speed.py @@ -196,7 +196,7 @@ def test_multispeed_actions_no_malfunction_no_blocking(): initial_direction=Grid4TransitionsEnum.EAST, ) - run_replay_config(env, [test_config], skip_reward_check=True) + run_replay_config(env, [test_config], skip_reward_check=True, skip_action_required_check=True) def test_multispeed_actions_no_malfunction_blocking(): @@ -206,11 +206,6 @@ def test_multispeed_actions_no_malfunction_blocking(): line_generator=sparse_line_generator(), number_of_agents=2, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env.reset() - - # Perform DO_NOTHING actions until all trains get to READY_TO_DEPART - for _ in range(max([agent.earliest_departure for agent in env.agents])): - env.step({}) # DO_NOTHING for all agents - set_penalties_for_replay(env) test_configs = [ diff --git a/tests/test_utils.py b/tests/test_utils.py index 391ba535a343c25df03df6fa18f8651543fab06f..fdae8f5c32f4ab305e54f31293e98fbba5c0a41a 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -43,7 +43,8 @@ def set_penalties_for_replay(env: RailEnv): env.invalid_action_penalty = -29 -def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: bool = False, activate_agents=True, skip_reward_check=False): +def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: bool = False, activate_agents=True, + skip_reward_check=False, set_ready_to_depart=False, skip_action_required_check=False): """ Runs the replay configs and checks assertions. @@ -90,7 +91,14 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: agent.target = test_config.target agent.speed_counter = SpeedCounter(speed=test_config.speed) env.reset(False, False) - if activate_agents: + + if set_ready_to_depart: + # Set all agents to ready to depart + for i_agent in range(len(env.agents)): + env.agents[i_agent].earliest_departure = 0 + env.agents[i_agent]._set_state(TrainState.READY_TO_DEPART) + + elif activate_agents: for a_idx in range(len(env.agents)): env.agents[a_idx].position = env.agents[a_idx].initial_position env.agents[a_idx]._set_state(TrainState.MOVING) @@ -113,12 +121,14 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: _assert(a, agent.state, replay.state, 'state') if replay.action is not None: - assert info_dict['action_required'][ + if not skip_action_required_check: + assert info_dict['action_required'][ a] == True or agent.state == TrainState.READY_TO_DEPART, "[{}] agent {} expecting action_required={} or agent status READY_TO_DEPART".format( step, a, True) action_dict[a] = replay.action else: - assert info_dict['action_required'][ + if not skip_action_required_check: + assert info_dict['action_required'][ a] == False, "[{}] agent {} expecting action_required={}, but found {}".format( step, a, False, info_dict['action_required'][a])