diff --git a/tests/test_eval_timeout.py b/tests/test_eval_timeout.py index dfc406e3b9d091fc8e9a477ea86fae025e7b1936..6c92db298b3c87ca8597ab113b56ab1c8f208cde 100644 --- a/tests/test_eval_timeout.py +++ b/tests/test_eval_timeout.py @@ -8,8 +8,6 @@ import time from flatland.core.env import Environment from flatland.core.env_observation_builder import ObservationBuilder -from flatland.core.env_prediction_builder import PredictionBuilder -from flatland.envs.agent_utils import RailAgentStatus, EnvAgent class CustomObservationBuilder(ObservationBuilder): diff --git a/tests/test_flaltland_rail_agent_status.py b/tests/test_flaltland_rail_agent_status.py index e3f1ced759fd755db58749cf0215a121a7b13026..82a2089f17cf1d25eb8bb28bd58e6918537035d2 100644 --- a/tests/test_flaltland_rail_agent_status.py +++ b/tests/test_flaltland_rail_agent_status.py @@ -1,5 +1,4 @@ from flatland.core.grid.grid4 import Grid4TransitionsEnum -from flatland.envs.agent_utils import RailAgentStatus from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv, RailEnvActions @@ -7,7 +6,7 @@ from flatland.envs.rail_generators import rail_from_grid_transition_map from flatland.envs.line_generators import sparse_line_generator from flatland.utils.simple_rail import make_simple_rail from test_utils import ReplayConfig, Replay, run_replay_config, set_penalties_for_replay - +from flatland.envs.step_utils.states import TrainState def test_initial_status(): """Test that agent lifecycle works correctly ready-to-depart -> active -> done.""" @@ -30,7 +29,7 @@ def test_initial_status(): Replay( position=None, # not entered grid yet direction=Grid4TransitionsEnum.EAST, - status=RailAgentStatus.READY_TO_DEPART, + state=TrainState.READY_TO_DEPART, action=RailEnvActions.DO_NOTHING, reward=env.step_penalty * 0.5, @@ -38,35 +37,35 @@ def test_initial_status(): Replay( position=None, # not entered grid yet before step direction=Grid4TransitionsEnum.EAST, - status=RailAgentStatus.READY_TO_DEPART, + state=TrainState.READY_TO_DEPART, action=RailEnvActions.MOVE_LEFT, reward=env.step_penalty * 0.5, # auto-correction left to forward without penalty! ), Replay( position=(3, 9), direction=Grid4TransitionsEnum.EAST, - status=RailAgentStatus.ACTIVE, + state=TrainState.MOVING, action=RailEnvActions.MOVE_LEFT, reward=env.start_penalty + env.step_penalty * 0.5, # running at speed 0.5 ), Replay( position=(3, 9), direction=Grid4TransitionsEnum.EAST, - status=RailAgentStatus.ACTIVE, + state=TrainState.MOVING, action=None, reward=env.step_penalty * 0.5, # running at speed 0.5 ), Replay( position=(3, 8), direction=Grid4TransitionsEnum.WEST, - status=RailAgentStatus.ACTIVE, + state=TrainState.MOVING, action=RailEnvActions.MOVE_FORWARD, reward=env.step_penalty * 0.5, # running at speed 0.5 ), Replay( position=(3, 8), direction=Grid4TransitionsEnum.WEST, - status=RailAgentStatus.ACTIVE, + state=TrainState.MOVING, action=None, reward=env.step_penalty * 0.5, # running at speed 0.5 @@ -76,28 +75,28 @@ def test_initial_status(): direction=Grid4TransitionsEnum.WEST, action=RailEnvActions.MOVE_FORWARD, reward=env.step_penalty * 0.5, # running at speed 0.5 - status=RailAgentStatus.ACTIVE + state=TrainState.MOVING ), Replay( position=(3, 7), direction=Grid4TransitionsEnum.WEST, action=None, reward=env.step_penalty * 0.5, # wrong action is corrected to forward without penalty! - status=RailAgentStatus.ACTIVE + state=TrainState.MOVING ), Replay( position=(3, 6), direction=Grid4TransitionsEnum.WEST, action=RailEnvActions.MOVE_RIGHT, reward=env.step_penalty * 0.5, # - status=RailAgentStatus.ACTIVE + state=TrainState.MOVING ), Replay( position=(3, 6), direction=Grid4TransitionsEnum.WEST, action=None, reward=env.global_reward, # - status=RailAgentStatus.ACTIVE + state=TrainState.MOVING ), # Replay( # position=(3, 5), @@ -122,7 +121,7 @@ def test_initial_status(): ) run_replay_config(env, [test_config], activate_agents=False, skip_reward_check=True) - assert env.agents[0].status == RailAgentStatus.DONE + assert env.agents[0].state == TrainState.DONE def test_status_done_remove(): @@ -146,7 +145,7 @@ def test_status_done_remove(): Replay( position=None, # not entered grid yet direction=Grid4TransitionsEnum.EAST, - status=RailAgentStatus.READY_TO_DEPART, + state=TrainState.READY_TO_DEPART, action=RailEnvActions.DO_NOTHING, reward=env.step_penalty * 0.5, @@ -154,35 +153,35 @@ def test_status_done_remove(): Replay( position=None, # not entered grid yet before step direction=Grid4TransitionsEnum.EAST, - status=RailAgentStatus.READY_TO_DEPART, + state=TrainState.READY_TO_DEPART, action=RailEnvActions.MOVE_LEFT, reward=env.step_penalty * 0.5, # auto-correction left to forward without penalty! ), Replay( position=(3, 9), direction=Grid4TransitionsEnum.EAST, - status=RailAgentStatus.ACTIVE, + state=TrainState.MOVING, action=RailEnvActions.MOVE_FORWARD, reward=env.start_penalty + env.step_penalty * 0.5, # running at speed 0.5 ), Replay( position=(3, 9), direction=Grid4TransitionsEnum.EAST, - status=RailAgentStatus.ACTIVE, + state=TrainState.MOVING, action=None, reward=env.step_penalty * 0.5, # running at speed 0.5 ), Replay( position=(3, 8), direction=Grid4TransitionsEnum.WEST, - status=RailAgentStatus.ACTIVE, + state=TrainState.MOVING, action=RailEnvActions.MOVE_FORWARD, reward=env.step_penalty * 0.5, # running at speed 0.5 ), Replay( position=(3, 8), direction=Grid4TransitionsEnum.WEST, - status=RailAgentStatus.ACTIVE, + state=TrainState.MOVING, action=None, reward=env.step_penalty * 0.5, # running at speed 0.5 @@ -192,28 +191,28 @@ def test_status_done_remove(): direction=Grid4TransitionsEnum.WEST, action=RailEnvActions.MOVE_RIGHT, reward=env.step_penalty * 0.5, # running at speed 0.5 - status=RailAgentStatus.ACTIVE + state=TrainState.MOVING ), Replay( position=(3, 7), direction=Grid4TransitionsEnum.WEST, action=None, reward=env.step_penalty * 0.5, # wrong action is corrected to forward without penalty! - status=RailAgentStatus.ACTIVE + state=TrainState.MOVING ), Replay( position=(3, 6), direction=Grid4TransitionsEnum.WEST, action=RailEnvActions.MOVE_FORWARD, reward=env.step_penalty * 0.5, # done - status=RailAgentStatus.ACTIVE + state=TrainState.MOVING ), Replay( position=(3, 6), direction=Grid4TransitionsEnum.WEST, action=None, reward=env.global_reward, # already done - status=RailAgentStatus.ACTIVE + state=TrainState.MOVING ), # Replay( # position=None, @@ -238,4 +237,4 @@ def test_status_done_remove(): ) run_replay_config(env, [test_config], activate_agents=False, skip_reward_check=True) - assert env.agents[0].status == RailAgentStatus.DONE_REMOVED + assert env.agents[0].state == TrainState.DONE diff --git a/tests/test_flatland_envs_observations.py b/tests/test_flatland_envs_observations.py index 2658813a95d20dac683c94a1fc827fd74eadbdfb..aee47c4009ded6cd4da38a33970a1cf51e08f5b8 100644 --- a/tests/test_flatland_envs_observations.py +++ b/tests/test_flatland_envs_observations.py @@ -5,7 +5,6 @@ import numpy as np from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.core.grid.grid4_utils import get_new_position -from flatland.envs.agent_utils import EnvAgent, RailAgentStatus from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv, RailEnvActions @@ -13,6 +12,7 @@ from flatland.envs.rail_generators import rail_from_grid_transition_map from flatland.envs.line_generators import sparse_line_generator from flatland.utils.rendertools import RenderTool from flatland.utils.simple_rail import make_simple_rail +from flatland.envs.step_utils.states import TrainState """Tests for `flatland` package.""" @@ -106,7 +106,7 @@ def test_reward_function_conflict(rendering=False): agent.initial_direction = 0 # north agent.target = (3, 9) # east dead-end agent.moving = True - agent.status = RailAgentStatus.ACTIVE + agent._set_state(TrainState.MOVING) agent = env.agents[1] agent.position = (3, 8) # east dead-end @@ -115,13 +115,13 @@ def test_reward_function_conflict(rendering=False): agent.initial_direction = 3 # west agent.target = (6, 6) # south dead-end agent.moving = True - agent.status = RailAgentStatus.ACTIVE + agent._set_state(TrainState.MOVING) env.reset(False, False) env.agents[0].moving = True env.agents[1].moving = True - env.agents[0].status = RailAgentStatus.ACTIVE - env.agents[1].status = RailAgentStatus.ACTIVE + env.agents[0]._set_state(TrainState.MOVING) + env.agents[1]._set_state(TrainState.MOVING) env.agents[0].position = (5, 6) env.agents[1].position = (3, 8) print("\n") @@ -195,7 +195,7 @@ def test_reward_function_waiting(rendering=False): agent.initial_direction = 3 # west agent.target = (3, 1) # west dead-end agent.moving = True - agent.status = RailAgentStatus.ACTIVE + agent._set_state(TrainState.MOVING) agent = env.agents[1] agent.initial_position = (5, 6) # south dead-end @@ -204,13 +204,13 @@ def test_reward_function_waiting(rendering=False): agent.initial_direction = 0 # north agent.target = (3, 8) # east dead-end agent.moving = True - agent.status = RailAgentStatus.ACTIVE + agent._set_state(TrainState.MOVING) env.reset(False, False) env.agents[0].moving = True env.agents[1].moving = True - env.agents[0].status = RailAgentStatus.ACTIVE - env.agents[1].status = RailAgentStatus.ACTIVE + env.agents[0]._set_state(TrainState.MOVING) + env.agents[1]._set_state(TrainState.MOVING) env.agents[0].position = (3, 8) env.agents[1].position = (5, 6) diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py index ad2187be4bad2df2b7a85438079aa7d1f2bb8a0e..399ec957c155715e30e2868f5bcc51a0c275bee3 100644 --- a/tests/test_flatland_envs_predictions.py +++ b/tests/test_flatland_envs_predictions.py @@ -5,7 +5,6 @@ import pprint import numpy as np from flatland.core.grid.grid4 import Grid4TransitionsEnum -from flatland.envs.agent_utils import RailAgentStatus from flatland.envs.observations import TreeObsForRailEnv, Node from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv @@ -16,6 +15,7 @@ from flatland.envs.line_generators import sparse_line_generator from flatland.utils.rendertools import RenderTool from flatland.utils.simple_rail import make_simple_rail, make_simple_rail2, make_invalid_simple_rail from flatland.envs.rail_env_action import RailEnvActions +from flatland.envs.step_utils.states import TrainState """Test predictions for `flatland` package.""" @@ -135,7 +135,7 @@ def test_shortest_path_predictor(rendering=False): agent.initial_direction = 0 # north agent.target = (3, 9) # east dead-end agent.moving = True - agent.status = RailAgentStatus.ACTIVE + agent._set_state(TrainState.MOVING) env.reset(False, False) env.distance_map._compute(env.agents, env.rail) @@ -269,7 +269,7 @@ def test_shortest_path_predictor_conflicts(rendering=False): env.agents[0].initial_direction = 0 # north env.agents[0].target = (3, 9) # east dead-end env.agents[0].moving = True - env.agents[0].status = RailAgentStatus.ACTIVE + env.agents[0]._set_state(TrainState.MOVING) env.agents[1].initial_position = (3, 8) # east dead-end env.agents[1].position = (3, 8) # east dead-end @@ -277,7 +277,7 @@ def test_shortest_path_predictor_conflicts(rendering=False): env.agents[1].initial_direction = 3 # west env.agents[1].target = (6, 6) # south dead-end env.agents[1].moving = True - env.agents[1].status = RailAgentStatus.ACTIVE + env.agents[1]._set_state(TrainState.MOVING) observations, info = env.reset(False, False) @@ -285,8 +285,8 @@ def test_shortest_path_predictor_conflicts(rendering=False): env.agent_positions[env.agents[0].position] = 0 env.agents[1].position = (3, 8) # east dead-end env.agent_positions[env.agents[1].position] = 1 - env.agents[0].status = RailAgentStatus.ACTIVE - env.agents[1].status = RailAgentStatus.ACTIVE + env.agents[0]._set_state(TrainState.MOVING) + env.agents[1]._set_state(TrainState.MOVING) observations = env._get_observations() diff --git a/tests/test_flatland_envs_sparse_rail_generator.py b/tests/test_flatland_envs_sparse_rail_generator.py index 5c12336a1a612cccd3df8beab42a8dcdfe9cdb59..358839f9e7b7368ead6861a00efdb4f36c9e090c 100644 --- a/tests/test_flatland_envs_sparse_rail_generator.py +++ b/tests/test_flatland_envs_sparse_rail_generator.py @@ -1315,8 +1315,8 @@ def test_rail_env_action_required_info(): if step == 0 or info_only_if_action_required['action_required'][a]: action_dict_only_if_action_required.update({a: action}) else: - print("[{}] not action_required {}, speed_data={}".format(step, a, - env_always_action.agents[a].speed_data)) + print("[{}] not action_required {}, speed_counter={}".format(step, a, + env_always_action.agents[a].speed_counter)) obs_always_action, rewards_always_action, done_always_action, info_always_action = env_always_action.step( action_dict_always_action) @@ -1375,7 +1375,7 @@ def test_rail_env_malfunction_speed_info(): for a in range(env.get_num_agents()): assert info['malfunction'][a] >= 0 assert info['speed'][a] >= 0 and info['speed'][a] <= 1 - assert info['speed'][a] == env.agents[a].speed_data['speed'] + assert info['speed'][a] == env.agents[a].sspeed_counter.speed env_renderer.render_env(show=True, show_observations=False, show_predictions=False) diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index e32e8d9f21120d7566cc027d7f9fa6cb36ded7be..d633351ed3624499aa2e30df9f09031b0b4cf581 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -6,14 +6,14 @@ import numpy as np from flatland.core.env_observation_builder import ObservationBuilder from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.core.grid.grid4_utils import get_new_position -from flatland.envs.agent_utils import RailAgentStatus from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters from flatland.envs.rail_env import RailEnv, RailEnvActions from flatland.envs.rail_generators import rail_from_grid_transition_map from flatland.envs.line_generators import sparse_line_generator from flatland.utils.simple_rail import make_simple_rail2 from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay - +from flatland.envs.step_utils.states import TrainState +from flatland.envs.step_utils.speed_counter import SpeedCounter class SingleAgentNavigationObs(ObservationBuilder): """ @@ -32,11 +32,11 @@ class SingleAgentNavigationObs(ObservationBuilder): def get(self, handle: int = 0) -> List[int]: agent = self.env.agents[handle] - if agent.status == RailAgentStatus.READY_TO_DEPART: + if agent.state.is_off_map_state(): agent_virtual_position = agent.initial_position - elif agent.status == RailAgentStatus.ACTIVE: + elif agent.state.is_on_map_state(): agent_virtual_position = agent.position - elif agent.status == RailAgentStatus.DONE: + elif agent.state == TrainState.DONE: agent_virtual_position = agent.target else: return None @@ -85,7 +85,7 @@ def test_malfunction_process(): obs, info = env.reset(False, False, random_seed=10) for a_idx in range(len(env.agents)): env.agents[a_idx].position = env.agents[a_idx].initial_position - env.agents[a_idx].status = RailAgentStatus.ACTIVE + env.agents[a_idx].state = TrainState.MOVING agent_halts = 0 total_down_time = 0 @@ -297,7 +297,7 @@ def test_initial_malfunction(): reward=env.step_penalty # running at speed 1.0 ) ], - speed=env.agents[0].speed_data['speed'], + speed=env.agents[0].speed_counter.speed, target=env.agents[0].target, initial_position=(3, 2), initial_direction=Grid4TransitionsEnum.EAST, @@ -315,7 +315,7 @@ def test_initial_malfunction_stop_moving(): env._max_episode_steps = 1000 - print(env.agents[0].initial_position, env.agents[0].direction, env.agents[0].position, env.agents[0].status) + print(env.agents[0].initial_position, env.agents[0].direction, env.agents[0].position, env.agents[0].state) set_penalties_for_replay(env) replay_config = ReplayConfig( @@ -327,7 +327,7 @@ def test_initial_malfunction_stop_moving(): set_malfunction=3, malfunction=3, reward=env.step_penalty, # full step penalty when stopped - status=RailAgentStatus.READY_TO_DEPART + state=TrainState.READY_TO_DEPART ), Replay( position=(3, 2), @@ -335,7 +335,7 @@ def test_initial_malfunction_stop_moving(): action=RailEnvActions.DO_NOTHING, malfunction=2, reward=env.step_penalty, # full step penalty when stopped - status=RailAgentStatus.ACTIVE + state=TrainState.READY_TO_DEPART ), # malfunction stops in the next step and we're still at the beginning of the cell # --> if we take action STOP_MOVING, agent should restart without moving @@ -346,7 +346,7 @@ def test_initial_malfunction_stop_moving(): action=RailEnvActions.STOP_MOVING, malfunction=1, reward=env.step_penalty, # full step penalty while stopped - status=RailAgentStatus.ACTIVE + state=TrainState.STOPPED ), # we have stopped and do nothing --> should stand still Replay( @@ -355,7 +355,7 @@ def test_initial_malfunction_stop_moving(): action=RailEnvActions.DO_NOTHING, malfunction=0, reward=env.step_penalty, # full step penalty while stopped - status=RailAgentStatus.ACTIVE + state=TrainState.STOPPED ), # we start to move forward --> should go to next cell now Replay( @@ -364,7 +364,7 @@ def test_initial_malfunction_stop_moving(): action=RailEnvActions.MOVE_FORWARD, malfunction=0, reward=env.start_penalty + env.step_penalty * 1.0, # full step penalty while stopped - status=RailAgentStatus.ACTIVE + state=TrainState.STOPPED ), Replay( position=(3, 3), @@ -372,10 +372,10 @@ def test_initial_malfunction_stop_moving(): action=RailEnvActions.MOVE_FORWARD, malfunction=0, reward=env.step_penalty * 1.0, # full step penalty while stopped - status=RailAgentStatus.ACTIVE + state=TrainState.STOPPED ) ], - speed=env.agents[0].speed_data['speed'], + speed=env.agents[0].speed_counter.speed, target=env.agents[0].target, initial_position=(3, 2), initial_direction=Grid4TransitionsEnum.EAST, @@ -412,7 +412,7 @@ def test_initial_malfunction_do_nothing(): set_malfunction=3, malfunction=3, reward=env.step_penalty, # full step penalty while malfunctioning - status=RailAgentStatus.READY_TO_DEPART + state=TrainState.READY_TO_DEPART ), Replay( position=(3, 2), @@ -420,7 +420,7 @@ def test_initial_malfunction_do_nothing(): action=RailEnvActions.DO_NOTHING, malfunction=2, reward=env.step_penalty, # full step penalty while malfunctioning - status=RailAgentStatus.ACTIVE + state=TrainState.ACTIVE ), # malfunction stops in the next step and we're still at the beginning of the cell # --> if we take action DO_NOTHING, agent should restart without moving @@ -431,7 +431,7 @@ def test_initial_malfunction_do_nothing(): action=RailEnvActions.DO_NOTHING, malfunction=1, reward=env.step_penalty, # full step penalty while stopped - status=RailAgentStatus.ACTIVE + state=TrainState.ACTIVE ), # we haven't started moving yet --> stay here Replay( @@ -440,7 +440,7 @@ def test_initial_malfunction_do_nothing(): action=RailEnvActions.DO_NOTHING, malfunction=0, reward=env.step_penalty, # full step penalty while stopped - status=RailAgentStatus.ACTIVE + state=TrainState.ACTIVE ), Replay( @@ -449,7 +449,7 @@ def test_initial_malfunction_do_nothing(): action=RailEnvActions.MOVE_FORWARD, malfunction=0, reward=env.start_penalty + env.step_penalty * 1.0, # start penalty + step penalty for speed 1.0 - status=RailAgentStatus.ACTIVE + state=TrainState.ACTIVE ), # we start to move forward --> should go to next cell now Replay( position=(3, 3), @@ -457,10 +457,10 @@ def test_initial_malfunction_do_nothing(): action=RailEnvActions.MOVE_FORWARD, malfunction=0, reward=env.step_penalty * 1.0, # step penalty for speed 1.0 - status=RailAgentStatus.ACTIVE + state=TrainState.ACTIVE ) ], - speed=env.agents[0].speed_data['speed'], + speed=env.agents[0].speed_counter.speed, target=env.agents[0].target, initial_position=(3, 2), initial_direction=Grid4TransitionsEnum.EAST, @@ -475,7 +475,7 @@ def tests_random_interference_from_outside(): env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals), line_generator=sparse_line_generator(seed=2), number_of_agents=1, random_seed=1) env.reset() - env.agents[0].speed_data['speed'] = 0.33 + env.agents[0].speed_counter = SpeedCounter(speed=0.33) env.reset(False, False, random_seed=10) env_data = [] @@ -501,7 +501,7 @@ def tests_random_interference_from_outside(): env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals), line_generator=sparse_line_generator(seed=2), number_of_agents=1, random_seed=1) env.reset() - env.agents[0].speed_data['speed'] = 0.33 + env.agents[0].speed_counter = SpeedCounter(speed=0.33) env.reset(False, False, random_seed=10) dummy_list = [1, 2, 6, 7, 8, 9, 4, 5, 4] @@ -536,7 +536,7 @@ def test_last_malfunction_step(): env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals), line_generator=sparse_line_generator(seed=2), number_of_agents=1, random_seed=1) env.reset() - env.agents[0].speed_data['speed'] = 1. / 3. + env.agents[0].speed_counter = SpeedCounter(speed=1./3.) env.agents[0].initial_position = (6, 6) env.agents[0].initial_direction = 2 env.agents[0].target = (0, 3) @@ -546,7 +546,7 @@ def test_last_malfunction_step(): env.reset(False, False) for a_idx in range(len(env.agents)): env.agents[a_idx].position = env.agents[a_idx].initial_position - env.agents[a_idx].status = RailAgentStatus.ACTIVE + env.agents[a_idx].state = TrainState.ACTIVE # Force malfunction to be off at beginning and next malfunction to happen in 2 steps env.agents[0].malfunction_data['next_malfunction'] = 2 env.agents[0].malfunction_data['malfunction'] = 0 @@ -565,13 +565,13 @@ def test_last_malfunction_step(): if env.agents[0].malfunction_data['malfunction'] < 1: agent_can_move = True # Store the position before and after the step - pre_position = env.agents[0].speed_data['position_fraction'] + pre_position = env.agents[0].speed_counter.counter _, reward, _, _ = env.step(action_dict) # Check if the agent is still allowed to move in this step if env.agents[0].malfunction_data['malfunction'] > 0: agent_can_move = False - post_position = env.agents[0].speed_data['position_fraction'] + post_position = env.agents[0].speed_counter.counter # Assert that the agent moved while it was still allowed if agent_can_move: assert pre_position != post_position diff --git a/tests/test_generators.py b/tests/test_generators.py index 67f883746f2767bc98a090285428d7d377c905a1..7d91bce89bd2d840f433de9f895b29e5a822cf3d 100644 --- a/tests/test_generators.py +++ b/tests/test_generators.py @@ -10,7 +10,7 @@ from flatland.envs.rail_generators import rail_from_grid_transition_map, rail_fr from flatland.envs.line_generators import sparse_line_generator, line_from_file from flatland.utils.simple_rail import make_simple_rail from flatland.envs.persistence import RailEnvPersister -from flatland.envs.agent_utils import RailAgentStatus +from flatland.envs.step_utils.states import TrainState def test_empty_rail_generator(): @@ -35,7 +35,7 @@ def test_rail_from_grid_transition_map(): for a_idx in range(len(env.agents)): env.agents[a_idx].position = env.agents[a_idx].initial_position - env.agents[a_idx].status = RailAgentStatus.ACTIVE + env.agents[a_idx]._set_state(TrainState.MOVING) nr_rail_elements = np.count_nonzero(env.rail.grid) diff --git a/tests/test_global_observation.py b/tests/test_global_observation.py index 851d849d1246773d7d06b5f38ed0eef820f74a56..1ea959a251e9dd672db4a71a11e3bd76bfced433 100644 --- a/tests/test_global_observation.py +++ b/tests/test_global_observation.py @@ -1,10 +1,11 @@ import numpy as np -from flatland.envs.agent_utils import EnvAgent, RailAgentStatus +from flatland.envs.agent_utils import EnvAgent from flatland.envs.observations import GlobalObsForRailEnv from flatland.envs.rail_env import RailEnv, RailEnvActions from flatland.envs.rail_generators import sparse_rail_generator from flatland.envs.line_generators import sparse_line_generator +from flatland.envs.step_utils.states import TrainState def test_get_global_observation(): @@ -37,7 +38,7 @@ def test_get_global_observation(): obs, all_rewards, done, _ = env.step({i: RailEnvActions.MOVE_FORWARD for i in range(number_of_agents)}) for i in range(len(env.agents)): agent: EnvAgent = env.agents[i] - print("[{}] status={}, position={}, target={}, initial_position={}".format(i, agent.status, agent.position, + print("[{}] state={}, position={}, target={}, initial_position={}".format(i, agent.state, agent.position, agent.target, agent.initial_position)) @@ -65,19 +66,19 @@ def test_get_global_observation(): # test first channel of obs_agents_state: direction at own position for r in range(env.height): for c in range(env.width): - if (agent.status == RailAgentStatus.ACTIVE or agent.status == RailAgentStatus.DONE) and ( + if (agent.state.is_on_map_state() or agent.state == TrainState.DONE) and ( r, c) == agent.position: assert np.isclose(obs_agents_state[(r, c)][0], agent.direction), \ - "agent {} in status {} at {} expected to contain own direction {}, found {}" \ - .format(i, agent.status, (r, c), agent.direction, obs_agents_state[(r, c)][0]) - elif (agent.status == RailAgentStatus.READY_TO_DEPART) and (r, c) == agent.initial_position: + "agent {} in state {} at {} expected to contain own direction {}, found {}" \ + .format(i, agent.state, (r, c), agent.direction, obs_agents_state[(r, c)][0]) + elif (agent.state == TrainState.READY_TO_DEPART) and (r, c) == agent.initial_position: assert np.isclose(obs_agents_state[(r, c)][0], agent.direction), \ - "agent {} in status {} at {} expected to contain own direction {}, found {}" \ - .format(i, agent.status, (r, c), agent.direction, obs_agents_state[(r, c)][0]) + "agent {} in state {} at {} expected to contain own direction {}, found {}" \ + .format(i, agent.state, (r, c), agent.direction, obs_agents_state[(r, c)][0]) else: assert np.isclose(obs_agents_state[(r, c)][0], -1), \ - "agent {} in status {} at {} expected contain -1 found {}" \ - .format(i, agent.status, (r, c), obs_agents_state[(r, c)][0]) + "agent {} in state {} at {} expected contain -1 found {}" \ + .format(i, agent.state, (r, c), obs_agents_state[(r, c)][0]) # test second channel of obs_agents_state: direction at other agents position for r in range(env.height): @@ -86,45 +87,45 @@ def test_get_global_observation(): for other_i, other_agent in enumerate(env.agents): if i == other_i: continue - if other_agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] and ( + if other_agent.state in [TrainState.MOVING, TrainState.MALFUNCTION, TrainState.STOPPED, TrainState.DONE] and ( r, c) == other_agent.position: assert np.isclose(obs_agents_state[(r, c)][1], other_agent.direction), \ - "agent {} in status {} at {} should see other agent with direction {}, found = {}" \ - .format(i, agent.status, (r, c), other_agent.direction, obs_agents_state[(r, c)][1]) + "agent {} in state {} at {} should see other agent with direction {}, found = {}" \ + .format(i, agent.state, (r, c), other_agent.direction, obs_agents_state[(r, c)][1]) has_agent = True if not has_agent: assert np.isclose(obs_agents_state[(r, c)][1], -1), \ - "agent {} in status {} at {} should see no other agent direction (-1), found = {}" \ - .format(i, agent.status, (r, c), obs_agents_state[(r, c)][1]) + "agent {} in state {} at {} should see no other agent direction (-1), found = {}" \ + .format(i, agent.state, (r, c), obs_agents_state[(r, c)][1]) # test third and fourth channel of obs_agents_state: malfunction and speed of own or other agent in the grid for r in range(env.height): for c in range(env.width): has_agent = False for other_i, other_agent in enumerate(env.agents): - if other_agent.status in [RailAgentStatus.ACTIVE, - RailAgentStatus.DONE] and other_agent.position == (r, c): + if other_agent.state in [TrainState.MOVING, TrainState.MALFUNCTION, TrainState.STOPPED, + TrainState.DONE] and other_agent.position == (r, c): assert np.isclose(obs_agents_state[(r, c)][2], other_agent.malfunction_data['malfunction']), \ - "agent {} in status {} at {} should see agent malfunction {}, found = {}" \ - .format(i, agent.status, (r, c), other_agent.malfunction_data['malfunction'], + "agent {} in state {} at {} should see agent malfunction {}, found = {}" \ + .format(i, agent.state, (r, c), other_agent.malfunction_data['malfunction'], obs_agents_state[(r, c)][2]) - assert np.isclose(obs_agents_state[(r, c)][3], other_agent.speed_data['speed']) + assert np.isclose(obs_agents_state[(r, c)][3], other_agent.speed_counter.speed) has_agent = True if not has_agent: assert np.isclose(obs_agents_state[(r, c)][2], -1), \ - "agent {} in status {} at {} should see no agent malfunction (-1), found = {}" \ - .format(i, agent.status, (r, c), obs_agents_state[(r, c)][2]) + "agent {} in state {} at {} should see no agent malfunction (-1), found = {}" \ + .format(i, agent.state, (r, c), obs_agents_state[(r, c)][2]) assert np.isclose(obs_agents_state[(r, c)][3], -1), \ - "agent {} in status {} at {} should see no agent speed (-1), found = {}" \ - .format(i, agent.status, (r, c), obs_agents_state[(r, c)][3]) + "agent {} in state {} at {} should see no agent speed (-1), found = {}" \ + .format(i, agent.state, (r, c), obs_agents_state[(r, c)][3]) # test fifth channel of obs_agents_state: number of agents ready to depart in to this cell for r in range(env.height): for c in range(env.width): count = 0 for other_i, other_agent in enumerate(env.agents): - if other_agent.status == RailAgentStatus.READY_TO_DEPART and other_agent.initial_position == (r, c): + if other_agent.state == TrainState.READY_TO_DEPART and other_agent.initial_position == (r, c): count += 1 assert np.isclose(obs_agents_state[(r, c)][4], count), \ - "agent {} in status {} at {} should see {} agents ready to depart, found{}" \ - .format(i, agent.status, (r, c), count, obs_agents_state[(r, c)][4]) + "agent {} in state {} at {} should see {} agents ready to depart, found{}" \ + .format(i, agent.state, (r, c), count, obs_agents_state[(r, c)][4]) diff --git a/tests/test_multi_speed.py b/tests/test_multi_speed.py index 561057d81b431dfbb87b904f7a57e6fcbf84f84e..50565e96bc5716f032af189e25b79622d3ca3586 100644 --- a/tests/test_multi_speed.py +++ b/tests/test_multi_speed.py @@ -8,7 +8,7 @@ from flatland.envs.rail_generators import sparse_rail_generator, rail_from_grid_ from flatland.envs.line_generators import sparse_line_generator from flatland.utils.simple_rail import make_simple_rail from test_utils import ReplayConfig, Replay, run_replay_config, set_penalties_for_replay -from flatland.envs.agent_utils import RailAgentStatus +from flatland.envs.step_utils.states import TrainState # Use the sparse_rail_generator to generate feasible network configurations with corresponding tasks @@ -65,13 +65,13 @@ def test_multi_speed_init(): for a_idx in range(len(env.agents)): env.agents[a_idx].position = env.agents[a_idx].initial_position - env.agents[a_idx].status = RailAgentStatus.ACTIVE + env.agents[a_idx]._set_state(TrainState.MOVING) # Here you can also further enhance the provided observation by means of normalization # See training navigation example in the baseline repository old_pos = [] for i_agent in range(env.get_num_agents()): - env.agents[i_agent].speed_data['speed'] = 1. / (i_agent + 1) + env.agents[i_agent].speed_counter.speed = 1. / (i_agent + 1) old_pos.append(env.agents[i_agent].position) print(env.agents[i_agent].position) # Run episode diff --git a/tests/test_speed_classes.py b/tests/test_speed_classes.py index 3cfe1b1c7f58786cf0caacde629fa3a6c704230d..66f1fbf06eaeb70ed39ac8aa35c93f0fa11c6a32 100644 --- a/tests/test_speed_classes.py +++ b/tests/test_speed_classes.py @@ -23,7 +23,7 @@ def test_rail_env_speed_intializer(): rail_generator=sparse_rail_generator(), line_generator=sparse_line_generator(), number_of_agents=10) env.reset() - actual_speeds = list(map(lambda agent: agent.speed_data['speed'], env.agents)) + actual_speeds = list(map(lambda agent: agent.speed_counter.speed, env.agents)) expected_speed_set = set(speed_ratio_map.keys()) diff --git a/tests/test_utils.py b/tests/test_utils.py index 4b72679ed6a1ceac1f266760d1871c6fc405e6dc..85e6a2755ac66ffeb15a7a8b2d0f4c9de9652e80 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -5,13 +5,15 @@ import numpy as np from attr import attrs, attrib from flatland.core.grid.grid4 import Grid4TransitionsEnum -from flatland.envs.agent_utils import EnvAgent, RailAgentStatus +from flatland.envs.agent_utils import EnvAgent from flatland.envs.malfunction_generators import MalfunctionParameters, malfunction_from_params from flatland.envs.rail_env import RailEnvActions, RailEnv from flatland.envs.rail_generators import RailGenerator from flatland.envs.line_generators import LineGenerator from flatland.utils.rendertools import RenderTool from flatland.envs.persistence import RailEnvPersister +from flatland.envs.step_utils.states import TrainState +from flatland.envs.step_utils.speed_counter import SpeedCounter @attrs class Replay(object): @@ -21,7 +23,7 @@ class Replay(object): malfunction = attrib(default=0, type=int) set_malfunction = attrib(default=None, type=Optional[int]) reward = attrib(default=None, type=Optional[float]) - status = attrib(default=None, type=Optional[RailAgentStatus]) + state = attrib(default=None, type=Optional[TrainState]) @attrs @@ -86,12 +88,12 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: agent.initial_direction = test_config.initial_direction agent.direction = test_config.initial_direction agent.target = test_config.target - agent.speed_data['speed'] = test_config.speed + agent.speed_counter = SpeedCounter(speed=test_config.speed) env.reset(False, False) if activate_agents: for a_idx in range(len(env.agents)): env.agents[a_idx].position = env.agents[a_idx].initial_position - env.agents[a_idx].status = RailAgentStatus.ACTIVE + env.agents[a_idx]._set_state(TrainState.MOVING) def _assert(a, actual, expected, msg): print("[{}] verifying {} on agent {}: actual={}, expected={}".format(step, msg, a, actual, expected)) @@ -108,12 +110,12 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: _assert(a, agent.position, replay.position, 'position') _assert(a, agent.direction, replay.direction, 'direction') - if replay.status is not None: - _assert(a, agent.status, replay.status, 'status') + if replay.state is not None: + _assert(a, agent.state, replay.state, 'state') if replay.action is not None: assert info_dict['action_required'][ - a] == True or agent.status == RailAgentStatus.READY_TO_DEPART, "[{}] agent {} expecting action_required={} or agent status READY_TO_DEPART".format( + a] == True or agent.state == TrainState.READY_TO_DEPART, "[{}] agent {} expecting action_required={} or agent status READY_TO_DEPART".format( step, a, True) action_dict[a] = replay.action else: