Skip to content
Snippets Groups Projects
Commit 483792a3 authored by u214892's avatar u214892
Browse files

#57 PILSvg from resource

parent aa91c6ff
No related branches found
No related tags found
No related merge requests found
# -*- coding: utf-8 -*-
"""Top-level package for flatland."""
__author__ = """S.P. Mohanty"""
__email__ = 'mohanty@aicrowd.com'
__version__ = '0.1.1'
...@@ -29,7 +29,7 @@ class PredictionBuilder: ...@@ -29,7 +29,7 @@ class PredictionBuilder:
def get(self, handle=0): def get(self, handle=0):
""" """
Called whenever step_prediction is called on the environment. Called whenever predict is called on the environment.
Parameters Parameters
------- -------
......
...@@ -18,7 +18,7 @@ class DummyPredictorForRailEnv(PredictionBuilder): ...@@ -18,7 +18,7 @@ class DummyPredictorForRailEnv(PredictionBuilder):
def get(self, handle=None): def get(self, handle=None):
""" """
Called whenever step_prediction is called on the environment. Called whenever predict is called on the environment.
Parameters Parameters
------- -------
......
import io import io
import os import os
import site
import time import time
import tkinter as tk import tkinter as tk
import numpy as np import numpy as np
from PIL import Image, ImageDraw, ImageTk # , ImageFont from PIL import Image, ImageDraw, ImageTk # , ImageFont
from numpy import array from numpy import array
from pkg_resources import resource_string as resource_bytes
from flatland.utils.graphics_layer import GraphicsLayer from flatland.utils.graphics_layer import GraphicsLayer
...@@ -239,18 +239,9 @@ class PILSVG(PILGL): ...@@ -239,18 +239,9 @@ class PILSVG(PILGL):
self.lwAgents = [] self.lwAgents = []
self.agents_prev = [] self.agents_prev = []
def pilFromSvgFile(self, sfPath): def pilFromSvgFile(self, package, resource):
try: bytestring = resource_bytes(package, resource)
with open(sfPath, "r") as fIn: bytesPNG = svg2png(bytestring=bytestring, output_height=self.nPixCell, output_width=self.nPixCell)
bytesPNG = svg2png(file_obj=fIn, output_height=self.nPixCell, output_width=self.nPixCell)
except: # noqa: E722
newList = ''
for directory in site.getsitepackages():
x = [word for word in os.listdir(directory) if word.startswith('flatland')]
if len(x) > 0:
newList = directory + '/' + x[0]
with open(newList + '/' + sfPath, "r") as fIn:
bytesPNG = svg2png(file_obj=fIn, output_height=self.nPixCell, output_width=self.nPixCell)
with io.BytesIO(bytesPNG) as fIn: with io.BytesIO(bytesPNG) as fIn:
pil_img = Image.open(fIn) pil_img = Image.open(fIn)
pil_img.load() pil_img.load()
...@@ -313,10 +304,7 @@ class PILSVG(PILGL): ...@@ -313,10 +304,7 @@ class PILSVG(PILGL):
lDirs = list("NESW") lDirs = list("NESW")
# svgBG = SVG("./svg/Background_#91D1DD.svg")
for sTrans, sFile in dDirFile.items(): for sTrans, sFile in dDirFile.items():
sPathSvg = "./svg/" + sFile
# Translate the ascii transition description in the format "NE WS" to the # Translate the ascii transition description in the format "NE WS" to the
# binary list of transitions as per RailEnv - NESW (in) x NESW (out) # binary list of transitions as per RailEnv - NESW (in) x NESW (out)
...@@ -330,7 +318,7 @@ class PILSVG(PILGL): ...@@ -330,7 +318,7 @@ class PILSVG(PILGL):
sTrans16 = "".join(lTrans16) sTrans16 = "".join(lTrans16)
binTrans = int(sTrans16, 2) binTrans = int(sTrans16, 2)
pilRail = self.pilFromSvgFile(sPathSvg) pilRail = self.pilFromSvgFile('svg', sFile)
if rotate: if rotate:
# For rotations, we also store the base image # For rotations, we also store the base image
...@@ -367,7 +355,7 @@ class PILSVG(PILGL): ...@@ -367,7 +355,7 @@ class PILSVG(PILGL):
print("Illegal target rail:", row, col, format(binTrans, "#018b")[2:]) print("Illegal target rail:", row, col, format(binTrans, "#018b")[2:])
if isSelected: if isSelected:
svgBG = self.pilFromSvgFile("./svg/Selected_Target.svg") svgBG = self.pilFromSvgFile("svg", "Selected_Target.svg")
self.clear_layer(3, 0) self.clear_layer(3, 0)
self.drawImageRC(svgBG, (row, col), layer=3) self.drawImageRC(svgBG, (row, col), layer=3)
...@@ -390,9 +378,9 @@ class PILSVG(PILGL): ...@@ -390,9 +378,9 @@ class PILSVG(PILGL):
# Seed initial train/zug files indexed by tuple(iDirIn, iDirOut): # Seed initial train/zug files indexed by tuple(iDirIn, iDirOut):
dDirsFile = { dDirsFile = {
(0, 0): "svg/Zug_Gleis_#0091ea.svg", (0, 0): "Zug_Gleis_#0091ea.svg",
(1, 2): "svg/Zug_1_Weiche_#0091ea.svg", (1, 2): "Zug_1_Weiche_#0091ea.svg",
(0, 3): "svg/Zug_2_Weiche_#0091ea.svg" (0, 3): "Zug_2_Weiche_#0091ea.svg"
} }
# "paint" color of the train images we load # "paint" color of the train images we load
...@@ -403,7 +391,7 @@ class PILSVG(PILGL): ...@@ -403,7 +391,7 @@ class PILSVG(PILGL):
for tDirs, sPathSvg in dDirsFile.items(): for tDirs, sPathSvg in dDirsFile.items():
iDirIn, iDirOut = tDirs iDirIn, iDirOut = tDirs
pilZug = self.pilFromSvgFile(sPathSvg) pilZug = self.pilFromSvgFile("svg", sPathSvg)
# Rotate both the directions and the image and save in the dict # Rotate both the directions and the image and save in the dict
for iDirRot in range(4): for iDirRot in range(4):
...@@ -429,7 +417,7 @@ class PILSVG(PILGL): ...@@ -429,7 +417,7 @@ class PILSVG(PILGL):
self.drawImageRC(pilZug, (row, col), layer=1) self.drawImageRC(pilZug, (row, col), layer=1)
if isSelected: if isSelected:
svgBG = self.pilFromSvgFile("./svg/Selected_Agent.svg") svgBG = self.pilFromSvgFile("svg", "Selected_Agent.svg")
self.clear_layer(2, 0) self.clear_layer(2, 0)
self.drawImageRC(svgBG, (row, col), layer=2) self.drawImageRC(svgBG, (row, col), layer=2)
......
...@@ -63,12 +63,12 @@ else: ...@@ -63,12 +63,12 @@ else:
def get_all_svg_files(directory='./svg/'): def get_all_svg_files(directory='./svg/'):
ret = [] ret = []
for f in os.listdir(directory): for f in os.listdir(directory):
ret.append(directory + f) if f != '__pycache__':
ret.append(directory + f)
return ret return ret
# Gather requirements from requirements_dev.txt # Gather requirements from requirements_dev.txt
# TODO : We could potentially split up the test/dev dependencies later
install_reqs = [] install_reqs = []
requirements_path = 'requirements_dev.txt' requirements_path = 'requirements_dev.txt'
with open(requirements_path, 'r') as f: with open(requirements_path, 'r') as f:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment