diff --git a/examples/demo.py b/examples/demo.py index dae2118c9a80680fbf89b38481c988ec7d404ed7..02fae475039578e09fbd885d0bde0275b8aa6ac9 100644 --- a/examples/demo.py +++ b/examples/demo.py @@ -1,8 +1,8 @@ import os import time import random - import numpy as np +from datetime import datetime from flatland.envs.generators import complex_rail_generator # from flatland.envs.generators import rail_from_list_of_saved_GridTransitionMap_generator @@ -125,21 +125,27 @@ class Demo: self.env = env self.create_renderer() self.action_size = 4 + self.max_frame_rate = 60 def create_renderer(self): - self.renderer = RenderTool(self.env, gl="QTSVG") + self.renderer = RenderTool(self.env, gl="PILSVG") handle = self.env.get_agent_handles() return handle + def set_max_framerate(self,max_frame_rate): + self.max_frame_rate = max_frame_rate + def run_demo(self, max_nbr_of_steps=30): action_dict = dict() # Reset environment _ = self.env.reset(False, False) + time.sleep(0.0001) # to satisfy lint... + for step in range(max_nbr_of_steps): - # time.sleep(.1) + begin_frame_time_stamp = datetime.now() # Action for iAgent in range(self.env.get_num_agents()): @@ -173,7 +179,19 @@ class Demo: break -if True: + # ensure that the rendering is not faster then the maximal allowed frame rate + end_frame_time_stamp = datetime.now() + frame_exe_time = end_frame_time_stamp - begin_frame_time_stamp + max_time = 1/self.max_frame_rate + delta = (max_time - frame_exe_time.total_seconds()) + if delta > 0.0: + time.sleep(delta) + + + self.renderer.close_window() + + +if False: demo_000 = Demo(Scenario_Generator.generate_random_scenario()) demo_000.run_demo() demo_000 = None @@ -194,18 +212,18 @@ if True: demo_002.run_demo() demo_002 = None - demo_flatland_000 = Demo(Scenario_Generator.load_scenario('./env-data/railway/example_flatland_000.pkl')) - demo_flatland_000.renderer.resize() - demo_flatland_000.run_demo(300) - demo_flatland_000 = None - - demo_flatland_000 = Demo(Scenario_Generator.load_scenario('./env-data/railway/example_flatland_001.pkl')) - demo_flatland_000.renderer.resize() - demo_flatland_000.run_demo(300) - demo_flatland_000 = None +demo_flatland_000 = Demo(Scenario_Generator.load_scenario('./env-data/railway/example_flatland_000.pkl')) +demo_flatland_000.renderer.resize() +demo_flatland_000.run_demo(60) +demo_flatland_000 = None +demo_flatland_000 = Demo(Scenario_Generator.load_scenario('./env-data/railway/example_flatland_001.pkl')) +demo_flatland_000.renderer.resize() +demo_flatland_000.run_demo(60) +demo_flatland_000 = None demo_flatland_000 = Demo(Scenario_Generator.load_scenario('./env-data/railway/example_network_003.pkl')) demo_flatland_000.renderer.resize() -demo_flatland_000.run_demo(1800) +demo_flatland_000.set_max_framerate(5) +demo_flatland_000.run_demo(30) demo_flatland_000 = None diff --git a/examples/simple_example_3.py b/examples/simple_example_3.py index b5283df627b89294b899a0dc3f4652d5d2375152..1978e27c8511154d317ef5c96c3a9bc6816168c5 100644 --- a/examples/simple_example_3.py +++ b/examples/simple_example_3.py @@ -3,7 +3,7 @@ import random from flatland.envs.generators import random_rail_generator from flatland.envs.rail_env import RailEnv from flatland.utils.rendertools import RenderTool -from flatland.core.env_observation_builder import TreeObsForRailEnv +from flatland.envs.observations import TreeObsForRailEnv import numpy as np random.seed(100) diff --git a/examples/tkplay.py b/examples/tkplay.py index a46dcbc6f3aca2b84a9b35f33c339c3f03291c32..c17ea519014ff304f654036f111ac953f328a118 100644 --- a/examples/tkplay.py +++ b/examples/tkplay.py @@ -26,7 +26,7 @@ def tkmain(n_trials=2, n_steps=50, sGL="PIL"): env_renderer.renderEnv(show=True, frames=True, iEpisode=trials, iStep=step, action_dict=oPlayer.action_dict) - env_renderer.gl.close_window() + env_renderer.close_window() if __name__ == "__main__": diff --git a/flatland/utils/editor.py b/flatland/utils/editor.py index 7c9e0687b15f6d599ad8338906b082fc9bac7a8c..1f4a604770575f5bd24d62ade3092cbd251081be 100644 --- a/flatland/utils/editor.py +++ b/flatland/utils/editor.py @@ -174,10 +174,9 @@ class View(object): self.oRT = rt.RenderTool(self.editor.env, gl=self.sGL) def redraw(self): - # TODO: bit of a hack - can we suppress the console messages from MPL at source? - # with redirect_stdout(stdout_dest): with self.wOutput: # plt.figure(figsize=(10, 10)) + self.oRT.set_new_rail() self.oRT.renderEnv(spacing=False, arrows=False, sRailColor="gray", show=False, iSelectedAgent=self.model.iSelectedAgent, show_observations=self.show_observations()) diff --git a/flatland/utils/graphics_layer.py b/flatland/utils/graphics_layer.py index a1de818fe57ed2f91a3b025f652234c578dc11e8..2598a8893cf47ac9be526a7c8809f55ca2bb8ddf 100644 --- a/flatland/utils/graphics_layer.py +++ b/flatland/utils/graphics_layer.py @@ -66,7 +66,12 @@ class GraphicsLayer(object): def get_cmap(self, *args, **kwargs): return plt.get_cmap(*args, **kwargs) - def setRailAt(self, row, col, binTrans): + def setRailAt(self, row, col, binTrans, iTarget=None): + """ Set the rail at cell (row, col) to have transitions binTrans. + The target argument can contain the index of the agent to indicate + that agent's target is at that cell, so that a station can be + rendered in the static rail layer. + """ pass def setAgentAt(self, iAgent, row, col, iDirIn, iDirOut): diff --git a/flatland/utils/graphics_pil.py b/flatland/utils/graphics_pil.py index fe5e74b541b3d958df7058acc5b127c45d1a8ebd..40e0a9270ede53de13174203a40972fdf540b61e 100644 --- a/flatland/utils/graphics_pil.py +++ b/flatland/utils/graphics_pil.py @@ -10,10 +10,10 @@ from cairosvg import svg2png from flatland.core.transitions import RailEnvTransitions # from copy import copy +from screeninfo import get_monitors class PILGL(GraphicsLayer): def __init__(self, width, height, nPixCell=60): - self.nPixCell = 60 self.yxBase = (0, 0) self.linewidth = 4 self.nAgentColors = 1 # overridden in loadAgent @@ -22,9 +22,23 @@ class PILGL(GraphicsLayer): self.width = width self.height = height + self.screen_width = 99999 + self.screen_height = 99999 + for m in get_monitors(): + self.screen_height = min(self.screen_height,m.height) + self.screen_width = min(self.screen_width,m.width) + + w = (self.screen_width-self.width-10)/(self.width + 1 + self.linewidth) + h = (self.screen_height-self.height-10)/(self.height + 1 + self.linewidth) + self.nPixCell = int(max(1,np.ceil(min(w,h)))) + # Total grid size at native scale self.widthPx = self.width * self.nPixCell + self.linewidth self.heightPx = self.height * self.nPixCell + self.linewidth + + self.xPx = int((self.screen_width - self.widthPx) / 2.0) + self.yPx = int((self.screen_height - self.heightPx) / 2.0) + self.layers = [] self.draws = [] @@ -33,12 +47,25 @@ class PILGL(GraphicsLayer): self.tColRail = (0, 0, 0) # black rails self.tColGrid = (230,) * 3 # light grey for grid + sColors = "d50000#c51162#aa00ff#6200ea#304ffe#2962ff#0091ea#00b8d4#00bfa5#00c853" + \ + "#64dd17#aeea00#ffd600#ffab00#ff6d00#ff3d00#5d4037#455a64" + + self.ltAgentColors = [self.rgb_s2i(sColor) for sColor in sColors.split("#")] + self.nAgentColors = len(self.ltAgentColors) + self.window_open = False # self.bShow = show self.firstFrame = True self.create_layers() # self.beginFrame() + def rgb_s2i(self, sRGB): + """ convert a hex RGB string like 0091ea to 3-tuple of ints """ + return tuple(int(sRGB[iRGB * 2:iRGB * 2 + 2], 16) for iRGB in [0, 1, 2]) + + def getAgentColor(self, iAgent): + return self.ltAgentColors[iAgent % self.nAgentColors] + def plot(self, gX, gY, color=None, linewidth=3, layer=0, opacity=255, **kwargs): color = self.adaptColor(color) if len(color) == 3: @@ -75,6 +102,7 @@ class PILGL(GraphicsLayer): self.window = tk.Tk() self.window.title("Flatland") self.window.configure(background='grey') + self.window.geometry('%dx%d+%d+%d' % (self.widthPx, self.heightPx, self.xPx, self.yPx)) self.window_open = True def close_window(self): @@ -246,10 +274,17 @@ class PILSVG(PILGL): "EE WW": "Bahnhof_#d50000_Gleis_horizontal.svg", "NN SS": "Bahnhof_#d50000_Gleis_vertikal.svg"} + # Dict of rail cell images indexed by binary transitions self.dPilRail = self.loadSVGs(dRailFiles, rotate=True) - self.dPilTarget = self.loadSVGs(dTargetFiles, rotate=False) - def loadSVGs(self, dDirFile, rotate=False): + # Load the target files (which have rails and transitions of their own) + # They are indexed by (binTrans, iAgent), ie a tuple of the binary transition and the agent index + dPilRail2 = self.loadSVGs(dTargetFiles, rotate=False, agent_colors=self.ltAgentColors) + # Merge them with the regular rails. + # https://stackoverflow.com/questions/38987/how-to-merge-two-dictionaries-in-a-single-expression + self.dPilRail = {**self.dPilRail, **dPilRail2} + + def loadSVGs(self, dDirFile, rotate=False, agent_colors=False): dPil = {} transitions = RailEnvTransitions() @@ -280,9 +315,10 @@ class PILSVG(PILGL): # svg = svg.merge(svgBG) pilRail = self.pilFromSvgFile(sPathSvg) - dPil[binTrans] = pilRail - + if rotate: + # For rotations, we also store the base image + dPil[binTrans] = pilRail # Rotate both the transition binary and the image and save in the dict for nRot in [90, 180, 270]: binTrans2 = transitions.rotate_transition(binTrans, nRot) @@ -290,25 +326,44 @@ class PILSVG(PILGL): # PIL rotates anticlockwise for positive theta pilRail2 = pilRail.rotate(-nRot) dPil[binTrans2] = pilRail2 + + if agent_colors: + # For recoloring, we don't store the base image. + a3BaseColor = self.rgb_s2i("d50000") + lPils = self.recolorImage(pilRail, a3BaseColor, self.ltAgentColors) + for iColor, pilRail2 in enumerate(lPils): + dPil[(binTrans, iColor)] = lPils[iColor] + return dPil - def setRailAt(self, row, col, binTrans, target=None): - if target is None: + def setRailAt(self, row, col, binTrans, iTarget=None): + if iTarget is None: if binTrans in self.dPilRail: pilTrack = self.dPilRail[binTrans] self.drawImageRC(pilTrack, (row, col)) else: print("Illegal rail:", row, col, format(binTrans, "#018b")[2:]) else: - if binTrans in self.dPilTarget: - pilTrack = self.dPilTarget[binTrans] + if (binTrans, iTarget) in self.dPilRail: + pilTrack = self.dPilRail[(binTrans, iTarget)] self.drawImageRC(pilTrack, (row, col)) else: print("Illegal target rail:", row, col, format(binTrans, "#018b")[2:]) - def rgb_s2i(self, sRGB): - """ convert a hex RGB string like 0091ea to 3-tuple of ints """ - return tuple(int(sRGB[iRGB * 2:iRGB * 2 + 2], 16) for iRGB in [0, 1, 2]) + def recolorImage(self, pil, a3BaseColor, ltColors): + rgbaImg = array(pil) + lPils = [] + + for iColor, tnColor in enumerate(ltColors): + # find the pixels which match the base paint color + xy_color_mask = np.all(rgbaImg[:, :, 0:3] - a3BaseColor == 0, axis=2) + rgbaImg2 = np.copy(rgbaImg) + + # Repaint the base color with the new color + rgbaImg2[xy_color_mask, 0:3] = tnColor + pil2 = Image.fromarray(rgbaImg2) + lPils.append(pil2) + return lPils def loadAgentSVGs(self): @@ -319,13 +374,8 @@ class PILSVG(PILGL): (0, 3): "svg/Zug_2_Weiche_#0091ea.svg" } - sColors = "d50000#c51162#aa00ff#6200ea#304ffe#2962ff#0091ea#00b8d4#00bfa5#00c853" + \ - "#64dd17#aeea00#ffd600#ffab00#ff6d00#ff3d00#5d4037#455a64" - lColors = sColors.split("#") - self.nAgentColors = len(lColors) - # "paint" color of the train images we load - a_base_color = self.rgb_s2i("0091ea") + a3BaseColor = self.rgb_s2i("0091ea") self.dPilZug = {} @@ -342,14 +392,11 @@ class PILSVG(PILGL): # PIL rotates anticlockwise for positive theta pilZug2 = pilZug.rotate(-nDegRot) - rgbaZug2 = array(pilZug2) - - for iColor, sColor in enumerate(lColors): - tnNewColor = self.rgb_s2i(sColor) - xy_color_mask = np.all(rgbaZug2[:, :, 0:3] - a_base_color == 0, axis=2) - rgbaZug3 = np.copy(rgbaZug2) - rgbaZug3[xy_color_mask, 0:3] = tnNewColor - self.dPilZug[(iDirIn2, iDirOut2, iColor)] = Image.fromarray(rgbaZug3) + + # Save colored versions of each rotation / variant + lPils = self.recolorImage(pilZug2, a3BaseColor, self.ltAgentColors) + for iColor, pilZug3 in enumerate(lPils): + self.dPilZug[(iDirIn2, iDirOut2, iColor)] = lPils[iColor] def setAgentAt(self, iAgent, row, col, iDirIn, iDirOut): delta_dir = (iDirOut - iDirIn) % 4 diff --git a/flatland/utils/render_qt.py b/flatland/utils/render_qt.py index 233f07bca474204ea31aec5d75b530911071bfad..dc2217c3ee9041caf39c88d8898bf3fb0c7cb24e 100644 --- a/flatland/utils/render_qt.py +++ b/flatland/utils/render_qt.py @@ -157,7 +157,7 @@ class QTSVG(GraphicsLayer): self.lwAgents = [] self.agents_prev = [] - def setRailAt(self, row, col, binTrans): + def setRailAt(self, row, col, binTrans, iTarget=None): if binTrans in self.track.dSvg: sSVG = self.track.dSvg[binTrans].to_string() svgWidget = create_QtSvgWidget_from_svg_string(sSVG) diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index f2200b97b03577e4dcdb2f34575b0330ff0693fe..aa8731b1135326586117be2bd8e8bf6da97ea146 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -498,10 +498,11 @@ class RenderTool(object): """ rt = self.__class__ - cmap = self.gl.get_cmap('hsv', lut=max(len(self.env.agents), len(self.env.agents_static) + 1)) + # cmap = self.gl.get_cmap('hsv', lut=max(len(self.env.agents), len(self.env.agents_static) + 1)) for agent in agent_handles: - color = cmap(agent) + # color = cmap(agent) + color = self.gl.getAgentColor(agent) for visited_cell in observation_dict[agent]: cell_coord = array(visited_cell[:2]) cell_coord_trans = np.matmul(cell_coord, rt.grc2xy) + rt.xyHalf @@ -757,7 +758,7 @@ class RenderTool(object): for iAgent, agent in enumerate(self.env.agents_static): if agent is None: continue - dTargets[agent.target] = iAgent + dTargets[tuple(agent.target)] = iAgent # Draw each cell independently for r in range(env.height): @@ -767,7 +768,7 @@ class RenderTool(object): target = dTargets[(r, c)] else: target = None - self.gl.setRailAt(r, c, binTrans) + self.gl.setRailAt(r, c, binTrans, iTarget=target) for iAgent, agent in enumerate(self.env.agents): if agent is None: @@ -782,8 +783,12 @@ class RenderTool(object): direction = agent.direction old_direction = agent.direction - cmap = self.gl.get_cmap('hsv', lut=max(len(self.env.agents), len(self.env.agents_static) + 1)) - self.gl.setAgentAt(iAgent, *position, old_direction, direction,color=cmap(iAgent)) + # setAgentAt uses the agent index for the color + # cmap = self.gl.get_cmap('hsv', lut=max(len(self.env.agents), len(self.env.agents_static) + 1)) + self.gl.setAgentAt(iAgent, *position, old_direction, direction) # ,color=cmap(iAgent)) + + if show_observations: + self.renderObs(range(env.get_num_agents()), env.dev_obs_dict) if show: self.gl.show() @@ -792,3 +797,6 @@ class RenderTool(object): self.iFrame += 1 return + + def close_window(self): + self.gl.close_window() diff --git a/flatland/utils/svg.py b/flatland/utils/svg.py index 0a6b895dcdaa5e9cd4d99cfa5861ecc187dad905..e7b2cebb606089f5937d37c8e4d6381ada9fc211 100644 --- a/flatland/utils/svg.py +++ b/flatland/utils/svg.py @@ -60,7 +60,7 @@ class SVG(object): sNewStyles = "\n" for sKey, sValue in self.dStyles.items(): if sKey == style_name: - sValue = "fill:#" + "".join([ ('{:#04x}'.format(int(255.0*col))[2:4]) for col in color[0:3]]) + ";" + sValue = "fill:#" + "".join([('{:#04x}'.format(int(255.0*col))[2:4]) for col in color[0:3]]) + ";" sNewStyle = "\t.st" + sKey + "{" + sValue + "}\n" sNewStyles += sNewStyle diff --git a/notebooks/Editor2.ipynb b/notebooks/Editor2.ipynb index 367bd1299173165ab5264ab7d6fb0077491c2959..e6e9100523e8eff39f3c191945419bad18f2bc32 100644 --- a/notebooks/Editor2.ipynb +++ b/notebooks/Editor2.ipynb @@ -106,7 +106,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "95562eba9d7b4a2794d8d65721413e04", + "model_id": "c6a9337bfb5a49b19941bcf7b643aaad", "version_major": 2, "version_minor": 0 }, @@ -132,7 +132,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "0fe475488df94151afbfe8dc1e1050e8", + "model_id": "b632ca8286e94419b9ea6c4bab7f84ae", "version_major": 2, "version_minor": 0 }, diff --git a/requirements_dev.txt b/requirements_dev.txt index 0bc1c76c25c5fc43a1f43ecaf484d23931f5bf7c..db80383febe488d4b0df2fc54407a976eb84b49b 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -20,8 +20,10 @@ xarray==0.11.3 matplotlib==3.0.2 PyQt5==5.12 Pillow==5.4.1 -CairoSVG==2.3.1 -pycairo==1.18.1 +# CairoSVG==2.3.1 +# pycairo==1.18.1 msgpack==0.6.1 svgutils==0.3.1 + +screeninfo==0.3.1 \ No newline at end of file diff --git a/tests/test_env_edit.py b/tests/test_env_edit.py new file mode 100644 index 0000000000000000000000000000000000000000..2d8a4e087a0ab1d4425c05cf3153abcded8a8ceb --- /dev/null +++ b/tests/test_env_edit.py @@ -0,0 +1,15 @@ + +from flatland.envs.rail_env import RailEnv +# from flatland.envs.agent_utils import EnvAgent +from flatland.envs.agent_utils import EnvAgentStatic + + +def test_load_env(): + env = RailEnv(10, 10) + env.load("env-data/tests/test-10x10.mpk") + + agent_static = EnvAgentStatic((0, 0), 2, (5, 5)) + env.add_agent_static(agent_static) + assert env.get_num_agents() == 1 + +