Commit d8bfd096 authored by nilabha's avatar nilabha
Browse files

Removed typing information to support latest flatland release

parent 0179a4a0
Pipeline #5606 failed with stage
in 2 minutes and 3 seconds
...@@ -32,7 +32,7 @@ class TreeObservation(Observation): ...@@ -32,7 +32,7 @@ class TreeObservation(Observation):
return gym.spaces.Box(low=-np.inf, high=np.inf, shape=(num_features_per_node * nr_nodes,)) return gym.spaces.Box(low=-np.inf, high=np.inf, shape=(num_features_per_node * nr_nodes,))
def _split_node_into_feature_groups(node: TreeObsForRailEnv.Node) -> (np.ndarray, np.ndarray, np.ndarray): def _split_node_into_feature_groups(node) -> (np.ndarray, np.ndarray, np.ndarray):
data = np.zeros(6) data = np.zeros(6)
distance = np.zeros(1) distance = np.zeros(1)
agent_data = np.zeros(4) agent_data = np.zeros(4)
...@@ -54,7 +54,7 @@ def _split_node_into_feature_groups(node: TreeObsForRailEnv.Node) -> (np.ndarray ...@@ -54,7 +54,7 @@ def _split_node_into_feature_groups(node: TreeObsForRailEnv.Node) -> (np.ndarray
return data, distance, agent_data return data, distance, agent_data
def _split_subtree_into_feature_groups(node: TreeObsForRailEnv.Node, current_tree_depth: int, max_tree_depth: int) -> (np.ndarray, np.ndarray, np.ndarray): def _split_subtree_into_feature_groups(node, current_tree_depth: int, max_tree_depth: int) -> (np.ndarray, np.ndarray, np.ndarray):
if node == -np.inf: if node == -np.inf:
remaining_depth = max_tree_depth - current_tree_depth remaining_depth = max_tree_depth - current_tree_depth
# reference: https://stackoverflow.com/questions/515214/total-number-of-nodes-in-a-tree-data-structure # reference: https://stackoverflow.com/questions/515214/total-number-of-nodes-in-a-tree-data-structure
...@@ -75,7 +75,7 @@ def _split_subtree_into_feature_groups(node: TreeObsForRailEnv.Node, current_tre ...@@ -75,7 +75,7 @@ def _split_subtree_into_feature_groups(node: TreeObsForRailEnv.Node, current_tre
return data, distance, agent_data return data, distance, agent_data
def split_tree_into_feature_groups(tree: TreeObsForRailEnv.Node, max_tree_depth: int) -> (np.ndarray, np.ndarray, np.ndarray): def split_tree_into_feature_groups(tree, max_tree_depth: int) -> (np.ndarray, np.ndarray, np.ndarray):
""" """
This function splits the tree into three difference arrays of values This function splits the tree into three difference arrays of values
""" """
...@@ -90,7 +90,7 @@ def split_tree_into_feature_groups(tree: TreeObsForRailEnv.Node, max_tree_depth: ...@@ -90,7 +90,7 @@ def split_tree_into_feature_groups(tree: TreeObsForRailEnv.Node, max_tree_depth:
return data, distance, agent_data return data, distance, agent_data
def normalize_observation(observation: TreeObsForRailEnv.Node, tree_depth: int, observation_radius=0): def normalize_observation(observation, tree_depth: int, observation_radius=0):
""" """
This function normalizes the observation used by the RL algorithm This function normalizes the observation used by the RL algorithm
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment