diff --git a/flatland/core/transitions.py b/flatland/core/transitions.py index 29b57c40faf567e6a9aa4b679df6af6fcf0909ba..5049c23bff3667a08375fee270a8867ca013c467 100644 --- a/flatland/core/transitions.py +++ b/flatland/core/transitions.py @@ -12,7 +12,7 @@ class Transitions: Generic class that implements checks to control whether a certain transition is allowed (agent facing a direction - `orientation' and moving into direction `direction') + `orientation' and moving into direction `orientation') """ def get_type(self): diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 8ed455cac857d292b1f63f6edfbfd4a68f4adf8e..a4903853bab30d185293892010c714d050f63439 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -28,7 +28,7 @@ class TreeObsForRailEnv(ObservationBuilder): for i in range(self.max_depth + 1): size += pow4 pow4 *= 4 - self.observation_dim = 8 + self.observation_dim = 9 self.observation_space = [size * self.observation_dim] self.location_has_agent = {} self.location_has_agent_direction = {} @@ -223,24 +223,29 @@ class TreeObsForRailEnv(ObservationBuilder): #3: if another agent is detected the distance in number of cells from current agent position is stored. - #4: This feature stores the distance in number of cells to the next branching store (current node) + #4: possible conflict detected + tot_dist = Other agent predicts to pass along this cell at the same time as the agent, we store the + distance in number of cells from current agent position - #5: minimum distance from node to the agent's target given the direction of the agent if this path is chosen + 0 = No other agent reserve the same cell at similar time + + #5: if an not usable switch (for agent) is detected we store the distance. + + #6: This feature stores the distance in number of cells to the next branching (current node) - #6: agent in the same direction + #7: minimum distance from node to the agent's target given the direction of the agent if this path is chosen + + #8: agent in the same direction n = number of agents present same direction (possible future use: number of other agents in the same direction in this branch) 0 = no agent present same direction - #7: agent in the opposite drection + #9: agent in the opposite drection n = number of agents present other direction than myself (so conflict) (possible future use: number of other agents in other direction in this branch, ie. number of conflicts) 0 = no agent present other direction than myself - #8: possible conflict detected - 1 = Other agent predicts to pass along this cell at the same time as the agent - 0 = No other agent reserve the same cell at similar time Missing/padding nodes are filled in with -inf (truncated). @@ -261,7 +266,7 @@ class TreeObsForRailEnv(ObservationBuilder): num_transitions = np.count_nonzero(possible_transitions) # Root node - current position - observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0, 0] + observation = [0, 0, 0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0] root_observation = observation[:] visited = set() @@ -278,7 +283,7 @@ class TreeObsForRailEnv(ObservationBuilder): if possible_transitions[branch_direction]: new_cell = self._new_position(agent.position, branch_direction) branch_observation, branch_visited = \ - self._explore_branch(handle, new_cell, branch_direction, root_observation, 0, 1) + self._explore_branch(handle, new_cell, branch_direction, root_observation, 1, 1) observation = observation + branch_observation visited = visited.union(branch_visited) else: @@ -294,6 +299,8 @@ class TreeObsForRailEnv(ObservationBuilder): def _explore_branch(self, handle, position, direction, root_observation, tot_dist, depth): """ Utility function to compute tree-based observations. + We walk along the branch and collect the information documented in the get() function. + If there is a branching point a new node is created and each possible branch is explored. """ # [Recursive branch opened] if depth >= self.max_depth + 1: @@ -313,9 +320,11 @@ class TreeObsForRailEnv(ObservationBuilder): own_target_encountered = np.inf other_agent_encountered = np.inf other_target_encountered = np.inf + potential_conflict = np.inf + unusable_switch = np.inf other_agent_same_direction = 0 other_agent_opposite_direction = 0 - potential_conflict = 0 + num_steps = 1 while exploring: # ############################# @@ -323,8 +332,8 @@ class TreeObsForRailEnv(ObservationBuilder): # Modify here to compute any useful data required to build the end node's features. This code is called # for each cell visited between the previous branching node and the next switch / target / dead-end. if position in self.location_has_agent: - if num_steps < other_agent_encountered: - other_agent_encountered = num_steps + if tot_dist < other_agent_encountered: + other_agent_encountered = tot_dist if self.location_has_agent_direction[position] == direction: # Cummulate the number of agents on branch with same direction @@ -342,31 +351,32 @@ class TreeObsForRailEnv(ObservationBuilder): post_step = min(self.max_prediction_depth - 1, tot_dist + 1) # Look for opposing paths at distance num_step - if int_position in np.delete(self.predicted_pos[tot_dist], handle): - conflicting_agent = np.where(np.delete(self.predicted_pos[tot_dist], handle) == int_position) - for ca in conflicting_agent: - if direction != self.predicted_dir[tot_dist][ca[0]]: - potential_conflict = 1 + if int_position in np.delete(self.predicted_pos[tot_dist], handle, 0): + conflicting_agent = np.where(self.predicted_pos[tot_dist] == int_position) + for ca in conflicting_agent[0]: + + if direction != self.predicted_dir[tot_dist][ca] and tot_dist < potential_conflict: + potential_conflict = tot_dist # Look for opposing paths at distance num_step-1 - elif int_position in np.delete(self.predicted_pos[pre_step], handle): + elif int_position in np.delete(self.predicted_pos[pre_step], handle, 0): conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position) - for ca in conflicting_agent: - if direction != self.predicted_dir[pre_step][ca[0]]: - potential_conflict = 1 + for ca in conflicting_agent[0]: + if direction != self.predicted_dir[pre_step][ca] and tot_dist < potential_conflict: + potential_conflict = tot_dist # Look for opposing paths at distance num_step+1 - elif int_position in np.delete(self.predicted_pos[post_step], handle): - conflicting_agent = np.where(np.delete(self.predicted_pos[post_step], handle) == int_position) - for ca in conflicting_agent: - if direction != self.predicted_dir[post_step][ca[0]]: - potential_conflict = 1 + elif int_position in np.delete(self.predicted_pos[post_step], handle, 0): + conflicting_agent = np.where(self.predicted_pos[post_step] == int_position) + for ca in conflicting_agent[0]: + if direction != self.predicted_dir[post_step][ca] and tot_dist < potential_conflict: + potential_conflict = tot_dist if position in self.location_has_target and position != agent.target: - if num_steps < other_target_encountered: - other_target_encountered = num_steps + if tot_dist < other_target_encountered: + other_target_encountered = tot_dist if position == agent.target: - if num_steps < own_target_encountered: - own_target_encountered = num_steps + if tot_dist < own_target_encountered: + own_target_encountered = tot_dist # ############################# # ############################# @@ -382,8 +392,13 @@ class TreeObsForRailEnv(ObservationBuilder): break cell_transitions = self.env.rail.get_transitions((*position, direction)) + total_transitions = bin(self.env.rail.get_transitions(position)).count("1") num_transitions = np.count_nonzero(cell_transitions) exploring = False + # Detect Switches that can only be used by other agents. + if total_transitions > 2 > num_transitions: + unusable_switch = tot_dist + if num_transitions == 1: # Check if dead-end, or if we can go forward along direction nbits = 0 @@ -422,72 +437,40 @@ class TreeObsForRailEnv(ObservationBuilder): # ############################# # ############################# # Modify here to append new / different features for each visited cell! - """ - other_agent_same_direction = \ - 1 if position in self.location_has_agent and self.location_has_agent_direction[position] == direction else 0 - other_agent_opposite_direction = \ - 1 if position in self.location_has_agent and self.location_has_agent_direction[position] != direction else 0 if last_isTarget: - observation = [0, + observation = [own_target_encountered, other_target_encountered, other_agent_encountered, - root_observation[3] + num_steps, + potential_conflict, + unusable_switch, + tot_dist, 0, other_agent_same_direction, other_agent_opposite_direction ] elif last_isTerminal: - observation = [0, + observation = [own_target_encountered, other_target_encountered, other_agent_encountered, + potential_conflict, + unusable_switch, np.inf, - np.inf, - other_agent_same_direction, - other_agent_opposite_direction - ] - else: - observation = [0, - other_target_encountered, - other_agent_encountered, - root_observation[3] + num_steps, self.distance_map[handle, position[0], position[1], direction], other_agent_same_direction, other_agent_opposite_direction ] - """ - - if last_isTarget: - observation = [own_target_encountered, - other_target_encountered, - other_agent_encountered, - root_observation[3] + num_steps, - 0, - other_agent_same_direction, - other_agent_opposite_direction, - potential_conflict - ] - - elif last_isTerminal: - observation = [own_target_encountered, - other_target_encountered, - other_agent_encountered, - np.inf, - np.inf, - other_agent_same_direction, - other_agent_opposite_direction, - potential_conflict - ] else: observation = [own_target_encountered, other_target_encountered, other_agent_encountered, - root_observation[3] + num_steps, + potential_conflict, + unusable_switch, + tot_dist, self.distance_map[handle, position[0], position[1], direction], other_agent_same_direction, other_agent_opposite_direction, - potential_conflict ] # ############################# # ############################# @@ -531,7 +514,7 @@ class TreeObsForRailEnv(ObservationBuilder): return observation, visited - def util_print_obs_subtree(self, tree, num_features_per_node=8, prompt='', current_depth=0): + def util_print_obs_subtree(self, tree, num_features_per_node=9, prompt='', current_depth=0): """ Utility function to pretty-print tree observations returned by this object. """ diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py index 654f5490da875afe746d516892a773308b466589..d471596bba9e6128ea14f78bd8b625d321c255fe 100644 --- a/flatland/envs/predictions.py +++ b/flatland/envs/predictions.py @@ -142,7 +142,8 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): min_dist = np.inf for direction in range(4): if cell_transitions[direction] == 1: - target_dist = distance_map[agent.handle, agent.position[0], agent.position[1], direction] + neighbour_cell = get_new_position(agent.position, direction) + target_dist = distance_map[agent.handle, neighbour_cell[0], neighbour_cell[1], direction] if target_dist < min_dist: min_dist = target_dist new_direction = direction @@ -150,21 +151,12 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): else: raise Exception("No transition possible {}".format(cell_transitions)) - # which action to take for the transition? - action = None - for _action in [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_RIGHT, RailEnvActions.MOVE_LEFT]: - _, _, _new_direction, _new_position, _ = self.env._check_action_on_agent(_action, agent) - if np.array_equal(_new_position, new_position): - action = _action - break - assert action is not None - # update the agent's position and direction agent.position = new_position agent.direction = new_direction # prediction is ready - prediction[index] = [index, *new_position, new_direction, action] + prediction[index] = [index, *new_position, new_direction, 0] prediction_dict[agent.handle] = prediction # cleanup: reset initial position diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index 7932aacafe5c67e188afa9f52a90ec47f73ff8da..537d8be832e706f2ec48a89ec006a8fc96806724 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -38,7 +38,7 @@ class RenderTool(object): gTheta = np.linspace(0, np.pi / 2, 5) gArc = array([np.cos(gTheta), np.sin(gTheta)]).T # from [1,0] to [0,1] - def __init__(self, env, gl="PILSVG", jupyter=False, agentRenderVariant=AgentRenderVariant.AGENT_SHOWS_OPTIONS): + def __init__(self, env, gl="PILSVG", jupyter=False, agentRenderVariant=AgentRenderVariant.ONE_STEP_BEHIND): self.env = env self.iFrame = 0 self.time1 = time.time() diff --git a/notebooks/Scene_Editor.ipynb b/notebooks/Scene_Editor.ipynb index a36b4e159a75edd2d18a5f7ed42ec41dbc5b9515..64bbad77d3a9b63caea062795cd0ab8f0cce72f3 100644 --- a/notebooks/Scene_Editor.ipynb +++ b/notebooks/Scene_Editor.ipynb @@ -70,7 +70,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "4153480c3fbe4fd9864a4d52b416477c", + "model_id": "0b92e7084e37450cbfb3855dc3a58543", "version_major": 2, "version_minor": 0 }, diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py index 16850672c1f5a479f1cf86ca3e3f6c547c4a4ca7..c5514bd09b296796c93cb4586c5640eb05569164 100644 --- a/tests/test_flatland_envs_predictions.py +++ b/tests/test_flatland_envs_predictions.py @@ -9,7 +9,6 @@ from flatland.envs.generators import rail_from_GridTransitionMap_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv -from flatland.envs.rail_env import RailEnvActions from flatland.utils.rendertools import RenderTool """Test predictions for `flatland` package.""" @@ -187,7 +186,6 @@ def test_shortest_path_predictor(rendering=False): positions = np.array(list(map(lambda prediction: [*prediction[1:3]], predictions[0]))) directions = np.array(list(map(lambda prediction: [prediction[3]], predictions[0]))) time_offsets = np.array(list(map(lambda prediction: [prediction[0]], predictions[0]))) - actions = np.array(list(map(lambda prediction: [prediction[4]], predictions[0]))) expected_positions = [ [5, 6], @@ -260,35 +258,9 @@ def test_shortest_path_predictor(rendering=False): [20.], ]) - expected_actions = np.array([ - [RailEnvActions.DO_NOTHING], # next [5,6] - [RailEnvActions.MOVE_FORWARD], # next [4,6] - [RailEnvActions.MOVE_FORWARD], # next [3,6] - [RailEnvActions.MOVE_RIGHT], # next [3,7] - [RailEnvActions.MOVE_FORWARD], # next [3,8] - [RailEnvActions.MOVE_FORWARD], # next [3,9] - [RailEnvActions.STOP_MOVING], # at [3,9] == target - [RailEnvActions.STOP_MOVING], - [RailEnvActions.STOP_MOVING], - [RailEnvActions.STOP_MOVING], - [RailEnvActions.STOP_MOVING], - [RailEnvActions.STOP_MOVING], - [RailEnvActions.STOP_MOVING], - [RailEnvActions.STOP_MOVING], - [RailEnvActions.STOP_MOVING], - [RailEnvActions.STOP_MOVING], - [RailEnvActions.STOP_MOVING], - [RailEnvActions.STOP_MOVING], - [RailEnvActions.STOP_MOVING], - [RailEnvActions.STOP_MOVING], - [RailEnvActions.STOP_MOVING], - ]) - assert np.array_equal(positions, expected_positions), \ "positions {}, expected {}".format(positions, expected_positions) assert np.array_equal(directions, expected_directions), \ "directions {}, expected {}".format(directions, expected_directions) assert np.array_equal(time_offsets, expected_time_offsets), \ "time_offsets {}, expected {}".format(time_offsets, expected_time_offsets) - assert np.array_equal(actions, expected_actions), \ - "actions {}, expected {}".format(actions, expected_actions)