Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Flatland
neurips2020-flatland-baselines
Commits
d8bfd096
Commit
d8bfd096
authored
Sep 27, 2020
by
nilabha
Browse files
Removed typing information to support latest flatland release
parent
0179a4a0
Pipeline
#5606
failed with stage
in 2 minutes and 3 seconds
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
envs/flatland/observations/tree_obs.py
View file @
d8bfd096
...
...
@@ -32,7 +32,7 @@ class TreeObservation(Observation):
return
gym
.
spaces
.
Box
(
low
=-
np
.
inf
,
high
=
np
.
inf
,
shape
=
(
num_features_per_node
*
nr_nodes
,))
def
_split_node_into_feature_groups
(
node
:
TreeObsForRailEnv
.
Node
)
->
(
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
):
def
_split_node_into_feature_groups
(
node
)
->
(
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
):
data
=
np
.
zeros
(
6
)
distance
=
np
.
zeros
(
1
)
agent_data
=
np
.
zeros
(
4
)
...
...
@@ -54,7 +54,7 @@ def _split_node_into_feature_groups(node: TreeObsForRailEnv.Node) -> (np.ndarray
return
data
,
distance
,
agent_data
def
_split_subtree_into_feature_groups
(
node
:
TreeObsForRailEnv
.
Node
,
current_tree_depth
:
int
,
max_tree_depth
:
int
)
->
(
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
):
def
_split_subtree_into_feature_groups
(
node
,
current_tree_depth
:
int
,
max_tree_depth
:
int
)
->
(
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
):
if
node
==
-
np
.
inf
:
remaining_depth
=
max_tree_depth
-
current_tree_depth
# reference: https://stackoverflow.com/questions/515214/total-number-of-nodes-in-a-tree-data-structure
...
...
@@ -75,7 +75,7 @@ def _split_subtree_into_feature_groups(node: TreeObsForRailEnv.Node, current_tre
return
data
,
distance
,
agent_data
def
split_tree_into_feature_groups
(
tree
:
TreeObsForRailEnv
.
Node
,
max_tree_depth
:
int
)
->
(
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
):
def
split_tree_into_feature_groups
(
tree
,
max_tree_depth
:
int
)
->
(
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
):
"""
This function splits the tree into three difference arrays of values
"""
...
...
@@ -90,7 +90,7 @@ def split_tree_into_feature_groups(tree: TreeObsForRailEnv.Node, max_tree_depth:
return
data
,
distance
,
agent_data
def
normalize_observation
(
observation
:
TreeObsForRailEnv
.
Node
,
tree_depth
:
int
,
observation_radius
=
0
):
def
normalize_observation
(
observation
,
tree_depth
:
int
,
observation_radius
=
0
):
"""
This function normalizes the observation used by the RL algorithm
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment