From 07cf76ac951a33e0ce7c3f7e65b1ad6bdf4419b3 Mon Sep 17 00:00:00 2001 From: Mattias Ljungstrom <mattias.ljungstrom@gmail.com> Date: Sun, 28 Apr 2019 17:28:04 +0200 Subject: [PATCH] level gen, A* validates rail transitions --- examples/play_model.py | 6 +- flatland/core/transitions.py | 16 +++- flatland/envs/rail_env.py | 156 +++++++++++++++++++++++------------ tests/test_transitions.py | 50 +++++++++++ 4 files changed, 172 insertions(+), 56 deletions(-) diff --git a/examples/play_model.py b/examples/play_model.py index 46a3fb5..c51c897 100644 --- a/examples/play_model.py +++ b/examples/play_model.py @@ -6,7 +6,7 @@ from collections import deque import torch import random import numpy as np -import matplotlib.pyplot as plt +#import matplotlib.pyplot as plt import time @@ -28,12 +28,12 @@ def main(render=True, delay=0.0): # Example generate a random rail env = RailEnv(width=15, height=15, - rail_generator=complex_rail_generator(), + rail_generator=complex_rail_generator(nr_start_goal=20), number_of_agents=1) if render: env_renderer = RenderTool(env, gl="QT") - plt.figure(figsize=(5,5)) + # plt.figure(figsize=(5,5)) # fRedis = redis.Redis() handle = env.get_agent_handles() diff --git a/flatland/core/transitions.py b/flatland/core/transitions.py index a0becdd..ec9586e 100644 --- a/flatland/core/transitions.py +++ b/flatland/core/transitions.py @@ -531,13 +531,22 @@ class RailEnvTransitions(Grid4Transitions): int('1001011000100001', 2), # Case 4 - single slip int('1100110000110011', 2), # Case 5 - double slip int('0101001000000010', 2), # Case 6 - symmetrical - int('0010000000000000', 2)] # Case 7 - dead end + int('0010000000000000', 2), # Case 7 - dead end + int('0100000000000010', 2), # Case 1b - simple turn right + int('0001001000000000', 2)] # Case 1c - simple turn left def __init__(self): super(RailEnvTransitions, self).__init__( transitions=self.transition_list ) + def print(self, cell_transition): + print(" NESW") + print("N", format(cell_transition>>(3*4) & 0xF, '04b')) + print("E", format(cell_transition>>(2*4) & 0xF, '04b')) + print("S", format(cell_transition>>(1*4) & 0xF, '04b')) + print("W", format(cell_transition>>(0*4) & 0xF, '04b')) + def is_valid(self, cell_transition): """ Checks if a cell transition is a valid cell setup. @@ -552,11 +561,16 @@ class RailEnvTransitions(Grid4Transitions): Boolean True or False """ + # i = 0 for trans in self.transitions: + # print(">", i) + # i += 1 + # self.print(trans) if cell_transition == trans: return True for _ in range(3): trans = self.rotate_transition(trans, rotation=90) + # self.print(trans) if cell_transition == trans: return True return False diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index aeb3621..86582f5 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -34,7 +34,61 @@ class AStarNode(): self.f = other.f -def a_star(rail_array, start, end): +def get_direction(pos1, pos2): + """ + Assumes pos1 and pos2 are adjacent location on grid. + Returns direction (int) that can be used with transitions. + """ + diff_0 = pos2[0] - pos1[0] + diff_1 = pos2[1] - pos1[1] + if diff_0 < 0: + return 0 + if diff_0 > 0: + return 2 + if diff_1 > 0: + return 1 + if diff_1 < 0: + return 3 + return 0 + + +def mirror(dir): + return (dir + 2) % 4 + + +def validate_new_transition(rail_trans, rail_array, prev_pos, current_pos, new_pos, end_pos): + # start by getting direction used to get to current node + # and direction from current node to possible child node + new_dir = get_direction(current_pos, new_pos) + if prev_pos is not None: + current_dir = get_direction(prev_pos, current_pos) + else: + current_dir = new_dir + # create new transition that would go to child + new_trans = rail_array[current_pos] + if prev_pos is None: + # need to flip direction because of how end points are defined + new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1) + else: + # set the forward path + new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) + # set the backwards path + new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1) + if new_pos == end_pos: + # need to validate end pos setup as well + new_trans_e = rail_array[end_pos] + new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1) + # print("========> end trans") + # rail_trans.print(new_trans_e) + if not rail_trans.is_valid(new_trans_e): + return False + # is transition is valid? + # print("=======> trans") + # rail_trans.print(new_trans) + return rail_trans.is_valid(new_trans) + + +def a_star(rail_trans, rail_array, start, end): """ Returns a list of tuples as a path from the given start to end. If no path is found, returns path to closest point to end. @@ -83,6 +137,10 @@ def a_star(rail_array, start, end): # generate children children = [] + if current_node.parent is not None: + prev_pos = current_node.parent.pos + else: + prev_pos = None for new_pos in [(0, -1), (0, 1), (-1, 0), (1, 0)]: node_pos = (current_node.pos[0] + new_pos[0], current_node.pos[1] + new_pos[1]) if node_pos[0] >= rail_shape[0] or \ @@ -96,6 +154,11 @@ def a_star(rail_array, start, end): # if rail_array.item(node_pos) != 0: # continue + # validate positions + if not validate_new_transition(rail_trans, rail_array, prev_pos, current_node.pos, node_pos, end_node.pos): + # print("A*: transition invalid") + continue + # create new node new_node = AStarNode(current_node, node_pos) children.append(new_node) @@ -136,7 +199,43 @@ def a_star(rail_array, start, end): return path[::-1] -def complex_rail_generator(nr_start_goal=10, min_dist=0, max_dist=99999, seed=0): +def connect_rail(rail_trans, rail_array, start, end): + """ + Creates a new path [start,end] in rail_array, based on rail_trans. + """ + # in the worst case we will need to do a A* search, so we might as well set that up + path = a_star(rail_trans, rail_array, start, end) + # print("connecting path", path) + if len(path) < 2: + return + current_dir = get_direction(path[0], path[1]) + end_pos = path[-1] + for index in range(len(path) - 1): + current_pos = path[index] + new_pos = path[index+1] + new_dir = get_direction(current_pos, new_pos) + + new_trans = rail_array[current_pos] + if index == 0: + # need to flip direction because of how end points are defined + new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1) + else: + # set the forward path + new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) + # set the backwards path + new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1) + rail_array[current_pos] = new_trans + + if new_pos == end_pos: + # need to validate end pos setup as well + new_trans_e = rail_array[end_pos] + new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1) + rail_array[end_pos] = new_trans_e + + current_dir = new_dir + + +def complex_rail_generator(nr_start_goal=1, min_dist=0, max_dist=99999, seed=0): """ Parameters ------- @@ -155,7 +254,7 @@ def complex_rail_generator(nr_start_goal=10, min_dist=0, max_dist=99999, seed=0) rail_trans = RailEnvTransitions() rail_array = np.zeros(shape=(width, height), dtype=np.uint16) - np.random.seed(seed) + np.random.seed(seed + num_resets) # generate rail array # step 1: @@ -204,55 +303,8 @@ def complex_rail_generator(nr_start_goal=10, min_dist=0, max_dist=99999, seed=0) # TODO: make sure min/max distance condition is met start_goal.append([start, goal]) - def get_direction(pos1, pos2): - diff_0 = pos2[0] - pos1[0] - diff_1 = pos2[1] - pos1[1] - if diff_0 < 0: - return 0 - if diff_0 > 0: - return 2 - if diff_1 > 0: - return 1 - if diff_1 < 0: - return 3 - return 0 - - def connect_two_cells(pos1, pos2): - # connect two adjacent cells - direction = get_direction(pos1, pos2) - rail_array[pos1] = rail_trans.set_transition(rail_array[pos1], direction, direction, 1) - o_dir = (direction + 2) % 4 - rail_array[pos2] = rail_trans.set_transition(rail_array[pos2], o_dir, o_dir, 1) - - def connect_rail(start, end): - # in the worst case we will need to do a A* search, so we might as well set that up - # TODO: need to check transitions in A* to see if new path is valid - path = a_star(rail_array, start, end) - print("connecting path", path) - if len(path) < 2: - return - if len(path) == 2: - connect_two_cells(path[0], path[1]) - return - current_dir = get_direction(path[0], path[1]) - for index in range(len(path)): - pos1 = path[index] - if index+1 < len(path): - new_dir = get_direction(pos1, path[index+1]) - else: - new_dir = current_dir - cell_trans = rail_array[pos1] - if index != len(path)-1: - # set the forward path - cell_trans = rail_trans.set_transition(cell_trans, current_dir, new_dir, 1) - if index != 0: - # set the backwards path - cell_trans = rail_trans.set_transition(cell_trans, (new_dir+2) % 4, (current_dir+2) % 4, 1) - rail_array[pos1] = cell_trans - current_dir = new_dir - for sg in start_goal: - connect_rail(sg[0], sg[1]) + connect_rail(rail_trans, rail_array, sg[0], sg[1]) return_rail = GridTransitionMap(width=width, height=height, transitions=rail_trans) return_rail.grid = rail_array @@ -396,7 +448,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8): transitions_templates_ = [] transition_probabilities = [] - for i in range(len(t_utils.transitions) - 1): # don't include dead-ends + for i in range(len(t_utils.transitions) - 3): # don't include dead-ends all_transitions = 0 for dir_ in range(4): trans = t_utils.get_transitions(t_utils.transitions[i], dir_) diff --git a/tests/test_transitions.py b/tests/test_transitions.py index 0f56e88..1d6ea96 100644 --- a/tests/test_transitions.py +++ b/tests/test_transitions.py @@ -3,6 +3,56 @@ """Tests for `flatland` package.""" from flatland.core.transitions import RailEnvTransitions, Grid8Transitions +from flatland.envs.rail_env import validate_new_transition +import numpy as np + + +def test_is_valid_railenv_transitions(): + rail_env_trans = RailEnvTransitions() + transition_list = rail_env_trans.transitions + + for t in transition_list: + assert(rail_env_trans.is_valid(t) == True) + for i in range(3): + rot_trans = rail_env_trans.rotate_transition(t, 90 * i) + assert(rail_env_trans.is_valid(rot_trans) == True) + + assert(rail_env_trans.is_valid(int('1111111111110010', 2)) == False) + assert(rail_env_trans.is_valid(int('1001111111110010', 2)) == False) + assert(rail_env_trans.is_valid(int('1001111001110110', 2)) == False) + + +def test_adding_new_valid_transition(): + rail_trans = RailEnvTransitions() + rail_array = np.zeros(shape=(15, 15), dtype=np.uint16) + + # adding straight + assert(validate_new_transition(rail_trans, rail_array, (4,5), (5,5), (6,5), (10,10)) == True) + + # adding valid right turn + assert(validate_new_transition(rail_trans, rail_array, (5,4), (5,5), (5,6), (10,10)) == True) + # adding valid left turn + assert(validate_new_transition(rail_trans, rail_array, (5,6), (5,5), (5,6), (10,10)) == True) + + # adding invalid turn + rail_array[(5,5)] = rail_trans.transitions[2] + assert(validate_new_transition(rail_trans, rail_array, (4,5), (5,5), (5,6), (10,10)) == False) + + # should create #4 -> valid + rail_array[(5,5)] = rail_trans.transitions[3] + assert(validate_new_transition(rail_trans, rail_array, (4,5), (5,5), (5,6), (10,10)) == True) + + # adding invalid turn + rail_array[(5,5)] = rail_trans.transitions[7] + assert(validate_new_transition(rail_trans, rail_array, (4,5), (5,5), (5,6), (10,10)) == False) + + # test path start condition + rail_array[(5,5)] = rail_trans.transitions[0] + assert(validate_new_transition(rail_trans, rail_array, None, (5,5), (5,6), (10,10)) == True) + + # test path end condition + rail_array[(5,5)] = rail_trans.transitions[0] + assert(validate_new_transition(rail_trans, rail_array, (5,4), (5,5), (6,5), (6,5)) == True) def test_valid_railenv_transitions(): -- GitLab