Skip to content
Snippets Groups Projects
Commit e8ba8172 authored by Erik Nygren's avatar Erik Nygren
Browse files

code cleanup

parent 8cec2335
No related branches found
No related tags found
No related merge requests found
Pipeline #462 failed
...@@ -17,6 +17,7 @@ class ObservationBuilder: ...@@ -17,6 +17,7 @@ class ObservationBuilder:
""" """
ObservationBuilder base class. ObservationBuilder base class.
""" """
def __init__(self): def __init__(self):
pass pass
...@@ -55,6 +56,7 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -55,6 +56,7 @@ class TreeObsForRailEnv(ObservationBuilder):
The information is local to each agent and exploits the tree structure of the rail The information is local to each agent and exploits the tree structure of the rail
network to simplify the representation of the state of the environment for each agent. network to simplify the representation of the state of the environment for each agent.
""" """
def __init__(self, max_depth): def __init__(self, max_depth):
self.max_depth = max_depth self.max_depth = max_depth
...@@ -135,7 +137,7 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -135,7 +137,7 @@ class TreeObsForRailEnv(ObservationBuilder):
new_cell = self._new_position(position, neigh_direction) new_cell = self._new_position(position, neigh_direction)
if new_cell[0] >= 0 and new_cell[0] < self.env.height and \ if new_cell[0] >= 0 and new_cell[0] < self.env.height and \
new_cell[1] >= 0 and new_cell[1] < self.env.width: new_cell[1] >= 0 and new_cell[1] < self.env.width:
desired_movement_from_new_cell = (neigh_direction + 2) % 4 desired_movement_from_new_cell = (neigh_direction + 2) % 4
...@@ -176,7 +178,7 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -176,7 +178,7 @@ class TreeObsForRailEnv(ObservationBuilder):
""" """
Utility function that converts a compass movement over a 2D grid to new positions (r, c). Utility function that converts a compass movement over a 2D grid to new positions (r, c).
""" """
if movement == 0: # NORTH if movement == 0: # NORTH
return (position[0] - 1, position[1]) return (position[0] - 1, position[1])
elif movement == 1: # EAST elif movement == 1: # EAST
return (position[0], position[1] + 1) return (position[0], position[1] + 1)
...@@ -340,7 +342,8 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -340,7 +342,8 @@ class TreeObsForRailEnv(ObservationBuilder):
elif num_transitions == 0: elif num_transitions == 0:
# Wrong cell type, but let's cover it and treat it as a dead-end, just in case # Wrong cell type, but let's cover it and treat it as a dead-end, just in case
print("WRONG CELL TYPE detected in tree-search (0 transitions possible) at cell",position[0], position[1] ) print("WRONG CELL TYPE detected in tree-search (0 transitions possible) at cell", position[0],
position[1], direction)
last_isTerminal = True last_isTerminal = True
break break
...@@ -394,7 +397,7 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -394,7 +397,7 @@ class TreeObsForRailEnv(ObservationBuilder):
observation = observation + branch_observation observation = observation + branch_observation
elif last_isSwitch and self.env.rail.get_transition((position[0], position[1], direction), elif last_isSwitch and self.env.rail.get_transition((position[0], position[1], direction),
(branch_direction + 2) % 4): (branch_direction + 2) % 4):
new_cell = self._new_position(position, branch_direction) new_cell = self._new_position(position, branch_direction)
branch_observation = self._explore_branch(handle, branch_observation = self._explore_branch(handle,
...@@ -456,6 +459,7 @@ class GlobalObsForRailEnv(ObservationBuilder): ...@@ -456,6 +459,7 @@ class GlobalObsForRailEnv(ObservationBuilder):
- A 4 elements array with one of encoding of the direction. - A 4 elements array with one of encoding of the direction.
""" """
def __init__(self): def __init__(self):
super(GlobalObsForRailEnv, self).__init__() super(GlobalObsForRailEnv, self).__init__()
......
...@@ -999,6 +999,7 @@ class RailEnv(Environment): ...@@ -999,6 +999,7 @@ class RailEnv(Environment):
for i in range(len(self.agents_handles)): for i in range(len(self.agents_handles)):
handle = self.agents_handles[i] handle = self.agents_handles[i]
transition_isValid = None transition_isValid = None
if handle not in action_dict: if handle not in action_dict:
continue continue
...@@ -1093,6 +1094,7 @@ class RailEnv(Environment): ...@@ -1093,6 +1094,7 @@ class RailEnv(Environment):
else: else:
new_cell_isValid = False new_cell_isValid = False
# If transition validity hasn't been checked yet.
if transition_isValid == None: if transition_isValid == None:
transition_isValid = self.rail.get_transition( transition_isValid = self.rail.get_transition(
(pos[0], pos[1], direction), (pos[0], pos[1], direction),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment