From 4f517e569551cca0c1660be2179b3f6a73bc7751 Mon Sep 17 00:00:00 2001 From: Guillaume Mollard <guillaume.mollard2@gmail.com> Date: Tue, 23 Apr 2019 10:10:43 +0200 Subject: [PATCH] solved flake8 bugs --- examples/temporary_example.py | 4 - flatland/core/env_observation_builder.py | 111 ++++++++++------------- tests/test_env_observation_builder.py | 2 - tests/test_environments.py | 8 -- 4 files changed, 50 insertions(+), 75 deletions(-) diff --git a/examples/temporary_example.py b/examples/temporary_example.py index 2ea68cf..2444d38 100644 --- a/examples/temporary_example.py +++ b/examples/temporary_example.py @@ -1,7 +1,3 @@ -import random -import numpy as np -import matplotlib.pyplot as plt - from flatland.core.env import RailEnv from flatland.utils.rail_env_generator import * from flatland.utils.rendertools import * diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py index 6e5dbbb..1f97ff2 100644 --- a/flatland/core/env_observation_builder.py +++ b/flatland/core/env_observation_builder.py @@ -29,7 +29,6 @@ class TreeObsForRailEnv(ObservationBuilder): for i in range(self.env.number_of_agents): self.max_dist[i] = self._distance_map_walker(self.env.agents_target[i], i) - def _distance_map_walker(self, position, target_nr): # Returns max distance to target, from the farthest away node, while filling in distance_map @@ -55,9 +54,6 @@ class TreeObsForRailEnv(ObservationBuilder): node_id = (node[0], node[1], node[2]) - #print(node_id, visited, (node_id in visited)) - #print(nodes_queue) - if node_id not in visited: visited.add(node_id) @@ -70,20 +66,18 @@ class TreeObsForRailEnv(ObservationBuilder): for n in valid_neighbors: nodes_queue.append(n) - if len(valid_neighbors)>0: + if len(valid_neighbors) > 0: max_distance = max(max_distance, node[3]+1) return max_distance - def _get_and_update_neighbors(self, position, target_nr, current_distance, enforce_target_direction=-1): neighbors = [] for direction in range(4): - new_cell = self._new_position(position, (direction+2)%4) + new_cell = self._new_position(position, (direction+2) % 4) - if new_cell[0]>=0 and new_cell[0]<self.env.height and\ - new_cell[1]>=0 and new_cell[1]<self.env.width: + if new_cell[0] >= 0 and new_cell[0] < self.env.height and new_cell[1] >= 0 and new_cell[1] < self.env.width: # Check if the two cells are connected by a valid transition transitionValid = False for orientation in range(4): @@ -99,7 +93,7 @@ class TreeObsForRailEnv(ObservationBuilder): # lands in the current cell with orientation `direction'; this only # applies to cells that are not dead-ends! directionMatch = True - if enforce_target_direction>=0: + if enforce_target_direction >= 0: directionMatch = self.env.rail.get_transition( (new_cell[0], new_cell[1], direction), enforce_target_direction) @@ -119,7 +113,7 @@ class TreeObsForRailEnv(ObservationBuilder): # Check if transition is possible in new_cell # with orientation (direction+2)%4 in direction `direction' directionMatch = directionMatch or self.env.rail.get_transition( - (new_cell[0], new_cell[1], (direction+2)%4), direction) + (new_cell[0], new_cell[1], (direction+2) % 4), direction) if transitionValid and directionMatch: new_distance = min(self.distance_map[target_nr, @@ -139,7 +133,6 @@ class TreeObsForRailEnv(ObservationBuilder): elif movement == 3: # WEST return (position[0], position[1] - 1) - def get(self, handle): # TODO: compute the observation for agent `handle' return [] @@ -193,7 +186,52 @@ class GlobalObsForRailEnv(ObservationBuilder): return self.rail_obs, obs_agents_targets_pos, direction +class Tree_State: + """ + Keep track of the current state while building the tree + """ + def __init__(self, agent, position, direction, depth, distance): + self.agent = agent + self.position = position + self.direction = direction + self.depth = depth + self.initial_direction = None + self.distance = distance + self.data = [np.inf, np.inf, np.inf, np.inf, np.inf] + + +class Node(): + """ + Define a tree node to get populated during search + """ + def __init__(self, position, data): + self.n_children = 4 + self.children = [None]*self.n_children + self.data = data + self.position = position + + def insert(self, position, data, child_idx): + """ + Insert new node with data + + @param data node data object to insert + """ + new_node = Node(position, data) + self.children[child_idx] = new_node + return new_node + def print_tree(self, i=0, depth=0): + """ + Print tree content inorder + """ + current_i = i + curr_depth = depth+1 + if i < self.n_children: + if self.children[i] is not None: + self.children[i].print_tree(depth=curr_depth) + current_i += 1 + if self.children[i] is not None: + self.children[i].print_tree(i, depth=curr_depth) """ @@ -339,52 +377,3 @@ class GlobalObsForRailEnv(ObservationBuilder): """ - - -class Tree_State: - """ - Keep track of the current state while building the tree - """ - def __init__(self, agent, position, direction, depth, distance): - self.agent = agent - self.position = position - self.direction = direction - self.depth = depth - self.initial_direction = None - self.distance = distance - self.data = [np.inf, np.inf, np.inf, np.inf, np.inf] - -class Node(): - """ - Define a tree node to get populated during search - """ - def __init__(self, position, data): - self.n_children = 4 - self.children = [None]*self.n_children - self.data = data - self.position = position - - def insert(self, position, data, child_idx): - """ - Insert new node with data - - @param data node data object to insert - """ - new_node = Node(position, data) - self.children[child_idx] = new_node - return new_node - - def print_tree(self, i=0, depth = 0): - """ - Print tree content inorder - """ - current_i = i - curr_depth = depth+1 - if i < self.n_children: - if self.children[i] != None: - self.children[i].print_tree(depth=curr_depth) - current_i += 1 - if self.children[i] != None: - self.children[i].print_tree(i, depth=curr_depth) - - diff --git a/tests/test_env_observation_builder.py b/tests/test_env_observation_builder.py index a89df6c..1a797de 100644 --- a/tests/test_env_observation_builder.py +++ b/tests/test_env_observation_builder.py @@ -2,10 +2,8 @@ # -*- coding: utf-8 -*- from flatland.core.env_observation_builder import GlobalObsForRailEnv -# from flatland.core.transitions import Grid4Transitions from flatland.core.transition_map import GridTransitionMap, Grid4Transitions from flatland.core.env import RailEnv -import numpy as np from flatland.utils.rendertools import * """Tests for `flatland` package.""" diff --git a/tests/test_environments.py b/tests/test_environments.py index 03544b0..a89133d 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -177,11 +177,3 @@ def test_dead_end(): rail_env.agents_position[0] = [2, 0] rail_env.agents_direction[0] = 0 check_consistency(rail_env) - - - - - - -test_dead_end() - -- GitLab