Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Flatland
neurips2020-flatland-baselines
Commits
7effaf44
Commit
7effaf44
authored
May 01, 2020
by
nilabha
Browse files
Added Typing and some documentation
parent
7ae8390f
Changes
2
Hide whitespace changes
Inline
Side-by-side
envs/flatland/observations/local_conflict_obs.py
View file @
7effaf44
from
typing
import
Optional
,
List
,
Dict
from
typing
import
Optional
,
List
,
Dict
,
Union
,
Tuple
import
gym
import
numpy
as
np
from
flatland.core.env_observation_builder
import
ObservationBuilder
from
flatland.envs.observations
import
TreeObsForRailEnv
from
flatland.envs.predictions
import
ShortestPathPredictorForRailEnv
from
flatland.utils.ordered_set
import
OrderedSet
from
envs.flatland.observations
import
Observation
,
register_obs
# noqa
from
itertools
import
combinations
...
...
@@ -14,27 +15,6 @@ from flatland.core.env_prediction_builder import PredictionBuilder
from
flatland.envs.agent_utils
import
RailAgentStatus
from
flatland.core.grid.grid4_utils
import
get_new_position
# from flatland.envs.rail_env import action_required
def
action_required
(
agent
):
"""
Check if an agent needs to provide an action
Parameters
----------
agent: RailEnvAgent
Agent we want to check
Returns
-------
True: Agent needs to provide an action
False: Agent cannot provide an action
"""
return
(
agent
.
status
==
RailAgentStatus
.
READY_TO_DEPART
or
(
agent
.
status
==
RailAgentStatus
.
ACTIVE
and
np
.
isclose
(
agent
.
speed_data
[
'position_fraction'
],
0.0
,
rtol
=
1e-03
)))
@
register_obs
(
"localConflict"
)
...
...
@@ -46,7 +26,8 @@ class LocalConflictObservation(Observation):
LocalConflictObsForRailEnv
(
max_depth
=
config
[
'max_depth'
],
predictor
=
ShortestPathPredictorForRailEnv
(
config
[
'shortest_path_max_depth'
]))
config
[
'shortest_path_max_depth'
]),
n_local
=
config
[
'n_local'
])
)
def
builder
(
self
)
->
ObservationBuilder
:
...
...
@@ -59,10 +40,21 @@ class LocalConflictObservation(Observation):
class
LocalConflictObsForRailEnvRLLibWrapper
(
ObservationBuilder
):
"""
The information is for each agent but uses the full set of
observations for all agents to come up with set of local
(Default: 5) most conflicting agents.
The observation set is based on the current agent and these local
identified agents. We also information about conflicts.
"""
def
__init__
(
self
,
local_conflict_obs_builder
:
TreeObsForRailEnv
):
super
().
__init__
()
self
.
_builder
=
local_conflict_obs_builder
self
.
agent_states
=
None
# To cache calculated agent states
# This is only computed once and reused for all other agents
self
.
agent_states
:
Optional
[
Dict
]
=
None
@
property
def
observation_dim
(
self
):
...
...
@@ -87,24 +79,16 @@ class LocalConflictObsForRailEnvRLLibWrapper(ObservationBuilder):
def
get_many
(
self
,
handles
:
Optional
[
List
[
int
]]
=
None
):
all_agent_observations
=
self
.
_builder
.
get_many
(
handles
)
o
=
dict
()
obs
=
dict
()
if
handles
is
None
:
handles
=
[]
for
k
in
handles
:
if
not
self
.
agent_states
:
self
.
agent_states
=
create_agent_states
(
all_agent_observations
,
self
.
_builder
.
predictor
.
max_depth
)
o
[
k
]
=
self
.
agent_states
[
k
]
return
o
# return {k: create_agent_states(o, self._builder.max_depth)
# for k, o in self._builder.get_many(handles).items()
# if o is not None}
obs
[
k
]
=
self
.
agent_states
[
k
]
# def util_print_obs_subtree(self, tree):
# self._builder.util_print_obs_subtree(tree)
# def print_subtree(self, node, label, indent):
# self._builder.print_subtree(node, label, indent)
return
obs
def
set_env
(
self
,
env
):
self
.
_builder
.
set_env
(
env
)
...
...
@@ -112,7 +96,14 @@ class LocalConflictObsForRailEnvRLLibWrapper(ObservationBuilder):
class
LocalConflictObsForRailEnv
(
TreeObsForRailEnv
):
"""
LocalConflict object made from TreeObsForRailEnv object.
This object returns observation vectors for agents in the RailEnv.
For details about the features in the observation
see the get() function.
We normalise all observations based on the grid size
"""
Node
=
collections
.
namedtuple
(
'Node'
,
'distance_target '
'observation_shortest '
...
...
@@ -136,7 +127,7 @@ class LocalConflictObsForRailEnv(TreeObsForRailEnv):
'predicted_pos'
)
def
__init__
(
self
,
max_depth
:
int
,
predictor
:
PredictionBuilder
=
None
,
n_local
=
5
):
n_local
:
int
=
5
):
super
().
__init__
(
max_depth
,
predictor
)
self
.
observation_dim
=
1
+
3
*
(
n_local
-
1
)
+
22
*
n_local
...
...
@@ -145,14 +136,7 @@ class LocalConflictObsForRailEnv(TreeObsForRailEnv):
def
get_many
(
self
,
handles
:
Optional
[
List
[
int
]]
=
None
):
# observations = {}
# if handles is None:
# handles = []
# for h in handles:
# observations[h] = self.get(h)
# return observations
observations
=
super
().
get_many
(
handles
)
# observations = list(observations.values())
return
observations
def
get
(
self
,
handle
:
int
=
0
):
...
...
@@ -186,7 +170,7 @@ class LocalConflictObsForRailEnv(TreeObsForRailEnv):
max_distance
=
self
.
env
.
width
+
self
.
env
.
height
# max_steps = int(4 * 2 * (20 + self.env.height + self.env.width))
visited
=
s
et
()
visited
=
OrderedS
et
()
for
_idx
in
range
(
10
):
# Check if any of the other prediction overlap
# with agents own predictions
...
...
@@ -200,6 +184,8 @@ class LocalConflictObsForRailEnv(TreeObsForRailEnv):
# visualize the observation
self
.
env
.
dev_obs_dict
[
handle
]
=
visited
# min_distance stores the distance to target in shortest path
# and any alternate path if exists
min_distances
=
[]
for
direction
in
[(
agent
.
direction
+
i
)
%
4
for
i
in
range
(
-
1
,
2
)]:
if
possible_transitions
[
direction
]:
...
...
@@ -266,16 +252,18 @@ class LocalConflictObsForRailEnv(TreeObsForRailEnv):
return
self
.
env
.
get_num_agents
()
def
create_agent_states
(
obs
,
max_depth
:
int
,
n_local
:
int
=
5
):
def
create_agent_states
(
obs
:
Union
[
Dict
,
List
],
max_depth
:
int
,
n_local
:
int
=
5
)
->
Dict
:
"""
Identifies local agent conflicts and adds information from
conflict prediction matrix.
"""
n_agents
=
len
(
obs
)
x_dim
=
0
y_dim
=
0
print
(
" N Agents:"
,
n_agents
)
for
i
in
range
(
n_agents
):
if
obs
[
i
]
is
not
None
:
custom_observations
=
obs
[
i
]
# n_agents, x_dim, y_dim = custom_observations.n_agents,
x_dim
=
custom_observations
.
width
y_dim
=
custom_observations
.
height
break
...
...
@@ -315,22 +303,21 @@ def create_agent_states(obs,
info_action_required
[
i
]
=
int
(
custom_observations
.
action_required
)
predicted_pos
=
custom_observations
.
predicted_pos
agent_conflicts_count_path
,
agent_conflicts_step_path
,
agent_total_step_conflicts
=
get_agent_conflict_prediction_matrix
(
n_agents
,
max_depth
,
predicted_pos
)
agent_conflicts_count_path
,
agent_conflicts_step_path
,
\
agent_total_step_conflicts
=
get_agent_conflict_prediction_matrix
(
n_agents
,
max_depth
,
predicted_pos
)
# Normalise based on average grid dimensions
avg_dim
=
(
x_dim
*
y_dim
)
**
0.5
depth
=
int
(
n_local
*
avg_dim
/
n_agents
)
agent_conflict_steps
=
min
(
max_depth
-
1
,
depth
)
agent_conflicts
=
agent_conflicts_step_path
[
agent_conflict_steps
]
# agent_counts = agent_conflicts_count_path[agent_conflict_steps]
agent_conflicts_avg_step_count
=
np
.
average
(
agent_total_step_conflicts
)
/
n_agents
for
i
in
range
(
n_agents
):
# if obs is None or obs[i] is None:
# # action_dict.update({i: 2})
if
obs
[
i
]
is
not
None
:
n_upd_local
=
min
(
n_local
,
n_agents
-
1
)
if
n_upd_local
<
n_local
:
...
...
@@ -393,7 +380,8 @@ def create_agent_states(obs,
return
local_agent_states_all
def
get_agent_conflict_prediction_matrix
(
n_agents
,
max_depth
,
predicted_pos
):
def
get_agent_conflict_prediction_matrix
(
n_agents
,
max_depth
,
predicted_pos
)
->
Tuple
[
List
,
List
,
List
]:
agent_total_step_conflicts
=
[]
agent_conflicts_step_path
=
[]
agent_conflicts_count_path
=
[]
...
...
@@ -439,4 +427,25 @@ def get_agent_conflict_prediction_matrix(n_agents, max_depth, predicted_pos):
agent_total_step_conflicts
.
append
(
sum
(
agent_conflicts_step_current
[
i
,
:]))
return
agent_conflicts_count_path
,
agent_conflicts_step_path
,
agent_total_step_conflicts
return
agent_conflicts_count_path
,
agent_conflicts_step_path
,
\
agent_total_step_conflicts
def
action_required
(
agent
):
"""
Check if an agent needs to provide an action
Parameters
----------
agent: RailEnvAgent
Agent we want to check
Returns
-------
True: Agent needs to provide an action
False: Agent cannot provide an action
"""
return
(
agent
.
status
==
RailAgentStatus
.
READY_TO_DEPART
or
(
agent
.
status
==
RailAgentStatus
.
ACTIVE
and
np
.
isclose
(
agent
.
speed_data
[
'position_fraction'
],
0.0
,
rtol
=
1e-03
)))
experiments/flatland_random_sparse_small/local_conflict_obs_fc_net/ppo.yaml
View file @
7effaf44
...
...
@@ -37,6 +37,7 @@ flatland-random-sparse-small-local-conflict-fc-ppo:
observation_config
:
max_depth
:
2
shortest_path_max_depth
:
30
n_local
:
5
regenerate_rail_on_reset
:
True
regenerate_schedule_on_reset
:
True
render
:
False
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment