Commit e8ba8172 authored by Erik Nygren's avatar Erik Nygren
Browse files

code cleanup

parent 8cec2335
Pipeline #462 failed with stage
in 1 minute and 49 seconds
......@@ -17,6 +17,7 @@ class ObservationBuilder:
"""
ObservationBuilder base class.
"""
def __init__(self):
pass
......@@ -55,6 +56,7 @@ class TreeObsForRailEnv(ObservationBuilder):
The information is local to each agent and exploits the tree structure of the rail
network to simplify the representation of the state of the environment for each agent.
"""
def __init__(self, max_depth):
self.max_depth = max_depth
......@@ -135,7 +137,7 @@ class TreeObsForRailEnv(ObservationBuilder):
new_cell = self._new_position(position, neigh_direction)
if new_cell[0] >= 0 and new_cell[0] < self.env.height and \
new_cell[1] >= 0 and new_cell[1] < self.env.width:
new_cell[1] >= 0 and new_cell[1] < self.env.width:
desired_movement_from_new_cell = (neigh_direction + 2) % 4
......@@ -176,7 +178,7 @@ class TreeObsForRailEnv(ObservationBuilder):
"""
Utility function that converts a compass movement over a 2D grid to new positions (r, c).
"""
if movement == 0: # NORTH
if movement == 0: # NORTH
return (position[0] - 1, position[1])
elif movement == 1: # EAST
return (position[0], position[1] + 1)
......@@ -340,7 +342,8 @@ class TreeObsForRailEnv(ObservationBuilder):
elif num_transitions == 0:
# Wrong cell type, but let's cover it and treat it as a dead-end, just in case
print("WRONG CELL TYPE detected in tree-search (0 transitions possible) at cell",position[0], position[1] )
print("WRONG CELL TYPE detected in tree-search (0 transitions possible) at cell", position[0],
position[1], direction)
last_isTerminal = True
break
......@@ -394,7 +397,7 @@ class TreeObsForRailEnv(ObservationBuilder):
observation = observation + branch_observation
elif last_isSwitch and self.env.rail.get_transition((position[0], position[1], direction),
(branch_direction + 2) % 4):
(branch_direction + 2) % 4):
new_cell = self._new_position(position, branch_direction)
branch_observation = self._explore_branch(handle,
......@@ -456,6 +459,7 @@ class GlobalObsForRailEnv(ObservationBuilder):
- A 4 elements array with one of encoding of the direction.
"""
def __init__(self):
super(GlobalObsForRailEnv, self).__init__()
......
......@@ -999,6 +999,7 @@ class RailEnv(Environment):
for i in range(len(self.agents_handles)):
handle = self.agents_handles[i]
transition_isValid = None
if handle not in action_dict:
continue
......@@ -1093,6 +1094,7 @@ class RailEnv(Environment):
else:
new_cell_isValid = False
# If transition validity hasn't been checked yet.
if transition_isValid == None:
transition_isValid = self.rail.get_transition(
(pos[0], pos[1], direction),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment