Commit c9967bf8 authored by adrian_egli2's avatar adrian_egli2
Browse files

the performance improvement is ready

parent a3527ac2
Pipeline #11452 passed with stages
in 28 minutes and 8 seconds
......@@ -61,8 +61,8 @@ class Grid4Transitions(Transitions):
# maxsize=None can be used because the number of possible transition is limited (16 bit encoded) and the
# direction/orientation is also limited (2bit). Where the 16bit are only sparse used = number of rail types
# Those methods can be cached -> the are independant of the railways (env)
maxsize_allowed = 128 # if NONE -> unlimted cache size will be used
'''
maxsize_allowed = 128 # if NONE -> unlimted cache size will be used
self.get_transitions = \
lru_cache(maxsize=maxsize_allowed, typed=False)(self.get_transitions)
self.get_transition = \
......@@ -77,7 +77,7 @@ class Grid4Transitions(Transitions):
# These bits represent all the possible dead ends
@staticmethod
#@lru_cache()
@lru_cache()
def maskDeadEnds():
return 0b0010000110000100
......@@ -246,7 +246,7 @@ class Grid4Transitions(Transitions):
return Grid4TransitionsEnum
@staticmethod
#@lru_cache()
@lru_cache()
def has_deadend(cell_transition):
"""
Checks if one entry can only by exited by a turn-around.
......@@ -265,6 +265,6 @@ class Grid4Transitions(Transitions):
return cell_transition
@staticmethod
#@lru_cache()
@lru_cache()
def get_entry_directions(cell_transition) -> List[int]:
return [(cell_transition >> ((3 - orientation) * 4)) & 15 > 0 for orientation in range(4)]
......@@ -321,8 +321,7 @@ class TreeObsForRailEnv(ObservationBuilder):
# Check number of possible transitions for agent and total number of transitions in cell (type)
cell_transitions = self.env.rail.get_transitions(*position, direction)
transition_bit = bin(self.env.rail.get_full_transitions(*position))
total_transitions = fast_argmax(cell_transitions)
total_transitions = transition_bit.count("1")
crossing_found = False
if int(transition_bit, 2) == int('1000010000100001', 2):
crossing_found = True
......
......@@ -136,11 +136,11 @@ class PILGL(GraphicsLayer):
def draw_image_xy(self, pil_img, xyPixLeftTop, layer=RAIL_LAYER, ):
# Resize all PIL images just before drawing them
# to ensure that resizing doesnt affect the
# to ensure that resizing doesnt affect the
# recolorizing strategies in place
#
# That said : All the code in this file needs
# some serious refactoring -_- to ensure the
#
# That said : All the code in this file needs
# some serious refactoring -_- to ensure the
# code style and structure is consitent.
# - Mohanty
pil_img = pil_img.resize(
......@@ -151,7 +151,7 @@ class PILGL(GraphicsLayer):
pil_mask = pil_img
else:
pil_mask = None
self.layers[layer].paste(pil_img, xyPixLeftTop, pil_mask)
def draw_image_row_col(self, pil_img, rcTopLeft, layer=RAIL_LAYER, ):
......@@ -550,7 +550,8 @@ class PILSVG(PILGL):
self.draw_image_row_col(pil_track, (row, col), layer=PILGL.RAIL_LAYER)
else:
print("Illegal rail:", row, col, format(binary_trans, "#018b")[2:], binary_trans)
print("Can't render - illegal rail or SVG element is undefined:", row, col,
format(binary_trans, "#018b")[2:], binary_trans)
if target is not None:
if is_selected:
......@@ -567,7 +568,7 @@ class PILSVG(PILGL):
xy_color_mask = np.all(rgbaImg[:, :, 0:3] - a3BaseColor != 0, axis=2)
else:
xy_color_mask = np.all(rgbaImg[:, :, 0:3] - a3BaseColor == 0, axis=2)
rgbaImg2 = np.copy(rgbaImg)
# Repaint the base color with the new color
......
......@@ -5,18 +5,17 @@ import pprint
import numpy as np
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.line_generators import sparse_line_generator
from flatland.envs.observations import TreeObsForRailEnv, Node
from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_env_action import RailEnvActions
from flatland.envs.rail_env_shortest_paths import get_shortest_paths
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.rail_trainrun_data_structures import Waypoint
from flatland.envs.line_generators import sparse_line_generator
from flatland.envs.step_utils.states import TrainState
from flatland.utils.rendertools import RenderTool
from flatland.utils.simple_rail import make_simple_rail, make_simple_rail2, make_invalid_simple_rail
from flatland.envs.rail_env_action import RailEnvActions
from flatland.envs.step_utils.states import TrainState
"""Test predictions for `flatland` package."""
......@@ -143,7 +142,7 @@ def test_shortest_path_predictor(rendering=False):
# Perform DO_NOTHING actions until all trains get to READY_TO_DEPART
for _ in range(max([agent.earliest_departure for agent in env.agents])):
env.step({}) # DO_NOTHING for all agents
env.step({}) # DO_NOTHING for all agents
if rendering:
renderer = RenderTool(env, gl="PILSVG")
......@@ -252,7 +251,7 @@ def test_shortest_path_predictor(rendering=False):
"directions {}, expected {}".format(directions, expected_directions)
def test_shortest_path_predictor_conflicts(rendering=True):
def test_shortest_path_predictor_conflicts(rendering=False):
rail, rail_map, optionals = make_invalid_simple_rail()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment