Commit e4695483 authored by spiglerg's avatar spiglerg
Browse files

cleaned up some deadends references in railenv + fixed pylint errors

parent b604cad1
Pipeline #488 passed with stage
in 2 minutes and 42 seconds
......@@ -3,6 +3,7 @@ import numpy as np
import matplotlib.pyplot as plt
from flatland.envs.rail_env import *
from flatland.envs.generators import *
from flatland.core.env_observation_builder import TreeObsForRailEnv
from flatland.utils.rendertools import *
......@@ -62,6 +63,11 @@ specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (7, 0), (0, 0)],
[(0, 0), (0, 0), (0, 0), (2, 180), (2, 90), (7, 90)],
[(0, 0), (0, 0), (0, 0), (7, 180), (0, 0), (0, 0)]]
# CURVED RAIL + DEAD-ENDS TEST
specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0)],
[(7, 270), (1, 90), (1, 90), (8, 90), (0, 0), (0, 0)],
[(0, 0), (7, 270),(1, 90), (8, 180), (0, 00), (0, 0)]]
env = RailEnv(width=6,
height=4,
rail_generator=rail_from_manual_specifications_generator(specs),
......
......@@ -314,14 +314,13 @@ class GridTransitionMap(TransitionMap):
Checks that:
- surrounding cells have inbound transitions for all the
outbound transitions of this cell.
These are NOT checked - see transition.is_valid:
- all transitions have the mirror transitions (N->E <=> W->S)
- Reverse transitions (N -> S) only exist for a dead-end
- a cell contains either no dead-ends or exactly one
Returns: True (valid) or False (invalid)
"""
cell_transition = self.grid[tuple(rcPos)]
......
......@@ -551,7 +551,7 @@ class RailEnvTransitions(Grid4Transitions):
super(RailEnvTransitions, self).__init__(
transitions=self.transition_list
)
# These bits represent all the possible dead ends
self.maskDeadEnds = 0b0010000110000100
......@@ -569,10 +569,10 @@ class RailEnvTransitions(Grid4Transitions):
def print(self, cell_transition):
print(" NESW")
print("N", format(cell_transition >> (3*4) & 0xF, '04b'))
print("E", format(cell_transition >> (2*4) & 0xF, '04b'))
print("S", format(cell_transition >> (1*4) & 0xF, '04b'))
print("W", format(cell_transition >> (0*4) & 0xF, '04b'))
print("N", format(cell_transition >> (3 * 4) & 0xF, '04b'))
print("E", format(cell_transition >> (2 * 4) & 0xF, '04b'))
print("S", format(cell_transition >> (1 * 4) & 0xF, '04b'))
print("W", format(cell_transition >> (0 * 4) & 0xF, '04b'))
def repr(self, cell_transition, version=0):
"""
......@@ -585,25 +585,23 @@ class RailEnvTransitions(Grid4Transitions):
sbinTrans = format(cell_transition, "#018b")[2:]
if version == 0:
sRepr = " ".join([
"{}:{}".format(sDir, sbinTrans[i:i+4])
"{}:{}".format(sDir, sbinTrans[i:(i + 4)])
for i, sDir in
zip(
range(0, len(sbinTrans), 4),
self.lsDirs # NESW
)])
self.lsDirs)]) # NESW
return sRepr
if version == 1:
lsRepr = []
for iDirIn in range(0, 4):
sDirTrans = sbinTrans[iDirIn*4:iDirIn*4+4]
sDirTrans = sbinTrans[(iDirIn * 4):(iDirIn * 4 + 4)]
if sDirTrans == "0000":
continue
sDirsOut = [
self.lsDirs[iDirOut]
for iDirOut in range(0, 4)
if sDirTrans[iDirOut] == "1"
]
if sDirTrans[iDirOut] == "1"]
lsRepr.append(self.lsDirs[iDirIn] + ":" + "".join(sDirsOut))
return ", ".join(lsRepr)
......
......@@ -271,4 +271,3 @@ def connect_rail(rail_trans, rail_array, start, end):
def distance_on_rail(pos1, pos2):
return abs(pos1[0] - pos2[0]) + abs(pos1[1] - pos2[1])
......@@ -250,7 +250,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
transitions_templates_ = []
transition_probabilities = []
for i in range(len(t_utils.transitions)-4): # don't include dead-ends
for i in range(len(t_utils.transitions) - 4): # don't include dead-ends
all_transitions = 0
for dir_ in range(4):
trans = t_utils.get_transitions(t_utils.transitions[i], dir_)
......@@ -475,4 +475,3 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
return return_rail
return generator
......@@ -268,7 +268,7 @@ class RailEnv(Environment):
break
else:
self.agents_direction[i] = direction
# Reset the state of the observation builder with the new environment
self.obs_builder.reset()
......@@ -314,60 +314,33 @@ class RailEnv(Environment):
# compute number of possible transitions in the current
# cell used to check for invalid actions
nbits = 0
tmp = self.rail.get_transitions((pos[0], pos[1]))
possible_transitions = self.rail.get_transitions((pos[0], pos[1], direction))
# print(np.sum(self.rail.get_transitions((pos[0], pos[1],direction))),
# self.rail.get_transitions((pos[0], pos[1],direction)),
# self.rail.get_transitions((pos[0], pos[1])),
# (pos[0], pos[1],direction))
num_transitions = np.count_nonzero(possible_transitions)
while tmp > 0:
nbits += (tmp & 1)
tmp = tmp >> 1
movement = direction
# print(nbits,np.sum(possible_transitions))
if action == 1:
movement = direction - 1
if nbits <= 2 or np.sum(possible_transitions) <= 1:
if num_transitions <= 1:
transition_isValid = False
elif action == 3:
movement = direction + 1
if nbits <= 2 or np.sum(possible_transitions) <= 1:
if num_transitions <= 1:
transition_isValid = False
if movement < 0:
movement += 4
if movement >= 4:
movement -= 4
is_deadend = False
if action == 2:
if nbits == 1:
# dead-end; assuming the rail network is consistent,
# this should match the direction the agent has come
# from. But it's better to check in any case.
reverse_direction = 0
if direction == 0:
reverse_direction = 2
elif direction == 1:
reverse_direction = 3
elif direction == 2:
reverse_direction = 0
elif direction == 3:
reverse_direction = 1
valid_transition = self.rail.get_transition(
(pos[0], pos[1], direction),
reverse_direction)
if valid_transition:
direction = reverse_direction
movement = reverse_direction
is_deadend = True
if np.sum(possible_transitions) == 1:
# Take only available transition
if num_transitions == 1:
# - dead-end, straight line or curved line;
# movement will be the only valid transition
# - take only available transition
movement = np.argmax(possible_transitions)
transition_isValid = True
new_position = self._new_position(pos, movement)
# Is it a legal move? 1) transition allows the movement in the
......@@ -388,7 +361,7 @@ class RailEnv(Environment):
if transition_isValid is None:
transition_isValid = self.rail.get_transition(
(pos[0], pos[1], direction),
movement) or is_deadend
movement)
cell_isFree = True
for j in range(self.number_of_agents):
......
......@@ -31,10 +31,10 @@ class EditorMVC(object):
if env is None:
env = RailEnv(width=10,
height=10,
rail_generator=random_rail_generator(cell_type_relative_proportion=[1, 1] + [0.5] * 6),
number_of_agents=0,
obs_builder_object=TreeObsForRailEnv(max_depth=2))
height=10,
rail_generator=random_rail_generator(cell_type_relative_proportion=[1, 1] + [0.5] * 6),
number_of_agents=0,
obs_builder_object=TreeObsForRailEnv(max_depth=2))
env.reset()
......@@ -113,20 +113,19 @@ class View(object):
wButton = ipywidgets.Button(description=dButton["name"])
wButton.on_click(dButton["method"])
self.lwButtons.append(wButton)
self.wVbox_controls = VBox([
self.wFilename, # self.wDrawMode,
*self.lwButtons,
self.wSize,
self.wDebug, self.wDebug_move,
self.wProg_steps
])
self.wProg_steps])
self.wMain = HBox([self.wImage, self.wVbox_controls])
def drawStroke(self):
pass
def new_env(self):
self.oRT = rt.RenderTool(self.editor.env)
......@@ -139,11 +138,11 @@ class View(object):
img = self.oRT.getImage()
plt.clf()
plt.close()
self.wImage.data = img
self.writableData = np.copy(self.wImage.data)
return img
def redisplayImage(self):
if self.writableData is not None:
# This updates the image in the browser to be the new edited version
......@@ -152,7 +151,7 @@ class View(object):
def drag_path_element(self, x, y):
# Draw a black square on the in-memory copy of the image
if x > 10 and x < self.yxSize[1] and y > 10 and y < self.yxSize[0]:
self.writableData[y-2:y+2, x-2:x+2, :] = 0
self.writableData[(y - 2):(y + 2), (x - 2):(x + 2), :] = 0
def xy_to_rc(self, x, y):
rcCell = ((array([y, x]) - self.yxBase) / self.nPixCell).astype(int)
......@@ -227,7 +226,7 @@ class Controller(object):
# The intention was to avoid too many redraws.
if event["buttons"] > 0:
qEvents.append((time.time(), x, y))
# Process the events in our queue:
# Draw a black square to indicate a trail
# TODO: infer a vector of moves between these squares to avoid gaps
......@@ -238,13 +237,13 @@ class Controller(object):
if tNow - qEvents[0][0] > 0.1: # wait before trying to draw
# height, width = wid.data.shape[:2]
# writableData = np.copy(self.wid_img.data) # writable copy of image - wid_img.data is somehow readonly
# with self.wid_img.hold_sync():
while len(qEvents) > 0:
t, x, y = qEvents.popleft() # get events from our queue
self.view.drag_path_element(x, y)
# Translate and scale from x,y to integer row,col (note order change)
# rcCell = ((array([y, x]) - self.yxBase) / self.nPixCell).astype(int)
rcCell = self.view.xy_to_rc(x, y)
......@@ -266,13 +265,13 @@ class Controller(object):
def refresh(self, event):
self.debug("refresh")
self.view.redraw()
def clear(self, event):
self.model.clear()
def regenerate(self, event):
self.model.regenerate()
def setRegenSize(self, event):
self.model.setRegenSize(event["new"])
......@@ -342,14 +341,14 @@ class EditorModel(object):
"""Mouse motion event handler for drawing.
"""
lrcStroke = self.lrcStroke
# Store the row,col location of the click, if we have entered a new cell
if len(lrcStroke) > 0:
rcLast = lrcStroke[-1]
if not np.array_equal(rcLast, rcCell): # only save at transition
lrcStroke.append(rcCell)
self.debug("lrcStroke ", len(lrcStroke), rcCell)
else:
# This is the first cell in a mouse stroke
lrcStroke.append(rcCell)
......@@ -427,14 +426,16 @@ class EditorModel(object):
# Set the transition
# If this transition spans 3 cells, it is not a deadend, so remove any deadends.
# The user will need to resolve any conflicts.
self.env.rail.set_transition((*rcMiddle, liTrans[0]), liTrans[1], bAddRemove,
remove_deadends=not bDeadend)
self.env.rail.set_transition((*rcMiddle, liTrans[0]),
liTrans[1],
bAddRemove,
remove_deadends=not bDeadend)
# Also set the reverse transition
# use the reversed outbound transition for inbound
# and the reversed inbound transition for outbound
self.env.rail.set_transition((*rcMiddle, mirror(liTrans[1])),
mirror(liTrans[0]), bAddRemove, remove_deadends=not bDeadend)
mirror(liTrans[0]), bAddRemove, remove_deadends=not bDeadend)
# bValid = self.env.rail.is_cell_valid(rcMiddle)
# if not bValid:
......@@ -457,7 +458,7 @@ class EditorModel(object):
# get the row, col delta between the 2 cells, eg [-1,0] = North
rc1Trans = np.diff(rc2Cells, axis=0)
# get the direction index for the transition
liTrans = []
for rcTrans in rc1Trans:
......@@ -491,7 +492,7 @@ class EditorModel(object):
self.env.agents_handles = []
self.env.agents_target = []
self.player = None
self.redraw()
def setFilename(self, filename):
......@@ -507,7 +508,7 @@ class EditorModel(object):
self.redraw()
else:
self.log("File does not exist:", self.env_filename, " Working directory: ", os.getcwd())
def save(self):
self.log("save to ", self.env_filename, " working dir: ", os.getcwd())
self.env.rail.save_transition_map(self.env_filename)
......@@ -515,17 +516,17 @@ class EditorModel(object):
def regenerate(self):
self.log("Regenerate size", self.regen_size)
self.env = RailEnv(width=self.regen_size,
height=self.regen_size,
rail_generator=random_rail_generator(cell_type_relative_proportion=[1, 1] + [0.5] * 6),
number_of_agents=self.env.number_of_agents,
obs_builder_object=TreeObsForRailEnv(max_depth=2))
height=self.regen_size,
rail_generator=random_rail_generator(cell_type_relative_proportion=[1, 1] + [0.5] * 6),
number_of_agents=self.env.number_of_agents,
obs_builder_object=TreeObsForRailEnv(max_depth=2))
self.env.reset(regen_rail=True)
self.fix_env()
self.set_env(self.env)
self.player = Player(self.env)
self.view.new_env()
self.redraw()
def setRegenSize(self, size):
self.regen_size = size
......@@ -579,5 +580,9 @@ class EditorModel(object):
def debug_cell(self, rcCell):
binTrans = self.env.rail.get_transitions(rcCell)
sbinTrans = format(binTrans, "#018b")[2:]
self.debug("cell ", rcCell, "Transitions: ", binTrans, sbinTrans,
[sbinTrans[i:i+4] for i in range(0, len(sbinTrans), 4)])
self.debug("cell ",
rcCell,
"Transitions: ",
binTrans,
sbinTrans,
[sbinTrans[i:(i + 4)] for i in range(0, len(sbinTrans), 4)])
......@@ -61,8 +61,8 @@ class QTGL(GraphicsLayer):
if lastx is not None:
# print("line", lastx, lasty, x, y)
self.qtr.drawLine(
lastx*self.cell_pixels, -lasty*self.cell_pixels,
x*self.cell_pixels, -y*self.cell_pixels)
lastx * self.cell_pixels, -lasty * self.cell_pixels,
x * self.cell_pixels, -y * self.cell_pixels)
lastx = x
lasty = y
else:
......
......@@ -145,7 +145,7 @@ class RenderTool(object):
self.plotTrans(visit.rc, gTransRCAg, depth=str(visit.iDepth), color=color)
def plotAgents(self, targets=True):
cmap = self.gl.get_cmap('hsv', lut=self.env.number_of_agents+1)
cmap = self.gl.get_cmap('hsv', lut=self.env.number_of_agents + 1)
for iAgent in range(self.env.number_of_agents):
oColor = cmap(iAgent)
......@@ -208,16 +208,16 @@ class RenderTool(object):
xyDir = np.matmul(rcDir, rt.grc2xy) # agent direction in xy
xyPos = np.matmul(rcPos - rcDir / 2, rt.grc2xy) + rt.xyHalf
#print("Agent:", rcPos, iDir, rcDir, xyDir, xyPos)
# print("Agent:", rcPos, iDir, rcDir, xyDir, xyPos)
self.gl.scatter(*xyPos, color=color, marker="o", s=100) # agent location
xyDirLine = array([xyPos, xyPos + xyDir/2]).T # line for agent orient.
xyDirLine = array([xyPos, xyPos + xyDir / 2]).T # line for agent orient.
self.gl.plot(*xyDirLine, color=color, lw=5, ms=0, alpha=0.6)
if target is not None:
rcTarget = array(target)
xyTarget = np.matmul(rcTarget, rt.grc2xy) + rt.xyHalf
self._draw_square(xyTarget, 1/3, color)
self._draw_square(xyTarget, 1 / 3, color)
def plotTrans(self, rcPos, gTransRCAg, color="r", depth=None):
"""
......@@ -587,7 +587,7 @@ class RenderTool(object):
# Draw each agent + its orientation + its target
if agents:
cmap = self.gl.get_cmap('hsv', lut=env.number_of_agents+1)
cmap = self.gl.get_cmap('hsv', lut=env.number_of_agents + 1)
self.plotAgents(targets=True)
if False:
......@@ -659,4 +659,3 @@ class RenderTool(object):
def getImage(self):
return self.gl.getImage()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment