diff --git a/examples/temporary_example.py b/examples/temporary_example.py index a5c97660e4d22ebaffeef27b320e45b21f29dfe5..6bed439cb21f611353763da36a890739c77e5866 100644 --- a/examples/temporary_example.py +++ b/examples/temporary_example.py @@ -3,6 +3,7 @@ import numpy as np import matplotlib.pyplot as plt from flatland.envs.rail_env import * +from flatland.envs.generators import * from flatland.core.env_observation_builder import TreeObsForRailEnv from flatland.utils.rendertools import * @@ -62,6 +63,11 @@ specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (7, 0), (0, 0)], [(0, 0), (0, 0), (0, 0), (2, 180), (2, 90), (7, 90)], [(0, 0), (0, 0), (0, 0), (7, 180), (0, 0), (0, 0)]] +# CURVED RAIL + DEAD-ENDS TEST +specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0)], + [(7, 270), (1, 90), (1, 90), (8, 90), (0, 0), (0, 0)], + [(0, 0), (7, 270),(1, 90), (8, 180), (0, 00), (0, 0)]] + env = RailEnv(width=6, height=4, rail_generator=rail_from_manual_specifications_generator(specs), diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py index fd42787a68c9ee007368fb3ced1ff5162c9926eb..2e34b4236943b36b167268828e3010e727108760 100644 --- a/flatland/core/transition_map.py +++ b/flatland/core/transition_map.py @@ -314,14 +314,13 @@ class GridTransitionMap(TransitionMap): Checks that: - surrounding cells have inbound transitions for all the outbound transitions of this cell. - + These are NOT checked - see transition.is_valid: - all transitions have the mirror transitions (N->E <=> W->S) - Reverse transitions (N -> S) only exist for a dead-end - a cell contains either no dead-ends or exactly one Returns: True (valid) or False (invalid) - """ cell_transition = self.grid[tuple(rcPos)] diff --git a/flatland/core/transitions.py b/flatland/core/transitions.py index a862c6a81dbd71a0aa95f2305158f940e964c0ad..622d900598bba6bbc48750d6bb48923975af9b5e 100644 --- a/flatland/core/transitions.py +++ b/flatland/core/transitions.py @@ -551,7 +551,7 @@ class RailEnvTransitions(Grid4Transitions): super(RailEnvTransitions, self).__init__( transitions=self.transition_list ) - + # These bits represent all the possible dead ends self.maskDeadEnds = 0b0010000110000100 @@ -569,10 +569,10 @@ class RailEnvTransitions(Grid4Transitions): def print(self, cell_transition): print(" NESW") - print("N", format(cell_transition >> (3*4) & 0xF, '04b')) - print("E", format(cell_transition >> (2*4) & 0xF, '04b')) - print("S", format(cell_transition >> (1*4) & 0xF, '04b')) - print("W", format(cell_transition >> (0*4) & 0xF, '04b')) + print("N", format(cell_transition >> (3 * 4) & 0xF, '04b')) + print("E", format(cell_transition >> (2 * 4) & 0xF, '04b')) + print("S", format(cell_transition >> (1 * 4) & 0xF, '04b')) + print("W", format(cell_transition >> (0 * 4) & 0xF, '04b')) def repr(self, cell_transition, version=0): """ @@ -585,25 +585,23 @@ class RailEnvTransitions(Grid4Transitions): sbinTrans = format(cell_transition, "#018b")[2:] if version == 0: sRepr = " ".join([ - "{}:{}".format(sDir, sbinTrans[i:i+4]) + "{}:{}".format(sDir, sbinTrans[i:(i + 4)]) for i, sDir in zip( range(0, len(sbinTrans), 4), - self.lsDirs # NESW - )]) + self.lsDirs)]) # NESW return sRepr if version == 1: lsRepr = [] for iDirIn in range(0, 4): - sDirTrans = sbinTrans[iDirIn*4:iDirIn*4+4] + sDirTrans = sbinTrans[(iDirIn * 4):(iDirIn * 4 + 4)] if sDirTrans == "0000": continue sDirsOut = [ self.lsDirs[iDirOut] for iDirOut in range(0, 4) - if sDirTrans[iDirOut] == "1" - ] + if sDirTrans[iDirOut] == "1"] lsRepr.append(self.lsDirs[iDirIn] + ":" + "".join(sDirsOut)) return ", ".join(lsRepr) diff --git a/flatland/envs/env_utils.py b/flatland/envs/env_utils.py index 79f5eb0d23b8b8a50bc1afd9ba1d34bba76c4ffd..382bdd7d9c483aa491f9c11b06b62b13b45490a5 100644 --- a/flatland/envs/env_utils.py +++ b/flatland/envs/env_utils.py @@ -271,4 +271,3 @@ def connect_rail(rail_trans, rail_array, start, end): def distance_on_rail(pos1, pos2): return abs(pos1[0] - pos2[0]) + abs(pos1[1] - pos2[1]) - diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index baba6baac4a4efe76844147d7c49435a88ad79af..021c63f31a958790c04e17f49c7a1da57777cb9c 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -250,7 +250,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8): transitions_templates_ = [] transition_probabilities = [] - for i in range(len(t_utils.transitions)-4): # don't include dead-ends + for i in range(len(t_utils.transitions) - 4): # don't include dead-ends all_transitions = 0 for dir_ in range(4): trans = t_utils.get_transitions(t_utils.transitions[i], dir_) @@ -475,4 +475,3 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8): return return_rail return generator - diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 9a48915bce8386f0ffe0237c6bf9512b0609d82a..cbb11c2db70dece00199bfae6bb051504f64a9f2 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -268,7 +268,7 @@ class RailEnv(Environment): break else: self.agents_direction[i] = direction - + # Reset the state of the observation builder with the new environment self.obs_builder.reset() @@ -314,60 +314,33 @@ class RailEnv(Environment): # compute number of possible transitions in the current # cell used to check for invalid actions - nbits = 0 - tmp = self.rail.get_transitions((pos[0], pos[1])) possible_transitions = self.rail.get_transitions((pos[0], pos[1], direction)) - # print(np.sum(self.rail.get_transitions((pos[0], pos[1],direction))), - # self.rail.get_transitions((pos[0], pos[1],direction)), - # self.rail.get_transitions((pos[0], pos[1])), - # (pos[0], pos[1],direction)) + num_transitions = np.count_nonzero(possible_transitions) - while tmp > 0: - nbits += (tmp & 1) - tmp = tmp >> 1 movement = direction # print(nbits,np.sum(possible_transitions)) if action == 1: movement = direction - 1 - if nbits <= 2 or np.sum(possible_transitions) <= 1: + if num_transitions <= 1: transition_isValid = False elif action == 3: movement = direction + 1 - if nbits <= 2 or np.sum(possible_transitions) <= 1: + if num_transitions <= 1: transition_isValid = False + if movement < 0: movement += 4 if movement >= 4: movement -= 4 - is_deadend = False if action == 2: - if nbits == 1: - # dead-end; assuming the rail network is consistent, - # this should match the direction the agent has come - # from. But it's better to check in any case. - reverse_direction = 0 - if direction == 0: - reverse_direction = 2 - elif direction == 1: - reverse_direction = 3 - elif direction == 2: - reverse_direction = 0 - elif direction == 3: - reverse_direction = 1 - - valid_transition = self.rail.get_transition( - (pos[0], pos[1], direction), - reverse_direction) - if valid_transition: - direction = reverse_direction - movement = reverse_direction - is_deadend = True - - if np.sum(possible_transitions) == 1: - # Take only available transition + if num_transitions == 1: + # - dead-end, straight line or curved line; + # movement will be the only valid transition + # - take only available transition movement = np.argmax(possible_transitions) + transition_isValid = True new_position = self._new_position(pos, movement) # Is it a legal move? 1) transition allows the movement in the @@ -388,7 +361,7 @@ class RailEnv(Environment): if transition_isValid is None: transition_isValid = self.rail.get_transition( (pos[0], pos[1], direction), - movement) or is_deadend + movement) cell_isFree = True for j in range(self.number_of_agents): diff --git a/flatland/utils/editor.py b/flatland/utils/editor.py index 21c1e64b8ce23b70d21dae032eb75f410dbddcca..543b793c5f7a0e0a6a9e07c46902735ec21c7158 100644 --- a/flatland/utils/editor.py +++ b/flatland/utils/editor.py @@ -31,10 +31,10 @@ class EditorMVC(object): if env is None: env = RailEnv(width=10, - height=10, - rail_generator=random_rail_generator(cell_type_relative_proportion=[1, 1] + [0.5] * 6), - number_of_agents=0, - obs_builder_object=TreeObsForRailEnv(max_depth=2)) + height=10, + rail_generator=random_rail_generator(cell_type_relative_proportion=[1, 1] + [0.5] * 6), + number_of_agents=0, + obs_builder_object=TreeObsForRailEnv(max_depth=2)) env.reset() @@ -113,20 +113,19 @@ class View(object): wButton = ipywidgets.Button(description=dButton["name"]) wButton.on_click(dButton["method"]) self.lwButtons.append(wButton) - + self.wVbox_controls = VBox([ self.wFilename, # self.wDrawMode, *self.lwButtons, self.wSize, self.wDebug, self.wDebug_move, - self.wProg_steps - ]) - + self.wProg_steps]) + self.wMain = HBox([self.wImage, self.wVbox_controls]) def drawStroke(self): pass - + def new_env(self): self.oRT = rt.RenderTool(self.editor.env) @@ -139,11 +138,11 @@ class View(object): img = self.oRT.getImage() plt.clf() plt.close() - + self.wImage.data = img self.writableData = np.copy(self.wImage.data) return img - + def redisplayImage(self): if self.writableData is not None: # This updates the image in the browser to be the new edited version @@ -152,7 +151,7 @@ class View(object): def drag_path_element(self, x, y): # Draw a black square on the in-memory copy of the image if x > 10 and x < self.yxSize[1] and y > 10 and y < self.yxSize[0]: - self.writableData[y-2:y+2, x-2:x+2, :] = 0 + self.writableData[(y - 2):(y + 2), (x - 2):(x + 2), :] = 0 def xy_to_rc(self, x, y): rcCell = ((array([y, x]) - self.yxBase) / self.nPixCell).astype(int) @@ -227,7 +226,7 @@ class Controller(object): # The intention was to avoid too many redraws. if event["buttons"] > 0: qEvents.append((time.time(), x, y)) - + # Process the events in our queue: # Draw a black square to indicate a trail # TODO: infer a vector of moves between these squares to avoid gaps @@ -238,13 +237,13 @@ class Controller(object): if tNow - qEvents[0][0] > 0.1: # wait before trying to draw # height, width = wid.data.shape[:2] # writableData = np.copy(self.wid_img.data) # writable copy of image - wid_img.data is somehow readonly - + # with self.wid_img.hold_sync(): - + while len(qEvents) > 0: t, x, y = qEvents.popleft() # get events from our queue self.view.drag_path_element(x, y) - + # Translate and scale from x,y to integer row,col (note order change) # rcCell = ((array([y, x]) - self.yxBase) / self.nPixCell).astype(int) rcCell = self.view.xy_to_rc(x, y) @@ -266,13 +265,13 @@ class Controller(object): def refresh(self, event): self.debug("refresh") self.view.redraw() - + def clear(self, event): self.model.clear() def regenerate(self, event): self.model.regenerate() - + def setRegenSize(self, event): self.model.setRegenSize(event["new"]) @@ -342,14 +341,14 @@ class EditorModel(object): """Mouse motion event handler for drawing. """ lrcStroke = self.lrcStroke - + # Store the row,col location of the click, if we have entered a new cell if len(lrcStroke) > 0: rcLast = lrcStroke[-1] if not np.array_equal(rcLast, rcCell): # only save at transition lrcStroke.append(rcCell) self.debug("lrcStroke ", len(lrcStroke), rcCell) - + else: # This is the first cell in a mouse stroke lrcStroke.append(rcCell) @@ -427,14 +426,16 @@ class EditorModel(object): # Set the transition # If this transition spans 3 cells, it is not a deadend, so remove any deadends. # The user will need to resolve any conflicts. - self.env.rail.set_transition((*rcMiddle, liTrans[0]), liTrans[1], bAddRemove, - remove_deadends=not bDeadend) + self.env.rail.set_transition((*rcMiddle, liTrans[0]), + liTrans[1], + bAddRemove, + remove_deadends=not bDeadend) # Also set the reverse transition # use the reversed outbound transition for inbound # and the reversed inbound transition for outbound self.env.rail.set_transition((*rcMiddle, mirror(liTrans[1])), - mirror(liTrans[0]), bAddRemove, remove_deadends=not bDeadend) + mirror(liTrans[0]), bAddRemove, remove_deadends=not bDeadend) # bValid = self.env.rail.is_cell_valid(rcMiddle) # if not bValid: @@ -457,7 +458,7 @@ class EditorModel(object): # get the row, col delta between the 2 cells, eg [-1,0] = North rc1Trans = np.diff(rc2Cells, axis=0) - + # get the direction index for the transition liTrans = [] for rcTrans in rc1Trans: @@ -491,7 +492,7 @@ class EditorModel(object): self.env.agents_handles = [] self.env.agents_target = [] self.player = None - + self.redraw() def setFilename(self, filename): @@ -507,7 +508,7 @@ class EditorModel(object): self.redraw() else: self.log("File does not exist:", self.env_filename, " Working directory: ", os.getcwd()) - + def save(self): self.log("save to ", self.env_filename, " working dir: ", os.getcwd()) self.env.rail.save_transition_map(self.env_filename) @@ -515,17 +516,17 @@ class EditorModel(object): def regenerate(self): self.log("Regenerate size", self.regen_size) self.env = RailEnv(width=self.regen_size, - height=self.regen_size, - rail_generator=random_rail_generator(cell_type_relative_proportion=[1, 1] + [0.5] * 6), - number_of_agents=self.env.number_of_agents, - obs_builder_object=TreeObsForRailEnv(max_depth=2)) + height=self.regen_size, + rail_generator=random_rail_generator(cell_type_relative_proportion=[1, 1] + [0.5] * 6), + number_of_agents=self.env.number_of_agents, + obs_builder_object=TreeObsForRailEnv(max_depth=2)) self.env.reset(regen_rail=True) self.fix_env() self.set_env(self.env) self.player = Player(self.env) self.view.new_env() self.redraw() - + def setRegenSize(self, size): self.regen_size = size @@ -579,5 +580,9 @@ class EditorModel(object): def debug_cell(self, rcCell): binTrans = self.env.rail.get_transitions(rcCell) sbinTrans = format(binTrans, "#018b")[2:] - self.debug("cell ", rcCell, "Transitions: ", binTrans, sbinTrans, - [sbinTrans[i:i+4] for i in range(0, len(sbinTrans), 4)]) + self.debug("cell ", + rcCell, + "Transitions: ", + binTrans, + sbinTrans, + [sbinTrans[i:(i + 4)] for i in range(0, len(sbinTrans), 4)]) diff --git a/flatland/utils/render_qt.py b/flatland/utils/render_qt.py index 05c436dc9f8723abdf2c013b9fad5181e695b58f..0804c9dd27b191b54f796b180165e1d73180ec16 100644 --- a/flatland/utils/render_qt.py +++ b/flatland/utils/render_qt.py @@ -61,8 +61,8 @@ class QTGL(GraphicsLayer): if lastx is not None: # print("line", lastx, lasty, x, y) self.qtr.drawLine( - lastx*self.cell_pixels, -lasty*self.cell_pixels, - x*self.cell_pixels, -y*self.cell_pixels) + lastx * self.cell_pixels, -lasty * self.cell_pixels, + x * self.cell_pixels, -y * self.cell_pixels) lastx = x lasty = y else: diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index 09ac48997bcb1c420be146ee011042c45b6b597f..e84ff46bb7fa7590bc043a3bc9e3c109670a8879 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -145,7 +145,7 @@ class RenderTool(object): self.plotTrans(visit.rc, gTransRCAg, depth=str(visit.iDepth), color=color) def plotAgents(self, targets=True): - cmap = self.gl.get_cmap('hsv', lut=self.env.number_of_agents+1) + cmap = self.gl.get_cmap('hsv', lut=self.env.number_of_agents + 1) for iAgent in range(self.env.number_of_agents): oColor = cmap(iAgent) @@ -208,16 +208,16 @@ class RenderTool(object): xyDir = np.matmul(rcDir, rt.grc2xy) # agent direction in xy xyPos = np.matmul(rcPos - rcDir / 2, rt.grc2xy) + rt.xyHalf - #print("Agent:", rcPos, iDir, rcDir, xyDir, xyPos) + # print("Agent:", rcPos, iDir, rcDir, xyDir, xyPos) self.gl.scatter(*xyPos, color=color, marker="o", s=100) # agent location - xyDirLine = array([xyPos, xyPos + xyDir/2]).T # line for agent orient. + xyDirLine = array([xyPos, xyPos + xyDir / 2]).T # line for agent orient. self.gl.plot(*xyDirLine, color=color, lw=5, ms=0, alpha=0.6) if target is not None: rcTarget = array(target) xyTarget = np.matmul(rcTarget, rt.grc2xy) + rt.xyHalf - self._draw_square(xyTarget, 1/3, color) + self._draw_square(xyTarget, 1 / 3, color) def plotTrans(self, rcPos, gTransRCAg, color="r", depth=None): """ @@ -587,7 +587,7 @@ class RenderTool(object): # Draw each agent + its orientation + its target if agents: - cmap = self.gl.get_cmap('hsv', lut=env.number_of_agents+1) + cmap = self.gl.get_cmap('hsv', lut=env.number_of_agents + 1) self.plotAgents(targets=True) if False: @@ -659,4 +659,3 @@ class RenderTool(object): def getImage(self): return self.gl.getImage() -