Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • flatland/flatland
  • stefan_otte/flatland
  • jiaodaxiaozi/flatland
  • sfwatergit/flatland
  • utozx126/flatland
  • ChenKuanSun/flatland
  • ashivani/flatland
  • minhhoa/flatland
  • pranjal_dhole/flatland
  • darthgera123/flatland
  • rivesunder/flatland
  • thomaslecat/flatland
  • joel_joseph/flatland
  • kchour/flatland
  • alex_zharichenko/flatland
  • yoogottamk/flatland
  • troye_fang/flatland
  • elrichgro/flatland
  • jun_jin/flatland
  • nimishsantosh107/flatland
20 results
Show changes
Showing
with 950 additions and 518 deletions
...@@ -10,29 +10,45 @@ from numpy import array ...@@ -10,29 +10,45 @@ from numpy import array
import flatland.utils.rendertools as rt import flatland.utils.rendertools as rt
from flatland.core.grid.grid4_utils import mirror from flatland.core.grid.grid4_utils import mirror
from flatland.envs.agent_utils import EnvAgent, EnvAgentStatic from flatland.envs.agent_utils import EnvAgent
from flatland.envs.generators import complex_rail_generator, empty_rail_generator from flatland.envs.line_generators import sparse_line_generator
from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.rail_env import RailEnv, random_rail_generator from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator, empty_rail_generator
from flatland.envs.persistence import RailEnvPersister
class EditorMVC(object): class EditorMVC(object):
""" EditorMVC - a class to encompass and assemble the Jupyter Editor Model-View-Controller. """ EditorMVC - a class to encompass and assemble the Jupyter Editor Model-View-Controller.
""" """
def __init__(self, env=None, sGL="PIL"): def __init__(self, env=None, sGL="PIL", env_filename="temp.pkl"):
""" Create an Editor MVC assembly around a railenv, or create one if None. """ Create an Editor MVC assembly around a railenv, or create one if None.
""" """
if env is None: if env is None:
env = RailEnv(width=10, nAgents = 3
height=10, n_cities = 2
rail_generator=empty_rail_generator(), max_rails_between_cities = 2
number_of_agents=0, max_rails_in_city = 4
obs_builder_object=TreeObsForRailEnv(max_depth=2)) seed = 0
env = RailEnv(
width=20,
height=30,
rail_generator=sparse_rail_generator(
max_num_cities=n_cities,
seed=seed,
grid_mode=True,
max_rails_between_cities=max_rails_between_cities,
max_rail_pairs_in_city=max_rails_in_city
),
line_generator=sparse_line_generator(),
number_of_agents=nAgents,
obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv())
)
env.reset() env.reset()
self.editor = EditorModel(env) self.editor = EditorModel(env, env_filename=env_filename)
self.editor.view = self.view = View(self.editor, sGL=sGL) self.editor.view = self.view = View(self.editor, sGL=sGL)
self.view.controller = self.editor.controller = self.controller = Controller(self.editor, self.view) self.view.controller = self.editor.controller = self.controller = Controller(self.editor, self.view)
self.view.init_canvas() self.view.init_canvas()
...@@ -43,19 +59,20 @@ class View(object): ...@@ -43,19 +59,20 @@ class View(object):
""" The Jupyter Editor View - creates and holds the widgets comprising the Editor. """ The Jupyter Editor View - creates and holds the widgets comprising the Editor.
""" """
def __init__(self, editor, sGL="MPL"): def __init__(self, editor, sGL="MPL", screen_width=1200, screen_height=1200):
self.editor = self.model = editor self.editor = self.model = editor
self.sGL = sGL self.sGL = sGL
self.xyScreen = (screen_width, screen_height)
def display(self): def display(self):
self.wOutput.clear_output() self.output_generator.clear_output()
return self.wMain return self.wMain
def init_canvas(self): def init_canvas(self):
# update the rendertool with the env # update the rendertool with the env
self.new_env() self.new_env()
self.oRT.renderEnv(spacing=False, arrows=False, sRailColor="gray", show=False) self.oRT.render_env(show=False)
img = self.oRT.getImage() img = self.oRT.get_image()
self.wImage = jpy_canvas.Canvas(img) self.wImage = jpy_canvas.Canvas(img)
self.yxSize = self.wImage.data.shape[:2] self.yxSize = self.wImage.data.shape[:2]
self.writableData = np.copy(self.wImage.data) # writable copy of image - wid_img.data is somehow readonly self.writableData = np.copy(self.wImage.data) # writable copy of image - wid_img.data is somehow readonly
...@@ -67,51 +84,51 @@ class View(object): ...@@ -67,51 +84,51 @@ class View(object):
def init_widgets(self): def init_widgets(self):
# Debug checkbox - enable logging in the Output widget # Debug checkbox - enable logging in the Output widget
self.wDebug = ipywidgets.Checkbox(description="Debug") self.debug = ipywidgets.Checkbox(description="Debug")
self.wDebug.observe(self.controller.setDebug, names="value") self.debug.observe(self.controller.set_debug, names="value")
# Separate checkbox for mouse move events - they are very verbose # Separate checkbox for mouse move events - they are very verbose
self.wDebug_move = Checkbox(description="Debug mouse move") self.debug_move = Checkbox(description="Debug mouse move")
self.wDebug_move.observe(self.controller.setDebugMove, names="value") self.debug_move.observe(self.controller.set_debug_move, names="value")
# This is like a cell widget where loggin goes # This is like a cell widget where loggin goes
self.wOutput = Output() self.output_generator = Output()
# Filename textbox # Filename textbox
self.wFilename = Text(description="Filename") self.filename = Text(description="Filename")
self.wFilename.value = self.model.env_filename self.filename.value = self.model.env_filename
self.wFilename.observe(self.controller.setFilename, names="value") self.filename.observe(self.controller.set_filename, names="value")
# Size of environment when regenerating # Size of environment when regenerating
self.wRegenSizeWidth = IntSlider(value=10, min=5, max=100, step=5, description="Regen Size (Width)", self.regen_width = IntSlider(value=10, min=5, max=100, step=5, description="Regen Size (Width)",
tip="Click Regenerate after changing this") tip="Click Regenerate after changing this")
self.wRegenSizeWidth.observe(self.controller.setRegenSizeWidth, names="value") self.regen_width.observe(self.controller.set_regen_width, names="value")
self.wRegenSizeHeight = IntSlider(value=10, min=5, max=100, step=5, description="Regen Size (Height)", self.regen_height = IntSlider(value=10, min=5, max=100, step=5, description="Regen Size (Height)",
tip="Click Regenerate after changing this") tip="Click Regenerate after changing this")
self.wRegenSizeHeight.observe(self.controller.setRegenSizeHeight, names="value") self.regen_height.observe(self.controller.set_regen_height, names="value")
# Number of Agents when regenerating # Number of Agents when regenerating
self.wRegenNAgents = IntSlider(value=1, min=0, max=5, step=1, description="# Agents", self.regen_n_agents = IntSlider(value=1, min=0, max=5, step=1, description="# Agents",
tip="Click regenerate or reset after changing this") tip="Click regenerate or reset after changing this")
self.wRegenMethod = RadioButtons(description="Regen\nMethod", options=["Empty", "Random Cell"]) self.regen_method = RadioButtons(description="Regen\nMethod", options=["Empty", "Sparse"])
self.wReplaceAgents = Checkbox(value=True, description="Replace Agents") self.replace_agents = Checkbox(value=True, description="Replace Agents")
self.wTab = Tab() self.wTab = Tab()
tab_contents = ["Regen", "Observation"] tab_contents = ["Regen", "Observation"]
for i, title in enumerate(tab_contents): for i, title in enumerate(tab_contents):
self.wTab.set_title(i, title) self.wTab.set_title(i, title)
self.wTab.children = [ self.wTab.children = [
VBox([self.wRegenSizeWidth, self.wRegenSizeHeight, self.wRegenNAgents, self.wRegenMethod]) VBox([self.regen_width, self.regen_height, self.regen_n_agents, self.regen_method])
] ]
# abbreviated description of buttons and the methods they call # abbreviated description of buttons and the methods they call
ldButtons = [ ldButtons = [
dict(name="Refresh", method=self.controller.refresh, tip="Redraw only"), dict(name="Refresh", method=self.controller.refresh, tip="Redraw only"),
dict(name="Rotate Agent", method=self.controller.rotate_agent, tip="Rotate selected agent"), dict(name="Rotate Agent", method=self.controller.rotate_agent, tip="Rotate selected agent"),
dict(name="Restart Agents", method=self.controller.restartAgents, dict(name="Restart Agents", method=self.controller.reset_agents,
tip="Move agents back to start positions"), tip="Move agents back to start positions"),
dict(name="Random", method=self.controller.reset, dict(name="Random", method=self.controller.reset,
tip="Generate a randomized scene, including regen rail + agents"), tip="Generate a randomized scene, including regen rail + agents"),
...@@ -119,7 +136,7 @@ class View(object): ...@@ -119,7 +136,7 @@ class View(object):
tip="Regenerate the rails using the method selected below"), tip="Regenerate the rails using the method selected below"),
dict(name="Load", method=self.controller.load), dict(name="Load", method=self.controller.load),
dict(name="Save", method=self.controller.save), dict(name="Save", method=self.controller.save),
dict(name="Save as image", method=self.controller.saveImage) dict(name="Save as image", method=self.controller.save_image)
] ]
self.lwButtons = [] self.lwButtons = []
...@@ -130,35 +147,38 @@ class View(object): ...@@ -130,35 +147,38 @@ class View(object):
self.lwButtons.append(wButton) self.lwButtons.append(wButton)
self.wVbox_controls = VBox([ self.wVbox_controls = VBox([
self.wFilename, self.filename,
*self.lwButtons, *self.lwButtons,
self.wTab]) self.wTab])
self.wMain = HBox([self.wImage, self.wVbox_controls]) self.wMain = HBox([self.wImage, self.wVbox_controls])
def drawStroke(self): def draw_stroke(self):
pass pass
def new_env(self): def new_env(self):
""" Tell the view to update its graphics when a new env is created. """ Tell the view to update its graphics when a new env is created.
""" """
self.oRT = rt.RenderTool(self.editor.env, gl=self.sGL) self.oRT = rt.RenderTool(self.editor.env, gl=self.sGL, show_debug=True,
screen_height=self.xyScreen[1], screen_width=self.xyScreen[0])
def redraw(self): def redraw(self):
with self.wOutput: with self.output_generator:
self.oRT.set_new_rail() self.oRT.set_new_rail()
self.model.env.reset_agents()
self.model.env.agents = self.model.env.agents_static
for a in self.model.env.agents: for a in self.model.env.agents:
if hasattr(a, 'old_position') is False: if hasattr(a, 'old_position') is False:
a.old_position = a.position a.old_position = a.position
if hasattr(a, 'old_direction') is False: if hasattr(a, 'old_direction') is False:
a.old_direction = a.direction a.old_direction = a.direction
self.oRT.renderEnv(spacing=False, arrows=False, sRailColor="gray", agents=True, self.oRT.render_env(show_agents=True,
show=False, iSelectedAgent=self.model.iSelectedAgent, show_inactive_agents=True,
show_observations=False) show=False,
img = self.oRT.getImage() selected_agent=self.model.selected_agent,
show_observations=False,
)
img = self.oRT.get_image()
self.wImage.data = img self.wImage.data = img
self.writableData = np.copy(self.wImage.data) self.writableData = np.copy(self.wImage.data)
...@@ -167,7 +187,7 @@ class View(object): ...@@ -167,7 +187,7 @@ class View(object):
self.yxSize = self.wImage.data.shape[:2] self.yxSize = self.wImage.data.shape[:2]
return img return img
def redisplayImage(self): def redisplay_image(self):
if self.writableData is not None: if self.writableData is not None:
# This updates the image in the browser to be the new edited version # This updates the image in the browser to be the new edited version
self.wImage.data = self.writableData self.wImage.data = self.writableData
...@@ -178,16 +198,18 @@ class View(object): ...@@ -178,16 +198,18 @@ class View(object):
self.writableData[(y - 2):(y + 2), (x - 2):(x + 2), :3] = 0 self.writableData[(y - 2):(y + 2), (x - 2):(x + 2), :3] = 0
def xy_to_rc(self, x, y): def xy_to_rc(self, x, y):
rcCell = ((array([y, x]) - self.yxBase)) rc_cell = ((array([y, x]) - self.yxBase))
nX = np.floor((self.yxSize[0] - self.yxBase[0]) / self.model.env.height) nX = np.floor((self.yxSize[0] - self.yxBase[0]) / self.model.env.height)
nY = np.floor((self.yxSize[1] - self.yxBase[1]) / self.model.env.width) nY = np.floor((self.yxSize[1] - self.yxBase[1]) / self.model.env.width)
rcCell[0] = max(0, min(np.floor(rcCell[0] / nY), self.model.env.height - 1)) rc_cell[0] = max(0, min(np.floor(rc_cell[0] / nY), self.model.env.height - 1))
rcCell[1] = max(0, min(np.floor(rcCell[1] / nX), self.model.env.width - 1)) rc_cell[1] = max(0, min(np.floor(rc_cell[1] / nX), self.model.env.width - 1))
return rcCell
# Using numpy arrays for coords not currently supported downstream in the env, observations, etc
return tuple(rc_cell)
def log(self, *args, **kwargs): def log(self, *args, **kwargs):
if self.wOutput: if self.output_generator:
with self.wOutput: with self.output_generator:
print(*args, **kwargs) print(*args, **kwargs)
else: else:
print(*args, **kwargs) print(*args, **kwargs)
...@@ -204,10 +226,10 @@ class Controller(object): ...@@ -204,10 +226,10 @@ class Controller(object):
def __init__(self, model, view): def __init__(self, model, view):
self.editor = self.model = model self.editor = self.model = model
self.view = view self.view = view
self.qEvents = deque() self.q_events = deque()
self.drawMode = "Draw" self.drawMode = "Draw"
def setModel(self, model): def set_model(self, model):
self.model = model self.model = model
def on_click(self, wid, event): def on_click(self, wid, event):
...@@ -215,38 +237,38 @@ class Controller(object): ...@@ -215,38 +237,38 @@ class Controller(object):
y = event['canvasY'] y = event['canvasY']
self.debug("debug:", x, y) self.debug("debug:", x, y)
rcCell = self.view.xy_to_rc(x, y) rc_cell = self.view.xy_to_rc(x, y)
bShift = event["shiftKey"] bShift = event["shiftKey"]
bCtrl = event["ctrlKey"] bCtrl = event["ctrlKey"]
bAlt = event["altKey"] bAlt = event["altKey"]
if bCtrl and not bShift and not bAlt: if bCtrl and not bShift and not bAlt:
self.model.click_agent(rcCell) self.model.click_agent(rc_cell)
self.lrcStroke = [] self.lrcStroke = []
elif bShift and bCtrl: elif bShift and bCtrl:
self.model.add_target(rcCell) self.model.add_target(rc_cell)
self.lrcStroke = [] self.lrcStroke = []
elif bAlt and not bShift and not bCtrl: elif bAlt and not bShift and not bCtrl:
self.model.clearCell(rcCell) self.model.clear_cell(rc_cell)
self.lrcStroke = [] self.lrcStroke = []
self.debug("click in cell", rcCell) self.debug("click in cell", rc_cell)
self.model.debug_cell(rcCell) self.model.debug_cell(rc_cell)
if self.model.iSelectedAgent is not None: if self.model.selected_agent is not None:
self.lrcStroke = [] self.lrcStroke = []
def setDebug(self, dEvent): def set_debug(self, event):
self.model.setDebug(dEvent["new"]) self.model.set_debug(event["new"])
def setDebugMove(self, dEvent): def set_debug_move(self, event):
self.model.setDebug_move(dEvent["new"]) self.model.set_debug_move(event["new"])
def setDrawMode(self, dEvent): def set_draw_mode(self, event):
self.drawMode = dEvent["new"] self.set_draw_mode = event["new"]
def setFilename(self, event): def set_filename(self, event):
self.model.setFilename(event["new"]) self.model.set_filename(event["new"])
def on_mouse_move(self, wid, event): def on_mouse_move(self, wid, event):
"""Mouse motion event handler for drawing. """Mouse motion event handler for drawing.
...@@ -254,60 +276,63 @@ class Controller(object): ...@@ -254,60 +276,63 @@ class Controller(object):
x = event['canvasX'] x = event['canvasX']
y = event['canvasY'] y = event['canvasY']
qEvents = self.qEvents q_events = self.q_events
if self.model.bDebug and (event["buttons"] > 0 or self.model.bDebug_move): if self.model.debug_bool and (event["buttons"] > 0 or self.model.debug_move_bool):
self.debug("debug:", len(qEvents), event) self.debug("debug:", len(q_events), event)
# If the mouse is held down, enqueue an event in our own queue # If the mouse is held down, enqueue an event in our own queue
# The intention was to avoid too many redraws. # The intention was to avoid too many redraws.
# Reset the lrcStroke list, if ALT, CTRL or SHIFT pressed # Reset the lrcStroke list, if ALT, CTRL or SHIFT pressed
if event["buttons"] > 0: if event["buttons"] > 0:
qEvents.append((time.time(), x, y)) q_events.append((time.time(), x, y))
bShift = event["shiftKey"] bShift = event["shiftKey"]
bCtrl = event["ctrlKey"] bCtrl = event["ctrlKey"]
bAlt = event["altKey"] bAlt = event["altKey"]
if bShift: if bShift:
self.lrcStroke = [] self.lrcStroke = []
while len(qEvents) > 0: while len(q_events) > 0:
t, x, y = qEvents.popleft() t, x, y = q_events.popleft()
return return
if bCtrl: if bCtrl:
self.lrcStroke = [] self.lrcStroke = []
while len(qEvents) > 0: while len(q_events) > 0:
t, x, y = qEvents.popleft() t, x, y = q_events.popleft()
return return
if bAlt: if bAlt:
self.lrcStroke = [] self.lrcStroke = []
while len(qEvents) > 0: while len(q_events) > 0:
t, x, y = qEvents.popleft() t, x, y = q_events.popleft()
return return
else: else:
self.lrcStroke = [] self.lrcStroke = []
if self.model.iSelectedAgent is not None: # JW: I think this clause causes all editing to fail once an agent is selected.
self.lrcStroke = [] # I also can't see why it's necessary. So I've if-falsed it out.
while len(qEvents) > 0: if False:
t, x, y = qEvents.popleft() if self.model.selected_agent is not None:
return self.lrcStroke = []
while len(q_events) > 0:
t, x, y = q_events.popleft()
return
# Process the events in our queue: # Process the events in our queue:
# Draw a black square to indicate a trail # Draw a black square to indicate a trail
# Convert the xy position to a cell rc # Convert the xy position to a cell rc
# Enqueue transitions across cells in another queue # Enqueue transitions across cells in another queue
if len(qEvents) > 0: if len(q_events) > 0:
tNow = time.time() t_now = time.time()
if tNow - qEvents[0][0] > 0.1: # wait before trying to draw if t_now - q_events[0][0] > 0.1: # wait before trying to draw
while len(qEvents) > 0: while len(q_events) > 0:
t, x, y = qEvents.popleft() # get events from our queue t, x, y = q_events.popleft() # get events from our queue
self.view.drag_path_element(x, y) self.view.drag_path_element(x, y)
# Translate and scale from x,y to integer row,col (note order change) # Translate and scale from x,y to integer row,col (note order change)
rcCell = self.view.xy_to_rc(x, y) rc_cell = self.view.xy_to_rc(x, y)
self.editor.drag_path_element(rcCell) self.editor.drag_path_element(rc_cell)
self.view.redisplayImage() self.view.redisplay_image()
else: else:
self.model.mod_path(not event["shiftKey"]) self.model.mod_path(not event["shiftKey"])
...@@ -320,44 +345,39 @@ class Controller(object): ...@@ -320,44 +345,39 @@ class Controller(object):
self.model.clear() self.model.clear()
def reset(self, event): def reset(self, event):
self.log("Reset - nAgents:", self.view.wRegenNAgents.value) self.log("Reset - nAgents:", self.view.regen_n_agents.value)
self.log("Reset - size:", self.model.regen_size_width) self.log("Reset - size:", self.model.regen_size_width)
self.log("Reset - size:", self.model.regen_size_height) self.log("Reset - size:", self.model.regen_size_height)
self.model.reset(replace_agents=self.view.wReplaceAgents.value, self.model.reset(regenerate_schedule=self.view.replace_agents.value,
nAgents=self.view.wRegenNAgents.value) nAgents=self.view.regen_n_agents.value)
def rotate_agent(self, event): def rotate_agent(self, event):
self.log("Rotate Agent:", self.model.iSelectedAgent) self.log("Rotate Agent:", self.model.selected_agent)
if self.model.iSelectedAgent is not None: if self.model.selected_agent is not None:
for iAgent, agent in enumerate(self.model.env.agents_static): for agent_idx, agent in enumerate(self.model.env.agents):
if agent is None: if agent is None:
continue continue
if iAgent == self.model.iSelectedAgent: if agent_idx == self.model.selected_agent:
agent.direction = (agent.direction + 1) % 4 agent.initial_direction = (agent.initial_direction + 1) % 4
agent.direction = agent.initial_direction
agent.old_direction = agent.direction agent.old_direction = agent.direction
self.model.redraw() self.model.redraw()
def restartAgents(self, event): def reset_agents(self, event):
self.log("Restart Agents - nAgents:", self.view.wRegenNAgents.value) self.log("Restart Agents - nAgents:", self.view.regen_n_agents.value)
if self.model.init_agents_static is not None: self.model.env.reset(False, False)
self.model.env.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in
self.model.init_agents_static]
self.model.env.agents = None
self.model.init_agents_static = None
self.model.env.restart_agents()
self.model.env.reset(False, False)
self.refresh(event) self.refresh(event)
def regenerate(self, event): def regenerate(self, event):
method = self.view.wRegenMethod.value method = self.view.regen_method.value
nAgents = self.view.wRegenNAgents.value n_agents = self.view.regen_n_agents.value
self.model.regenerate(method, nAgents) self.model.regenerate(method, n_agents)
def setRegenSizeWidth(self, event): def set_regen_width(self, event):
self.model.setRegenSizeWidth(event["new"]) self.model.set_regen_width(event["new"])
def setRegenSizeHeight(self, event): def set_regen_height(self, event):
self.model.setRegenSizeHeight(event["new"]) self.model.set_regen_height(event["new"])
def load(self, event): def load(self, event):
self.model.load() self.model.load()
...@@ -365,8 +385,8 @@ class Controller(object): ...@@ -365,8 +385,8 @@ class Controller(object):
def save(self, event): def save(self, event):
self.model.save() self.model.save()
def saveImage(self, event): def save_image(self, event):
self.model.saveImage() self.model.save_image()
def step(self, event): def step(self, event):
self.model.step() self.model.step()
...@@ -382,7 +402,7 @@ class Controller(object): ...@@ -382,7 +402,7 @@ class Controller(object):
class EditorModel(object): class EditorModel(object):
def __init__(self, env): def __init__(self, env, env_filename="temp.pkl"):
self.view = None self.view = None
self.env = env self.env = env
self.regen_size_width = 10 self.regen_size_width = 10
...@@ -392,16 +412,15 @@ class EditorModel(object): ...@@ -392,16 +412,15 @@ class EditorModel(object):
self.iTransLast = -1 self.iTransLast = -1
self.gRCTrans = array([[-1, 0], [0, 1], [1, 0], [0, -1]]) # NESW in RC self.gRCTrans = array([[-1, 0], [0, 1], [1, 0], [0, -1]]) # NESW in RC
self.bDebug = False self.debug_bool = False
self.bDebug_move = False self.debug_move_bool = False
self.wid_output = None self.wid_output = None
self.drawMode = "Draw" self.draw_mode = "Draw"
self.env_filename = "temp.pkl" self.env_filename = env_filename
self.set_env(env) self.set_env(env)
self.iSelectedAgent = None self.selected_agent = None
self.init_agents_static = None
self.thread = None self.thread = None
self.saveImageCnt = 0 self.save_image_count = 0
def set_env(self, env): def set_env(self, env):
""" """
...@@ -409,23 +428,23 @@ class EditorModel(object): ...@@ -409,23 +428,23 @@ class EditorModel(object):
""" """
self.env = env self.env = env
def setDebug(self, bDebug): def set_debug(self, debug):
self.bDebug = bDebug self.debug_bool = debug
self.log("Set Debug:", self.bDebug) self.log("Set Debug:", self.debug_bool)
def setDebugMove(self, bDebug): def set_debug_move(self, debug):
self.bDebug_move = bDebug self.debug_move_bool = debug
self.log("Set DebugMove:", self.bDebug_move) self.log("Set DebugMove:", self.debug_move_bool)
def setDrawMode(self, sDrawMode): def set_draw_mode(self, draw_mode):
self.drawMode = sDrawMode self.draw_mode = draw_mode
def interpolate_path(self, rcLast, rcCell): def interpolate_pair(self, rcLast, rc_cell):
if np.array_equal(rcLast, rcCell): if np.array_equal(rcLast, rc_cell):
return [] return []
rcLast = array(rcLast) rcLast = array(rcLast)
rcCell = array(rcCell) rc_cell = array(rc_cell)
rcDelta = rcCell - rcLast rcDelta = rc_cell - rcLast
lrcInterp = [] # extra row,col points lrcInterp = [] # extra row,col points
...@@ -457,7 +476,16 @@ class EditorModel(object): ...@@ -457,7 +476,16 @@ class EditorModel(object):
lrcInterp = list(map(tuple, g2Interp)) lrcInterp = list(map(tuple, g2Interp))
return lrcInterp return lrcInterp
def drag_path_element(self, rcCell): def interpolate_path(self, lrcPath):
lrcPath2 = [] # interpolated version of the path
rcLast = None
for rcCell in lrcPath:
if rcLast is not None:
lrcPath2.extend(self.interpolate_pair(rcLast, rcCell))
rcLast = rcCell
return lrcPath2
def drag_path_element(self, rc_cell):
"""Mouse motion event handler for drawing. """Mouse motion event handler for drawing.
""" """
lrcStroke = self.lrcStroke lrcStroke = self.lrcStroke
...@@ -465,15 +493,15 @@ class EditorModel(object): ...@@ -465,15 +493,15 @@ class EditorModel(object):
# Store the row,col location of the click, if we have entered a new cell # Store the row,col location of the click, if we have entered a new cell
if len(lrcStroke) > 0: if len(lrcStroke) > 0:
rcLast = lrcStroke[-1] rcLast = lrcStroke[-1]
if not np.array_equal(rcLast, rcCell): # only save at transition if not np.array_equal(rcLast, rc_cell): # only save at transition
lrcInterp = self.interpolate_path(rcLast, rcCell) lrcInterp = self.interpolate_pair(rcLast, rc_cell)
lrcStroke.extend(lrcInterp) lrcStroke.extend(lrcInterp)
self.debug("lrcStroke ", len(lrcStroke), rcCell, "interp:", lrcInterp) self.debug("lrcStroke ", len(lrcStroke), rc_cell, "interp:", lrcInterp)
else: else:
# This is the first cell in a mouse stroke # This is the first cell in a mouse stroke
lrcStroke.append(rcCell) lrcStroke.append(rc_cell)
self.debug("lrcStroke ", len(lrcStroke), rcCell) self.debug("lrcStroke ", len(lrcStroke), rc_cell)
def mod_path(self, bAddRemove): def mod_path(self, bAddRemove):
# disabled functionality (no longer required) # disabled functionality (no longer required)
...@@ -492,6 +520,8 @@ class EditorModel(object): ...@@ -492,6 +520,8 @@ class EditorModel(object):
# If we have already touched 3 cells # If we have already touched 3 cells
# We have a transition into a cell, and out of it. # We have a transition into a cell, and out of it.
#print(lrcStroke)
if len(lrcStroke) >= 2: if len(lrcStroke) >= 2:
# If the first cell in a stroke is empty, add a deadend to cell 0 # If the first cell in a stroke is empty, add a deadend to cell 0
if self.env.rail.get_full_transitions(*lrcStroke[0]) == 0: if self.env.rail.get_full_transitions(*lrcStroke[0]) == 0:
...@@ -500,6 +530,7 @@ class EditorModel(object): ...@@ -500,6 +530,7 @@ class EditorModel(object):
# Add transitions for groups of 3 cells # Add transitions for groups of 3 cells
# hence inbound and outbound transitions for middle cell # hence inbound and outbound transitions for middle cell
while len(lrcStroke) >= 3: while len(lrcStroke) >= 3:
#print(lrcStroke)
self.mod_rail_3cells(lrcStroke, bAddRemove=bAddRemove) self.mod_rail_3cells(lrcStroke, bAddRemove=bAddRemove)
# If final cell empty, insert deadend: # If final cell empty, insert deadend:
...@@ -507,6 +538,8 @@ class EditorModel(object): ...@@ -507,6 +538,8 @@ class EditorModel(object):
if self.env.rail.get_full_transitions(*lrcStroke[1]) == 0: if self.env.rail.get_full_transitions(*lrcStroke[1]) == 0:
self.mod_rail_2cells(lrcStroke, bAddRemove, iCellToMod=1) self.mod_rail_2cells(lrcStroke, bAddRemove, iCellToMod=1)
#print("final:", lrcStroke)
# now empty out the final two cells from the queue # now empty out the final two cells from the queue
lrcStroke.clear() lrcStroke.clear()
...@@ -521,6 +554,7 @@ class EditorModel(object): ...@@ -521,6 +554,7 @@ class EditorModel(object):
eg rcCells [(3,4), (2,4), (2,5)] would result in the transitions eg rcCells [(3,4), (2,4), (2,5)] would result in the transitions
N->E and W->S in cell (2,4). N->E and W->S in cell (2,4).
""" """
rc3Cells = array(lrcStroke[:3]) # the 3 cells rc3Cells = array(lrcStroke[:3]) # the 3 cells
rcMiddle = rc3Cells[1] # the middle cell which we will update rcMiddle = rc3Cells[1] # the middle cell which we will update
bDeadend = np.all(lrcStroke[0] == lrcStroke[2]) # deadend means cell 0 == cell 2 bDeadend = np.all(lrcStroke[0] == lrcStroke[2]) # deadend means cell 0 == cell 2
...@@ -581,6 +615,8 @@ class EditorModel(object): ...@@ -581,6 +615,8 @@ class EditorModel(object):
iTrans = iTrans[0][0] iTrans = iTrans[0][0]
liTrans.append(iTrans) liTrans.append(iTrans)
#self.log("liTrans:", liTrans)
# check that we have one transition # check that we have one transition
if len(liTrans) == 1: if len(liTrans) == 1:
# Set the transition as a deadend # Set the transition as a deadend
...@@ -601,39 +637,38 @@ class EditorModel(object): ...@@ -601,39 +637,38 @@ class EditorModel(object):
def clear(self): def clear(self):
self.env.rail.grid[:, :] = 0 self.env.rail.grid[:, :] = 0
self.env.agents = [] self.env.agents = []
self.env.agents_static = []
self.redraw() self.redraw()
def clearCell(self, rcCell): def clear_cell(self, cell_row_col):
self.debug_cell(rcCell) self.debug_cell(cell_row_col)
self.env.rail.grid[rcCell[0], rcCell[1]] = 0 self.env.rail.grid[cell_row_col[0], cell_row_col[1]] = 0
self.redraw() self.redraw()
def reset(self, replace_agents=False, nAgents=0): def reset(self, regenerate_schedule=False, nAgents=0):
self.regenerate("complex", nAgents=nAgents) self.regenerate("complex", nAgents=nAgents)
self.redraw() self.redraw()
def restartAgents(self): def restart_agents(self):
self.env.agents = EnvAgent.list_from_static(self.env.agents_static) self.env.reset_agents()
self.redraw() self.redraw()
def setFilename(self, filename): def set_filename(self, filename):
self.env_filename = filename self.env_filename = filename
def load(self): def load(self):
if os.path.exists(self.env_filename): if os.path.exists(self.env_filename):
self.log("load file: ", self.env_filename) self.log("load file: ", self.env_filename)
self.env.load(self.env_filename) #self.env.load(self.env_filename)
RailEnvPersister.load(self.env, self.env_filename)
if not self.regen_size_height == self.env.height or not self.regen_size_width == self.env.width: if not self.regen_size_height == self.env.height or not self.regen_size_width == self.env.width:
self.regen_size_height = self.env.height self.regen_size_height = self.env.height
self.regen_size_width = self.env.width self.regen_size_width = self.env.width
self.regenerate(None, 0, self.env) self.regenerate(None, 0, self.env)
self.env.load(self.env_filename) RailEnvPersister.load(self.env, self.env_filename)
self.env.restart_agents() self.env.reset_agents()
self.env.reset(False, False) self.env.reset(False, False)
self.init_agents_static = None
self.view.oRT.update_background() self.view.oRT.update_background()
self.fix_env() self.fix_env()
self.set_env(self.env) self.set_env(self.env)
...@@ -643,16 +678,12 @@ class EditorModel(object): ...@@ -643,16 +678,12 @@ class EditorModel(object):
def save(self): def save(self):
self.log("save to ", self.env_filename, " working dir: ", os.getcwd()) self.log("save to ", self.env_filename, " working dir: ", os.getcwd())
temp_store = self.env.agents #self.env.save(self.env_filename)
# clear agents before save , because we want the "init" position of the agent to expert RailEnvPersister.save(self.env, self.env_filename)
self.env.agents = []
self.env.save(self.env_filename)
# reset agents current (current position)
self.env.agents = temp_store
def saveImage(self): def save_image(self):
self.view.oRT.gl.saveImage('frame_{:04d}.bmp'.format(self.saveImageCnt)) self.view.oRT.gl.save_image('frame_{:04d}.bmp'.format(self.save_image_count))
self.saveImageCnt += 1 self.save_image_count += 1
self.view.redraw() self.view.redraw()
def regenerate(self, method=None, nAgents=0, env=None): def regenerate(self, method=None, nAgents=0, env=None):
...@@ -662,78 +693,90 @@ class EditorModel(object): ...@@ -662,78 +693,90 @@ class EditorModel(object):
if method is None or method == "Empty": if method is None or method == "Empty":
fnMethod = empty_rail_generator() fnMethod = empty_rail_generator()
elif method == "Random Cell":
fnMethod = random_rail_generator(cell_type_relative_proportion=[1] * 11)
else: else:
fnMethod = complex_rail_generator(nr_start_goal=nAgents, nr_extra=20, min_dist=12, seed=int(time.time())) fnMethod = sparse_rail_generator(nr_start_goal=nAgents, nr_extra=20, min_dist=12, seed=int(time.time()))
if env is None: if env is None:
self.env = RailEnv(width=self.regen_size_width, self.env = RailEnv(width=self.regen_size_width, height=self.regen_size_height, rail_generator=fnMethod,
height=self.regen_size_height, number_of_agents=nAgents, obs_builder_object=TreeObsForRailEnv(max_depth=2))
rail_generator=fnMethod,
number_of_agents=nAgents,
obs_builder_object=TreeObsForRailEnv(max_depth=2))
else: else:
self.env = env self.env = env
self.env.reset(regen_rail=True) self.env.reset(regenerate_rail=True)
self.fix_env() self.fix_env()
self.selected_agent = None # clear the selected agent.
self.set_env(self.env) self.set_env(self.env)
self.view.new_env() self.view.new_env()
self.redraw() self.redraw()
def setRegenSizeWidth(self, size): def set_regen_width(self, size):
self.regen_size_width = size self.regen_size_width = size
def setRegenSizeHeight(self, size): def set_regen_height(self, size):
self.regen_size_height = size self.regen_size_height = size
def find_agent_at(self, rcCell): def find_agent_at(self, cell_row_col):
for iAgent, agent in enumerate(self.env.agents_static): for agent_idx, agent in enumerate(self.env.agents):
if tuple(agent.position) == tuple(rcCell): if agent.position is None:
return iAgent rc_pos = agent.initial_position
else:
rc_pos = agent.position
if tuple(rc_pos) == tuple(cell_row_col):
return agent_idx
return None return None
def click_agent(self, rcCell): def click_agent(self, cell_row_col):
""" The user has clicked on a cell - """ The user has clicked on a cell -
- If there is an agent, select it * If there is an agent, select it
- If that agent was already selected, then deselect it * If that agent was already selected, then deselect it
- If there is no agent selected, and no agent in the cell, create one * If there is no agent selected, and no agent in the cell, create one
- If there is an agent selected, and no agent in the cell, move the selected agent to the cell * If there is an agent selected, and no agent in the cell, move the selected agent to the cell
""" """
# Has the user clicked on an existing agent? # Has the user clicked on an existing agent?
iAgent = self.find_agent_at(rcCell) agent_idx = self.find_agent_at(cell_row_col)
# This is in case we still have a selected agent even though the env has been recreated
# with no agents.
if (self.selected_agent is not None) and (self.selected_agent > len(self.env.agents)):
self.selected_agent = None
# Defensive coding below - for cell_row_col to be a tuple, not a numpy array:
# numpy array breaks various things when loading the env.
if iAgent is None: if agent_idx is None:
# No # No
if self.iSelectedAgent is None: if self.selected_agent is None:
# Create a new agent and select it. # Create a new agent and select it.
agent_static = EnvAgentStatic(position=rcCell, direction=0, target=rcCell, moving=False) agent = EnvAgent(initial_position=tuple(cell_row_col),
self.iSelectedAgent = self.env.add_agent_static(agent_static) initial_direction=0,
direction=0,
target=tuple(cell_row_col),
moving=False,
)
self.selected_agent = self.env.add_agent(agent)
# self.env.set_agent_active(agent)
self.view.oRT.update_background() self.view.oRT.update_background()
else: else:
# Move the selected agent to this cell # Move the selected agent to this cell
agent_static = self.env.agents_static[self.iSelectedAgent] agent = self.env.agents[self.selected_agent]
agent_static.position = rcCell agent.initial_position = tuple(cell_row_col)
agent_static.old_position = rcCell agent.position = tuple(cell_row_col)
self.env.agents = [] agent.old_position = tuple(cell_row_col)
else: else:
# Yes # Yes
# Have they clicked on the agent already selected? # Have they clicked on the agent already selected?
if self.iSelectedAgent is not None and iAgent == self.iSelectedAgent: if self.selected_agent is not None and agent_idx == self.selected_agent:
# Yes - deselect the agent # Yes - deselect the agent
self.iSelectedAgent = None self.selected_agent = None
else: else:
# No - select the agent # No - select the agent
self.iSelectedAgent = iAgent self.selected_agent = agent_idx
self.init_agents_static = None
self.redraw() self.redraw()
def add_target(self, rcCell): def add_target(self, rc_cell):
if self.iSelectedAgent is not None: if self.selected_agent is not None:
self.env.agents_static[self.iSelectedAgent].target = rcCell self.env.agents[self.selected_agent].target = tuple(rc_cell)
self.init_agents_static = None
self.view.oRT.update_background() self.view.oRT.update_background()
self.redraw() self.redraw()
...@@ -748,14 +791,14 @@ class EditorModel(object): ...@@ -748,14 +791,14 @@ class EditorModel(object):
self.view.log(*args, **kwargs) self.view.log(*args, **kwargs)
def debug(self, *args, **kwargs): def debug(self, *args, **kwargs):
if self.bDebug: if self.debug_bool:
self.log(*args, **kwargs) self.log(*args, **kwargs)
def debug_cell(self, rcCell): def debug_cell(self, rc_cell):
binTrans = self.env.rail.get_full_transitions(*rcCell) binTrans = self.env.rail.get_full_transitions(*rc_cell)
sbinTrans = format(binTrans, "#018b")[2:] sbinTrans = format(binTrans, "#018b")[2:]
self.debug("cell ", self.debug("cell ",
rcCell, rc_cell,
"Transitions: ", "Transitions: ",
binTrans, binTrans,
sbinTrans, sbinTrans,
......
from numpy.random import RandomState
import flatland.envs.observations as obs
import flatland.envs.rail_generators as rg
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.line_generators import BaseLineGen
from flatland.envs.rail_env import RailEnv
from flatland.envs.timetable_utils import Line
from flatland.utils import editor
# Start and end all agents at the same place
class SchedGen2(BaseLineGen):
def __init__(self, rcStart, rcEnd, iDir):
self.rcStart = rcStart
self.rcEnd = rcEnd
self.iDir = iDir
def generate(self, rail: GridTransitionMap, num_agents: int, hints: dict = None, num_resets: int = None,
np_random: RandomState = None) -> Line:
return Line(agent_positions=[self.rcStart] * num_agents,
agent_directions=[self.iDir] * num_agents,
agent_targets=[self.rcEnd] * num_agents,
agent_speeds=[1.0] * num_agents)
# cycle through lists of start, end and direction
class SchedGen3(BaseLineGen):
def __init__(self, lrcStarts, lrcTargs, liDirs):
self.lrcStarts = lrcStarts
self.lrcTargs = lrcTargs
self.liDirs = liDirs
def generate(self, rail: GridTransitionMap, num_agents: int, hints: dict = None, num_resets: int = None,
np_random: RandomState = None) -> Line:
return Line(agent_positions=[self.lrcStarts[i % len(self.lrcStarts)] for i in range(num_agents)],
agent_directions=[self.liDirs[i % len(self.liDirs)] for i in range(num_agents)],
agent_targets=[self.lrcTargs[i % len(self.lrcTargs)] for i in range(num_agents)],
agent_speeds=[1.0] * num_agents)
def makeEnv(nAg=2, width=20, height=10, oSG=None):
env = RailEnv(width=width, height=height, rail_generator=rg.empty_rail_generator(),
number_of_agents=nAg,
line_generator=oSG,
obs_builder_object=obs.TreeObsForRailEnv(max_depth=1))
envModel = editor.EditorModel(env)
env.reset()
return env, envModel
def makeEnv2(nAg=2, shape=(20, 10), llrcPaths=[], lrcStarts=[], lrcTargs=[], liDirs=[], remove_agents_at_target=True):
oSG = SchedGen3(lrcStarts, lrcTargs, liDirs)
env = RailEnv(width=shape[0], height=shape[1],
rail_generator=rg.empty_rail_generator(),
number_of_agents=nAg,
line_generator=oSG,
obs_builder_object=obs.TreeObsForRailEnv(max_depth=1),
remove_agents_at_target=remove_agents_at_target,
record_steps=True)
envModel = editor.EditorModel(env)
env.reset()
for lrcPath in llrcPaths:
envModel.mod_rail_cell_seq(envModel.interpolate_path(lrcPath))
return env, envModel
ddEnvSpecs = {
# opposing stations with single alternative path
"single_alternative": {
"llrcPaths": [
[(1, 0), (1, 15)], # across the top
[(1, 4), (1, 6), (3, 6), (3, 12), (1, 12), (1, 14)], # alternative loop below
],
"lrcStarts": [(1, 3), (1, 14)],
"lrcTargs": [(1, 14), (1, 3)],
"liDirs": [1, 3]
},
# single spur so one agent needs to wait
"single_spur": {
"llrcPaths": [
[(1, 0), (1, 15)],
[(4, 0), (4, 6), (1, 6), (1, 8)]],
"lrcStarts": [(1, 3), (1, 14)],
"lrcTargs": [(1, 14), (4, 2)],
"liDirs": [1, 3]
},
# single spur so one agent needs to wait
"merging_spurs": {
"llrcPaths": [
[(1, 0), (1, 15), (7, 15), (7, 0)],
[(4, 0), (4, 6), (1, 6), (1, 8)],
# [((1,14), (1,16), (7,16), )]
],
"lrcStarts": [(1, 2), (4, 2)],
"lrcTargs": [(7, 3)],
"liDirs": [1]
},
# Concentric Loops
"concentric_loops": {
"llrcPaths": [
[(1, 1), (1, 5), (8, 5), (8, 1), (1, 1), (1, 3)],
[(1, 3), (1, 10), (8, 10), (8, 3)]
],
"lrcStarts": [(1, 3)],
"lrcTargs": [(2, 1)],
"liDirs": [1]
},
# two loops
"loop_with_loops": {
"llrcPaths": [
# big outer loop Row 1, 8; Col 1, 15
[(1, 1), (1, 15), (8, 15), (8, 1), (1, 1), (1, 3)],
# alternative 1
[(1, 3), (1, 5), (3, 5), (3, 10), (1, 10), (1, 12)],
# alternative 2
[(8, 3), (8, 5), (6, 5), (6, 10), (8, 10), (8, 12)],
],
# list of row,col of agent start cells
"lrcStarts": [(1, 3), (8, 3)],
# list of row,col of targets
"lrcTargs": [(8, 2), (1, 2)],
# list of initial directions
"liDirs": [1, 1],
}
}
def makeTestEnv(sName="single_alternative", nAg=2, remove_agents_at_target=True):
global ddEnvSpecs
dSpec = ddEnvSpecs[sName]
return makeEnv2(nAg=nAg, remove_agents_at_target=remove_agents_at_target, **dSpec)
def getAgentState(env):
dAgState = {}
for iAg, ag in enumerate(env.agents):
dAgState[iAg] = (*ag.position, ag.direction)
return dAgState
...@@ -9,9 +9,6 @@ class GraphicsLayer(object): ...@@ -9,9 +9,6 @@ class GraphicsLayer(object):
def open_window(self): def open_window(self):
pass pass
def is_raster(self):
return True
def plot(self, *args, **kwargs): def plot(self, *args, **kwargs):
pass pass
...@@ -28,24 +25,31 @@ class GraphicsLayer(object): ...@@ -28,24 +25,31 @@ class GraphicsLayer(object):
pass pass
def pause(self, seconds=0.00001): def pause(self, seconds=0.00001):
""" deprecated """
pass
def idle(self, seconds=0.00001):
""" process any display events eg redraw, resize.
Return only after the given number of seconds, ie idle / loop until that number.
"""
pass pass
def clf(self): def clf(self):
pass pass
def beginFrame(self): def begin_frame(self):
pass pass
def endFrame(self): def endFrame(self):
pass pass
def getImage(self): def get_image(self):
pass pass
def saveImage(self, filename): def save_image(self, filename):
pass pass
def adaptColor(self, color, lighten=False): def adapt_color(self, color, lighten=False):
if type(color) is str: if type(color) is str:
if color == "red" or color == "r": if color == "red" or color == "r":
color = (255, 0, 0) color = (255, 0, 0)
...@@ -68,7 +72,7 @@ class GraphicsLayer(object): ...@@ -68,7 +72,7 @@ class GraphicsLayer(object):
def get_cmap(self, *args, **kwargs): def get_cmap(self, *args, **kwargs):
return plt.get_cmap(*args, **kwargs) return plt.get_cmap(*args, **kwargs)
def setRailAt(self, row, col, binTrans, iTarget=None, isSelected=False, rail_grid=None): def set_rail_at(self, row, col, binTrans, iTarget=None, isSelected=False, rail_grid=None, num_agents=None):
""" Set the rail at cell (row, col) to have transitions binTrans. """ Set the rail at cell (row, col) to have transitions binTrans.
The target argument can contain the index of the agent to indicate The target argument can contain the index of the agent to indicate
that agent's target is at that cell, so that a station can be that agent's target is at that cell, so that a station can be
...@@ -76,10 +80,11 @@ class GraphicsLayer(object): ...@@ -76,10 +80,11 @@ class GraphicsLayer(object):
""" """
pass pass
def setAgentAt(self, iAgent, row, col, iDirIn, iDirOut, isSelected=False): def set_agent_at(self, iAgent, row, col, iDirIn, iDirOut, isSelected=False, rail_grid=None, show_debug=False,
clear_debug_text=True):
pass pass
def setCellOccupied(self, iAgent, row, col): def set_cell_occupied(self, iAgent, row, col):
pass pass
def resize(self, env): def resize(self, env):
......
import pyglet as pgl
import time
from PIL import Image
# from numpy import array
# from pkg_resources import resource_string as resource_bytes
# from flatland.utils.graphics_layer import GraphicsLayer
from flatland.utils.graphics_pil import PILSVG
class PGLGL(PILSVG):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.window_open = False # means the window has not yet been opened.
self.close_requested = False # user has clicked
self.closed = False # windows has been closed (currently, we leave the env still running)
def open_window(self):
print("open_window - pyglet")
assert self.window_open is False, "Window is already open!"
self.window = pgl.window.Window(resizable=True, vsync=False, width=1200, height=800)
#self.__class__.window.title("Flatland")
#self.__class__.window.configure(background='grey')
self.window_open = True
@self.window.event
def on_draw():
#print("pyglet draw event")
self.window.clear()
self.show(from_event=True)
#print("pyglet draw event done")
@self.window.event
def on_resize(width, height):
#print(f"The window was resized to {width}, {height}")
self.show(from_event=True)
self.window.dispatch_event("on_draw")
#print("pyglet resize event done")
@self.window.event
def on_close():
self.close_requested = True
def close_window(self):
self.window.close()
self.closed=True
def show(self, block=False, from_event=False):
if not self.window_open:
self.open_window()
if self.close_requested:
if not self.closed:
self.close_window()
return
#tStart = time.time()
self._processEvents()
pil_img = self.alpha_composite_layers()
pil_img_resized = pil_img.resize((self.window.width, self.window.height), resample=Image.NEAREST)
# convert our PIL image to pyglet:
bytes_image = pil_img_resized.tobytes()
pgl_image = pgl.image.ImageData(pil_img_resized.width, pil_img_resized.height,
#self.window.width, self.window.height,
'RGBA',
bytes_image, pitch=-pil_img_resized.width * 4)
pgl_image.blit(0,0)
#tEnd = time.time()
#print("show time: ", tEnd - tStart)
def _processEvents(self):
""" This is the replacement for a custom event loop for Pyglet.
The lines below are typical of Pyglet examples.
Manually resizing the window is still very clunky.
"""
#print("process events...", end="")
pgl.clock.tick()
#for window in pgl.app.windows:
if not self.closed:
self.window.switch_to()
self.window.dispatch_events()
self.window.flip()
#print(" events done")
def idle(self, seconds=0.00001):
tStart = time.time()
tEnd = tStart + seconds
while (time.time() < tEnd):
self._processEvents()
#self.show()
time.sleep(min(seconds, 0.1))
def test_pyglet():
oGL = PGLGL(400,300)
time.sleep(2)
def test_event_loop():
""" Shows how it should work with the standard event loop
Resizing is fairly smooth (ie runs at least 10-20x a second)
"""
window = pgl.window.Window(resizable=True)
pil_img = Image.open("notebooks/simple_example_3.png")
def show():
pil_img_resized = pil_img.resize((window.width, window.height), resample=Image.NEAREST)
bytes_image = pil_img_resized.tobytes()
pgl_image = pgl.image.ImageData(pil_img_resized.width, pil_img_resized.height,
#self.window.width, self.window.height,
'RGBA',
bytes_image, pitch=-pil_img_resized.width * 4)
pgl_image.blit(0,0)
@window.event
def on_draw():
print("pyglet draw event")
window.clear()
show()
print("pyglet draw event done")
@window.event
def on_resize(width, height):
print(f"The window was resized to {width}, {height}")
#show()
print("pyglet resize event done")
@window.event
def on_close():
#self.close_requested = True
print("close")
pgl.app.run()
if __name__=="__main__":
#test_pyglet()
test_event_loop()
\ No newline at end of file
import io import io
import os import os
import platform
import time import time
import tkinter as tk #import tkinter as tk
import numpy as np import numpy as np
from PIL import Image, ImageDraw, ImageTk # , ImageFont from PIL import Image, ImageDraw, ImageFont
from numpy import array from numpy import array
from pkg_resources import resource_string as resource_bytes from pkg_resources import resource_string as resource_bytes
from flatland.utils.graphics_layer import GraphicsLayer from flatland.utils.graphics_layer import GraphicsLayer
def enable_windows_cairo_support():
if os.name == 'nt':
import site
import ctypes.util
default_os_path = os.environ['PATH']
os.environ['PATH'] = ''
for s in site.getsitepackages():
os.environ['PATH'] = os.environ['PATH'] + ';' + s + '\\cairo'
os.environ['PATH'] = os.environ['PATH'] + ';' + default_os_path
if ctypes.util.find_library('cairo') is None:
print("Error: cairo not installed")
enable_windows_cairo_support()
from cairosvg import svg2png # noqa: E402
from screeninfo import get_monitors # noqa: E402
from flatland.core.grid.rail_env_grid import RailEnvTransitions # noqa: E402 from flatland.core.grid.rail_env_grid import RailEnvTransitions # noqa: E402
class PILGL(GraphicsLayer): class PILGL(GraphicsLayer):
# tk.Tk() must be a singleton! # tk.Tk() must be a singleton!
# https://stackoverflow.com/questions/26097811/image-pyimage2-doesnt-exist # https://stackoverflow.com/questions/26097811/image-pyimage2-doesnt-exist
window = tk.Tk() # window = tk.Tk()
RAIL_LAYER = 0
PREDICTION_PATH_LAYER = 1
TARGET_LAYER = 2
AGENT_LAYER = 3
SELECTED_AGENT_LAYER = 4
SELECTED_TARGET_LAYER = 5
def __init__(self, width, height, jupyter=False): def __init__(self, width, height, jupyter=False, screen_width=800, screen_height=600):
self.yxBase = (0, 0) self.yxBase = (0, 0)
self.linewidth = 4 self.linewidth = 4
self.nAgentColors = 1 # overridden in loadAgent self.n_agent_colors = 1 # overridden in loadAgent
self.width = width self.width = width
self.height = height self.height = height
...@@ -48,19 +36,13 @@ class PILGL(GraphicsLayer): ...@@ -48,19 +36,13 @@ class PILGL(GraphicsLayer):
self.background_grid = np.zeros(shape=(self.width, self.height)) self.background_grid = np.zeros(shape=(self.width, self.height))
if jupyter is False: if jupyter is False:
self.screen_width = 800 # NOTE: Currently removed the dependency on
self.screen_height = 600 # screeninfo. We have to find an alternate
# way to compute the screen width and height
if platform.system() == "Windows" or platform.system() == "Linux": # In the meantime, we are harcoding the 800x600
self.screen_width = 9999 # assumption
self.screen_height = 9999 self.screen_width = screen_width
for m in get_monitors(): self.screen_height = screen_height
self.screen_height = min(self.screen_height, m.height)
self.screen_width = min(self.screen_width, m.width)
# Note: screeninfo doesnot have proper support for
# OSX yet, hence the default values of 800,600
# will be used for the same.
w = (self.screen_width - self.width - 10) / (self.width + 1 + self.linewidth) w = (self.screen_width - self.width - 10) / (self.width + 1 + self.linewidth)
h = (self.screen_height - self.height - 10) / (self.height + 1 + self.linewidth) h = (self.screen_height - self.height - 10) / (self.height + 1 + self.linewidth)
self.nPixCell = int(max(1, np.ceil(min(w, h)))) self.nPixCell = int(max(1, np.ceil(min(w, h))))
...@@ -83,15 +65,15 @@ class PILGL(GraphicsLayer): ...@@ -83,15 +65,15 @@ class PILGL(GraphicsLayer):
sColors = "d50000#c51162#aa00ff#6200ea#304ffe#2962ff#0091ea#00b8d4#00bfa5#00c853" + \ sColors = "d50000#c51162#aa00ff#6200ea#304ffe#2962ff#0091ea#00b8d4#00bfa5#00c853" + \
"#64dd17#aeea00#ffd600#ffab00#ff6d00#ff3d00#5d4037#455a64" "#64dd17#aeea00#ffd600#ffab00#ff6d00#ff3d00#5d4037#455a64"
self.agent_colors = [self.rgb_s2i(sColor) for sColor in sColors.split("#")]
self.n_agent_colors = len(self.agent_colors)
self.ltAgentColors = [self.rgb_s2i(sColor) for sColor in sColors.split("#")]
self.nAgentColors = len(self.ltAgentColors)
self.window_open = False
self.firstFrame = True self.firstFrame = True
self.old_background_image = (None, None, None) self.old_background_image = (None, None, None)
self.create_layers() self.create_layers()
self.font = ImageFont.load_default()
def build_background_map(self, dTargets): def build_background_map(self, dTargets):
x = self.old_background_image x = self.old_background_image
rebuild = False rebuild = False
...@@ -109,14 +91,17 @@ class PILGL(GraphicsLayer): ...@@ -109,14 +91,17 @@ class PILGL(GraphicsLayer):
rebuild = True rebuild = True
if rebuild: if rebuild:
# rebuild background_grid to control the visualisation of buildings, trees, mountains, lakes and river
self.background_grid = np.zeros(shape=(self.width, self.height)) self.background_grid = np.zeros(shape=(self.width, self.height))
# build base distance map (distance to targets)
for x in range(self.width): for x in range(self.width):
for y in range(self.height): for y in range(self.height):
distance = int(np.ceil(np.sqrt(self.width ** 2.0 + self.height ** 2.0))) distance = int(np.ceil(np.sqrt(self.width ** 2.0 + self.height ** 2.0)))
for rc in dTargets: for rc in dTargets:
r = rc[1] r = rc[1]
c = rc[0] c = rc[0]
d = int(np.floor(np.sqrt((x - r) ** 2 + (y - c) ** 2))) d = int(np.floor(np.sqrt((x - r) ** 2 + (y - c) ** 2)) / 0.5)
distance = min(d, distance) distance = min(d, distance)
self.background_grid[x][y] = distance self.background_grid[x][y] = distance
...@@ -126,27 +111,42 @@ class PILGL(GraphicsLayer): ...@@ -126,27 +111,42 @@ class PILGL(GraphicsLayer):
""" convert a hex RGB string like 0091ea to 3-tuple of ints """ """ convert a hex RGB string like 0091ea to 3-tuple of ints """
return tuple(int(sRGB[iRGB * 2:iRGB * 2 + 2], 16) for iRGB in [0, 1, 2]) return tuple(int(sRGB[iRGB * 2:iRGB * 2 + 2], 16) for iRGB in [0, 1, 2])
def getAgentColor(self, iAgent): def get_agent_color(self, iAgent):
return self.ltAgentColors[iAgent % self.nAgentColors] return self.agent_colors[iAgent % self.n_agent_colors]
def plot(self, gX, gY, color=None, linewidth=3, layer=0, opacity=255, **kwargs): def plot(self, gX, gY, color=None, linewidth=3, layer=RAIL_LAYER, opacity=255, **kwargs):
color = self.adaptColor(color) """ Draw a line joining the points in gX, GY - each an"""
color = self.adapt_color(color)
if len(color) == 3: if len(color) == 3:
color += (opacity,) color += (opacity,)
elif len(color) == 4: elif len(color) == 4:
color = color[:3] + (opacity,) color = color[:3] + (opacity,)
gPoints = np.stack([array(gX), -array(gY)]).T * self.nPixCell gPoints = np.stack([array(gX), -array(gY)]).T * self.nPixCell
gPoints = list(gPoints.ravel()) gPoints = list(gPoints.ravel())
self.draws[layer].line(gPoints, fill=color, width=self.linewidth) # the width here was self.linewidth - not really sure of the implications
self.draws[layer].line(gPoints, fill=color, width=linewidth)
def scatter(self, gX, gY, color=None, marker="o", s=50, layer=0, opacity=255, *args, **kwargs): def scatter(self, gX, gY, color=None, marker="o", s=50, layer=RAIL_LAYER, opacity=255, *args, **kwargs):
color = self.adaptColor(color) color = self.adapt_color(color)
r = np.sqrt(s) r = np.sqrt(s)
gPoints = np.stack([np.atleast_1d(gX), -np.atleast_1d(gY)]).T * self.nPixCell gPoints = np.stack([np.atleast_1d(gX), -np.atleast_1d(gY)]).T * self.nPixCell
for x, y in gPoints: for x, y in gPoints:
self.draws[layer].rectangle([(x - r, y - r), (x + r, y + r)], fill=color, outline=color) self.draws[layer].rectangle([(x - r, y - r), (x + r, y + r)], fill=color, outline=color)
def drawImageXY(self, pil_img, xyPixLeftTop, layer=0): def draw_image_xy(self, pil_img, xyPixLeftTop, layer=RAIL_LAYER, ):
# Resize all PIL images just before drawing them
# to ensure that resizing doesnt affect the
# recolorizing strategies in place
#
# That said : All the code in this file needs
# some serious refactoring -_- to ensure the
# code style and structure is consitent.
# - Mohanty
pil_img = pil_img.resize(
(self.nPixCell, self.nPixCell)
)
if (pil_img.mode == "RGBA"): if (pil_img.mode == "RGBA"):
pil_mask = pil_img pil_mask = pil_img
else: else:
...@@ -154,74 +154,63 @@ class PILGL(GraphicsLayer): ...@@ -154,74 +154,63 @@ class PILGL(GraphicsLayer):
self.layers[layer].paste(pil_img, xyPixLeftTop, pil_mask) self.layers[layer].paste(pil_img, xyPixLeftTop, pil_mask)
def drawImageRC(self, pil_img, rcTopLeft, layer=0): def draw_image_row_col(self, pil_img, rcTopLeft, layer=RAIL_LAYER, ):
xyPixLeftTop = tuple((array(rcTopLeft) * self.nPixCell)[[1, 0]]) xyPixLeftTop = tuple((array(rcTopLeft) * self.nPixCell)[[1, 0]])
self.drawImageXY(pil_img, xyPixLeftTop, layer=layer) self.draw_image_xy(pil_img, xyPixLeftTop, layer=layer)
def open_window(self): def open_window(self):
assert self.window_open is False, "Window is already open!" pass
self.__class__.window.title("Flatland")
self.__class__.window.configure(background='grey')
self.window_open = True
def close_window(self): def close_window(self):
self.panel.destroy()
# quit but not destroy!
self.__class__.window.quit()
def text(self, *args, **kwargs):
pass pass
def text(self, xPx, yPx, strText, layer=RAIL_LAYER):
xyPixLeftTop = (xPx, yPx)
self.draws[layer].text(xyPixLeftTop, strText, font=self.font, fill=(0, 0, 0, 255))
def text_rowcol(self, rcTopLeft, strText, layer=AGENT_LAYER):
xyPixLeftTop = tuple((array(rcTopLeft) * self.nPixCell)[[1, 0]])
self.text(*xyPixLeftTop, strText, layer)
def prettify(self, *args, **kwargs): def prettify(self, *args, **kwargs):
pass pass
def prettify2(self, width, height, cell_size): def prettify2(self, width, height, cell_size):
pass pass
def beginFrame(self): def begin_frame(self):
# Create a new agent layer # Create a new agent layer
self.create_layer(iLayer=1, clear=True) self.create_layer(iLayer=PILGL.AGENT_LAYER, clear=True)
self.create_layer(iLayer=PILGL.PREDICTION_PATH_LAYER, clear=True)
def show(self, block=False): def show(self, block=False):
img = self.alpha_composite_layers() #print("show() - ", self.__class__)
pass
if not self.window_open:
self.open_window()
tkimg = ImageTk.PhotoImage(img)
if self.firstFrame:
# Do TK actions for a new panel (not sure what they really do)
self.panel = tk.Label(self.window, image=tkimg)
self.panel.pack(side="bottom", fill="both", expand="yes")
else:
# update the image in situ
self.panel.configure(image=tkimg)
self.panel.image = tkimg
self.__class__.window.update()
self.firstFrame = False
def pause(self, seconds=0.00001): def pause(self, seconds=0.00001):
pass pass
def idle(self, seconds=0.00001):
pass
def alpha_composite_layers(self): def alpha_composite_layers(self):
img = self.layers[0] img = self.layers[0]
for img2 in self.layers[1:]: for img2 in self.layers[1:]:
img = Image.alpha_composite(img, img2) img = Image.alpha_composite(img, img2)
return img return img
def getImage(self): def get_image(self):
""" return a blended / alpha composited image composed of all the layers, """ return a blended / alpha composited image composed of all the layers,
with layer 0 at the "back". with layer 0 at the "back".
""" """
img = self.alpha_composite_layers() img = self.alpha_composite_layers()
return array(img) return array(img)
def saveImage(self, filename): def save_image(self, filename):
""" """
Renders the current scene into a image file Renders the current scene into a image file
:param filename: filename where to store the rendering output (supported image format *.bmp , .. , *.png) :param filename: filename where to store the rendering output_generator
(supported image format *.bmp , .. , *.png)
""" """
img = self.alpha_composite_layers() img = self.alpha_composite_layers()
img.save(filename) img.save(filename)
...@@ -254,29 +243,33 @@ class PILGL(GraphicsLayer): ...@@ -254,29 +243,33 @@ class PILGL(GraphicsLayer):
self.clear_layer(iLayer) self.clear_layer(iLayer)
def create_layers(self, clear=True): def create_layers(self, clear=True):
self.create_layer(0, clear=clear) # rail / background (scene) self.create_layer(PILGL.RAIL_LAYER, clear=clear) # rail / background (scene)
self.create_layer(1, clear=clear) # agents self.create_layer(PILGL.AGENT_LAYER, clear=clear) # agents
self.create_layer(2, clear=clear) # drawing layer for selected agent self.create_layer(PILGL.TARGET_LAYER, clear=clear) # agents
self.create_layer(3, clear=clear) # drawing layer for selected agent's target self.create_layer(PILGL.PREDICTION_PATH_LAYER, clear=clear) # drawing layer for agent's prediction path
self.create_layer(PILGL.SELECTED_AGENT_LAYER, clear=clear) # drawing layer for selected agent
self.create_layer(PILGL.SELECTED_TARGET_LAYER, clear=clear) # drawing layer for selected agent's target
class PILSVG(PILGL): class PILSVG(PILGL):
def __init__(self, width, height, jupyter=False): """
Note : This class should now ideally be called as PILPNG,
but for backward compatibility, and to not introduce any breaking changes at this point
we are sticking to the legacy name of PILSVG (when in practice we are not using SVG anymore)
"""
def __init__(self, width, height, jupyter=False, screen_width=800, screen_height=600):
oSuper = super() oSuper = super()
oSuper.__init__(width, height, jupyter) oSuper.__init__(width, height, jupyter, screen_width, screen_height)
self.lwAgents = [] self.lwAgents = []
self.agents_prev = [] self.agents_prev = []
self.loadBuildingSVGs() self.load_buildings()
self.loadScenerySVGs() self.load_scenery()
self.loadRailSVGs() self.load_rail()
self.loadAgentSVGs() self.load_agent()
def is_raster(self): def process_events(self):
return False
def processEvents(self):
time.sleep(0.001) time.sleep(0.001)
def clear_rails(self): def clear_rails(self):
...@@ -289,310 +282,395 @@ class PILSVG(PILGL): ...@@ -289,310 +282,395 @@ class PILSVG(PILGL):
self.lwAgents = [] self.lwAgents = []
self.agents_prev = [] self.agents_prev = []
def pilFromSvgFile(self, package, resource): def pil_from_png_file(self, package, resource):
bytestring = resource_bytes(package, resource) bytestring = resource_bytes(package, resource)
bytesPNG = svg2png(bytestring=bytestring, output_height=self.nPixCell, output_width=self.nPixCell) with io.BytesIO(bytestring) as fIn:
with io.BytesIO(bytesPNG) as fIn:
pil_img = Image.open(fIn) pil_img = Image.open(fIn)
pil_img.load() pil_img.load()
return pil_img return pil_img
def pilFromSvgBytes(self, bytesSVG): def load_buildings(self):
bytesPNG = svg2png(bytesSVG, output_height=self.nPixCell, output_width=self.nPixCell) lBuildingFiles = [
with io.BytesIO(bytesPNG) as fIn: "Buildings-Bank.png",
pil_img = Image.open(fIn) "Buildings-Bar.png",
return pil_img "Buildings-Wohnhaus.png",
"Buildings-Hochhaus.png",
def loadBuildingSVGs(self): "Buildings-Hotel.png",
dBuildingFiles = [ "Buildings-Office.png",
"Buildings/Bank.svg", "Buildings-Polizei.png",
"Buildings/Bar.svg", "Buildings-Post.png",
"Buildings/Wohnhaus.svg", "Buildings-Supermarkt.png",
"Buildings/Hochhaus.svg", "Buildings-Tankstelle.png",
"Buildings/Hotel.svg", "Buildings-Fabrik_A.png",
"Buildings/Office.svg", "Buildings-Fabrik_B.png",
"Buildings/Polizei.svg", "Buildings-Fabrik_C.png",
"Buildings/Post.svg", "Buildings-Fabrik_D.png",
"Buildings/Supermarkt.svg", "Buildings-Fabrik_E.png",
"Buildings/Tankstelle.svg", "Buildings-Fabrik_F.png",
"Buildings/Fabrik_A.svg", "Buildings-Fabrik_G.png",
"Buildings/Fabrik_B.svg", "Buildings-Fabrik_H.png",
"Buildings/Fabrik_C.svg", "Buildings-Fabrik_I.png"
"Buildings/Fabrik_D.svg",
"Buildings/Fabrik_E.svg",
"Buildings/Fabrik_F.svg",
"Buildings/Fabrik_G.svg",
"Buildings/Fabrik_H.svg",
"Buildings/Fabrik_I.svg",
] ]
imgBg = self.pilFromSvgFile('svg', "Background_city.svg") imgBg = self.pil_from_png_file('flatland.png', "Background_city.png")
imgBg = imgBg.convert("RGBA")
self.dBuildings = [] self.lBuildings = []
for sFile in dBuildingFiles: for sFile in lBuildingFiles:
img = self.pilFromSvgFile('svg', sFile) img = self.pil_from_png_file('flatland.png', sFile)
img = Image.alpha_composite(imgBg, img) img = Image.alpha_composite(imgBg, img)
self.dBuildings.append(img) self.lBuildings.append(img)
def loadScenerySVGs(self): def load_scenery(self):
dSceneryFiles = [ scenery_files = [
"Scenery/Laubbaume_A.svg", "Scenery-Laubbaume_A.png",
"Scenery/Laubbaume_B.svg", "Scenery-Laubbaume_B.png",
"Scenery/Laubbaume_C.svg", "Scenery-Laubbaume_C.png",
"Scenery/Nadelbaume_A.svg", "Scenery-Nadelbaume_A.png",
"Scenery/Nadelbaume_B.svg", "Scenery-Nadelbaume_B.png",
"Scenery/Bergwelt_B.svg" "Scenery-Bergwelt_B.png"
] ]
dSceneryFilesDim2 = [ scenery_files_d2 = [
"Scenery/Bergwelt_C_Teil_1_links.svg", "Scenery-Bergwelt_C_Teil_1_links.png",
"Scenery/Bergwelt_C_Teil_2_rechts.svg" "Scenery-Bergwelt_C_Teil_2_rechts.png"
] ]
dSceneryFilesDim3 = [ scenery_files_d3 = [
"Scenery/Bergwelt_A_Teil_3_rechts.svg", "Scenery-Bergwelt_A_Teil_1_links.png",
"Scenery/Bergwelt_A_Teil_2_mitte.svg", "Scenery-Bergwelt_A_Teil_2_mitte.png",
"Scenery/Bergwelt_A_Teil_1_links.svg" "Scenery-Bergwelt_A_Teil_3_rechts.png"
] ]
imgBg = self.pilFromSvgFile('svg', "Background_Light_green.svg") scenery_files_water = [
"Scenery_Water.png"
]
self.dScenery = [] img_back_ground = self.pil_from_png_file('flatland.png', "Background_Light_green.png").convert("RGBA")
for sFile in dSceneryFiles:
img = self.pilFromSvgFile('svg', sFile)
img = Image.alpha_composite(imgBg, img)
self.dScenery.append(img)
self.dSceneryDim2 = [] self.scenery_background_white = self.pil_from_png_file('flatland.png', "Background_white.png").convert("RGBA")
for sFile in dSceneryFilesDim2:
img = self.pilFromSvgFile('svg', sFile)
img = Image.alpha_composite(imgBg, img)
self.dSceneryDim2.append(img)
self.dSceneryDim3 = [] self.scenery = []
for sFile in dSceneryFilesDim3: for file in scenery_files:
img = self.pilFromSvgFile('svg', sFile) img = self.pil_from_png_file('flatland.png', file)
img = Image.alpha_composite(imgBg, img) img = Image.alpha_composite(img_back_ground, img)
self.dSceneryDim3.append(img) self.scenery.append(img)
self.scenery_d2 = []
for file in scenery_files_d2:
img = self.pil_from_png_file('flatland.png', file)
img = Image.alpha_composite(img_back_ground, img)
self.scenery_d2.append(img)
def loadRailSVGs(self): self.scenery_d3 = []
for file in scenery_files_d3:
img = self.pil_from_png_file('flatland.png', file)
img = Image.alpha_composite(img_back_ground, img)
self.scenery_d3.append(img)
self.scenery_water = []
for file in scenery_files_water:
img = self.pil_from_png_file('flatland.png', file)
img = Image.alpha_composite(img_back_ground, img)
self.scenery_water.append(img)
def load_rail(self):
""" Load the rail SVG images, apply rotations, and store as PIL images. """ Load the rail SVG images, apply rotations, and store as PIL images.
""" """
dRailFiles = { rail_files = {
"": "Background_Light_green.svg", "": "Background_Light_green.png",
"WE": "Gleis_Deadend.svg", "WE": "Gleis_Deadend.png",
"WW EE NN SS": "Gleis_Diamond_Crossing.svg", "WW EE NN SS": "Gleis_Diamond_Crossing.png",
"WW EE": "Gleis_horizontal.svg", "WW EE": "Gleis_horizontal.png",
"EN SW": "Gleis_Kurve_oben_links.svg", "EN SW": "Gleis_Kurve_oben_links.png",
"WN SE": "Gleis_Kurve_oben_rechts.svg", "WN SE": "Gleis_Kurve_oben_rechts.png",
"ES NW": "Gleis_Kurve_unten_links.svg", "ES NW": "Gleis_Kurve_unten_links.png",
"NE WS": "Gleis_Kurve_unten_rechts.svg", "NE WS": "Gleis_Kurve_unten_rechts.png",
"NN SS": "Gleis_vertikal.svg", "NN SS": "Gleis_vertikal.png",
"NN SS EE WW ES NW SE WN": "Weiche_Double_Slip.svg", "NN SS EE WW ES NW SE WN": "Weiche_Double_Slip.png",
"EE WW EN SW": "Weiche_horizontal_oben_links.svg", "EE WW EN SW": "Weiche_horizontal_oben_links.png",
"EE WW SE WN": "Weiche_horizontal_oben_rechts.svg", "EE WW SE WN": "Weiche_horizontal_oben_rechts.png",
"EE WW ES NW": "Weiche_horizontal_unten_links.svg", "EE WW ES NW": "Weiche_horizontal_unten_links.png",
"EE WW NE WS": "Weiche_horizontal_unten_rechts.svg", "EE WW NE WS": "Weiche_horizontal_unten_rechts.png",
"NN SS EE WW NW ES": "Weiche_Single_Slip.svg", "NN SS EE WW NW ES": "Weiche_Single_Slip.png",
"NE NW ES WS": "Weiche_Symetrical.svg", "NE NW ES WS": "Weiche_Symetrical.png",
"NN SS EN SW": "Weiche_vertikal_oben_links.svg", "NN SS EN SW": "Weiche_vertikal_oben_links.png",
"NN SS SE WN": "Weiche_vertikal_oben_rechts.svg", "NN SS SE WN": "Weiche_vertikal_oben_rechts.png",
"NN SS NW ES": "Weiche_vertikal_unten_links.svg", "NN SS NW ES": "Weiche_vertikal_unten_links.png",
"NN SS NE WS": "Weiche_vertikal_unten_rechts.svg", "NN SS NE WS": "Weiche_vertikal_unten_rechts.png",
"NE NW ES WS SS NN": "Weiche_Symetrical_gerade.svg", "NE NW ES WS SS NN": "Weiche_Symetrical_gerade.png",
"NE EN SW WS": "Gleis_Kurve_oben_links_unten_rechts.svg" "NE EN SW WS": "Gleis_Kurve_oben_links_unten_rechts.png"
} }
dTargetFiles = { target_files = {
"EW": "Bahnhof_#d50000_Deadend_links.svg", "EW": "Bahnhof_#d50000_Deadend_links.png",
"NS": "Bahnhof_#d50000_Deadend_oben.svg", "NS": "Bahnhof_#d50000_Deadend_oben.png",
"WE": "Bahnhof_#d50000_Deadend_rechts.svg", "WE": "Bahnhof_#d50000_Deadend_rechts.png",
"SN": "Bahnhof_#d50000_Deadend_unten.svg", "SN": "Bahnhof_#d50000_Deadend_unten.png",
"EE WW": "Bahnhof_#d50000_Gleis_horizontal.svg", "EE WW": "Bahnhof_#d50000_Gleis_horizontal.png",
"NN SS": "Bahnhof_#d50000_Gleis_vertikal.svg"} "NN SS": "Bahnhof_#d50000_Gleis_vertikal.png"}
# Dict of rail cell images indexed by binary transitions # Dict of rail cell images indexed by binary transitions
dPilRailFiles = self.loadSVGs(dRailFiles, rotate=True, backgroundImage="Background_rail.svg", pil_rail_files_org = self.load_pngs(rail_files, rotate=True)
whitefilter="Background_white_filter.svg") pil_rail_files = self.load_pngs(rail_files, rotate=True, background_image="Background_rail.png",
whitefilter="Background_white_filter.png")
# Load the target files (which have rails and transitions of their own) # Load the target files (which have rails and transitions of their own)
# They are indexed by (binTrans, iAgent), ie a tuple of the binary transition and the agent index # They are indexed by (binTrans, iAgent), ie a tuple of the binary transition and the agent index
dPilTargetFiles = self.loadSVGs(dTargetFiles, rotate=False, agent_colors=self.ltAgentColors, pil_target_files_org = self.load_pngs(target_files, rotate=False, agent_colors=self.agent_colors)
backgroundImage="Background_rail.svg", pil_target_files = self.load_pngs(target_files, rotate=False, agent_colors=self.agent_colors,
whitefilter="Background_white_filter.svg") background_image="Background_rail.png",
whitefilter="Background_white_filter.png")
# Load station and recolorize them # Load station and recolorize them
station = self.pilFromSvgFile("svg", "Bahnhof_#d50000_target.svg") station = self.pil_from_png_file('flatland.png', "Bahnhof_#d50000_target.png")
self.ltStationColors = self.recolorImage(station, [0, 0, 0], self.ltAgentColors, False) self.station_colors = self.recolor_image(station, [0, 0, 0], self.agent_colors, False)
cellOccupied = self.pilFromSvgFile("svg", "Cell_occupied.svg") cell_occupied = self.pil_from_png_file('flatland.png', "Cell_occupied.png")
self.ltCellOccupied = self.recolorImage(cellOccupied, [0, 0, 0], self.ltAgentColors, False) self.cell_occupied = self.recolor_image(cell_occupied, [0, 0, 0], self.agent_colors, False)
# Merge them with the regular rails. # Merge them with the regular rails.
# https://stackoverflow.com/questions/38987/how-to-merge-two-dictionaries-in-a-single-expression # https://stackoverflow.com/questions/38987/how-to-merge-two-dictionaries-in-a-single-expression
self.dPilRail = {**dPilRailFiles, **dPilTargetFiles} self.pil_rail = {**pil_rail_files, **pil_target_files}
self.pil_rail_org = {**pil_rail_files_org, **pil_target_files_org}
def loadSVGs(self, dDirFile, rotate=False, agent_colors=False, backgroundImage=None, whitefilter=None): def load_pngs(self, file_directory, rotate=False, agent_colors=False, background_image=None, whitefilter=None):
dPil = {} pil = {}
transitions = RailEnvTransitions() transitions = RailEnvTransitions()
lDirs = list("NESW") directions = list("NESW")
for sTrans, sFile in dDirFile.items(): for transition, file in file_directory.items():
# Translate the ascii transition description in the format "NE WS" to the # Translate the ascii transition description in the format "NE WS" to the
# binary list of transitions as per RailEnv - NESW (in) x NESW (out) # binary list of transitions as per RailEnv - NESW (in) x NESW (out)
lTrans16 = ["0"] * 16 transition_16_bit = ["0"] * 16
for sTran in sTrans.split(" "): for sTran in transition.split(" "):
if len(sTran) == 2: if len(sTran) == 2:
iDirIn = lDirs.index(sTran[0]) in_direction = directions.index(sTran[0])
iDirOut = lDirs.index(sTran[1]) out_direction = directions.index(sTran[1])
iTrans = 4 * iDirIn + iDirOut transition_idx = 4 * in_direction + out_direction
lTrans16[iTrans] = "1" transition_16_bit[transition_idx] = "1"
sTrans16 = "".join(lTrans16) transition_16_bit_string = "".join(transition_16_bit)
binTrans = int(sTrans16, 2) binary_trans = int(transition_16_bit_string, 2)
pilRail = self.pilFromSvgFile('svg', sFile) pil_rail = self.pil_from_png_file('flatland.png', file).convert("RGBA")
if backgroundImage is not None: if background_image is not None:
imgBg = self.pilFromSvgFile('svg', backgroundImage) img_bg = self.pil_from_png_file('flatland.png', background_image).convert("RGBA")
pilRail = Image.alpha_composite(imgBg, pilRail) pil_rail = Image.alpha_composite(img_bg, pil_rail)
if whitefilter is not None: if whitefilter is not None:
imgBg = self.pilFromSvgFile('svg', whitefilter) img_bg = self.pil_from_png_file('flatland.png', whitefilter).convert("RGBA")
pilRail = Image.alpha_composite(pilRail, imgBg) pil_rail = Image.alpha_composite(pil_rail, img_bg)
if rotate: if rotate:
# For rotations, we also store the base image # For rotations, we also store the base image
dPil[binTrans] = pilRail pil[binary_trans] = pil_rail
# Rotate both the transition binary and the image and save in the dict # Rotate both the transition binary and the image and save in the dict
for nRot in [90, 180, 270]: for nRot in [90, 180, 270]:
binTrans2 = transitions.rotate_transition(binTrans, nRot) binary_trans_2 = transitions.rotate_transition(binary_trans, nRot)
# PIL rotates anticlockwise for positive theta # PIL rotates anticlockwise for positive theta
pilRail2 = pilRail.rotate(-nRot) pil_rail_2 = pil_rail.rotate(-nRot)
dPil[binTrans2] = pilRail2 pil[binary_trans_2] = pil_rail_2
if agent_colors: if agent_colors:
# For recoloring, we don't store the base image. # For recoloring, we don't store the base image.
a3BaseColor = self.rgb_s2i("d50000") base_color = self.rgb_s2i("d50000")
lPils = self.recolorImage(pilRail, a3BaseColor, self.ltAgentColors) pils = self.recolor_image(pil_rail, base_color, self.agent_colors)
for iColor, pilRail2 in enumerate(lPils): for color_idx, pil_rail_2 in enumerate(pils):
dPil[(binTrans, iColor)] = lPils[iColor] pil[(binary_trans, color_idx)] = pils[color_idx]
return dPil return pil
def setRailAt(self, row, col, binTrans, iTarget=None, isSelected=False, rail_grid=None): def set_predicion_path_at(self, row, col, binary_trans, agent_rail_color):
if binTrans in self.dPilRail: colored_rail = self.recolor_image(self.pil_rail_org[binary_trans],
pilTrack = self.dPilRail[binTrans] [61, 61, 61], [agent_rail_color],
if iTarget is not None: False)[0]
pilTrack = Image.alpha_composite(pilTrack, self.ltStationColors[iTarget % len(self.ltStationColors)]) self.draw_image_row_col(colored_rail, (row, col), layer=PILGL.PREDICTION_PATH_LAYER)
if binTrans == 0: def set_rail_at(self, row, col, binary_trans, target=None, is_selected=False, rail_grid=None, num_agents=None,
if self.background_grid[col][row] <= 4: show_debug=True):
if binary_trans in self.pil_rail:
pil_track = self.pil_rail[binary_trans]
if target is not None:
target_img = self.station_colors[target % len(self.station_colors)]
target_img = Image.alpha_composite(pil_track, target_img)
self.draw_image_row_col(target_img, (row, col), layer=PILGL.TARGET_LAYER)
if show_debug:
self.text_rowcol((row + 0.8, col + 0.0), strText=str(target), layer=PILGL.TARGET_LAYER)
city_size = 1
if num_agents is not None:
city_size = max(1, np.log(1 + num_agents) / 2.5)
if binary_trans == 0:
if self.background_grid[col][row] <= 4 + np.ceil(((col * row + col) % 10) / city_size):
a = int(self.background_grid[col][row]) a = int(self.background_grid[col][row])
a = a % len(self.dBuildings) a = a % len(self.lBuildings)
if (col + row + col * row) % 13 > 11: if (col + row + col * row) % 13 > 11:
pilTrack = self.dScenery[a % len(self.dScenery)] pil_track = self.scenery[a % len(self.scenery)]
else: else:
if (col + row + col * row) % 3 == 0: if (col + row + col * row) % 3 == 0:
a = (a + (col + row + col * row)) % len(self.dBuildings) a = (a + (col + row + col * row)) % len(self.lBuildings)
pilTrack = self.dBuildings[a] pil_track = self.lBuildings[a]
elif (self.background_grid[col][row] > 4) or ((col ** 3 + row ** 2 + col * row) % 10 == 0): elif ((self.background_grid[col][row] > 5 + ((col * row + col) % 3)) or
((col ** 3 + row ** 2 + col * row) % 10 == 0)):
a = int(self.background_grid[col][row]) - 4 a = int(self.background_grid[col][row]) - 4
a2 = (a + (col + row + col * row + col ** 3 + row ** 4)) a2 = (a + (col + row + col * row + col ** 3 + row ** 4))
if a2 % 17 > 11: if a2 % 64 > 11:
a = a2 a = a2
pilTrack = self.dScenery[a % len(self.dScenery)] a_l = a % len(self.scenery)
if a2 % 50 == 49:
self.drawImageRC(pilTrack, (row, col)) pil_track = self.scenery_water[0]
else:
pil_track = self.scenery[a_l]
if rail_grid is not None:
if a2 % 11 > 3:
if a_l == len(self.scenery) - 1:
# mountain
if col > 1 and row % 7 == 1:
if rail_grid[row, col - 1] == 0:
self.draw_image_row_col(self.scenery_d2[0], (row, col - 1),
layer=PILGL.RAIL_LAYER)
pil_track = self.scenery_d2[1]
else:
if a_l == len(self.scenery) - 1:
# mountain
if col > 2 and not (row % 7 == 1):
if rail_grid[row, col - 2] == 0 and rail_grid[row, col - 1] == 0:
self.draw_image_row_col(self.scenery_d3[0], (row, col - 2),
layer=PILGL.RAIL_LAYER)
self.draw_image_row_col(self.scenery_d3[1], (row, col - 1),
layer=PILGL.RAIL_LAYER)
pil_track = self.scenery_d3[2]
self.draw_image_row_col(pil_track, (row, col), layer=PILGL.RAIL_LAYER)
else: else:
print("Illegal rail:", row, col, format(binTrans, "#018b")[2:], binTrans) print("Can't render - illegal rail or SVG element is undefined:", row, col,
format(binary_trans, "#018b")[2:], binary_trans)
if iTarget is not None: if target is not None:
if isSelected: if is_selected:
svgBG = self.pilFromSvgFile("svg", "Selected_Target.svg") svgBG = self.pil_from_png_file('flatland.png', "Selected_Target.png")
self.clear_layer(3, 0) self.clear_layer(PILGL.SELECTED_TARGET_LAYER, 0)
self.drawImageRC(svgBG, (row, col), layer=3) self.draw_image_row_col(svgBG, (row, col), layer=PILGL.SELECTED_TARGET_LAYER)
def recolorImage(self, pil, a3BaseColor, ltColors, invert=False): def recolor_image(self, pil, a3BaseColor, ltColors, invert=False):
rgbaImg = array(pil) rgbaImg = array(pil)
lPils = [] pils = []
for iColor, tnColor in enumerate(ltColors): for iColor, tnColor in enumerate(ltColors):
# find the pixels which match the base paint color # find the pixels which match the base paint color
if invert: if invert:
xy_color_mask = np.all(rgbaImg[:, :, 0:3] - a3BaseColor != 0, axis=2) xy_color_mask = np.all(rgbaImg[:, :, 0:3] - a3BaseColor != 0, axis=2)
else: else:
xy_color_mask = np.all(rgbaImg[:, :, 0:3] - a3BaseColor == 0, axis=2) xy_color_mask = np.all(rgbaImg[:, :, 0:3] - a3BaseColor == 0, axis=2)
rgbaImg2 = np.copy(rgbaImg) rgbaImg2 = np.copy(rgbaImg)
# Repaint the base color with the new color # Repaint the base color with the new color
rgbaImg2[xy_color_mask, 0:3] = tnColor rgbaImg2[xy_color_mask, 0:3] = tnColor
pil2 = Image.fromarray(rgbaImg2) pil2 = Image.fromarray(rgbaImg2)
lPils.append(pil2) pils.append(pil2)
return lPils return pils
def loadAgentSVGs(self): def load_agent(self):
# Seed initial train/zug files indexed by tuple(iDirIn, iDirOut): # Seed initial train/zug files indexed by tuple(iDirIn, iDirOut):
dDirsFile = { file_directory = {
(0, 0): "Zug_Gleis_#0091ea.svg", (0, 0): "Zug_Gleis_#0091ea.png",
(1, 2): "Zug_1_Weiche_#0091ea.svg", (1, 2): "Zug_1_Weiche_#0091ea.png",
(0, 3): "Zug_2_Weiche_#0091ea.svg" (0, 3): "Zug_2_Weiche_#0091ea.png"
} }
# "paint" color of the train images we load - this is the color we will change. # "paint" color of the train images we load - this is the color we will change.
# a3BaseColor = self.rgb_s2i("0091ea") \# noqa: E800 # base_color = self.rgb_s2i("0091ea") \# noqa: E800
# temporary workaround for trains / agents renamed with different colour: # temporary workaround for trains / agents renamed with different colour:
a3BaseColor = self.rgb_s2i("d50000") base_color = self.rgb_s2i("d50000")
self.dPilZug = {} self.pil_zug = {}
for tDirs, sPathSvg in dDirsFile.items(): for directions, path_svg in file_directory.items():
iDirIn, iDirOut = tDirs in_direction, out_direction = directions
pilZug = self.pilFromSvgFile("svg", sPathSvg) pil_zug = self.pil_from_png_file('flatland.png', path_svg)
# Rotate both the directions and the image and save in the dict # Rotate both the directions and the image and save in the dict
for iDirRot in range(4): for rot_direction in range(4):
nDegRot = iDirRot * 90 rotation_degree = rot_direction * 90
iDirIn2 = (iDirIn + iDirRot) % 4 in_direction_2 = (in_direction + rot_direction) % 4
iDirOut2 = (iDirOut + iDirRot) % 4 out_direction_2 = (out_direction + rot_direction) % 4
# PIL rotates anticlockwise for positive theta # PIL rotates anticlockwise for positive theta
pilZug2 = pilZug.rotate(-nDegRot) pil_zug_2 = pil_zug.rotate(-rotation_degree)
# Save colored versions of each rotation / variant # Save colored versions of each rotation / variant
lPils = self.recolorImage(pilZug2, a3BaseColor, self.ltAgentColors) pils = self.recolor_image(pil_zug_2, base_color, self.agent_colors)
for iColor, pilZug3 in enumerate(lPils): for color_idx, pil_zug_3 in enumerate(pils):
self.dPilZug[(iDirIn2, iDirOut2, iColor)] = lPils[iColor] self.pil_zug[(in_direction_2, out_direction_2, color_idx)] = pils[color_idx]
def setAgentAt(self, iAgent, row, col, iDirIn, iDirOut, isSelected): def set_agent_at(self, agent_idx, row, col, in_direction, out_direction, is_selected,
delta_dir = (iDirOut - iDirIn) % 4 rail_grid=None, show_debug=False, clear_debug_text=True, malfunction=False):
iColor = iAgent % self.nAgentColors delta_dir = (out_direction - in_direction) % 4
# when flipping direction at a dead end, use the "iDirOut" direction. color_idx = agent_idx % self.n_agent_colors
# when flipping direction at a dead end, use the "out_direction" direction.
if delta_dir == 2: if delta_dir == 2:
iDirIn = iDirOut in_direction = out_direction
pilZug = self.dPilZug[(iDirIn % 4, iDirOut % 4, iColor)] pil_zug = self.pil_zug[(in_direction % 4, out_direction % 4, color_idx)]
self.drawImageRC(pilZug, (row, col), layer=1) self.draw_image_row_col(pil_zug, (row, col), layer=PILGL.AGENT_LAYER)
if rail_grid is not None:
if rail_grid[row, col] == 0.0:
self.draw_image_row_col(self.scenery_background_white, (row, col), layer=PILGL.RAIL_LAYER)
if is_selected:
bg_svg = self.pil_from_png_file('flatland.png', "Selected_Agent.png")
self.clear_layer(PILGL.SELECTED_AGENT_LAYER, 0)
self.draw_image_row_col(bg_svg, (row, col), layer=PILGL.SELECTED_AGENT_LAYER)
if show_debug:
if not clear_debug_text:
dr = 0.2
dc = 0.2
if in_direction == 0:
dr = 0.8
dc = 0.0
if in_direction == 1:
dr = 0.0
dc = 0.8
if in_direction == 2:
dr = 0.4
dc = 0.8
if in_direction == 3:
dr = 0.8
dc = 0.4
self.text_rowcol((row + dr, col + dc,), str(agent_idx), layer=PILGL.SELECTED_AGENT_LAYER)
else:
self.text_rowcol((row + 0.2, col + 0.2,), str(agent_idx))
if malfunction:
self.draw_malfunction(agent_idx, (row, col))
if isSelected: def set_cell_occupied(self, agent_idx, row, col):
svgBG = self.pilFromSvgFile("svg", "Selected_Agent.svg") occupied_im = self.cell_occupied[agent_idx % len(self.cell_occupied)]
self.clear_layer(2, 0) self.draw_image_row_col(occupied_im, (row, col), 1)
self.drawImageRC(svgBG, (row, col), layer=2)
def setCellOccupied(self, iAgent, row, col): def draw_malfunction(self, agent_idx, rcTopLeft):
occIm = self.ltCellOccupied[iAgent % len(self.ltCellOccupied)] # Roughly an "X" shape to indicate malfunction
self.drawImageRC(occIm, (row, col), 1) grcOffsets = np.array([[0, 0], [1, 1], [1, 0], [0, 1]])
grcPoints = np.array(rcTopLeft)[None] + grcOffsets
gxyPoints = grcPoints[:, [1, 0]]
gxPoints, gyPoints = gxyPoints.T
# print(agent_idx, rcTopLeft, gxyPoints, "X:", gxPoints, "Y:", gyPoints)
# plot(self, gX, gY, color=None, linewidth=3, layer=RAIL_LAYER, opacity=255, **kwargs):
self.plot(gxPoints, -gyPoints, color=(0, 0, 0, 255), layer=PILGL.AGENT_LAYER, linewidth=2)
def main2(): def main2():
gl = PILSVG(10, 10) gl = PILSVG(10, 10)
for i in range(10): for i in range(10):
gl.beginFrame() gl.begin_frame()
gl.plot([3 + i, 4], [-4 - i, -5], color="r") gl.plot([3 + i, 4], [-4 - i, -5], color="r")
gl.endFrame() gl.endFrame()
time.sleep(1) time.sleep(1)
...@@ -602,7 +680,7 @@ def main(): ...@@ -602,7 +680,7 @@ def main():
gl = PILSVG(width=10, height=10) gl = PILSVG(width=10, height=10)
for i in range(1000): for i in range(1000):
gl.processEvents() gl.process_events()
time.sleep(0.1) time.sleep(0.1)
time.sleep(1) time.sleep(1)
......