diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index 5ece03e9c56d672b76a453e0036f6b89c3a6ee77..f07cba14f9e1221e2b1544892dc71c4c39765e9c 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -43,7 +43,8 @@ env = RailEnv(width=100, number_of_agents=100, stochastic_data=stochastic_data, # Malfunction data generator obs_builder_object=GlobalObsForRailEnv(), - remove_agents_at_target=True + remove_agents_at_target=True, + record_steps=True ) # RailEnv.DEPOT_POSITION = lambda agent, agent_handle : (agent_handle % env.height,0) @@ -132,3 +133,4 @@ for step in range(500): break print('Episode: Steps {}\t Score = {}'.format(step, score)) +env.save_episode("saved_episode_2.mpk") diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 0dd58813b5568ba95e11f984d15f6bd256100c0f..887616d3a1d8003ffe53db9a3f22a580c2bc7f27 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -120,7 +120,8 @@ class RailEnv(Environment): obs_builder_object: ObservationBuilder = GlobalObsForRailEnv(), stochastic_data=None, remove_agents_at_target=True, - random_seed=1 + random_seed=1, + record_steps=False ): """ Environment init. @@ -217,6 +218,10 @@ class RailEnv(Environment): # global numpy array of agents position, -1 means that the cell is free, otherwise the agent handle is placed # inside the cell self.agent_positions: np.ndarray = np.zeros((height, width), dtype=int) - 1 + + # save episode timesteps ie agent positions, orientations. (not yet actions / observations) + self.record_steps = record_steps # whether to save timesteps + self.cur_episode = [] # save timesteps in here def _seed(self, seed=None): self.np_random, seed = seeding.np_random(seed) @@ -505,6 +510,8 @@ class RailEnv(Environment): self.dones["__all__"] = True for i_agent in range(self.get_num_agents()): self.dones[i_agent] = True + if self.record_steps: + self.record_timestep() return self._get_observations(), self.rewards_dict, self.dones, info_dict @@ -909,6 +916,13 @@ class RailEnv(Environment): with open(filename, "wb") as file_out: file_out.write(self.get_full_state_msg()) + def save_episode(self, filename): + episode_data = self.cur_episode + msgpack.packb(episode_data, use_bin_type=True) + dict_data = {"episode": episode_data} + # msgpack.packb(msg_data, use_bin_type=True) + with open(filename, "wb") as file_out: + file_out.write(msgpack.packb(dict_data)) def load(self, filename): """ Load environment with distance map from a file diff --git a/flatland/utils/graphics_pil.py b/flatland/utils/graphics_pil.py index 86d2e2495fb4e39622b167a01d47f784d5190ab6..8b7a57e899a24523135e828b42359627d350822f 100644 --- a/flatland/utils/graphics_pil.py +++ b/flatland/utils/graphics_pil.py @@ -133,6 +133,7 @@ class PILGL(GraphicsLayer): return self.agent_colors[iAgent % self.n_agent_colors] def plot(self, gX, gY, color=None, linewidth=3, layer=RAIL_LAYER, opacity=255, **kwargs): + """ Draw a line joining the points in gX, GY - each an""" color = self.adapt_color(color) if len(color) == 3: color += (opacity,) @@ -140,7 +141,8 @@ class PILGL(GraphicsLayer): color = color[:3] + (opacity,) gPoints = np.stack([array(gX), -array(gY)]).T * self.nPixCell gPoints = list(gPoints.ravel()) - self.draws[layer].line(gPoints, fill=color, width=self.linewidth) + # the width here was self.linewidth - not really sure of the implications + self.draws[layer].line(gPoints, fill=color, width=linewidth) def scatter(self, gX, gY, color=None, marker="o", s=50, layer=RAIL_LAYER, opacity=255, *args, **kwargs): color = self.adapt_color(color) @@ -540,9 +542,8 @@ class PILSVG(PILGL): if (col + row + col * row) % 3 == 0: a = (a + (col + row + col * row)) % len(self.dBuildings) pil_track = self.dBuildings[a] - elif (self.background_grid[col][row] > 5 + ((col * row + col) % 3)) or ( - (col ** 3 + row ** 2 + col * row) % - 10 == 0): + elif ((self.background_grid[col][row] > 5 + ((col * row + col) % 3)) or + ((col ** 3 + row ** 2 + col * row) % 10 == 0)): a = int(self.background_grid[col][row]) - 4 a2 = (a + (col + row + col * row + col ** 3 + row ** 4)) if a2 % 64 > 11: @@ -635,7 +636,7 @@ class PILSVG(PILGL): self.pil_zug[(in_direction_2, out_direction_2, color_idx)] = pils[color_idx] def set_agent_at(self, agent_idx, row, col, in_direction, out_direction, is_selected, - rail_grid=None, show_debug=False, clear_debug_text=True): + rail_grid=None, show_debug=False, clear_debug_text=True, malfunction=False): delta_dir = (out_direction - in_direction) % 4 color_idx = agent_idx % self.n_agent_colors # when flipping direction at a dead end, use the "out_direction" direction. @@ -671,11 +672,23 @@ class PILSVG(PILGL): self.text_rowcol((row + dr, col + dc,), str(agent_idx), layer=PILGL.SELECTED_AGENT_LAYER) else: self.text_rowcol((row + 0.2, col + 0.2,), str(agent_idx)) + if malfunction: + self.draw_malfunction(agent_idx, (row, col)) def set_cell_occupied(self, agent_idx, row, col): occupied_im = self.cell_occupied[agent_idx % len(self.cell_occupied)] self.draw_image_row_col(occupied_im, (row, col), 1) + def draw_malfunction(self, agent_idx, rcTopLeft): + # Roughly an "X" shape to indicate malfunction + grcOffsets = np.array([[0, 0], [1, 1], [1, 0], [0, 1]]) + grcPoints = np.array(rcTopLeft)[None] + grcOffsets + gxyPoints = grcPoints[:, [1, 0]] + gxPoints, gyPoints = gxyPoints.T + # print(agent_idx, rcTopLeft, gxyPoints, "X:", gxPoints, "Y:", gyPoints) + # plot(self, gX, gY, color=None, linewidth=3, layer=RAIL_LAYER, opacity=255, **kwargs): + self.plot(gxPoints, -gyPoints, color=(0, 0, 0, 255), layer=PILGL.AGENT_LAYER, linewidth=2) + def main2(): gl = PILSVG(10, 10) diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index cc496cb94cd2ba0d927749bf813cd449bd70e236..65672a11230a5df52ee05c60537ddcaf6cdbe46c 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -414,7 +414,8 @@ class RenderTool(object): self.render_env_svg(show=show, show_observations=show_observations, show_predictions=show_predictions, - selected_agent=selected_agent + selected_agent=selected_agent, + agents=agents ) else: self.render_env_pil(show=show, @@ -497,7 +498,8 @@ class RenderTool(object): return def render_env_svg( - self, show=False, show_observations=True, show_predictions=False, selected_agent=None + self, show=False, show_observations=True, show_predictions=False, selected_agent=None, + show_agents=True ): """ Renders the environment with SVG support (nice image) @@ -537,49 +539,54 @@ class RenderTool(object): self.gl.build_background_map(targets) - for agent_idx, agent in enumerate(self.env.agents): + if show_agents: + for agent_idx, agent in enumerate(self.env.agents): - if agent is None or agent.position is None: - continue + if agent is None or agent.position is None: + continue - if self.agent_render_variant == AgentRenderVariant.BOX_ONLY: - self.gl.set_cell_occupied(agent_idx, *(agent.position)) - elif self.agent_render_variant == AgentRenderVariant.ONE_STEP_BEHIND or \ - self.agent_render_variant == AgentRenderVariant.ONE_STEP_BEHIND_AND_BOX: # noqa: E125 - if agent.old_position is not None: - position = agent.old_position - direction = agent.direction - old_direction = agent.old_direction + is_malfunction = agent.malfunction_data["malfunction"] > 0 + + if self.agent_render_variant == AgentRenderVariant.BOX_ONLY: + self.gl.set_cell_occupied(agent_idx, *(agent.position)) + elif self.agent_render_variant == AgentRenderVariant.ONE_STEP_BEHIND or \ + self.agent_render_variant == AgentRenderVariant.ONE_STEP_BEHIND_AND_BOX: # noqa: E125 + if agent.old_position is not None: + position = agent.old_position + direction = agent.direction + old_direction = agent.old_direction + else: + position = agent.position + direction = agent.direction + old_direction = agent.direction + + # set_agent_at uses the agent index for the color + if self.agent_render_variant == AgentRenderVariant.ONE_STEP_BEHIND_AND_BOX: + self.gl.set_cell_occupied(agent_idx, *(agent.position)) + self.gl.set_agent_at(agent_idx, *position, old_direction, direction, + selected_agent == agent_idx, rail_grid=env.rail.grid, + show_debug=self.show_debug, clear_debug_text=self.clear_debug_text, + malfunction=is_malfunction) else: position = agent.position direction = agent.direction - old_direction = agent.direction - - # set_agent_at uses the agent index for the color - if self.agent_render_variant == AgentRenderVariant.ONE_STEP_BEHIND_AND_BOX: - self.gl.set_cell_occupied(agent_idx, *(agent.position)) - self.gl.set_agent_at(agent_idx, *position, old_direction, direction, - selected_agent == agent_idx, rail_grid=env.rail.grid, - show_debug=self.show_debug, clear_debug_text=self.clear_debug_text) - else: - position = agent.position - direction = agent.direction - for possible_directions in range(4): - # Is a transition along movement `desired_movement_from_new_cell` to the current cell possible? - isValid = env.rail.get_transition((*agent.position, agent.direction), possible_directions) - if isValid: - direction = possible_directions - - # set_agent_at uses the agent index for the color - self.gl.set_agent_at(agent_idx, *position, agent.direction, direction, - selected_agent == agent_idx, rail_grid=env.rail.grid, - show_debug=self.show_debug, clear_debug_text=self.clear_debug_text) - - # set_agent_at uses the agent index for the color - if self.agent_render_variant == AgentRenderVariant.AGENT_SHOWS_OPTIONS_AND_BOX: - self.gl.set_cell_occupied(agent_idx, *(agent.position)) - self.gl.set_agent_at(agent_idx, *position, agent.direction, direction, selected_agent == agent_idx, - rail_grid=env.rail.grid) + for possible_direction in range(4): + # Is a transition along movement `desired_movement_from_new_cell` to the current cell possible? + isValid = env.rail.get_transition((*agent.position, agent.direction), possible_direction) + if isValid: + direction = possible_direction + + # set_agent_at uses the agent index for the color + self.gl.set_agent_at(agent_idx, *position, agent.direction, direction, + selected_agent == agent_idx, rail_grid=env.rail.grid, + show_debug=self.show_debug, clear_debug_text=self.clear_debug_text, + malfunction=is_malfunction) + + # set_agent_at uses the agent index for the color + if self.agent_render_variant == AgentRenderVariant.AGENT_SHOWS_OPTIONS_AND_BOX: + self.gl.set_cell_occupied(agent_idx, *(agent.position)) + self.gl.set_agent_at(agent_idx, *position, agent.direction, direction, selected_agent == agent_idx, + rail_grid=env.rail.grid, malfunction=is_malfunction) if show_observations: self.render_observation(range(env.get_num_agents()), env.dev_obs_dict)