diff --git a/flatland/__init__.py b/flatland/__init__.py index b0399d2ae671472306a95956a05bb63a58b07263..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100644 --- a/flatland/__init__.py +++ b/flatland/__init__.py @@ -1,7 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Top-level package for flatland.""" - -__author__ = """S.P. Mohanty""" -__email__ = 'mohanty@aicrowd.com' -__version__ = '0.1.1' diff --git a/flatland/core/env_prediction_builder.py b/flatland/core/env_prediction_builder.py index 9f5e4dc5033ba3a789313b47d94b033251cb8276..060dbfc38ec6b035f3264db7ef394545e54387f6 100644 --- a/flatland/core/env_prediction_builder.py +++ b/flatland/core/env_prediction_builder.py @@ -29,7 +29,7 @@ class PredictionBuilder: def get(self, handle=0): """ - Called whenever step_prediction is called on the environment. + Called whenever predict is called on the environment. Parameters ------- diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py index 3338e68126ee9a198e0208e17b1986f1c9fde6c8..43f0a93f6494a8edb91b4d5a18a12f8fca0fd435 100644 --- a/flatland/envs/predictions.py +++ b/flatland/envs/predictions.py @@ -18,7 +18,7 @@ class DummyPredictorForRailEnv(PredictionBuilder): def get(self, handle=None): """ - Called whenever step_prediction is called on the environment. + Called whenever predict is called on the environment. Parameters ------- diff --git a/flatland/utils/graphics_pil.py b/flatland/utils/graphics_pil.py index d93b0f754507fe8981e68c34368e62d0c4461b43..a0dd2db7965f463351a47b99fd97894f96d6c593 100644 --- a/flatland/utils/graphics_pil.py +++ b/flatland/utils/graphics_pil.py @@ -1,12 +1,12 @@ import io import os -import site import time import tkinter as tk import numpy as np from PIL import Image, ImageDraw, ImageTk # , ImageFont from numpy import array +from pkg_resources import resource_string as resource_bytes from flatland.utils.graphics_layer import GraphicsLayer @@ -239,18 +239,9 @@ class PILSVG(PILGL): self.lwAgents = [] self.agents_prev = [] - def pilFromSvgFile(self, sfPath): - try: - with open(sfPath, "r") as fIn: - bytesPNG = svg2png(file_obj=fIn, output_height=self.nPixCell, output_width=self.nPixCell) - except: # noqa: E722 - newList = '' - for directory in site.getsitepackages(): - x = [word for word in os.listdir(directory) if word.startswith('flatland')] - if len(x) > 0: - newList = directory + '/' + x[0] - with open(newList + '/' + sfPath, "r") as fIn: - bytesPNG = svg2png(file_obj=fIn, output_height=self.nPixCell, output_width=self.nPixCell) + def pilFromSvgFile(self, package, resource): + bytestring = resource_bytes(package, resource) + bytesPNG = svg2png(bytestring=bytestring, output_height=self.nPixCell, output_width=self.nPixCell) with io.BytesIO(bytesPNG) as fIn: pil_img = Image.open(fIn) pil_img.load() @@ -313,10 +304,7 @@ class PILSVG(PILGL): lDirs = list("NESW") - # svgBG = SVG("./svg/Background_#91D1DD.svg") - for sTrans, sFile in dDirFile.items(): - sPathSvg = "./svg/" + sFile # Translate the ascii transition description in the format "NE WS" to the # binary list of transitions as per RailEnv - NESW (in) x NESW (out) @@ -330,7 +318,7 @@ class PILSVG(PILGL): sTrans16 = "".join(lTrans16) binTrans = int(sTrans16, 2) - pilRail = self.pilFromSvgFile(sPathSvg) + pilRail = self.pilFromSvgFile('svg', sFile) if rotate: # For rotations, we also store the base image @@ -367,7 +355,7 @@ class PILSVG(PILGL): print("Illegal target rail:", row, col, format(binTrans, "#018b")[2:]) if isSelected: - svgBG = self.pilFromSvgFile("./svg/Selected_Target.svg") + svgBG = self.pilFromSvgFile("svg", "Selected_Target.svg") self.clear_layer(3, 0) self.drawImageRC(svgBG, (row, col), layer=3) @@ -390,9 +378,9 @@ class PILSVG(PILGL): # Seed initial train/zug files indexed by tuple(iDirIn, iDirOut): dDirsFile = { - (0, 0): "svg/Zug_Gleis_#0091ea.svg", - (1, 2): "svg/Zug_1_Weiche_#0091ea.svg", - (0, 3): "svg/Zug_2_Weiche_#0091ea.svg" + (0, 0): "Zug_Gleis_#0091ea.svg", + (1, 2): "Zug_1_Weiche_#0091ea.svg", + (0, 3): "Zug_2_Weiche_#0091ea.svg" } # "paint" color of the train images we load @@ -403,7 +391,7 @@ class PILSVG(PILGL): for tDirs, sPathSvg in dDirsFile.items(): iDirIn, iDirOut = tDirs - pilZug = self.pilFromSvgFile(sPathSvg) + pilZug = self.pilFromSvgFile("svg", sPathSvg) # Rotate both the directions and the image and save in the dict for iDirRot in range(4): @@ -429,7 +417,7 @@ class PILSVG(PILGL): self.drawImageRC(pilZug, (row, col), layer=1) if isSelected: - svgBG = self.pilFromSvgFile("./svg/Selected_Agent.svg") + svgBG = self.pilFromSvgFile("svg", "Selected_Agent.svg") self.clear_layer(2, 0) self.drawImageRC(svgBG, (row, col), layer=2) diff --git a/setup.py b/setup.py index d517c279cb5d31a09174b71419b3615160abc39a..39bffd173a000a1fa40018cc5b856c9a26e2e253 100644 --- a/setup.py +++ b/setup.py @@ -63,12 +63,12 @@ else: def get_all_svg_files(directory='./svg/'): ret = [] for f in os.listdir(directory): - ret.append(directory + f) + if f != '__pycache__': + ret.append(directory + f) return ret # Gather requirements from requirements_dev.txt -# TODO : We could potentially split up the test/dev dependencies later install_reqs = [] requirements_path = 'requirements_dev.txt' with open(requirements_path, 'r') as f: diff --git a/svg/__init__.py b/svg/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391