diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py index 7271f6823222ca8e94257140418e03de456a7049..86485ec2c068ea410d5d27997f6f037d3aab6c23 100644 --- a/flatland/core/env_observation_builder.py +++ b/flatland/core/env_observation_builder.py @@ -12,8 +12,6 @@ import numpy as np from collections import deque -# TODO: add docstrings, pylint, etc... - class ObservationBuilder: """ @@ -273,8 +271,7 @@ class TreeObsForRailEnv(ObservationBuilder): exploring = True last_isSwitch = False last_isDeadEnd = False - last_isTerminal = False # wrong cell encountered OR cycle encountered; either way, we don't want the agent - # to land here + last_isTerminal = False # wrong cell OR cycle; either way, we don't want the agent to land here last_isTarget = False visited = set() @@ -302,7 +299,6 @@ class TreeObsForRailEnv(ObservationBuilder): break visited.add((position[0], position[1], direction)) - # If the target node is encountered, pick that as node. Also, no further branching is possible. if position[0] == self.env.agents_target[handle][0] and position[1] == self.env.agents_target[handle][1]: last_isTarget = True @@ -390,14 +386,22 @@ class TreeObsForRailEnv(ObservationBuilder): # it back new_cell = self._new_position(position, (branch_direction+2) % 4) - branch_observation = self._explore_branch(handle, new_cell, (branch_direction+2) % 4, new_root_observation, depth+1) + branch_observation = self._explore_branch(handle, + new_cell, + (branch_direction+2) % 4, + new_root_observation, + depth+1) observation = observation + branch_observation elif last_isSwitch and self.env.rail.get_transition((position[0], position[1], direction), branch_direction): new_cell = self._new_position(position, branch_direction) - branch_observation = self._explore_branch(handle, new_cell, branch_direction, new_root_observation, depth+1) + branch_observation = self._explore_branch(handle, + new_cell, + branch_direction, + new_root_observation, + depth+1) observation = observation + branch_observation else: diff --git a/tests/test_env_observation_builder.py b/tests/test_env_observation_builder.py index 44dbfc6d8f14a7293e148f125b47eed66c9ca08d..db264c2975b75f32c2612aa19c0511076460ec6b 100644 --- a/tests/test_env_observation_builder.py +++ b/tests/test_env_observation_builder.py @@ -1,10 +1,11 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- +import numpy as np + from flatland.core.env_observation_builder import GlobalObsForRailEnv from flatland.core.transition_map import GridTransitionMap, Grid4Transitions from flatland.envs.rail_env import RailEnv, rail_from_GridTransitionMap_generator -from flatland.utils.rendertools import * """Tests for `flatland` package.""" @@ -45,14 +46,14 @@ def test_global_obs(): double_switch_south_horizontal_straight, 180) rail_map = np.array( - [[empty] * 3 + [dead_end_from_south] + [empty] * 6] + - [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 + - [[dead_end_from_east] + [horizontal_straight] * 2 + - [double_switch_north_horizontal_straight] + - [horizontal_straight] * 2 + [double_switch_south_horizontal_straight] + - [horizontal_straight] * 2 + [dead_end_from_west]] + - [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 + - [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16) + [[empty] * 3 + [dead_end_from_south] + [empty] * 6] + + [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 + + [[dead_end_from_east] + [horizontal_straight] * 2 + + [double_switch_north_horizontal_straight] + + [horizontal_straight] * 2 + [double_switch_south_horizontal_straight] + + [horizontal_straight] * 2 + [dead_end_from_west]] + + [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 + + [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16) rail = GridTransitionMap(width=rail_map.shape[1], height=rail_map.shape[0], transitions=transitions) @@ -81,17 +82,3 @@ def test_global_obs(): # If this assertion is wrong, it means that the observation returned # places the agent on an empty cell assert(np.sum(rail_map * global_obs[0][1][0]) > 0) - - - -test_global_obs() - - - - - - - - - -