From 63bd0b9610f58f815bce5c5f2685809f506423e2 Mon Sep 17 00:00:00 2001 From: hagrid67 <jdhwatson@gmail.com> Date: Mon, 20 Jul 2020 17:50:48 +0100 Subject: [PATCH] moved Node out of TreeObsForRailEnv to avoid pickle problem --- flatland/envs/observations.py | 2 ++ flatland/envs/rail_generators.py | 2 +- tests/test_flatland_envs_predictions.py | 4 ++-- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 97e5a050..3738edd3 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -15,6 +15,7 @@ from flatland.envs.agent_utils import RailAgentStatus, EnvAgent from flatland.utils.ordered_set import OrderedSet + Node = collections.namedtuple('Node', 'dist_own_target_encountered ' 'dist_other_target_encountered ' 'dist_other_agent_encountered ' @@ -219,6 +220,7 @@ class TreeObsForRailEnv(ObservationBuilder): speed_min_fractional=agent.speed_data['speed'], num_agents_ready_to_depart=0, childs={}) + print("root node type:", type(root_node_observation)) visited = OrderedSet() diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 5fc74da3..750a8c4f 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -57,7 +57,7 @@ class EmptyRailGen(RailGen): Primarily used by the editor """ - def generate(width: int, height: int, num_agents: int, num_resets: int = 0, + def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0, np_random: RandomState = None) -> RailGenerator: rail_trans = RailEnvTransitions() grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans) diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py index 3aa9d8a5..c6495171 100644 --- a/tests/test_flatland_envs_predictions.py +++ b/tests/test_flatland_envs_predictions.py @@ -6,7 +6,7 @@ import numpy as np from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.envs.agent_utils import RailAgentStatus -from flatland.envs.observations import TreeObsForRailEnv +from flatland.envs.observations import TreeObsForRailEnv, Node from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env_shortest_paths import get_shortest_paths @@ -294,7 +294,7 @@ def test_shortest_path_predictor_conflicts(rendering=False): _check_expected_conflicts(expected_conflicts_1, obs_builder, tree_1, "agent[1]: ") -def _check_expected_conflicts(expected_conflicts, obs_builder, tree: TreeObsForRailEnv.Node, prompt=''): +def _check_expected_conflicts(expected_conflicts, obs_builder, tree: Node, prompt=''): assert (tree.num_agents_opposite_direction > 0) == (() in expected_conflicts), "{}[]".format(prompt) for a_1 in obs_builder.tree_explored_actions_char: if tree.childs[a_1] == -np.inf: -- GitLab