Skip to content
Snippets Groups Projects
Commit 06c520f0 authored by spiglerg's avatar spiglerg
Browse files

fixed pylint errors

parent 1e2d8f26
No related branches found
No related tags found
No related merge requests found
...@@ -12,8 +12,6 @@ import numpy as np ...@@ -12,8 +12,6 @@ import numpy as np
from collections import deque from collections import deque
# TODO: add docstrings, pylint, etc...
class ObservationBuilder: class ObservationBuilder:
""" """
...@@ -273,8 +271,7 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -273,8 +271,7 @@ class TreeObsForRailEnv(ObservationBuilder):
exploring = True exploring = True
last_isSwitch = False last_isSwitch = False
last_isDeadEnd = False last_isDeadEnd = False
last_isTerminal = False # wrong cell encountered OR cycle encountered; either way, we don't want the agent last_isTerminal = False # wrong cell OR cycle; either way, we don't want the agent to land here
# to land here
last_isTarget = False last_isTarget = False
visited = set() visited = set()
...@@ -302,7 +299,6 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -302,7 +299,6 @@ class TreeObsForRailEnv(ObservationBuilder):
break break
visited.add((position[0], position[1], direction)) visited.add((position[0], position[1], direction))
# If the target node is encountered, pick that as node. Also, no further branching is possible. # If the target node is encountered, pick that as node. Also, no further branching is possible.
if position[0] == self.env.agents_target[handle][0] and position[1] == self.env.agents_target[handle][1]: if position[0] == self.env.agents_target[handle][0] and position[1] == self.env.agents_target[handle][1]:
last_isTarget = True last_isTarget = True
...@@ -390,14 +386,22 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -390,14 +386,22 @@ class TreeObsForRailEnv(ObservationBuilder):
# it back # it back
new_cell = self._new_position(position, (branch_direction+2) % 4) new_cell = self._new_position(position, (branch_direction+2) % 4)
branch_observation = self._explore_branch(handle, new_cell, (branch_direction+2) % 4, new_root_observation, depth+1) branch_observation = self._explore_branch(handle,
new_cell,
(branch_direction+2) % 4,
new_root_observation,
depth+1)
observation = observation + branch_observation observation = observation + branch_observation
elif last_isSwitch and self.env.rail.get_transition((position[0], position[1], direction), elif last_isSwitch and self.env.rail.get_transition((position[0], position[1], direction),
branch_direction): branch_direction):
new_cell = self._new_position(position, branch_direction) new_cell = self._new_position(position, branch_direction)
branch_observation = self._explore_branch(handle, new_cell, branch_direction, new_root_observation, depth+1) branch_observation = self._explore_branch(handle,
new_cell,
branch_direction,
new_root_observation,
depth+1)
observation = observation + branch_observation observation = observation + branch_observation
else: else:
......
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import numpy as np
from flatland.core.env_observation_builder import GlobalObsForRailEnv from flatland.core.env_observation_builder import GlobalObsForRailEnv
from flatland.core.transition_map import GridTransitionMap, Grid4Transitions from flatland.core.transition_map import GridTransitionMap, Grid4Transitions
from flatland.envs.rail_env import RailEnv, rail_from_GridTransitionMap_generator from flatland.envs.rail_env import RailEnv, rail_from_GridTransitionMap_generator
from flatland.utils.rendertools import *
"""Tests for `flatland` package.""" """Tests for `flatland` package."""
...@@ -45,14 +46,14 @@ def test_global_obs(): ...@@ -45,14 +46,14 @@ def test_global_obs():
double_switch_south_horizontal_straight, 180) double_switch_south_horizontal_straight, 180)
rail_map = np.array( rail_map = np.array(
[[empty] * 3 + [dead_end_from_south] + [empty] * 6] + [[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
[[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 + [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 +
[[dead_end_from_east] + [horizontal_straight] * 2 + [[dead_end_from_east] + [horizontal_straight] * 2 +
[double_switch_north_horizontal_straight] + [double_switch_north_horizontal_straight] +
[horizontal_straight] * 2 + [double_switch_south_horizontal_straight] + [horizontal_straight] * 2 + [double_switch_south_horizontal_straight] +
[horizontal_straight] * 2 + [dead_end_from_west]] + [horizontal_straight] * 2 + [dead_end_from_west]] +
[[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 + [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
[[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16) [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
rail = GridTransitionMap(width=rail_map.shape[1], rail = GridTransitionMap(width=rail_map.shape[1],
height=rail_map.shape[0], transitions=transitions) height=rail_map.shape[0], transitions=transitions)
...@@ -81,17 +82,3 @@ def test_global_obs(): ...@@ -81,17 +82,3 @@ def test_global_obs():
# If this assertion is wrong, it means that the observation returned # If this assertion is wrong, it means that the observation returned
# places the agent on an empty cell # places the agent on an empty cell
assert(np.sum(rail_map * global_obs[0][1][0]) > 0) assert(np.sum(rail_map * global_obs[0][1][0]) > 0)
test_global_obs()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment