diff --git a/tests/simple_rail.py b/tests/simple_rail.py index 894864acc8d7435681636f2b0c60da238265194e..fb739f29942d0552b41b602046030c76031e7539 100644 --- a/tests/simple_rail.py +++ b/tests/simple_rail.py @@ -1,10 +1,11 @@ import numpy as np +from typing import Tuple from flatland.core.grid.grid4 import Grid4Transitions from flatland.core.transition_map import GridTransitionMap -def make_simple_rail(): +def make_simple_rail() -> Tuple[GridTransitionMap,np.array]: # We instantiate a very simple rail network on a 7x10 grid: # | # | diff --git a/tests/test_distance_map.py b/tests/test_distance_map.py index 79f4bab164312f757ad584bd3708a7af3fb7a97e..742e841c1a699f849a19c1f27f1b084b440a155a 100644 --- a/tests/test_distance_map.py +++ b/tests/test_distance_map.py @@ -53,4 +53,4 @@ def test_walker(): print(obs_builder.distance_map[(0, *[0, 1], 1)]) assert obs_builder.distance_map[(0, *[0, 1], 1)] == 3 print(obs_builder.distance_map[(0, *[0, 2], 3)]) - assert obs_builder.distance_map[(0, *[0, 2], 1)] == 2 # does not work yet, Erik's proposal. + assert obs_builder.distance_map[(0, *[0, 2], 1)] == 2 diff --git a/tests/test_flatland_envs_observations.py b/tests/test_flatland_envs_observations.py index fb45e5fa162d1323533a62349922b569a415395f..4eb8d633c5a246c5a755a922cae02484bb1ac696 100644 --- a/tests/test_flatland_envs_observations.py +++ b/tests/test_flatland_envs_observations.py @@ -3,9 +3,13 @@ import numpy as np +from flatland.core.grid.grid4 import Grid4TransitionsEnum +from flatland.envs.agent_utils import EnvAgent from flatland.envs.generators import rail_from_GridTransitionMap_generator -from flatland.envs.observations import GlobalObsForRailEnv -from flatland.envs.rail_env import RailEnv +from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv +from flatland.envs.predictions import ShortestPathPredictorForRailEnv +from flatland.envs.rail_env import RailEnv, RailEnvActions +from flatland.utils.rendertools import RenderTool from tests.simple_rail import make_simple_rail """Tests for `flatland` package.""" @@ -35,3 +39,253 @@ def test_global_obs(): # If this assertion is wrong, it means that the observation returned # places the agent on an empty cell assert (np.sum(rail_map * global_obs[0][1][:, :, :4].sum(2)) > 0) + + +def _step_along_shortest_path(env, obs_builder, rail): + actions = {} + expected_next_position = {} + for agent in env.agents: + agent: EnvAgent + shortest_distance = np.inf + + for exit_direction in range(4): + neighbour = obs_builder._new_position(agent.position, exit_direction) + + if neighbour[0] >= 0 and neighbour[0] < env.height and neighbour[1] >= 0 and neighbour[1] < env.width: + desired_movement_from_new_cell = (exit_direction + 2) % 4 + + # Check all possible transitions in new_cell + for agent_orientation in range(4): + # Is a transition along movement `entry_direction' to the neighbour possible? + is_valid = obs_builder.env.rail.get_transition((neighbour[0], neighbour[1], agent_orientation), + desired_movement_from_new_cell) + if is_valid: + distance_to_target = obs_builder.distance_map[ + (agent.handle, *agent.position, exit_direction)] + print("agent {} at {} facing {} taking {} distance {}".format(agent.handle, agent.position, + agent.direction, + exit_direction, + distance_to_target)) + + if distance_to_target < shortest_distance: + shortest_distance = distance_to_target + actions_to_be_taken_when_facing_north = { + Grid4TransitionsEnum.NORTH: RailEnvActions.MOVE_FORWARD, + Grid4TransitionsEnum.EAST: RailEnvActions.MOVE_RIGHT, + Grid4TransitionsEnum.WEST: RailEnvActions.MOVE_LEFT, + Grid4TransitionsEnum.SOUTH: RailEnvActions.DO_NOTHING, + } + print(" improved (direction) -> {}".format(exit_direction)) + + actions[agent.handle] = actions_to_be_taken_when_facing_north[ + (exit_direction - agent.direction) % len(rail.transitions.get_direction_enum())] + expected_next_position[agent.handle] = neighbour + print(" improved (action) -> {}".format(actions[agent.handle])) + _, rewards, dones, _ = env.step(actions) + return rewards + + +def test_reward_function_conflict(rendering=False): + rail, rail_map = make_simple_rail() + env = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_GridTransitionMap_generator(rail), + number_of_agents=2, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), + ) + obs_builder: TreeObsForRailEnv = env.obs_builder + # initialize agents_static + env.reset() + + # set the initial position + agent = env.agents_static[0] + agent.position = (5, 6) # south dead-end + agent.direction = 0 # north + agent.target = (3, 9) # east dead-end + agent.moving = True + + agent = env.agents_static[1] + agent.position = (3, 8) # east dead-end + agent.direction = 3 # west + agent.target = (6, 6) # south dead-end + agent.moving = True + + # reset to set agents from agents_static + env.reset(False, False) + + if rendering: + renderer = RenderTool(env, gl="PILSVG") + renderer.renderEnv(show=True, show_observations=True) + + iteration = -1 + expected_positions = { + 0: { + 0: (5, 6), + 1: (3, 8) + }, + # both can move + 1: { + 0: (4, 6), + 1: (3, 7) + }, + # first can move, second stuck + 2: { + 0: (3, 6), + 1: (3, 7) + }, + # both stuck from now on + 3: { + 0: (3, 6), + 1: (3, 7) + }, + 4: { + 0: (3, 6), + 1: (3, 7) + }, + 5: { + 0: (3, 6), + 1: (3, 7) + }, + } + while not env.dones["__all__"] and iteration + 1 < 5: + iteration += 1 + rewards = _step_along_shortest_path(env, obs_builder, rail) + + for agent in env.agents: + assert rewards[agent.handle] == -1 + expected_position = expected_positions[iteration + 1][agent.handle] + assert agent.position == expected_position, "[{}] agent {} at {}, expected {}".format(iteration + 1, + agent.handle, + agent.position, + expected_position) + if rendering: + renderer.renderEnv(show=True, show_observations=True) + + +def test_reward_function_waiting(rendering=False): + rail, rail_map = make_simple_rail() + env = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_GridTransitionMap_generator(rail), + number_of_agents=2, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), + ) + obs_builder: TreeObsForRailEnv = env.obs_builder + # initialize agents_static + env.reset() + + # set the initial position + agent = env.agents_static[0] + agent.position = (3, 8) # east dead-end + agent.direction = 3 # west + agent.target = (3, 1) # west dead-end + agent.moving = True + + agent = env.agents_static[1] + agent.position = (5, 6) # south dead-end + agent.direction = 0 # north + agent.target = (3, 8) # east dead-end + agent.moving = True + + # reset to set agents from agents_static + env.reset(False, False) + + if rendering: + renderer = RenderTool(env, gl="PILSVG") + renderer.renderEnv(show=True, show_observations=True) + + iteration = -1 + expectations = { + 0: { + 'positions': { + 0: (3, 8), + 1: (5, 6), + }, + 'rewards': [-1, -1], + }, + 1: { + 'positions': { + 0: (3, 7), + 1: (4, 6), + }, + 'rewards': [-1, -1], + }, + # second agent has to wait for first, first can continue + 2: { + 'positions': { + 0: (3, 6), + 1: (4, 6), + }, + 'rewards': [-1, -1], + }, + # both can move again + 3: { + 'positions': { + 0: (3, 5), + 1: (3, 6), + }, + 'rewards': [-1, -1], + }, + 4: { + 'positions': { + 0: (3, 4), + 1: (3, 7), + }, + 'rewards': [-1, -1], + }, + # second reached target + 5: { + 'positions': { + 0: (3, 3), + 1: (3, 8), + }, + 'rewards': [-1, 0], + }, + 6: { + 'positions': { + 0: (3, 2), + 1: (3, 8), + }, + 'rewards': [-1, 0], + }, + # first reaches, target too + 7: { + 'positions': { + 0: (3, 1), + 1: (5, 6), + }, + 'rewards': [1, 1], + }, + 8: { + 'positions': { + 0: (3, 1), + 1: (5, 6), + }, + 'rewards': [1, 1], + }, + } + while not env.dones["__all__"] and iteration + 1 < 5: + iteration += 1 + rewards = _step_along_shortest_path(env, obs_builder, rail) + + if rendering: + renderer.renderEnv(show=True, show_observations=True) + + print(env.dones["__all__"]) + for agent in env.agents: + agent: EnvAgent + print("[{}] agent {} at {}, target {} ".format(iteration + 1, agent.handle, agent.position, agent.target)) + print(np.all([np.array_equal(agent2.position, agent2.target) for agent2 in env.agents])) + for agent in env.agents: + expected_position = expectations[iteration + 1]['positions'][agent.handle] + assert agent.position == expected_position, \ + "[{}] agent {} at {}, expected {}".format(iteration + 1, + agent.handle, + agent.position, + expected_position) + expected_reward = expectations[iteration + 1]['rewards'][agent.handle] + actual_reward = rewards[agent.handle] + assert expected_reward == actual_reward, "[{}] agent {} reward {}, expected {}".format(iteration + 1, + agent.handle, + actual_reward, + expected_reward)