Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
Flatland
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
pranjal_dhole
Flatland
Commits
aadff790
Commit
aadff790
authored
5 years ago
by
u214892
Browse files
Options
Downloads
Patches
Plain Diff
47 agent directions in observation
parent
f5022411
No related branches found
No related tags found
No related merge requests found
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
flatland/core/env_observation_builder.py
+8
-1
8 additions, 1 deletion
flatland/core/env_observation_builder.py
flatland/envs/observations.py
+41
-28
41 additions, 28 deletions
flatland/envs/observations.py
with
49 additions
and
29 deletions
flatland/core/env_observation_builder.py
+
8
−
1
View file @
aadff790
...
@@ -7,13 +7,14 @@ The ObservationBuilder-derived custom classes implement 2 functions, reset() and
...
@@ -7,13 +7,14 @@ The ObservationBuilder-derived custom classes implement 2 functions, reset() and
+ Get() is called whenever an observation has to be computed, potentially for each agent independently in
+ Get() is called whenever an observation has to be computed, potentially for each agent independently in
case of multi-agent environments.
case of multi-agent environments.
"""
"""
import
numpy
as
np
class
ObservationBuilder
:
class
ObservationBuilder
:
"""
"""
ObservationBuilder base class.
ObservationBuilder base class.
Derived objects must implement and `observation_space
'
attribute as a tuple with the dimens
u
ions of the returned
Derived objects must implement and `observation_space
'
attribute as a tuple with the dimensions of the returned
observations.
observations.
"""
"""
...
@@ -45,3 +46,9 @@ class ObservationBuilder:
...
@@ -45,3 +46,9 @@ class ObservationBuilder:
An observation structure, specific to the corresponding environment.
An observation structure, specific to the corresponding environment.
"""
"""
raise
NotImplementedError
()
raise
NotImplementedError
()
def
_get_one_hot_for_agent_direction
(
self
,
agent
):
"""
Retuns the agent
'
s direction to one-hot encoding.
"""
direction
=
np
.
zeros
(
4
)
direction
[
agent
.
direction
]
=
1
return
direction
This diff is collapsed.
Click to expand it.
flatland/envs/observations.py
+
41
−
28
View file @
aadff790
"""
"""
Collection of environment-specific ObservationBuilder.
Collection of environment-specific ObservationBuilder.
"""
"""
import
numpy
as
np
from
collections
import
deque
from
collections
import
deque
import
numpy
as
np
from
flatland.core.env_observation_builder
import
ObservationBuilder
from
flatland.core.env_observation_builder
import
ObservationBuilder
...
@@ -22,10 +23,10 @@ class TreeObsForRailEnv(ObservationBuilder):
...
@@ -22,10 +23,10 @@ class TreeObsForRailEnv(ObservationBuilder):
# Compute the size of the returned observation vector
# Compute the size of the returned observation vector
size
=
0
size
=
0
pow4
=
1
pow4
=
1
for
i
in
range
(
self
.
max_depth
+
1
):
for
i
in
range
(
self
.
max_depth
+
1
):
size
+=
pow4
size
+=
pow4
pow4
*=
4
pow4
*=
4
self
.
observation_space
=
[
size
*
5
]
self
.
observation_space
=
[
size
*
6
]
def
reset
(
self
):
def
reset
(
self
):
agents
=
self
.
env
.
agents
agents
=
self
.
env
.
agents
...
@@ -186,6 +187,10 @@ class TreeObsForRailEnv(ObservationBuilder):
...
@@ -186,6 +187,10 @@ class TreeObsForRailEnv(ObservationBuilder):
#5: minimum distance from node to the agent
'
s target (when landing to the node following the corresponding
#5: minimum distance from node to the agent
'
s target (when landing to the node following the corresponding
branch.
branch.
#6: agent direction
Missing/padding nodes are filled in with -inf (truncated).
Missing/padding nodes are filled in with -inf (truncated).
Missing values in present node are filled in with +inf (truncated).
Missing values in present node are filled in with +inf (truncated).
...
@@ -202,13 +207,10 @@ class TreeObsForRailEnv(ObservationBuilder):
...
@@ -202,13 +207,10 @@ class TreeObsForRailEnv(ObservationBuilder):
if
handle
>
len
(
self
.
env
.
agents
):
if
handle
>
len
(
self
.
env
.
agents
):
print
(
"
ERROR: obs _get - handle
"
,
handle
,
"
len(agents)
"
,
len
(
self
.
env
.
agents
))
print
(
"
ERROR: obs _get - handle
"
,
handle
,
"
len(agents)
"
,
len
(
self
.
env
.
agents
))
agent
=
self
.
env
.
agents
[
handle
]
# TODO: handle being treated as index
agent
=
self
.
env
.
agents
[
handle
]
# TODO: handle being treated as index
# position = self.env.agents_position[handle]
# orientation = self.env.agents_direction[handle]
possible_transitions
=
self
.
env
.
rail
.
get_transitions
((
*
agent
.
position
,
agent
.
direction
))
possible_transitions
=
self
.
env
.
rail
.
get_transitions
((
*
agent
.
position
,
agent
.
direction
))
num_transitions
=
np
.
count_nonzero
(
possible_transitions
)
num_transitions
=
np
.
count_nonzero
(
possible_transitions
)
# Root node - current position
# Root node - current position
# observation = [0, 0, 0, 0, self.distance_map[handle, position[0], position[1], orientation]]
observation
=
[
0
,
0
,
0
,
0
,
self
.
distance_map
[(
handle
,
*
agent
.
position
,
agent
.
direction
)],
agent
.
direction
]
observation
=
[
0
,
0
,
0
,
0
,
self
.
distance_map
[(
handle
,
*
agent
.
position
,
agent
.
direction
)]]
root_observation
=
observation
[:]
root_observation
=
observation
[:]
visited
=
set
()
visited
=
set
()
# Start from the current orientation, and see which transitions are available;
# Start from the current orientation, and see which transitions are available;
...
@@ -337,40 +339,49 @@ class TreeObsForRailEnv(ObservationBuilder):
...
@@ -337,40 +339,49 @@ class TreeObsForRailEnv(ObservationBuilder):
1 if other_target_encountered else 0,
1 if other_target_encountered else 0,
1 if other_agent_encountered else 0,
1 if other_agent_encountered else 0,
root_observation[3] + num_steps,
root_observation[3] + num_steps,
0]
0,
direction]
elif last_isTerminal:
elif last_isTerminal:
observation = [0,
observation = [0,
1 if other_target_encountered else 0,
1 if other_target_encountered else 0,
1 if other_agent_encountered else 0,
1 if other_agent_encountered else 0,
np.inf,
np.inf,
np.inf]
np.inf,
direction]
else:
else:
observation = [0,
observation = [0,
1 if other_target_encountered else 0,
1 if other_target_encountered else 0,
1 if other_agent_encountered else 0,
1 if other_agent_encountered else 0,
root_observation[3] + num_steps,
root_observation[3] + num_steps,
self.distance_map[handle, position[0], position[1], direction]]
self.distance_map[handle, position[0], position[1], direction],
direction]
"""
"""
if
last_isTarget
:
if
last_isTarget
:
observation
=
[
0
,
observation
=
[
0
,
other_target_encountered
,
other_target_encountered
,
other_agent_encountered
,
other_agent_encountered
,
root_observation
[
3
]
+
num_steps
,
root_observation
[
3
]
+
num_steps
,
0
]
0
,
direction
]
elif
last_isTerminal
:
elif
last_isTerminal
:
observation
=
[
0
,
observation
=
[
0
,
other_target_encountered
,
other_target_encountered
,
other_agent_encountered
,
other_agent_encountered
,
np
.
inf
,
np
.
inf
,
np
.
inf
]
np
.
inf
,
direction
]
else
:
else
:
observation
=
[
0
,
observation
=
[
0
,
other_target_encountered
,
other_target_encountered
,
other_agent_encountered
,
other_agent_encountered
,
root_observation
[
3
]
+
num_steps
,
root_observation
[
3
]
+
num_steps
,
self
.
distance_map
[
handle
,
position
[
0
],
position
[
1
],
direction
]]
self
.
distance_map
[
handle
,
position
[
0
],
position
[
1
],
direction
],
direction
]
# #############################
# #############################
# #############################
# #############################
...
@@ -409,7 +420,7 @@ class TreeObsForRailEnv(ObservationBuilder):
...
@@ -409,7 +420,7 @@ class TreeObsForRailEnv(ObservationBuilder):
for
i
in
range
(
self
.
max_depth
-
depth
):
for
i
in
range
(
self
.
max_depth
-
depth
):
num_cells_to_fill_in
+=
pow4
num_cells_to_fill_in
+=
pow4
pow4
*=
4
pow4
*=
4
observation
=
observation
+
[
-
np
.
inf
,
-
np
.
inf
,
-
np
.
inf
,
-
np
.
inf
,
-
np
.
inf
]
*
num_cells_to_fill_in
observation
=
observation
+
[
-
np
.
inf
,
-
np
.
inf
,
-
np
.
inf
,
-
np
.
inf
,
-
np
.
inf
,
-
np
.
inf
]
*
num_cells_to_fill_in
return
observation
,
visited
return
observation
,
visited
...
@@ -532,7 +543,9 @@ class GlobalObsForRailEnv(ObservationBuilder):
...
@@ -532,7 +543,9 @@ class GlobalObsForRailEnv(ObservationBuilder):
obs_agents_state
[
agent2
.
position
][
4
+
agent2
.
direction
]
=
1
obs_agents_state
[
agent2
.
position
][
4
+
agent2
.
direction
]
=
1
obs_targets
[
agent2
.
target
][
1
]
+=
1
obs_targets
[
agent2
.
target
][
1
]
+=
1
return
self
.
rail_obs
,
obs_agents_state
,
obs_targets
direction
=
self
.
_get_one_hot_for_agent_direction
(
agent
)
return
self
.
rail_obs
,
obs_agents_state
,
obs_targets
,
direction
class
GlobalObsForRailEnvDirectionDependent
(
ObservationBuilder
):
class
GlobalObsForRailEnvDirectionDependent
(
ObservationBuilder
):
...
@@ -542,13 +555,15 @@ class GlobalObsForRailEnvDirectionDependent(ObservationBuilder):
...
@@ -542,13 +555,15 @@ class GlobalObsForRailEnvDirectionDependent(ObservationBuilder):
- transition map array with dimensions (env.height, env.width, 16),
- transition map array with dimensions (env.height, env.width, 16),
assuming 16 bits encoding of transitions, flipped in the direction of the agent
assuming 16 bits encoding of transitions, flipped in the direction of the agent
(the agent is always heding north on the flipped view).
(the agent is always he
a
ding north on the flipped view).
- Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent
- Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent
target and the positions of the other agents targets, also flipped depending on the agent
'
s direction.
target and the positions of the other agents targets, also flipped depending on the agent
'
s direction.
- A 3D array (map_height, map_width, 5) containing the one hot encoding of the direction of the other
- A 3D array (map_height, map_width, 5) containing the one hot encoding of the direction of the other
agents at their position coordinates, and the last channel containing the position of the given agent.
agents at their position coordinates, and the last channel containing the position of the given agent.
- A 4 elements array with one hot encoding of the direction.
"""
"""
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -603,7 +618,9 @@ class GlobalObsForRailEnvDirectionDependent(ObservationBuilder):
...
@@ -603,7 +618,9 @@ class GlobalObsForRailEnvDirectionDependent(ObservationBuilder):
obs_agents_state
[
agent2
.
position
][
1
+
idx
[
4
+
(
agent2
.
direction
-
direction
)]]
=
1
obs_agents_state
[
agent2
.
position
][
1
+
idx
[
4
+
(
agent2
.
direction
-
direction
)]]
=
1
obs_targets
[
agent2
.
target
][
1
]
+=
1
obs_targets
[
agent2
.
target
][
1
]
+=
1
return
rail_obs
,
obs_agents_state
,
obs_targets
direction
=
self
.
_get_one_hot_for_agent_direction
(
agent
)
return
rail_obs
,
obs_agents_state
,
obs_targets
,
direction
class
LocalObsForRailEnv
(
ObservationBuilder
):
class
LocalObsForRailEnv
(
ObservationBuilder
):
...
@@ -635,8 +652,8 @@ class LocalObsForRailEnv(ObservationBuilder):
...
@@ -635,8 +652,8 @@ class LocalObsForRailEnv(ObservationBuilder):
# We build the transition map with a view_radius empty cells expansion on each side.
# We build the transition map with a view_radius empty cells expansion on each side.
# This helps to collect the local transition map view when the agent is close to a border.
# This helps to collect the local transition map view when the agent is close to a border.
self
.
rail_obs
=
np
.
zeros
((
self
.
env
.
height
+
2
*
self
.
view_radius
,
self
.
rail_obs
=
np
.
zeros
((
self
.
env
.
height
+
2
*
self
.
view_radius
,
self
.
env
.
width
+
2
*
self
.
view_radius
,
16
))
self
.
env
.
width
+
2
*
self
.
view_radius
,
16
))
for
i
in
range
(
self
.
env
.
height
):
for
i
in
range
(
self
.
env
.
height
):
for
j
in
range
(
self
.
env
.
width
):
for
j
in
range
(
self
.
env
.
width
):
bitlist
=
[
int
(
digit
)
for
digit
in
bin
(
self
.
env
.
rail
.
get_transitions
((
i
,
j
)))[
2
:]]
bitlist
=
[
int
(
digit
)
for
digit
in
bin
(
self
.
env
.
rail
.
get_transitions
((
i
,
j
)))[
2
:]]
...
@@ -654,12 +671,12 @@ class LocalObsForRailEnv(ObservationBuilder):
...
@@ -654,12 +671,12 @@ class LocalObsForRailEnv(ObservationBuilder):
# top_offset = max(0, agent.position[0] - 1 - self.view_radius)
# top_offset = max(0, agent.position[0] - 1 - self.view_radius)
# bottom_offset = min(0, agent.position[0] + 1 + self.view_radius)
# bottom_offset = min(0, agent.position[0] + 1 + self.view_radius)
local_rail_obs
=
self
.
rail_obs
[
agent
.
position
[
0
]:
agent
.
position
[
0
]
+
2
*
self
.
view_radius
+
1
,
local_rail_obs
=
self
.
rail_obs
[
agent
.
position
[
0
]:
agent
.
position
[
0
]
+
2
*
self
.
view_radius
+
1
,
agent
.
position
[
1
]:
agent
.
position
[
1
]
+
2
*
self
.
view_radius
+
1
]
agent
.
position
[
1
]:
agent
.
position
[
1
]
+
2
*
self
.
view_radius
+
1
]
obs_map_state
=
np
.
zeros
((
2
*
self
.
view_radius
+
1
,
2
*
self
.
view_radius
+
1
,
2
))
obs_map_state
=
np
.
zeros
((
2
*
self
.
view_radius
+
1
,
2
*
self
.
view_radius
+
1
,
2
))
obs_other_agents_state
=
np
.
zeros
((
2
*
self
.
view_radius
+
1
,
2
*
self
.
view_radius
+
1
,
4
))
obs_other_agents_state
=
np
.
zeros
((
2
*
self
.
view_radius
+
1
,
2
*
self
.
view_radius
+
1
,
4
))
def
relative_pos
(
pos
):
def
relative_pos
(
pos
):
return
[
agent
.
position
[
0
]
-
pos
[
0
],
agent
.
position
[
1
]
-
pos
[
1
]]
return
[
agent
.
position
[
0
]
-
pos
[
0
],
agent
.
position
[
1
]
-
pos
[
1
]]
...
@@ -684,15 +701,11 @@ class LocalObsForRailEnv(ObservationBuilder):
...
@@ -684,15 +701,11 @@ class LocalObsForRailEnv(ObservationBuilder):
if
is_in
(
target_rel_pos_2
):
if
is_in
(
target_rel_pos_2
):
obs_map_state
[
self
.
view_radius
+
np
.
array
(
target_rel_pos_2
)][
1
]
+=
1
obs_map_state
[
self
.
view_radius
+
np
.
array
(
target_rel_pos_2
)][
1
]
+=
1
direction
=
np
.
zeros
(
4
)
direction
=
self
.
_get_one_hot_for_agent_direction
(
agent
)
direction
[
agent
.
direction
]
=
1
return
local_rail_obs
,
obs_map_state
,
obs_other_agents_state
,
direction
return
local_rail_obs
,
obs_map_state
,
obs_other_agents_state
,
direction
# class LocalObsForRailEnvImproved(ObservationBuilder):
# class LocalObsForRailEnvImproved(ObservationBuilder):
# """
# """
# Returns a local observation around the given agent
# Returns a local observation around the given agent
# """
# """
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment