cell_graph.py 3.27 KB
 nilabha committed Jul 14, 2020 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 ``````import numpy as np from collections import deque from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnvActions class Vertex: def __init__(self, y, x, idx): self.point = (y, x) self.idx = idx self.out = [[], [], [], []] self.in_edges = [[], [], [], []] class Edge: def __init__(self, start_v, end_v, start_dir, end_dir, action_type): self.start_v = start_v self.end_v = end_v self.start_direction = start_dir self.end_direction = end_dir self.action_type = action_type class CellGraph: def __init__(self, env : RailEnv): self.env = env self._build_graph() def _build_graph(self): width = self.env.width height = self.env.height self.vertex_idx = np.zeros((height, width), dtype=np.int) self.vertex_idx.fill(-1) self.vertexes = [] for y in range(height): for x in range(width): if self._is_rail(y, x): idx = len(self.vertexes) self.vertexes.append(Vertex(y, x, idx)) self.vertex_idx[y, x] = idx # print('vertexes:', len(self.vertexes)) edges_cnt = 0 for v_idx, v in enumerate(self.vertexes): start_point = v.point for direction in range(4): directions = self._possible_directions(start_point, direction) # assert len(directions) <= 2 for end_direction in directions: next_point = self._next_point(start_point, end_direction) end_v = self._vertex_idx_from_point(next_point) action_type = self._action_from_directions(direction, end_direction) e = Edge(v_idx, end_v, direction, end_direction, action_type) v.out[direction].append(e) self.vertexes[end_v].in_edges[end_direction].append(e) edges_cnt += 1 # print('edges_cnt', edges_cnt) def _is_rail(self, y, x): return self.env.rail.grid[y, x] != 0 def _next_point(self, point, direction): if direction==0: return (point[0]-1, point[1]) elif direction==1: return (point[0], point[1]+1) elif direction==2: return (point[0]+1, point[1]) else: return (point[0], point[1]-1) def _possible_directions(self, point, in_direction): return np.flatnonzero(self.env.rail.get_transitions(point[0], point[1], in_direction)) def _vertex_idx_from_point(self, point): assert (point[0] >= 0) and (point[0] < self.vertex_idx.shape[0]) assert (point[1] >= 0) and (point[1] < self.vertex_idx.shape[1]) return self.vertex_idx[point[0], point[1]] def position_from_vertexid(self, vertexid: int): return self.vertexes[vertexid].point def _action_from_directions(self, in_direction, new_direction): if in_direction==new_direction: return RailEnvActions.MOVE_FORWARD if (in_direction+1)%4 == new_direction: return RailEnvActions.MOVE_RIGHT elif (in_direction-1)%4 == new_direction: return RailEnvActions.MOVE_LEFT else: return RailEnvActions.MOVE_FORWARD ``````