Skip to content
Snippets Groups Projects
Commit f0d0db79 authored by hagrid67's avatar hagrid67
Browse files

fixed various editor bugs - issue #22

parent 379b7e48
No related branches found
No related tags found
No related merge requests found
...@@ -15,7 +15,7 @@ import os ...@@ -15,7 +15,7 @@ import os
# from ipywidgets import IntSlider, link, VBox # from ipywidgets import IntSlider, link, VBox
from flatland.envs.rail_env import RailEnv, random_rail_generator from flatland.envs.rail_env import RailEnv, random_rail_generator
from flatland.envs.generators import complex_rail_generator from flatland.envs.generators import complex_rail_generator, empty_rail_generator
# from flatland.core.transitions import RailEnvTransitions # from flatland.core.transitions import RailEnvTransitions
from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.observations import TreeObsForRailEnv
import flatland.utils.rendertools as rt import flatland.utils.rendertools as rt
...@@ -60,7 +60,7 @@ class View(object): ...@@ -60,7 +60,7 @@ class View(object):
self.new_env() self.new_env()
self.oRT.renderEnv(spacing=False, arrows=False, sRailColor="gray", show=False) self.oRT.renderEnv(spacing=False, arrows=False, sRailColor="gray", show=False)
img = self.oRT.getImage() img = self.oRT.getImage()
plt.clf() plt.clf() # TODO: remove this plt.clf() call
self.wImage = jpy_canvas.Canvas(img) self.wImage = jpy_canvas.Canvas(img)
self.yxSize = self.wImage.data.shape[:2] self.yxSize = self.wImage.data.shape[:2]
self.writableData = np.copy(self.wImage.data) # writable copy of image - wid_img.data is somehow readonly self.writableData = np.copy(self.wImage.data) # writable copy of image - wid_img.data is somehow readonly
...@@ -86,6 +86,9 @@ class View(object): ...@@ -86,6 +86,9 @@ class View(object):
self.wDebug_move = Checkbox(description="Debug mouse move") self.wDebug_move = Checkbox(description="Debug mouse move")
self.wDebug_move.observe(self.controller.setDebugMove, names="value") self.wDebug_move.observe(self.controller.setDebugMove, names="value")
# Checkbox for rendering observations
self.wShowObs = Checkbox(description="Show Agent Observations")
# This is like a cell widget where loggin goes # This is like a cell widget where loggin goes
self.wOutput = Output() self.wOutput = Output()
...@@ -95,13 +98,15 @@ class View(object): ...@@ -95,13 +98,15 @@ class View(object):
self.wFilename.observe(self.controller.setFilename, names="value") self.wFilename.observe(self.controller.setFilename, names="value")
# Size of environment when regenerating # Size of environment when regenerating
self.wSize = IntSlider(value=10, min=5, max=30, step=5, description="Regen Size") self.wRegenSize = IntSlider(value=10, min=5, max=100, step=5, description="Regen Size",
self.wSize.observe(self.controller.setRegenSize, names="value") tip="Click Regenerate after changing this")
self.wRegenSize.observe(self.controller.setRegenSize, names="value")
# Number of Agents when regenerating # Number of Agents when regenerating
self.wNAgents = IntSlider(value=1, min=0, max=20, step=1, description="# Agents") self.wRegenNAgents = IntSlider(value=1, min=0, max=20, step=1, description="# Agents",
tip="Click regenerate or reset after changing this")
self.wRegenMethod = RadioButtons(description="Regen\nMethod", options=["Random Cell", "Path-based"]) self.wRegenMethod = RadioButtons(description="Regen\nMethod", options=["Empty", "Random Cell", "Path-based"])
self.wReplaceAgents = Checkbox(value=True, description="Replace Agents") self.wReplaceAgents = Checkbox(value=True, description="Replace Agents")
self.wTab = Tab() self.wTab = Tab()
...@@ -109,8 +114,8 @@ class View(object): ...@@ -109,8 +114,8 @@ class View(object):
for i, title in enumerate(tab_contents): for i, title in enumerate(tab_contents):
self.wTab.set_title(i, title) self.wTab.set_title(i, title)
self.wTab.children = [ self.wTab.children = [
VBox([self.wDebug, self.wDebug_move]), VBox([self.wDebug, self.wDebug_move, self.wShowObs]),
VBox([self.wRegenMethod, self.wReplaceAgents])] VBox([self.wRegenSize, self.wRegenNAgents, self.wRegenMethod, self.wReplaceAgents])]
# Progress bar intended for stepping in the background (not yet working) # Progress bar intended for stepping in the background (not yet working)
self.wProg_steps = ipywidgets.IntProgress(value=0, min=0, max=20, step=1, description="Step") self.wProg_steps = ipywidgets.IntProgress(value=0, min=0, max=20, step=1, description="Step")
...@@ -140,8 +145,8 @@ class View(object): ...@@ -140,8 +145,8 @@ class View(object):
self.wVbox_controls = VBox([ self.wVbox_controls = VBox([
self.wFilename, # self.wDrawMode, self.wFilename, # self.wDrawMode,
*self.lwButtons, *self.lwButtons,
self.wSize, # self.wRegenSize,
self.wNAgents, # self.wRegenNAgents,
self.wProg_steps, self.wProg_steps,
self.wTab]) self.wTab])
...@@ -161,13 +166,17 @@ class View(object): ...@@ -161,13 +166,17 @@ class View(object):
with self.wOutput: with self.wOutput:
# plt.figure(figsize=(10, 10)) # plt.figure(figsize=(10, 10))
self.oRT.renderEnv(spacing=False, arrows=False, sRailColor="gray", self.oRT.renderEnv(spacing=False, arrows=False, sRailColor="gray",
show=False, iSelectedAgent=self.model.iSelectedAgent) show=False, iSelectedAgent=self.model.iSelectedAgent,
show_observations=self.show_observations())
img = self.oRT.getImage() img = self.oRT.getImage()
# plt.clf() # plt.clf()
# plt.close() # plt.close()
self.wImage.data = img self.wImage.data = img
self.writableData = np.copy(self.wImage.data) self.writableData = np.copy(self.wImage.data)
# the size should only be updated on regenerate at most
self.yxSize = self.wImage.data.shape[:2]
return img return img
def redisplayImage(self): def redisplayImage(self):
...@@ -191,6 +200,13 @@ class View(object): ...@@ -191,6 +200,13 @@ class View(object):
else: else:
print(*args, **kwargs) print(*args, **kwargs)
def show_observations(self):
''' returns whether to show observations - boolean '''
if self.wShowObs.value:
return True
else:
return False
class Controller(object): class Controller(object):
""" """
...@@ -297,17 +313,17 @@ class Controller(object): ...@@ -297,17 +313,17 @@ class Controller(object):
self.model.clear() self.model.clear()
def reset(self, event): def reset(self, event):
self.log("Reset - nAgents:", self.view.wNAgents.value) self.log("Reset - nAgents:", self.view.wRegenNAgents.value)
self.model.reset(replace_agents=self.view.wReplaceAgents.value, self.model.reset(replace_agents=self.view.wReplaceAgents.value,
nAgents=self.view.wNAgents.value) nAgents=self.view.wRegenNAgents.value)
def restartAgents(self, event): def restartAgents(self, event):
self.log("Restart Agents - nAgents:", self.view.wNAgents.value) self.log("Restart Agents - nAgents:", self.view.wRegenNAgents.value)
self.model.restartAgents() self.model.restartAgents()
def regenerate(self, event): def regenerate(self, event):
method = self.view.wRegenMethod.value method = self.view.wRegenMethod.value
nAgents = self.view.wNAgents.value nAgents = self.view.wRegenNAgents.value
self.model.regenerate(method, nAgents) self.model.regenerate(method, nAgents)
def setRegenSize(self, event): def setRegenSize(self, event):
...@@ -375,6 +391,43 @@ class EditorModel(object): ...@@ -375,6 +391,43 @@ class EditorModel(object):
def setDrawMode(self, sDrawMode): def setDrawMode(self, sDrawMode):
self.drawMode = sDrawMode self.drawMode = sDrawMode
def interpolate_path(self, rcLast, rcCell):
if np.array_equal(rcLast, rcCell):
return []
rcLast = array(rcLast)
rcCell = array(rcCell)
rcDelta = rcCell - rcLast
lrcInterp = [] # extra row,col points
if np.any(np.abs(rcDelta) >= 1):
iDim0 = np.argmax(np.abs(rcDelta)) # the dimension with the bigger move
iDim1 = 1 - iDim0 # the dim with the smaller move
rcRatio = rcDelta[iDim1] / rcDelta[iDim0]
delta0 = rcDelta[iDim0]
sgn0 = np.sign(delta0)
iDelta1 = 0
# count integers along the larger dimension
for iDelta0 in range(sgn0, delta0 + sgn0, sgn0):
rDelta1 = iDelta0 * rcRatio
if np.abs(rDelta1 - iDelta1) >= 1:
rcInterp = (iDelta0, iDelta1) # fill in the "corner" for "Manhattan interpolation"
lrcInterp.append(rcInterp)
iDelta1 = int(rDelta1)
rcInterp = (iDelta0, int(rDelta1))
lrcInterp.append(rcInterp)
g2Interp = array(lrcInterp)
if iDim0 == 1: # if necessary, swap c,r to make r,c
g2Interp = g2Interp[:, [1, 0]]
g2Interp += rcLast
# Convert the array to a list of tuples
lrcInterp = list(map(tuple, g2Interp))
return lrcInterp
def drag_path_element(self, rcCell): def drag_path_element(self, rcCell):
"""Mouse motion event handler for drawing. """Mouse motion event handler for drawing.
""" """
...@@ -384,8 +437,9 @@ class EditorModel(object): ...@@ -384,8 +437,9 @@ class EditorModel(object):
if len(lrcStroke) > 0: if len(lrcStroke) > 0:
rcLast = lrcStroke[-1] rcLast = lrcStroke[-1]
if not np.array_equal(rcLast, rcCell): # only save at transition if not np.array_equal(rcLast, rcCell): # only save at transition
lrcStroke.append(rcCell) lrcInterp = self.interpolate_path(rcLast, rcCell)
self.debug("lrcStroke ", len(lrcStroke), rcCell) lrcStroke.extend(lrcInterp)
self.debug("lrcStroke ", len(lrcStroke), rcCell, "interp:", lrcInterp)
else: else:
# This is the first cell in a mouse stroke # This is the first cell in a mouse stroke
...@@ -567,7 +621,9 @@ class EditorModel(object): ...@@ -567,7 +621,9 @@ class EditorModel(object):
def regenerate(self, method=None, nAgents=0): def regenerate(self, method=None, nAgents=0):
self.log("Regenerate size", self.regen_size) self.log("Regenerate size", self.regen_size)
if method is None or method == "Random Cell": if method is None or method == "Empty":
fnMethod = empty_rail_generator()
elif method == "Random Cell":
fnMethod = random_rail_generator(cell_type_relative_proportion=[1] * 11) fnMethod = random_rail_generator(cell_type_relative_proportion=[1] * 11)
else: else:
fnMethod = complex_rail_generator(nr_start_goal=5, nr_extra=20, min_dist=12) fnMethod = complex_rail_generator(nr_start_goal=5, nr_extra=20, min_dist=12)
...@@ -583,6 +639,7 @@ class EditorModel(object): ...@@ -583,6 +639,7 @@ class EditorModel(object):
self.set_env(self.env) self.set_env(self.env)
self.player = Player(self.env) self.player = Player(self.env)
self.view.new_env() self.view.new_env()
# self.view.init_canvas() # Can't do init_canvas - need to keep the same canvas widget!
self.redraw() self.redraw()
def setRegenSize(self, size): def setRegenSize(self, size):
......
...@@ -51,7 +51,7 @@ class GraphicsLayer(object): ...@@ -51,7 +51,7 @@ class GraphicsLayer(object):
elif type(color) is tuple: elif type(color) is tuple:
if type(color[0]) is not int: if type(color[0]) is not int:
gcolor = array(color) gcolor = array(color)
color = tuple((gcolor[:4] * 255).astype(int)) color = tuple((gcolor[:3] * 255).astype(int))
else: else:
color = self.tColGrid color = self.tColGrid
......
...@@ -606,7 +606,7 @@ class RenderTool(object): ...@@ -606,7 +606,7 @@ class RenderTool(object):
def renderEnv( def renderEnv(
self, show=False, curves=True, spacing=False, self, show=False, curves=True, spacing=False,
arrows=False, agents=True, obsrender=True, sRailColor="gray", frames=False, arrows=False, agents=True, show_observations=True, sRailColor="gray", frames=False,
iEpisode=None, iStep=None, iEpisode=None, iStep=None,
iSelectedAgent=None, action_dict=None): iSelectedAgent=None, action_dict=None):
""" """
...@@ -643,7 +643,7 @@ class RenderTool(object): ...@@ -643,7 +643,7 @@ class RenderTool(object):
# Draw each agent + its orientation + its target # Draw each agent + its orientation + its target
if agents: if agents:
self.plotAgents(targets=True, iSelectedAgent=iSelectedAgent) self.plotAgents(targets=True, iSelectedAgent=iSelectedAgent)
if obsrender: if show_observations:
self.renderObs(range(env.get_num_agents()), env.dev_obs_dict) self.renderObs(range(env.get_num_agents()), env.dev_obs_dict)
# Draw some textual information like fps # Draw some textual information like fps
yText = [-0.3, -0.6, -0.9] yText = [-0.3, -0.6, -0.9]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment