Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
Flatland
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
elrichgro
Flatland
Commits
6e2c1c9d
Commit
6e2c1c9d
authored
5 years ago
by
Erik Nygren
Browse files
Options
Downloads
Patches
Plain Diff
updating local obs for rail env to be better suited for the task
parent
c9c5e411
No related branches found
No related tags found
No related merge requests found
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
examples/training_example.py
+8
-2
8 additions, 2 deletions
examples/training_example.py
flatland/core/env_observation_builder.py
+1
-0
1 addition, 0 deletions
flatland/core/env_observation_builder.py
flatland/envs/observations.py
+77
-33
77 additions, 33 deletions
flatland/envs/observations.py
with
86 additions
and
35 deletions
examples/training_example.py
+
8
−
2
View file @
6e2c1c9d
import
numpy
as
np
from
flatland.envs.generators
import
complex_rail_generator
from
flatland.envs.observations
import
TreeObsForRailEnv
from
flatland.envs.observations
import
TreeObsForRailEnv
,
LocalObsForRailEnv
from
flatland.envs.predictions
import
ShortestPathPredictorForRailEnv
from
flatland.envs.rail_env
import
RailEnv
from
flatland.utils.rendertools
import
RenderTool
np
.
random
.
seed
(
1
)
...
...
@@ -12,12 +13,14 @@ np.random.seed(1)
#
TreeObservation
=
TreeObsForRailEnv
(
max_depth
=
2
,
predictor
=
ShortestPathPredictorForRailEnv
())
LocalGridObs
=
LocalObsForRailEnv
(
view_height
=
10
,
view_width
=
2
,
center
=
2
)
env
=
RailEnv
(
width
=
20
,
height
=
20
,
rail_generator
=
complex_rail_generator
(
nr_start_goal
=
10
,
nr_extra
=
1
,
min_dist
=
8
,
max_dist
=
99999
,
seed
=
0
),
obs_builder_object
=
TreeObservation
,
obs_builder_object
=
LocalGridObs
,
number_of_agents
=
2
)
env_renderer
=
RenderTool
(
env
,
gl
=
"
PILSVG
"
,
)
# Import your own Agent or use RLlib to train agents on Flatland
# As an example we use a random agent here
...
...
@@ -66,6 +69,7 @@ for trials in range(1, n_trials + 1):
# Reset environment and get initial observations for all agents
obs
=
env
.
reset
()
env_renderer
.
reset
()
# Here you can also further enhance the provided observation by means of normalization
# See training navigation example in the baseline repository
...
...
@@ -80,6 +84,8 @@ for trials in range(1, n_trials + 1):
# Environment step which returns the observations for all agents, their corresponding
# reward and whether their are done
next_obs
,
all_rewards
,
done
,
_
=
env
.
step
(
action_dict
)
env_renderer
.
render_env
(
show
=
True
,
show_observations
=
True
,
show_predictions
=
False
)
# Update replay buffer and train agent
for
a
in
range
(
env
.
get_num_agents
()):
agent
.
step
((
obs
[
a
],
action_dict
[
a
],
all_rewards
[
a
],
next_obs
[
a
],
done
[
a
]))
...
...
This diff is collapsed.
Click to expand it.
flatland/core/env_observation_builder.py
+
1
−
0
View file @
6e2c1c9d
...
...
@@ -74,6 +74,7 @@ class ObservationBuilder:
direction
[
agent
.
direction
]
=
1
return
direction
class
DummyObservationBuilder
(
ObservationBuilder
):
"""
DummyObservationBuilder class which returns dummy observations
...
...
This diff is collapsed.
Click to expand it.
flatland/envs/observations.py
+
77
−
33
View file @
6e2c1c9d
...
...
@@ -698,71 +698,115 @@ class LocalObsForRailEnv(ObservationBuilder):
The observation is composed of the following elements:
- transition map array of the local environment around the given agent,
with dimensions (
2*
view_
radius + 1, 2*view_radius +
1, 16),
with dimensions (view_
height,2*view_width+
1, 16),
assuming 16 bits encoding of transitions.
- Two
2
D arrays (
2*
view_
radius + 1, 2*view_radius +
1, 2) containing respectively,
- Two
3
D arrays (view_
height,2*view_width+
1, 2) containing respectively,
if they are in the agent
'
s vision range, its target position, the positions of the other targets.
- A 3D array (
2*
view_
radius + 1, 2*view_radius +
1, 4) containing the one hot encoding of directions
- A 3D array (view_
height,2*view_width+
1, 4) containing the one hot encoding of directions
of the other agents at their position coordinates, if they are in the agent
'
s vision range.
- A 4 elements array with one hot encoding of the direction.
"""
def
__init__
(
self
,
view_
radius
):
def
__init__
(
self
,
view_
width
,
view_height
,
center
):
"""
:param view_radius:
"""
super
(
LocalObsForRailEnv
,
self
).
__init__
()
self
.
view_radius
=
view_radius
self
.
view_width
=
view_width
self
.
view_height
=
view_height
self
.
center
=
center
self
.
max_padding
=
max
(
self
.
view_width
,
self
.
view_height
-
self
.
center
)
def
reset
(
self
):
# We build the transition map with a view_radius empty cells expansion on each side.
# This helps to collect the local transition map view when the agent is close to a border.
self
.
rail_obs
=
np
.
zeros
((
self
.
env
.
height
+
2
*
self
.
view_radius
,
self
.
env
.
width
+
2
*
self
.
view_radius
,
16
))
self
.
max_padding
=
max
(
self
.
view_width
,
self
.
view_height
)
self
.
rail_obs
=
np
.
zeros
((
self
.
env
.
height
+
2
*
self
.
max_padding
,
self
.
env
.
width
+
2
*
self
.
max_padding
,
16
))
for
i
in
range
(
self
.
env
.
height
):
for
j
in
range
(
self
.
env
.
width
):
bitlist
=
[
int
(
digit
)
for
digit
in
bin
(
self
.
env
.
rail
.
get_full_transitions
(
i
,
j
))[
2
:]]
bitlist
=
[
0
]
*
(
16
-
len
(
bitlist
))
+
bitlist
self
.
rail_obs
[
i
+
self
.
view_
radius
,
j
+
self
.
view_
radius
]
=
np
.
array
(
bitlist
)
self
.
rail_obs
[
i
+
self
.
view_
height
,
j
+
self
.
view_
width
]
=
np
.
array
(
bitlist
)
def
get
(
self
,
handle
):
agents
=
self
.
env
.
agents
agent
=
agents
[
handle
]
agent_rel_pos
=
[
0
,
0
]
local_rail_obs
=
self
.
rail_obs
[
agent
.
position
[
0
]:
agent
.
position
[
0
]
+
2
*
self
.
view_radius
+
1
,
agent
.
position
[
1
]:
agent
.
position
[
1
]
+
2
*
self
.
view_radius
+
1
]
obs_map_state
=
np
.
zeros
((
2
*
self
.
view_radius
+
1
,
2
*
self
.
view_radius
+
1
,
2
))
# Correct agents position for padding
agent_rel_pos
[
0
]
=
agent
.
position
[
0
]
+
self
.
max_padding
agent_rel_pos
[
1
]
=
agent
.
position
[
1
]
+
self
.
max_padding
obs_other_agents_state
=
np
.
zeros
((
2
*
self
.
view_radius
+
1
,
2
*
self
.
view_radius
+
1
,
4
))
# Collect the rail information in the local field of view
local_rail_obs
=
self
.
field_of_view
(
agent_rel_pos
,
agent
.
direction
,
state
=
self
.
rail_obs
)
def
relative_pos
(
pos
):
return
[
agent
.
position
[
0
]
-
pos
[
0
],
agent
.
position
[
1
]
-
pos
[
1
]]
# Locate observed agents and their coresponding targets
obs_map_state
=
np
.
zeros
((
self
.
view_height
+
1
,
2
*
self
.
view_width
+
1
,
2
))
obs_other_agents_state
=
np
.
zeros
((
self
.
view_height
+
1
,
2
*
self
.
view_width
+
1
,
4
))
def
is_in
(
rel_pos
):
return
(
abs
(
rel_pos
[
0
])
<
=
self
.
v
ie
w_radius
)
and
(
abs
(
rel_pos
[
1
])
<=
self
.
view_radius
)
# Collect visible cells as set to be plotted
visited
=
self
.
f
ie
ld_of_view
(
agent
.
position
,
agent
.
direction
)
target_rel_pos
=
relative_pos
(
agent
.
target
)
if
is_in
(
target_rel_pos
):
obs_map_state
[
self
.
view_radius
+
np
.
array
(
target_rel_pos
)][
0
]
+=
1
# Add the visible cells to the observed cells
self
.
env
.
dev_obs_dict
[
handle
]
=
visited
for
i
in
range
(
len
(
agents
)):
if
i
!=
handle
:
# TODO: handle used as index...?
agent2
=
agents
[
i
]
direction
=
self
.
_get_one_hot_for_agent_direction
(
agent
)
agent_2_rel_pos
=
relative_pos
(
agent2
.
position
)
if
is_in
(
agent_2_rel_pos
):
obs_other_agents_state
[
self
.
view_radius
+
agent_2_rel_pos
[
0
],
self
.
view_radius
+
agent_2_rel_pos
[
1
]][
agent2
.
direction
]
+=
1
return
local_rail_obs
,
obs_map_state
,
obs_other_agents_state
,
direction
target_rel_pos_2
=
relative_pos
(
agent2
.
position
)
if
is_in
(
target_rel_pos_2
):
obs_map_state
[
self
.
view_radius
+
np
.
array
(
target_rel_pos_2
)][
1
]
+=
1
def
get_many
(
self
,
handles
=
None
):
"""
Called whenever an observation has to be computed for the `env
'
environment, for each agent with handle
in the `handles
'
list.
"""
direction
=
self
.
_get_one_hot_for_agent_direction
(
agent
)
observations
=
{}
for
h
in
handles
:
observations
[
h
]
=
self
.
get
(
h
)
return
observations
return
local_rail_obs
,
obs_map_state
,
obs_other_agents_state
,
direction
def
field_of_view
(
self
,
position
,
direction
,
state
=
None
):
# Compute the local field of view for an agent in the environment
data_collection
=
False
if
state
is
not
None
:
temp_visible_data
=
np
.
zeros
(
shape
=
(
self
.
view_height
,
2
*
self
.
view_width
+
1
,
16
))
data_collection
=
True
if
direction
==
0
:
origin
=
(
position
[
0
]
+
self
.
center
,
position
[
1
]
-
self
.
view_width
)
elif
direction
==
1
:
origin
=
(
position
[
0
]
-
self
.
view_width
,
position
[
1
]
-
self
.
center
)
elif
direction
==
2
:
origin
=
(
position
[
0
]
-
self
.
center
,
position
[
1
]
+
self
.
view_width
)
else
:
origin
=
(
position
[
0
]
+
self
.
view_width
,
position
[
1
]
+
self
.
center
)
visible
=
set
()
for
h
in
range
(
self
.
view_height
):
for
w
in
range
(
2
*
self
.
view_width
+
1
):
if
direction
==
0
:
if
0
<=
origin
[
0
]
-
h
<
self
.
env
.
height
and
0
<=
origin
[
1
]
+
w
<
self
.
env
.
width
:
visible
.
add
((
origin
[
0
]
-
h
,
origin
[
1
]
+
w
))
if
data_collection
:
temp_visible_data
[
h
,
w
,
:]
=
state
[
origin
[
0
]
-
h
,
origin
[
1
]
+
w
,
:]
elif
direction
==
1
:
if
0
<=
origin
[
0
]
+
w
<
self
.
env
.
height
and
0
<=
origin
[
1
]
+
h
<
self
.
env
.
width
:
visible
.
add
((
origin
[
0
]
+
w
,
origin
[
1
]
+
h
))
if
data_collection
:
temp_visible_data
[
h
,
w
,
:]
=
state
[
origin
[
0
]
+
w
,
origin
[
1
]
+
h
,
:]
elif
direction
==
2
:
if
0
<=
origin
[
0
]
-
h
<
self
.
env
.
height
and
0
<=
origin
[
1
]
+
w
<
self
.
env
.
width
:
visible
.
add
((
origin
[
0
]
+
h
,
origin
[
1
]
-
w
))
if
data_collection
:
temp_visible_data
[
h
,
w
,
:]
=
state
[
origin
[
0
]
+
h
,
origin
[
1
]
-
w
,
:]
else
:
if
0
<=
origin
[
0
]
-
h
<
self
.
env
.
height
and
0
<=
origin
[
1
]
+
w
<
self
.
env
.
width
:
visible
.
add
((
origin
[
0
]
-
w
,
origin
[
1
]
-
h
))
if
data_collection
:
temp_visible_data
[
h
,
w
,
:]
=
state
[
origin
[
0
]
-
w
,
origin
[
1
]
-
h
,
:]
if
data_collection
:
return
temp_visible_data
else
:
return
visible
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment