diff --git a/examples/introduction_flatland_2_1.py b/examples/introduction_flatland_2_1.py index 3ec2d0d4cf2c63bd916b74d06c21a73d589c8c60..8ea267c18accf0a28f964bff5aecdbc70b50016e 100644 --- a/examples/introduction_flatland_2_1.py +++ b/examples/introduction_flatland_2_1.py @@ -63,7 +63,7 @@ schedule_generator = sparse_schedule_generator(speed_ration_map) # We can furthermore pass stochastic data to the RailEnv constructor which will allow for stochastic malfunctions # during an episode. -stochastic_data = MalfunctionParameters(malfunction_rate=10000, # Rate of malfunction occurence +stochastic_data = MalfunctionParameters(malfunction_rate=1/10000, # Rate of malfunction occurence min_duration=15, # Minimal duration of malfunction max_duration=50 # Max duration of malfunction ) diff --git a/flatland/utils/flask_util.py b/flatland/utils/flask_util.py deleted file mode 100644 index e30fd72dfa30073f1404217257b20f8ee582c617..0000000000000000000000000000000000000000 --- a/flatland/utils/flask_util.py +++ /dev/null @@ -1,270 +0,0 @@ - - -from flask import Flask, request, redirect, Response -from flask_socketio import SocketIO, emit -from flask_cors import CORS, cross_origin -import threading -import os -import time -import webbrowser -import numpy as np -import typing -import socket - -from flatland.envs.rail_env import RailEnv, RailEnvActions - - -#async_mode = None - - -class simple_flask_server(object): - """ I wanted to wrap the flask server in a class but this seems to be quite hard; - eg see: https://stackoverflow.com/questions/40460846/using-flask-inside-class - I have made a messy sort of singleton pattern. - It might be easier to revert to the "standard" flask global functions + decorators. - """ - - static_folder = os.path.join(os.getcwd(), "static") - print("Flask static folder: ", static_folder) - app = Flask(__name__, - static_url_path='', - static_folder=static_folder) - - socketio = SocketIO(app, cors_allowed_origins='*') - - # This is the original format for the I/O. - # It comes from the format used in the msgpack saved episode. - # The lists here are truncated from the original - see CK's original main.py, in flatland-render. - gridmap = [ - # list of rows (?). Each cell is a 16-char binary string. Yes this is inefficient! - ["0000000000000000", "0010000000000000", "0000000000000000", "0000000000000000", "0010000000000000", "0000000000000000", "0000000000000000", "0000000000000000", "0010000000000000", "0000000000000000"], - ["0000000000000000", "1000000000100000", "0000000000000000", "0000000000000000", "0000000001001000", "0001001000000000", "0010000000000000", "0000000000000000", "1000000000100000", "0000000000000000"], # ... - ] - agents_static = [ - # [initial position], initial direction, [target], 0 (?) - [[7, 9], 2, [3, 5], 0, - # Speed and malfunction params - {"position_fraction": 0, "speed": 1, "transition_action_on_cellexit": 3}, - {"malfunction": 0, "malfunction_rate": 0, "next_malfunction": 0, "nr_malfunctions": 0}], - [[8, 8], 1, [1, 6], 0, - {"position_fraction": 0, "speed": 1, "transition_action_on_cellexit": 2}, - {"malfunction": 0, "malfunction_rate": 0, "next_malfunction": 0, "nr_malfunctions": 0}], - [[3, 7], 2, [0, 1], 0, - {"position_fraction": 0, "speed": 1, "transition_action_on_cellexit": 2}, - {"malfunction": 0, "malfunction_rate": 0, "next_malfunction": 0, "nr_malfunctions": 0}] - ] - - # "actions" are not really actions, but [row, col, direction] for each agent, at each time step - # This format does not yet handle agents which are in states inactive or done_removed - actions= [ - [[7, 9, 2], [8, 8, 1], [3, 7, 2]], [[7, 9, 2], [8, 7, 3], [2, 7, 0]], # ... - ] - - def __init__(self, env): - # Some ugly stuff with cls and self here - cls = self.__class__ - cls.instance = self # intended as singleton - - cls.app.config['CORS_HEADERS'] = 'Content-Type' - cls.app.config['SECRET_KEY'] = 'secret!' - - self.app = cls.app - self.socketio = cls.socketio - self.env = env - self.renderer_ready = False # to indicate env background not yet drawn - self.port = None # we only assign a port once we start the background server... - self.host = None - - def run_flask_server(self, host='127.0.0.1', port=None): - self.host = host - - if port is None: - self.port = self._find_available_port(host) - else: - self.port = port - - self.socketio.run(simple_flask_server.app, host=host, port=self.port) - - def run_flask_server_in_thread(self, host="127.0.0.1", port=None): - # daemon=True so that this thread exits when the main / foreground thread exits, - # usually when the episode finishes. - self.thread = threading.Thread( - target=self.run_flask_server, - kwargs={"host": host, "port": port}, - daemon=True) - self.thread.start() - # short sleep to allow thread to start (may be unnnecessary) - time.sleep(1) - - def open_browser(self): - webbrowser.open("http://localhost:{}".format(self.port)) - # short sleep to allow browser to request the page etc (may be unnecessary) - time.sleep(1) - - def _test_listen_port(self, host: str, port: int): - oSock = socket.socket() - try: - oSock.bind((host, port)) - except OSError: - return False # The port is not available - - del oSock # This should release the port - return True # The port is available - - def _find_available_port(self, host: str, port_start: int = 8080): - for nPort in range(port_start, port_start+100): - if self._test_listen_port(host, nPort): - return nPort - print("Could not find an available port for Flask to listen on!") - return None - - def get_endpoint_url(self): - return "http://{}:{}".format(self.host, self.port) - - @app.route('/', methods=['GET']) - def home(): - # redirects from "/" to "/index.html" which is then served from static. - # print("Here - / - cwd:", os.getcwd()) - return redirect("index.html") - - @staticmethod - @socketio.on('connect') - def connected(): - ''' - When the JS Renderer connects, - this method will send the env and agent information - ''' - cls = simple_flask_server - print('Client connected') - - # Do we really need this? - cls.socketio.emit('message', {'message': 'Connected'}) - - print('Send Env grid and agents') - # cls.socketio.emit('grid', {'grid': cls.gridmap, 'agents_static': cls.agents_static}, broadcast=False) - cls.instance.send_env() - print("Env and agents sent") - - @staticmethod - @socketio.on('disconnect') - def disconnected(): - print('Client disconnected') - - def send_actions(self, dict_actions): - ''' Sends the agent positions and directions, not really actions. - ''' - llAgents = self.agents_to_list() - self.socketio.emit('agentsAction', {'actions': llAgents}) - - def send_observation(self, agent_handles, dict_obs): - """ Send an observation. - TODO: format observation message. - """ - self.socketio.emit("observation", {"agents": agent_handles, "observations": dict_obs}) - - def send_env(self): - """ Sends the env, ie the rail grid, and the agents (static) information - """ - # convert 2d array of int into 2d array of 16char strings - g2sGrid = np.vectorize(np.binary_repr)(self.env.rail.grid, width=16) - llGrid = g2sGrid.tolist() - llAgents = self.agents_to_list_dict() - self.socketio.emit('grid', { - 'grid': llGrid, - 'agents_static': llAgents - }, - broadcast=False) - - def send_env_and_wait(self): - for iAttempt in range(30): - if self.is_renderer_ready(): - print("Background Render complete") - break - else: - print("Waiting for browser to signal that rendering complete") - time.sleep(1) - - @staticmethod - @socketio.on('renderEvent') - def handle_render_event(data): - cls=simple_flask_server - self = cls.instance - print('RenderEvent!!') - print('status: ' + data['status']) - print('message: ' + data['message']) - - if data['status'] == 'listening': - self.renderer_ready = True - - def is_renderer_ready(self): - return self.renderer_ready - - def agents_to_list_dict(self): - ''' Create a list of lists / dicts for serialisation - Maps from the internal representation in EnvAgent to - the schema used by the Javascript renderer. - ''' - llAgents = [] - for agent in self.env.agents: - if agent.position is None: - # the int()s are to convert from numpy int64 which causes problems in serialization - # to plain old python int - lPos = [int(agent.initial_position[0]), int(agent.initial_position[1])] - else: - lPos = [int(agent.position[0]), int(agent.position[1])] - - lAgent = [ - lPos, - int(agent.direction), - [int(agent.target[0]), int(agent.target[1])], 0, - { # dummy values: - "position_fraction": 0, - "speed": 1, - "transition_action_on_cellexit": 3 - }, - { - "malfunction": 0, - "malfunction_rate": 0, - "next_malfunction": 0, - "nr_malfunctions": 0 - } - ] - llAgents.append(lAgent) - return llAgents - - def agents_to_list(self): - llAgents = [] - for agent in self.env.agents: - if agent.position is None: - lPos = [int(agent.initial_position[0]), int(agent.initial_position[1])] - else: - lPos = [int(agent.position[0]), int(agent.position[1])] - iDir = int(agent.direction) - - lAgent = [*lPos, iDir] - - llAgents.append(lAgent) - return llAgents - - - -def main1(): - - print('Run Flask SocketIO Server') - server = simple_flask_server() - threading.Thread(target=server.run_flask_server).start() - # Open Browser - webbrowser.open('http://127.0.0.1:8080') - - print('Send Action') - for i in server.actions: - time.sleep(1) - print('send action') - server.socketio.emit('agentsAction', {'actions': i}) - - - - - -if __name__ == "__main__": - main1() \ No newline at end of file diff --git a/flatland/utils/graphics_pil.py b/flatland/utils/graphics_pil.py index 2d0713b118344b8e26c041a120d9fdaed49e21ab..595b6ed61e84bf7e677745d378b918a36eb44a45 100644 --- a/flatland/utils/graphics_pil.py +++ b/flatland/utils/graphics_pil.py @@ -1,7 +1,6 @@ import io import os import time -#import tkinter as tk import numpy as np from PIL import Image, ImageDraw, ImageFont @@ -14,10 +13,6 @@ from flatland.core.grid.rail_env_grid import RailEnvTransitions # noqa: E402 class PILGL(GraphicsLayer): - # tk.Tk() must be a singleton! - # https://stackoverflow.com/questions/26097811/image-pyimage2-doesnt-exist - # window = tk.Tk() - RAIL_LAYER = 0 PREDICTION_PATH_LAYER = 1 TARGET_LAYER = 2 diff --git a/flatland/utils/graphics_tkpil.py b/flatland/utils/graphics_tkpil.py deleted file mode 100644 index 7e89e734a170e746c07a4779f4f12e2ff88cf42e..0000000000000000000000000000000000000000 --- a/flatland/utils/graphics_tkpil.py +++ /dev/null @@ -1,52 +0,0 @@ - -import tkinter as tk - -from PIL import ImageTk -# from numpy import array -# from pkg_resources import resource_string as resource_bytes - -# from flatland.utils.graphics_layer import GraphicsLayer -from flatland.utils.graphics_pil import PILSVG - - -class TKPILGL(PILSVG): - # tk.Tk() must be a singleton! - # https://stackoverflow.com/questions/26097811/image-pyimage2-doesnt-exist - window = tk.Tk() - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.window_open = False - - def open_window(self): - print("open_window - tk") - assert self.window_open is False, "Window is already open!" - self.__class__.window.title("Flatland") - self.__class__.window.configure(background='grey') - self.window_open = True - - def close_window(self): - self.panel.destroy() - # quit but not destroy! - self.__class__.window.quit() - - def show(self, block=False): - # print("show - ", self.__class__) - img = self.alpha_composite_layers() - - if not self.window_open: - self.open_window() - - tkimg = ImageTk.PhotoImage(img) - - if self.firstFrame: - # Do TK actions for a new panel (not sure what they really do) - self.panel = tk.Label(self.window, image=tkimg) - self.panel.pack(side="bottom", fill="both", expand="yes") - else: - # update the image in situ - self.panel.configure(image=tkimg) - self.panel.image = tkimg - - self.__class__.window.update() - self.firstFrame = False diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index 196f4a382ac0b92f975384f9aeaa3d4127d6e93a..cfcb530fb75614b84155f85b08086bfd82535d0f 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -8,6 +8,7 @@ from numpy import array from recordtype import recordtype from flatland.utils.graphics_pil import PILGL, PILSVG +from flatland.utils.graphics_pgl import PGLGL # TODO: suggested renaming to RailEnvRenderTool, as it will only work with RailEnv! @@ -21,7 +22,8 @@ class AgentRenderVariant(IntEnum): class RenderTool(object): - """ RenderTool is a facade to a renderer, either local or browser + """ RenderTool is a facade to a renderer. + (This was introduced for the Browser / JS renderer which has now been removed.) """ def __init__(self, env, gl="PGL", jupyter=False, agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND, @@ -35,20 +37,13 @@ class RenderTool(object): self.agent_render_variant = agent_render_variant - if gl in ["PIL", "PILSVG", "TKPIL", "TKPILSVG", "PGL"]: + if gl in ["PIL", "PILSVG", "PGL"]: self.renderer = RenderLocal(env, gl, jupyter, agent_render_variant, show_debug, clear_debug_text, screen_width, screen_height) - - # To support legacy access to the GraphicsLayer (gl) - # DEPRECATED - TODO: remove these calls! self.gl = self.renderer.gl - - elif gl == "BROWSER": - from flatland.utils.flask_util import simple_flask_server - self.renderer = RenderBrowser(env, host=host, port=port) else: - print("[", gl, "] not found, switch to PILSVG or BROWSER") + print("[", gl, "] not found, switch to PGL") def render_env(self, show=False, # whether to call matplotlib show() or equivalent after completion @@ -98,7 +93,6 @@ class RenderTool(object): return None - class RenderBase(object): def __init__(self, env): pass @@ -124,56 +118,6 @@ class RenderBase(object): pass -class RenderBrowser(RenderBase): - def __init__(self, env, host="localhost", port=None): - self.server = simple_flask_server(env) - self.server.run_flask_server_in_thread(host=host, port=port) - self.env = env - self.background_rendered = False - - def render_env(self, - show=False, # whether to call matplotlib show() or equivalent after completion - show_agents=True, # whether to include agents - show_inactive_agents=False, - show_observations=True, # whether to include observations - show_predictions=False, # whether to include predictions - frames=False, # frame counter to show (intended since invocation) - episode=None, # int episode number to show - step=None, # int step number to show in image - selected_agent=None, # indicate which agent is "selected" in the editor): - return_image=False): # indicate if image is returned for use in monitor: - - if not self.background_rendered: - self.server.send_env_and_wait() - self.background_rendered = True - - self.server.send_actions({}) - - if show_observations: - self.render_observation(range(self.env.get_num_agents()), self.env.dev_obs_dict) - - def render_observation(self, agent_handles, dict_observation): - # Change keys to strings, and OrderedSet to list (of tuples) - dict_obs2 = {str(item[0]): list(item[1]) for item in self.env.dev_obs_dict.items()} - # Convert any ranges into a list - list_handles = list(agent_handles) - self.server.send_observation(list_handles, dict_obs2) - - def get_port(self): - return self.server.port - - def get_endpoint_url(self): - return self.server.get_endpoint_url() - - def close_window(self): - pass - - def reset(self): - pass - - def set_new_rail(self): - pass - class RenderLocal(RenderBase): """ Class to render the RailEnv and agents. @@ -211,19 +155,11 @@ class RenderLocal(RenderBase): self.gl = PILGL(env.width, env.height, jupyter, screen_width=screen_width, screen_height=screen_height) elif gl == "PILSVG": self.gl = PILSVG(env.width, env.height, jupyter, screen_width=screen_width, screen_height=screen_height) - elif gl in ["TKPILSVG", "TKPIL"]: - # Conditional import to avoid importing tkinter unless required. - print("Importing TKPILGL - requires a local display!") - from flatland.utils.graphics_tkpil import TKPILGL - self.gl = TKPILGL(env.width, env.height, jupyter, screen_width=screen_width, screen_height=screen_height) - elif gl in ["PGL"]: - # Conditional import - from flatland.utils.graphics_pgl import PGLGL - self.gl = PGLGL(env.width, env.height, jupyter, screen_width=screen_width, screen_height=screen_height) else: - print("[", gl, "] not found, switch to PGL, PILSVG, TKPIL (deprecated) or BROWSER") - print("Using PILSVG.") - self.gl = PILSVG(env.width, env.height, jupyter, screen_width=screen_width, screen_height=screen_height) + if gl != "PGL": + print("[", gl, "] not found, switch to PGL, PILSVG") + print("Using PGL") + self.gl = PGLGL(env.width, env.height, jupyter, screen_width=screen_width, screen_height=screen_height) self.new_rail = True self.show_debug = show_debug @@ -582,7 +518,7 @@ class RenderLocal(RenderBase): """ # if type(self.gl) is PILSVG: - if self.gl_str in ["PILSVG", "TKPIL", "TKPILSVG", "PGL"]: + if self.gl_str in ["PILSVG", "PGL"]: return self.render_env_svg(show=show, show_observations=show_observations, show_predictions=show_predictions,