Commit 63bd0b96 authored by hagrid67's avatar hagrid67
Browse files

moved Node out of TreeObsForRailEnv to avoid pickle problem

parent b4c64e45
Pipeline #4990 failed with stages
in 33 minutes and 2 seconds
......@@ -15,6 +15,7 @@ from flatland.envs.agent_utils import RailAgentStatus, EnvAgent
from flatland.utils.ordered_set import OrderedSet
Node = collections.namedtuple('Node', 'dist_own_target_encountered '
'dist_other_target_encountered '
'dist_other_agent_encountered '
......@@ -219,6 +220,7 @@ class TreeObsForRailEnv(ObservationBuilder):
speed_min_fractional=agent.speed_data['speed'],
num_agents_ready_to_depart=0,
childs={})
print("root node type:", type(root_node_observation))
visited = OrderedSet()
......
......@@ -57,7 +57,7 @@ class EmptyRailGen(RailGen):
Primarily used by the editor
"""
def generate(width: int, height: int, num_agents: int, num_resets: int = 0,
def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0,
np_random: RandomState = None) -> RailGenerator:
rail_trans = RailEnvTransitions()
grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
......
......@@ -6,7 +6,7 @@ import numpy as np
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.observations import TreeObsForRailEnv, Node
from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_env_shortest_paths import get_shortest_paths
......@@ -294,7 +294,7 @@ def test_shortest_path_predictor_conflicts(rendering=False):
_check_expected_conflicts(expected_conflicts_1, obs_builder, tree_1, "agent[1]: ")
def _check_expected_conflicts(expected_conflicts, obs_builder, tree: TreeObsForRailEnv.Node, prompt=''):
def _check_expected_conflicts(expected_conflicts, obs_builder, tree: Node, prompt=''):
assert (tree.num_agents_opposite_direction > 0) == (() in expected_conflicts), "{}[]".format(prompt)
for a_1 in obs_builder.tree_explored_actions_char:
if tree.childs[a_1] == -np.inf:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment