diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 97e5a050aa299ae5a1e37763c2ac75cc85f946c2..3738edd392492cb8ecbedc96f65a42643b4d20d9 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -15,6 +15,7 @@ from flatland.envs.agent_utils import RailAgentStatus, EnvAgent from flatland.utils.ordered_set import OrderedSet + Node = collections.namedtuple('Node', 'dist_own_target_encountered ' 'dist_other_target_encountered ' 'dist_other_agent_encountered ' @@ -219,6 +220,7 @@ class TreeObsForRailEnv(ObservationBuilder): speed_min_fractional=agent.speed_data['speed'], num_agents_ready_to_depart=0, childs={}) + print("root node type:", type(root_node_observation)) visited = OrderedSet() diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 5fc74da30e8626281a1654345ee7d242e7ab98d9..750a8c4f32fa60cef15c66590fe0fba3c1823aa9 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -57,7 +57,7 @@ class EmptyRailGen(RailGen): Primarily used by the editor """ - def generate(width: int, height: int, num_agents: int, num_resets: int = 0, + def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0, np_random: RandomState = None) -> RailGenerator: rail_trans = RailEnvTransitions() grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans) diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py index 3aa9d8a54e9e8db628f4fbc1c9a0e8db4f1b0305..c649517108597de87dd8195169b4672434590c0f 100644 --- a/tests/test_flatland_envs_predictions.py +++ b/tests/test_flatland_envs_predictions.py @@ -6,7 +6,7 @@ import numpy as np from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.envs.agent_utils import RailAgentStatus -from flatland.envs.observations import TreeObsForRailEnv +from flatland.envs.observations import TreeObsForRailEnv, Node from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env_shortest_paths import get_shortest_paths @@ -294,7 +294,7 @@ def test_shortest_path_predictor_conflicts(rendering=False): _check_expected_conflicts(expected_conflicts_1, obs_builder, tree_1, "agent[1]: ") -def _check_expected_conflicts(expected_conflicts, obs_builder, tree: TreeObsForRailEnv.Node, prompt=''): +def _check_expected_conflicts(expected_conflicts, obs_builder, tree: Node, prompt=''): assert (tree.num_agents_opposite_direction > 0) == (() in expected_conflicts), "{}[]".format(prompt) for a_1 in obs_builder.tree_explored_actions_char: if tree.childs[a_1] == -np.inf: