From bcd44bd68530583b036e32f82ae7deeb821fe92a Mon Sep 17 00:00:00 2001 From: spiglerg <spiglerg@gmail.com> Date: Wed, 24 Apr 2019 13:58:34 +0200 Subject: [PATCH] fixed new lint errors --- flatland/core/env_observation_builder.py | 50 +++++----- flatland/core/transitions.py | 43 +++----- flatland/envs/rail_env.py | 78 +++++++-------- flatland/utils/graphics_layer.py | 7 +- flatland/utils/graphics_qt.py | 5 +- flatland/utils/render_qt.py | 16 ++- flatland/utils/rendertools.py | 119 +++++++++++------------ tests/test_env_observation_builder.py | 16 +-- tests/test_environments.py | 3 +- tests/test_rendertools.py | 4 +- tests/test_transitions.py | 43 ++++---- 11 files changed, 173 insertions(+), 211 deletions(-) diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py index 86485ec..8737862 100644 --- a/flatland/core/env_observation_builder.py +++ b/flatland/core/env_observation_builder.py @@ -114,7 +114,7 @@ class TreeObsForRailEnv(ObservationBuilder): nodes_queue.append(n) if len(valid_neighbors) > 0: - max_distance = max(max_distance, node[3]+1) + max_distance = max(max_distance, node[3] + 1) return max_distance @@ -129,7 +129,7 @@ class TreeObsForRailEnv(ObservationBuilder): if enforce_target_direction >= 0: # The agent must land into the current cell with orientation `enforce_target_direction'. # This is only possible if the agent has arrived from the cell in the opposite direction! - possible_directions = [(enforce_target_direction+2) % 4] + possible_directions = [(enforce_target_direction + 2) % 4] for neigh_direction in possible_directions: new_cell = self._new_position(position, neigh_direction) @@ -137,7 +137,7 @@ class TreeObsForRailEnv(ObservationBuilder): if new_cell[0] >= 0 and new_cell[0] < self.env.height and \ new_cell[1] >= 0 and new_cell[1] < self.env.width: - desired_movement_from_new_cell = (neigh_direction+2) % 4 + desired_movement_from_new_cell = (neigh_direction + 2) % 4 """ # Is the next cell a dead-end? @@ -166,7 +166,7 @@ class TreeObsForRailEnv(ObservationBuilder): movement = (desired_movement_from_new_cell+2) % 4 """ new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation], - current_distance+1) + current_distance + 1) neighbors.append((new_cell[0], new_cell[1], agent_orientation, new_distance)) self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation] = new_distance @@ -177,11 +177,11 @@ class TreeObsForRailEnv(ObservationBuilder): Utility function that converts a compass movement over a 2D grid to new positions (r, c). """ if movement == 0: # NORTH - return (position[0]-1, position[1]) + return (position[0] - 1, position[1]) elif movement == 1: # EAST return (position[0], position[1] + 1) elif movement == 2: # SOUTH - return (position[0]+1, position[1]) + return (position[0] + 1, position[1]) elif movement == 3: # WEST return (position[0], position[1] - 1) @@ -241,7 +241,7 @@ class TreeObsForRailEnv(ObservationBuilder): # Start from the current orientation, and see which transitions are available; # organize them as [left, forward, right, back], relative to the current orientation - for branch_direction in [(orientation+4+i) % 4 for i in range(-1, 3)]: + for branch_direction in [(orientation + 4 + i) % 4 for i in range(-1, 3)]: if self.env.rail.get_transition((position[0], position[1], orientation), branch_direction): new_cell = self._new_position(position, branch_direction) @@ -253,7 +253,7 @@ class TreeObsForRailEnv(ObservationBuilder): for i in range(self.max_depth): num_cells_to_fill_in += pow4 pow4 *= 4 - observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf]*num_cells_to_fill_in + observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in return observation @@ -262,7 +262,7 @@ class TreeObsForRailEnv(ObservationBuilder): Utility function to compute tree-based observations. """ # [Recursive branch opened] - if depth >= self.max_depth+1: + if depth >= self.max_depth + 1: return [] # Continue along direction until next switch or @@ -356,7 +356,7 @@ class TreeObsForRailEnv(ObservationBuilder): observation = [0, 1 if other_target_encountered else 0, 1 if other_agent_encountered else 0, - root_observation[3]+num_steps, + root_observation[3] + num_steps, 0] elif last_isTerminal: @@ -369,7 +369,7 @@ class TreeObsForRailEnv(ObservationBuilder): observation = [0, 1 if other_target_encountered else 0, 1 if other_agent_encountered else 0, - root_observation[3]+num_steps, + root_observation[3] + num_steps, self.distance_map[handle, position[0], position[1], direction]] # ############################# @@ -379,18 +379,18 @@ class TreeObsForRailEnv(ObservationBuilder): # Start from the current orientation, and see which transitions are available; # organize them as [left, forward, right, back], relative to the current orientation - for branch_direction in [(direction+4+i) % 4 for i in range(-1, 3)]: + for branch_direction in [(direction + 4 + i) % 4 for i in range(-1, 3)]: if last_isDeadEnd and self.env.rail.get_transition((position[0], position[1], direction), - (branch_direction+2) % 4): + (branch_direction + 2) % 4): # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes # it back - new_cell = self._new_position(position, (branch_direction+2) % 4) + new_cell = self._new_position(position, (branch_direction + 2) % 4) branch_observation = self._explore_branch(handle, new_cell, - (branch_direction+2) % 4, + (branch_direction + 2) % 4, new_root_observation, - depth+1) + depth + 1) observation = observation + branch_observation elif last_isSwitch and self.env.rail.get_transition((position[0], position[1], direction), @@ -401,16 +401,16 @@ class TreeObsForRailEnv(ObservationBuilder): new_cell, branch_direction, new_root_observation, - depth+1) + depth + 1) observation = observation + branch_observation else: num_cells_to_fill_in = 0 pow4 = 1 - for i in range(self.max_depth-depth): + for i in range(self.max_depth - depth): num_cells_to_fill_in += pow4 pow4 *= 4 - observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf]*num_cells_to_fill_in + observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in return observation @@ -422,7 +422,7 @@ class TreeObsForRailEnv(ObservationBuilder): return depth = 0 - tmp = len(tree)/num_features_per_node-1 + tmp = len(tree) / num_features_per_node - 1 pow4 = 4 while tmp > 0: tmp -= pow4 @@ -431,15 +431,15 @@ class TreeObsForRailEnv(ObservationBuilder): prompt_ = ['L:', 'F:', 'R:', 'B:'] - print(" "*current_depth + prompt, tree[0:num_features_per_node]) - child_size = (len(tree)-num_features_per_node)//4 + print(" " * current_depth + prompt, tree[0:num_features_per_node]) + child_size = (len(tree) - num_features_per_node) // 4 for children in range(4): - child_tree = tree[(num_features_per_node+children*child_size): - (num_features_per_node+(children+1)*child_size)] + child_tree = tree[(num_features_per_node + children * child_size): + (num_features_per_node + (children + 1) * child_size)] self.util_print_obs_subtree(child_tree, num_features_per_node, prompt=prompt_[children], - current_depth=current_depth+1) + current_depth=current_depth + 1) class GlobalObsForRailEnv(ObservationBuilder): diff --git a/flatland/core/transitions.py b/flatland/core/transitions.py index eb4cb8e..a8cb8d6 100644 --- a/flatland/core/transitions.py +++ b/flatland/core/transitions.py @@ -180,7 +180,7 @@ class Grid4Transitions(Transitions): List of the validity of transitions in the cell. """ - bits = (cell_transition >> ((3-orientation)*4)) + bits = (cell_transition >> ((3 - orientation) * 4)) return ((bits >> 3) & 1, (bits >> 2) & 1, (bits >> 1) & 1, (bits) & 1) def set_transitions(self, cell_transition, orientation, new_transitions): @@ -208,7 +208,7 @@ class Grid4Transitions(Transitions): `orientation'. """ - mask = (1 << ((4-orientation)*4)) - (1 << ((3-orientation)*4)) + mask = (1 << ((4 - orientation) * 4)) - (1 << ((3 - orientation) * 4)) negmask = ~mask new_transitions = \ @@ -217,9 +217,7 @@ class Grid4Transitions(Transitions): (new_transitions[2] & 1) << 1 | \ (new_transitions[3] & 1) - cell_transition = \ - (cell_transition & negmask) | \ - (new_transitions << ((3-orientation)*4)) + cell_transition = (cell_transition & negmask) | (new_transitions << ((3 - orientation) * 4)) return cell_transition @@ -245,8 +243,7 @@ class Grid4Transitions(Transitions): Validity of the requested transition: 0/1 allowed/not allowed. """ - return ((cell_transition >> ((4-1-orientation) * 4)) >> - (4-1-direction)) & 1 + return ((cell_transition >> ((4 - 1 - orientation) * 4)) >> (4 - 1 - direction)) & 1 def set_transition(self, cell_transition, orientation, direction, new_transition): @@ -276,12 +273,9 @@ class Grid4Transitions(Transitions): """ if new_transition: - cell_transition |= (1 << ((4-1-orientation) * 4 + - (4-1-direction))) + cell_transition |= (1 << ((4 - 1 - orientation) * 4 + (4 - 1 - direction))) else: - cell_transition &= \ - ~(1 << ((4-1-orientation) * 4 + - (4-1-direction))) + cell_transition &= ~(1 << ((4 - 1 - orientation) * 4 + (4 - 1 - direction))) return cell_transition @@ -310,13 +304,11 @@ class Grid4Transitions(Transitions): rotation = rotation // 90 for i in range(4): block_tuple = self.get_transitions(value, i) - block_tuple = block_tuple[( - 4-rotation):] + block_tuple[:(4-rotation)] + block_tuple = block_tuple[(4 - rotation):] + block_tuple[:(4 - rotation)] value = self.set_transitions(value, i, block_tuple) # Rotate the 4-bits blocks - value = ((value & (2**(rotation*4)-1)) << - ((4-rotation)*4)) | (value >> (rotation*4)) + value = ((value & (2**(rotation * 4) - 1)) << ((4 - rotation) * 4)) | (value >> (rotation * 4)) cell_transition = value return cell_transition @@ -355,7 +347,7 @@ class Grid8Transitions(Transitions): List of the validity of transitions in the cell. """ - bits = (cell_transition >> ((7-orientation)*8)) + bits = (cell_transition >> ((7 - orientation) * 8)) cell_transition = ( (bits >> 7) & 1, (bits >> 6) & 1, @@ -389,7 +381,7 @@ class Grid8Transitions(Transitions): `orientation'. """ - mask = (1 << ((8-orientation)*8)) - (1 << ((7-orientation)*8)) + mask = (1 << ((8 - orientation) * 8)) - (1 << ((7 - orientation) * 8)) negmask = ~mask new_transitions = \ @@ -402,8 +394,7 @@ class Grid8Transitions(Transitions): (new_transitions[6] & 1) << 1 | \ (new_transitions[7] & 1) - cell_transition = (cell_transition & negmask) | ( - new_transitions << ((7-orientation)*8)) + cell_transition = (cell_transition & negmask) | (new_transitions << ((7 - orientation) * 8)) return cell_transition @@ -429,8 +420,7 @@ class Grid8Transitions(Transitions): Validity of the requested transition: 0/1 allowed/not allowed. """ - return ((cell_transition >> ((8-1-orientation) * 8)) >> - (8-1-direction)) & 1 + return ((cell_transition >> ((8 - 1 - orientation) * 8)) >> (8 - 1 - direction)) & 1 def set_transition(self, cell_transition, orientation, direction, new_transition): @@ -460,11 +450,9 @@ class Grid8Transitions(Transitions): """ if new_transition: - cell_transition |= (1 << ((8-1-orientation) * 8 + - (8 - 1 - direction))) + cell_transition |= (1 << ((8 - 1 - orientation) * 8 + (8 - 1 - direction))) else: - cell_transition &= ~(1 << ((8-1-orientation) * 8 + - (8 - 1 - direction))) + cell_transition &= ~(1 << ((8 - 1 - orientation) * 8 + (8 - 1 - direction))) return cell_transition @@ -500,8 +488,7 @@ class Grid8Transitions(Transitions): value = self.set_transitions(value, i, block_tuple) # Rotate the 8bits blocks - value = ((value & (2**(rotation*8)-1)) << - ((8-rotation)*8)) | (value >> (rotation*8)) + value = ((value & (2**(rotation * 8) - 1)) << ((8 - rotation) * 8)) | (value >> (rotation * 8)) cell_transition = value diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index cf20603..3fadf66 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -45,8 +45,7 @@ def rail_from_manual_specifications_generator(rail_spec): if cell[0] < 0 or cell[0] >= len(t_utils.transitions): print("ERROR - invalid cell type=", cell[0]) return [] - rail.set_transitions((r, c), t_utils.rotate_transition( - t_utils.transitions[cell[0]], cell[1])) + rail.set_transitions((r, c), t_utils.rotate_transition(t_utils.transitions[cell[0]], cell[1])) return rail @@ -110,7 +109,7 @@ def generate_rail_from_list_of_manual_specifications(list_of_specifications) """ -def random_rail_generator(cell_type_relative_proportion=[1.0]*8): +def random_rail_generator(cell_type_relative_proportion=[1.0] * 8): """ Dummy random level generator: - fill in cells at random in [width-2, height-2] @@ -149,7 +148,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8): transitions_templates_ = [] transition_probabilities = [] - for i in range(len(t_utils.transitions)-1): # don't include dead-ends + for i in range(len(t_utils.transitions) - 1): # don't include dead-ends all_transitions = 0 for dir_ in range(4): trans = t_utils.get_transitions(t_utils.transitions[i], dir_) @@ -159,7 +158,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8): (trans[3]) template = [int(x) for x in bin(all_transitions)[2:]] - template = [0]*(4-len(template)) + template + template = [0] * (4 - len(template)) + template # add all rotations for rot in [0, 90, 180, 270]: @@ -168,7 +167,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8): t_utils.transitions[i], rot))) transition_probabilities.append(transition_probability[i]) - template = [template[-1]]+template[:-1] + template = [template[-1]] + template[:-1] def get_matching_templates(template): ret = [] @@ -183,7 +182,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8): ret.append((transitions_templates_[i][1], transition_probabilities[i])) return ret - MAX_INSERTIONS = (width-2) * (height-2) * 10 + MAX_INSERTIONS = (width - 2) * (height - 2) * 10 MAX_ATTEMPTS_FROM_SCRATCH = 10 attempt_number = 0 @@ -191,10 +190,9 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8): cells_to_fill = [] rail = [] for r in range(height): - rail.append([None]*width) - if r > 0 and r < height-1: - cells_to_fill = cells_to_fill \ - + [(r, c) for c in range(1, width-1)] + rail.append([None] * width) + if r > 0 and r < height - 1: + cells_to_fill = cells_to_fill + [(r, c) for c in range(1, width - 1)] num_insertions = 0 while num_insertions < MAX_INSERTIONS and len(cells_to_fill) > 0: @@ -212,14 +210,13 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8): (1, 3, (0, 1)), (2, 0, (1, 0)), (3, 1, (0, -1))]: # N, E, S, W - neigh_trans = rail[row+el[2][0]][col+el[2][1]] + neigh_trans = rail[row + el[2][0]][col + el[2][1]] if neigh_trans is not None: # select transition coming from facing direction el[1] and # moving to direction el[1] max_bit = 0 for k in range(4): - max_bit |= \ - t_utils.get_transition(neigh_trans, k, el[1]) + max_bit |= t_utils.get_transition(neigh_trans, k, el[1]) if max_bit: valid_template[el[0]] = 1 @@ -244,8 +241,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8): elif k == 3: rot = 90 - rail[row][col] = t_utils.rotate_transition( - int('0010000000000000', 2), rot) + rail[row][col] = t_utils.rotate_transition(int('0010000000000000', 2), rot) num_insertions += 1 break @@ -258,8 +254,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8): for k in range(4): tmp_template = valid_template[:] tmp_template[k] = -1 - possible_cell_transitions = get_matching_templates( - tmp_template) + possible_cell_transitions = get_matching_templates(tmp_template) if len(possible_cell_transitions) > len(besttrans): besttrans = possible_cell_transitions bestk = k @@ -284,7 +279,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8): rail[replace_row][replace_col] = None possible_transitions, possible_probabilities = zip(*besttrans) - possible_probabilities = [p/sum(possible_probabilities) for p in possible_probabilities] + possible_probabilities = [p / sum(possible_probabilities) for p in possible_probabilities] rail[row][col] = np.random.choice(possible_transitions, p=possible_probabilities) @@ -298,7 +293,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8): else: possible_transitions, possible_probabilities = zip(*possible_cell_transitions) - possible_probabilities = [p/sum(possible_probabilities) for p in possible_probabilities] + possible_probabilities = [p / sum(possible_probabilities) for p in possible_probabilities] rail[row][col] = np.random.choice(possible_transitions, p=possible_probabilities) @@ -321,12 +316,10 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8): neigh_trans = rail[r][1] if neigh_trans is not None: for k in range(4): - neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \ - & (2**4-1) + neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2**4 - 1) max_bit = max_bit | (neigh_trans_from_direction & 1) if max_bit: - rail[r][0] = t_utils.rotate_transition( - int('0010000000000000', 2), 270) + rail[r][0] = t_utils.rotate_transition(int('0010000000000000', 2), 270) else: rail[r][0] = int('0000000000000000', 2) @@ -335,8 +328,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8): neigh_trans = rail[r][-2] if neigh_trans is not None: for k in range(4): - neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \ - & (2**4-1) + neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2**4 - 1) max_bit = max_bit | (neigh_trans_from_direction & (1 << 2)) if max_bit: rail[r][-1] = t_utils.rotate_transition(int('0010000000000000', 2), @@ -350,8 +342,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8): neigh_trans = rail[1][c] if neigh_trans is not None: for k in range(4): - neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \ - & (2**4-1) + neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2**4 - 1) max_bit = max_bit | (neigh_trans_from_direction & (1 << 3)) if max_bit: rail[0][c] = int('0010000000000000', 2) @@ -363,12 +354,10 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8): neigh_trans = rail[-2][c] if neigh_trans is not None: for k in range(4): - neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \ - & (2**4-1) + neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2**4 - 1) max_bit = max_bit | (neigh_trans_from_direction & (1 << 1)) if max_bit: - rail[-1][c] = t_utils.rotate_transition( - int('0010000000000000', 2), 180) + rail[-1][c] = t_utils.rotate_transition(int('0010000000000000', 2), 180) else: rail[-1][c] = int('0000000000000000', 2) @@ -458,8 +447,8 @@ class RailEnv(Environment): self.obs_builder = obs_builder_object self.obs_builder._set_env(self) - self.actions = [0]*self.number_of_agents - self.rewards = [0]*self.number_of_agents + self.actions = [0] * self.number_of_agents + self.rewards = [0] * self.number_of_agents self.done = False self.dones = {"__all__": False} @@ -507,14 +496,13 @@ class RailEnv(Environment): # agents_direction must be a direction for which a solution is # guaranteed. - self.agents_direction = [0]*self.number_of_agents + self.agents_direction = [0] * self.number_of_agents re_generate = False for i in range(self.number_of_agents): valid_movements = [] for direction in range(4): position = self.agents_position[i] - moves = self.rail.get_transitions( - (position[0], position[1], direction)) + moves = self.rail.get_transitions((position[0], position[1], direction)) for move_index in range(4): if moves[move_index]: valid_movements.append((direction, move_index)) @@ -608,8 +596,8 @@ class RailEnv(Environment): reverse_direction = 1 valid_transition = self.rail.get_transition( - (pos[0], pos[1], direction), - reverse_direction) + (pos[0], pos[1], direction), + reverse_direction) if valid_transition: direction = reverse_direction movement = reverse_direction @@ -629,8 +617,8 @@ class RailEnv(Environment): new_cell_isValid = False transition_isValid = self.rail.get_transition( - (pos[0], pos[1], direction), - movement) or is_deadend + (pos[0], pos[1], direction), + movement) or is_deadend cell_isFree = True for j in range(self.number_of_agents): @@ -663,20 +651,20 @@ class RailEnv(Environment): if num_agents_in_target_position == self.number_of_agents: self.dones["__all__"] = True - self.rewards_dict = [r+global_reward for r in self.rewards_dict] + self.rewards_dict = [r + global_reward for r in self.rewards_dict] # Reset the step actions (in case some agent doesn't 'register_action' # on the next step) - self.actions = [0]*self.number_of_agents + self.actions = [0] * self.number_of_agents return self._get_observations(), self.rewards_dict, self.dones, {} def _new_position(self, position, movement): if movement == 0: # NORTH - return (position[0]-1, position[1]) + return (position[0] - 1, position[1]) elif movement == 1: # EAST return (position[0], position[1] + 1) elif movement == 2: # SOUTH - return (position[0]+1, position[1]) + return (position[0] + 1, position[1]) elif movement == 3: # WEST return (position[0], position[1] - 1) diff --git a/flatland/utils/graphics_layer.py b/flatland/utils/graphics_layer.py index b199a3b..5ace85c 100644 --- a/flatland/utils/graphics_layer.py +++ b/flatland/utils/graphics_layer.py @@ -18,16 +18,15 @@ class GraphicsLayer(object): def show(self, block=False): pass - + def pause(self, seconds=0.00001): pass def clf(self): pass - + def beginFrame(self): pass - + def endFrame(self): pass - diff --git a/flatland/utils/graphics_qt.py b/flatland/utils/graphics_qt.py index 6571f4d..8c5b37f 100644 --- a/flatland/utils/graphics_qt.py +++ b/flatland/utils/graphics_qt.py @@ -214,13 +214,12 @@ class QtRenderer(object): def takeSnapshot(self, sDir="./movie"): oWidget = self.window.mainWidget oPixmap = oWidget.grab() - + if not os.path.isdir(sDir): os.mkdir(sDir) - + nRunIn = 30 if self.iFrame > nRunIn: sfImage = "%s/frame%05d.jpg" % (sDir, self.iFrame - nRunIn) oPixmap.save(sfImage, "jpg") self.iFrame += 1 - diff --git a/flatland/utils/render_qt.py b/flatland/utils/render_qt.py index 60b8d29..f7e75ea 100644 --- a/flatland/utils/render_qt.py +++ b/flatland/utils/render_qt.py @@ -55,8 +55,8 @@ class QTGL(GraphicsLayer): if lastx is not None: # print("line", lastx, lasty, x, y) self.qtr.drawLine( - lastx*self.cell_pixels, -lasty*self.cell_pixels, - x*self.cell_pixels, -y*self.cell_pixels) + lastx * self.cell_pixels, -lasty * self.cell_pixels, + x * self.cell_pixels, -y * self.cell_pixels) lastx = x lasty = y @@ -64,20 +64,20 @@ class QTGL(GraphicsLayer): print("scatter not yet implemented in ", self.__class__) def text(self, x, y, sText): - self.qtr.drawText(x*self.cell_pixels, -y*self.cell_pixels, sText) - + self.qtr.drawText(x * self.cell_pixels, -y * self.cell_pixels, sText) + def prettify(self, *args, **kwargs): pass def prettify2(self, width, height, cell_size): pass - + def show(self, block=False): pass def pause(self, seconds=0.00001): pass - + def clf(self): pass @@ -88,9 +88,7 @@ class QTGL(GraphicsLayer): self.qtr.beginFrame() self.qtr.push() self.qtr.fillRect(0, 0, self.widthPx, self.heightPx, *self.tColBg) - + def endFrame(self): self.qtr.pop() self.qtr.endFrame() - - diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index f9ebf53..997cee6 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -24,11 +24,11 @@ class MPLGL(GraphicsLayer): def text(self, *args, **kwargs): plt.text(*args, **kwargs) - + def prettify(self, *args, **kwargs): ax = plt.gca() - plt.xticks(range(int(ax.get_xlim()[1])+1)) - plt.yticks(range(int(ax.get_ylim()[1])+1)) + plt.xticks(range(int(ax.get_xlim()[1]) + 1)) + plt.yticks(range(int(ax.get_ylim()[1]) + 1)) plt.grid() plt.xlabel("Euclidean distance") plt.ylabel("Tree / Transition Depth") @@ -41,28 +41,28 @@ class MPLGL(GraphicsLayer): gLabels = np.arange(0, height) plt.xticks(gTicks, gLabels) - gTicks = np.arange(-height * cell_size, 0) + cell_size/2 - gLabels = np.arange(height-1, -1, -1) + gTicks = np.arange(-height * cell_size, 0) + cell_size / 2 + gLabels = np.arange(height - 1, -1, -1) plt.yticks(gTicks, gLabels) plt.xlim([0, width * cell_size]) plt.ylim([-height * cell_size, 0]) - + def show(self, block=False): plt.show(block=block) def pause(self, seconds=0.00001): plt.pause(seconds) - + def clf(self): plt.clf() - + def get_cmap(self, *args, **kwargs): return plt.get_cmap(*args, **kwargs) def beginFrame(self): pass - + def endFrame(self): pass @@ -85,7 +85,7 @@ class RenderTool(object): gCentres = xr.DataArray(gGrid, dims=["xy", "p1", "p2"], coords={"xy": ["x", "y"]}) + xyPixHalf - gTheta = np.linspace(0, np.pi/2, 10) + gTheta = np.linspace(0, np.pi / 2, 10) gArc = array([np.cos(gTheta), np.sin(gTheta)]).T # from [1,0] to [0,1] def __init__(self, env, gl="MPL"): @@ -197,7 +197,7 @@ class RenderTool(object): rcDir = rt.gTransRC[iDir] # agent direction in RC xyDir = np.matmul(rcDir, rt.grc2xy) # agent direction in xy - xyDirLine = array([xyPos, xyPos+xyDir/2]).T # line for agent orient. + xyDirLine = array([xyPos, xyPos + xyDir / 2]).T # line for agent orient. self.gl.plot(*xyDirLine, color=sColor, lw=5, ms=0, alpha=0.6) # just mark the next cell we're heading into @@ -215,7 +215,7 @@ class RenderTool(object): rt = self.__class__ xyPos = np.matmul(rcPos, rt.grc2xy) + rt.xyHalf - gxyTrans = xyPos + np.matmul(gTransRCAg, rt.grc2xy/2.4) + gxyTrans = xyPos + np.matmul(gTransRCAg, rt.grc2xy / 2.4) self.gl.scatter(*gxyTrans.T, color=color, marker="o", s=50, alpha=0.2) if depth is not None: for x, y in gxyTrans: @@ -255,7 +255,7 @@ class RenderTool(object): # print("Trans:", gTransRC2) visitNext = rt.Visit(tuple(visit.rc + gTransRC2), iTrans, - visit.iDepth+1, + visit.iDepth + 1, visit) # print("node2: ", node2) stack.append(visitNext) @@ -294,7 +294,7 @@ class RenderTool(object): xLoc = rDist + visit.iDir / 4 # point labelled with distance - self.gl.scatter(xLoc, visit.iDepth, color="k", s=2) + self.gl.scatter(xLoc, visit.iDepth, color="k", s=2) # plt.text(xLoc, visit.iDepth, sDist, color="k", rotation=45) self.gl.text(xLoc, visit.iDepth, visit.rc, color="k", rotation=45) @@ -312,8 +312,8 @@ class RenderTool(object): # line from prev node self.gl.plot([xLocPrev, xLoc], - [visit.iDepth-1, visit.iDepth], - color="k", alpha=0.5, lw=1) + [visit.iDepth - 1, visit.iDepth], + color="k", alpha=0.5, lw=1) if rDist < 0.1: visitDest = visit @@ -326,8 +326,8 @@ class RenderTool(object): rDist = np.linalg.norm(array(visit.rc) - array(xyTarg)) xLoc = rDist + visit.iDir / 4 if xLocPrev is not None: - self.gl.plot([xLoc, xLocPrev], [visit.iDepth, visit.iDepth+1], - color="r", alpha=0.5, lw=2) + self.gl.plot([xLoc, xLocPrev], [visit.iDepth, visit.iDepth + 1], + color="r", alpha=0.5, lw=2) xLocPrev = xLoc visit = visit.prev # prev = prev.prev @@ -360,13 +360,12 @@ class RenderTool(object): self.gl.plot(*xyLine.T, color="r", alpha=0.5, lw=1) - xyMid = np.sum(xyLine * [[1/4], [3/4]], axis=0) + xyMid = np.sum(xyLine * [[1 / 4], [3 / 4]], axis=0) xyArrow = array([ - xyMid + [-dx-dy, +dx-dy], + xyMid + [-dx - dy, +dx - dy], xyMid, - xyMid + [-dx+dy, -dx-dy] - ]) + xyMid + [-dx + dy, -dx - dy]]) self.gl.plot(*xyArrow.T, color="r") visit = visit.prev @@ -411,13 +410,12 @@ class RenderTool(object): self.gl.plot(*xyLine2.T, color=sColor) if bArrow: - xyMid = np.sum(xyLine2 * [[1/4], [3/4]], axis=0) + xyMid = np.sum(xyLine2 * [[1 / 4], [3 / 4]], axis=0) xyArrow = array([ - xyMid + [-dx-dy, +dx-dy], + xyMid + [-dx - dy, +dx - dy], xyMid, - xyMid + [-dx+dy, -dx-dy] - ]) + xyMid + [-dx + dy, -dx - dy]]) self.gl.plot(*xyArrow.T, color=sColor) else: @@ -443,10 +441,9 @@ class RenderTool(object): iArc = int(len(rt.gArc) / 2) xyMid = xyCorner + rt.gArc[iArc] * dxy2 xyArrow = array([ - xyMid + [-dx-dy, +dx-dy], + xyMid + [-dx - dy, +dx - dy], xyMid, - xyMid + [-dx+dy, -dx-dy] - ]) + xyMid + [-dx + dy, -dx - dy]]) self.gl.plot(*xyArrow.T, color=sColor) def renderEnv( @@ -480,14 +477,14 @@ class RenderTool(object): # Draw cells grid grid_color = [0.95, 0.95, 0.95] - for r in range(env.height+1): - self.gl.plot([0, (env.width+1)*cell_size], - [-r*cell_size, -r*cell_size], - color=grid_color) - for c in range(env.width+1): - self.gl.plot([c*cell_size, c*cell_size], - [0, -(env.height+1)*cell_size], - color=grid_color) + for r in range(env.height + 1): + self.gl.plot([0, (env.width + 1) * cell_size], + [-r * cell_size, -r * cell_size], + color=grid_color) + for c in range(env.width + 1): + self.gl.plot([c * cell_size, c * cell_size], + [0, -(env.height + 1) * cell_size], + color=grid_color) # Draw each cell independently for r in range(env.height): @@ -495,16 +492,16 @@ class RenderTool(object): # bounding box of the grid cell x0 = cell_size * c # left - x1 = cell_size * (c+1) # right + x1 = cell_size * (c + 1) # right y0 = cell_size * -r # top - y1 = cell_size * -(r+1) # bottom + y1 = cell_size * -(r + 1) # bottom # centres of cell edges coords = [ - ((x0+x1)/2.0, y0), # N middle top - (x1, (y0+y1)/2.0), # E middle right - ((x0+x1)/2.0, y1), # S middle bottom - (x0, (y0+y1)/2.0) # W middle left + ((x0 + x1) / 2.0, y0), # N middle top + (x1, (y0 + y1) / 2.0), # E middle right + ((x0 + x1) / 2.0, y1), # S middle bottom + (x0, (y0 + y1) / 2.0) # W middle left ] # cell centre @@ -570,21 +567,21 @@ class RenderTool(object): # Draw each agent + its orientation + its target if agents: - cmap = self.gl.get_cmap('hsv', lut=env.number_of_agents+1) + cmap = self.gl.get_cmap('hsv', lut=env.number_of_agents + 1) for i in range(env.number_of_agents): self._draw_square(( - env.agents_position[i][1] * - cell_size+cell_size/2, - -env.agents_position[i][0] * - cell_size-cell_size/2), - cell_size/8, cmap(i)) + env.agents_position[i][1] * + cell_size + cell_size / 2, + -env.agents_position[i][0] * + cell_size - cell_size / 2), + cell_size / 8, cmap(i)) for i in range(env.number_of_agents): self._draw_square(( - env.agents_target[i][1] * - cell_size+cell_size/2, - -env.agents_target[i][0] * - cell_size-cell_size/2), - cell_size/3, [c for c in cmap(i)]) + env.agents_target[i][1] * + cell_size + cell_size / 2, + -env.agents_target[i][0] * + cell_size - cell_size / 2), + cell_size / 3, [c for c in cmap(i)]) # orientation is a line connecting the center of the cell to the # side of the square of the agent @@ -594,8 +591,8 @@ class RenderTool(object): (new_position[1] + env.agents_position[i][1]) / 2 * cell_size) self.gl.plot( - [env.agents_position[i][1] * cell_size+cell_size/2, new_position[1]+cell_size/2], - [-env.agents_position[i][0] * cell_size-cell_size/2, -new_position[0]-cell_size/2], + [env.agents_position[i][1] * cell_size + cell_size / 2, new_position[1] + cell_size / 2], + [-env.agents_position[i][0] * cell_size - cell_size / 2, -new_position[0] - cell_size / 2], color=cmap(i), linewidth=2.0) @@ -604,7 +601,7 @@ class RenderTool(object): if frames: self.gl.text(0.1, yText[2], "Frame:{:}".format(self.iFrame)) self.iFrame += 1 - + if iEpisode is not None: self.gl.text(0.1, yText[1], "Ep:{}".format(iEpisode)) @@ -631,8 +628,8 @@ class RenderTool(object): return def _draw_square(self, center, size, color): - x0 = center[0]-size/2 - x1 = center[0]+size/2 - y0 = center[1]-size/2 - y1 = center[1]+size/2 + x0 = center[0] - size / 2 + x1 = center[0] + size / 2 + y0 = center[1] - size / 2 + y1 = center[1] + size / 2 self.gl.plot([x0, x1, x1, x0, x0], [y0, y0, y1, y1, y0], color=color) diff --git a/tests/test_env_observation_builder.py b/tests/test_env_observation_builder.py index db264c2..55c229e 100644 --- a/tests/test_env_observation_builder.py +++ b/tests/test_env_observation_builder.py @@ -46,14 +46,14 @@ def test_global_obs(): double_switch_south_horizontal_straight, 180) rail_map = np.array( - [[empty] * 3 + [dead_end_from_south] + [empty] * 6] + - [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 + - [[dead_end_from_east] + [horizontal_straight] * 2 + - [double_switch_north_horizontal_straight] + - [horizontal_straight] * 2 + [double_switch_south_horizontal_straight] + - [horizontal_straight] * 2 + [dead_end_from_west]] + - [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 + - [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16) + [[empty] * 3 + [dead_end_from_south] + [empty] * 6] + + [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 + + [[dead_end_from_east] + [horizontal_straight] * 2 + + [double_switch_north_horizontal_straight] + + [horizontal_straight] * 2 + [double_switch_south_horizontal_straight] + + [horizontal_straight] * 2 + [dead_end_from_west]] + + [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 + + [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16) rail = GridTransitionMap(width=rail_map.shape[1], height=rail_map.shape[0], transitions=transitions) diff --git a/tests/test_environments.py b/tests/test_environments.py index ea8748b..210f1c7 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -30,8 +30,7 @@ def test_rail_environment_single_agent(): transitions = Grid4Transitions([]) vertical_line = cells[1] south_symmetrical_switch = cells[6] - north_symmetrical_switch = transitions.rotate_transition( - south_symmetrical_switch, 180) + north_symmetrical_switch = transitions.rotate_transition(south_symmetrical_switch, 180) # Simple turn not in the base transitions ? south_east_turn = int('0100000000000010', 2) south_west_turn = transitions.rotate_transition(south_east_turn, 90) diff --git a/tests/test_rendertools.py b/tests/test_rendertools.py index 528cc59..1f5c317 100644 --- a/tests/test_rendertools.py +++ b/tests/test_rendertools.py @@ -11,7 +11,7 @@ import os import matplotlib.pyplot as plt import flatland.utils.rendertools as rt -from flatland.core.env_observation_builder import GlobalObsForRailEnv, TreeObsForRailEnv +from flatland.core.env_observation_builder import TreeObsForRailEnv def checkFrozenImage(sFileImage): @@ -31,7 +31,7 @@ def checkFrozenImage(sFileImage): bytesFrozenImage = bytesImage else: assert(bytesFrozenImage.shape == bytesImage.shape) - assert((np.sum(np.square(bytesFrozenImage-bytesImage)) / bytesFrozenImage.size) < 1e-3) + assert((np.sum(np.square(bytesFrozenImage - bytesImage)) / bytesFrozenImage.size) < 1e-3) def test_render_env(): diff --git a/tests/test_transitions.py b/tests/test_transitions.py index f68b836..0f56e88 100644 --- a/tests/test_transitions.py +++ b/tests/test_transitions.py @@ -15,65 +15,60 @@ def test_valid_railenv_transitions(): for i in range(2): assert(rail_env_trans.get_transitions( - int('1100110000110011', 2), i) == (1, 1, 0, 0)) + int('1100110000110011', 2), i) == (1, 1, 0, 0)) assert(rail_env_trans.get_transitions( - int('1100110000110011', 2), 2+i) == (0, 0, 1, 1)) + int('1100110000110011', 2), 2 + i) == (0, 0, 1, 1)) no_transition_cell = int('0000000000000000', 2) for i in range(4): assert(rail_env_trans.get_transitions( - no_transition_cell, i) == (0, 0, 0, 0)) + no_transition_cell, i) == (0, 0, 0, 0)) # Facing south, going south - north_south_transition = rail_env_trans.set_transitions( - no_transition_cell, 2, (0, 0, 1, 0)) + north_south_transition = rail_env_trans.set_transitions(no_transition_cell, 2, (0, 0, 1, 0)) assert(rail_env_trans.set_transition( - north_south_transition, 2, 2, 0) == no_transition_cell) + north_south_transition, 2, 2, 0) == no_transition_cell) assert(rail_env_trans.get_transition( - north_south_transition, 2, 2)) + north_south_transition, 2, 2)) # Facing north, going east south_east_transition = \ - rail_env_trans.set_transition( - no_transition_cell, 0, 1, 1) + rail_env_trans.set_transition(no_transition_cell, 0, 1, 1) assert(rail_env_trans.get_transition( - south_east_transition, 0, 1)) + south_east_transition, 0, 1)) # The opposite transitions are not feasible assert(not rail_env_trans.get_transition( - north_south_transition, 2, 0)) + north_south_transition, 2, 0)) assert(not rail_env_trans.get_transition( - south_east_transition, 2, 1)) + south_east_transition, 2, 1)) - east_west_transition = rail_env_trans.rotate_transition( - north_south_transition, 90) - north_west_transition = rail_env_trans.rotate_transition( - south_east_transition, 180) + east_west_transition = rail_env_trans.rotate_transition(north_south_transition, 90) + north_west_transition = rail_env_trans.rotate_transition(south_east_transition, 180) # Facing west, going west assert(rail_env_trans.get_transition( - east_west_transition, 3, 3)) + east_west_transition, 3, 3)) # Facing south, going west assert(rail_env_trans.get_transition( - north_west_transition, 2, 3)) + north_west_transition, 2, 3)) assert(south_east_transition == rail_env_trans.rotate_transition( - south_east_transition, 360)) + south_east_transition, 360)) def test_diagonal_transitions(): diagonal_trans_env = Grid8Transitions([]) # Facing north, going north-east - south_northeast_transition = int('01000000' + '0'*8*7, 2) + south_northeast_transition = int('01000000' + '0' * 8 * 7, 2) assert(diagonal_trans_env.get_transitions( - south_northeast_transition, 0) == (0, 1, 0, 0, 0, 0, 0, 0)) + south_northeast_transition, 0) == (0, 1, 0, 0, 0, 0, 0, 0)) # Allowing transition from north to southwest: Facing south, going SW north_southwest_transition = \ - diagonal_trans_env.set_transitions( - int('0' * 64, 2), 4, (0, 0, 0, 0, 0, 1, 0, 0)) + diagonal_trans_env.set_transitions(int('0' * 64, 2), 4, (0, 0, 0, 0, 0, 1, 0, 0)) assert(diagonal_trans_env.rotate_transition( - south_northeast_transition, 180) == north_southwest_transition) + south_northeast_transition, 180) == north_southwest_transition) -- GitLab