Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
Flatland
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
pranjal_dhole
Flatland
Commits
55957edf
Commit
55957edf
authored
5 years ago
by
hagrid67
Browse files
Options
Downloads
Patches
Plain Diff
added env_utils and generators.py
parent
0a151b32
No related branches found
No related tags found
No related merge requests found
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
flatland/envs/env_utils.py
+274
-0
274 additions, 0 deletions
flatland/envs/env_utils.py
flatland/envs/generators.py
+478
-0
478 additions, 0 deletions
flatland/envs/generators.py
with
752 additions
and
0 deletions
flatland/envs/env_utils.py
0 → 100644
+
274
−
0
View file @
55957edf
"""
Definition of the RailEnv environment and related level-generation functions.
Generator functions are functions that take width, height and num_resets as arguments and return
a GridTransitionMap object.
"""
# import numpy as np
# from flatland.core.env import Environment
# from flatland.core.env_observation_builder import TreeObsForRailEnv
# from flatland.core.transitions import Grid8Transitions, RailEnvTransitions
# from flatland.core.transition_map import GridTransitionMap
class
AStarNode
():
"""
A node class for A* Pathfinding
"""
def
__init__
(
self
,
parent
=
None
,
pos
=
None
):
self
.
parent
=
parent
self
.
pos
=
pos
self
.
g
=
0
self
.
h
=
0
self
.
f
=
0
def
__eq__
(
self
,
other
):
return
self
.
pos
==
other
.
pos
def
update_if_better
(
self
,
other
):
if
other
.
g
<
self
.
g
:
self
.
parent
=
other
.
parent
self
.
g
=
other
.
g
self
.
h
=
other
.
h
self
.
f
=
other
.
f
def
get_direction
(
pos1
,
pos2
):
"""
Assumes pos1 and pos2 are adjacent location on grid.
Returns direction (int) that can be used with transitions.
"""
diff_0
=
pos2
[
0
]
-
pos1
[
0
]
diff_1
=
pos2
[
1
]
-
pos1
[
1
]
if
diff_0
<
0
:
return
0
if
diff_0
>
0
:
return
2
if
diff_1
>
0
:
return
1
if
diff_1
<
0
:
return
3
return
0
def
mirror
(
dir
):
return
(
dir
+
2
)
%
4
def
validate_new_transition
(
rail_trans
,
rail_array
,
prev_pos
,
current_pos
,
new_pos
,
end_pos
):
# start by getting direction used to get to current node
# and direction from current node to possible child node
new_dir
=
get_direction
(
current_pos
,
new_pos
)
if
prev_pos
is
not
None
:
current_dir
=
get_direction
(
prev_pos
,
current_pos
)
else
:
current_dir
=
new_dir
# create new transition that would go to child
new_trans
=
rail_array
[
current_pos
]
if
prev_pos
is
None
:
if
new_trans
==
0
:
# need to flip direction because of how end points are defined
new_trans
=
rail_trans
.
set_transition
(
new_trans
,
mirror
(
current_dir
),
new_dir
,
1
)
else
:
# check if matches existing layout
new_trans
=
rail_trans
.
set_transition
(
new_trans
,
current_dir
,
new_dir
,
1
)
# new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
# rail_trans.print(new_trans)
else
:
# set the forward path
new_trans
=
rail_trans
.
set_transition
(
new_trans
,
current_dir
,
new_dir
,
1
)
# set the backwards path
new_trans
=
rail_trans
.
set_transition
(
new_trans
,
mirror
(
new_dir
),
mirror
(
current_dir
),
1
)
if
new_pos
==
end_pos
:
# need to validate end pos setup as well
new_trans_e
=
rail_array
[
end_pos
]
if
new_trans_e
==
0
:
# need to flip direction because of how end points are defined
new_trans_e
=
rail_trans
.
set_transition
(
new_trans_e
,
new_dir
,
mirror
(
new_dir
),
1
)
else
:
# check if matches existing layout
new_trans_e
=
rail_trans
.
set_transition
(
new_trans_e
,
new_dir
,
new_dir
,
1
)
# new_trans_e = rail_trans.set_transition(new_trans_e, mirror(new_dir), mirror(new_dir), 1)
# print("end:", end_pos, current_pos)
# rail_trans.print(new_trans_e)
# print("========> end trans")
# rail_trans.print(new_trans_e)
if
not
rail_trans
.
is_valid
(
new_trans_e
):
# print("end failed", end_pos, current_pos)
return
False
# else:
# print("end ok!", end_pos, current_pos)
# is transition is valid?
# print("=======> trans")
# rail_trans.print(new_trans)
return
rail_trans
.
is_valid
(
new_trans
)
def
a_star
(
rail_trans
,
rail_array
,
start
,
end
):
"""
Returns a list of tuples as a path from the given start to end.
If no path is found, returns path to closest point to end.
"""
rail_shape
=
rail_array
.
shape
start_node
=
AStarNode
(
None
,
start
)
end_node
=
AStarNode
(
None
,
end
)
open_list
=
[]
closed_list
=
[]
open_list
.
append
(
start_node
)
# this could be optimized
def
is_node_in_list
(
node
,
the_list
):
for
o_node
in
the_list
:
if
node
==
o_node
:
return
o_node
return
None
while
len
(
open_list
)
>
0
:
# get node with current shortest est. path (lowest f)
current_node
=
open_list
[
0
]
current_index
=
0
for
index
,
item
in
enumerate
(
open_list
):
if
item
.
f
<
current_node
.
f
:
current_node
=
item
current_index
=
index
# pop current off open list, add to closed list
open_list
.
pop
(
current_index
)
closed_list
.
append
(
current_node
)
# print("a*:", current_node.pos)
# for cn in closed_list:
# print("closed:", cn.pos)
# found the goal
if
current_node
==
end_node
:
path
=
[]
current
=
current_node
while
current
is
not
None
:
path
.
append
(
current
.
pos
)
current
=
current
.
parent
# return reversed path
return
path
[::
-
1
]
# generate children
children
=
[]
if
current_node
.
parent
is
not
None
:
prev_pos
=
current_node
.
parent
.
pos
else
:
prev_pos
=
None
for
new_pos
in
[(
0
,
-
1
),
(
0
,
1
),
(
-
1
,
0
),
(
1
,
0
)]:
node_pos
=
(
current_node
.
pos
[
0
]
+
new_pos
[
0
],
current_node
.
pos
[
1
]
+
new_pos
[
1
])
if
node_pos
[
0
]
>=
rail_shape
[
0
]
or
\
node_pos
[
0
]
<
0
or
\
node_pos
[
1
]
>=
rail_shape
[
1
]
or
\
node_pos
[
1
]
<
0
:
continue
# validate positions
# debug: avoid all current rails
# if rail_array.item(node_pos) != 0:
# continue
# validate positions
if
not
validate_new_transition
(
rail_trans
,
rail_array
,
prev_pos
,
current_node
.
pos
,
node_pos
,
end_node
.
pos
):
# print("A*: transition invalid")
continue
# create new node
new_node
=
AStarNode
(
current_node
,
node_pos
)
children
.
append
(
new_node
)
# loop through children
for
child
in
children
:
# already in closed list?
closed_node
=
is_node_in_list
(
child
,
closed_list
)
if
closed_node
is
not
None
:
continue
# create the f, g, and h values
child
.
g
=
current_node
.
g
+
1
# this heuristic favors diagonal paths
# child.h = ((child.pos[0] - end_node.pos[0]) ** 2) + \
# ((child.pos[1] - end_node.pos[1]) ** 2)
# this heuristic avoids diagonal paths
child
.
h
=
abs
(
child
.
pos
[
0
]
-
end_node
.
pos
[
0
])
+
abs
(
child
.
pos
[
1
]
-
end_node
.
pos
[
1
])
child
.
f
=
child
.
g
+
child
.
h
# already in the open list?
open_node
=
is_node_in_list
(
child
,
open_list
)
if
open_node
is
not
None
:
open_node
.
update_if_better
(
child
)
continue
# add the child to the open list
open_list
.
append
(
child
)
# no full path found, return partial path
if
len
(
open_list
)
==
0
:
path
=
[]
current
=
current_node
while
current
is
not
None
:
path
.
append
(
current
.
pos
)
current
=
current
.
parent
# return reversed path
print
(
"
partial:
"
,
start
,
end
,
path
[::
-
1
])
return
path
[::
-
1
]
def
connect_rail
(
rail_trans
,
rail_array
,
start
,
end
):
"""
Creates a new path [start,end] in rail_array, based on rail_trans.
"""
# in the worst case we will need to do a A* search, so we might as well set that up
path
=
a_star
(
rail_trans
,
rail_array
,
start
,
end
)
# print("connecting path", path)
if
len
(
path
)
<
2
:
return
current_dir
=
get_direction
(
path
[
0
],
path
[
1
])
end_pos
=
path
[
-
1
]
for
index
in
range
(
len
(
path
)
-
1
):
current_pos
=
path
[
index
]
new_pos
=
path
[
index
+
1
]
new_dir
=
get_direction
(
current_pos
,
new_pos
)
new_trans
=
rail_array
[
current_pos
]
if
index
==
0
:
if
new_trans
==
0
:
# end-point
# need to flip direction because of how end points are defined
new_trans
=
rail_trans
.
set_transition
(
new_trans
,
mirror
(
current_dir
),
new_dir
,
1
)
else
:
# into existing rail
new_trans
=
rail_trans
.
set_transition
(
new_trans
,
current_dir
,
new_dir
,
1
)
# new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
pass
else
:
# set the forward path
new_trans
=
rail_trans
.
set_transition
(
new_trans
,
current_dir
,
new_dir
,
1
)
# set the backwards path
new_trans
=
rail_trans
.
set_transition
(
new_trans
,
mirror
(
new_dir
),
mirror
(
current_dir
),
1
)
rail_array
[
current_pos
]
=
new_trans
if
new_pos
==
end_pos
:
# setup end pos setup
new_trans_e
=
rail_array
[
end_pos
]
if
new_trans_e
==
0
:
# end-point
new_trans_e
=
rail_trans
.
set_transition
(
new_trans_e
,
new_dir
,
mirror
(
new_dir
),
1
)
else
:
# into existing rail
new_trans_e
=
rail_trans
.
set_transition
(
new_trans_e
,
new_dir
,
new_dir
,
1
)
# new_trans_e = rail_trans.set_transition(new_trans_e, mirror(new_dir), mirror(new_dir), 1)
rail_array
[
end_pos
]
=
new_trans_e
current_dir
=
new_dir
def
distance_on_rail
(
pos1
,
pos2
):
return
abs
(
pos1
[
0
]
-
pos2
[
0
])
+
abs
(
pos1
[
1
]
-
pos2
[
1
])
This diff is collapsed.
Click to expand it.
flatland/envs/generators.py
0 → 100644
+
478
−
0
View file @
55957edf
import
numpy
as
np
# from flatland.core.env import Environment
# from flatland.core.env_observation_builder import TreeObsForRailEnv
from
flatland.core.transitions
import
Grid8Transitions
,
RailEnvTransitions
from
flatland.core.transition_map
import
GridTransitionMap
from
flatland.envs.env_utils
import
distance_on_rail
,
connect_rail
def
complex_rail_generator
(
nr_start_goal
=
1
,
min_dist
=
2
,
max_dist
=
99999
,
seed
=
0
):
"""
Parameters
-------
width : int
The width (number of cells) of the grid to generate.
height : int
The height (number of cells) of the grid to generate.
Returns
-------
numpy.ndarray of type numpy.uint16
The matrix with the correct 16-bit bitmaps for each cell.
"""
def
generator
(
width
,
height
,
num_resets
=
0
):
rail_trans
=
RailEnvTransitions
()
rail_array
=
np
.
zeros
(
shape
=
(
width
,
height
),
dtype
=
np
.
uint16
)
np
.
random
.
seed
(
seed
+
num_resets
)
# generate rail array
# step 1:
# - generate a list of start and goal positions
# - use a min/max distance allowed as input for this
# - validate that start/goals are not placed too close to other start/goals
#
# step 2: (optional)
# - place random elements on rails array
# - for instance "train station", etc.
#
# step 3:
# - iterate over all [start, goal] pairs:
# - [first X pairs]
# - draw a rail from [start,goal]
# - draw either vertical or horizontal part first (randomly)
# - if rail crosses existing rail then validate new connection
# - if new connection is invalid turn 90 degrees to left/right
# - possibility that this fails to create a path to goal
# - on failure goto step1 and retry with seed+1
# - [avoid crossing other start,goal positions] (optional)
#
# - [after X pairs]
# - find closest rail from start (Pa)
# - iterating outwards in a "circle" from start until an existing rail cell is hit
# - connect [start, Pa]
# - validate crossing rails
# - Do A* from Pa to find closest point on rail (Pb) to goal point
# - Basically normal A* but find point on rail which is closest to goal
# - since full path to goal is unlikely
# - connect [Pb, goal]
# - validate crossing rails
#
# step 4: (optional)
# - add more rails to map randomly
#
# step 5:
# - return transition map + list of [start, goal] points
#
start_goal
=
[]
for
_
in
range
(
nr_start_goal
):
sanity_max
=
9000
for
_
in
range
(
sanity_max
):
start
=
(
np
.
random
.
randint
(
0
,
width
),
np
.
random
.
randint
(
0
,
height
))
goal
=
(
np
.
random
.
randint
(
0
,
height
),
np
.
random
.
randint
(
0
,
height
))
# check to make sure start,goal pos is empty?
if
rail_array
[
goal
]
!=
0
or
rail_array
[
start
]
!=
0
:
continue
# check min/max distance
dist_sg
=
distance_on_rail
(
start
,
goal
)
if
dist_sg
<
min_dist
:
continue
if
dist_sg
>
max_dist
:
continue
# check distance to existing points
sg_new
=
[
start
,
goal
]
def
check_all_dist
(
sg_new
):
for
sg
in
start_goal
:
for
i
in
range
(
2
):
for
j
in
range
(
2
):
dist
=
distance_on_rail
(
sg_new
[
i
],
sg
[
j
])
if
dist
<
2
:
# print("too close:", dist, sg_new[i], sg[j])
return
False
return
True
if
check_all_dist
(
sg_new
):
break
start_goal
.
append
([
start
,
goal
])
connect_rail
(
rail_trans
,
rail_array
,
start
,
goal
)
print
(
"
Created #
"
,
len
(
start_goal
),
"
pairs
"
)
# print(start_goal)
return_rail
=
GridTransitionMap
(
width
=
width
,
height
=
height
,
transitions
=
rail_trans
)
return_rail
.
grid
=
rail_array
# TODO: return start_goal
return
return_rail
return
generator
def
rail_from_manual_specifications_generator
(
rail_spec
):
"""
Utility to convert a rail given by manual specification as a map of tuples
(cell_type, rotation), to a transition map with the correct 16-bit
transitions specifications.
Parameters
-------
rail_spec : list of list of tuples
List (rows) of lists (columns) of tuples, each specifying a cell for
the RailEnv environment as (cell_type, rotation), with rotation being
clock-wise and in [0, 90, 180, 270].
Returns
-------
function
Generator function that always returns a GridTransitionMap object with
the matrix of correct 16-bit bitmaps for each cell.
"""
def
generator
(
width
,
height
,
num_resets
=
0
):
t_utils
=
RailEnvTransitions
()
height
=
len
(
rail_spec
)
width
=
len
(
rail_spec
[
0
])
rail
=
GridTransitionMap
(
width
=
width
,
height
=
height
,
transitions
=
t_utils
)
for
r
in
range
(
height
):
for
c
in
range
(
width
):
cell
=
rail_spec
[
r
][
c
]
if
cell
[
0
]
<
0
or
cell
[
0
]
>=
len
(
t_utils
.
transitions
):
print
(
"
ERROR - invalid cell type=
"
,
cell
[
0
])
return
[]
rail
.
set_transitions
((
r
,
c
),
t_utils
.
rotate_transition
(
t_utils
.
transitions
[
cell
[
0
]],
cell
[
1
]))
return
rail
return
generator
def
rail_from_GridTransitionMap_generator
(
rail_map
):
"""
Utility to convert a rail given by a GridTransitionMap map with the correct
16-bit transitions specifications.
Parameters
-------
rail_map : GridTransitionMap object
GridTransitionMap object to return when the generator is called.
Returns
-------
function
Generator function that always returns the given `rail_map
'
object.
"""
def
generator
(
width
,
height
,
num_resets
=
0
):
return
rail_map
return
generator
def
rail_from_list_of_saved_GridTransitionMap_generator
(
list_of_filenames
):
"""
Utility to sequentially and cyclically return GridTransitionMap-s from a list of files, on each environment reset.
Parameters
-------
list_of_filenames : list
List of filenames with the saved grids to load.
Returns
-------
function
Generator function that always returns the given `rail_map
'
object.
"""
def
generator
(
width
,
height
,
num_resets
=
0
):
t_utils
=
RailEnvTransitions
()
rail_map
=
GridTransitionMap
(
width
=
width
,
height
=
height
,
transitions
=
t_utils
)
rail_map
.
load_transition_map
(
list_of_filenames
[
num_resets
%
len
(
list_of_filenames
)],
override_gridsize
=
False
)
if
rail_map
.
grid
.
dtype
==
np
.
uint64
:
rail_map
.
transitions
=
Grid8Transitions
()
return
rail_map
return
generator
"""
def generate_rail_from_list_of_manual_specifications(list_of_specifications)
def generator(width, height, num_resets=0):
return generate_rail_from_manual_specifications(list_of_specifications)
return generator
"""
def
random_rail_generator
(
cell_type_relative_proportion
=
[
1.0
]
*
8
):
"""
Dummy random level generator:
- fill in cells at random in [width-2, height-2]
- keep filling cells in among the unfilled ones, such that all transitions
are legit; if no cell can be filled in without violating some
transitions, pick one among those that can satisfy most transitions
(1,2,3 or 4), and delete (+mark to be re-filled) the cells that were
incompatible.
- keep trying for a total number of insertions
(e.g., (W-2)*(H-2)*MAX_REPETITIONS ); if no solution is found, empty the
board and try again from scratch.
- finally pad the border of the map with dead-ends to avoid border issues.
Dead-ends are not allowed inside the grid, only at the border; however, if
no cell type can be inserted in a given cell (because of the neighboring
transitions), deadends are allowed if they solve the problem. This was
found to turn most un-genereatable levels into valid ones.
Parameters
-------
width : int
The width (number of cells) of the grid to generate.
height : int
The height (number of cells) of the grid to generate.
Returns
-------
numpy.ndarray of type numpy.uint16
The matrix with the correct 16-bit bitmaps for each cell.
"""
def
generator
(
width
,
height
,
num_resets
=
0
):
t_utils
=
RailEnvTransitions
()
transition_probability
=
cell_type_relative_proportion
transitions_templates_
=
[]
transition_probabilities
=
[]
for
i
in
range
(
len
(
t_utils
.
transitions
)
-
4
):
# don't include dead-ends
all_transitions
=
0
for
dir_
in
range
(
4
):
trans
=
t_utils
.
get_transitions
(
t_utils
.
transitions
[
i
],
dir_
)
all_transitions
|=
(
trans
[
0
]
<<
3
)
|
\
(
trans
[
1
]
<<
2
)
|
\
(
trans
[
2
]
<<
1
)
|
\
(
trans
[
3
])
template
=
[
int
(
x
)
for
x
in
bin
(
all_transitions
)[
2
:]]
template
=
[
0
]
*
(
4
-
len
(
template
))
+
template
# add all rotations
for
rot
in
[
0
,
90
,
180
,
270
]:
transitions_templates_
.
append
((
template
,
t_utils
.
rotate_transition
(
t_utils
.
transitions
[
i
],
rot
)))
transition_probabilities
.
append
(
transition_probability
[
i
])
template
=
[
template
[
-
1
]]
+
template
[:
-
1
]
def
get_matching_templates
(
template
):
ret
=
[]
for
i
in
range
(
len
(
transitions_templates_
)):
is_match
=
True
for
j
in
range
(
4
):
if
template
[
j
]
>=
0
and
template
[
j
]
!=
transitions_templates_
[
i
][
0
][
j
]:
is_match
=
False
break
if
is_match
:
ret
.
append
((
transitions_templates_
[
i
][
1
],
transition_probabilities
[
i
]))
return
ret
MAX_INSERTIONS
=
(
width
-
2
)
*
(
height
-
2
)
*
10
MAX_ATTEMPTS_FROM_SCRATCH
=
10
attempt_number
=
0
while
attempt_number
<
MAX_ATTEMPTS_FROM_SCRATCH
:
cells_to_fill
=
[]
rail
=
[]
for
r
in
range
(
height
):
rail
.
append
([
None
]
*
width
)
if
r
>
0
and
r
<
height
-
1
:
cells_to_fill
=
cells_to_fill
+
[(
r
,
c
)
for
c
in
range
(
1
,
width
-
1
)]
num_insertions
=
0
while
num_insertions
<
MAX_INSERTIONS
and
len
(
cells_to_fill
)
>
0
:
# cell = random.sample(cells_to_fill, 1)[0]
cell
=
cells_to_fill
[
np
.
random
.
choice
(
len
(
cells_to_fill
),
1
)[
0
]]
cells_to_fill
.
remove
(
cell
)
row
=
cell
[
0
]
col
=
cell
[
1
]
# look at its neighbors and see what are the possible transitions
# that can be chosen from, if any.
valid_template
=
[
-
1
,
-
1
,
-
1
,
-
1
]
for
el
in
[(
0
,
2
,
(
-
1
,
0
)),
(
1
,
3
,
(
0
,
1
)),
(
2
,
0
,
(
1
,
0
)),
(
3
,
1
,
(
0
,
-
1
))]:
# N, E, S, W
neigh_trans
=
rail
[
row
+
el
[
2
][
0
]][
col
+
el
[
2
][
1
]]
if
neigh_trans
is
not
None
:
# select transition coming from facing direction el[1] and
# moving to direction el[1]
max_bit
=
0
for
k
in
range
(
4
):
max_bit
|=
t_utils
.
get_transition
(
neigh_trans
,
k
,
el
[
1
])
if
max_bit
:
valid_template
[
el
[
0
]]
=
1
else
:
valid_template
[
el
[
0
]]
=
0
possible_cell_transitions
=
get_matching_templates
(
valid_template
)
if
len
(
possible_cell_transitions
)
==
0
:
# NO VALID TRANSITIONS
# no cell can be filled in without violating some transitions
# can a dead-end solve the problem?
if
valid_template
.
count
(
1
)
==
1
:
for
k
in
range
(
4
):
if
valid_template
[
k
]
==
1
:
rot
=
0
if
k
==
0
:
rot
=
180
elif
k
==
1
:
rot
=
270
elif
k
==
2
:
rot
=
0
elif
k
==
3
:
rot
=
90
rail
[
row
][
col
]
=
t_utils
.
rotate_transition
(
int
(
'
0010000000000000
'
,
2
),
rot
)
num_insertions
+=
1
break
else
:
# can I get valid transitions by removing a single
# neighboring cell?
bestk
=
-
1
besttrans
=
[]
for
k
in
range
(
4
):
tmp_template
=
valid_template
[:]
tmp_template
[
k
]
=
-
1
possible_cell_transitions
=
get_matching_templates
(
tmp_template
)
if
len
(
possible_cell_transitions
)
>
len
(
besttrans
):
besttrans
=
possible_cell_transitions
bestk
=
k
if
bestk
>=
0
:
# Replace the corresponding cell with None, append it
# to cells to fill, fill in a transition in the current
# cell.
replace_row
=
row
-
1
replace_col
=
col
if
bestk
==
1
:
replace_row
=
row
replace_col
=
col
+
1
elif
bestk
==
2
:
replace_row
=
row
+
1
replace_col
=
col
elif
bestk
==
3
:
replace_row
=
row
replace_col
=
col
-
1
cells_to_fill
.
append
((
replace_row
,
replace_col
))
rail
[
replace_row
][
replace_col
]
=
None
possible_transitions
,
possible_probabilities
=
zip
(
*
besttrans
)
possible_probabilities
=
[
p
/
sum
(
possible_probabilities
)
for
p
in
possible_probabilities
]
rail
[
row
][
col
]
=
np
.
random
.
choice
(
possible_transitions
,
p
=
possible_probabilities
)
num_insertions
+=
1
else
:
print
(
'
WARNING: still nothing!
'
)
rail
[
row
][
col
]
=
int
(
'
0000000000000000
'
,
2
)
num_insertions
+=
1
pass
else
:
possible_transitions
,
possible_probabilities
=
zip
(
*
possible_cell_transitions
)
possible_probabilities
=
[
p
/
sum
(
possible_probabilities
)
for
p
in
possible_probabilities
]
rail
[
row
][
col
]
=
np
.
random
.
choice
(
possible_transitions
,
p
=
possible_probabilities
)
num_insertions
+=
1
if
num_insertions
==
MAX_INSERTIONS
:
# Failed to generate a valid level; try again for a number of times
attempt_number
+=
1
else
:
break
if
attempt_number
==
MAX_ATTEMPTS_FROM_SCRATCH
:
print
(
'
ERROR: failed to generate level
'
)
# Finally pad the border of the map with dead-ends to avoid border issues;
# at most 1 transition in the neigh cell
for
r
in
range
(
height
):
# Check for transitions coming from [r][1] to WEST
max_bit
=
0
neigh_trans
=
rail
[
r
][
1
]
if
neigh_trans
is
not
None
:
for
k
in
range
(
4
):
neigh_trans_from_direction
=
(
neigh_trans
>>
((
3
-
k
)
*
4
))
&
(
2
**
4
-
1
)
max_bit
=
max_bit
|
(
neigh_trans_from_direction
&
1
)
if
max_bit
:
rail
[
r
][
0
]
=
t_utils
.
rotate_transition
(
int
(
'
0010000000000000
'
,
2
),
270
)
else
:
rail
[
r
][
0
]
=
int
(
'
0000000000000000
'
,
2
)
# Check for transitions coming from [r][-2] to EAST
max_bit
=
0
neigh_trans
=
rail
[
r
][
-
2
]
if
neigh_trans
is
not
None
:
for
k
in
range
(
4
):
neigh_trans_from_direction
=
(
neigh_trans
>>
((
3
-
k
)
*
4
))
&
(
2
**
4
-
1
)
max_bit
=
max_bit
|
(
neigh_trans_from_direction
&
(
1
<<
2
))
if
max_bit
:
rail
[
r
][
-
1
]
=
t_utils
.
rotate_transition
(
int
(
'
0010000000000000
'
,
2
),
90
)
else
:
rail
[
r
][
-
1
]
=
int
(
'
0000000000000000
'
,
2
)
for
c
in
range
(
width
):
# Check for transitions coming from [1][c] to NORTH
max_bit
=
0
neigh_trans
=
rail
[
1
][
c
]
if
neigh_trans
is
not
None
:
for
k
in
range
(
4
):
neigh_trans_from_direction
=
(
neigh_trans
>>
((
3
-
k
)
*
4
))
&
(
2
**
4
-
1
)
max_bit
=
max_bit
|
(
neigh_trans_from_direction
&
(
1
<<
3
))
if
max_bit
:
rail
[
0
][
c
]
=
int
(
'
0010000000000000
'
,
2
)
else
:
rail
[
0
][
c
]
=
int
(
'
0000000000000000
'
,
2
)
# Check for transitions coming from [-2][c] to SOUTH
max_bit
=
0
neigh_trans
=
rail
[
-
2
][
c
]
if
neigh_trans
is
not
None
:
for
k
in
range
(
4
):
neigh_trans_from_direction
=
(
neigh_trans
>>
((
3
-
k
)
*
4
))
&
(
2
**
4
-
1
)
max_bit
=
max_bit
|
(
neigh_trans_from_direction
&
(
1
<<
1
))
if
max_bit
:
rail
[
-
1
][
c
]
=
t_utils
.
rotate_transition
(
int
(
'
0010000000000000
'
,
2
),
180
)
else
:
rail
[
-
1
][
c
]
=
int
(
'
0000000000000000
'
,
2
)
# For display only, wrong levels
for
r
in
range
(
height
):
for
c
in
range
(
width
):
if
rail
[
r
][
c
]
is
None
:
rail
[
r
][
c
]
=
int
(
'
0000000000000000
'
,
2
)
tmp_rail
=
np
.
asarray
(
rail
,
dtype
=
np
.
uint16
)
return_rail
=
GridTransitionMap
(
width
=
width
,
height
=
height
,
transitions
=
t_utils
)
return_rail
.
grid
=
tmp_rail
return
return_rail
return
generator
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment