Commit 7d460229 authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

fix tests for activate_agents removal

parent ab6d03f6
Pipeline #8438 failed with stages
in 5 minutes and 58 seconds
......@@ -30,7 +30,7 @@ def test_action_plan(rendering: bool = False):
env.agents[1].initial_direction = Grid4TransitionsEnum.WEST
env.agents[1].target = (0, 3)
env.agents[1].speed_data['speed'] = 0.5 # two
env.reset(False, False, False)
env.reset(False, False)
for handle, agent in enumerate(env.agents):
print("[{}] {} -> {}".format(handle, agent.initial_position, agent.target))
......
......@@ -15,6 +15,7 @@ from flatland.envs.rail_trainrun_data_structures import Waypoint
from flatland.envs.line_generators import sparse_line_generator
from flatland.utils.rendertools import RenderTool
from flatland.utils.simple_rail import make_simple_rail, make_simple_rail2, make_invalid_simple_rail
from flatland.envs.rail_env_action import RailEnvActions
"""Test predictions for `flatland` package."""
......@@ -38,7 +39,11 @@ def test_dummy_predictor(rendering=False):
env.agents[0].target = (3, 0)
env.reset(False, False)
env.set_agent_active(env.agents[0])
env.agents[0].earliest_departure = 1
env._max_episode_steps = 100
# Make Agent 0 active
env.step({})
env.step({0: RailEnvActions.MOVE_FORWARD})
if rendering:
renderer = RenderTool(env, gl="PILSVG")
......@@ -258,25 +263,33 @@ def test_shortest_path_predictor_conflicts(rendering=False):
env.reset()
# set the initial position
agent = env.agents[0]
agent.initial_position = (5, 6) # south dead-end
agent.position = (5, 6) # south dead-end
agent.direction = 0 # north
agent.initial_direction = 0 # north
agent.target = (3, 9) # east dead-end
agent.moving = True
agent.status = RailAgentStatus.ACTIVE
agent = env.agents[1]
agent.initial_position = (3, 8) # east dead-end
agent.position = (3, 8) # east dead-end
agent.direction = 3 # west
agent.initial_direction = 3 # west
agent.target = (6, 6) # south dead-end
agent.moving = True
agent.status = RailAgentStatus.ACTIVE
env.agents[0].initial_position = (5, 6) # south dead-end
env.agents[0].position = (5, 6) # south dead-end
env.agents[0].direction = 0 # north
env.agents[0].initial_direction = 0 # north
env.agents[0].target = (3, 9) # east dead-end
env.agents[0].moving = True
env.agents[0].status = RailAgentStatus.ACTIVE
env.agents[1].initial_position = (3, 8) # east dead-end
env.agents[1].position = (3, 8) # east dead-end
env.agents[1].direction = 3 # west
env.agents[1].initial_direction = 3 # west
env.agents[1].target = (6, 6) # south dead-end
env.agents[1].moving = True
env.agents[1].status = RailAgentStatus.ACTIVE
observations, info = env.reset(False, False)
env.agents[0].position = (5, 6) # south dead-end
env.agent_positions[env.agents[0].position] = 0
env.agents[1].position = (3, 8) # east dead-end
env.agent_positions[env.agents[1].position] = 1
env.agents[0].status = RailAgentStatus.ACTIVE
env.agents[1].status = RailAgentStatus.ACTIVE
observations = env._get_observations()
observations, info = env.reset(False, False, True)
if rendering:
renderer = RenderTool(env, gl="PILSVG")
......
......@@ -19,7 +19,7 @@ def test_sparse_rail_generator():
),
line_generator=sparse_line_generator(), number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv())
env.reset(False, False, True)
env.reset(False, False)
# for r in range(env.height):
# for c in range(env.width):
# if env.rail.grid[r][c] > 0:
......@@ -1300,8 +1300,8 @@ def test_rail_env_action_required_info():
# Reset the envs
env_always_action.reset(False, False, True, random_seed=5)
env_only_if_action_required.reset(False, False, True, random_seed=5)
env_always_action.reset(False, False, random_seed=5)
env_only_if_action_required.reset(False, False, random_seed=5)
assert env_only_if_action_required.rail.grid.tolist() == env_always_action.rail.grid.tolist()
for step in range(50):
print("step {}".format(step))
......@@ -1358,7 +1358,7 @@ def test_rail_env_malfunction_speed_info():
),
line_generator=sparse_line_generator(), number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv())
env.reset(False, False, True)
env.reset(False, False)
env_renderer = RenderTool(env, gl="PILSVG", )
for step in range(100):
......@@ -1432,5 +1432,5 @@ def test_sparse_generator_changes_to_grid_mode():
), line_generator=sparse_line_generator(), number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv())
with warnings.catch_warnings(record=True) as w:
rail_env.reset(True, True, True, random_seed=15)
rail_env.reset(True, True, random_seed=15)
assert "[WARNING]" in str(w[-1].message)
......@@ -82,7 +82,10 @@ def test_malfunction_process():
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
obs_builder_object=SingleAgentNavigationObs()
)
obs, info = env.reset(False, False, True, random_seed=10)
obs, info = env.reset(False, False, random_seed=10)
for a_idx in range(len(env.agents)):
env.agents[a_idx].position = env.agents[a_idx].initial_position
env.agents[a_idx].status = RailAgentStatus.ACTIVE
agent_halts = 0
total_down_time = 0
......@@ -142,7 +145,7 @@ def test_malfunction_process_statistically():
obs_builder_object=SingleAgentNavigationObs()
)
env.reset(True, True, False, random_seed=10)
env.reset(True, True, random_seed=10)
env._max_episode_steps = 1000
env.agents[0].target = (0, 0)
......@@ -181,7 +184,7 @@ def test_malfunction_before_entry():
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
obs_builder_object=SingleAgentNavigationObs()
)
env.reset(False, False, False, random_seed=10)
env.reset(False, False, random_seed=10)
env.agents[0].target = (0, 0)
# Test initial malfunction values for all agents
......@@ -215,7 +218,7 @@ def test_malfunction_values_and_behavior():
obs_builder_object=SingleAgentNavigationObs()
)
env.reset(False, False, activate_agents=True, random_seed=10)
env.reset(False, False, random_seed=10)
# Assertions
assert_list = [9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 10, 9, 8, 7, 6, 5]
......@@ -247,7 +250,7 @@ def test_initial_malfunction():
obs_builder_object=SingleAgentNavigationObs()
)
# reset to initialize agents_static
env.reset(False, False, True, random_seed=10)
env.reset(False, False, random_seed=10)
env._max_episode_steps = 1000
print(env.agents[0].malfunction_data)
env.agents[0].target = (0, 5)
......@@ -473,7 +476,7 @@ def tests_random_interference_from_outside():
line_generator=sparse_line_generator(seed=2), number_of_agents=1, random_seed=1)
env.reset()
env.agents[0].speed_data['speed'] = 0.33
env.reset(False, False, False, random_seed=10)
env.reset(False, False, random_seed=10)
env_data = []
for step in range(200):
......@@ -499,7 +502,7 @@ def tests_random_interference_from_outside():
line_generator=sparse_line_generator(seed=2), number_of_agents=1, random_seed=1)
env.reset()
env.agents[0].speed_data['speed'] = 0.33
env.reset(False, False, False, random_seed=10)
env.reset(False, False, random_seed=10)
dummy_list = [1, 2, 6, 7, 8, 9, 4, 5, 4]
for step in range(200):
......@@ -540,7 +543,10 @@ def test_last_malfunction_step():
env._max_episode_steps = 1000
env.reset(False, False, True)
env.reset(False, False)
for a_idx in range(len(env.agents)):
env.agents[a_idx].position = env.agents[a_idx].initial_position
env.agents[a_idx].status = RailAgentStatus.ACTIVE
# Force malfunction to be off at beginning and next malfunction to happen in 2 steps
env.agents[0].malfunction_data['next_malfunction'] = 2
env.agents[0].malfunction_data['malfunction'] = 0
......
......@@ -10,6 +10,7 @@ from flatland.envs.rail_generators import rail_from_grid_transition_map, rail_fr
from flatland.envs.line_generators import sparse_line_generator, line_from_file
from flatland.utils.simple_rail import make_simple_rail
from flatland.envs.persistence import RailEnvPersister
from flatland.envs.agent_utils import RailAgentStatus
def test_empty_rail_generator():
......@@ -30,7 +31,12 @@ def test_rail_from_grid_transition_map():
n_agents = 2
env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(), number_of_agents=n_agents)
env.reset(False, False, True)
env.reset(False, False)
for a_idx in range(len(env.agents)):
env.agents[a_idx].position = env.agents[a_idx].initial_position
env.agents[a_idx].status = RailAgentStatus.ACTIVE
nr_rail_elements = np.count_nonzero(env.rail.grid)
# Check if the number of non-empty rail cells is ok
......
......@@ -60,7 +60,7 @@ def test_multi_speed_init():
# Set all the different speeds
# Reset environment and get initial observations for all agents
env.reset(False, False, True)
env.reset(False, False)
# Here you can also further enhance the provided observation by means of normalization
# See training navigation example in the baseline repository
......@@ -68,6 +68,7 @@ def test_multi_speed_init():
for i_agent in range(env.get_num_agents()):
env.agents[i_agent].speed_data['speed'] = 1. / (i_agent + 2)
old_pos.append(env.agents[i_agent].position)
# Run episode
for step in range(100):
......
......@@ -16,7 +16,7 @@ def ndom_seeding():
for idx in range(100):
env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(seed=12), number_of_agents=10)
env.reset(True, True, False, random_seed=1)
env.reset(True, True, random_seed=1)
env.agents[0].target = (0, 0)
for step in range(10):
......@@ -56,8 +56,8 @@ def test_seeding_and_observations():
line_generator=sparse_line_generator(seed=12), number_of_agents=10,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
env.reset(False, False, False, random_seed=12)
env2.reset(False, False, False, random_seed=12)
env.reset(False, False, random_seed=12)
env2.reset(False, False, random_seed=12)
# Check that both environments produce the same initial start positions
assert env.agents[0].initial_position == env2.agents[0].initial_position
assert env.agents[1].initial_position == env2.agents[1].initial_position
......@@ -112,8 +112,8 @@ def test_seeding_and_malfunction():
line_generator=sparse_line_generator(), number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv())
env.reset(True, False, True, random_seed=tests)
env2.reset(True, False, True, random_seed=tests)
env.reset(True, False, random_seed=tests)
env2.reset(True, False, random_seed=tests)
# Check that both environments produce the same initial start positions
assert env.agents[0].initial_position == env2.agents[0].initial_position
......@@ -170,7 +170,7 @@ def test_reproducability_env():
grid_mode=True
),
line_generator=sparse_line_generator(speed_ration_map), number_of_agents=1)
env.reset(True, True, True, random_seed=10)
env.reset(True, True, random_seed=10)
excpeted_grid = [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
......@@ -213,5 +213,5 @@ def test_reproducability_env():
np.random.seed(10)
for i in range(10):
np.random.randn()
env2.reset(True, True, True, random_seed=10)
env2.reset(True, True, random_seed=10)
assert env2.rail.grid.tolist() == excpeted_grid
......@@ -87,7 +87,11 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering:
agent.direction = test_config.initial_direction
agent.target = test_config.target
agent.speed_data['speed'] = test_config.speed
env.reset(False, False, activate_agents)
env.reset(False, False)
if activate_agents:
for a_idx in range(len(env.agents)):
env.agents[a_idx].position = env.agents[a_idx].initial_position
env.agents[a_idx].status = RailAgentStatus.ACTIVE
def _assert(a, actual, expected, msg):
print("[{}] verifying {} on agent {}: actual={}, expected={}".format(step, msg, a, actual, expected))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment