Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Flatland
Flatland
Commits
854c9726
Commit
854c9726
authored
Sep 18, 2019
by
Egli Adrian (IT-SCI-API-PFI)
Browse files
refactoring and preparation (city generator).
parent
0b4c3f90
Changes
7
Hide whitespace changes
Inline
Side-by-side
examples/
S
imple_
Realistic_R
ailway_
G
enerator.py
→
examples/
s
imple_
example_city_r
ailway_
g
enerator.py
View file @
854c9726
import
copy
import
os
import
warnings
from
typing
import
Sequence
,
Optional
import
numpy
as
np
from
flatland.core.grid.grid_utils
import
Vec2dOperations
as
Vec2d
,
IntVector2DArrayType
from
flatland.core.grid.grid_utils
import
Vec2dOperations
as
Vec2d
,
IntVector2DArray
,
IntVector2DDistance
,
\
IntVector2DArrayArray
from
flatland.core.grid.rail_env_grid
import
RailEnvTransitions
from
flatland.core.transition_map
import
GridTransitionMap
from
flatland.envs.grid4_generators_utils
import
connect_from_nodes
,
connect_nodes
,
connect_rail
...
...
@@ -17,19 +19,19 @@ from flatland.utils.rendertools import RenderTool, AgentRenderVariant
FloatArrayType
=
[]
def
realistic_rail_generator
(
num_cities
=
5
,
city_size
=
10
,
allowed_rotation_angles
=
None
,
max_number_of_station_tracks
=
4
,
nbr_of_switches_per_station_track
=
2
,
connect_max_nbr_of_shortes_city
=
4
,
do_random_connect_stations
=
False
,
seed
=
0
,
print_out_info
=
True
)
->
RailGenerator
:
def
realistic_rail_generator
(
num_cities
:
int
=
5
,
city_size
:
int
=
10
,
allowed_rotation_angles
:
Optional
[
Sequence
[
float
]]
=
None
,
max_number_of_station_tracks
:
int
=
4
,
nbr_of_switches_per_station_track
:
int
=
2
,
connect_max_nbr_of_shortes_city
:
int
=
4
,
do_random_connect_stations
:
bool
=
False
,
a_star_distance_function
:
IntVector2DDistance
=
Vec2d
.
get_manhattan_distance
,
seed
:
int
=
0
,
print_out_info
:
bool
=
True
)
->
RailGenerator
:
"""
This is a level generator which generates a realistic rail configurations
:param print_out_info:
:param num_cities: Number of city node
:param city_size: Length of city measure in cells
:param allowed_rotation_angles: Rotate the city (around center)
...
...
@@ -37,8 +39,9 @@ def realistic_rail_generator(num_cities=5,
:param nbr_of_switches_per_station_track: number of switches per track (max)
:param connect_max_nbr_of_shortes_city: max number of connecting track between stations
:param do_random_connect_stations : if false connect the stations along the grid (top,left -> down,right), else rand
:param a_star_distance_function: Heuristic how the distance between two nodes get estimated in the "a-star" path
:param seed: Random Seed
:print_out_info
: print debug info
:
param
print_out_info: print debug info
if True
:return:
-------
numpy.ndarray of type numpy.uint16
...
...
@@ -48,7 +51,7 @@ def realistic_rail_generator(num_cities=5,
def
do_generate_city_locations
(
width
:
int
,
height
:
int
,
intern_city_size
:
int
,
intern_max_number_of_station_tracks
:
int
)
->
(
IntVector2DArray
Type
,
int
):
intern_max_number_of_station_tracks
:
int
)
->
(
IntVector2DArray
,
int
):
X
=
int
(
np
.
floor
(
max
(
1
,
height
-
2
*
intern_max_number_of_station_tracks
-
1
)
/
intern_city_size
))
Y
=
int
(
np
.
floor
(
max
(
1
,
width
-
2
*
intern_max_number_of_station_tracks
-
1
)
/
intern_city_size
))
...
...
@@ -68,7 +71,7 @@ def realistic_rail_generator(num_cities=5,
generate_city_locations
=
[[(
int
(
xs
[
i
]),
int
(
ys
[
i
])),
(
int
(
xs
[
i
]),
int
(
ys
[
i
]))]
for
i
in
range
(
len
(
xs
))]
return
generate_city_locations
,
max_num_cities
def
do_orient_cities
(
generate_city_locations
:
IntVector2DArray
Type
,
intern_city_size
:
int
,
def
do_orient_cities
(
generate_city_locations
:
IntVector2DArray
Array
,
intern_city_size
:
int
,
rotation_angles_set
:
FloatArrayType
):
for
i
in
range
(
len
(
generate_city_locations
)):
# station main orientation (horizontal or vertical
...
...
@@ -83,12 +86,12 @@ def realistic_rail_generator(num_cities=5,
def
create_stations_from_city_locations
(
rail_trans
:
RailEnvTransitions
,
grid_map
:
GridTransitionMap
,
generate_city_locations
:
IntVector2DArray
Type
,
intern_max_number_of_station_tracks
:
int
)
->
(
IntVector2DArray
Type
,
IntVector2DArray
Type
,
IntVector2DArray
Type
,
IntVector2DArray
Type
,
IntVector2DArray
Type
):
generate_city_locations
:
IntVector2DArray
,
intern_max_number_of_station_tracks
:
int
)
->
(
IntVector2DArray
,
IntVector2DArray
,
IntVector2DArray
,
IntVector2DArray
,
IntVector2DArray
):
nodes_added
=
[]
start_nodes_added
=
[[]
for
_
in
range
(
len
(
generate_city_locations
))]
...
...
@@ -115,7 +118,7 @@ def realistic_rail_generator(num_cities=5,
end_node
=
Vec2d
.
ceil
(
Vec2d
.
add
(
org_end_node
,
Vec2d
.
scale
(
ortho_trans
,
s
)))
connection
=
connect_from_nodes
(
rail_trans
,
grid_map
,
start_node
,
end_node
)
connection
=
connect_from_nodes
(
rail_trans
,
grid_map
,
start_node
,
end_node
,
a_star_distance_function
)
if
len
(
connection
)
>
0
:
nodes_added
.
append
(
start_node
)
nodes_added
.
append
(
end_node
)
...
...
@@ -142,9 +145,9 @@ def realistic_rail_generator(num_cities=5,
def
create_switches_at_stations
(
rail_trans
:
RailEnvTransitions
,
grid_map
:
GridTransitionMap
,
station_tracks
:
IntVector2DArray
Type
,
nodes_added
:
IntVector2DArray
Type
,
intern_nbr_of_switches_per_station_track
:
int
)
->
IntVector2DArray
Type
:
station_tracks
:
IntVector2DArray
,
nodes_added
:
IntVector2DArray
,
intern_nbr_of_switches_per_station_track
:
int
)
->
IntVector2DArray
:
for
k_loop
in
range
(
intern_nbr_of_switches_per_station_track
):
for
city_loop
in
range
(
len
(
station_tracks
)):
...
...
@@ -170,13 +173,14 @@ def realistic_rail_generator(num_cities=5,
if
x
<
2
:
x
=
len
(
track
)
-
1
end_node
=
track
[
x
]
connection
=
connect_rail
(
rail_trans
,
grid_map
,
start_node
,
end_node
)
connection
=
connect_rail
(
rail_trans
,
grid_map
,
start_node
,
end_node
,
a_star_distance_function
)
if
len
(
connection
)
==
0
:
if
print_out_info
:
print
(
"create_switches_at_stations : connect_rail -> no path found"
)
start_node
=
datas
[
i
][
0
]
end_node
=
datas
[
i
-
1
][
0
]
connect_rail
(
rail_trans
,
grid_map
,
start_node
,
end_node
)
connect_rail
(
rail_trans
,
grid_map
,
start_node
,
end_node
,
a_star_distance_function
)
nodes_added
.
append
(
start_node
)
nodes_added
.
append
(
end_node
)
...
...
@@ -226,10 +230,10 @@ def realistic_rail_generator(num_cities=5,
return
graph
,
np
.
unique
(
graph_ids
).
astype
(
int
)
def
connect_sub_graphs
(
rail_trans
:
RailEnvTransitions
,
grid_map
:
GridTransitionMap
,
org_s_nodes
:
IntVector2DArray
Type
,
org_e_nodes
:
IntVector2DArray
Type
,
city_edges
:
IntVector2DArray
Type
,
nodes_added
:
IntVector2DArray
Type
):
org_s_nodes
:
IntVector2DArray
,
org_e_nodes
:
IntVector2DArray
,
city_edges
:
IntVector2DArray
,
nodes_added
:
IntVector2DArray
):
_
,
graphids
=
calc_nbr_of_graphs
(
city_edges
)
if
len
(
graphids
)
>
0
:
for
i
in
range
(
len
(
graphids
)
-
1
):
...
...
@@ -247,7 +251,7 @@ def realistic_rail_generator(num_cities=5,
# TODO : will be generated.
grid_map
.
grid
[
start_node
]
=
0
grid_map
.
grid
[
end_node
]
=
0
connection
=
connect_rail
(
rail_trans
,
grid_map
,
start_node
,
end_node
)
connection
=
connect_rail
(
rail_trans
,
grid_map
,
start_node
,
end_node
,
a_star_distance_function
)
if
len
(
connection
)
>
0
:
nodes_added
.
append
(
start_node
)
nodes_added
.
append
(
end_node
)
...
...
@@ -259,9 +263,9 @@ def realistic_rail_generator(num_cities=5,
def
connect_stations
(
rail_trans
:
RailEnvTransitions
,
grid_map
:
GridTransitionMap
,
org_s_nodes
:
IntVector2DArray
Type
,
org_e_nodes
:
IntVector2DArray
Type
,
nodes_added
:
IntVector2DArray
Type
,
org_s_nodes
:
IntVector2DArray
,
org_e_nodes
:
IntVector2DArray
,
nodes_added
:
IntVector2DArray
,
intern_connect_max_nbr_of_shortes_city
:
int
):
city_edges
=
[]
...
...
@@ -291,7 +295,7 @@ def realistic_rail_generator(num_cities=5,
tmp_trans_en
=
grid_map
.
grid
[
end_node
]
grid_map
.
grid
[
start_node
]
=
0
grid_map
.
grid
[
end_node
]
=
0
connection
=
connect_rail
(
rail_trans
,
grid_map
,
start_node
,
end_node
)
connection
=
connect_rail
(
rail_trans
,
grid_map
,
start_node
,
end_node
,
a_star_distance_function
)
if
len
(
connection
)
>
0
:
s_nodes
[
city_loop
].
remove
(
start_node
)
e_nodes
[
cl
].
remove
(
end_node
)
...
...
@@ -313,9 +317,9 @@ def realistic_rail_generator(num_cities=5,
connect_sub_graphs
(
rail_trans
,
grid_map
,
org_s_nodes
,
org_e_nodes
,
city_edges
,
nodes_added
)
def
connect_random_stations
(
rail_trans
:
RailEnvTransitions
,
grid_map
:
GridTransitionMap
,
start_nodes_added
:
IntVector2DArray
Type
,
end_nodes_added
:
IntVector2DArray
Type
,
nodes_added
:
IntVector2DArray
Type
,
start_nodes_added
:
IntVector2DArray
,
end_nodes_added
:
IntVector2DArray
,
nodes_added
:
IntVector2DArray
,
intern_connect_max_nbr_of_shortes_city
:
int
):
if
len
(
start_nodes_added
)
<
1
:
return
...
...
@@ -355,7 +359,7 @@ def realistic_rail_generator(num_cities=5,
end_node
=
e_nodes
[
idx_e_nodes
[
i
]]
grid_map
.
grid
[
start_node
]
=
0
grid_map
.
grid
[
end_node
]
=
0
connection
=
connect_nodes
(
rail_trans
,
grid_map
,
start_node
,
end_node
)
connection
=
connect_nodes
(
rail_trans
,
grid_map
,
start_node
,
end_node
,
a_star_distance_function
)
if
len
(
connection
)
>
0
:
nodes_added
.
append
(
start_node
)
nodes_added
.
append
(
end_node
)
...
...
@@ -364,7 +368,7 @@ def realistic_rail_generator(num_cities=5,
print
(
"connect_random_stations : connect_nodes -> no path found"
)
def
remove_switch_stations
(
rail_trans
:
RailEnvTransitions
,
grid_map
:
GridTransitionMap
,
train_stations
:
IntVector2DArray
Type
):
train_stations
:
IntVector2DArray
):
tmp_train_stations
=
copy
.
deepcopy
(
train_stations
)
for
city_loop
in
range
(
len
(
train_stations
)):
for
n
in
tmp_train_stations
[
city_loop
]:
...
...
@@ -481,7 +485,7 @@ def realistic_rail_generator(num_cities=5,
if
(
tries
+
1
)
%
10
==
0
:
start_node
=
np
.
random
.
choice
(
avail_start_nodes
)
if
tries
>
100
:
warnings
.
warn
(
"Could not set trainstations, removing agent!"
)
warnings
.
warn
(
"Could not set train
_
stations, removing agent!"
)
found_agent_pair
=
False
break
if
found_agent_pair
:
...
...
@@ -508,13 +512,13 @@ if os.path.exists("./../render_output/"):
height
=
40
+
np
.
random
.
choice
(
100
),
rail_generator
=
realistic_rail_generator
(
num_cities
=
5
+
np
.
random
.
choice
(
10
),
city_size
=
10
+
np
.
random
.
choice
(
5
),
allowed_rotation_angles
=
np
.
arange
(
0
,
360
,
90
),
max_number_of_station_tracks
=
1
+
np
.
random
.
choice
(
4
),
allowed_rotation_angles
=
np
.
arange
(
0
,
360
,
6
),
max_number_of_station_tracks
=
4
+
np
.
random
.
choice
(
4
),
nbr_of_switches_per_station_track
=
2
+
np
.
random
.
choice
(
2
),
connect_max_nbr_of_shortes_city
=
2
+
np
.
random
.
choice
(
4
),
do_random_connect_stations
=
itrials
%
2
==
0
,
# Number of cities in map
seed
=
itrials
,
# Random seed
a_star_distance_function
=
Vec2d
.
get_euclidean_distance
,
seed
=
itrials
,
print_out_info
=
False
),
schedule_generator
=
sparse_schedule_generator
(),
...
...
flatland/core/grid/grid4_astar.py
View file @
854c9726
import
numpy
as
np
from
matplotlib
import
pyplot
as
plt
from
flatland.core.grid.grid_utils
import
IntVector2D
from
flatland.core.grid.grid_utils
import
IntVector2DArray
Type
from
flatland.core.grid.grid_utils
import
IntVector2D
,
IntVector2DDistance
from
flatland.core.grid.grid_utils
import
IntVector2DArray
from
flatland.core.grid.grid_utils
import
Vec2dOperations
as
Vec2d
from
flatland.core.grid.rail_env_grid
import
RailEnvTransitions
from
flatland.core.transition_map
import
GridTransitionMap
class
AStarNode
:
"""A node class for A* Pathfinding"""
def
__init__
(
self
,
p
arent
:
IntVector2D
=
None
,
pos
:
IntVector2D
=
None
):
self
.
parent
:
IntVector2D
=
parent
def
__init__
(
self
,
p
os
:
IntVector2D
,
parent
=
None
):
self
.
parent
=
parent
self
.
pos
:
IntVector2D
=
pos
self
.
g
=
0.0
self
.
h
=
0.0
self
.
f
=
0.0
def
__eq__
(
self
,
other
:
IntVector2D
):
def
__eq__
(
self
,
other
):
"""
Parameters
----------
other : AStarNode
"""
return
self
.
pos
==
other
.
pos
def
__hash__
(
self
):
...
...
@@ -32,10 +36,9 @@ class AStarNode:
self
.
f
=
other
.
f
def
a_star
(
rail_trans
:
RailEnvTransitions
,
grid_map
:
GridTransitionMap
,
def
a_star
(
grid_map
:
GridTransitionMap
,
start
:
IntVector2D
,
end
:
IntVector2D
,
a_star_distance_function
=
Vec2d
.
get_manhattan_distance
)
->
IntVector2DArray
Type
:
a_star_distance_function
:
IntVector2DDistance
=
Vec2d
.
get_manhattan_distance
)
->
IntVector2DArray
:
"""
Returns a list of tuples as a path from the given start to end.
If no path is found, returns path to closest point to end.
...
...
@@ -44,8 +47,8 @@ def a_star(rail_trans: RailEnvTransitions,
tmp
=
np
.
zeros
(
rail_shape
)
-
10
start_node
=
AStarNode
(
None
,
start
)
end_node
=
AStarNode
(
None
,
end
)
start_node
=
AStarNode
(
start
,
None
)
end_node
=
AStarNode
(
end
,
None
)
open_nodes
=
set
()
closed_nodes
=
set
()
open_nodes
.
add
(
start_node
)
...
...
@@ -72,13 +75,6 @@ def a_star(rail_trans: RailEnvTransitions,
path
.
append
(
current
.
pos
)
current
=
current
.
parent
if
False
:
plt
.
ion
()
plt
.
clf
()
plt
.
imshow
(
tmp
,
interpolation
=
'nearest'
)
plt
.
draw
()
plt
.
pause
(
1e-17
)
# return reversed path
return
path
[::
-
1
]
...
...
@@ -91,7 +87,7 @@ def a_star(rail_trans: RailEnvTransitions,
for
new_pos
in
[(
0
,
-
1
),
(
0
,
1
),
(
-
1
,
0
),
(
1
,
0
)]:
# update the "current" pos
node_pos
=
Vec2d
.
add
(
current_node
.
pos
,
new_pos
)
node_pos
:
IntVector2D
=
Vec2d
.
add
(
current_node
.
pos
,
new_pos
)
# is node_pos inside the grid?
if
node_pos
[
0
]
>=
rail_shape
[
0
]
or
node_pos
[
0
]
<
0
or
node_pos
[
1
]
>=
rail_shape
[
1
]
or
node_pos
[
1
]
<
0
:
...
...
@@ -102,7 +98,7 @@ def a_star(rail_trans: RailEnvTransitions,
continue
# create new node
new_node
=
AStarNode
(
current_node
,
node_pos
)
new_node
=
AStarNode
(
node_pos
,
current_node
)
children
.
append
(
new_node
)
# loop through children
...
...
flatland/core/grid/grid4_utils.py
View file @
854c9726
from
flatland.core.grid.grid4
import
Grid4TransitionsEnum
from
flatland.core.grid.grid_utils
import
IntVector2DArray
Type
from
flatland.core.grid.grid_utils
import
IntVector2DArray
def
get_direction
(
pos1
:
IntVector2DArray
Type
,
pos2
:
IntVector2DArray
Type
)
->
Grid4TransitionsEnum
:
def
get_direction
(
pos1
:
IntVector2DArray
,
pos2
:
IntVector2DArray
)
->
Grid4TransitionsEnum
:
"""
Assumes pos1 and pos2 are adjacent location on grid.
Returns direction (int) that can be used with transitions.
...
...
flatland/core/grid/grid_utils.py
View file @
854c9726
from
typing
import
Tuple
from
typing
import
Tuple
,
Callable
,
List
import
numpy
as
np
Vector2D
=
Tuple
[
float
,
float
]
IntVector2D
=
Tuple
[
int
,
int
]
IntVector2DArrayType
=
[]
IntVector2DArray
=
List
[
IntVector2D
]
IntVector2DArrayArray
=
List
[
List
[
IntVector2D
]]
Vector2DArray
=
List
[
Vector2D
]
Vector2DArrayArray
=
List
[
List
[
Vector2D
]]
IntVector2DDistance
=
Callable
[[
IntVector2D
,
IntVector2D
],
float
]
class
Vec2dOperations
:
...
...
@@ -73,42 +79,30 @@ class Vec2dOperations:
"""
return
np
.
sqrt
(
node
[
0
]
*
node
[
0
]
+
node
[
1
]
*
node
[
1
])
@
staticmethod
def
get_
manhattan_norm
(
node
:
Vector2D
)
->
float
:
def
get_
euclidean_distance
(
node_a
:
Vector2D
,
node
_b
:
Vector2D
)
->
float
:
"""
calculates the euclidean norm of the 2d vector
:param node: tuple with coordinate (x,y) or 2d vector
:return:
-------
returns the
manhatten norm
returns the
euclidean distance
"""
return
abs
(
node
[
0
]
*
node
[
0
])
+
abs
(
node
[
1
]
*
node
[
1
])
@
staticmethod
def
get_euclidean_distance
(
node_a
:
Vector2D
,
node_b
:
Vector2D
)
->
float
:
"""
calculates the euclidean norm of the 2d vector
:param node: tuple with coordinate (x,y) or 2d vector
:return:
-------
returnss the manhatten distance
"""
return
Vec2dOperations
.
get_norm
(
Vec2dOperations
.
subtract
(
node_b
,
node_a
))
return
Vec2dOperations
.
get_norm
(
Vec2dOperations
.
subtract
(
node_b
,
node_a
))
@
staticmethod
def
get_manhattan_distance
(
node_a
:
Vector2D
,
node_b
:
Vector2D
)
->
float
:
"""
calculates the
euclidean norm
of the 2d vector
calculates the
manhattan distance
of the 2d vector
:param node: tuple with coordinate (x,y) or 2d vector
:return:
-------
returns
s
the manhatt
e
n distance
returns the manhatt
a
n distance
"""
return
Vec2dOperations
.
get_manhattan_norm
(
Vec2dOperations
.
subtract
(
node_b
,
node_a
))
delta
=
(
Vec2dOperations
.
subtract
(
node_b
,
node_a
))
return
np
.
abs
(
delta
[
0
])
+
np
.
abs
(
delta
[
1
])
@
staticmethod
def
normalize
(
node
:
Vector2D
)
->
Tuple
[
float
,
float
]:
...
...
flatland/core/transition_map.py
View file @
854c9726
...
...
@@ -8,7 +8,7 @@ from numpy import array
from
flatland.core.grid.grid4
import
Grid4Transitions
from
flatland.core.grid.grid4_utils
import
get_new_position
,
get_direction
from
flatland.core.grid.grid_utils
import
IntVector2DArray
Type
from
flatland.core.grid.grid_utils
import
IntVector2DArray
,
IntVector2D
from
flatland.core.grid.grid_utils
import
Vec2dOperations
as
Vec2d
from
flatland.core.grid.rail_env_grid
import
RailEnvTransitions
from
flatland.core.transitions
import
Transitions
...
...
@@ -302,7 +302,7 @@ class GridTransitionMap(TransitionMap):
self
.
height
=
new_height
self
.
grid
=
new_grid
def
is_dead_end
(
self
,
rcPos
:
IntVector2DArray
Type
):
def
is_dead_end
(
self
,
rcPos
:
IntVector2DArray
):
"""
Check if the cell is a dead-end.
...
...
@@ -322,7 +322,7 @@ class GridTransitionMap(TransitionMap):
tmp
=
tmp
>>
1
return
nbits
==
1
def
is_simple_turn
(
self
,
rcPos
:
IntVector2DArray
Type
):
def
is_simple_turn
(
self
,
rcPos
:
IntVector2DArray
):
"""
Check if the cell is a left/right simple turn
...
...
@@ -349,7 +349,7 @@ class GridTransitionMap(TransitionMap):
return
is_simple_turn
(
tmp
)
def
check_path_exists
(
self
,
start
:
IntVector2DArray
Type
,
direction
:
int
,
end
:
IntVector2DArray
Type
):
def
check_path_exists
(
self
,
start
:
IntVector2DArray
,
direction
:
int
,
end
:
IntVector2DArray
):
# print("_path_exists({},{},{}".format(start, direction, end))
# BFS - Check if a path exists between the 2 nodes
...
...
@@ -373,7 +373,7 @@ class GridTransitionMap(TransitionMap):
return
False
def
cell_neighbours_valid
(
self
,
rcPos
:
IntVector2DArray
Type
,
check_this_cell
=
False
):
def
cell_neighbours_valid
(
self
,
rcPos
:
IntVector2DArray
,
check_this_cell
=
False
):
"""
Check validity of cell at rcPos = tuple(row, column)
Checks that:
...
...
@@ -425,7 +425,7 @@ class GridTransitionMap(TransitionMap):
return
True
def
fix_neighbours
(
self
,
rcPos
:
IntVector2DArray
Type
,
check_this_cell
=
False
):
def
fix_neighbours
(
self
,
rcPos
:
IntVector2DArray
,
check_this_cell
=
False
):
"""
Check validity of cell at rcPos = tuple(row, column)
Checks that:
...
...
@@ -478,7 +478,7 @@ class GridTransitionMap(TransitionMap):
return
True
def
fix_transitions
(
self
,
rcPos
:
IntVector2DArray
Type
):
def
fix_transitions
(
self
,
rcPos
:
IntVector2DArray
):
"""
Fixes broken transitions
"""
...
...
@@ -543,8 +543,8 @@ class GridTransitionMap(TransitionMap):
self
.
set_transitions
((
rcPos
[
0
],
rcPos
[
1
]),
transition
)
return
True
def
validate_new_transition
(
self
,
prev_pos
:
IntVector2D
ArrayType
,
current_pos
:
IntVector2D
ArrayType
,
new_pos
:
IntVector2D
ArrayType
,
end_pos
:
IntVector2D
ArrayType
):
def
validate_new_transition
(
self
,
prev_pos
:
IntVector2D
,
current_pos
:
IntVector2D
,
new_pos
:
IntVector2D
,
end_pos
:
IntVector2D
):
# start by getting direction used to get to current node
# and direction from current node to possible child node
...
...
flatland/envs/grid4_generators_utils.py
View file @
854c9726
...
...
@@ -7,7 +7,8 @@ a GridTransitionMap object.
from
flatland.core.grid.grid4_astar
import
a_star
from
flatland.core.grid.grid4_utils
import
get_direction
,
mirror
from
flatland.core.grid.grid_utils
import
IntVector2D
from
flatland.core.grid.grid_utils
import
IntVector2D
,
IntVector2DDistance
from
flatland.core.grid.grid_utils
import
Vec2dOperations
as
Vec2d
from
flatland.core.transition_map
import
GridTransitionMap
,
RailEnvTransitions
...
...
@@ -15,12 +16,13 @@ def connect_basic_operation(rail_trans: RailEnvTransitions, grid_map: GridTransi
start
:
IntVector2D
,
end
:
IntVector2D
,
flip_start_node_trans
=
False
,
flip_end_node_trans
=
False
):
flip_end_node_trans
=
False
,
a_star_distance_function
:
IntVector2DDistance
=
Vec2d
.
get_manhattan_distance
):
"""
Creates a new path [start,end] in grid_map, based on rail_trans.
"""
# in the worst case we will need to do a A* search, so we might as well set that up
path
=
a_star
(
rail_trans
,
grid_map
,
start
,
end
)
path
=
a_star
(
grid_map
,
start
,
end
,
a_star_distance_function
)
if
len
(
path
)
<
2
:
return
[]
current_dir
=
get_direction
(
path
[
0
],
path
[
1
])
...
...
@@ -67,18 +69,25 @@ def connect_basic_operation(rail_trans: RailEnvTransitions, grid_map: GridTransi
return
path
def
connect_rail
(
rail_trans
:
RailEnvTransitions
,
grid_map
:
GridTransitionMap
,
start
:
IntVector2D
,
end
:
IntVector2D
):
return
connect_basic_operation
(
rail_trans
,
grid_map
,
start
,
end
,
True
,
True
)
def
connect_rail
(
rail_trans
:
RailEnvTransitions
,
grid_map
:
GridTransitionMap
,
start
:
IntVector2D
,
end
:
IntVector2D
,
a_star_distance_function
:
IntVector2DDistance
=
Vec2d
.
get_manhattan_distance
):
return
connect_basic_operation
(
rail_trans
,
grid_map
,
start
,
end
,
True
,
True
,
a_star_distance_function
)
def
connect_nodes
(
rail_trans
:
RailEnvTransitions
,
grid_map
:
GridTransitionMap
,
start
:
IntVector2D
,
end
:
IntVector2D
):
return
connect_basic_operation
(
rail_trans
,
grid_map
,
start
,
end
,
False
,
False
)
def
connect_nodes
(
rail_trans
:
RailEnvTransitions
,
grid_map
:
GridTransitionMap
,
start
:
IntVector2D
,
end
:
IntVector2D
,
a_star_distance_function
:
IntVector2DDistance
=
Vec2d
.
get_manhattan_distance
):
return
connect_basic_operation
(
rail_trans
,
grid_map
,
start
,
end
,
False
,
False
,
a_star_distance_function
)
def
connect_from_nodes
(
rail_trans
:
RailEnvTransitions
,
grid_map
:
GridTransitionMap
,
start
:
IntVector2D
,
end
:
IntVector2D
):
return
connect_basic_operation
(
rail_trans
,
grid_map
,
start
,
end
,
False
,
True
)
def
connect_from_nodes
(
rail_trans
:
RailEnvTransitions
,
grid_map
:
GridTransitionMap
,
start
:
IntVector2D
,
end
:
IntVector2D
,
a_star_distance_function
:
IntVector2DDistance
=
Vec2d
.
get_manhattan_distance
):
return
connect_basic_operation
(
rail_trans
,
grid_map
,
start
,
end
,
False
,
True
,
a_star_distance_function
)
def
connect_to_nodes
(
rail_trans
:
RailEnvTransitions
,
grid_map
:
GridTransitionMap
,
start
:
IntVector2D
,
end
:
IntVector2D
):
return
connect_basic_operation
(
rail_trans
,
grid_map
,
start
,
end
,
True
,
False
)
def
connect_to_nodes
(
rail_trans
:
RailEnvTransitions
,
grid_map
:
GridTransitionMap
,
start
:
IntVector2D
,
end
:
IntVector2D
,
a_star_distance_function
:
IntVector2DDistance
=
Vec2d
.
get_manhattan_distance
):
return
connect_basic_operation
(
rail_trans
,
grid_map
,
start
,
end
,
True
,
False
,
a_star_distance_function
)
flatland/envs/rail_generators_city_generator.py
0 → 100644
View file @
854c9726
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment