diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 6778fe8d4f081efa5baa518ecfb73a37a9878837..3d97119e7f7289008525265d412ae5255dbafc69 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -20,10 +20,18 @@ from flatland.core.grid.grid_utils import IntVector2D from flatland.core.transition_map import GridTransitionMap from flatland.envs.agent_utils import EnvAgent, RailAgentStatus from flatland.envs.distance_map import DistanceMap -from flatland.envs.malfunction_generators import no_malfunction_generator, Malfunction, MalfunctionProcessData +from flatland.envs import malfunction_generators as mal_gen +from flatland.envs import rail_generators as rail_gen +from flatland.envs import schedule_generators as sched_gen +# Direct import of objects / classes does not work with circular imports. +# from flatland.envs.malfunction_generators import no_malfunction_generator, Malfunction, MalfunctionProcessData +# from flatland.envs.observations import GlobalObsForRailEnv +# from flatland.envs.rail_generators import random_rail_generator, RailGenerator +# from flatland.envs.schedule_generators import random_schedule_generator, ScheduleGenerator from flatland.envs.observations import GlobalObsForRailEnv -from flatland.envs.rail_generators import random_rail_generator, RailGenerator -from flatland.envs.schedule_generators import random_schedule_generator, ScheduleGenerator + + +import pickle m.patch() @@ -116,11 +124,11 @@ class RailEnv(Environment): def __init__(self, width, height, - rail_generator: RailGenerator = random_rail_generator(), - schedule_generator: ScheduleGenerator = random_schedule_generator(), + rail_generator: rail_gen.RailGenerator = rail_gen.random_rail_generator(), + schedule_generator: sched_gen.ScheduleGenerator = sched_gen.random_schedule_generator(), number_of_agents=1, obs_builder_object: ObservationBuilder = GlobalObsForRailEnv(), - malfunction_generator_and_process_data=no_malfunction_generator(), + malfunction_generator_and_process_data=None, #mal_gen.no_malfunction_generator(), remove_agents_at_target=True, random_seed=1, record_steps=False @@ -162,6 +170,8 @@ class RailEnv(Environment): """ super().__init__() + if malfunction_generator_and_process_data is None: + malfunction_generator_and_process_data = mal_gen.no_malfunction_generator() self.malfunction_generator, self.malfunction_process_data = malfunction_generator_and_process_data self.rail_generator: RailGenerator = rail_generator self.schedule_generator: ScheduleGenerator = schedule_generator @@ -206,7 +216,9 @@ class RailEnv(Environment): # save episode timesteps ie agent positions, orientations. (not yet actions / observations) self.record_steps = record_steps # whether to save timesteps - self.cur_episode = [] # save timesteps in here + # save timesteps in here: [[[row, col, dir, malfunction],...nAgents], ...nSteps] + self.cur_episode = [] + self.list_actions = [] # save actions in here def _seed(self, seed=None): self.np_random, seed = seeding.np_random(seed) @@ -346,6 +358,9 @@ class RailEnv(Environment): # Reset the malfunction generator self.malfunction_generator(reset=True) + # Empty the episode store of agent positions + self.cur_episode = [] + info_dict: Dict = { 'action_required': {i: self.action_required(agent) for i, agent in enumerate(self.agents)}, 'malfunction': { @@ -470,7 +485,7 @@ class RailEnv(Environment): for i_agent in range(self.get_num_agents()): self.dones[i_agent] = True if self.record_steps: - self.record_timestep() + self.record_timestep(action_dict_) return self._get_observations(), self.rewards_dict, self.dones, info_dict @@ -694,7 +709,7 @@ class RailEnv(Environment): cell_free = False return cell_free, new_cell_valid, new_direction, new_position, transition_valid - def record_timestep(self): + def record_timestep(self, dActions): ''' Record the positions and orientations of all agents in memory, in the cur_episode ''' list_agents_state = [] @@ -707,8 +722,11 @@ class RailEnv(Environment): else: pos = (int(agent.position[0]), int(agent.position[1])) # print("pos:", pos, type(pos[0])) - list_agents_state.append([*pos, int(agent.direction)]) + list_agents_state.append( + [*pos, int(agent.direction), agent.malfunction_data["malfunction"] ]) + self.cur_episode.append(list_agents_state) + self.list_actions.append(dActions) def cell_free(self, position: IntVector2D) -> bool: """ @@ -792,155 +810,7 @@ class RailEnv(Environment): """ return Grid4Transitions.get_entry_directions(self.rail.get_full_transitions(row, col)) - def get_full_state_msg(self) -> Packer: - """ - Returns state of environment in msgpack object - """ - grid_data = self.rail.grid.tolist() - agent_data = [agent.to_agent() for agent in self.agents] - malfunction_data: MalfunctionProcessData = self.malfunction_process_data - msgpack.packb(grid_data, use_bin_type=True) - msgpack.packb(agent_data, use_bin_type=True) - msg_data = { - "grid": grid_data, - "agents": agent_data, - "malfunction": malfunction_data, - "max_episode_steps": self._max_episode_steps} - return msgpack.packb(msg_data, use_bin_type=True) - - def get_agent_state_msg(self) -> Packer: - """ - Returns agents information in msgpack object - """ - agent_data = [agent.to_agent() for agent in self.agents] - msg_data = { - "agents": agent_data} - return msgpack.packb(msg_data, use_bin_type=True) - - def get_full_state_dist_msg(self) -> Packer: - """ - Returns environment information with distance map information as msgpack object - """ - grid_data = self.rail.grid.tolist() - agent_data = [agent.to_agent() for agent in self.agents] - msgpack.packb(grid_data, use_bin_type=True) - msgpack.packb(agent_data, use_bin_type=True) - distance_map_data = self.distance_map.get() - malfunction_data: MalfunctionProcessData = self.malfunction_process_data - msgpack.packb(distance_map_data, use_bin_type=True) - msg_data = { - "grid": grid_data, - "agents": agent_data, - "distance_map": distance_map_data, - "malfunction": malfunction_data, - "max_episode_steps": self._max_episode_steps} - return msgpack.packb(msg_data, use_bin_type=True) - - def set_full_state_msg(self, msg_data): - """ - Sets environment state with msgdata object passed as argument - Parameters - ------- - msg_data: msgpack object - """ - data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8') - self.rail.grid = np.array(data["grid"]) - # agents are always reset as not moving - if "agents_static" in data: - self.agents = EnvAgent.load_legacy_static_agent(data["agents_static"]) - else: - self.agents = [EnvAgent(*d[0:12]) for d in data["agents"]] - # setup with loaded data - self.height, self.width = self.rail.grid.shape - self.rail.height = self.height - self.rail.width = self.width - self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False) - - def set_full_state_dist_msg(self, msg_data): - """ - Sets environment grid state and distance map with msgdata object passed as argument - - Parameters - ------- - msg_data: msgpack object - """ - data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8') - self.rail.grid = np.array(data["grid"]) - # agents are always reset as not moving - if "agents_static" in data: - self.agents = EnvAgent.load_legacy_static_agent(data["agents_static"]) - else: - self.agents = [EnvAgent(*d[0:12]) for d in data["agents"]] - if "distance_map" in data.keys(): - self.distance_map.set(data["distance_map"]) - # setup with loaded data - self.height, self.width = self.rail.grid.shape - self.rail.height = self.height - self.rail.width = self.width - self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False) - - def save(self, filename, save_distance_maps=False): - """ - Saves environment and distance map information in a file - - Parameters: - --------- - filename: string - save_distance_maps: bool - """ - if save_distance_maps is True: - if self.distance_map.get() is not None: - if len(self.distance_map.get()) > 0: - with open(filename, "wb") as file_out: - file_out.write(self.get_full_state_dist_msg()) - else: - print("[WARNING] Unable to save the distance map for this environment, as none was found !") - - else: - print("[WARNING] Unable to save the distance map for this environment, as none was found !") - - else: - with open(filename, "wb") as file_out: - file_out.write(self.get_full_state_msg()) - - def save_episode(self, filename): - episode_data = self.cur_episode - msgpack.packb(episode_data, use_bin_type=True) - dict_data = {"episode": episode_data} - # msgpack.packb(msg_data, use_bin_type=True) - with open(filename, "wb") as file_out: - file_out.write(msgpack.packb(dict_data)) - - def load(self, filename): - """ - Load environment with distance map from a file - - Parameters: - ------- - filename: string - """ - with open(filename, "rb") as file_in: - load_data = file_in.read() - self.set_full_state_dist_msg(load_data) - - def load_pkl(self, pkl_data): - """ - Load environment with distance map from a pickle file - - Parameters: - ------- - pkl_data: pickle file - """ - self.set_full_state_msg(pkl_data) - - def load_resource(self, package, resource): - """ - Load environment with distance map from a binary - """ - from importlib_resources import read_binary - load_data = read_binary(package, resource) - self.set_full_state_msg(load_data) def _exp_distirbution_synced(self, rate: float) -> float: """ diff --git a/flatland/envs/rail_env_utils.py b/flatland/envs/rail_env_utils.py index 823fe4f11555298c6ab854c40320dfc9d37d9f4d..9143a86334401889fbd9d2494b2f97a8e6ef5435 100644 --- a/flatland/envs/rail_env_utils.py +++ b/flatland/envs/rail_env_utils.py @@ -8,7 +8,9 @@ from flatland.envs.schedule_generators import schedule_from_file def load_flatland_environment_from_file(file_name: str, load_from_package: str = None, - obs_builder_object: ObservationBuilder = None) -> RailEnv: + obs_builder_object: ObservationBuilder = None, + record_steps = False, + ) -> RailEnv: """ Parameters ---------- @@ -31,6 +33,9 @@ def load_flatland_environment_from_file(file_name: str, max_depth=2, predictor=ShortestPathPredictorForRailEnv(max_depth=10)) environment = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name, load_from_package), - schedule_generator=schedule_from_file(file_name, load_from_package), number_of_agents=1, - obs_builder_object=obs_builder_object) + schedule_generator=schedule_from_file(file_name, load_from_package), + number_of_agents=1, + obs_builder_object=obs_builder_object, + record_steps=record_steps, + ) return environment diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index cb8569643698f68612720a175192af885a222d6d..1a73acb93908fda477be25637631a465df8b9183 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -3,7 +3,6 @@ import sys import warnings from typing import Callable, Tuple, Optional, Dict, List -import msgpack import numpy as np from numpy.random.mtrand import RandomState @@ -16,6 +15,7 @@ from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transition_map import GridTransitionMap from flatland.envs.grid4_generators_utils import connect_rail_in_grid_map, connect_straight_line_in_grid_map, \ fix_inner_nodes, align_cell_to_city +from flatland.envs import persistence RailGeneratorProduct = Tuple[GridTransitionMap, Optional[Dict]] RailGenerator = Callable[[int, int, int, int], RailGeneratorProduct] @@ -240,21 +240,15 @@ def rail_from_file(filename, load_from_package=None) -> RailGenerator: """ def generator(width: int, height: int, num_agents: int, num_resets: int = 0, - np_random: RandomState = None) -> RailGenerator: + np_random: RandomState = None) -> List: + env_dict = persistence.RailEnvPersister.load_env_dict(filename, load_from_package=load_from_package) rail_env_transitions = RailEnvTransitions() - if load_from_package is not None: - from importlib_resources import read_binary - load_data = read_binary(load_from_package, filename) - else: - with open(filename, "rb") as file_in: - load_data = file_in.read() - data = msgpack.unpackb(load_data, use_list=False) - grid = np.array(data[b"grid"]) + grid = np.array(env_dict["grid"]) rail = GridTransitionMap(width=np.shape(grid)[1], height=np.shape(grid)[0], transitions=rail_env_transitions) rail.grid = grid - if b"distance_map" in data.keys(): - distance_map = data[b"distance_map"] + if "distance_map" in env_dict: + distance_map = env_dict["distance_map"] if len(distance_map) > 0: return rail, {'distance_map': distance_map} return [rail, None] @@ -634,7 +628,8 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_ max_feasible_cities = min(max_num_cities, ((height - 2) // (2 * (city_radius + 1))) * ((width - 2) // (2 * (city_radius + 1)))) if max_feasible_cities < 2: - sys.exit("[ABORT] Cannot fit more than one city in this map, no feasible environment possible! Aborting.") + # sys.exit("[ABORT] Cannot fit more than one city in this map, no feasible environment possible! Aborting.") + raise ValueError("ERROR: Cannot fit more than one city in this map, no feasible environment possible!") # Evenly distribute cities if grid_mode: diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py index a84264e04fa2039ea09f9a53df52ab692800ff30..d04047495b8ebf79a82faa472a5608f6da2a0ac2 100644 --- a/flatland/envs/schedule_generators.py +++ b/flatland/envs/schedule_generators.py @@ -2,7 +2,6 @@ import warnings from typing import Tuple, List, Callable, Mapping, Optional, Any -import msgpack import numpy as np from numpy.random.mtrand import RandomState @@ -10,6 +9,7 @@ from flatland.core.grid.grid4_utils import get_new_position from flatland.core.transition_map import GridTransitionMap from flatland.envs.agent_utils import EnvAgent from flatland.envs.schedule_utils import Schedule +from flatland.envs import persistence AgentPosition = Tuple[int, int] ScheduleGenerator = Callable[[GridTransitionMap, int, Optional[Any], Optional[int]], Schedule] @@ -299,22 +299,24 @@ def schedule_from_file(filename, load_from_package=None) -> ScheduleGenerator: def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None, num_resets: int = 0, np_random: RandomState = None) -> Schedule: - if load_from_package is not None: - from importlib_resources import read_binary - load_data = read_binary(load_from_package, filename) - else: - with open(filename, "rb") as file_in: - load_data = file_in.read() - data = msgpack.unpackb(load_data, use_list=False, encoding='utf-8') - if "agents_static" in data: - agents = EnvAgent.load_legacy_static_agent(data["agents_static"]) - else: - agents = [EnvAgent(*d[0:12]) for d in data["agents"]] - if "max_episode_steps" in data: - max_episode_steps = data["max_episode_steps"] - else: - # If no max time was found return 0. - max_episode_steps = 0 + # if load_from_package is not None: + # from importlib_resources import read_binary + # load_data = read_binary(load_from_package, filename) + # else: + # with open(filename, "rb") as file_in: + # load_data = file_in.read() + # data = msgpack.unpackb(load_data, use_list=False, encoding='utf-8') + # if "agents_static" in data: + # agents = EnvAgent.load_legacy_static_agent(data["agents_static"]) + # else: + # agents = [EnvAgent(*d[0:12]) for d in data["agents"]] + + env_dict = persistence.RailEnvPersister.load_env_dict(filename, load_from_package=load_from_package) + + max_episode_steps = env_dict.get("max_episode_steps", 0) + + agents = env_dict["agents"] + # setup with loaded data agents_position = [a.initial_position for a in agents] agents_direction = [a.direction for a in agents] diff --git a/flatland/utils/editor.py b/flatland/utils/editor.py index cb779b6fbb0fec0a0bed3b36d5cb5b7358d18925..1438a3ebea1738cf1feedb5a5ff38bcd308dee0e 100644 --- a/flatland/utils/editor.py +++ b/flatland/utils/editor.py @@ -20,7 +20,7 @@ class EditorMVC(object): """ EditorMVC - a class to encompass and assemble the Jupyter Editor Model-View-Controller. """ - def __init__(self, env=None, sGL="PIL"): + def __init__(self, env=None, sGL="PIL", env_filename="temp.mpk"): """ Create an Editor MVC assembly around a railenv, or create one if None. """ if env is None: @@ -29,7 +29,7 @@ class EditorMVC(object): env.reset() - self.editor = EditorModel(env) + self.editor = EditorModel(env, env_filename=env_filename) self.editor.view = self.view = View(self.editor, sGL=sGL) self.view.controller = self.editor.controller = self.controller = Controller(self.editor, self.view) self.view.init_canvas() @@ -40,9 +40,10 @@ class View(object): """ The Jupyter Editor View - creates and holds the widgets comprising the Editor. """ - def __init__(self, editor, sGL="MPL"): + def __init__(self, editor, sGL="MPL", screen_width=1200, screen_height=1200): self.editor = self.model = editor self.sGL = sGL + self.xyScreen = (screen_width, screen_height) def display(self): self.output_generator.clear_output() @@ -139,7 +140,8 @@ class View(object): def new_env(self): """ Tell the view to update its graphics when a new env is created. """ - self.oRT = rt.RenderTool(self.editor.env, gl=self.sGL) + self.oRT = rt.RenderTool(self.editor.env, gl=self.sGL, show_debug=True, + screen_height=self.xyScreen[1], screen_width=self.xyScreen[0]) def redraw(self): with self.output_generator: @@ -151,10 +153,12 @@ class View(object): if hasattr(a, 'old_direction') is False: a.old_direction = a.direction - self.oRT.render_env(agents=True, + self.oRT.render_env(show_agents=True, + show_inactive_agents=True, show=False, selected_agent=self.model.selected_agent, - show_observations=False) + show_observations=False, + ) img = self.oRT.get_image() self.wImage.data = img @@ -180,7 +184,9 @@ class View(object): nY = np.floor((self.yxSize[1] - self.yxBase[1]) / self.model.env.width) rc_cell[0] = max(0, min(np.floor(rc_cell[0] / nY), self.model.env.height - 1)) rc_cell[1] = max(0, min(np.floor(rc_cell[1] / nX), self.model.env.width - 1)) - return rc_cell + + # Using numpy arrays for coords not currently supported downstream in the env, observations, etc + return tuple(rc_cell) def log(self, *args, **kwargs): if self.output_generator: @@ -282,11 +288,14 @@ class Controller(object): else: self.lrcStroke = [] - if self.model.selected_agent is not None: - self.lrcStroke = [] - while len(q_events) > 0: - t, x, y = q_events.popleft() - return + # JW: I think this clause causes all editing to fail once an agent is selected. + # I also can't see why it's necessary. So I've if-falsed it out. + if False: + if self.model.selected_agent is not None: + self.lrcStroke = [] + while len(q_events) > 0: + t, x, y = q_events.popleft() + return # Process the events in our queue: # Draw a black square to indicate a trail @@ -330,7 +339,8 @@ class Controller(object): if agent is None: continue if agent_idx == self.model.selected_agent: - agent.direction = (agent.direction + 1) % 4 + agent.initial_direction = (agent.initial_direction + 1) % 4 + agent.direction = agent.initial_direction agent.old_direction = agent.direction self.model.redraw() @@ -373,7 +383,7 @@ class Controller(object): class EditorModel(object): - def __init__(self, env): + def __init__(self, env, env_filename="temp.mpk"): self.view = None self.env = env self.regen_size_width = 10 @@ -387,7 +397,7 @@ class EditorModel(object): self.debug_move_bool = False self.wid_output = None self.draw_mode = "Draw" - self.env_filename = "temp.pkl" + self.env_filename = env_filename self.set_env(env) self.selected_agent = None self.thread = None @@ -658,6 +668,7 @@ class EditorModel(object): self.env = env self.env.reset(regenerate_rail=True) self.fix_env() + self.selected_agent = None # clear the selected agent. self.set_env(self.env) self.view.new_env() self.redraw() @@ -670,7 +681,11 @@ class EditorModel(object): def find_agent_at(self, cell_row_col): for agent_idx, agent in enumerate(self.env.agents): - if tuple(agent.position) == tuple(cell_row_col): + if agent.position is None: + rc_pos = agent.initial_position + else: + rc_pos = agent.position + if tuple(rc_pos) == tuple(cell_row_col): return agent_idx return None @@ -685,18 +700,33 @@ class EditorModel(object): # Has the user clicked on an existing agent? agent_idx = self.find_agent_at(cell_row_col) + # This is in case we still have a selected agent even though the env has been recreated + # with no agents. + if (self.selected_agent is not None) and (self.selected_agent > len(self.env.agents)): + self.selected_agent = None + + # Defensive coding below - for cell_row_col to be a tuple, not a numpy array: + # numpy array breaks various things when loading the env. + if agent_idx is None: # No if self.selected_agent is None: # Create a new agent and select it. - agent = EnvAgent(position=cell_row_col, direction=0, target=cell_row_col, moving=False) + agent = EnvAgent(initial_position=tuple(cell_row_col), + initial_direction=0, + direction=0, + target=tuple(cell_row_col), + moving=False, + ) self.selected_agent = self.env.add_agent(agent) + # self.env.set_agent_active(agent) self.view.oRT.update_background() else: # Move the selected agent to this cell agent = self.env.agents[self.selected_agent] - agent.position = cell_row_col - agent.old_position = cell_row_col + agent.initial_position = tuple(cell_row_col) + agent.position = tuple(cell_row_col) + agent.old_position = tuple(cell_row_col) else: # Yes # Have they clicked on the agent already selected? @@ -711,7 +741,7 @@ class EditorModel(object): def add_target(self, rc_cell): if self.selected_agent is not None: - self.env.agents[self.selected_agent].target = rc_cell + self.env.agents[self.selected_agent].target = tuple(rc_cell) self.view.oRT.update_background() self.redraw() diff --git a/flatland/utils/flask_util.py b/flatland/utils/flask_util.py new file mode 100644 index 0000000000000000000000000000000000000000..e30fd72dfa30073f1404217257b20f8ee582c617 --- /dev/null +++ b/flatland/utils/flask_util.py @@ -0,0 +1,270 @@ + + +from flask import Flask, request, redirect, Response +from flask_socketio import SocketIO, emit +from flask_cors import CORS, cross_origin +import threading +import os +import time +import webbrowser +import numpy as np +import typing +import socket + +from flatland.envs.rail_env import RailEnv, RailEnvActions + + +#async_mode = None + + +class simple_flask_server(object): + """ I wanted to wrap the flask server in a class but this seems to be quite hard; + eg see: https://stackoverflow.com/questions/40460846/using-flask-inside-class + I have made a messy sort of singleton pattern. + It might be easier to revert to the "standard" flask global functions + decorators. + """ + + static_folder = os.path.join(os.getcwd(), "static") + print("Flask static folder: ", static_folder) + app = Flask(__name__, + static_url_path='', + static_folder=static_folder) + + socketio = SocketIO(app, cors_allowed_origins='*') + + # This is the original format for the I/O. + # It comes from the format used in the msgpack saved episode. + # The lists here are truncated from the original - see CK's original main.py, in flatland-render. + gridmap = [ + # list of rows (?). Each cell is a 16-char binary string. Yes this is inefficient! + ["0000000000000000", "0010000000000000", "0000000000000000", "0000000000000000", "0010000000000000", "0000000000000000", "0000000000000000", "0000000000000000", "0010000000000000", "0000000000000000"], + ["0000000000000000", "1000000000100000", "0000000000000000", "0000000000000000", "0000000001001000", "0001001000000000", "0010000000000000", "0000000000000000", "1000000000100000", "0000000000000000"], # ... + ] + agents_static = [ + # [initial position], initial direction, [target], 0 (?) + [[7, 9], 2, [3, 5], 0, + # Speed and malfunction params + {"position_fraction": 0, "speed": 1, "transition_action_on_cellexit": 3}, + {"malfunction": 0, "malfunction_rate": 0, "next_malfunction": 0, "nr_malfunctions": 0}], + [[8, 8], 1, [1, 6], 0, + {"position_fraction": 0, "speed": 1, "transition_action_on_cellexit": 2}, + {"malfunction": 0, "malfunction_rate": 0, "next_malfunction": 0, "nr_malfunctions": 0}], + [[3, 7], 2, [0, 1], 0, + {"position_fraction": 0, "speed": 1, "transition_action_on_cellexit": 2}, + {"malfunction": 0, "malfunction_rate": 0, "next_malfunction": 0, "nr_malfunctions": 0}] + ] + + # "actions" are not really actions, but [row, col, direction] for each agent, at each time step + # This format does not yet handle agents which are in states inactive or done_removed + actions= [ + [[7, 9, 2], [8, 8, 1], [3, 7, 2]], [[7, 9, 2], [8, 7, 3], [2, 7, 0]], # ... + ] + + def __init__(self, env): + # Some ugly stuff with cls and self here + cls = self.__class__ + cls.instance = self # intended as singleton + + cls.app.config['CORS_HEADERS'] = 'Content-Type' + cls.app.config['SECRET_KEY'] = 'secret!' + + self.app = cls.app + self.socketio = cls.socketio + self.env = env + self.renderer_ready = False # to indicate env background not yet drawn + self.port = None # we only assign a port once we start the background server... + self.host = None + + def run_flask_server(self, host='127.0.0.1', port=None): + self.host = host + + if port is None: + self.port = self._find_available_port(host) + else: + self.port = port + + self.socketio.run(simple_flask_server.app, host=host, port=self.port) + + def run_flask_server_in_thread(self, host="127.0.0.1", port=None): + # daemon=True so that this thread exits when the main / foreground thread exits, + # usually when the episode finishes. + self.thread = threading.Thread( + target=self.run_flask_server, + kwargs={"host": host, "port": port}, + daemon=True) + self.thread.start() + # short sleep to allow thread to start (may be unnnecessary) + time.sleep(1) + + def open_browser(self): + webbrowser.open("http://localhost:{}".format(self.port)) + # short sleep to allow browser to request the page etc (may be unnecessary) + time.sleep(1) + + def _test_listen_port(self, host: str, port: int): + oSock = socket.socket() + try: + oSock.bind((host, port)) + except OSError: + return False # The port is not available + + del oSock # This should release the port + return True # The port is available + + def _find_available_port(self, host: str, port_start: int = 8080): + for nPort in range(port_start, port_start+100): + if self._test_listen_port(host, nPort): + return nPort + print("Could not find an available port for Flask to listen on!") + return None + + def get_endpoint_url(self): + return "http://{}:{}".format(self.host, self.port) + + @app.route('/', methods=['GET']) + def home(): + # redirects from "/" to "/index.html" which is then served from static. + # print("Here - / - cwd:", os.getcwd()) + return redirect("index.html") + + @staticmethod + @socketio.on('connect') + def connected(): + ''' + When the JS Renderer connects, + this method will send the env and agent information + ''' + cls = simple_flask_server + print('Client connected') + + # Do we really need this? + cls.socketio.emit('message', {'message': 'Connected'}) + + print('Send Env grid and agents') + # cls.socketio.emit('grid', {'grid': cls.gridmap, 'agents_static': cls.agents_static}, broadcast=False) + cls.instance.send_env() + print("Env and agents sent") + + @staticmethod + @socketio.on('disconnect') + def disconnected(): + print('Client disconnected') + + def send_actions(self, dict_actions): + ''' Sends the agent positions and directions, not really actions. + ''' + llAgents = self.agents_to_list() + self.socketio.emit('agentsAction', {'actions': llAgents}) + + def send_observation(self, agent_handles, dict_obs): + """ Send an observation. + TODO: format observation message. + """ + self.socketio.emit("observation", {"agents": agent_handles, "observations": dict_obs}) + + def send_env(self): + """ Sends the env, ie the rail grid, and the agents (static) information + """ + # convert 2d array of int into 2d array of 16char strings + g2sGrid = np.vectorize(np.binary_repr)(self.env.rail.grid, width=16) + llGrid = g2sGrid.tolist() + llAgents = self.agents_to_list_dict() + self.socketio.emit('grid', { + 'grid': llGrid, + 'agents_static': llAgents + }, + broadcast=False) + + def send_env_and_wait(self): + for iAttempt in range(30): + if self.is_renderer_ready(): + print("Background Render complete") + break + else: + print("Waiting for browser to signal that rendering complete") + time.sleep(1) + + @staticmethod + @socketio.on('renderEvent') + def handle_render_event(data): + cls=simple_flask_server + self = cls.instance + print('RenderEvent!!') + print('status: ' + data['status']) + print('message: ' + data['message']) + + if data['status'] == 'listening': + self.renderer_ready = True + + def is_renderer_ready(self): + return self.renderer_ready + + def agents_to_list_dict(self): + ''' Create a list of lists / dicts for serialisation + Maps from the internal representation in EnvAgent to + the schema used by the Javascript renderer. + ''' + llAgents = [] + for agent in self.env.agents: + if agent.position is None: + # the int()s are to convert from numpy int64 which causes problems in serialization + # to plain old python int + lPos = [int(agent.initial_position[0]), int(agent.initial_position[1])] + else: + lPos = [int(agent.position[0]), int(agent.position[1])] + + lAgent = [ + lPos, + int(agent.direction), + [int(agent.target[0]), int(agent.target[1])], 0, + { # dummy values: + "position_fraction": 0, + "speed": 1, + "transition_action_on_cellexit": 3 + }, + { + "malfunction": 0, + "malfunction_rate": 0, + "next_malfunction": 0, + "nr_malfunctions": 0 + } + ] + llAgents.append(lAgent) + return llAgents + + def agents_to_list(self): + llAgents = [] + for agent in self.env.agents: + if agent.position is None: + lPos = [int(agent.initial_position[0]), int(agent.initial_position[1])] + else: + lPos = [int(agent.position[0]), int(agent.position[1])] + iDir = int(agent.direction) + + lAgent = [*lPos, iDir] + + llAgents.append(lAgent) + return llAgents + + + +def main1(): + + print('Run Flask SocketIO Server') + server = simple_flask_server() + threading.Thread(target=server.run_flask_server).start() + # Open Browser + webbrowser.open('http://127.0.0.1:8080') + + print('Send Action') + for i in server.actions: + time.sleep(1) + print('send action') + server.socketio.emit('agentsAction', {'actions': i}) + + + + + +if __name__ == "__main__": + main1() \ No newline at end of file diff --git a/flatland/utils/graphics_layer.py b/flatland/utils/graphics_layer.py index 2f51df2a1486adb6159a809b6f49e18bc8c873d9..e38a3694417a4cbaffce3a8ee26ac3fbd47a521d 100644 --- a/flatland/utils/graphics_layer.py +++ b/flatland/utils/graphics_layer.py @@ -25,6 +25,13 @@ class GraphicsLayer(object): pass def pause(self, seconds=0.00001): + """ deprecated """ + pass + + def idle(self, seconds=0.00001): + """ process any display events eg redraw, resize. + Return only after the given number of seconds, ie idle / loop until that number. + """ pass def clf(self): diff --git a/flatland/utils/graphics_pgl.py b/flatland/utils/graphics_pgl.py new file mode 100644 index 0000000000000000000000000000000000000000..299459cc5106566f78e94e0b961bfe13d3ee6a7c --- /dev/null +++ b/flatland/utils/graphics_pgl.py @@ -0,0 +1,152 @@ + +import pyglet as pgl +import time + +from PIL import Image +# from numpy import array +# from pkg_resources import resource_string as resource_bytes + +# from flatland.utils.graphics_layer import GraphicsLayer +from flatland.utils.graphics_pil import PILSVG + + +class PGLGL(PILSVG): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.window_open = False # means the window has not yet been opened. + self.close_requested = False # user has clicked + self.closed = False # windows has been closed (currently, we leave the env still running) + + def open_window(self): + print("open_window - pyglet") + assert self.window_open is False, "Window is already open!" + self.window = pgl.window.Window(resizable=True, vsync=False) + #self.__class__.window.title("Flatland") + #self.__class__.window.configure(background='grey') + self.window_open = True + + + @self.window.event + def on_draw(): + #print("pyglet draw event") + self.window.clear() + self.show(from_event=True) + #print("pyglet draw event done") + + + @self.window.event + def on_resize(width, height): + #print(f"The window was resized to {width}, {height}") + self.show(from_event=True) + self.window.dispatch_event("on_draw") + #print("pyglet resize event done") + + @self.window.event + def on_close(): + self.close_requested = True + + + def close_window(self): + self.window.close() + self.closed=True + + def show(self, block=False, from_event=False): + if not self.window_open: + self.open_window() + + if self.close_requested: + if not self.closed: + self.close_window() + return + + #tStart = time.time() + self._processEvents() + + pil_img = self.alpha_composite_layers() + pil_img_resized = pil_img.resize((self.window.width, self.window.height), resample=Image.NEAREST) + + # convert our PIL image to pyglet: + bytes_image = pil_img_resized.tobytes() + pgl_image = pgl.image.ImageData(pil_img_resized.width, pil_img_resized.height, + #self.window.width, self.window.height, + 'RGBA', + bytes_image, pitch=-pil_img_resized.width * 4) + + pgl_image.blit(0,0) + #tEnd = time.time() + #print("show time: ", tEnd - tStart) + + def _processEvents(self): + """ This is the replacement for a custom event loop for Pyglet. + The lines below are typical of Pyglet examples. + Manually resizing the window is still very clunky. + """ + #print("process events...", end="") + pgl.clock.tick() + #for window in pgl.app.windows: + if not self.closed: + self.window.switch_to() + self.window.dispatch_events() + self.window.flip() + #print(" events done") + + + + def idle(self, seconds=0.00001): + tStart = time.time() + tEnd = tStart + seconds + while (time.time() < tEnd): + self._processEvents() + #self.show() + time.sleep(min(seconds, 0.1)) + + +def test_pyglet(): + oGL = PGLGL(400,300) + time.sleep(2) + + +def test_event_loop(): + """ Shows how it should work with the standard event loop + Resizing is fairly smooth (ie runs at least 10-20x a second) + """ + + + window = pgl.window.Window(resizable=True) + pil_img = Image.open("notebooks/simple_example_3.png") + + def show(): + pil_img_resized = pil_img.resize((window.width, window.height), resample=Image.NEAREST) + bytes_image = pil_img_resized.tobytes() + pgl_image = pgl.image.ImageData(pil_img_resized.width, pil_img_resized.height, + #self.window.width, self.window.height, + 'RGBA', + bytes_image, pitch=-pil_img_resized.width * 4) + pgl_image.blit(0,0) + + @window.event + def on_draw(): + print("pyglet draw event") + window.clear() + show() + print("pyglet draw event done") + + + @window.event + def on_resize(width, height): + print(f"The window was resized to {width}, {height}") + #show() + print("pyglet resize event done") + + @window.event + def on_close(): + #self.close_requested = True + print("close") + + pgl.app.run() + + +if __name__=="__main__": + #test_pyglet() + test_event_loop() \ No newline at end of file diff --git a/flatland/utils/graphics_pil.py b/flatland/utils/graphics_pil.py index 0e464f03d5060b8d8db4462ed5d6c640a967b2e3..7c5d2fe3f3fe5da9812c8919897a08b115064019 100644 --- a/flatland/utils/graphics_pil.py +++ b/flatland/utils/graphics_pil.py @@ -1,10 +1,10 @@ import io import os import time -import tkinter as tk +#import tkinter as tk import numpy as np -from PIL import Image, ImageDraw, ImageTk, ImageFont +from PIL import Image, ImageDraw, ImageFont from numpy import array from pkg_resources import resource_string as resource_bytes @@ -32,7 +32,7 @@ from flatland.core.grid.rail_env_grid import RailEnvTransitions # noqa: E402 class PILGL(GraphicsLayer): # tk.Tk() must be a singleton! # https://stackoverflow.com/questions/26097811/image-pyimage2-doesnt-exist - window = tk.Tk() + # window = tk.Tk() RAIL_LAYER = 0 PREDICTION_PATH_LAYER = 1 @@ -85,7 +85,7 @@ class PILGL(GraphicsLayer): self.agent_colors = [self.rgb_s2i(sColor) for sColor in sColors.split("#")] self.n_agent_colors = len(self.agent_colors) - self.window_open = False + # self.window_open = False self.firstFrame = True self.old_background_image = (None, None, None) self.create_layers() @@ -164,15 +164,10 @@ class PILGL(GraphicsLayer): self.draw_image_xy(pil_img, xyPixLeftTop, layer=layer) def open_window(self): - assert self.window_open is False, "Window is already open!" - self.__class__.window.title("Flatland") - self.__class__.window.configure(background='grey') - self.window_open = True + pass def close_window(self): - self.panel.destroy() - # quit but not destroy! - self.__class__.window.quit() + pass def text(self, xPx, yPx, strText, layer=RAIL_LAYER): xyPixLeftTop = (xPx, yPx) @@ -194,28 +189,14 @@ class PILGL(GraphicsLayer): self.create_layer(iLayer=PILGL.PREDICTION_PATH_LAYER, clear=True) def show(self, block=False): - img = self.alpha_composite_layers() - - if not self.window_open: - self.open_window() - - tkimg = ImageTk.PhotoImage(img) - - if self.firstFrame: - # Do TK actions for a new panel (not sure what they really do) - self.panel = tk.Label(self.window, image=tkimg) - self.panel.pack(side="bottom", fill="both", expand="yes") - else: - # update the image in situ - self.panel.configure(image=tkimg) - self.panel.image = tkimg - - self.__class__.window.update() - self.firstFrame = False + print("show() - ", self.__class__) def pause(self, seconds=0.00001): pass + def idle(self, seconds=0.00001): + pass + def alpha_composite_layers(self): img = self.layers[0] for img2 in self.layers[1:]: @@ -316,7 +297,7 @@ class PILSVG(PILGL): return pil_img def load_buildings(self): - dBuildingFiles = [ + lBuildingFiles = [ "Buildings-Bank.svg", "Buildings-Bar.svg", "Buildings-Wohnhaus.svg", @@ -338,13 +319,17 @@ class PILSVG(PILGL): "Buildings-Fabrik_I.svg" ] - imgBg = self.pil_from_svg_file('svg', "Background_city.svg") + imgBg = self.pil_from_svg_file('flatland.svg', "Background_city.svg") + imgBg = imgBg.convert("RGBA") + #print("imgBg mode:", imgBg.mode) - self.dBuildings = [] - for sFile in dBuildingFiles: - img = self.pil_from_svg_file('svg', sFile) + self.lBuildings = [] + for sFile in lBuildingFiles: + #print("Loading:", sFile) + img = self.pil_from_svg_file('flatland.svg', sFile) + #print("img mode:", img.mode) img = Image.alpha_composite(imgBg, img) - self.dBuildings.append(img) + self.lBuildings.append(img) def load_scenery(self): scenery_files = [ @@ -371,31 +356,31 @@ class PILSVG(PILGL): "Scenery_Water.svg" ] - img_back_ground = self.pil_from_svg_file('svg', "Background_Light_green.svg") + img_back_ground = self.pil_from_svg_file('flatland.svg', "Background_Light_green.svg").convert("RGBA") - self.scenery_background_white = self.pil_from_svg_file('svg', "Background_white.svg") + self.scenery_background_white = self.pil_from_svg_file('flatland.svg', "Background_white.svg").convert("RGBA") self.scenery = [] for file in scenery_files: - img = self.pil_from_svg_file('svg', file) + img = self.pil_from_svg_file('flatland.svg', file) img = Image.alpha_composite(img_back_ground, img) self.scenery.append(img) self.scenery_d2 = [] for file in scenery_files_d2: - img = self.pil_from_svg_file('svg', file) + img = self.pil_from_svg_file('flatland.svg', file) img = Image.alpha_composite(img_back_ground, img) self.scenery_d2.append(img) self.scenery_d3 = [] for file in scenery_files_d3: - img = self.pil_from_svg_file('svg', file) + img = self.pil_from_svg_file('flatland.svg', file) img = Image.alpha_composite(img_back_ground, img) self.scenery_d3.append(img) self.scenery_water = [] for file in scenery_files_water: - img = self.pil_from_svg_file('svg', file) + img = self.pil_from_svg_file('flatland.svg', file) img = Image.alpha_composite(img_back_ground, img) self.scenery_water.append(img) @@ -448,10 +433,10 @@ class PILSVG(PILGL): whitefilter="Background_white_filter.svg") # Load station and recolorize them - station = self.pil_from_svg_file("svg", "Bahnhof_#d50000_target.svg") + station = self.pil_from_svg_file('flatland.svg', "Bahnhof_#d50000_target.svg") self.station_colors = self.recolor_image(station, [0, 0, 0], self.agent_colors, False) - cell_occupied = self.pil_from_svg_file("svg", "Cell_occupied.svg") + cell_occupied = self.pil_from_svg_file('flatland.svg', "Cell_occupied.svg") self.cell_occupied = self.recolor_image(cell_occupied, [0, 0, 0], self.agent_colors, False) # Merge them with the regular rails. @@ -480,14 +465,14 @@ class PILSVG(PILGL): transition_16_bit_string = "".join(transition_16_bit) binary_trans = int(transition_16_bit_string, 2) - pil_rail = self.pil_from_svg_file('svg', file) + pil_rail = self.pil_from_svg_file('flatland.svg', file).convert("RGBA") if background_image is not None: - img_bg = self.pil_from_svg_file('svg', background_image) + img_bg = self.pil_from_svg_file('flatland.svg', background_image).convert("RGBA") pil_rail = Image.alpha_composite(img_bg, pil_rail) if whitefilter is not None: - img_bg = self.pil_from_svg_file('svg', whitefilter) + img_bg = self.pil_from_svg_file('flatland.svg', whitefilter).convert("RGBA") pil_rail = Image.alpha_composite(pil_rail, img_bg) if rotate: @@ -535,13 +520,13 @@ class PILSVG(PILGL): if binary_trans == 0: if self.background_grid[col][row] <= 4 + np.ceil(((col * row + col) % 10) / city_size): a = int(self.background_grid[col][row]) - a = a % len(self.dBuildings) + a = a % len(self.lBuildings) if (col + row + col * row) % 13 > 11: pil_track = self.scenery[a % len(self.scenery)] else: if (col + row + col * row) % 3 == 0: - a = (a + (col + row + col * row)) % len(self.dBuildings) - pil_track = self.dBuildings[a] + a = (a + (col + row + col * row)) % len(self.lBuildings) + pil_track = self.lBuildings[a] elif ((self.background_grid[col][row] > 5 + ((col * row + col) % 3)) or ((col ** 3 + row ** 2 + col * row) % 10 == 0)): a = int(self.background_grid[col][row]) - 4 @@ -579,7 +564,7 @@ class PILSVG(PILGL): if target is not None: if is_selected: - svgBG = self.pil_from_svg_file("svg", "Selected_Target.svg") + svgBG = self.pil_from_svg_file('flatland.svg', "Selected_Target.svg") self.clear_layer(PILGL.SELECTED_TARGET_LAYER, 0) self.draw_image_row_col(svgBG, (row, col), layer=PILGL.SELECTED_TARGET_LAYER) @@ -619,7 +604,7 @@ class PILSVG(PILGL): for directions, path_svg in file_directory.items(): in_direction, out_direction = directions - pil_zug = self.pil_from_svg_file("svg", path_svg) + pil_zug = self.pil_from_svg_file('flatland.svg', path_svg) # Rotate both the directions and the image and save in the dict for rot_direction in range(4): @@ -649,7 +634,7 @@ class PILSVG(PILGL): self.draw_image_row_col(self.scenery_background_white, (row, col), layer=PILGL.RAIL_LAYER) if is_selected: - bg_svg = self.pil_from_svg_file("svg", "Selected_Agent.svg") + bg_svg = self.pil_from_svg_file('flatland.svg', "Selected_Agent.svg") self.clear_layer(PILGL.SELECTED_AGENT_LAYER, 0) self.draw_image_row_col(bg_svg, (row, col), layer=PILGL.SELECTED_AGENT_LAYER) if show_debug: diff --git a/flatland/utils/graphics_tkpil.py b/flatland/utils/graphics_tkpil.py new file mode 100644 index 0000000000000000000000000000000000000000..7e89e734a170e746c07a4779f4f12e2ff88cf42e --- /dev/null +++ b/flatland/utils/graphics_tkpil.py @@ -0,0 +1,52 @@ + +import tkinter as tk + +from PIL import ImageTk +# from numpy import array +# from pkg_resources import resource_string as resource_bytes + +# from flatland.utils.graphics_layer import GraphicsLayer +from flatland.utils.graphics_pil import PILSVG + + +class TKPILGL(PILSVG): + # tk.Tk() must be a singleton! + # https://stackoverflow.com/questions/26097811/image-pyimage2-doesnt-exist + window = tk.Tk() + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.window_open = False + + def open_window(self): + print("open_window - tk") + assert self.window_open is False, "Window is already open!" + self.__class__.window.title("Flatland") + self.__class__.window.configure(background='grey') + self.window_open = True + + def close_window(self): + self.panel.destroy() + # quit but not destroy! + self.__class__.window.quit() + + def show(self, block=False): + # print("show - ", self.__class__) + img = self.alpha_composite_layers() + + if not self.window_open: + self.open_window() + + tkimg = ImageTk.PhotoImage(img) + + if self.firstFrame: + # Do TK actions for a new panel (not sure what they really do) + self.panel = tk.Label(self.window, image=tkimg) + self.panel.pack(side="bottom", fill="both", expand="yes") + else: + # update the image in situ + self.panel.configure(image=tkimg) + self.panel.image = tkimg + + self.__class__.window.update() + self.firstFrame = False diff --git a/requirements_dev.txt b/requirements_dev.txt index b71ad94971c59c7a2ef18f038f9c498330a92b1a..e971cdd466203b5253e9700001ba0733dbf419c0 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -6,11 +6,10 @@ Click>=7.0 crowdai-api>=0.1.21 numpy>=1.16.2 recordtype>=1.3 -xarray>=0.11.3 matplotlib>=3.0.2 Pillow>=5.4.1 CairoSVG>=2.3.1 -msgpack>=0.6.1 +msgpack>=1.0.0 msgpack-numpy>=0.4.4.0 svgutils>=0.3.1 pyarrow>=0.13.0