Skip to content
Snippets Groups Projects
Commit 0a151b32 authored by hagrid67's avatar hagrid67
Browse files

split up rail_env.py into rail_env, generators and env_utils.py

parent cf1cc9b7
No related branches found
No related tags found
No related merge requests found
from flatland.envs.rail_env import RailEnv, complex_rail_generator
from flatland.envs.rail_env import RailEnv
from flatland.envs.generators import complex_rail_generator
# from flatland.core.env_observation_builder import TreeObsForRailEnv
from flatland.utils.rendertools import RenderTool
from flatland.baselines.dueling_double_dqn import Agent
......
......@@ -8,737 +8,10 @@ import numpy as np
from flatland.core.env import Environment
from flatland.core.env_observation_builder import TreeObsForRailEnv
from flatland.envs.generators import random_rail_generator
from flatland.core.transitions import Grid8Transitions, RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap
class AStarNode():
"""A node class for A* Pathfinding"""
def __init__(self, parent=None, pos=None):
self.parent = parent
self.pos = pos
self.g = 0
self.h = 0
self.f = 0
def __eq__(self, other):
return self.pos == other.pos
def update_if_better(self, other):
if other.g < self.g:
self.parent = other.parent
self.g = other.g
self.h = other.h
self.f = other.f
def get_direction(pos1, pos2):
"""
Assumes pos1 and pos2 are adjacent location on grid.
Returns direction (int) that can be used with transitions.
"""
diff_0 = pos2[0] - pos1[0]
diff_1 = pos2[1] - pos1[1]
if diff_0 < 0:
return 0
if diff_0 > 0:
return 2
if diff_1 > 0:
return 1
if diff_1 < 0:
return 3
return 0
def mirror(dir):
return (dir + 2) % 4
def validate_new_transition(rail_trans, rail_array, prev_pos, current_pos, new_pos, end_pos):
# start by getting direction used to get to current node
# and direction from current node to possible child node
new_dir = get_direction(current_pos, new_pos)
if prev_pos is not None:
current_dir = get_direction(prev_pos, current_pos)
else:
current_dir = new_dir
# create new transition that would go to child
new_trans = rail_array[current_pos]
if prev_pos is None:
if new_trans == 0:
# need to flip direction because of how end points are defined
new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1)
else:
# check if matches existing layout
new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
# new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
# rail_trans.print(new_trans)
else:
# set the forward path
new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
# set the backwards path
new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
if new_pos == end_pos:
# need to validate end pos setup as well
new_trans_e = rail_array[end_pos]
if new_trans_e == 0:
# need to flip direction because of how end points are defined
new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1)
else:
# check if matches existing layout
new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1)
# new_trans_e = rail_trans.set_transition(new_trans_e, mirror(new_dir), mirror(new_dir), 1)
# print("end:", end_pos, current_pos)
# rail_trans.print(new_trans_e)
# print("========> end trans")
# rail_trans.print(new_trans_e)
if not rail_trans.is_valid(new_trans_e):
# print("end failed", end_pos, current_pos)
return False
# else:
# print("end ok!", end_pos, current_pos)
# is transition is valid?
# print("=======> trans")
# rail_trans.print(new_trans)
return rail_trans.is_valid(new_trans)
def a_star(rail_trans, rail_array, start, end):
"""
Returns a list of tuples as a path from the given start to end.
If no path is found, returns path to closest point to end.
"""
rail_shape = rail_array.shape
start_node = AStarNode(None, start)
end_node = AStarNode(None, end)
open_list = []
closed_list = []
open_list.append(start_node)
# this could be optimized
def is_node_in_list(node, the_list):
for o_node in the_list:
if node == o_node:
return o_node
return None
while len(open_list) > 0:
# get node with current shortest est. path (lowest f)
current_node = open_list[0]
current_index = 0
for index, item in enumerate(open_list):
if item.f < current_node.f:
current_node = item
current_index = index
# pop current off open list, add to closed list
open_list.pop(current_index)
closed_list.append(current_node)
# print("a*:", current_node.pos)
# for cn in closed_list:
# print("closed:", cn.pos)
# found the goal
if current_node == end_node:
path = []
current = current_node
while current is not None:
path.append(current.pos)
current = current.parent
# return reversed path
return path[::-1]
# generate children
children = []
if current_node.parent is not None:
prev_pos = current_node.parent.pos
else:
prev_pos = None
for new_pos in [(0, -1), (0, 1), (-1, 0), (1, 0)]:
node_pos = (current_node.pos[0] + new_pos[0], current_node.pos[1] + new_pos[1])
if node_pos[0] >= rail_shape[0] or \
node_pos[0] < 0 or \
node_pos[1] >= rail_shape[1] or \
node_pos[1] < 0:
continue
# validate positions
# debug: avoid all current rails
# if rail_array.item(node_pos) != 0:
# continue
# validate positions
if not validate_new_transition(rail_trans, rail_array, prev_pos, current_node.pos, node_pos, end_node.pos):
# print("A*: transition invalid")
continue
# create new node
new_node = AStarNode(current_node, node_pos)
children.append(new_node)
# loop through children
for child in children:
# already in closed list?
closed_node = is_node_in_list(child, closed_list)
if closed_node is not None:
continue
# create the f, g, and h values
child.g = current_node.g + 1
# this heuristic favors diagonal paths
# child.h = ((child.pos[0] - end_node.pos[0]) ** 2) + \
# ((child.pos[1] - end_node.pos[1]) ** 2)
# this heuristic avoids diagonal paths
child.h = abs(child.pos[0] - end_node.pos[0]) + abs(child.pos[1] - end_node.pos[1])
child.f = child.g + child.h
# already in the open list?
open_node = is_node_in_list(child, open_list)
if open_node is not None:
open_node.update_if_better(child)
continue
# add the child to the open list
open_list.append(child)
# no full path found, return partial path
if len(open_list) == 0:
path = []
current = current_node
while current is not None:
path.append(current.pos)
current = current.parent
# return reversed path
print("partial:", start, end, path[::-1])
return path[::-1]
def connect_rail(rail_trans, rail_array, start, end):
"""
Creates a new path [start,end] in rail_array, based on rail_trans.
"""
# in the worst case we will need to do a A* search, so we might as well set that up
path = a_star(rail_trans, rail_array, start, end)
# print("connecting path", path)
if len(path) < 2:
return
current_dir = get_direction(path[0], path[1])
end_pos = path[-1]
for index in range(len(path) - 1):
current_pos = path[index]
new_pos = path[index + 1]
new_dir = get_direction(current_pos, new_pos)
new_trans = rail_array[current_pos]
if index == 0:
if new_trans == 0:
# end-point
# need to flip direction because of how end points are defined
new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1)
else:
# into existing rail
new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
# new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
pass
else:
# set the forward path
new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
# set the backwards path
new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
rail_array[current_pos] = new_trans
if new_pos == end_pos:
# setup end pos setup
new_trans_e = rail_array[end_pos]
if new_trans_e == 0:
# end-point
new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1)
else:
# into existing rail
new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1)
# new_trans_e = rail_trans.set_transition(new_trans_e, mirror(new_dir), mirror(new_dir), 1)
rail_array[end_pos] = new_trans_e
current_dir = new_dir
def distance_on_rail(pos1, pos2):
return abs(pos1[0] - pos2[0]) + abs(pos1[1] - pos2[1])
def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0):
"""
Parameters
-------
width : int
The width (number of cells) of the grid to generate.
height : int
The height (number of cells) of the grid to generate.
Returns
-------
numpy.ndarray of type numpy.uint16
The matrix with the correct 16-bit bitmaps for each cell.
"""
def generator(width, height, num_resets=0):
rail_trans = RailEnvTransitions()
rail_array = np.zeros(shape=(width, height), dtype=np.uint16)
np.random.seed(seed + num_resets)
# generate rail array
# step 1:
# - generate a list of start and goal positions
# - use a min/max distance allowed as input for this
# - validate that start/goals are not placed too close to other start/goals
#
# step 2: (optional)
# - place random elements on rails array
# - for instance "train station", etc.
#
# step 3:
# - iterate over all [start, goal] pairs:
# - [first X pairs]
# - draw a rail from [start,goal]
# - draw either vertical or horizontal part first (randomly)
# - if rail crosses existing rail then validate new connection
# - if new connection is invalid turn 90 degrees to left/right
# - possibility that this fails to create a path to goal
# - on failure goto step1 and retry with seed+1
# - [avoid crossing other start,goal positions] (optional)
#
# - [after X pairs]
# - find closest rail from start (Pa)
# - iterating outwards in a "circle" from start until an existing rail cell is hit
# - connect [start, Pa]
# - validate crossing rails
# - Do A* from Pa to find closest point on rail (Pb) to goal point
# - Basically normal A* but find point on rail which is closest to goal
# - since full path to goal is unlikely
# - connect [Pb, goal]
# - validate crossing rails
#
# step 4: (optional)
# - add more rails to map randomly
#
# step 5:
# - return transition map + list of [start, goal] points
#
start_goal = []
for _ in range(nr_start_goal):
sanity_max = 9000
for _ in range(sanity_max):
start = (np.random.randint(0, width), np.random.randint(0, height))
goal = (np.random.randint(0, height), np.random.randint(0, height))
# check to make sure start,goal pos is empty?
if rail_array[goal] != 0 or rail_array[start] != 0:
continue
# check min/max distance
dist_sg = distance_on_rail(start, goal)
if dist_sg < min_dist:
continue
if dist_sg > max_dist:
continue
# check distance to existing points
sg_new = [start, goal]
def check_all_dist(sg_new):
for sg in start_goal:
for i in range(2):
for j in range(2):
dist = distance_on_rail(sg_new[i], sg[j])
if dist < 2:
# print("too close:", dist, sg_new[i], sg[j])
return False
return True
if check_all_dist(sg_new):
break
start_goal.append([start, goal])
connect_rail(rail_trans, rail_array, start, goal)
print("Created #", len(start_goal), "pairs")
# print(start_goal)
return_rail = GridTransitionMap(width=width, height=height, transitions=rail_trans)
return_rail.grid = rail_array
# TODO: return start_goal
return return_rail
return generator
def rail_from_manual_specifications_generator(rail_spec):
"""
Utility to convert a rail given by manual specification as a map of tuples
(cell_type, rotation), to a transition map with the correct 16-bit
transitions specifications.
Parameters
-------
rail_spec : list of list of tuples
List (rows) of lists (columns) of tuples, each specifying a cell for
the RailEnv environment as (cell_type, rotation), with rotation being
clock-wise and in [0, 90, 180, 270].
Returns
-------
function
Generator function that always returns a GridTransitionMap object with
the matrix of correct 16-bit bitmaps for each cell.
"""
def generator(width, height, num_resets=0):
t_utils = RailEnvTransitions()
height = len(rail_spec)
width = len(rail_spec[0])
rail = GridTransitionMap(width=width, height=height, transitions=t_utils)
for r in range(height):
for c in range(width):
cell = rail_spec[r][c]
if cell[0] < 0 or cell[0] >= len(t_utils.transitions):
print("ERROR - invalid cell type=", cell[0])
return []
rail.set_transitions((r, c), t_utils.rotate_transition(t_utils.transitions[cell[0]], cell[1]))
return rail
return generator
def rail_from_GridTransitionMap_generator(rail_map):
"""
Utility to convert a rail given by a GridTransitionMap map with the correct
16-bit transitions specifications.
Parameters
-------
rail_map : GridTransitionMap object
GridTransitionMap object to return when the generator is called.
Returns
-------
function
Generator function that always returns the given `rail_map' object.
"""
def generator(width, height, num_resets=0):
return rail_map
return generator
def rail_from_list_of_saved_GridTransitionMap_generator(list_of_filenames):
"""
Utility to sequentially and cyclically return GridTransitionMap-s from a list of files, on each environment reset.
Parameters
-------
list_of_filenames : list
List of filenames with the saved grids to load.
Returns
-------
function
Generator function that always returns the given `rail_map' object.
"""
def generator(width, height, num_resets=0):
t_utils = RailEnvTransitions()
rail_map = GridTransitionMap(width=width, height=height, transitions=t_utils)
rail_map.load_transition_map(list_of_filenames[num_resets % len(list_of_filenames)], override_gridsize=False)
if rail_map.grid.dtype == np.uint64:
rail_map.transitions = Grid8Transitions()
return rail_map
return generator
"""
def generate_rail_from_list_of_manual_specifications(list_of_specifications)
def generator(width, height, num_resets=0):
return generate_rail_from_manual_specifications(list_of_specifications)
return generator
"""
def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
"""
Dummy random level generator:
- fill in cells at random in [width-2, height-2]
- keep filling cells in among the unfilled ones, such that all transitions
are legit; if no cell can be filled in without violating some
transitions, pick one among those that can satisfy most transitions
(1,2,3 or 4), and delete (+mark to be re-filled) the cells that were
incompatible.
- keep trying for a total number of insertions
(e.g., (W-2)*(H-2)*MAX_REPETITIONS ); if no solution is found, empty the
board and try again from scratch.
- finally pad the border of the map with dead-ends to avoid border issues.
Dead-ends are not allowed inside the grid, only at the border; however, if
no cell type can be inserted in a given cell (because of the neighboring
transitions), deadends are allowed if they solve the problem. This was
found to turn most un-genereatable levels into valid ones.
Parameters
-------
width : int
The width (number of cells) of the grid to generate.
height : int
The height (number of cells) of the grid to generate.
Returns
-------
numpy.ndarray of type numpy.uint16
The matrix with the correct 16-bit bitmaps for each cell.
"""
def generator(width, height, num_resets=0):
t_utils = RailEnvTransitions()
transition_probability = cell_type_relative_proportion
transitions_templates_ = []
transition_probabilities = []
for i in range(len(t_utils.transitions) - 4): # don't include dead-ends
all_transitions = 0
for dir_ in range(4):
trans = t_utils.get_transitions(t_utils.transitions[i], dir_)
all_transitions |= (trans[0] << 3) | \
(trans[1] << 2) | \
(trans[2] << 1) | \
(trans[3])
template = [int(x) for x in bin(all_transitions)[2:]]
template = [0] * (4 - len(template)) + template
# add all rotations
for rot in [0, 90, 180, 270]:
transitions_templates_.append((template,
t_utils.rotate_transition(
t_utils.transitions[i],
rot)))
transition_probabilities.append(transition_probability[i])
template = [template[-1]] + template[:-1]
def get_matching_templates(template):
ret = []
for i in range(len(transitions_templates_)):
is_match = True
for j in range(4):
if template[j] >= 0 and template[j] != transitions_templates_[i][0][j]:
is_match = False
break
if is_match:
ret.append((transitions_templates_[i][1], transition_probabilities[i]))
return ret
MAX_INSERTIONS = (width - 2) * (height - 2) * 10
MAX_ATTEMPTS_FROM_SCRATCH = 10
attempt_number = 0
while attempt_number < MAX_ATTEMPTS_FROM_SCRATCH:
cells_to_fill = []
rail = []
for r in range(height):
rail.append([None] * width)
if r > 0 and r < height - 1:
cells_to_fill = cells_to_fill + [(r, c) for c in range(1, width - 1)]
num_insertions = 0
while num_insertions < MAX_INSERTIONS and len(cells_to_fill) > 0:
# cell = random.sample(cells_to_fill, 1)[0]
cell = cells_to_fill[np.random.choice(len(cells_to_fill), 1)[0]]
cells_to_fill.remove(cell)
row = cell[0]
col = cell[1]
# look at its neighbors and see what are the possible transitions
# that can be chosen from, if any.
valid_template = [-1, -1, -1, -1]
for el in [(0, 2, (-1, 0)),
(1, 3, (0, 1)),
(2, 0, (1, 0)),
(3, 1, (0, -1))]: # N, E, S, W
neigh_trans = rail[row + el[2][0]][col + el[2][1]]
if neigh_trans is not None:
# select transition coming from facing direction el[1] and
# moving to direction el[1]
max_bit = 0
for k in range(4):
max_bit |= t_utils.get_transition(neigh_trans, k, el[1])
if max_bit:
valid_template[el[0]] = 1
else:
valid_template[el[0]] = 0
possible_cell_transitions = get_matching_templates(valid_template)
if len(possible_cell_transitions) == 0: # NO VALID TRANSITIONS
# no cell can be filled in without violating some transitions
# can a dead-end solve the problem?
if valid_template.count(1) == 1:
for k in range(4):
if valid_template[k] == 1:
rot = 0
if k == 0:
rot = 180
elif k == 1:
rot = 270
elif k == 2:
rot = 0
elif k == 3:
rot = 90
rail[row][col] = t_utils.rotate_transition(int('0010000000000000', 2), rot)
num_insertions += 1
break
else:
# can I get valid transitions by removing a single
# neighboring cell?
bestk = -1
besttrans = []
for k in range(4):
tmp_template = valid_template[:]
tmp_template[k] = -1
possible_cell_transitions = get_matching_templates(tmp_template)
if len(possible_cell_transitions) > len(besttrans):
besttrans = possible_cell_transitions
bestk = k
if bestk >= 0:
# Replace the corresponding cell with None, append it
# to cells to fill, fill in a transition in the current
# cell.
replace_row = row - 1
replace_col = col
if bestk == 1:
replace_row = row
replace_col = col + 1
elif bestk == 2:
replace_row = row + 1
replace_col = col
elif bestk == 3:
replace_row = row
replace_col = col - 1
cells_to_fill.append((replace_row, replace_col))
rail[replace_row][replace_col] = None
possible_transitions, possible_probabilities = zip(*besttrans)
possible_probabilities = [p / sum(possible_probabilities) for p in possible_probabilities]
rail[row][col] = np.random.choice(possible_transitions,
p=possible_probabilities)
num_insertions += 1
else:
print('WARNING: still nothing!')
rail[row][col] = int('0000000000000000', 2)
num_insertions += 1
pass
else:
possible_transitions, possible_probabilities = zip(*possible_cell_transitions)
possible_probabilities = [p / sum(possible_probabilities) for p in possible_probabilities]
rail[row][col] = np.random.choice(possible_transitions,
p=possible_probabilities)
num_insertions += 1
if num_insertions == MAX_INSERTIONS:
# Failed to generate a valid level; try again for a number of times
attempt_number += 1
else:
break
if attempt_number == MAX_ATTEMPTS_FROM_SCRATCH:
print('ERROR: failed to generate level')
# Finally pad the border of the map with dead-ends to avoid border issues;
# at most 1 transition in the neigh cell
for r in range(height):
# Check for transitions coming from [r][1] to WEST
max_bit = 0
neigh_trans = rail[r][1]
if neigh_trans is not None:
for k in range(4):
neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1)
max_bit = max_bit | (neigh_trans_from_direction & 1)
if max_bit:
rail[r][0] = t_utils.rotate_transition(int('0010000000000000', 2), 270)
else:
rail[r][0] = int('0000000000000000', 2)
# Check for transitions coming from [r][-2] to EAST
max_bit = 0
neigh_trans = rail[r][-2]
if neigh_trans is not None:
for k in range(4):
neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1)
max_bit = max_bit | (neigh_trans_from_direction & (1 << 2))
if max_bit:
rail[r][-1] = t_utils.rotate_transition(int('0010000000000000', 2),
90)
else:
rail[r][-1] = int('0000000000000000', 2)
for c in range(width):
# Check for transitions coming from [1][c] to NORTH
max_bit = 0
neigh_trans = rail[1][c]
if neigh_trans is not None:
for k in range(4):
neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1)
max_bit = max_bit | (neigh_trans_from_direction & (1 << 3))
if max_bit:
rail[0][c] = int('0010000000000000', 2)
else:
rail[0][c] = int('0000000000000000', 2)
# Check for transitions coming from [-2][c] to SOUTH
max_bit = 0
neigh_trans = rail[-2][c]
if neigh_trans is not None:
for k in range(4):
neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1)
max_bit = max_bit | (neigh_trans_from_direction & (1 << 1))
if max_bit:
rail[-1][c] = t_utils.rotate_transition(int('0010000000000000', 2), 180)
else:
rail[-1][c] = int('0000000000000000', 2)
# For display only, wrong levels
for r in range(height):
for c in range(width):
if rail[r][c] is None:
rail[r][c] = int('0000000000000000', 2)
tmp_rail = np.asarray(rail, dtype=np.uint16)
return_rail = GridTransitionMap(width=width, height=height, transitions=t_utils)
return_rail.grid = tmp_rail
return return_rail
return generator
# from flatland.core.transitions import Grid8Transitions, RailEnvTransitions
# from flatland.core.transition_map import GridTransitionMap
class EnvAgentStatic(object):
......
......@@ -167,26 +167,32 @@ class JupEditor(object):
# get the direction index for the 2 transitions
liTrans = []
for rcTrans in rc2Trans:
# gRCTrans - rcTrans gives an array of vector differences between our rcTrans
# and the 4 directions stored in gRCTrans.
# Where the vector difference is zero, we have a match...
# np.all detects where the whole row,col vector is zero.
# argwhere gives the index of the zero vector, ie the direction index
iTrans = np.argwhere(np.all(self.gRCTrans - rcTrans == 0, axis=1))
if len(iTrans) > 0:
iTrans = iTrans[0][0]
liTrans.append(iTrans)
# check that we have two transitions
if len(liTrans) == 2:
# Set the transition
# oEnv.rail.set_transition((*rcLast, iTransLast), iTrans, True) # does nothing
iValCell = env.rail.transitions.set_transition(
env.rail.grid[tuple(rcMiddle)], liTrans[0], liTrans[1], bTransition)
env.rail.set_transition((*rcMiddle, liTrans[0]), liTrans[1], True)
# iValCell = env.rail.transitions.set_transition(
# env.rail.grid[tuple(rcMiddle)], liTrans[0], liTrans[1], bTransition)
# Also set the reverse transition
iValCell = env.rail.transitions.set_transition(
iValCell,
(liTrans[1] + 2) % 4,
(liTrans[0] + 2) % 4,
bTransition)
# iValCell = env.rail.transitions.set_transition(
# iValCell,
# (liTrans[1] + 2) % 4, # use the reversed outbound transition for inbound
# (liTrans[0] + 2) % 4, # use the reversed inbound transition for outbound
# bTransition)
# Write the cell transition value back into the grid
env.rail.grid[tuple(rcMiddle)] = iValCell
# env.rail.grid[tuple(rcMiddle)] = iValCell
rcHistory.pop(0) # remove the last-but-one
......
......@@ -5,7 +5,8 @@ import numpy as np
from flatland.core.env_observation_builder import GlobalObsForRailEnv
from flatland.core.transition_map import GridTransitionMap, Grid4Transitions
from flatland.envs.rail_env import RailEnv, rail_from_GridTransitionMap_generator
from flatland.envs.rail_env import RailEnv
from flatland.envs.generators import rail_from_GridTransitionMap_generator
"""Tests for `flatland` package."""
......
......@@ -2,11 +2,13 @@
# -*- coding: utf-8 -*-
import numpy as np
from flatland.envs.rail_env import RailEnv, rail_from_GridTransitionMap_generator
from flatland.envs.rail_env import RailEnv
from flatland.envs.generators import rail_from_GridTransitionMap_generator
from flatland.core.transitions import Grid4Transitions
from flatland.core.transition_map import GridTransitionMap
from flatland.core.env_observation_builder import GlobalObsForRailEnv
"""Tests for `flatland` package."""
......
......@@ -3,7 +3,8 @@
"""Tests for `flatland` package."""
from flatland.core.transitions import RailEnvTransitions, Grid8Transitions
from flatland.envs.rail_env import validate_new_transition
# from flatland.envs.rail_env import validate_new_transition
from flatland.envs.env_utils import validate_new_transition
import numpy as np
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment