From efccb8e41c9174dae7fc2017360543f9f36c2cc2 Mon Sep 17 00:00:00 2001 From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch> Date: Wed, 5 Jun 2019 07:31:06 +0200 Subject: [PATCH] QT renderer removed --- docs/gettingstarted.rst | 2 +- examples/custom_railmap_example.py | 9 +- examples/qt2.py | 76 ---------- examples/simple_example_1.py | 4 +- examples/simple_example_2.py | 7 +- examples/simple_example_3.py | 7 +- examples/tkplay.py | 2 +- flatland/utils/graphics_qt.py | 225 --------------------------- flatland/utils/render_qt.py | 236 ----------------------------- flatland/utils/rendertools.py | 15 +- 10 files changed, 21 insertions(+), 562 deletions(-) delete mode 100644 examples/qt2.py delete mode 100644 flatland/utils/graphics_qt.py delete mode 100644 flatland/utils/render_qt.py diff --git a/docs/gettingstarted.rst b/docs/gettingstarted.rst index 2d545b24..e64013aa 100644 --- a/docs/gettingstarted.rst +++ b/docs/gettingstarted.rst @@ -76,7 +76,7 @@ Environments can be rendered using the utils.rendertools utilities, for example: .. code-block:: python - env_renderer = RenderTool(env, gl="QT") + env_renderer = RenderTool(env) env_renderer.renderEnv(show=True) diff --git a/examples/custom_railmap_example.py b/examples/custom_railmap_example.py index 9d483c0c..26beb61d 100644 --- a/examples/custom_railmap_example.py +++ b/examples/custom_railmap_example.py @@ -1,10 +1,11 @@ import random -from flatland.envs.rail_env import RailEnv -from flatland.core.transitions import RailEnvTransitions +import numpy as np + from flatland.core.transition_map import GridTransitionMap +from flatland.core.transitions import RailEnvTransitions +from flatland.envs.rail_env import RailEnv from flatland.utils.rendertools import RenderTool -import numpy as np random.seed(100) np.random.seed(100) @@ -32,7 +33,7 @@ env = RailEnv(width=6, env.reset() -env_renderer = RenderTool(env, gl="QT") +env_renderer = RenderTool(env) env_renderer.renderEnv(show=True) input("Press Enter to continue...") diff --git a/examples/qt2.py b/examples/qt2.py deleted file mode 100644 index ee3ea0cd..00000000 --- a/examples/qt2.py +++ /dev/null @@ -1,76 +0,0 @@ -import sys - -from PyQt5 import QtSvg -from PyQt5.QtCore import Qt, QByteArray -from PyQt5.QtWidgets import QApplication, QLabel, QMainWindow, QGridLayout, QWidget - -from flatland.utils import svg - - -# Subclass QMainWindow to customise your application's main window -class MainWindow(QMainWindow): - - def __init__(self, *args, **kwargs): - super(MainWindow, self).__init__(*args, **kwargs) - - self.setWindowTitle("My Awesome App") - - layout = QGridLayout() - layout.setSpacing(0) - - wMain = QWidget(self) - - wMain.setLayout(layout) - - label = QLabel("This is a PyQt5 window!") - - # The `Qt` namespace has a lot of attributes to customise - # widgets. See: http://doc.qt.io/qt-5/qt.html - label.setAlignment(Qt.AlignCenter) - layout.addWidget(label, 0, 0) - - svgWidget = QtSvg.QSvgWidget("./svg/Gleis_vertikal.svg") - layout.addWidget(svgWidget, 1, 0) - - if True: - track = svg.Track() - - svgWidget = None - iRow = 0 - iCol = 2 - iArt = 0 - nCols = 3 - for binTrans in list(track.dSvg.keys())[:2]: - sSVG = track.dSvg[binTrans].to_string() - - bySVG = bytearray(sSVG, encoding='utf-8') - - # with open(sfPath, "r") as fIn: - # sSVG = fIn.read() - # bySVG = bytearray(sSVG, encoding='utf-8') - - svgWidget = QtSvg.QSvgWidget() - oQB = QByteArray(bySVG) - - bSuccess = svgWidget.renderer().load(oQB) - # print(x0, y0, x1, y1) - print(iRow, iCol, bSuccess) - print("\n\n\n", bySVG.decode("utf-8")) - # svgWidget.setGeometry(x0, y0, x1, y1) - layout.addWidget(svgWidget, iRow, iCol) - - iArt += 1 - iRow = int(iArt / nCols) - iCol = iArt % nCols - - # Set the central widget of the Window. Widget will expand - # to take up all the space in the window by default. - self.setCentralWidget(wMain) - - -app = QApplication(sys.argv) - -window = MainWindow() -window.show() - -app.exec_() diff --git a/examples/simple_example_1.py b/examples/simple_example_1.py index 7132b533..ca442873 100644 --- a/examples/simple_example_1.py +++ b/examples/simple_example_1.py @@ -1,6 +1,6 @@ from flatland.envs.generators import rail_from_manual_specifications_generator -from flatland.envs.rail_env import RailEnv from flatland.envs.observations import TreeObsForRailEnv +from flatland.envs.rail_env import RailEnv from flatland.utils.rendertools import RenderTool # Example generate a rail given a manual specification, @@ -24,7 +24,7 @@ env = RailEnv(width=6, env.reset() -env_renderer = RenderTool(env, gl="QT") +env_renderer = RenderTool(env, gl="PILSVG") env_renderer.renderEnv(show=True) input("Press Enter to continue...") diff --git a/examples/simple_example_2.py b/examples/simple_example_2.py index 70f1f8f4..f1a6a7c7 100644 --- a/examples/simple_example_2.py +++ b/examples/simple_example_2.py @@ -1,10 +1,11 @@ import random +import numpy as np + from flatland.envs.generators import random_rail_generator # , rail_from_list_of_saved_GridTransitionMap_generator -from flatland.envs.rail_env import RailEnv from flatland.envs.observations import TreeObsForRailEnv +from flatland.envs.rail_env import RailEnv from flatland.utils.rendertools import RenderTool -import numpy as np random.seed(100) np.random.seed(100) @@ -37,7 +38,7 @@ env = RailEnv(width=10, env.reset() -env_renderer = RenderTool(env, gl="QT") +env_renderer = RenderTool(env, gl="PILSVG") env_renderer.renderEnv(show=True) input("Press Enter to continue...") diff --git a/examples/simple_example_3.py b/examples/simple_example_3.py index 1978e27c..8bf134c6 100644 --- a/examples/simple_example_3.py +++ b/examples/simple_example_3.py @@ -1,10 +1,11 @@ import random +import numpy as np + from flatland.envs.generators import random_rail_generator +from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_env import RailEnv from flatland.utils.rendertools import RenderTool -from flatland.envs.observations import TreeObsForRailEnv -import numpy as np random.seed(100) np.random.seed(100) @@ -24,7 +25,7 @@ obs, all_rewards, done, _ = env.step({0: 0}) for i in range(env.get_num_agents()): env.obs_builder.util_print_obs_subtree(tree=obs[i], num_features_per_node=5) -env_renderer = RenderTool(env, gl="QT") +env_renderer = RenderTool(env, gl="PILSVG") env_renderer.renderEnv(show=True) print("Manual control: s=perform step, q=quit, [agent id] [1-2-3 action] \ diff --git a/examples/tkplay.py b/examples/tkplay.py index c17ea519..9e37a26f 100644 --- a/examples/tkplay.py +++ b/examples/tkplay.py @@ -11,7 +11,7 @@ def tkmain(n_trials=2, n_steps=50, sGL="PIL"): rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=20, min_dist=12), number_of_agents=5) - env_renderer = RenderTool(env, gl=sGL, show=True) + env_renderer = RenderTool(env, gl=sGL) oPlayer = Player(env) n_trials = 1 diff --git a/flatland/utils/graphics_qt.py b/flatland/utils/graphics_qt.py deleted file mode 100644 index 09dc1fa4..00000000 --- a/flatland/utils/graphics_qt.py +++ /dev/null @@ -1,225 +0,0 @@ - -import numpy as np -from PyQt5.QtCore import Qt -from PyQt5.QtGui import QImage, QPixmap, QPainter, QColor, QPolygon -from PyQt5.QtCore import QPoint, QRect # QSize -from PyQt5.QtWidgets import QApplication, QMainWindow, QWidget, QTextEdit -from PyQt5.QtWidgets import QHBoxLayout, QVBoxLayout, QLabel, QFrame -import os - - -class Window(QMainWindow): - """ - Simple application window to render the environment into - """ - - def __init__(self): - super().__init__() - - self.setWindowTitle('MiniGrid Gym Environment') - - # Image label to display the rendering - self.imgLabel = QLabel() - self.imgLabel.setFrameStyle(QFrame.Panel | QFrame.Sunken) - - if False: - # Text box for the mission - self.missionBox = QTextEdit() - self.missionBox.setReadOnly(True) - self.missionBox.setMinimumSize(400, 100) - - # Center the image - hbox = QHBoxLayout() - hbox.addStretch(1) - hbox.addWidget(self.imgLabel) - hbox.addStretch(1) - - # Arrange widgets vertically - vbox = QVBoxLayout() - vbox.addLayout(hbox) - # vbox.addWidget(self.missionBox) - - # Create a main widget for the window - self.mainWidget = QWidget(self) - self.setCentralWidget(self.mainWidget) - self.mainWidget.setLayout(vbox) - - # Show the application window - self.show() - self.setFocus() - - self.closed = False - - # Callback for keyboard events - self.keyDownCb = None - - def closeEvent(self, event): - self.closed = True - - def setPixmap(self, pixmap): - self.imgLabel.setPixmap(pixmap) - - def setText(self, text): - # self.missionBox.setPlainText(text) - pass - - def setKeyDownCb(self, callback): - self.keyDownCb = callback - - def keyPressEvent(self, e): - if self.keyDownCb is None: - return - - keyName = None - if e.key() == Qt.Key_Left: - keyName = 'LEFT' - elif e.key() == Qt.Key_Right: - keyName = 'RIGHT' - elif e.key() == Qt.Key_Up: - keyName = 'UP' - elif e.key() == Qt.Key_Down: - keyName = 'DOWN' - elif e.key() == Qt.Key_Space: - keyName = 'SPACE' - elif e.key() == Qt.Key_Return: - keyName = 'RETURN' - elif e.key() == Qt.Key_Alt: - keyName = 'ALT' - elif e.key() == Qt.Key_Control: - keyName = 'CTRL' - elif e.key() == Qt.Key_PageUp: - keyName = 'PAGE_UP' - elif e.key() == Qt.Key_PageDown: - keyName = 'PAGE_DOWN' - elif e.key() == Qt.Key_Backspace: - keyName = 'BACKSPACE' - elif e.key() == Qt.Key_Escape: - keyName = 'ESCAPE' - - if keyName is None: - return - self.keyDownCb(keyName) - - -class QtRenderer(object): - def __init__(self, width, height, ownWindow=False): - self.width = width - self.height = height - - self.img = QImage(width, height, QImage.Format_RGB888) - self.painter = QPainter() - - self.window = None - if ownWindow: - self.app = QApplication([]) - self.window = Window() - self.iFrame = 0 # for movie capture - - def close(self): - """ - Deallocate resources used - """ - pass - - def beginFrame(self): - self.painter.begin(self.img) - # self.painter.setRenderHint(QPainter.Antialiasing, False) - - # Clear the background - self.painter.setBrush(QColor(0, 0, 0)) - self.painter.drawRect(0, 0, self.width - 1, self.height - 1) - - def endFrame(self): - self.painter.end() - - if self.window: - if self.window.closed: - self.window = None - else: - self.window.setPixmap(self.getPixmap()) - self.app.processEvents() - - def getPixmap(self): - return QPixmap.fromImage(self.img) - - def getArray(self): - """ - Get a numpy array of RGB pixel values. - The size argument should be (3,w,h) - """ - - width = self.width - height = self.height - shape = (width, height, 3) - - numBytes = self.width * self.height * 3 - buf = self.img.bits().asstring(numBytes) - output = np.frombuffer(buf, dtype='uint8') - output = output.reshape(shape) - - return output - - def push(self): - self.painter.save() - - def pop(self): - self.painter.restore() - - def rotate(self, degrees): - self.painter.rotate(degrees) - - def translate(self, x, y): - self.painter.translate(x, y) - - def scale(self, x, y): - self.painter.scale(x, y) - - def setLineColor(self, r, g, b, a=255): - self.painter.setPen(QColor(r, g, b, a)) - - def setColor(self, r, g, b, a=255): - self.painter.setBrush(QColor(r, g, b, a)) - - def setLineWidth(self, width): - pen = self.painter.pen() - pen.setWidthF(width) - self.painter.setPen(pen) - - def drawLine(self, x0, y0, x1, y1): - self.painter.drawLine(x0, y0, x1, y1) - - def drawCircle(self, x, y, r): - center = QPoint(x, y) - self.painter.drawEllipse(center, r, r) - - def drawPolygon(self, points): - """Takes a list of points (tuples) as input""" - points = map(lambda p: QPoint(p[0], p[1]), points) - self.painter.drawPolygon(QPolygon(points)) - - def drawRect(self, x, y, w, h): - self.painter.drawRect(x, y, w, h) - - def drawPolyline(self, points): - """Takes a list of points (tuples) as input""" - points = map(lambda p: QPoint(p[0], p[1]), points) - self.painter.drawPolyline(QPolygon(points)) - - def fillRect(self, x, y, width, height, r, g, b, a=255): - self.painter.fillRect(QRect(x, y, width, height), QColor(r, g, b, a)) - - def drawText(self, x, y, sText): - self.painter.drawText(x, y, sText) - - def takeSnapshot(self, sDir="./movie"): - oWidget = self.window.mainWidget - oPixmap = oWidget.grab() - - if not os.path.isdir(sDir): - os.mkdir(sDir) - - nRunIn = 30 - if self.iFrame > nRunIn: - sfImage = "%s/frame%05d.jpg" % (sDir, self.iFrame - nRunIn) - oPixmap.save(sfImage, "jpg") - self.iFrame += 1 diff --git a/flatland/utils/render_qt.py b/flatland/utils/render_qt.py deleted file mode 100644 index 97a0799b..00000000 --- a/flatland/utils/render_qt.py +++ /dev/null @@ -1,236 +0,0 @@ -import time - -# from matplotlib import pyplot as plt -import numpy as np -from PyQt5 import QtSvg -from PyQt5.QtWidgets import QApplication, QMainWindow, QWidget, QGridLayout -from numpy import array - -from flatland.envs.agent_utils import EnvAgent -from flatland.utils.graphics_layer import GraphicsLayer -from flatland.utils.graphics_qt import QtRenderer -from flatland.utils.svg import Track, Zug - - -def transform_string_svg(sSVG): - sSVG = sSVG.replace("ASCII", "UTF-8") - bySVG = bytearray(sSVG, encoding='utf-8') - return bySVG - - -def create_QtSvgWidget_from_svg_string(sSVG): - svgWidget = QtSvg.QSvgWidget() - ret = svgWidget.renderer().load(transform_string_svg(sSVG)) - if ret is False: - print("create_QtSvgWidget_from_svg_string : failed to parse:", sSVG) - return svgWidget - - -class QTGL(GraphicsLayer): - def __init__(self, width, height): - self.cell_pixels = 60 - self.tile_size = self.cell_pixels - - self.width = width - self.height = height - - # Total grid size at native scale - self.widthPx = self.width * self.cell_pixels - self.heightPx = self.height * self.cell_pixels - self.qtr = QtRenderer(self.widthPx, self.heightPx, ownWindow=True) - - self.qtr.beginFrame() - self.qtr.push() - - # This comment comes from minigrid. Not sure if it's still true. Jeremy. - # Internally, we draw at the "large" full-grid resolution, but we - # use the renderer to scale back to the desired size - self.qtr.scale(self.tile_size / self.cell_pixels, self.tile_size / self.cell_pixels) - - self.tColBg = (255, 255, 255) # white background - # self.tColBg = (220, 120, 40) # background color - self.tColRail = (0, 0, 0) # black rails - self.tColGrid = (230,) * 3 # light grey for grid - - # Draw the background of the in-world cells - self.qtr.fillRect(0, 0, self.widthPx, self.heightPx, *self.tColBg) - self.qtr.pop() - self.qtr.endFrame() - - def plot(self, gX, gY, color=None, lw=2, **kwargs): - color = self.adaptColor(color) - - self.qtr.setLineColor(*color) - lastx = lasty = None - - if False: - for x, y in zip(gX, gY): - if lastx is not None: - # print("line", lastx, lasty, x, y) - self.qtr.drawLine( - lastx * self.cell_pixels, -lasty * self.cell_pixels, - x * self.cell_pixels, -y * self.cell_pixels) - lastx = x - lasty = y - else: - gPoints = np.stack([array(gX), -array(gY)]).T * self.cell_pixels - self.qtr.setLineWidth(5) - self.qtr.drawPolyline(gPoints) - - def scatter(self, gX, gY, color=None, marker="o", s=50, *args, **kwargs): - color = self.adaptColor(color) - self.qtr.setColor(*color) - self.qtr.setLineColor(*color) - r = np.sqrt(s) - gPoints = np.stack([np.atleast_1d(gX), -np.atleast_1d(gY)]).T * self.cell_pixels - for x, y in gPoints: - self.qtr.drawCircle(x, y, r) - - def text(self, x, y, sText): - self.qtr.drawText(x * self.cell_pixels, -y * self.cell_pixels, sText) - - def prettify(self, *args, **kwargs): - pass - - def prettify2(self, width, height, cell_size): - pass - - def show(self, block=False): - pass - - def pause(self, seconds=0.00001): - pass - - def beginFrame(self): - self.qtr.beginFrame() - self.qtr.push() - self.qtr.fillRect(0, 0, self.widthPx, self.heightPx, *self.tColBg) - - def endFrame(self): - self.qtr.pop() - self.qtr.endFrame() - - -class QTSVG(GraphicsLayer): - def __init__(self, width, height, jupyter=False): - self.app = QApplication([]) - self.wWinMain = QMainWindow() - - self.wMain = QWidget(self.wWinMain) - - self.wWinMain.setCentralWidget(self.wMain) - - self.layout = QGridLayout() - self.layout.setSpacing(0) - self.wMain.setLayout(self.layout) - self.wWinMain.resize(600, 600) - self.wWinMain.show() - self.wWinMain.setFocus() - - self.track = self.track = Track() - self.lwTrack = [] - self.zug = Zug() - - self.lwAgents = [] - self.agents_prev = [] - - # svgWidget = None - - def is_raster(self): - return False - - def processEvents(self): - self.app.processEvents() - time.sleep(0.001) - - def clear_rails(self): - # print("Clear rails: ", len(self.lwTrack)) - for wRail in self.lwTrack: - self.layout.removeWidget(wRail) - self.lwTrack = [] - self.clear_agents() - - def clear_agents(self): - # print("Clear Agents: ", len(self.lwAgents)) - for wAgent in self.lwAgents: - self.layout.removeWidget(wAgent) - self.lwAgents = [] - self.agents_prev = [] - - def setRailAt(self, row, col, binTrans, iTarget=None): - if binTrans in self.track.dSvg: - sSVG = self.track.dSvg[binTrans].to_string() - svgWidget = create_QtSvgWidget_from_svg_string(sSVG) - self.layout.addWidget(svgWidget, row, col) - self.lwTrack.append(svgWidget) - else: - print("Illegal rail:", row, col, format(binTrans, "#018b")[2:]) - - def setAgentAt(self, iAgent, row, col, iDirIn, iDirOut, color=None): - if iAgent < len(self.lwAgents): - wAgent = self.lwAgents[iAgent] - agentPrev = self.agents_prev[iAgent] - - # If we have an existing agent widget, we can just move it - if wAgent is not None: - self.layout.removeWidget(wAgent) - self.layout.addWidget(wAgent, row, col) - - # We can only reuse the image if noth new and old are straight and the same: - if iDirIn == iDirOut and \ - agentPrev.direction == iDirIn and \ - agentPrev.old_direction == agentPrev.direction: - return - else: - # need to load new image - # print("new dir:", iAgent, row, col, agentPrev.direction, iDirIn) - agentPrev.direction = iDirOut - agentPrev.old_direction = iDirIn - sSVG = self.zug.getSvg(iAgent, iDirIn, iDirOut, color=color).to_string() - wAgent.renderer().load(transform_string_svg(sSVG)) - return - - # Ensure we have adequate slots in the list lwAgents - for i in range(len(self.lwAgents), iAgent + 1): - self.lwAgents.append(None) - self.agents_prev.append(None) - - # Create a new widget for the agent - sSVG = self.zug.getSvg(iAgent, iDirIn, iDirOut, color=color).to_string() - svgWidget = create_QtSvgWidget_from_svg_string(sSVG) - self.lwAgents[iAgent] = svgWidget - self.agents_prev[iAgent] = EnvAgent((row, col), iDirOut, (0, 0), old_direction=iDirIn) - self.layout.addWidget(svgWidget, row, col) - - def show(self, block=False): - self.wMain.update() - - def resize(self, env): - screen_resolution = self.app.desktop().screenGeometry() - width, height = screen_resolution.width(), screen_resolution.height() - w = np.ceil(width * 0.8 / env.width) - h = np.ceil(height * 0.8 / env.height) - self.wWinMain.resize(env.width * w, env.height * h) - self.wWinMain.move((width - env.width * w) / 2, (height - env.height * h) / 2) - - -def main2(): - gl = QTGL(10, 10) - for i in range(10): - gl.beginFrame() - gl.plot([3 + i, 4], [-4 - i, -5], color="r") - gl.endFrame() - time.sleep(1) - - -def main(): - gl = QTSVG() - - for i in range(1000): - gl.processEvents() - time.sleep(0.1) - time.sleep(1) - - -if __name__ == "__main__": - main() diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index 52892478..e02e5b42 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -9,7 +9,6 @@ from recordtype import recordtype from flatland.utils.graphics_layer import GraphicsLayer from flatland.utils.graphics_pil import PILGL, PILSVG -from flatland.utils.render_qt import QTGL, QTSVG # TODO: suggested renaming to RailEnvRenderTool, as it will only work with RailEnv! @@ -128,14 +127,13 @@ class RenderTool(object): if gl == "MPL": self.gl = MPLGL(env.width, env.height, jupyter) - elif gl == "QT": - self.gl = QTGL(env.width, env.height, jupyter) elif gl == "PIL": self.gl = PILGL(env.width, env.height, jupyter) elif gl == "PILSVG": self.gl = PILSVG(env.width, env.height, jupyter) - elif gl == "QTSVG": - self.gl = QTSVG(env.width, env.height, jupyter) + else: + print("[", gl, "] not found, switch to PILSVG") + self.gl = PILSVG(env.width, env.height, jupyter) self.new_rail = True @@ -630,7 +628,7 @@ class RenderTool(object): iSelectedAgent=iSelectedAgent, action_dict=action_dict) return - if type(self.gl) in (QTGL, PILGL): + if type(self.gl) is PILGL: self.gl.beginFrame() if type(self.gl) is MPLGL: @@ -675,12 +673,7 @@ class RenderTool(object): self.gl.prettify2(env.width, env.height, self.nPixCell) # TODO: for MPL, we don't want to call clf (called by endframe) - # for QT, we need to call endFrame() # if not show: - if type(self.gl) is QTGL: - self.gl.endFrame() - if show: - self.gl.show(block=False) if type(self.gl) is MPLGL: if show: -- GitLab