From eb66349c041f2245341dab67467aa9ccdfeb0aa0 Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Sat, 5 Oct 2019 09:07:04 -0400 Subject: [PATCH] code cleanup --- flatland/envs/distance_map.py | 1 - flatland/envs/grid4_generators_utils.py | 5 +++-- flatland/envs/observations.py | 3 +-- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/flatland/envs/distance_map.py b/flatland/envs/distance_map.py index 22721407..2bc1a511 100644 --- a/flatland/envs/distance_map.py +++ b/flatland/envs/distance_map.py @@ -31,7 +31,6 @@ class DistanceMap: if self.reset_was_called: self.reset_was_called = False - nb_agents = len(self.agents) compute_distance_map = True # Don't compute the distance map if it was loaded if self.agents_previous_computation is None and self.distance_map is not None: diff --git a/flatland/envs/grid4_generators_utils.py b/flatland/envs/grid4_generators_utils.py index 0b866e33..04903d3a 100644 --- a/flatland/envs/grid4_generators_utils.py +++ b/flatland/envs/grid4_generators_utils.py @@ -19,7 +19,8 @@ def connect_rail_in_grid_map(grid_map: GridTransitionMap, start: IntVector2D, en rail_trans: RailEnvTransitions, a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance, flip_start_node_trans: bool = False, flip_end_node_trans: bool = False, - respect_transition_validity: bool = True, forbidden_cells: IntVector2DArray = None) -> IntVector2DArray: + respect_transition_validity: bool = True, + forbidden_cells: IntVector2DArray = None) -> IntVector2DArray: """ Creates a new path [start,end] in `grid_map.grid`, based on rail_trans, and returns the path created as a list of positions. @@ -109,7 +110,7 @@ def connect_straight_line_in_grid_map(grid_map: GridTransitionMap, start: IntVec length = np.abs(end[0] - start[0]) + 1 cols = np.repeat(start[1], length) - else: # Grid4TransitionsEnum.EAST or Grid4TransitionsEnum.WEST + else: # Grid4TransitionsEnum.EAST or Grid4TransitionsEnum.WEST start_col = min(start[1], end[1]) end_col = max(start[1], end[1]) + 1 cols = np.arange(start_col, end_col) diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index caa3329a..8cf4d875 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -67,7 +67,7 @@ class TreeObsForRailEnv(ObservationBuilder): self.predicted_dir = {} self.predictions = self.predictor.get() if self.predictions: - for t in range(self.predictor.max_depth+1): + for t in range(self.predictor.max_depth + 1): pos_list = [] dir_list = [] for a in handles: @@ -162,7 +162,6 @@ class TreeObsForRailEnv(ObservationBuilder): In case the target node is reached, the values are [0, 0, 0, 0, 0]. """ - # Update local lookup table for all agents' positions # ignore other agents not in the grid (only status active and done) self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents if -- GitLab