diff --git a/AUTHORS.rst b/AUTHORS.rst index 461cefa51b1a61703bef43ed611550ac4e4a4c91..39a8017090e88e8a6b52119c3cbdc1f10e09ce41 100644 --- a/AUTHORS.rst +++ b/AUTHORS.rst @@ -11,7 +11,7 @@ Development * A Egli <adrian.egli@sbb.ch> -* Mattias Ljungström <ml@mljx.io> +* Mattias Ljungström Contributors ------------ diff --git a/README.rst b/README.rst index 7bd9627abbe022482047b552b2f43eb3370b7358..60e8276427c5cce82a01047755a9677d878c3da1 100644 --- a/README.rst +++ b/README.rst @@ -1,5 +1,5 @@ ======== -flatland +Flatland ======== @@ -15,45 +15,86 @@ flatland Multi Agent Reinforcement Learning on Trains +Getting Started +=============== + Generate Docs -------------- - | The docs have a lot more details about how to interact with this codebase. - | **TODO**: Mohanty to add atleast a neat outline herefor the contents to the docs here. - - .. code-block:: bash - git clone git@gitlab.aicrowd.com:flatland/flatland.git - cd flatland - pip install -r requirements_dev.txt +The docs have a lot more details about how to interact with this codebase. - * Linux and macOS +**TODO**: Mohanty to add atleast a neat outline herefor the contents to the docs here :: - .. code-block:: bash + git clone git@gitlab.aicrowd.com:flatland/flatland.git + cd flatland + pip install -r requirements_dev.txt - make docs +* On, Linux and macOS :: + make docs - * Windows - .. code-block:: bash +* On, Windows :: - python setup.py develop (or) - python setup.py install - python make_docs.py + python setup.py develop (or) + python setup.py install + python make_docs.py Features -------- -* TODO +TODO + + +Installation +============ + +Stable Release +-------------- + +To install flatland, run this command in your terminal :: + + pip install flatland-rl + +This is the preferred method to install flatland, as it will always install the most recent stable release. + +If you don’t have `pip <https://pip.pypa.io/en/stable/>`_ installed, this `Python installation guide <https://docs.python-guide.org/starting/installation/>`_ can guide you through the process. + + +From Sources +------------ +The sources for flatland can be downloaded from the `Gitlab repo <https://gitlab.aicrowd.com/flatland/flatland>`_. + +You can clone the public repository :: + + $ git clone git@gitlab.aicrowd.com:flatland/flatland.git + +Once you have a copy of the source, you can install it with :: + + $ python setup.py install + + +Usage +===== +To use flatland in a project :: + + import flatland + +flatland +======== +TODO: explain the interface here + Authors -------- * Sharada Mohanty <mohanty@aicrowd.com> * Giacomo Spigler <giacomo.spigler@gmail.com> -* Mattias Ljungström <ml@mljx.io> +* Mattias Ljungström * Jeremy Watson * Erik Nygren <erik.nygren@sbb.ch> * Adrian Egli <adrian.egli@sbb.ch> +* Vaibhav Agrawal <theinfamouswayne@gmail.com> + <please fill yourself in> diff --git a/env-data/tests/test1.npy b/env-data/tests/test1.npy index 77e0288589171b8b03d828423ca456f2ac8395e3..f0cff3c9a1260facf073b88702da3f0557ab32f0 100644 Binary files a/env-data/tests/test1.npy and b/env-data/tests/test1.npy differ diff --git a/examples/__init__.py b/examples/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/examples/play_model.py b/examples/play_model.py index 2c18c3e3fbf5e54320f3382ae158f542a2130080..1f654c15c590801fcfcaa518deb8289af9a95d99 100644 --- a/examples/play_model.py +++ b/examples/play_model.py @@ -1,4 +1,4 @@ -from flatland.envs.rail_env import RailEnv, random_rail_generator +from flatland.envs.rail_env import RailEnv, complex_rail_generator # from flatland.core.env_observation_builder import TreeObsForRailEnv from flatland.utils.rendertools import RenderTool from flatland.baselines.dueling_double_dqn import Agent @@ -6,7 +6,6 @@ from collections import deque import torch import random import numpy as np -import matplotlib.pyplot as plt import time @@ -34,7 +33,7 @@ class Player(object): self.tStart = time.time() # Reset environment - #self.obs = self.env.reset() + # self.obs = self.env.reset() self.env.obs_builder.reset() self.obs = self.env._get_observations() for a in range(self.env.number_of_agents): @@ -86,7 +85,6 @@ def max_lt(seq, val): return None - def main(render=True, delay=0.0): random.seed(1) @@ -94,27 +92,26 @@ def main(render=True, delay=0.0): # Example generate a rail given a manual specification, # a map of tuples (cell_type, rotation) - transition_probability = [0.5, # empty cell - Case 0 - 1.0, # Case 1 - straight - 1.0, # Case 2 - simple switch - 0.3, # Case 3 - diamond drossing - 0.5, # Case 4 - single slip - 0.5, # Case 5 - double slip - 0.2, # Case 6 - symmetrical - 0.0] # Case 7 - dead end + # transition_probability = [0.5, # empty cell - Case 0 + # 1.0, # Case 1 - straight + # 1.0, # Case 2 - simple switch + # 0.3, # Case 3 - diamond crossing + # 0.5, # Case 4 - single slip + # 0.5, # Case 5 - double slip + # 0.2, # Case 6 - symmetrical + # 0.0] # Case 7 - dead end # Example generate a random rail - env = RailEnv(width=15, - height=15, - rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), - number_of_agents=5) + env = RailEnv(width=15, height=15, + rail_generator=complex_rail_generator(nr_start_goal=15, min_dist=5), + number_of_agents=1) if render: env_renderer = RenderTool(env, gl="QT") - plt.figure(figsize=(5,5)) + # plt.figure(figsize=(5,5)) # fRedis = redis.Redis() - handle = env.get_agent_handles() + # handle = env.get_agent_handles() state_size = 105 action_size = 4 @@ -152,7 +149,7 @@ def main(render=True, delay=0.0): obs = env.reset() for a in range(env.number_of_agents): - norm = max(1, max_lt(obs[a],np.inf)) + norm = max(1, max_lt(obs[a], np.inf)) obs[a] = np.clip(np.array(obs[a]) / norm, -1, 1) # env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5) @@ -162,9 +159,9 @@ def main(render=True, delay=0.0): # Run episode for step in range(50): - #if trials > 114: - #env_renderer.renderEnv(show=True) - #print(step) + # if trials > 114: + # env_renderer.renderEnv(show=True) + # print(step) # Action for a in range(env.number_of_agents): action = agent.act(np.array(obs[a]), eps=eps) @@ -188,7 +185,6 @@ def main(render=True, delay=0.0): iFrame += 1 - obs = next_obs.copy() if done['__all__']: env_done = 1 @@ -202,23 +198,23 @@ def main(render=True, delay=0.0): dones_list.append((np.mean(done_window))) print(('\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%' + - '\tEpsilon: {:.2f} \t Action Probabilities: \t {}').format( - env.number_of_agents, - trials, - np.mean(scores_window), - 100 * np.mean(done_window), - eps, action_prob/np.sum(action_prob)), + '\tEpsilon: {:.2f} \t Action Probabilities: \t {}').format( + env.number_of_agents, + trials, + np.mean(scores_window), + 100 * np.mean(done_window), + eps, action_prob/np.sum(action_prob)), end=" ") if trials % 100 == 0: tNow = time.time() rFps = iFrame / (tNow - tStart) print(('\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%' + - '\tEpsilon: {:.2f} fps: {:.2f} \t Action Probabilities: \t {}').format( - env.number_of_agents, - trials, - np.mean(scores_window), - 100 * np.mean(done_window), - eps, rFps, action_prob / np.sum(action_prob))) + '\tEpsilon: {:.2f} fps: {:.2f} \t Action Probabilities: \t {}').format( + env.number_of_agents, + trials, + np.mean(scores_window), + 100 * np.mean(done_window), + eps, rFps, action_prob / np.sum(action_prob))) torch.save(agent.qnetwork_local.state_dict(), '../flatland/baselines/Nets/avoid_checkpoint' + str(trials) + '.pth') action_prob = [1]*4 diff --git a/examples/training_navigation.py b/examples/training_navigation.py index 6a3885280fb1cdc60816ecb11340c6023d48a1b3..c1182074731725091630df0f3753b2c55d0280b2 100644 --- a/examples/training_navigation.py +++ b/examples/training_navigation.py @@ -25,12 +25,19 @@ env = RailEnv(width=10, rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), number_of_agents=1) """ +env = RailEnv(width=20, + height=20, + rail_generator=complex_rail_generator(nr_start_goal=20, min_dist=10, max_dist=99999, seed=0), + number_of_agents=5) + +""" env = RailEnv(width=20, height=20, rail_generator=rail_from_list_of_saved_GridTransitionMap_generator( - ['../flatland/baselines/test-editor.npy']), + ['../env-data/tests/circle.npy']), number_of_agents=1) -""" + + env_renderer = RenderTool(env, gl="QT") handle = env.get_agent_handles() @@ -47,7 +54,7 @@ scores = [] dones_list = [] action_prob = [0] * 4 agent = Agent(state_size, action_size, "FC", 0) -agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint15000.pth')) +agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint14900.pth')) demo = True @@ -102,10 +109,10 @@ for trials in range(1, n_trials + 1): for a in range(env.number_of_agents): if demo: eps = 0 - action = agent.act(np.array(obs[a]), eps=eps) + action = 2 #agent.act(np.array(obs[a]), eps=eps) action_prob[action] += 1 action_dict.update({a: action}) - + #env.obs_builder.util_print_obs_subtree(tree=obs[a], num_features_per_node=5) # Environment step next_obs, all_rewards, done, _ = env.step(action_dict) for a in range(env.number_of_agents): @@ -120,7 +127,7 @@ for trials in range(1, n_trials + 1): if done['__all__']: env_done = 1 break - # Epsioln decay + # Epsilon decay eps = max(eps_end, eps_decay * eps) # decrease epsilon done_window.append(env_done) diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py index 8737862b60d9c330cb95e7679b3e39c2ad897da6..1ae2819d28170e4831a104c5149681a5e7d2fafe 100644 --- a/flatland/core/env_observation_builder.py +++ b/flatland/core/env_observation_builder.py @@ -17,6 +17,7 @@ class ObservationBuilder: """ ObservationBuilder base class. """ + def __init__(self): pass @@ -55,6 +56,7 @@ class TreeObsForRailEnv(ObservationBuilder): The information is local to each agent and exploits the tree structure of the rail network to simplify the representation of the state of the environment for each agent. """ + def __init__(self, max_depth): self.max_depth = max_depth @@ -135,7 +137,7 @@ class TreeObsForRailEnv(ObservationBuilder): new_cell = self._new_position(position, neigh_direction) if new_cell[0] >= 0 and new_cell[0] < self.env.height and \ - new_cell[1] >= 0 and new_cell[1] < self.env.width: + new_cell[1] >= 0 and new_cell[1] < self.env.width: desired_movement_from_new_cell = (neigh_direction + 2) % 4 @@ -176,7 +178,7 @@ class TreeObsForRailEnv(ObservationBuilder): """ Utility function that converts a compass movement over a 2D grid to new positions (r, c). """ - if movement == 0: # NORTH + if movement == 0: # NORTH return (position[0] - 1, position[1]) elif movement == 1: # EAST return (position[0], position[1] + 1) @@ -325,7 +327,7 @@ class TreeObsForRailEnv(ObservationBuilder): if not last_isDeadEnd: # Keep walking through the tree along `direction' exploring = True - + # TODO: Remove below calculation, this is computed already above and could be reused for i in range(4): if cell_transitions[i]: position = self._new_position(position, i) @@ -340,7 +342,8 @@ class TreeObsForRailEnv(ObservationBuilder): elif num_transitions == 0: # Wrong cell type, but let's cover it and treat it as a dead-end, just in case - print("WRONG CELL TYPE detected in tree-search (0 transitions possible)") + print("WRONG CELL TYPE detected in tree-search (0 transitions possible) at cell", position[0], + position[1], direction) last_isTerminal = True break @@ -394,7 +397,7 @@ class TreeObsForRailEnv(ObservationBuilder): observation = observation + branch_observation elif last_isSwitch and self.env.rail.get_transition((position[0], position[1], direction), - branch_direction): + (branch_direction + 2) % 4): new_cell = self._new_position(position, branch_direction) branch_observation = self._explore_branch(handle, @@ -456,6 +459,7 @@ class GlobalObsForRailEnv(ObservationBuilder): - A 4 elements array with one of encoding of the direction. """ + def __init__(self): super(GlobalObsForRailEnv, self).__init__() diff --git a/flatland/core/transitions.py b/flatland/core/transitions.py index a8cb8d6f49157bafbb65551b53c6612c45565c88..cdc657cf0e10fc41e0a2bbec465bf39979e04819 100644 --- a/flatland/core/transitions.py +++ b/flatland/core/transitions.py @@ -531,9 +531,50 @@ class RailEnvTransitions(Grid4Transitions): int('1001011000100001', 2), # Case 4 - single slip int('1100110000110011', 2), # Case 5 - double slip int('0101001000000010', 2), # Case 6 - symmetrical - int('0010000000000000', 2)] # Case 7 - dead end + int('0010000000000000', 2), # Case 7 - dead end + int('0100000000000010', 2), # Case 1b (8) - simple turn right + int('0001001000000000', 2), # Case 1c (9) - simple turn left + int('1100000000100010', 2)] # Case 2b (10) - simple switch mirrored def __init__(self): super(RailEnvTransitions, self).__init__( transitions=self.transition_list ) + # create this to make validation faster + self.transitions_all = [] + for index, trans in enumerate(self.transitions): + self.transitions_all.append(trans) + if index in (2, 4, 6, 7, 8, 9, 10): + for _ in range(3): + trans = self.rotate_transition(trans, rotation=90) + self.transitions_all.append(trans) + elif index in (1, 5): + trans = self.rotate_transition(trans, rotation=90) + self.transitions_all.append(trans) + + def print(self, cell_transition): + print(" NESW") + print("N", format(cell_transition >> (3*4) & 0xF, '04b')) + print("E", format(cell_transition >> (2*4) & 0xF, '04b')) + print("S", format(cell_transition >> (1*4) & 0xF, '04b')) + print("W", format(cell_transition >> (0*4) & 0xF, '04b')) + + def is_valid(self, cell_transition): + """ + Checks if a cell transition is a valid cell setup. + + Parameters + ---------- + cell_transition : int + 64 bits used to encode the valid transitions for a cell. + + Returns + ------- + Boolean + True or False + """ + for trans in self.transitions_all: + if cell_transition == trans: + return True + return False + diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index ce05ce02ccaec6ca5a8add3adef24fdcead02924..f1a3d20cc9873ca11b604148a8639d908ce11486 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -13,6 +13,369 @@ from flatland.core.transitions import Grid8Transitions, RailEnvTransitions from flatland.core.transition_map import GridTransitionMap +class AStarNode(): + """A node class for A* Pathfinding""" + + def __init__(self, parent=None, pos=None): + self.parent = parent + self.pos = pos + self.g = 0 + self.h = 0 + self.f = 0 + + def __eq__(self, other): + return self.pos == other.pos + + def update_if_better(self, other): + if other.g < self.g: + self.parent = other.parent + self.g = other.g + self.h = other.h + self.f = other.f + + +def get_direction(pos1, pos2): + """ + Assumes pos1 and pos2 are adjacent location on grid. + Returns direction (int) that can be used with transitions. + """ + diff_0 = pos2[0] - pos1[0] + diff_1 = pos2[1] - pos1[1] + if diff_0 < 0: + return 0 + if diff_0 > 0: + return 2 + if diff_1 > 0: + return 1 + if diff_1 < 0: + return 3 + return 0 + + +def mirror(dir): + return (dir + 2) % 4 + + +def validate_new_transition(rail_trans, rail_array, prev_pos, current_pos, new_pos, end_pos): + # start by getting direction used to get to current node + # and direction from current node to possible child node + new_dir = get_direction(current_pos, new_pos) + if prev_pos is not None: + current_dir = get_direction(prev_pos, current_pos) + else: + current_dir = new_dir + # create new transition that would go to child + new_trans = rail_array[current_pos] + if prev_pos is None: + if new_trans == 0: + # need to flip direction because of how end points are defined + new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1) + else: + # check if matches existing layout + new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) + # new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1) + # rail_trans.print(new_trans) + else: + # set the forward path + new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) + # set the backwards path + new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1) + if new_pos == end_pos: + # need to validate end pos setup as well + new_trans_e = rail_array[end_pos] + if new_trans_e == 0: + # need to flip direction because of how end points are defined + new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1) + else: + # check if matches existing layout + new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1) + # new_trans_e = rail_trans.set_transition(new_trans_e, mirror(new_dir), mirror(new_dir), 1) + # print("end:", end_pos, current_pos) + # rail_trans.print(new_trans_e) + + # print("========> end trans") + # rail_trans.print(new_trans_e) + if not rail_trans.is_valid(new_trans_e): + # print("end failed", end_pos, current_pos) + return False + # else: + # print("end ok!", end_pos, current_pos) + + # is transition is valid? + # print("=======> trans") + # rail_trans.print(new_trans) + return rail_trans.is_valid(new_trans) + + +def a_star(rail_trans, rail_array, start, end): + """ + Returns a list of tuples as a path from the given start to end. + If no path is found, returns path to closest point to end. + """ + rail_shape = rail_array.shape + start_node = AStarNode(None, start) + end_node = AStarNode(None, end) + open_list = [] + closed_list = [] + + open_list.append(start_node) + + # this could be optimized + def is_node_in_list(node, the_list): + for o_node in the_list: + if node == o_node: + return o_node + return None + + while len(open_list) > 0: + # get node with current shortest est. path (lowest f) + current_node = open_list[0] + current_index = 0 + for index, item in enumerate(open_list): + if item.f < current_node.f: + current_node = item + current_index = index + + # pop current off open list, add to closed list + open_list.pop(current_index) + closed_list.append(current_node) + + # print("a*:", current_node.pos) + # for cn in closed_list: + # print("closed:", cn.pos) + + # found the goal + if current_node == end_node: + path = [] + current = current_node + while current is not None: + path.append(current.pos) + current = current.parent + # return reversed path + return path[::-1] + + # generate children + children = [] + if current_node.parent is not None: + prev_pos = current_node.parent.pos + else: + prev_pos = None + for new_pos in [(0, -1), (0, 1), (-1, 0), (1, 0)]: + node_pos = (current_node.pos[0] + new_pos[0], current_node.pos[1] + new_pos[1]) + if node_pos[0] >= rail_shape[0] or \ + node_pos[0] < 0 or \ + node_pos[1] >= rail_shape[1] or \ + node_pos[1] < 0: + continue + + # validate positions + # debug: avoid all current rails + # if rail_array.item(node_pos) != 0: + # continue + + # validate positions + if not validate_new_transition(rail_trans, rail_array, prev_pos, current_node.pos, node_pos, end_node.pos): + # print("A*: transition invalid") + continue + + # create new node + new_node = AStarNode(current_node, node_pos) + children.append(new_node) + + # loop through children + for child in children: + # already in closed list? + closed_node = is_node_in_list(child, closed_list) + if closed_node is not None: + continue + + # create the f, g, and h values + child.g = current_node.g + 1 + # this heuristic favors diagonal paths + # child.h = ((child.pos[0] - end_node.pos[0]) ** 2) + \ + # ((child.pos[1] - end_node.pos[1]) ** 2) + # this heuristic avoids diagonal paths + child.h = abs(child.pos[0] - end_node.pos[0]) + abs(child.pos[1] - end_node.pos[1]) + child.f = child.g + child.h + + # already in the open list? + open_node = is_node_in_list(child, open_list) + if open_node is not None: + open_node.update_if_better(child) + continue + + # add the child to the open list + open_list.append(child) + + # no full path found, return partial path + if len(open_list) == 0: + path = [] + current = current_node + while current is not None: + path.append(current.pos) + current = current.parent + # return reversed path + print("partial:", start, end, path[::-1]) + return path[::-1] + + +def connect_rail(rail_trans, rail_array, start, end): + """ + Creates a new path [start,end] in rail_array, based on rail_trans. + """ + # in the worst case we will need to do a A* search, so we might as well set that up + path = a_star(rail_trans, rail_array, start, end) + # print("connecting path", path) + if len(path) < 2: + return + current_dir = get_direction(path[0], path[1]) + end_pos = path[-1] + for index in range(len(path) - 1): + current_pos = path[index] + new_pos = path[index + 1] + new_dir = get_direction(current_pos, new_pos) + + new_trans = rail_array[current_pos] + if index == 0: + if new_trans == 0: + # end-point + # need to flip direction because of how end points are defined + new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1) + else: + # into existing rail + new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) + # new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1) + pass + else: + # set the forward path + new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) + # set the backwards path + new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1) + rail_array[current_pos] = new_trans + + if new_pos == end_pos: + # setup end pos setup + new_trans_e = rail_array[end_pos] + if new_trans_e == 0: + # end-point + new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1) + else: + # into existing rail + new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1) + # new_trans_e = rail_trans.set_transition(new_trans_e, mirror(new_dir), mirror(new_dir), 1) + rail_array[end_pos] = new_trans_e + + current_dir = new_dir + + +def distance_on_rail(pos1, pos2): + return abs(pos1[0] - pos2[0]) + abs(pos1[1] - pos2[1]) + + +def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0): + """ + Parameters + ------- + width : int + The width (number of cells) of the grid to generate. + height : int + The height (number of cells) of the grid to generate. + + Returns + ------- + numpy.ndarray of type numpy.uint16 + The matrix with the correct 16-bit bitmaps for each cell. + """ + + def generator(width, height, num_resets=0): + rail_trans = RailEnvTransitions() + rail_array = np.zeros(shape=(width, height), dtype=np.uint16) + + np.random.seed(seed + num_resets) + + # generate rail array + # step 1: + # - generate a list of start and goal positions + # - use a min/max distance allowed as input for this + # - validate that start/goals are not placed too close to other start/goals + # + # step 2: (optional) + # - place random elements on rails array + # - for instance "train station", etc. + # + # step 3: + # - iterate over all [start, goal] pairs: + # - [first X pairs] + # - draw a rail from [start,goal] + # - draw either vertical or horizontal part first (randomly) + # - if rail crosses existing rail then validate new connection + # - if new connection is invalid turn 90 degrees to left/right + # - possibility that this fails to create a path to goal + # - on failure goto step1 and retry with seed+1 + # - [avoid crossing other start,goal positions] (optional) + # + # - [after X pairs] + # - find closest rail from start (Pa) + # - iterating outwards in a "circle" from start until an existing rail cell is hit + # - connect [start, Pa] + # - validate crossing rails + # - Do A* from Pa to find closest point on rail (Pb) to goal point + # - Basically normal A* but find point on rail which is closest to goal + # - since full path to goal is unlikely + # - connect [Pb, goal] + # - validate crossing rails + # + # step 4: (optional) + # - add more rails to map randomly + # + # step 5: + # - return transition map + list of [start, goal] points + # + + start_goal = [] + for _ in range(nr_start_goal): + sanity_max = 9000 + for _ in range(sanity_max): + start = (np.random.randint(0, width), np.random.randint(0, height)) + goal = (np.random.randint(0, height), np.random.randint(0, height)) + # check to make sure start,goal pos is empty? + if rail_array[goal] != 0 or rail_array[start] != 0: + continue + # check min/max distance + dist_sg = distance_on_rail(start, goal) + if dist_sg < min_dist: + continue + if dist_sg > max_dist: + continue + # check distance to existing points + sg_new = [start, goal] + + def check_all_dist(sg_new): + for sg in start_goal: + for i in range(2): + for j in range(2): + dist = distance_on_rail(sg_new[i], sg[j]) + if dist < 2: + # print("too close:", dist, sg_new[i], sg[j]) + return False + return True + + if check_all_dist(sg_new): + break + start_goal.append([start, goal]) + connect_rail(rail_trans, rail_array, start, goal) + + print("Created #", len(start_goal), "pairs") + # print(start_goal) + + return_rail = GridTransitionMap(width=width, height=height, transitions=rail_trans) + return_rail.grid = rail_array + # TODO: return start_goal + return return_rail + + return generator + + def rail_from_manual_specifications_generator(rail_spec): """ Utility to convert a rail given by manual specification as a map of tuples @@ -32,6 +395,7 @@ def rail_from_manual_specifications_generator(rail_spec): Generator function that always returns a GridTransitionMap object with the matrix of correct 16-bit bitmaps for each cell. """ + def generator(width, height, num_resets=0): t_utils = RailEnvTransitions() @@ -67,6 +431,7 @@ def rail_from_GridTransitionMap_generator(rail_map): function Generator function that always returns the given `rail_map' object. """ + def generator(width, height, num_resets=0): return rail_map @@ -87,6 +452,7 @@ def rail_from_list_of_saved_GridTransitionMap_generator(list_of_filenames): function Generator function that always returns the given `rail_map' object. """ + def generator(width, height, num_resets=0): t_utils = RailEnvTransitions() rail_map = GridTransitionMap(width=width, height=height, transitions=t_utils) @@ -148,7 +514,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8): transitions_templates_ = [] transition_probabilities = [] - for i in range(len(t_utils.transitions) - 1): # don't include dead-ends + for i in range(len(t_utils.transitions) - 4): # don't include dead-ends all_transitions = 0 for dir_ in range(4): trans = t_utils.get_transitions(t_utils.transitions[i], dir_) @@ -163,9 +529,9 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8): # add all rotations for rot in [0, 90, 180, 270]: transitions_templates_.append((template, - t_utils.rotate_transition( - t_utils.transitions[i], - rot))) + t_utils.rotate_transition( + t_utils.transitions[i], + rot))) transition_probabilities.append(transition_probability[i]) template = [template[-1]] + template[:-1] @@ -175,7 +541,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8): is_match = True for j in range(4): if template[j] >= 0 and \ - template[j] != transitions_templates_[i][0][j]: + template[j] != transitions_templates_[i][0][j]: is_match = False break if is_match: @@ -316,7 +682,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8): neigh_trans = rail[r][1] if neigh_trans is not None: for k in range(4): - neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2**4 - 1) + neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1) max_bit = max_bit | (neigh_trans_from_direction & 1) if max_bit: rail[r][0] = t_utils.rotate_transition(int('0010000000000000', 2), 270) @@ -328,7 +694,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8): neigh_trans = rail[r][-2] if neigh_trans is not None: for k in range(4): - neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2**4 - 1) + neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1) max_bit = max_bit | (neigh_trans_from_direction & (1 << 2)) if max_bit: rail[r][-1] = t_utils.rotate_transition(int('0010000000000000', 2), @@ -342,7 +708,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8): neigh_trans = rail[1][c] if neigh_trans is not None: for k in range(4): - neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2**4 - 1) + neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1) max_bit = max_bit | (neigh_trans_from_direction & (1 << 3)) if max_bit: rail[0][c] = int('0010000000000000', 2) @@ -354,7 +720,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8): neigh_trans = rail[-2][c] if neigh_trans is not None: for k in range(4): - neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2**4 - 1) + neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1) max_bit = max_bit | (neigh_trans_from_direction & (1 << 1)) if max_bit: rail[-1][c] = t_utils.rotate_transition(int('0010000000000000', 2), 180) @@ -478,9 +844,9 @@ class RailEnv(Environment): def check_agent_lists(self): for lAgents, name in zip( - [self.agents_handles, self.agents_position, self.agents_direction], - ["handles", "positions", "directions"]): - assert self.number_of_agents == len(lAgents), "Inconsistent agent list:"+name + [self.agents_handles, self.agents_position, self.agents_direction], + ["handles", "positions", "directions"]): + assert self.number_of_agents == len(lAgents), "Inconsistent agent list:" + name def check_agent_locdirpath(self, iAgent): valid_movements = [] @@ -495,7 +861,7 @@ class RailEnv(Environment): for m in valid_movements: new_position = self._new_position(self.agents_position[iAgent], m[1]) if m[0] not in valid_starting_directions and \ - self._path_exists(new_position, m[0], self.agents_target[iAgent]): + self._path_exists(new_position, m[0], self.agents_target[iAgent]): valid_starting_directions.append(m[0]) if len(valid_starting_directions) == 0: @@ -514,7 +880,7 @@ class RailEnv(Environment): for m in valid_movements: new_position = self._new_position(rcPos, m[1]) if m[0] not in valid_starting_directions and \ - self._path_exists(new_position, m[0], rcTarget): + self._path_exists(new_position, m[0], rcTarget): valid_starting_directions.append(m[0]) if len(valid_starting_directions) == 0: @@ -529,7 +895,7 @@ class RailEnv(Environment): rcPos = np.random.choice(len(self.valid_positions)) iAgent = self.number_of_agents - + self.agents_position.append(tuple(rcPos)) # ensure it's a tuple not a list self.agents_handles.append(max(self.agents_handles + [-1]) + 1) # max(handles) + 1, starting at 0 @@ -540,9 +906,10 @@ class RailEnv(Environment): self.number_of_agents += 1 self.check_agent_lists() return iAgent - + def reset(self, regen_rail=True, replace_agents=True): if regen_rail or self.rail is None: + # TODO: Import not only rail information but also start and goal positions self.rail = self.rail_generator(self.width, self.height, self.num_resets) self.fill_valid_positions() @@ -554,7 +921,7 @@ class RailEnv(Environment): # Use a TreeObsForRailEnv to compute distance maps to each agent's target, to sample initial # agent's orientations that allow a valid solution. - + # TODO: Possibility ot fill valid positions from list of goals and start self.fill_valid_positions() if replace_agents: @@ -598,7 +965,7 @@ class RailEnv(Environment): for m in valid_movements: new_position = self._new_position(self.agents_position[i], m[1]) if m[0] not in valid_starting_directions and \ - self._path_exists(new_position, m[0], self.agents_target[i]): + self._path_exists(new_position, m[0], self.agents_target[i]): valid_starting_directions.append(m[0]) if len(valid_starting_directions) == 0: @@ -631,6 +998,7 @@ class RailEnv(Environment): for i in range(len(self.agents_handles)): handle = self.agents_handles[i] + transition_isValid = None if handle not in action_dict: continue @@ -648,12 +1016,24 @@ class RailEnv(Environment): pos = self.agents_position[i] direction = self.agents_direction[i] + # compute number of possible transitions in the current + # cell used to check for invalid actions + + nbits = 0 + tmp = self.rail.get_transitions((pos[0], pos[1])) + while tmp > 0: + nbits += (tmp & 1) + tmp = tmp >> 1 movement = direction if action == 1: movement = direction - 1 + if nbits <= 2: + transition_isValid = False + elif action == 3: movement = direction + 1 - + if nbits <= 2: + transition_isValid = False if movement < 0: movement += 4 if movement >= 4: @@ -661,13 +1041,6 @@ class RailEnv(Environment): is_deadend = False if action == 2: - # compute number of possible transitions in the current - # cell - nbits = 0 - tmp = self.rail.get_transitions((pos[0], pos[1])) - while tmp > 0: - nbits += (tmp & 1) - tmp = tmp >> 1 if nbits == 1: # dead-end; assuming the rail network is consistent, # this should match the direction the agent has come @@ -689,13 +1062,31 @@ class RailEnv(Environment): direction = reverse_direction movement = reverse_direction is_deadend = True + if nbits == 2: + # Checking for curves + + valid_transition = self.rail.get_transition( + (pos[0], pos[1], direction), + movement) + reverse_direction = (direction + 2) % 4 + curv_dir = (movement + 1) % 4 + while not valid_transition: + if curv_dir != reverse_direction: + valid_transition = self.rail.get_transition( + (pos[0], pos[1], direction), + curv_dir) + if valid_transition: + movement = curv_dir + curv_dir = (curv_dir + 1) % 4 + + new_position = self._new_position(pos, movement) # Is it a legal move? 1) transition allows the movement in the # cell, 2) the new cell is not empty (case 0), 3) the cell is # free, i.e., no agent is currently in that cell - if new_position[1] >= self.width or\ - new_position[0] >= self.height or\ - new_position[0] < 0 or new_position[1] < 0: + if new_position[1] >= self.width or \ + new_position[0] >= self.height or \ + new_position[0] < 0 or new_position[1] < 0: new_cell_isValid = False elif self.rail.get_transitions((new_position[0], new_position[1])) > 0: @@ -703,9 +1094,11 @@ class RailEnv(Environment): else: new_cell_isValid = False - transition_isValid = self.rail.get_transition( - (pos[0], pos[1], direction), - movement) or is_deadend + # If transition validity hasn't been checked yet. + if transition_isValid == None: + transition_isValid = self.rail.get_transition( + (pos[0], pos[1], direction), + movement) or is_deadend cell_isFree = True for j in range(self.number_of_agents): @@ -724,7 +1117,7 @@ class RailEnv(Environment): # if agent is not in target position, add step penalty if self.agents_position[i][0] == self.agents_target[i][0] and \ - self.agents_position[i][1] == self.agents_target[i][1]: + self.agents_position[i][1] == self.agents_target[i][1]: self.dones[handle] = True else: self.rewards_dict[handle] += step_penalty @@ -733,7 +1126,7 @@ class RailEnv(Environment): num_agents_in_target_position = 0 for i in range(self.number_of_agents): if self.agents_position[i][0] == self.agents_target[i][0] and \ - self.agents_position[i][1] == self.agents_target[i][1]: + self.agents_position[i][1] == self.agents_target[i][1]: num_agents_in_target_position += 1 if num_agents_in_target_position == self.number_of_agents: @@ -746,7 +1139,7 @@ class RailEnv(Environment): return self._get_observations(), self.rewards_dict, self.dones, {} def _new_position(self, position, movement): - if movement == 0: # NORTH + if movement == 0: # NORTH return (position[0] - 1, position[1]) elif movement == 1: # EAST return (position[0], position[1] + 1) diff --git a/flatland/utils/editor.py b/flatland/utils/editor.py index 9c2afe7af954df837a7ee4deff1bb014e4e6b508..ebaf905bed2c92993da01f7d7b02c353b1ad593f 100644 --- a/flatland/utils/editor.py +++ b/flatland/utils/editor.py @@ -33,6 +33,7 @@ class View(object): class JupEditor(object): def __init__(self, env, wid_img): + print("Correct editor") self.env = env self.wid_img = wid_img diff --git a/flatland/utils/render_qt.py b/flatland/utils/render_qt.py index 34e198566fdd8df8cff7e1822934e1f04cf3945c..05c436dc9f8723abdf2c013b9fad5181e695b58f 100644 --- a/flatland/utils/render_qt.py +++ b/flatland/utils/render_qt.py @@ -7,7 +7,7 @@ import numpy as np class QTGL(GraphicsLayer): def __init__(self, width, height): - self.cell_pixels = 50 + self.cell_pixels = 60 self.tile_size = self.cell_pixels self.width = width diff --git a/notebooks/CanvasEditor.ipynb b/notebooks/CanvasEditor.ipynb index ea9bcf018e5edbe1a507ad3546751f44f1fdd049..9773d174ff35e18e843cd4a1609c7d7b6529b569 100644 --- a/notebooks/CanvasEditor.ipynb +++ b/notebooks/CanvasEditor.ipynb @@ -120,7 +120,7 @@ "metadata": {}, "outputs": [], "source": [ - "sfEnv = \"../flatland/env-data/tests/test1.npy\"\n", + "sfEnv = \"C:/Users/u224870/Projekte_Git/flatland/env-data/tests/test1.npy\"\n", "if True:\n", " oEnv.rail.load_transition_map(sfEnv)\n", " oEnv.width = oEnv.rail.width\n", @@ -258,10 +258,18 @@ "cell_type": "code", "execution_count": 13, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Correct editor\n" + ] + } + ], "source": [ "wid_img.unregister_all()\n", - "oEditor = JupEditor(oEnv, wid_img)\n", + "oEditor = JupEditor(oEnv,wid_img)\n", "wid_img.register_move(oEditor.event_handler)\n", "wid_img.register_click(oEditor.on_click)\n" ] @@ -344,7 +352,7 @@ "version_minor": 0 }, "text/plain": [ - "HBox(children=(Canvas(), VBox(children=(Text(value='../flatland/env-data/tests/test1.npy', description='Filena…" + "HBox(children=(Canvas(), VBox(children=(Text(value='C:/Users/u224870/Projekte_Git/flatland/env-data/tests/test…" ] }, "metadata": {}, @@ -436,7 +444,21 @@ "cell_type": "code", "execution_count": 20, "metadata": {}, - "outputs": [], + "outputs": [ + { + "ename": "FileNotFoundError", + "evalue": "[Errno 2] No such file or directory: '../flatland/env-data/tests/test-editor.npy'", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mFileNotFoundError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m<ipython-input-20-a40691809d2c>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m()\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0moEnv\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrail\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msave_transition_map\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"../flatland/env-data/tests/test-editor.npy\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[1;32mc:\\users\\u224870\\appdata\\local\\programs\\python\\python36\\lib\\site-packages\\flatland_rl-0.1.1-py3.6.egg\\flatland\\core\\transition_map.py\u001b[0m in \u001b[0;36msave_transition_map\u001b[1;34m(self, filename)\u001b[0m\n\u001b[0;32m 259\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 260\u001b[0m \"\"\"\n\u001b[1;32m--> 261\u001b[1;33m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msave\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfilename\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mgrid\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 262\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 263\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mload_transition_map\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfilename\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0moverride_gridsize\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mTrue\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32mc:\\users\\u224870\\appdata\\local\\programs\\python\\python36\\lib\\site-packages\\numpy\\lib\\npyio.py\u001b[0m in \u001b[0;36msave\u001b[1;34m(file, arr, allow_pickle, fix_imports)\u001b[0m\n\u001b[0;32m 490\u001b[0m \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0mfile\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mendswith\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'.npy'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 491\u001b[0m \u001b[0mfile\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mfile\u001b[0m \u001b[1;33m+\u001b[0m \u001b[1;34m'.npy'\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 492\u001b[1;33m \u001b[0mfid\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mopen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfile\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m\"wb\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 493\u001b[0m \u001b[0mown_fid\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;32mTrue\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 494\u001b[0m \u001b[1;32melif\u001b[0m \u001b[0mis_pathlib_path\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfile\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: '../flatland/env-data/tests/test-editor.npy'" + ] + } + ], "source": [ "if False: \n", " oEnv.rail.save_transition_map(\"../flatland/env-data/tests/test-editor.npy\")" @@ -451,7 +473,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -467,7 +489,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -477,7 +499,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -486,7 +508,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -510,7 +532,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -519,7 +541,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -528,7 +550,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -537,7 +559,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -554,7 +576,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -587,7 +609,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -625,7 +647,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -657,9 +679,9 @@ "metadata": { "hide_input": false, "kernelspec": { - "display_name": "ve367", + "display_name": "Python 3", "language": "python", - "name": "ve367" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -671,7 +693,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.7" + "version": "3.6.5" }, "latex_envs": { "LaTeX_envs_menu_present": true, diff --git a/tests/test_transitions.py b/tests/test_transitions.py index 0f56e886071fd1d217be03b9a7e875c20d1a0e8a..2ebfc462cd62bee167b5c9f742d159e8567ed8b4 100644 --- a/tests/test_transitions.py +++ b/tests/test_transitions.py @@ -3,6 +3,56 @@ """Tests for `flatland` package.""" from flatland.core.transitions import RailEnvTransitions, Grid8Transitions +from flatland.envs.rail_env import validate_new_transition +import numpy as np + + +def test_is_valid_railenv_transitions(): + rail_env_trans = RailEnvTransitions() + transition_list = rail_env_trans.transitions + + for t in transition_list: + assert(rail_env_trans.is_valid(t) is True) + for i in range(3): + rot_trans = rail_env_trans.rotate_transition(t, 90 * i) + assert(rail_env_trans.is_valid(rot_trans) is True) + + assert(rail_env_trans.is_valid(int('1111111111110010', 2)) is False) + assert(rail_env_trans.is_valid(int('1001111111110010', 2)) is False) + assert(rail_env_trans.is_valid(int('1001111001110110', 2)) is False) + + +def test_adding_new_valid_transition(): + rail_trans = RailEnvTransitions() + rail_array = np.zeros(shape=(15, 15), dtype=np.uint16) + + # adding straight + assert(validate_new_transition(rail_trans, rail_array, (4, 5), (5, 5), (6, 5), (10, 10)) is True) + + # adding valid right turn + assert(validate_new_transition(rail_trans, rail_array, (5, 4), (5, 5), (5, 6), (10, 10)) is True) + # adding valid left turn + assert(validate_new_transition(rail_trans, rail_array, (5, 6), (5, 5), (5, 6), (10, 10)) is True) + + # adding invalid turn + rail_array[(5, 5)] = rail_trans.transitions[2] + assert(validate_new_transition(rail_trans, rail_array, (4, 5), (5, 5), (5, 6), (10, 10)) is False) + + # should create #4 -> valid + rail_array[(5, 5)] = rail_trans.transitions[3] + assert(validate_new_transition(rail_trans, rail_array, (4, 5), (5, 5), (5, 6), (10, 10)) is True) + + # adding invalid turn + rail_array[(5, 5)] = rail_trans.transitions[7] + assert(validate_new_transition(rail_trans, rail_array, (4, 5), (5, 5), (5, 6), (10, 10)) is False) + + # test path start condition + rail_array[(5, 5)] = rail_trans.transitions[0] + assert(validate_new_transition(rail_trans, rail_array, None, (5, 5), (5, 6), (10, 10)) is True) + + # test path end condition + rail_array[(5, 5)] = rail_trans.transitions[0] + assert(validate_new_transition(rail_trans, rail_array, (5, 4), (5, 5), (6, 5), (6, 5)) is True) def test_valid_railenv_transitions():