Skip to content
Snippets Groups Projects
Commit 4f517e56 authored by gmollard's avatar gmollard
Browse files

solved flake8 bugs

parent c07e281f
No related branches found
No related tags found
No related merge requests found
import random
import numpy as np
import matplotlib.pyplot as plt
from flatland.core.env import RailEnv
from flatland.utils.rail_env_generator import *
from flatland.utils.rendertools import *
......
......@@ -29,7 +29,6 @@ class TreeObsForRailEnv(ObservationBuilder):
for i in range(self.env.number_of_agents):
self.max_dist[i] = self._distance_map_walker(self.env.agents_target[i], i)
def _distance_map_walker(self, position, target_nr):
# Returns max distance to target, from the farthest away node, while filling in distance_map
......@@ -55,9 +54,6 @@ class TreeObsForRailEnv(ObservationBuilder):
node_id = (node[0], node[1], node[2])
#print(node_id, visited, (node_id in visited))
#print(nodes_queue)
if node_id not in visited:
visited.add(node_id)
......@@ -70,20 +66,18 @@ class TreeObsForRailEnv(ObservationBuilder):
for n in valid_neighbors:
nodes_queue.append(n)
if len(valid_neighbors)>0:
if len(valid_neighbors) > 0:
max_distance = max(max_distance, node[3]+1)
return max_distance
def _get_and_update_neighbors(self, position, target_nr, current_distance, enforce_target_direction=-1):
neighbors = []
for direction in range(4):
new_cell = self._new_position(position, (direction+2)%4)
new_cell = self._new_position(position, (direction+2) % 4)
if new_cell[0]>=0 and new_cell[0]<self.env.height and\
new_cell[1]>=0 and new_cell[1]<self.env.width:
if new_cell[0] >= 0 and new_cell[0] < self.env.height and new_cell[1] >= 0 and new_cell[1] < self.env.width:
# Check if the two cells are connected by a valid transition
transitionValid = False
for orientation in range(4):
......@@ -99,7 +93,7 @@ class TreeObsForRailEnv(ObservationBuilder):
# lands in the current cell with orientation `direction'; this only
# applies to cells that are not dead-ends!
directionMatch = True
if enforce_target_direction>=0:
if enforce_target_direction >= 0:
directionMatch = self.env.rail.get_transition(
(new_cell[0], new_cell[1], direction), enforce_target_direction)
......@@ -119,7 +113,7 @@ class TreeObsForRailEnv(ObservationBuilder):
# Check if transition is possible in new_cell
# with orientation (direction+2)%4 in direction `direction'
directionMatch = directionMatch or self.env.rail.get_transition(
(new_cell[0], new_cell[1], (direction+2)%4), direction)
(new_cell[0], new_cell[1], (direction+2) % 4), direction)
if transitionValid and directionMatch:
new_distance = min(self.distance_map[target_nr,
......@@ -139,7 +133,6 @@ class TreeObsForRailEnv(ObservationBuilder):
elif movement == 3: # WEST
return (position[0], position[1] - 1)
def get(self, handle):
# TODO: compute the observation for agent `handle'
return []
......@@ -193,7 +186,52 @@ class GlobalObsForRailEnv(ObservationBuilder):
return self.rail_obs, obs_agents_targets_pos, direction
class Tree_State:
"""
Keep track of the current state while building the tree
"""
def __init__(self, agent, position, direction, depth, distance):
self.agent = agent
self.position = position
self.direction = direction
self.depth = depth
self.initial_direction = None
self.distance = distance
self.data = [np.inf, np.inf, np.inf, np.inf, np.inf]
class Node():
"""
Define a tree node to get populated during search
"""
def __init__(self, position, data):
self.n_children = 4
self.children = [None]*self.n_children
self.data = data
self.position = position
def insert(self, position, data, child_idx):
"""
Insert new node with data
@param data node data object to insert
"""
new_node = Node(position, data)
self.children[child_idx] = new_node
return new_node
def print_tree(self, i=0, depth=0):
"""
Print tree content inorder
"""
current_i = i
curr_depth = depth+1
if i < self.n_children:
if self.children[i] is not None:
self.children[i].print_tree(depth=curr_depth)
current_i += 1
if self.children[i] is not None:
self.children[i].print_tree(i, depth=curr_depth)
"""
......@@ -339,52 +377,3 @@ class GlobalObsForRailEnv(ObservationBuilder):
"""
class Tree_State:
"""
Keep track of the current state while building the tree
"""
def __init__(self, agent, position, direction, depth, distance):
self.agent = agent
self.position = position
self.direction = direction
self.depth = depth
self.initial_direction = None
self.distance = distance
self.data = [np.inf, np.inf, np.inf, np.inf, np.inf]
class Node():
"""
Define a tree node to get populated during search
"""
def __init__(self, position, data):
self.n_children = 4
self.children = [None]*self.n_children
self.data = data
self.position = position
def insert(self, position, data, child_idx):
"""
Insert new node with data
@param data node data object to insert
"""
new_node = Node(position, data)
self.children[child_idx] = new_node
return new_node
def print_tree(self, i=0, depth = 0):
"""
Print tree content inorder
"""
current_i = i
curr_depth = depth+1
if i < self.n_children:
if self.children[i] != None:
self.children[i].print_tree(depth=curr_depth)
current_i += 1
if self.children[i] != None:
self.children[i].print_tree(i, depth=curr_depth)
......@@ -2,10 +2,8 @@
# -*- coding: utf-8 -*-
from flatland.core.env_observation_builder import GlobalObsForRailEnv
# from flatland.core.transitions import Grid4Transitions
from flatland.core.transition_map import GridTransitionMap, Grid4Transitions
from flatland.core.env import RailEnv
import numpy as np
from flatland.utils.rendertools import *
"""Tests for `flatland` package."""
......
......@@ -177,11 +177,3 @@ def test_dead_end():
rail_env.agents_position[0] = [2, 0]
rail_env.agents_direction[0] = 0
check_consistency(rail_env)
test_dead_end()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment