Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
Flatland
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
yoogottamk
Flatland
Commits
8e29a411
Commit
8e29a411
authored
5 years ago
by
Egli Adrian (IT-SCI-API-PFI)
Browse files
Options
Downloads
Patches
Plain Diff
city generator moved into framework
parent
2dc8b5be
No related branches found
No related tags found
No related merge requests found
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
examples/simple_example_city_railway_generator.py
+17
-509
17 additions, 509 deletions
examples/simple_example_city_railway_generator.py
flatland/envs/rail_generators_city_generator.py
+499
-0
499 additions, 0 deletions
flatland/envs/rail_generators_city_generator.py
with
516 additions
and
509 deletions
examples/simple_example_city_railway_generator.py
+
17
−
509
View file @
8e29a411
import
copy
import
os
import
warnings
from
typing
import
Sequence
,
Optional
from
typing
import
Sequence
import
numpy
as
np
from
flatland.core.grid.grid_utils
import
Vec2dOperations
as
Vec2d
,
IntVector2DArray
,
IntVector2DDistance
,
\
IntVector2DArrayArray
from
flatland.core.grid.rail_env_grid
import
RailEnvTransitions
from
flatland.core.transition_map
import
GridTransitionMap
from
flatland.envs.grid4_generators_utils
import
connect_from_nodes
,
connect_nodes
,
connect_rail
from
flatland.core.grid.grid_utils
import
Vec2dOperations
as
Vec2d
from
flatland.envs.observations
import
GlobalObsForRailEnv
from
flatland.envs.rail_env
import
RailEnv
from
flatland.envs.rail_generators
import
RailGenerator
,
RailG
enerator
Product
from
flatland.envs.schedule_generators
import
sparse
_schedule_generator
from
flatland.envs.rail_generators
_city_generator
import
city_g
enerator
from
flatland.envs.schedule_generators
import
city
_schedule_generator
from
flatland.utils.rendertools
import
RenderTool
,
AgentRenderVariant
FloatArrayType
=
[]
def
realistic_rail_generator
(
num_cities
:
int
=
5
,
city_size
:
int
=
10
,
allowed_rotation_angles
:
Optional
[
Sequence
[
float
]]
=
None
,
max_number_of_station_tracks
:
int
=
4
,
nbr_of_switches_per_station_track
:
int
=
2
,
connect_max_nbr_of_shortes_city
:
int
=
4
,
do_random_connect_stations
:
bool
=
False
,
a_star_distance_function
:
IntVector2DDistance
=
Vec2d
.
get_manhattan_distance
,
seed
:
int
=
0
,
print_out_info
:
bool
=
True
)
->
RailGenerator
:
"""
This is a level generator which generates a realistic rail configurations
:param num_cities: Number of city node
:param city_size: Length of city measure in cells
:param allowed_rotation_angles: Rotate the city (around center)
:param max_number_of_station_tracks: max number of tracks per station
:param nbr_of_switches_per_station_track: number of switches per track (max)
:param connect_max_nbr_of_shortes_city: max number of connecting track between stations
:param do_random_connect_stations : if false connect the stations along the grid (top,left -> down,right), else rand
:param a_star_distance_function: Heuristic how the distance between two nodes get estimated in the
"
a-star
"
path
:param seed: Random Seed
:param print_out_info: print debug info if True
:return:
-------
numpy.ndarray of type numpy.uint16
The matrix with the correct 16-bit bitmaps for each cell.
"""
def
do_generate_city_locations
(
width
:
int
,
height
:
int
,
intern_city_size
:
int
,
intern_max_number_of_station_tracks
:
int
)
->
(
IntVector2DArray
,
int
):
X
=
int
(
np
.
floor
(
max
(
1
,
height
-
2
*
intern_max_number_of_station_tracks
-
1
)
/
intern_city_size
))
Y
=
int
(
np
.
floor
(
max
(
1
,
width
-
2
*
intern_max_number_of_station_tracks
-
1
)
/
intern_city_size
))
max_num_cities
=
min
(
num_cities
,
X
*
Y
)
cities_at
=
np
.
random
.
choice
(
X
*
Y
,
max_num_cities
,
False
)
cities_at
=
np
.
sort
(
cities_at
)
if
print_out_info
:
print
(
"
max nbr of cities with given configuration is:
"
,
max_num_cities
)
x
=
np
.
floor
(
cities_at
/
Y
)
y
=
cities_at
-
x
*
Y
xs
=
(
x
*
intern_city_size
+
intern_max_number_of_station_tracks
)
+
intern_city_size
/
2
ys
=
(
y
*
intern_city_size
+
intern_max_number_of_station_tracks
)
+
intern_city_size
/
2
generate_city_locations
=
[[(
int
(
xs
[
i
]),
int
(
ys
[
i
])),
(
int
(
xs
[
i
]),
int
(
ys
[
i
]))]
for
i
in
range
(
len
(
xs
))]
return
generate_city_locations
,
max_num_cities
def
do_orient_cities
(
generate_city_locations
:
IntVector2DArrayArray
,
intern_city_size
:
int
,
rotation_angles_set
:
FloatArrayType
):
for
i
in
range
(
len
(
generate_city_locations
)):
# station main orientation (horizontal or vertical
rot_angle
=
np
.
random
.
choice
(
rotation_angles_set
)
add_pos_val
=
Vec2d
.
scale
(
Vec2d
.
rotate
((
1
,
0
),
rot_angle
),
int
(
max
(
1.0
,
(
intern_city_size
-
3
)
/
2
)))
generate_city_locations
[
i
][
0
]
=
Vec2d
.
add
(
generate_city_locations
[
i
][
1
],
add_pos_val
)
add_pos_val
=
Vec2d
.
scale
(
Vec2d
.
rotate
((
1
,
0
),
180
+
rot_angle
),
int
(
max
(
1.0
,
(
intern_city_size
-
3
)
/
2
)))
generate_city_locations
[
i
][
1
]
=
Vec2d
.
add
(
generate_city_locations
[
i
][
1
],
add_pos_val
)
return
generate_city_locations
def
create_stations_from_city_locations
(
rail_trans
:
RailEnvTransitions
,
grid_map
:
GridTransitionMap
,
generate_city_locations
:
IntVector2DArray
,
intern_max_number_of_station_tracks
:
int
)
->
(
IntVector2DArray
,
IntVector2DArray
,
IntVector2DArray
,
IntVector2DArray
,
IntVector2DArray
):
nodes_added
=
[]
start_nodes_added
=
[[]
for
_
in
range
(
len
(
generate_city_locations
))]
end_nodes_added
=
[[]
for
_
in
range
(
len
(
generate_city_locations
))]
station_slots
=
[[]
for
_
in
range
(
len
(
generate_city_locations
))]
station_tracks
=
[[[]
for
_
in
range
(
intern_max_number_of_station_tracks
)]
for
_
in
range
(
len
(
generate_city_locations
))]
station_slots_cnt
=
0
for
city_loop
in
range
(
len
(
generate_city_locations
)):
# Connect train station to the correct node
number_of_connecting_tracks
=
np
.
random
.
choice
(
max
(
0
,
intern_max_number_of_station_tracks
))
+
1
track_id
=
0
for
ct
in
range
(
number_of_connecting_tracks
):
org_start_node
=
generate_city_locations
[
city_loop
][
0
]
org_end_node
=
generate_city_locations
[
city_loop
][
1
]
ortho_trans
=
Vec2d
.
make_orthogonal
(
Vec2d
.
normalize
(
Vec2d
.
subtract
(
org_start_node
,
org_end_node
)))
s
=
(
ct
-
number_of_connecting_tracks
/
2.0
)
start_node
=
Vec2d
.
ceil
(
Vec2d
.
add
(
org_start_node
,
Vec2d
.
scale
(
ortho_trans
,
s
)))
end_node
=
Vec2d
.
ceil
(
Vec2d
.
add
(
org_end_node
,
Vec2d
.
scale
(
ortho_trans
,
s
)))
connection
=
connect_from_nodes
(
rail_trans
,
grid_map
,
start_node
,
end_node
,
a_star_distance_function
)
if
len
(
connection
)
>
0
:
nodes_added
.
append
(
start_node
)
nodes_added
.
append
(
end_node
)
start_nodes_added
[
city_loop
].
append
(
start_node
)
end_nodes_added
[
city_loop
].
append
(
end_node
)
# place in the center of path a station slot
# station_slots[city_loop].append(connection[int(np.floor(len(connection) / 2))])
for
c_loop
in
range
(
len
(
connection
)):
station_slots
[
city_loop
].
append
(
connection
[
c_loop
])
station_slots_cnt
+=
len
(
connection
)
station_tracks
[
city_loop
][
track_id
]
=
connection
track_id
+=
1
else
:
if
print_out_info
:
print
(
"
create_stations_from_city_locations : connect_from_nodes -> no path found
"
)
if
print_out_info
:
print
(
"
max nbr of station slots with given configuration is:
"
,
station_slots_cnt
)
return
nodes_added
,
station_slots
,
start_nodes_added
,
end_nodes_added
,
station_tracks
def
create_switches_at_stations
(
rail_trans
:
RailEnvTransitions
,
grid_map
:
GridTransitionMap
,
station_tracks
:
IntVector2DArray
,
nodes_added
:
IntVector2DArray
,
intern_nbr_of_switches_per_station_track
:
int
)
->
IntVector2DArray
:
for
k_loop
in
range
(
intern_nbr_of_switches_per_station_track
):
for
city_loop
in
range
(
len
(
station_tracks
)):
k
=
k_loop
+
city_loop
datas
=
station_tracks
[
city_loop
]
if
len
(
datas
)
>
1
:
track
=
datas
[
0
]
if
len
(
track
)
>
0
:
if
k
%
2
==
0
:
x
=
int
(
np
.
random
.
choice
(
int
(
len
(
track
)
/
2
))
+
1
)
else
:
x
=
len
(
track
)
-
int
(
np
.
random
.
choice
(
int
(
len
(
track
)
/
2
))
+
1
)
start_node
=
track
[
x
]
for
i
in
np
.
arange
(
1
,
len
(
datas
)):
track
=
datas
[
i
]
if
len
(
track
)
>
1
:
if
k
%
2
==
0
:
x
=
x
+
2
if
len
(
track
)
<=
x
:
x
=
1
else
:
x
=
x
-
2
if
x
<
2
:
x
=
len
(
track
)
-
1
end_node
=
track
[
x
]
connection
=
connect_rail
(
rail_trans
,
grid_map
,
start_node
,
end_node
,
a_star_distance_function
)
if
len
(
connection
)
==
0
:
if
print_out_info
:
print
(
"
create_switches_at_stations : connect_rail -> no path found
"
)
start_node
=
datas
[
i
][
0
]
end_node
=
datas
[
i
-
1
][
0
]
connect_rail
(
rail_trans
,
grid_map
,
start_node
,
end_node
,
a_star_distance_function
)
nodes_added
.
append
(
start_node
)
nodes_added
.
append
(
end_node
)
if
k
%
2
==
0
:
x
=
x
+
2
if
len
(
track
)
<=
x
:
x
=
1
else
:
x
=
x
-
2
if
x
<
2
:
x
=
len
(
track
)
-
2
start_node
=
track
[
x
]
return
nodes_added
def
create_graph_edge
(
from_city_index
:
int
,
to_city_index
:
int
)
->
(
int
,
int
,
int
):
return
from_city_index
,
to_city_index
,
np
.
inf
def
calc_nbr_of_graphs
(
graph
:
[])
->
([],
[]):
for
i
in
range
(
len
(
graph
)):
for
j
in
range
(
len
(
graph
)):
a
=
graph
[
i
]
b
=
graph
[
j
]
connected
=
False
if
a
[
0
]
==
b
[
0
]
or
a
[
1
]
==
b
[
0
]:
connected
=
True
if
a
[
0
]
==
b
[
1
]
or
a
[
1
]
==
b
[
1
]:
connected
=
True
if
connected
:
a
=
[
graph
[
i
][
0
],
graph
[
i
][
1
],
graph
[
i
][
2
]]
b
=
[
graph
[
j
][
0
],
graph
[
j
][
1
],
graph
[
j
][
2
]]
graph
[
i
]
=
(
graph
[
i
][
0
],
graph
[
i
][
1
],
min
(
np
.
min
(
a
),
np
.
min
(
b
)))
graph
[
j
]
=
(
graph
[
j
][
0
],
graph
[
j
][
1
],
min
(
np
.
min
(
a
),
np
.
min
(
b
)))
else
:
a
=
[
graph
[
i
][
0
],
graph
[
i
][
1
],
graph
[
i
][
2
]]
graph
[
i
]
=
(
graph
[
i
][
0
],
graph
[
i
][
1
],
np
.
min
(
a
))
b
=
[
graph
[
j
][
0
],
graph
[
j
][
1
],
graph
[
j
][
2
]]
graph
[
j
]
=
(
graph
[
j
][
0
],
graph
[
j
][
1
],
np
.
min
(
b
))
graph_ids
=
[]
for
i
in
range
(
len
(
graph
)):
graph_ids
.
append
(
graph
[
i
][
2
])
if
print_out_info
:
print
(
"
************* NBR of graphs:
"
,
len
(
np
.
unique
(
graph_ids
)))
return
graph
,
np
.
unique
(
graph_ids
).
astype
(
int
)
def
connect_sub_graphs
(
rail_trans
:
RailEnvTransitions
,
grid_map
:
GridTransitionMap
,
org_s_nodes
:
IntVector2DArray
,
org_e_nodes
:
IntVector2DArray
,
city_edges
:
IntVector2DArray
,
nodes_added
:
IntVector2DArray
):
_
,
graphids
=
calc_nbr_of_graphs
(
city_edges
)
if
len
(
graphids
)
>
0
:
for
i
in
range
(
len
(
graphids
)
-
1
):
connection
=
[]
iteration_counter
=
0
while
len
(
connection
)
==
0
and
iteration_counter
<
100
:
s_nodes
=
copy
.
deepcopy
(
org_s_nodes
)
e_nodes
=
copy
.
deepcopy
(
org_e_nodes
)
start_nodes
=
s_nodes
[
graphids
[
i
]]
end_nodes
=
e_nodes
[
graphids
[
i
+
1
]]
start_node
=
start_nodes
[
np
.
random
.
choice
(
len
(
start_nodes
))]
end_node
=
end_nodes
[
np
.
random
.
choice
(
len
(
end_nodes
))]
# TODO : removing, what the hell is going on, why we have to set rail_array -> transition to zero
# TODO : before we can call connect_rail. If we don't reset the transistion to zero -> no rail
# TODO : will be generated.
grid_map
.
grid
[
start_node
]
=
0
grid_map
.
grid
[
end_node
]
=
0
connection
=
connect_rail
(
rail_trans
,
grid_map
,
start_node
,
end_node
,
a_star_distance_function
)
if
len
(
connection
)
>
0
:
nodes_added
.
append
(
start_node
)
nodes_added
.
append
(
end_node
)
else
:
if
print_out_info
:
print
(
"
connect_sub_graphs : connect_rail -> no path found
"
)
iteration_counter
+=
1
def
connect_stations
(
rail_trans
:
RailEnvTransitions
,
grid_map
:
GridTransitionMap
,
org_s_nodes
:
IntVector2DArray
,
org_e_nodes
:
IntVector2DArray
,
nodes_added
:
IntVector2DArray
,
intern_connect_max_nbr_of_shortes_city
:
int
):
city_edges
=
[]
s_nodes
=
copy
.
deepcopy
(
org_s_nodes
)
e_nodes
=
copy
.
deepcopy
(
org_e_nodes
)
for
nbr_connected
in
range
(
intern_connect_max_nbr_of_shortes_city
):
for
city_loop
in
range
(
len
(
s_nodes
)):
sns
=
s_nodes
[
city_loop
]
for
start_node
in
sns
:
min_distance
=
np
.
inf
end_node
=
None
cl
=
0
for
city_loop_find_shortest
in
range
(
len
(
e_nodes
)):
if
city_loop_find_shortest
==
city_loop
:
continue
ens
=
e_nodes
[
city_loop_find_shortest
]
for
en
in
ens
:
d
=
Vec2d
.
get_euclidean_distance
(
start_node
,
en
)
if
d
<
min_distance
:
min_distance
=
d
end_node
=
en
cl
=
city_loop_find_shortest
if
end_node
is
not
None
:
tmp_trans_sn
=
grid_map
.
grid
[
start_node
]
tmp_trans_en
=
grid_map
.
grid
[
end_node
]
grid_map
.
grid
[
start_node
]
=
0
grid_map
.
grid
[
end_node
]
=
0
connection
=
connect_rail
(
rail_trans
,
grid_map
,
start_node
,
end_node
,
a_star_distance_function
)
if
len
(
connection
)
>
0
:
s_nodes
[
city_loop
].
remove
(
start_node
)
e_nodes
[
cl
].
remove
(
end_node
)
edge
=
create_graph_edge
(
city_loop
,
cl
)
if
city_loop
>
cl
:
edge
=
create_graph_edge
(
cl
,
city_loop
)
if
not
(
edge
in
city_edges
):
city_edges
.
append
(
edge
)
nodes_added
.
append
(
start_node
)
nodes_added
.
append
(
end_node
)
else
:
if
print_out_info
:
print
(
"
connect_stations : connect_rail -> no path found
"
)
grid_map
.
grid
[
start_node
]
=
tmp_trans_sn
grid_map
.
grid
[
end_node
]
=
tmp_trans_en
connect_sub_graphs
(
rail_trans
,
grid_map
,
org_s_nodes
,
org_e_nodes
,
city_edges
,
nodes_added
)
def
connect_random_stations
(
rail_trans
:
RailEnvTransitions
,
grid_map
:
GridTransitionMap
,
start_nodes_added
:
IntVector2DArray
,
end_nodes_added
:
IntVector2DArray
,
nodes_added
:
IntVector2DArray
,
intern_connect_max_nbr_of_shortes_city
:
int
):
if
len
(
start_nodes_added
)
<
1
:
return
x
=
np
.
arange
(
len
(
start_nodes_added
))
random_city_idx
=
np
.
random
.
choice
(
x
,
len
(
x
),
False
)
# cyclic connection
random_city_idx
=
np
.
append
(
random_city_idx
,
random_city_idx
[
0
])
for
city_loop
in
range
(
len
(
random_city_idx
)
-
1
):
idx_a
=
random_city_idx
[
city_loop
+
1
]
idx_b
=
random_city_idx
[
city_loop
]
s_nodes
=
start_nodes_added
[
idx_a
]
e_nodes
=
end_nodes_added
[
idx_b
]
max_input_output
=
max
(
len
(
s_nodes
),
len
(
e_nodes
))
max_input_output
=
min
(
intern_connect_max_nbr_of_shortes_city
,
max_input_output
)
idx_s_nodes
=
np
.
random
.
choice
(
np
.
arange
(
len
(
s_nodes
)),
len
(
s_nodes
),
False
)
idx_e_nodes
=
np
.
random
.
choice
(
np
.
arange
(
len
(
e_nodes
)),
len
(
e_nodes
),
False
)
if
len
(
idx_s_nodes
)
<
max_input_output
:
idx_s_nodes
=
np
.
append
(
idx_s_nodes
,
np
.
random
.
choice
(
np
.
arange
(
len
(
s_nodes
)),
max_input_output
-
len
(
idx_s_nodes
)))
if
len
(
idx_e_nodes
)
<
max_input_output
:
idx_e_nodes
=
np
.
append
(
idx_e_nodes
,
np
.
random
.
choice
(
np
.
arange
(
len
(
idx_e_nodes
)),
max_input_output
-
len
(
idx_e_nodes
)))
if
len
(
idx_s_nodes
)
>
intern_connect_max_nbr_of_shortes_city
:
idx_s_nodes
=
np
.
random
.
choice
(
idx_s_nodes
,
intern_connect_max_nbr_of_shortes_city
,
False
)
if
len
(
idx_e_nodes
)
>
intern_connect_max_nbr_of_shortes_city
:
idx_e_nodes
=
np
.
random
.
choice
(
idx_e_nodes
,
intern_connect_max_nbr_of_shortes_city
,
False
)
for
i
in
range
(
max_input_output
):
start_node
=
s_nodes
[
idx_s_nodes
[
i
]]
end_node
=
e_nodes
[
idx_e_nodes
[
i
]]
grid_map
.
grid
[
start_node
]
=
0
grid_map
.
grid
[
end_node
]
=
0
connection
=
connect_nodes
(
rail_trans
,
grid_map
,
start_node
,
end_node
,
a_star_distance_function
)
if
len
(
connection
)
>
0
:
nodes_added
.
append
(
start_node
)
nodes_added
.
append
(
end_node
)
else
:
if
print_out_info
:
print
(
"
connect_random_stations : connect_nodes -> no path found
"
)
def
remove_switch_stations
(
rail_trans
:
RailEnvTransitions
,
grid_map
:
GridTransitionMap
,
train_stations
:
IntVector2DArray
):
tmp_train_stations
=
copy
.
deepcopy
(
train_stations
)
for
city_loop
in
range
(
len
(
train_stations
)):
for
n
in
tmp_train_stations
[
city_loop
]:
do_remove
=
True
trans
=
rail_trans
.
transition_list
[
1
]
for
_
in
range
(
4
):
trans
=
rail_trans
.
rotate_transition
(
trans
,
rotation
=
90
)
if
grid_map
.
grid
[
n
]
==
trans
:
do_remove
=
False
if
do_remove
:
train_stations
[
city_loop
].
remove
(
n
)
def
generator
(
width
:
int
,
height
:
int
,
num_agents
:
int
,
num_resets
:
int
=
0
)
->
RailGeneratorProduct
:
rail_trans
=
RailEnvTransitions
()
grid_map
=
GridTransitionMap
(
width
=
width
,
height
=
height
,
transitions
=
rail_trans
)
grid_map
.
grid
.
fill
(
0
)
np
.
random
.
seed
(
seed
+
num_resets
)
intern_city_size
=
city_size
if
city_size
<
3
:
warnings
.
warn
(
"
min city_size requried to be > 3!
"
)
intern_city_size
=
3
if
print_out_info
:
print
(
"
intern_city_size:
"
,
intern_city_size
)
intern_max_number_of_station_tracks
=
max_number_of_station_tracks
if
max_number_of_station_tracks
<
1
:
warnings
.
warn
(
"
min max_number_of_station_tracks requried to be > 1!
"
)
intern_max_number_of_station_tracks
=
1
if
print_out_info
:
print
(
"
intern_max_number_of_station_tracks:
"
,
intern_max_number_of_station_tracks
)
intern_nbr_of_switches_per_station_track
=
nbr_of_switches_per_station_track
if
nbr_of_switches_per_station_track
<
1
:
warnings
.
warn
(
"
min intern_nbr_of_switches_per_station_track requried to be > 2!
"
)
intern_nbr_of_switches_per_station_track
=
2
if
print_out_info
:
print
(
"
intern_nbr_of_switches_per_station_track:
"
,
intern_nbr_of_switches_per_station_track
)
intern_connect_max_nbr_of_shortes_city
=
connect_max_nbr_of_shortes_city
if
connect_max_nbr_of_shortes_city
<
1
:
warnings
.
warn
(
"
min intern_connect_max_nbr_of_shortes_city requried to be > 1!
"
)
intern_connect_max_nbr_of_shortes_city
=
1
if
print_out_info
:
print
(
"
intern_connect_max_nbr_of_shortes_city:
"
,
intern_connect_max_nbr_of_shortes_city
)
agent_start_targets_nodes
=
[]
# ----------------------------------------------------------------------------------
# generate city locations
generate_city_locations
,
max_num_cities
=
do_generate_city_locations
(
width
,
height
,
intern_city_size
,
intern_max_number_of_station_tracks
)
# ----------------------------------------------------------------------------------
# apply orientation to cities (horizontal, vertical)
generate_city_locations
=
do_orient_cities
(
generate_city_locations
,
intern_city_size
,
allowed_rotation_angles
)
# ----------------------------------------------------------------------------------
# generate city topology
nodes_added
,
train_stations
,
s_nodes
,
e_nodes
,
station_tracks
=
\
create_stations_from_city_locations
(
rail_trans
,
grid_map
,
generate_city_locations
,
intern_max_number_of_station_tracks
)
# build switches
# TODO remove true/false block
if
True
:
create_switches_at_stations
(
rail_trans
,
grid_map
,
station_tracks
,
nodes_added
,
intern_nbr_of_switches_per_station_track
)
# ----------------------------------------------------------------------------------
# connect stations
# TODO remove true/false block
if
True
:
if
do_random_connect_stations
:
connect_random_stations
(
rail_trans
,
grid_map
,
s_nodes
,
e_nodes
,
nodes_added
,
intern_connect_max_nbr_of_shortes_city
)
else
:
connect_stations
(
rail_trans
,
grid_map
,
s_nodes
,
e_nodes
,
nodes_added
,
intern_connect_max_nbr_of_shortes_city
)
# ----------------------------------------------------------------------------------
# fix all transition at starting / ending points (mostly add a dead end, if missing)
# TODO i would like to remove the fixing stuff.
for
i
in
range
(
len
(
nodes_added
)):
grid_map
.
fix_transitions
(
nodes_added
[
i
])
# ----------------------------------------------------------------------------------
# remove stations where rail is a switch
remove_switch_stations
(
rail_trans
,
grid_map
,
train_stations
)
# ----------------------------------------------------------------------------------
# Slot availability in node
node_available_start
=
[]
node_available_target
=
[]
for
node_idx
in
range
(
max_num_cities
):
node_available_start
.
append
(
len
(
train_stations
[
node_idx
]))
node_available_target
.
append
(
len
(
train_stations
[
node_idx
]))
# Assign agents to slots
for
agent_idx
in
range
(
num_agents
):
avail_start_nodes
=
[
idx
for
idx
,
val
in
enumerate
(
node_available_start
)
if
val
>
0
]
avail_target_nodes
=
[
idx
for
idx
,
val
in
enumerate
(
node_available_target
)
if
val
>
0
]
if
len
(
avail_target_nodes
)
==
0
:
num_agents
-=
1
continue
start_node
=
np
.
random
.
choice
(
avail_start_nodes
)
target_node
=
np
.
random
.
choice
(
avail_target_nodes
)
tries
=
0
found_agent_pair
=
True
while
target_node
==
start_node
:
target_node
=
np
.
random
.
choice
(
avail_target_nodes
)
tries
+=
1
# Test again with new start node if no pair is found (This code needs to be improved)
if
(
tries
+
1
)
%
10
==
0
:
start_node
=
np
.
random
.
choice
(
avail_start_nodes
)
if
tries
>
100
:
warnings
.
warn
(
"
Could not set train_stations, removing agent!
"
)
found_agent_pair
=
False
break
if
found_agent_pair
:
node_available_start
[
start_node
]
-=
1
node_available_target
[
target_node
]
-=
1
agent_start_targets_nodes
.
append
((
start_node
,
target_node
))
else
:
num_agents
-=
1
return
grid_map
,
{
'
agents_hints
'
:
{
'
num_agents
'
:
num_agents
,
'
agent_start_targets_nodes
'
:
agent_start_targets_nodes
,
'
train_stations
'
:
train_stations
}}
return
generator
FloatArrayType
=
Sequence
[
float
]
if
os
.
path
.
exists
(
"
./../render_output/
"
):
for
itrials
in
np
.
arange
(
1
,
1000
,
1
):
...
...
@@ -510,18 +18,18 @@ if os.path.exists("./../render_output/"):
np
.
random
.
seed
(
itrials
)
env
=
RailEnv
(
width
=
40
+
np
.
random
.
choice
(
100
),
height
=
40
+
np
.
random
.
choice
(
100
),
rail_generator
=
realistic_rail
_generator
(
num_cities
=
5
+
np
.
random
.
choice
(
10
),
city_size
=
10
+
np
.
random
.
choice
(
5
),
allowed_rotation_angles
=
np
.
arange
(
0
,
360
,
6
),
max_number_of_station_tracks
=
4
+
np
.
random
.
choice
(
4
),
nbr_of_switches_per_station_track
=
2
+
np
.
random
.
choice
(
2
),
connect_max_nbr_of_shortes_city
=
2
+
np
.
random
.
choice
(
4
),
do_random_connect_stations
=
itrials
%
2
==
0
,
a_star_distance_function
=
Vec2d
.
get_euclidean_distance
,
seed
=
itrials
,
print_out_info
=
False
),
schedule_generator
=
sparse
_schedule_generator
(),
rail_generator
=
city
_generator
(
num_cities
=
5
+
np
.
random
.
choice
(
10
),
city_size
=
10
+
np
.
random
.
choice
(
5
),
allowed_rotation_angles
=
np
.
arange
(
0
,
360
,
6
),
max_number_of_station_tracks
=
4
+
np
.
random
.
choice
(
4
),
nbr_of_switches_per_station_track
=
2
+
np
.
random
.
choice
(
2
),
connect_max_nbr_of_shortes_city
=
2
+
np
.
random
.
choice
(
4
),
do_random_connect_stations
=
itrials
%
2
==
0
,
a_star_distance_function
=
Vec2d
.
get_euclidean_distance
,
seed
=
itrials
,
print_out_info
=
False
),
schedule_generator
=
city
_schedule_generator
(),
number_of_agents
=
10000
,
obs_builder_object
=
GlobalObsForRailEnv
())
...
...
This diff is collapsed.
Click to expand it.
flatland/envs/rail_generators_city_generator.py
+
499
−
0
View file @
8e29a411
import
copy
import
warnings
from
typing
import
Sequence
,
Optional
import
numpy
as
np
from
flatland.core.grid.grid_utils
import
Vec2dOperations
as
Vec2d
,
IntVector2DArray
,
IntVector2DDistance
,
\
IntVector2DArrayArray
from
flatland.core.grid.rail_env_grid
import
RailEnvTransitions
from
flatland.core.transition_map
import
GridTransitionMap
from
flatland.envs.grid4_generators_utils
import
connect_from_nodes
,
connect_nodes
,
connect_rail
from
flatland.envs.rail_generators
import
RailGenerator
,
RailGeneratorProduct
FloatArrayType
=
Sequence
[
float
]
def
city_generator
(
num_cities
:
int
=
5
,
city_size
:
int
=
10
,
allowed_rotation_angles
:
Optional
[
Sequence
[
float
]]
=
None
,
max_number_of_station_tracks
:
int
=
4
,
nbr_of_switches_per_station_track
:
int
=
2
,
connect_max_nbr_of_shortes_city
:
int
=
4
,
do_random_connect_stations
:
bool
=
False
,
a_star_distance_function
:
IntVector2DDistance
=
Vec2d
.
get_manhattan_distance
,
seed
:
int
=
0
,
print_out_info
:
bool
=
True
)
->
RailGenerator
:
"""
This is a level generator which generates a realistic rail configurations
:param num_cities: Number of city node
:param city_size: Length of city measure in cells
:param allowed_rotation_angles: Rotate the city (around center)
:param max_number_of_station_tracks: max number of tracks per station
:param nbr_of_switches_per_station_track: number of switches per track (max)
:param connect_max_nbr_of_shortes_city: max number of connecting track between stations
:param do_random_connect_stations : if false connect the stations along the grid (top,left -> down,right), else rand
:param a_star_distance_function: Heuristic how the distance between two nodes get estimated in the
"
a-star
"
path
:param seed: Random Seed
:param print_out_info: print debug info if True
:return:
-------
numpy.ndarray of type numpy.uint16
The matrix with the correct 16-bit bitmaps for each cell.
"""
def
do_generate_city_locations
(
width
:
int
,
height
:
int
,
intern_city_size
:
int
,
intern_max_number_of_station_tracks
:
int
)
->
(
IntVector2DArray
,
int
):
X
=
int
(
np
.
floor
(
max
(
1
,
height
-
2
*
intern_max_number_of_station_tracks
-
1
)
/
intern_city_size
))
Y
=
int
(
np
.
floor
(
max
(
1
,
width
-
2
*
intern_max_number_of_station_tracks
-
1
)
/
intern_city_size
))
max_num_cities
=
min
(
num_cities
,
X
*
Y
)
cities_at
=
np
.
random
.
choice
(
X
*
Y
,
max_num_cities
,
False
)
cities_at
=
np
.
sort
(
cities_at
)
if
print_out_info
:
print
(
"
max nbr of cities with given configuration is:
"
,
max_num_cities
)
x
=
np
.
floor
(
cities_at
/
Y
)
y
=
cities_at
-
x
*
Y
xs
=
(
x
*
intern_city_size
+
intern_max_number_of_station_tracks
)
+
intern_city_size
/
2
ys
=
(
y
*
intern_city_size
+
intern_max_number_of_station_tracks
)
+
intern_city_size
/
2
generate_city_locations
=
[[(
int
(
xs
[
i
]),
int
(
ys
[
i
])),
(
int
(
xs
[
i
]),
int
(
ys
[
i
]))]
for
i
in
range
(
len
(
xs
))]
return
generate_city_locations
,
max_num_cities
def
do_orient_cities
(
generate_city_locations
:
IntVector2DArrayArray
,
intern_city_size
:
int
,
rotation_angles_set
:
FloatArrayType
):
for
i
in
range
(
len
(
generate_city_locations
)):
# station main orientation (horizontal or vertical
rot_angle
=
np
.
random
.
choice
(
rotation_angles_set
)
add_pos_val
=
Vec2d
.
scale
(
Vec2d
.
rotate
((
1
,
0
),
rot_angle
),
int
(
max
(
1.0
,
(
intern_city_size
-
3
)
/
2
)))
generate_city_locations
[
i
][
0
]
=
Vec2d
.
add
(
generate_city_locations
[
i
][
1
],
add_pos_val
)
add_pos_val
=
Vec2d
.
scale
(
Vec2d
.
rotate
((
1
,
0
),
180
+
rot_angle
),
int
(
max
(
1.0
,
(
intern_city_size
-
3
)
/
2
)))
generate_city_locations
[
i
][
1
]
=
Vec2d
.
add
(
generate_city_locations
[
i
][
1
],
add_pos_val
)
return
generate_city_locations
def
create_stations_from_city_locations
(
rail_trans
:
RailEnvTransitions
,
grid_map
:
GridTransitionMap
,
generate_city_locations
:
IntVector2DArray
,
intern_max_number_of_station_tracks
:
int
)
->
(
IntVector2DArray
,
IntVector2DArray
,
IntVector2DArray
,
IntVector2DArray
,
IntVector2DArray
):
nodes_added
=
[]
start_nodes_added
=
[[]
for
_
in
range
(
len
(
generate_city_locations
))]
end_nodes_added
=
[[]
for
_
in
range
(
len
(
generate_city_locations
))]
station_slots
=
[[]
for
_
in
range
(
len
(
generate_city_locations
))]
station_tracks
=
[[[]
for
_
in
range
(
intern_max_number_of_station_tracks
)]
for
_
in
range
(
len
(
generate_city_locations
))]
station_slots_cnt
=
0
for
city_loop
in
range
(
len
(
generate_city_locations
)):
# Connect train station to the correct node
number_of_connecting_tracks
=
np
.
random
.
choice
(
max
(
0
,
intern_max_number_of_station_tracks
))
+
1
track_id
=
0
for
ct
in
range
(
number_of_connecting_tracks
):
org_start_node
=
generate_city_locations
[
city_loop
][
0
]
org_end_node
=
generate_city_locations
[
city_loop
][
1
]
ortho_trans
=
Vec2d
.
make_orthogonal
(
Vec2d
.
normalize
(
Vec2d
.
subtract
(
org_start_node
,
org_end_node
)))
s
=
(
ct
-
number_of_connecting_tracks
/
2.0
)
start_node
=
Vec2d
.
ceil
(
Vec2d
.
add
(
org_start_node
,
Vec2d
.
scale
(
ortho_trans
,
s
)))
end_node
=
Vec2d
.
ceil
(
Vec2d
.
add
(
org_end_node
,
Vec2d
.
scale
(
ortho_trans
,
s
)))
connection
=
connect_from_nodes
(
rail_trans
,
grid_map
,
start_node
,
end_node
,
a_star_distance_function
)
if
len
(
connection
)
>
0
:
nodes_added
.
append
(
start_node
)
nodes_added
.
append
(
end_node
)
start_nodes_added
[
city_loop
].
append
(
start_node
)
end_nodes_added
[
city_loop
].
append
(
end_node
)
# place in the center of path a station slot
# station_slots[city_loop].append(connection[int(np.floor(len(connection) / 2))])
for
c_loop
in
range
(
len
(
connection
)):
station_slots
[
city_loop
].
append
(
connection
[
c_loop
])
station_slots_cnt
+=
len
(
connection
)
station_tracks
[
city_loop
][
track_id
]
=
connection
track_id
+=
1
else
:
if
print_out_info
:
print
(
"
create_stations_from_city_locations : connect_from_nodes -> no path found
"
)
if
print_out_info
:
print
(
"
max nbr of station slots with given configuration is:
"
,
station_slots_cnt
)
return
nodes_added
,
station_slots
,
start_nodes_added
,
end_nodes_added
,
station_tracks
def
create_switches_at_stations
(
rail_trans
:
RailEnvTransitions
,
grid_map
:
GridTransitionMap
,
station_tracks
:
IntVector2DArray
,
nodes_added
:
IntVector2DArray
,
intern_nbr_of_switches_per_station_track
:
int
)
->
IntVector2DArray
:
for
k_loop
in
range
(
intern_nbr_of_switches_per_station_track
):
for
city_loop
in
range
(
len
(
station_tracks
)):
k
=
k_loop
+
city_loop
datas
=
station_tracks
[
city_loop
]
if
len
(
datas
)
>
1
:
track
=
datas
[
0
]
if
len
(
track
)
>
0
:
if
k
%
2
==
0
:
x
=
int
(
np
.
random
.
choice
(
int
(
len
(
track
)
/
2
))
+
1
)
else
:
x
=
len
(
track
)
-
int
(
np
.
random
.
choice
(
int
(
len
(
track
)
/
2
))
+
1
)
start_node
=
track
[
x
]
for
i
in
np
.
arange
(
1
,
len
(
datas
)):
track
=
datas
[
i
]
if
len
(
track
)
>
1
:
if
k
%
2
==
0
:
x
=
x
+
2
if
len
(
track
)
<=
x
:
x
=
1
else
:
x
=
x
-
2
if
x
<
2
:
x
=
len
(
track
)
-
1
end_node
=
track
[
x
]
connection
=
connect_rail
(
rail_trans
,
grid_map
,
start_node
,
end_node
,
a_star_distance_function
)
if
len
(
connection
)
==
0
:
if
print_out_info
:
print
(
"
create_switches_at_stations : connect_rail -> no path found
"
)
start_node
=
datas
[
i
][
0
]
end_node
=
datas
[
i
-
1
][
0
]
connect_rail
(
rail_trans
,
grid_map
,
start_node
,
end_node
,
a_star_distance_function
)
nodes_added
.
append
(
start_node
)
nodes_added
.
append
(
end_node
)
if
k
%
2
==
0
:
x
=
x
+
2
if
len
(
track
)
<=
x
:
x
=
1
else
:
x
=
x
-
2
if
x
<
2
:
x
=
len
(
track
)
-
2
start_node
=
track
[
x
]
return
nodes_added
def
create_graph_edge
(
from_city_index
:
int
,
to_city_index
:
int
)
->
(
int
,
int
,
int
):
return
from_city_index
,
to_city_index
,
np
.
inf
def
calc_nbr_of_graphs
(
graph
:
[])
->
([],
[]):
for
i
in
range
(
len
(
graph
)):
for
j
in
range
(
len
(
graph
)):
a
=
graph
[
i
]
b
=
graph
[
j
]
connected
=
False
if
a
[
0
]
==
b
[
0
]
or
a
[
1
]
==
b
[
0
]:
connected
=
True
if
a
[
0
]
==
b
[
1
]
or
a
[
1
]
==
b
[
1
]:
connected
=
True
if
connected
:
a
=
[
graph
[
i
][
0
],
graph
[
i
][
1
],
graph
[
i
][
2
]]
b
=
[
graph
[
j
][
0
],
graph
[
j
][
1
],
graph
[
j
][
2
]]
graph
[
i
]
=
(
graph
[
i
][
0
],
graph
[
i
][
1
],
min
(
np
.
min
(
a
),
np
.
min
(
b
)))
graph
[
j
]
=
(
graph
[
j
][
0
],
graph
[
j
][
1
],
min
(
np
.
min
(
a
),
np
.
min
(
b
)))
else
:
a
=
[
graph
[
i
][
0
],
graph
[
i
][
1
],
graph
[
i
][
2
]]
graph
[
i
]
=
(
graph
[
i
][
0
],
graph
[
i
][
1
],
np
.
min
(
a
))
b
=
[
graph
[
j
][
0
],
graph
[
j
][
1
],
graph
[
j
][
2
]]
graph
[
j
]
=
(
graph
[
j
][
0
],
graph
[
j
][
1
],
np
.
min
(
b
))
graph_ids
=
[]
for
i
in
range
(
len
(
graph
)):
graph_ids
.
append
(
graph
[
i
][
2
])
if
print_out_info
:
print
(
"
************* NBR of graphs:
"
,
len
(
np
.
unique
(
graph_ids
)))
return
graph
,
np
.
unique
(
graph_ids
).
astype
(
int
)
def
connect_sub_graphs
(
rail_trans
:
RailEnvTransitions
,
grid_map
:
GridTransitionMap
,
org_s_nodes
:
IntVector2DArray
,
org_e_nodes
:
IntVector2DArray
,
city_edges
:
IntVector2DArray
,
nodes_added
:
IntVector2DArray
):
_
,
graphids
=
calc_nbr_of_graphs
(
city_edges
)
if
len
(
graphids
)
>
0
:
for
i
in
range
(
len
(
graphids
)
-
1
):
connection
=
[]
iteration_counter
=
0
while
len
(
connection
)
==
0
and
iteration_counter
<
100
:
s_nodes
=
copy
.
deepcopy
(
org_s_nodes
)
e_nodes
=
copy
.
deepcopy
(
org_e_nodes
)
start_nodes
=
s_nodes
[
graphids
[
i
]]
end_nodes
=
e_nodes
[
graphids
[
i
+
1
]]
start_node
=
start_nodes
[
np
.
random
.
choice
(
len
(
start_nodes
))]
end_node
=
end_nodes
[
np
.
random
.
choice
(
len
(
end_nodes
))]
# TODO : removing, what the hell is going on, why we have to set rail_array -> transition to zero
# TODO : before we can call connect_rail. If we don't reset the transistion to zero -> no rail
# TODO : will be generated.
grid_map
.
grid
[
start_node
]
=
0
grid_map
.
grid
[
end_node
]
=
0
connection
=
connect_rail
(
rail_trans
,
grid_map
,
start_node
,
end_node
,
a_star_distance_function
)
if
len
(
connection
)
>
0
:
nodes_added
.
append
(
start_node
)
nodes_added
.
append
(
end_node
)
else
:
if
print_out_info
:
print
(
"
connect_sub_graphs : connect_rail -> no path found
"
)
iteration_counter
+=
1
def
connect_stations
(
rail_trans
:
RailEnvTransitions
,
grid_map
:
GridTransitionMap
,
org_s_nodes
:
IntVector2DArray
,
org_e_nodes
:
IntVector2DArray
,
nodes_added
:
IntVector2DArray
,
intern_connect_max_nbr_of_shortes_city
:
int
):
city_edges
=
[]
s_nodes
=
copy
.
deepcopy
(
org_s_nodes
)
e_nodes
=
copy
.
deepcopy
(
org_e_nodes
)
for
nbr_connected
in
range
(
intern_connect_max_nbr_of_shortes_city
):
for
city_loop
in
range
(
len
(
s_nodes
)):
sns
=
s_nodes
[
city_loop
]
for
start_node
in
sns
:
min_distance
=
np
.
inf
end_node
=
None
cl
=
0
for
city_loop_find_shortest
in
range
(
len
(
e_nodes
)):
if
city_loop_find_shortest
==
city_loop
:
continue
ens
=
e_nodes
[
city_loop_find_shortest
]
for
en
in
ens
:
d
=
Vec2d
.
get_euclidean_distance
(
start_node
,
en
)
if
d
<
min_distance
:
min_distance
=
d
end_node
=
en
cl
=
city_loop_find_shortest
if
end_node
is
not
None
:
tmp_trans_sn
=
grid_map
.
grid
[
start_node
]
tmp_trans_en
=
grid_map
.
grid
[
end_node
]
grid_map
.
grid
[
start_node
]
=
0
grid_map
.
grid
[
end_node
]
=
0
connection
=
connect_rail
(
rail_trans
,
grid_map
,
start_node
,
end_node
,
a_star_distance_function
)
if
len
(
connection
)
>
0
:
s_nodes
[
city_loop
].
remove
(
start_node
)
e_nodes
[
cl
].
remove
(
end_node
)
edge
=
create_graph_edge
(
city_loop
,
cl
)
if
city_loop
>
cl
:
edge
=
create_graph_edge
(
cl
,
city_loop
)
if
not
(
edge
in
city_edges
):
city_edges
.
append
(
edge
)
nodes_added
.
append
(
start_node
)
nodes_added
.
append
(
end_node
)
else
:
if
print_out_info
:
print
(
"
connect_stations : connect_rail -> no path found
"
)
grid_map
.
grid
[
start_node
]
=
tmp_trans_sn
grid_map
.
grid
[
end_node
]
=
tmp_trans_en
connect_sub_graphs
(
rail_trans
,
grid_map
,
org_s_nodes
,
org_e_nodes
,
city_edges
,
nodes_added
)
def
connect_random_stations
(
rail_trans
:
RailEnvTransitions
,
grid_map
:
GridTransitionMap
,
start_nodes_added
:
IntVector2DArray
,
end_nodes_added
:
IntVector2DArray
,
nodes_added
:
IntVector2DArray
,
intern_connect_max_nbr_of_shortes_city
:
int
):
if
len
(
start_nodes_added
)
<
1
:
return
x
=
np
.
arange
(
len
(
start_nodes_added
))
random_city_idx
=
np
.
random
.
choice
(
x
,
len
(
x
),
False
)
# cyclic connection
random_city_idx
=
np
.
append
(
random_city_idx
,
random_city_idx
[
0
])
for
city_loop
in
range
(
len
(
random_city_idx
)
-
1
):
idx_a
=
random_city_idx
[
city_loop
+
1
]
idx_b
=
random_city_idx
[
city_loop
]
s_nodes
=
start_nodes_added
[
idx_a
]
e_nodes
=
end_nodes_added
[
idx_b
]
max_input_output
=
max
(
len
(
s_nodes
),
len
(
e_nodes
))
max_input_output
=
min
(
intern_connect_max_nbr_of_shortes_city
,
max_input_output
)
idx_s_nodes
=
np
.
random
.
choice
(
np
.
arange
(
len
(
s_nodes
)),
len
(
s_nodes
),
False
)
idx_e_nodes
=
np
.
random
.
choice
(
np
.
arange
(
len
(
e_nodes
)),
len
(
e_nodes
),
False
)
if
len
(
idx_s_nodes
)
<
max_input_output
:
idx_s_nodes
=
np
.
append
(
idx_s_nodes
,
np
.
random
.
choice
(
np
.
arange
(
len
(
s_nodes
)),
max_input_output
-
len
(
idx_s_nodes
)))
if
len
(
idx_e_nodes
)
<
max_input_output
:
idx_e_nodes
=
np
.
append
(
idx_e_nodes
,
np
.
random
.
choice
(
np
.
arange
(
len
(
idx_e_nodes
)),
max_input_output
-
len
(
idx_e_nodes
)))
if
len
(
idx_s_nodes
)
>
intern_connect_max_nbr_of_shortes_city
:
idx_s_nodes
=
np
.
random
.
choice
(
idx_s_nodes
,
intern_connect_max_nbr_of_shortes_city
,
False
)
if
len
(
idx_e_nodes
)
>
intern_connect_max_nbr_of_shortes_city
:
idx_e_nodes
=
np
.
random
.
choice
(
idx_e_nodes
,
intern_connect_max_nbr_of_shortes_city
,
False
)
for
i
in
range
(
max_input_output
):
start_node
=
s_nodes
[
idx_s_nodes
[
i
]]
end_node
=
e_nodes
[
idx_e_nodes
[
i
]]
grid_map
.
grid
[
start_node
]
=
0
grid_map
.
grid
[
end_node
]
=
0
connection
=
connect_nodes
(
rail_trans
,
grid_map
,
start_node
,
end_node
,
a_star_distance_function
)
if
len
(
connection
)
>
0
:
nodes_added
.
append
(
start_node
)
nodes_added
.
append
(
end_node
)
else
:
if
print_out_info
:
print
(
"
connect_random_stations : connect_nodes -> no path found
"
)
def
remove_switch_stations
(
rail_trans
:
RailEnvTransitions
,
grid_map
:
GridTransitionMap
,
train_stations
:
IntVector2DArray
):
tmp_train_stations
=
copy
.
deepcopy
(
train_stations
)
for
city_loop
in
range
(
len
(
train_stations
)):
for
n
in
tmp_train_stations
[
city_loop
]:
do_remove
=
True
trans
=
rail_trans
.
transition_list
[
1
]
for
_
in
range
(
4
):
trans
=
rail_trans
.
rotate_transition
(
trans
,
rotation
=
90
)
if
grid_map
.
grid
[
n
]
==
trans
:
do_remove
=
False
if
do_remove
:
train_stations
[
city_loop
].
remove
(
n
)
def
generator
(
width
:
int
,
height
:
int
,
num_agents
:
int
,
num_resets
:
int
=
0
)
->
RailGeneratorProduct
:
rail_trans
=
RailEnvTransitions
()
grid_map
=
GridTransitionMap
(
width
=
width
,
height
=
height
,
transitions
=
rail_trans
)
grid_map
.
grid
.
fill
(
0
)
np
.
random
.
seed
(
seed
+
num_resets
)
intern_city_size
=
city_size
if
city_size
<
3
:
warnings
.
warn
(
"
min city_size requried to be > 3!
"
)
intern_city_size
=
3
if
print_out_info
:
print
(
"
intern_city_size:
"
,
intern_city_size
)
intern_max_number_of_station_tracks
=
max_number_of_station_tracks
if
max_number_of_station_tracks
<
1
:
warnings
.
warn
(
"
min max_number_of_station_tracks requried to be > 1!
"
)
intern_max_number_of_station_tracks
=
1
if
print_out_info
:
print
(
"
intern_max_number_of_station_tracks:
"
,
intern_max_number_of_station_tracks
)
intern_nbr_of_switches_per_station_track
=
nbr_of_switches_per_station_track
if
nbr_of_switches_per_station_track
<
1
:
warnings
.
warn
(
"
min intern_nbr_of_switches_per_station_track requried to be > 2!
"
)
intern_nbr_of_switches_per_station_track
=
2
if
print_out_info
:
print
(
"
intern_nbr_of_switches_per_station_track:
"
,
intern_nbr_of_switches_per_station_track
)
intern_connect_max_nbr_of_shortes_city
=
connect_max_nbr_of_shortes_city
if
connect_max_nbr_of_shortes_city
<
1
:
warnings
.
warn
(
"
min intern_connect_max_nbr_of_shortes_city requried to be > 1!
"
)
intern_connect_max_nbr_of_shortes_city
=
1
if
print_out_info
:
print
(
"
intern_connect_max_nbr_of_shortes_city:
"
,
intern_connect_max_nbr_of_shortes_city
)
agent_start_targets_nodes
=
[]
# ----------------------------------------------------------------------------------
# generate city locations
generate_city_locations
,
max_num_cities
=
do_generate_city_locations
(
width
,
height
,
intern_city_size
,
intern_max_number_of_station_tracks
)
# ----------------------------------------------------------------------------------
# apply orientation to cities (horizontal, vertical)
generate_city_locations
=
do_orient_cities
(
generate_city_locations
,
intern_city_size
,
allowed_rotation_angles
)
# ----------------------------------------------------------------------------------
# generate city topology
nodes_added
,
train_stations
,
s_nodes
,
e_nodes
,
station_tracks
=
\
create_stations_from_city_locations
(
rail_trans
,
grid_map
,
generate_city_locations
,
intern_max_number_of_station_tracks
)
# build switches
# TODO remove true/false block
if
True
:
create_switches_at_stations
(
rail_trans
,
grid_map
,
station_tracks
,
nodes_added
,
intern_nbr_of_switches_per_station_track
)
# ----------------------------------------------------------------------------------
# connect stations
# TODO remove true/false block
if
True
:
if
do_random_connect_stations
:
connect_random_stations
(
rail_trans
,
grid_map
,
s_nodes
,
e_nodes
,
nodes_added
,
intern_connect_max_nbr_of_shortes_city
)
else
:
connect_stations
(
rail_trans
,
grid_map
,
s_nodes
,
e_nodes
,
nodes_added
,
intern_connect_max_nbr_of_shortes_city
)
# ----------------------------------------------------------------------------------
# fix all transition at starting / ending points (mostly add a dead end, if missing)
# TODO i would like to remove the fixing stuff.
for
i
in
range
(
len
(
nodes_added
)):
grid_map
.
fix_transitions
(
nodes_added
[
i
])
# ----------------------------------------------------------------------------------
# remove stations where rail is a switch
remove_switch_stations
(
rail_trans
,
grid_map
,
train_stations
)
# ----------------------------------------------------------------------------------
# Slot availability in node
node_available_start
=
[]
node_available_target
=
[]
for
node_idx
in
range
(
max_num_cities
):
node_available_start
.
append
(
len
(
train_stations
[
node_idx
]))
node_available_target
.
append
(
len
(
train_stations
[
node_idx
]))
# Assign agents to slots
for
agent_idx
in
range
(
num_agents
):
avail_start_nodes
=
[
idx
for
idx
,
val
in
enumerate
(
node_available_start
)
if
val
>
0
]
avail_target_nodes
=
[
idx
for
idx
,
val
in
enumerate
(
node_available_target
)
if
val
>
0
]
if
len
(
avail_target_nodes
)
==
0
:
num_agents
-=
1
continue
start_node
=
np
.
random
.
choice
(
avail_start_nodes
)
target_node
=
np
.
random
.
choice
(
avail_target_nodes
)
tries
=
0
found_agent_pair
=
True
while
target_node
==
start_node
:
target_node
=
np
.
random
.
choice
(
avail_target_nodes
)
tries
+=
1
# Test again with new start node if no pair is found (This code needs to be improved)
if
(
tries
+
1
)
%
10
==
0
:
start_node
=
np
.
random
.
choice
(
avail_start_nodes
)
if
tries
>
100
:
warnings
.
warn
(
"
Could not set train_stations, removing agent!
"
)
found_agent_pair
=
False
break
if
found_agent_pair
:
node_available_start
[
start_node
]
-=
1
node_available_target
[
target_node
]
-=
1
agent_start_targets_nodes
.
append
((
start_node
,
target_node
))
else
:
num_agents
-=
1
return
grid_map
,
{
'
agents_hints
'
:
{
'
num_agents
'
:
num_agents
,
'
agent_start_targets_nodes
'
:
agent_start_targets_nodes
,
'
train_stations
'
:
train_stations
}}
return
generator
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment