From e8ba817244bede7d321e10de692177229d5bbdc3 Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Wed, 1 May 2019 08:45:30 +0200 Subject: [PATCH] code cleanup --- flatland/core/env_observation_builder.py | 12 ++++++++---- flatland/envs/rail_env.py | 2 ++ 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py index aaf5177f..1ae2819d 100644 --- a/flatland/core/env_observation_builder.py +++ b/flatland/core/env_observation_builder.py @@ -17,6 +17,7 @@ class ObservationBuilder: """ ObservationBuilder base class. """ + def __init__(self): pass @@ -55,6 +56,7 @@ class TreeObsForRailEnv(ObservationBuilder): The information is local to each agent and exploits the tree structure of the rail network to simplify the representation of the state of the environment for each agent. """ + def __init__(self, max_depth): self.max_depth = max_depth @@ -135,7 +137,7 @@ class TreeObsForRailEnv(ObservationBuilder): new_cell = self._new_position(position, neigh_direction) if new_cell[0] >= 0 and new_cell[0] < self.env.height and \ - new_cell[1] >= 0 and new_cell[1] < self.env.width: + new_cell[1] >= 0 and new_cell[1] < self.env.width: desired_movement_from_new_cell = (neigh_direction + 2) % 4 @@ -176,7 +178,7 @@ class TreeObsForRailEnv(ObservationBuilder): """ Utility function that converts a compass movement over a 2D grid to new positions (r, c). """ - if movement == 0: # NORTH + if movement == 0: # NORTH return (position[0] - 1, position[1]) elif movement == 1: # EAST return (position[0], position[1] + 1) @@ -340,7 +342,8 @@ class TreeObsForRailEnv(ObservationBuilder): elif num_transitions == 0: # Wrong cell type, but let's cover it and treat it as a dead-end, just in case - print("WRONG CELL TYPE detected in tree-search (0 transitions possible) at cell",position[0], position[1] ) + print("WRONG CELL TYPE detected in tree-search (0 transitions possible) at cell", position[0], + position[1], direction) last_isTerminal = True break @@ -394,7 +397,7 @@ class TreeObsForRailEnv(ObservationBuilder): observation = observation + branch_observation elif last_isSwitch and self.env.rail.get_transition((position[0], position[1], direction), - (branch_direction + 2) % 4): + (branch_direction + 2) % 4): new_cell = self._new_position(position, branch_direction) branch_observation = self._explore_branch(handle, @@ -456,6 +459,7 @@ class GlobalObsForRailEnv(ObservationBuilder): - A 4 elements array with one of encoding of the direction. """ + def __init__(self): super(GlobalObsForRailEnv, self).__init__() diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 878892d7..5f43847f 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -999,6 +999,7 @@ class RailEnv(Environment): for i in range(len(self.agents_handles)): handle = self.agents_handles[i] transition_isValid = None + if handle not in action_dict: continue @@ -1093,6 +1094,7 @@ class RailEnv(Environment): else: new_cell_isValid = False + # If transition validity hasn't been checked yet. if transition_isValid == None: transition_isValid = self.rail.get_transition( (pos[0], pos[1], direction), -- GitLab