Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
Flatland
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
yoogottamk
Flatland
Commits
69063b3b
Commit
69063b3b
authored
5 years ago
by
maljx
Browse files
Options
Downloads
Patches
Plain Diff
new level gen, work in progress
parent
5ab5adfd
No related branches found
No related tags found
No related merge requests found
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
examples/play_model.py
+12
-13
12 additions, 13 deletions
examples/play_model.py
flatland/core/transitions.py
+24
-0
24 additions, 0 deletions
flatland/core/transitions.py
flatland/envs/rail_env.py
+248
-0
248 additions, 0 deletions
flatland/envs/rail_env.py
with
284 additions
and
13 deletions
examples/play_model.py
+
12
−
13
View file @
69063b3b
from
flatland.envs.rail_env
import
RailEnv
,
random_rail_generator
from
flatland.envs.rail_env
import
RailEnv
,
random_rail_generator
,
complex_rail_generator
# from flatland.core.env_observation_builder import TreeObsForRailEnv
# from flatland.core.env_observation_builder import TreeObsForRailEnv
from
flatland.utils.rendertools
import
RenderTool
from
flatland.utils.rendertools
import
RenderTool
from
flatland.baselines.dueling_double_dqn
import
Agent
from
flatland.baselines.dueling_double_dqn
import
Agent
...
@@ -17,20 +17,19 @@ def main(render=True, delay=0.0):
...
@@ -17,20 +17,19 @@ def main(render=True, delay=0.0):
# Example generate a rail given a manual specification,
# Example generate a rail given a manual specification,
# a map of tuples (cell_type, rotation)
# a map of tuples (cell_type, rotation)
transition_probability
=
[
0.5
,
# empty cell - Case 0
#
transition_probability = [0.5, # empty cell - Case 0
1.0
,
# Case 1 - straight
#
1.0, # Case 1 - straight
1.0
,
# Case 2 - simple switch
#
1.0, # Case 2 - simple switch
0.3
,
# Case 3 - diamond drossing
#
0.3, # Case 3 - diamond drossing
0.5
,
# Case 4 - single slip
#
0.5, # Case 4 - single slip
0.5
,
# Case 5 - double slip
#
0.5, # Case 5 - double slip
0.2
,
# Case 6 - symmetrical
#
0.2, # Case 6 - symmetrical
0.0
]
# Case 7 - dead end
#
0.0] # Case 7 - dead end
# Example generate a random rail
# Example generate a random rail
env
=
RailEnv
(
width
=
15
,
env
=
RailEnv
(
width
=
15
,
height
=
15
,
height
=
15
,
rail_generator
=
complex_rail_generator
(),
rail_generator
=
random_rail_generator
(
cell_type_relative_proportion
=
transition_probability
),
number_of_agents
=
1
)
number_of_agents
=
5
)
if
render
:
if
render
:
env_renderer
=
RenderTool
(
env
,
gl
=
"
QT
"
)
env_renderer
=
RenderTool
(
env
,
gl
=
"
QT
"
)
...
...
This diff is collapsed.
Click to expand it.
flatland/core/transitions.py
+
24
−
0
View file @
69063b3b
...
@@ -537,3 +537,27 @@ class RailEnvTransitions(Grid4Transitions):
...
@@ -537,3 +537,27 @@ class RailEnvTransitions(Grid4Transitions):
super
(
RailEnvTransitions
,
self
).
__init__
(
super
(
RailEnvTransitions
,
self
).
__init__
(
transitions
=
self
.
transition_list
transitions
=
self
.
transition_list
)
)
def
is_valid
(
self
,
cell_transition
):
"""
Checks if a cell transition is a valid cell setup.
Parameters
----------
cell_transition : int
64 bits used to encode the valid transitions for a cell.
Returns
-------
Boolean
True or False
"""
for
trans
in
self
.
transitions
:
if
cell_transition
==
trans
:
return
True
for
_
in
range
(
3
):
trans
=
self
.
rotate_transition
(
trans
,
rotation
=
90
)
if
cell_transition
==
trans
:
return
True
return
False
This diff is collapsed.
Click to expand it.
flatland/envs/rail_env.py
+
248
−
0
View file @
69063b3b
...
@@ -13,6 +13,254 @@ from flatland.core.transitions import Grid8Transitions, RailEnvTransitions
...
@@ -13,6 +13,254 @@ from flatland.core.transitions import Grid8Transitions, RailEnvTransitions
from
flatland.core.transition_map
import
GridTransitionMap
from
flatland.core.transition_map
import
GridTransitionMap
class
AStarNode
():
"""
A node class for A* Pathfinding
"""
def
__init__
(
self
,
parent
=
None
,
pos
=
None
):
self
.
parent
=
parent
self
.
pos
=
pos
self
.
g
=
0
self
.
h
=
0
self
.
f
=
0
def
__eq__
(
self
,
other
):
return
self
.
pos
==
other
.
pos
def
update_if_better
(
self
,
other
):
if
other
.
g
<
self
.
g
:
self
.
parent
=
other
.
parent
self
.
g
=
other
.
g
self
.
h
=
other
.
h
self
.
f
=
other
.
f
def
a_star
(
rail_array
,
start
,
end
):
"""
Returns a list of tuples as a path from the given start to end.
If no path is found, returns path to closest point to end.
"""
rail_shape
=
rail_array
.
shape
start_node
=
AStarNode
(
None
,
start
)
end_node
=
AStarNode
(
None
,
end
)
open_list
=
[]
closed_list
=
[]
open_list
.
append
(
start_node
)
# this could be optimized
def
is_node_in_list
(
node
,
the_list
):
for
o_node
in
the_list
:
if
node
==
o_node
:
return
o_node
return
None
while
len
(
open_list
)
>
0
:
# get node with current shortest est. path (lowest f)
current_node
=
open_list
[
0
]
current_index
=
0
for
index
,
item
in
enumerate
(
open_list
):
if
item
.
f
<
current_node
.
f
:
current_node
=
item
current_index
=
index
# pop current off open list, add to closed list
open_list
.
pop
(
current_index
)
closed_list
.
append
(
current_node
)
# print("a*:", current_node.pos)
# for cn in closed_list:
# print("closed:", cn.pos)
# found the goal
if
current_node
==
end_node
:
path
=
[]
current
=
current_node
while
current
is
not
None
:
path
.
append
(
current
.
pos
)
current
=
current
.
parent
# return reversed path
return
path
[::
-
1
]
# generate children
children
=
[]
for
new_pos
in
[(
0
,
-
1
),
(
0
,
1
),
(
-
1
,
0
),
(
1
,
0
)]:
node_pos
=
(
current_node
.
pos
[
0
]
+
new_pos
[
0
],
current_node
.
pos
[
1
]
+
new_pos
[
1
])
if
node_pos
[
0
]
>=
rail_shape
[
0
]
or
\
node_pos
[
0
]
<
0
or
\
node_pos
[
1
]
>=
rail_shape
[
1
]
or
\
node_pos
[
1
]
<
0
:
continue
# validate positions
# debug: avoid all current rails
# if rail_array.item(node_pos) != 0:
# continue
# create new node
new_node
=
AStarNode
(
current_node
,
node_pos
)
children
.
append
(
new_node
)
# loop through children
for
child
in
children
:
# already in closed list?
closed_node
=
is_node_in_list
(
child
,
closed_list
)
if
closed_node
is
not
None
:
continue
# create the f, g, and h values
child
.
g
=
current_node
.
g
+
1
# this heuristic favors diagonal paths
# child.h = ((child.pos[0] - end_node.pos[0]) ** 2) + \
# ((child.pos[1] - end_node.pos[1]) ** 2)
# this heuristic avoids diagonal paths
child
.
h
=
abs
(
child
.
pos
[
0
]
-
end_node
.
pos
[
0
])
+
abs
(
child
.
pos
[
1
]
-
end_node
.
pos
[
1
])
child
.
f
=
child
.
g
+
child
.
h
# already in the open list?
open_node
=
is_node_in_list
(
child
,
open_list
)
if
open_node
is
not
None
:
open_node
.
update_if_better
(
child
)
continue
# add the child to the open list
open_list
.
append
(
child
)
# no full path found, return partial path
if
len
(
open_list
)
==
0
:
path
=
[]
current
=
current_node
while
current
is
not
None
:
path
.
append
(
current
.
pos
)
current
=
current
.
parent
# return reversed path
return
path
[::
-
1
]
def
complex_rail_generator
(
nr_start_goal
=
10
,
min_dist
=
0
,
max_dist
=
99999
,
seed
=
0
):
"""
Parameters
-------
width : int
The width (number of cells) of the grid to generate.
height : int
The height (number of cells) of the grid to generate.
Returns
-------
numpy.ndarray of type numpy.uint16
The matrix with the correct 16-bit bitmaps for each cell.
"""
def
generator
(
width
,
height
,
num_resets
=
0
):
rail_trans
=
RailEnvTransitions
()
rail_array
=
np
.
zeros
(
shape
=
(
width
,
height
),
dtype
=
np
.
uint16
)
np
.
random
.
seed
(
seed
)
# generate rail array
# step 1:
# - generate a list of start and goal positions
# - use a min/max distance allowed as input for this
# - validate that start/goals are not placed too close to other start/goals
#
# step 2: (optional)
# - place random elements on rails array
# - for instance "train station", etc.
#
# step 3:
# - iterate over all [start, goal] pairs:
# - [first X pairs]
# - draw a rail from [start,goal]
# - draw either vertical or horizontal part first (randomly)
# - if rail crosses existing rail then validate new connection
# - if new connection is invalid turn 90 degrees to left/right
# - possibility that this fails to create a path to goal
# - on failure goto step1 and retry with seed+1
# - [avoid crossing other start,goal positions] (optional)
#
# - [after X pairs]
# - find closest rail from start (Pa)
# - iterating outwards in a "circle" from start until an existing rail cell is hit
# - connect [start, Pa]
# - validate crossing rails
# - Do A* from Pa to find closest point on rail (Pb) to goal point
# - Basically normal A* but find point on rail which is closest to goal
# - since full path to goal is unlikely
# - connect [Pb, goal]
# - validate crossing rails
#
# step 4: (optional)
# - add more rails to map randomly
#
# step 5:
# - return transition map + list of [start, goal] points
#
start_goal
=
[]
for
_
in
range
(
nr_start_goal
):
start
=
(
np
.
random
.
randint
(
0
,
width
),
np
.
random
.
randint
(
0
,
height
))
goal
=
(
np
.
random
.
randint
(
0
,
height
),
np
.
random
.
randint
(
0
,
height
))
# TODO: validate closeness with existing points
# TODO: make sure min/max distance condition is met
start_goal
.
append
([
start
,
goal
])
def
get_direction
(
pos1
,
pos2
):
diff_0
=
pos2
[
0
]
-
pos1
[
0
]
diff_1
=
pos2
[
1
]
-
pos1
[
1
]
if
diff_0
<
0
:
return
0
if
diff_0
>
0
:
return
2
if
diff_1
>
0
:
return
1
if
diff_1
<
0
:
return
3
return
0
def
connect_two_cells
(
pos1
,
pos2
):
# connect two adjacent cells
direction
=
get_direction
(
pos1
,
pos2
)
rail_array
[
pos1
]
=
rail_trans
.
set_transition
(
rail_array
[
pos1
],
direction
,
direction
,
1
)
o_dir
=
(
direction
+
2
)
%
4
rail_array
[
pos2
]
=
rail_trans
.
set_transition
(
rail_array
[
pos2
],
o_dir
,
o_dir
,
1
)
def
connect_rail
(
start
,
end
):
# in the worst case we will need to do a A* search, so we might as well set that up
# TODO: need to check transitions in A* to see if new path is valid
path
=
a_star
(
rail_array
,
start
,
end
)
print
(
"
connecting path
"
,
path
)
if
len
(
path
)
<
2
:
return
if
len
(
path
)
==
2
:
connect_two_cells
(
path
[
0
],
path
[
1
])
return
current_dir
=
get_direction
(
path
[
0
],
path
[
1
])
for
index
in
range
(
len
(
path
)):
pos1
=
path
[
index
]
if
index
+
1
<
len
(
path
):
new_dir
=
get_direction
(
pos1
,
path
[
index
+
1
])
else
:
new_dir
=
current_dir
cell_trans
=
rail_array
[
pos1
]
if
index
!=
len
(
path
)
-
1
:
# set the forward path
cell_trans
=
rail_trans
.
set_transition
(
cell_trans
,
current_dir
,
new_dir
,
1
)
if
index
!=
0
:
# set the backwards path
cell_trans
=
rail_trans
.
set_transition
(
cell_trans
,
(
new_dir
+
2
)
%
4
,
(
current_dir
+
2
)
%
4
,
1
)
rail_array
[
pos1
]
=
cell_trans
current_dir
=
new_dir
for
sg
in
start_goal
:
connect_rail
(
sg
[
0
],
sg
[
1
])
return_rail
=
GridTransitionMap
(
width
=
width
,
height
=
height
,
transitions
=
rail_trans
)
return_rail
.
grid
=
rail_array
return
return_rail
return
generator
def
rail_from_manual_specifications_generator
(
rail_spec
):
def
rail_from_manual_specifications_generator
(
rail_spec
):
"""
"""
Utility to convert a rail given by manual specification as a map of tuples
Utility to convert a rail given by manual specification as a map of tuples
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment