Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Flatland
Flatland
Commits
9a6efb05
Commit
9a6efb05
authored
Apr 21, 2019
by
spiglerg
Browse files
added save/load gridmap for GridTransitionMap
parent
13ebd009
Pipeline
#316
failed with stage
in 2 minutes and 8 seconds
Changes
3
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
examples/temporary_example.py
View file @
9a6efb05
...
@@ -9,24 +9,14 @@ from flatland.utils.rendertools import *
...
@@ -9,24 +9,14 @@ from flatland.utils.rendertools import *
random
.
seed
(
0
)
random
.
seed
(
0
)
np
.
random
.
seed
(
0
)
np
.
random
.
seed
(
0
)
"""
transition_probability = [1.0, # empty cell - Case 0
3.0, # Case 1 - straight
1.0, # Case 2 - simple switch
3.0, # Case 3 - diamond drossing
2.0, # Case 4 - single slip
1.0, # Case 5 - double slip
1.0, # Case 6 - symmetrical
1.0] # Case 7 - dead end
"""
transition_probability
=
[
1.0
,
# empty cell - Case 0
transition_probability
=
[
1.0
,
# empty cell - Case 0
1.0
,
# Case 1 - straight
1.0
,
# Case 1 - straight
0.5
,
# Case 2 - simple switch
1.0
,
# Case 2 - simple switch
0.
2
,
# Case 3 - diamond drossing
0.
3
,
# Case 3 - diamond drossing
0.5
,
# Case 4 - single slip
0.5
,
# Case 4 - single slip
0.
1
,
# Case 5 - double slip
0.
5
,
# Case 5 - double slip
0.2
,
# Case 6 - symmetrical
0.2
,
# Case 6 - symmetrical
0.0
1
]
# Case 7 - dead end
0.0
]
# Case 7 - dead end
# Example generate a random rail
# Example generate a random rail
env
=
RailEnv
(
width
=
20
,
env
=
RailEnv
(
width
=
20
,
...
@@ -38,12 +28,12 @@ env.reset()
...
@@ -38,12 +28,12 @@ env.reset()
env_renderer
=
RenderTool
(
env
)
env_renderer
=
RenderTool
(
env
)
env_renderer
.
renderEnv
(
show
=
True
)
env_renderer
.
renderEnv
(
show
=
True
)
"""
# Example generate a rail given a manual specification,
# Example generate a rail given a manual specification,
# a map of tuples (cell_type, rotation)
# a map of tuples (cell_type, rotation)
specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (7, 0), (0, 0)],
specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (7, 0), (0, 0)],
[(7, 270), (1, 90), (1, 90), (1, 90), (2, 90), (7, 90)]]
[(7, 270), (1, 90), (1, 90), (1, 90), (2, 90), (7, 90)]]
"""
env = RailEnv(width=6,
env = RailEnv(width=6,
height=2,
height=2,
rail_generator=rail_from_manual_specifications_generator(specs),
rail_generator=rail_from_manual_specifications_generator(specs),
...
@@ -56,20 +46,20 @@ env.agents_target[0] = [1, 1]
...
@@ -56,20 +46,20 @@ env.agents_target[0] = [1, 1]
env.agents_direction[0] = 1
env.agents_direction[0] = 1
# TODO: watch out: if these variables are overridden, the obs_builder object has to be reset, too!
# TODO: watch out: if these variables are overridden, the obs_builder object has to be reset, too!
env.obs_builder.reset()
env.obs_builder.reset()
#"""
"""
env
=
RailEnv
(
width
=
7
,
env
=
RailEnv
(
width
=
7
,
height
=
7
,
height
=
7
,
rail_generator
=
random_rail_generator
(
cell_type_relative_proportion
=
transition_probability
),
rail_generator
=
random_rail_generator
(
cell_type_relative_proportion
=
transition_probability
),
number_of_agents
=
2
)
number_of_agents
=
2
)
#
TODO: delete next line
#
Print the distance map of each cell to the target of the first agent
#for i in range(4):
#
for i in range(4):
# print(env.obs_builder.distance_map[0, :, :, i])
#
print(env.obs_builder.distance_map[0, :, :, i])
# Print the observation vector for agent 0
obs
,
all_rewards
,
done
,
_
=
env
.
step
({
0
:
0
})
obs
,
all_rewards
,
done
,
_
=
env
.
step
({
0
:
0
})
for
i
in
range
(
env
.
number_of_agents
):
for
i
in
range
(
env
.
number_of_agents
):
env
.
obs_builder
.
util_print_obs_subtree
(
tree
=
obs
[
i
],
num_
element
s_per_node
=
5
)
env
.
obs_builder
.
util_print_obs_subtree
(
tree
=
obs
[
i
],
num_
feature
s_per_node
=
5
)
env_renderer
=
RenderTool
(
env
)
env_renderer
=
RenderTool
(
env
)
env_renderer
.
renderEnv
(
show
=
True
)
env_renderer
.
renderEnv
(
show
=
True
)
...
...
flatland/core/env_observation_builder.py
View file @
9a6efb05
...
@@ -383,15 +383,15 @@ class TreeObsForRailEnv(ObservationBuilder):
...
@@ -383,15 +383,15 @@ class TreeObsForRailEnv(ObservationBuilder):
return
observation
return
observation
def
util_print_obs_subtree
(
self
,
tree
,
num_
element
s_per_node
=
5
,
prompt
=
''
,
current_depth
=
0
):
def
util_print_obs_subtree
(
self
,
tree
,
num_
feature
s_per_node
=
5
,
prompt
=
''
,
current_depth
=
0
):
"""
"""
Utility function to pretty-print tree observations returned by this object.
Utility function to pretty-print tree observations returned by this object.
"""
"""
if
len
(
tree
)
<
num_
element
s_per_node
:
if
len
(
tree
)
<
num_
feature
s_per_node
:
return
return
depth
=
0
depth
=
0
tmp
=
len
(
tree
)
/
num_
element
s_per_node
-
1
tmp
=
len
(
tree
)
/
num_
feature
s_per_node
-
1
pow4
=
4
pow4
=
4
while
tmp
>
0
:
while
tmp
>
0
:
tmp
-=
pow4
tmp
-=
pow4
...
@@ -400,12 +400,12 @@ class TreeObsForRailEnv(ObservationBuilder):
...
@@ -400,12 +400,12 @@ class TreeObsForRailEnv(ObservationBuilder):
prompt_
=
[
'L:'
,
'F:'
,
'R:'
,
'B:'
]
prompt_
=
[
'L:'
,
'F:'
,
'R:'
,
'B:'
]
print
(
" "
*
current_depth
+
prompt
,
tree
[
0
:
num_
element
s_per_node
])
print
(
" "
*
current_depth
+
prompt
,
tree
[
0
:
num_
feature
s_per_node
])
child_size
=
(
len
(
tree
)
-
num_
element
s_per_node
)
//
4
child_size
=
(
len
(
tree
)
-
num_
feature
s_per_node
)
//
4
for
children
in
range
(
4
):
for
children
in
range
(
4
):
child_tree
=
tree
[(
num_
element
s_per_node
+
children
*
child_size
):
child_tree
=
tree
[(
num_
feature
s_per_node
+
children
*
child_size
):
(
num_
element
s_per_node
+
(
children
+
1
)
*
child_size
)]
(
num_
feature
s_per_node
+
(
children
+
1
)
*
child_size
)]
self
.
util_print_obs_subtree
(
child_tree
,
self
.
util_print_obs_subtree
(
child_tree
,
num_
element
s_per_node
,
num_
feature
s_per_node
,
prompt
=
prompt_
[
children
],
prompt
=
prompt_
[
children
],
current_depth
=
current_depth
+
1
)
current_depth
=
current_depth
+
1
)
flatland/core/transition_map.py
View file @
9a6efb05
...
@@ -118,7 +118,7 @@ class GridTransitionMap(TransitionMap):
...
@@ -118,7 +118,7 @@ class GridTransitionMap(TransitionMap):
Width of the grid.
Width of the grid.
height : int
height : int
Height of the grid.
Height of the grid.
transitions
_class
: Transitions object
transitions : Transitions object
The Transitions object to use to encode/decode transitions over the
The Transitions object to use to encode/decode transitions over the
grid.
grid.
...
@@ -243,6 +243,47 @@ class GridTransitionMap(TransitionMap):
...
@@ -243,6 +243,47 @@ class GridTransitionMap(TransitionMap):
return
return
self
.
transitions
.
set_transition
(
self
.
grid
[
cell_id
[
0
]][
cell_id
[
1
]],
cell_id
[
2
],
transition_index
,
new_transition
)
self
.
transitions
.
set_transition
(
self
.
grid
[
cell_id
[
0
]][
cell_id
[
1
]],
cell_id
[
2
],
transition_index
,
new_transition
)
def
save_transition_map
(
self
,
filename
):
"""
Save the transitions grid as `filename', in npy format.
Parameters
----------
filename : string
Name of the file to which to save the transitions grid.
"""
np
.
save
(
filename
,
self
.
grid
)
def
load_transition_map
(
self
,
filename
,
override_gridsize
=
True
):
"""
Load the transitions grid from `filename' (npy format).
The load function only updates the transitions grid, and possibly width and height, but the object has to be
initialized with the correct `transitions' object anyway.
Parameters
----------
filename : string
Name of the file from which to load the transitions grid.
override_gridsize : bool
If override_gridsize=True, the width and height of the GridTransitionMap object are replaced with the size
of the map loaded from `filename'. If override_gridsize=False, the transitions grid is either cropped (if
the grid size is larger than (height,width) ) or padded with zeros (if the grid size is smaller than (height,width) )
"""
new_grid
=
np
.
load
(
filename
)
new_height
=
new_grid
.
shape
[
0
]
new_width
=
new_grid
.
shape
[
1
]
if
override_gridsize
:
self
.
width
=
new_width
self
.
height
=
new_height
self
.
grid
=
new_grid
else
:
self
.
grid
=
self
.
grid
*
0
self
.
grid
[
0
:
min
(
self
.
height
,
new_height
),
0
:
min
(
self
.
width
,
new_width
)]
=
new_grid
[
0
:
min
(
self
.
height
,
new_height
),
0
:
min
(
self
.
width
,
new_width
)]
# TODO: GIACOMO: is it better to provide those methods with lists of cell_ids
# TODO: GIACOMO: is it better to provide those methods with lists of cell_ids
# (most general implementation) or to make Grid-class specific methods for
# (most general implementation) or to make Grid-class specific methods for
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment