From 96f670b38e8d8b4ce039960755102e7e6e26f82c Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Thu, 4 Jul 2019 16:29:01 -0400 Subject: [PATCH] Major update to shortest path predictor major update and bugfix for collision detection --- flatland/core/transitions.py | 2 +- flatland/envs/observations.py | 60 ++++++------------------- flatland/envs/predictions.py | 14 ++---- tests/test_flatland_envs_predictions.py | 28 ------------ 4 files changed, 17 insertions(+), 87 deletions(-) diff --git a/flatland/core/transitions.py b/flatland/core/transitions.py index 29b57c4..5049c23 100644 --- a/flatland/core/transitions.py +++ b/flatland/core/transitions.py @@ -12,7 +12,7 @@ class Transitions: Generic class that implements checks to control whether a certain transition is allowed (agent facing a direction - `orientation' and moving into direction `direction') + `orientation' and moving into direction `orientation') """ def get_type(self): diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 8222115..a490385 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -283,7 +283,7 @@ class TreeObsForRailEnv(ObservationBuilder): if possible_transitions[branch_direction]: new_cell = self._new_position(agent.position, branch_direction) branch_observation, branch_visited = \ - self._explore_branch(handle, new_cell, branch_direction, root_observation, 0, 1) + self._explore_branch(handle, new_cell, branch_direction, root_observation, 1, 1) observation = observation + branch_observation visited = visited.union(branch_visited) else: @@ -351,22 +351,23 @@ class TreeObsForRailEnv(ObservationBuilder): post_step = min(self.max_prediction_depth - 1, tot_dist + 1) # Look for opposing paths at distance num_step - if int_position in np.delete(self.predicted_pos[tot_dist], handle): - conflicting_agent = np.where(np.delete(self.predicted_pos[tot_dist], handle) == int_position) - for ca in conflicting_agent: - if direction != self.predicted_dir[tot_dist][ca[0]] and tot_dist < potential_conflict: + if int_position in np.delete(self.predicted_pos[tot_dist], handle, 0): + conflicting_agent = np.where(self.predicted_pos[tot_dist] == int_position) + for ca in conflicting_agent[0]: + + if direction != self.predicted_dir[tot_dist][ca] and tot_dist < potential_conflict: potential_conflict = tot_dist # Look for opposing paths at distance num_step-1 - elif int_position in np.delete(self.predicted_pos[pre_step], handle): + elif int_position in np.delete(self.predicted_pos[pre_step], handle, 0): conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position) - for ca in conflicting_agent: - if direction != self.predicted_dir[pre_step][ca[0]] and tot_dist < potential_conflict: + for ca in conflicting_agent[0]: + if direction != self.predicted_dir[pre_step][ca] and tot_dist < potential_conflict: potential_conflict = tot_dist # Look for opposing paths at distance num_step+1 - elif int_position in np.delete(self.predicted_pos[post_step], handle): - conflicting_agent = np.where(np.delete(self.predicted_pos[post_step], handle) == int_position) - for ca in conflicting_agent: - if direction != self.predicted_dir[post_step][ca[0]] and tot_dist < potential_conflict: + elif int_position in np.delete(self.predicted_pos[post_step], handle, 0): + conflicting_agent = np.where(self.predicted_pos[post_step] == int_position) + for ca in conflicting_agent[0]: + if direction != self.predicted_dir[post_step][ca] and tot_dist < potential_conflict: potential_conflict = tot_dist if position in self.location_has_target and position != agent.target: @@ -436,41 +437,6 @@ class TreeObsForRailEnv(ObservationBuilder): # ############################# # ############################# # Modify here to append new / different features for each visited cell! - """ - other_agent_same_direction = \ - 1 if position in self.location_has_agent and self.location_has_agent_direction[position] == direction else 0 - other_agent_opposite_direction = \ - 1 if position in self.location_has_agent and self.location_has_agent_direction[position] != direction else 0 - - if last_isTarget: - observation = [0, - other_target_encountered, - other_agent_encountered, - root_observation[3] + num_steps, - 0, - other_agent_same_direction, - other_agent_opposite_direction - ] - - elif last_isTerminal: - observation = [0, - other_target_encountered, - other_agent_encountered, - np.inf, - np.inf, - other_agent_same_direction, - other_agent_opposite_direction - ] - else: - observation = [0, - other_target_encountered, - other_agent_encountered, - root_observation[3] + num_steps, - self.distance_map[handle, position[0], position[1], direction], - other_agent_same_direction, - other_agent_opposite_direction - ] - """ if last_isTarget: observation = [own_target_encountered, diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py index 654f549..d471596 100644 --- a/flatland/envs/predictions.py +++ b/flatland/envs/predictions.py @@ -142,7 +142,8 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): min_dist = np.inf for direction in range(4): if cell_transitions[direction] == 1: - target_dist = distance_map[agent.handle, agent.position[0], agent.position[1], direction] + neighbour_cell = get_new_position(agent.position, direction) + target_dist = distance_map[agent.handle, neighbour_cell[0], neighbour_cell[1], direction] if target_dist < min_dist: min_dist = target_dist new_direction = direction @@ -150,21 +151,12 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): else: raise Exception("No transition possible {}".format(cell_transitions)) - # which action to take for the transition? - action = None - for _action in [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_RIGHT, RailEnvActions.MOVE_LEFT]: - _, _, _new_direction, _new_position, _ = self.env._check_action_on_agent(_action, agent) - if np.array_equal(_new_position, new_position): - action = _action - break - assert action is not None - # update the agent's position and direction agent.position = new_position agent.direction = new_direction # prediction is ready - prediction[index] = [index, *new_position, new_direction, action] + prediction[index] = [index, *new_position, new_direction, 0] prediction_dict[agent.handle] = prediction # cleanup: reset initial position diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py index 1685067..c5514bd 100644 --- a/tests/test_flatland_envs_predictions.py +++ b/tests/test_flatland_envs_predictions.py @@ -9,7 +9,6 @@ from flatland.envs.generators import rail_from_GridTransitionMap_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv -from flatland.envs.rail_env import RailEnvActions from flatland.utils.rendertools import RenderTool """Test predictions for `flatland` package.""" @@ -187,7 +186,6 @@ def test_shortest_path_predictor(rendering=False): positions = np.array(list(map(lambda prediction: [*prediction[1:3]], predictions[0]))) directions = np.array(list(map(lambda prediction: [prediction[3]], predictions[0]))) time_offsets = np.array(list(map(lambda prediction: [prediction[0]], predictions[0]))) - actions = np.array(list(map(lambda prediction: [prediction[4]], predictions[0]))) expected_positions = [ [5, 6], @@ -260,35 +258,9 @@ def test_shortest_path_predictor(rendering=False): [20.], ]) - expected_actions = np.array([ - [RailEnvActions.DO_NOTHING], # next [5,6] - [RailEnvActions.MOVE_FORWARD], # next [4,6] - [RailEnvActions.MOVE_FORWARD], # next [3,6] - [RailEnvActions.MOVE_RIGHT], # next [3,7] - [RailEnvActions.MOVE_FORWARD], # next [3,8] - [RailEnvActions.MOVE_FORWARD], # next [3,9] - [RailEnvActions.STOP_MOVING], # at [3,9] == target - [RailEnvActions.STOP_MOVING], - [RailEnvActions.STOP_MOVING], - [RailEnvActions.STOP_MOVING], - [RailEnvActions.STOP_MOVING], - [RailEnvActions.STOP_MOVING], - [RailEnvActions.STOP_MOVING], - [RailEnvActions.STOP_MOVING], - [RailEnvActions.STOP_MOVING], - [RailEnvActions.STOP_MOVING], - [RailEnvActions.STOP_MOVING], - [RailEnvActions.STOP_MOVING], - [RailEnvActions.STOP_MOVING], - [RailEnvActions.STOP_MOVING], - [RailEnvActions.STOP_MOVING], - ]) - assert np.array_equal(positions, expected_positions), \ "positions {}, expected {}".format(positions, expected_positions) assert np.array_equal(directions, expected_directions), \ "directions {}, expected {}".format(directions, expected_directions) assert np.array_equal(time_offsets, expected_time_offsets), \ "time_offsets {}, expected {}".format(time_offsets, expected_time_offsets) - assert np.array_equal(actions, expected_actions), \ - "actions {}, expected {}".format(actions, expected_actions) -- GitLab