Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • flatland/flatland
  • stefan_otte/flatland
  • jiaodaxiaozi/flatland
  • sfwatergit/flatland
  • utozx126/flatland
  • ChenKuanSun/flatland
  • ashivani/flatland
  • minhhoa/flatland
  • pranjal_dhole/flatland
  • darthgera123/flatland
  • rivesunder/flatland
  • thomaslecat/flatland
  • joel_joseph/flatland
  • kchour/flatland
  • alex_zharichenko/flatland
  • yoogottamk/flatland
  • troye_fang/flatland
  • elrichgro/flatland
  • jun_jin/flatland
  • nimishsantosh107/flatland
20 results
Show changes
Showing
with 1814 additions and 447 deletions
......@@ -10,29 +10,45 @@ from numpy import array
import flatland.utils.rendertools as rt
from flatland.core.grid.grid4_utils import mirror
from flatland.envs.agent_utils import EnvAgent, EnvAgentStatic
from flatland.envs.generators import complex_rail_generator, empty_rail_generator
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.line_generators import sparse_line_generator
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.rail_env import RailEnv, random_rail_generator
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator, empty_rail_generator
from flatland.envs.persistence import RailEnvPersister
class EditorMVC(object):
""" EditorMVC - a class to encompass and assemble the Jupyter Editor Model-View-Controller.
"""
def __init__(self, env=None, sGL="PIL"):
def __init__(self, env=None, sGL="PIL", env_filename="temp.pkl"):
""" Create an Editor MVC assembly around a railenv, or create one if None.
"""
if env is None:
env = RailEnv(width=10,
height=10,
rail_generator=empty_rail_generator(),
number_of_agents=0,
obs_builder_object=TreeObsForRailEnv(max_depth=2))
nAgents = 3
n_cities = 2
max_rails_between_cities = 2
max_rails_in_city = 4
seed = 0
env = RailEnv(
width=20,
height=30,
rail_generator=sparse_rail_generator(
max_num_cities=n_cities,
seed=seed,
grid_mode=True,
max_rails_between_cities=max_rails_between_cities,
max_rail_pairs_in_city=max_rails_in_city
),
line_generator=sparse_line_generator(),
number_of_agents=nAgents,
obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv())
)
env.reset()
self.editor = EditorModel(env)
self.editor = EditorModel(env, env_filename=env_filename)
self.editor.view = self.view = View(self.editor, sGL=sGL)
self.view.controller = self.editor.controller = self.controller = Controller(self.editor, self.view)
self.view.init_canvas()
......@@ -43,9 +59,10 @@ class View(object):
""" The Jupyter Editor View - creates and holds the widgets comprising the Editor.
"""
def __init__(self, editor, sGL="MPL"):
def __init__(self, editor, sGL="MPL", screen_width=1200, screen_height=1200):
self.editor = self.model = editor
self.sGL = sGL
self.xyScreen = (screen_width, screen_height)
def display(self):
self.output_generator.clear_output()
......@@ -95,7 +112,7 @@ class View(object):
# Number of Agents when regenerating
self.regen_n_agents = IntSlider(value=1, min=0, max=5, step=1, description="# Agents",
tip="Click regenerate or reset after changing this")
self.regen_method = RadioButtons(description="Regen\nMethod", options=["Empty", "Random Cell"])
self.regen_method = RadioButtons(description="Regen\nMethod", options=["Empty", "Sparse"])
self.replace_agents = Checkbox(value=True, description="Replace Agents")
......@@ -111,7 +128,7 @@ class View(object):
ldButtons = [
dict(name="Refresh", method=self.controller.refresh, tip="Redraw only"),
dict(name="Rotate Agent", method=self.controller.rotate_agent, tip="Rotate selected agent"),
dict(name="Restart Agents", method=self.controller.restart_agents,
dict(name="Restart Agents", method=self.controller.reset_agents,
tip="Move agents back to start positions"),
dict(name="Random", method=self.controller.reset,
tip="Generate a randomized scene, including regen rail + agents"),
......@@ -142,22 +159,25 @@ class View(object):
def new_env(self):
""" Tell the view to update its graphics when a new env is created.
"""
self.oRT = rt.RenderTool(self.editor.env, gl=self.sGL)
self.oRT = rt.RenderTool(self.editor.env, gl=self.sGL, show_debug=True,
screen_height=self.xyScreen[1], screen_width=self.xyScreen[0])
def redraw(self):
with self.output_generator:
self.oRT.set_new_rail()
self.model.env.agents = self.model.env.agents_static
self.model.env.reset_agents()
for a in self.model.env.agents:
if hasattr(a, 'old_position') is False:
a.old_position = a.position
if hasattr(a, 'old_direction') is False:
a.old_direction = a.direction
self.oRT.render_env(agents=True,
self.oRT.render_env(show_agents=True,
show_inactive_agents=True,
show=False,
selected_agent=self.model.selected_agent,
show_observations=False)
show_observations=False,
)
img = self.oRT.get_image()
self.wImage.data = img
......@@ -178,12 +198,14 @@ class View(object):
self.writableData[(y - 2):(y + 2), (x - 2):(x + 2), :3] = 0
def xy_to_rc(self, x, y):
rcCell = ((array([y, x]) - self.yxBase))
rc_cell = ((array([y, x]) - self.yxBase))
nX = np.floor((self.yxSize[0] - self.yxBase[0]) / self.model.env.height)
nY = np.floor((self.yxSize[1] - self.yxBase[1]) / self.model.env.width)
rcCell[0] = max(0, min(np.floor(rcCell[0] / nY), self.model.env.height - 1))
rcCell[1] = max(0, min(np.floor(rcCell[1] / nX), self.model.env.width - 1))
return rcCell
rc_cell[0] = max(0, min(np.floor(rc_cell[0] / nY), self.model.env.height - 1))
rc_cell[1] = max(0, min(np.floor(rc_cell[1] / nX), self.model.env.width - 1))
# Using numpy arrays for coords not currently supported downstream in the env, observations, etc
return tuple(rc_cell)
def log(self, *args, **kwargs):
if self.output_generator:
......@@ -215,23 +237,23 @@ class Controller(object):
y = event['canvasY']
self.debug("debug:", x, y)
rcCell = self.view.xy_to_rc(x, y)
rc_cell = self.view.xy_to_rc(x, y)
bShift = event["shiftKey"]
bCtrl = event["ctrlKey"]
bAlt = event["altKey"]
if bCtrl and not bShift and not bAlt:
self.model.click_agent(rcCell)
self.model.click_agent(rc_cell)
self.lrcStroke = []
elif bShift and bCtrl:
self.model.add_target(rcCell)
self.model.add_target(rc_cell)
self.lrcStroke = []
elif bAlt and not bShift and not bCtrl:
self.model.clear_cell(rcCell)
self.model.clear_cell(rc_cell)
self.lrcStroke = []
self.debug("click in cell", rcCell)
self.model.debug_cell(rcCell)
self.debug("click in cell", rc_cell)
self.model.debug_cell(rc_cell)
if self.model.selected_agent is not None:
self.lrcStroke = []
......@@ -285,11 +307,14 @@ class Controller(object):
else:
self.lrcStroke = []
if self.model.selected_agent is not None:
self.lrcStroke = []
while len(q_events) > 0:
t, x, y = q_events.popleft()
return
# JW: I think this clause causes all editing to fail once an agent is selected.
# I also can't see why it's necessary. So I've if-falsed it out.
if False:
if self.model.selected_agent is not None:
self.lrcStroke = []
while len(q_events) > 0:
t, x, y = q_events.popleft()
return
# Process the events in our queue:
# Draw a black square to indicate a trail
......@@ -304,8 +329,8 @@ class Controller(object):
self.view.drag_path_element(x, y)
# Translate and scale from x,y to integer row,col (note order change)
rcCell = self.view.xy_to_rc(x, y)
self.editor.drag_path_element(rcCell)
rc_cell = self.view.xy_to_rc(x, y)
self.editor.drag_path_element(rc_cell)
self.view.redisplay_image()
......@@ -323,29 +348,24 @@ class Controller(object):
self.log("Reset - nAgents:", self.view.regen_n_agents.value)
self.log("Reset - size:", self.model.regen_size_width)
self.log("Reset - size:", self.model.regen_size_height)
self.model.reset(replace_agents=self.view.replace_agents.value,
self.model.reset(regenerate_schedule=self.view.replace_agents.value,
nAgents=self.view.regen_n_agents.value)
def rotate_agent(self, event):
self.log("Rotate Agent:", self.model.selected_agent)
if self.model.selected_agent is not None:
for agent_idx, agent in enumerate(self.model.env.agents_static):
for agent_idx, agent in enumerate(self.model.env.agents):
if agent is None:
continue
if agent_idx == self.model.selected_agent:
agent.direction = (agent.direction + 1) % 4
agent.initial_direction = (agent.initial_direction + 1) % 4
agent.direction = agent.initial_direction
agent.old_direction = agent.direction
self.model.redraw()
def restart_agents(self, event):
def reset_agents(self, event):
self.log("Restart Agents - nAgents:", self.view.regen_n_agents.value)
if self.model.init_agents_static is not None:
self.model.env.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in
self.model.init_agents_static]
self.model.env.agents = None
self.model.init_agents_static = None
self.model.env.restart_agents()
self.model.env.reset(False, False)
self.model.env.reset(False, False)
self.refresh(event)
def regenerate(self, event):
......@@ -382,7 +402,7 @@ class Controller(object):
class EditorModel(object):
def __init__(self, env):
def __init__(self, env, env_filename="temp.pkl"):
self.view = None
self.env = env
self.regen_size_width = 10
......@@ -396,10 +416,9 @@ class EditorModel(object):
self.debug_move_bool = False
self.wid_output = None
self.draw_mode = "Draw"
self.env_filename = "temp.pkl"
self.env_filename = env_filename
self.set_env(env)
self.selected_agent = None
self.init_agents_static = None
self.thread = None
self.save_image_count = 0
......@@ -420,12 +439,12 @@ class EditorModel(object):
def set_draw_mode(self, draw_mode):
self.draw_mode = draw_mode
def interpolate_path(self, rcLast, rcCell):
if np.array_equal(rcLast, rcCell):
def interpolate_pair(self, rcLast, rc_cell):
if np.array_equal(rcLast, rc_cell):
return []
rcLast = array(rcLast)
rcCell = array(rcCell)
rcDelta = rcCell - rcLast
rc_cell = array(rc_cell)
rcDelta = rc_cell - rcLast
lrcInterp = [] # extra row,col points
......@@ -457,7 +476,16 @@ class EditorModel(object):
lrcInterp = list(map(tuple, g2Interp))
return lrcInterp
def drag_path_element(self, rcCell):
def interpolate_path(self, lrcPath):
lrcPath2 = [] # interpolated version of the path
rcLast = None
for rcCell in lrcPath:
if rcLast is not None:
lrcPath2.extend(self.interpolate_pair(rcLast, rcCell))
rcLast = rcCell
return lrcPath2
def drag_path_element(self, rc_cell):
"""Mouse motion event handler for drawing.
"""
lrcStroke = self.lrcStroke
......@@ -465,15 +493,15 @@ class EditorModel(object):
# Store the row,col location of the click, if we have entered a new cell
if len(lrcStroke) > 0:
rcLast = lrcStroke[-1]
if not np.array_equal(rcLast, rcCell): # only save at transition
lrcInterp = self.interpolate_path(rcLast, rcCell)
if not np.array_equal(rcLast, rc_cell): # only save at transition
lrcInterp = self.interpolate_pair(rcLast, rc_cell)
lrcStroke.extend(lrcInterp)
self.debug("lrcStroke ", len(lrcStroke), rcCell, "interp:", lrcInterp)
self.debug("lrcStroke ", len(lrcStroke), rc_cell, "interp:", lrcInterp)
else:
# This is the first cell in a mouse stroke
lrcStroke.append(rcCell)
self.debug("lrcStroke ", len(lrcStroke), rcCell)
lrcStroke.append(rc_cell)
self.debug("lrcStroke ", len(lrcStroke), rc_cell)
def mod_path(self, bAddRemove):
# disabled functionality (no longer required)
......@@ -492,6 +520,8 @@ class EditorModel(object):
# If we have already touched 3 cells
# We have a transition into a cell, and out of it.
#print(lrcStroke)
if len(lrcStroke) >= 2:
# If the first cell in a stroke is empty, add a deadend to cell 0
if self.env.rail.get_full_transitions(*lrcStroke[0]) == 0:
......@@ -500,6 +530,7 @@ class EditorModel(object):
# Add transitions for groups of 3 cells
# hence inbound and outbound transitions for middle cell
while len(lrcStroke) >= 3:
#print(lrcStroke)
self.mod_rail_3cells(lrcStroke, bAddRemove=bAddRemove)
# If final cell empty, insert deadend:
......@@ -507,6 +538,8 @@ class EditorModel(object):
if self.env.rail.get_full_transitions(*lrcStroke[1]) == 0:
self.mod_rail_2cells(lrcStroke, bAddRemove, iCellToMod=1)
#print("final:", lrcStroke)
# now empty out the final two cells from the queue
lrcStroke.clear()
......@@ -582,6 +615,8 @@ class EditorModel(object):
iTrans = iTrans[0][0]
liTrans.append(iTrans)
#self.log("liTrans:", liTrans)
# check that we have one transition
if len(liTrans) == 1:
# Set the transition as a deadend
......@@ -602,7 +637,6 @@ class EditorModel(object):
def clear(self):
self.env.rail.grid[:, :] = 0
self.env.agents = []
self.env.agents_static = []
self.redraw()
......@@ -611,12 +645,12 @@ class EditorModel(object):
self.env.rail.grid[cell_row_col[0], cell_row_col[1]] = 0
self.redraw()
def reset(self, replace_agents=False, nAgents=0):
def reset(self, regenerate_schedule=False, nAgents=0):
self.regenerate("complex", nAgents=nAgents)
self.redraw()
def restart_agents(self):
self.env.agents = EnvAgent.list_from_static(self.env.agents_static)
self.env.reset_agents()
self.redraw()
def set_filename(self, filename):
......@@ -625,16 +659,16 @@ class EditorModel(object):
def load(self):
if os.path.exists(self.env_filename):
self.log("load file: ", self.env_filename)
self.env.load(self.env_filename)
#self.env.load(self.env_filename)
RailEnvPersister.load(self.env, self.env_filename)
if not self.regen_size_height == self.env.height or not self.regen_size_width == self.env.width:
self.regen_size_height = self.env.height
self.regen_size_width = self.env.width
self.regenerate(None, 0, self.env)
self.env.load(self.env_filename)
RailEnvPersister.load(self.env, self.env_filename)
self.env.restart_agents()
self.env.reset_agents()
self.env.reset(False, False)
self.init_agents_static = None
self.view.oRT.update_background()
self.fix_env()
self.set_env(self.env)
......@@ -644,12 +678,8 @@ class EditorModel(object):
def save(self):
self.log("save to ", self.env_filename, " working dir: ", os.getcwd())
temp_store = self.env.agents
# clear agents before save , because we want the "init" position of the agent to expert
self.env.agents = []
self.env.save(self.env_filename)
# reset agents current (current position)
self.env.agents = temp_store
#self.env.save(self.env_filename)
RailEnvPersister.save(self.env, self.env_filename)
def save_image(self):
self.view.oRT.gl.save_image('frame_{:04d}.bmp'.format(self.save_image_count))
......@@ -663,21 +693,17 @@ class EditorModel(object):
if method is None or method == "Empty":
fnMethod = empty_rail_generator()
elif method == "Random Cell":
fnMethod = random_rail_generator(cell_type_relative_proportion=[1] * 11)
else:
fnMethod = complex_rail_generator(nr_start_goal=nAgents, nr_extra=20, min_dist=12, seed=int(time.time()))
fnMethod = sparse_rail_generator(nr_start_goal=nAgents, nr_extra=20, min_dist=12, seed=int(time.time()))
if env is None:
self.env = RailEnv(width=self.regen_size_width,
height=self.regen_size_height,
rail_generator=fnMethod,
number_of_agents=nAgents,
obs_builder_object=TreeObsForRailEnv(max_depth=2))
self.env = RailEnv(width=self.regen_size_width, height=self.regen_size_height, rail_generator=fnMethod,
number_of_agents=nAgents, obs_builder_object=TreeObsForRailEnv(max_depth=2))
else:
self.env = env
self.env.reset(regen_rail=True)
self.env.reset(regenerate_rail=True)
self.fix_env()
self.selected_agent = None # clear the selected agent.
self.set_env(self.env)
self.view.new_env()
self.redraw()
......@@ -689,35 +715,53 @@ class EditorModel(object):
self.regen_size_height = size
def find_agent_at(self, cell_row_col):
for agent_idx, agent in enumerate(self.env.agents_static):
if tuple(agent.position) == tuple(cell_row_col):
for agent_idx, agent in enumerate(self.env.agents):
if agent.position is None:
rc_pos = agent.initial_position
else:
rc_pos = agent.position
if tuple(rc_pos) == tuple(cell_row_col):
return agent_idx
return None
def click_agent(self, cell_row_col):
""" The user has clicked on a cell -
- If there is an agent, select it
- If that agent was already selected, then deselect it
- If there is no agent selected, and no agent in the cell, create one
- If there is an agent selected, and no agent in the cell, move the selected agent to the cell
* If there is an agent, select it
* If that agent was already selected, then deselect it
* If there is no agent selected, and no agent in the cell, create one
* If there is an agent selected, and no agent in the cell, move the selected agent to the cell
"""
# Has the user clicked on an existing agent?
agent_idx = self.find_agent_at(cell_row_col)
# This is in case we still have a selected agent even though the env has been recreated
# with no agents.
if (self.selected_agent is not None) and (self.selected_agent > len(self.env.agents)):
self.selected_agent = None
# Defensive coding below - for cell_row_col to be a tuple, not a numpy array:
# numpy array breaks various things when loading the env.
if agent_idx is None:
# No
if self.selected_agent is None:
# Create a new agent and select it.
agent_static = EnvAgentStatic(position=cell_row_col, direction=0, target=cell_row_col, moving=False)
self.selected_agent = self.env.add_agent_static(agent_static)
agent = EnvAgent(initial_position=tuple(cell_row_col),
initial_direction=0,
direction=0,
target=tuple(cell_row_col),
moving=False,
)
self.selected_agent = self.env.add_agent(agent)
# self.env.set_agent_active(agent)
self.view.oRT.update_background()
else:
# Move the selected agent to this cell
agent_static = self.env.agents_static[self.selected_agent]
agent_static.position = cell_row_col
agent_static.old_position = cell_row_col
self.env.agents = []
agent = self.env.agents[self.selected_agent]
agent.initial_position = tuple(cell_row_col)
agent.position = tuple(cell_row_col)
agent.old_position = tuple(cell_row_col)
else:
# Yes
# Have they clicked on the agent already selected?
......@@ -728,13 +772,11 @@ class EditorModel(object):
# No - select the agent
self.selected_agent = agent_idx
self.init_agents_static = None
self.redraw()
def add_target(self, rcCell):
def add_target(self, rc_cell):
if self.selected_agent is not None:
self.env.agents_static[self.selected_agent].target = rcCell
self.init_agents_static = None
self.env.agents[self.selected_agent].target = tuple(rc_cell)
self.view.oRT.update_background()
self.redraw()
......@@ -752,11 +794,11 @@ class EditorModel(object):
if self.debug_bool:
self.log(*args, **kwargs)
def debug_cell(self, rcCell):
binTrans = self.env.rail.get_full_transitions(*rcCell)
def debug_cell(self, rc_cell):
binTrans = self.env.rail.get_full_transitions(*rc_cell)
sbinTrans = format(binTrans, "#018b")[2:]
self.debug("cell ",
rcCell,
rc_cell,
"Transitions: ",
binTrans,
sbinTrans,
......
from numpy.random import RandomState
import flatland.envs.observations as obs
import flatland.envs.rail_generators as rg
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.line_generators import BaseLineGen
from flatland.envs.rail_env import RailEnv
from flatland.envs.timetable_utils import Line
from flatland.utils import editor
# Start and end all agents at the same place
class SchedGen2(BaseLineGen):
def __init__(self, rcStart, rcEnd, iDir):
self.rcStart = rcStart
self.rcEnd = rcEnd
self.iDir = iDir
def generate(self, rail: GridTransitionMap, num_agents: int, hints: dict = None, num_resets: int = None,
np_random: RandomState = None) -> Line:
return Line(agent_positions=[self.rcStart] * num_agents,
agent_directions=[self.iDir] * num_agents,
agent_targets=[self.rcEnd] * num_agents,
agent_speeds=[1.0] * num_agents)
# cycle through lists of start, end and direction
class SchedGen3(BaseLineGen):
def __init__(self, lrcStarts, lrcTargs, liDirs):
self.lrcStarts = lrcStarts
self.lrcTargs = lrcTargs
self.liDirs = liDirs
def generate(self, rail: GridTransitionMap, num_agents: int, hints: dict = None, num_resets: int = None,
np_random: RandomState = None) -> Line:
return Line(agent_positions=[self.lrcStarts[i % len(self.lrcStarts)] for i in range(num_agents)],
agent_directions=[self.liDirs[i % len(self.liDirs)] for i in range(num_agents)],
agent_targets=[self.lrcTargs[i % len(self.lrcTargs)] for i in range(num_agents)],
agent_speeds=[1.0] * num_agents)
def makeEnv(nAg=2, width=20, height=10, oSG=None):
env = RailEnv(width=width, height=height, rail_generator=rg.empty_rail_generator(),
number_of_agents=nAg,
line_generator=oSG,
obs_builder_object=obs.TreeObsForRailEnv(max_depth=1))
envModel = editor.EditorModel(env)
env.reset()
return env, envModel
def makeEnv2(nAg=2, shape=(20, 10), llrcPaths=[], lrcStarts=[], lrcTargs=[], liDirs=[], remove_agents_at_target=True):
oSG = SchedGen3(lrcStarts, lrcTargs, liDirs)
env = RailEnv(width=shape[0], height=shape[1],
rail_generator=rg.empty_rail_generator(),
number_of_agents=nAg,
line_generator=oSG,
obs_builder_object=obs.TreeObsForRailEnv(max_depth=1),
remove_agents_at_target=remove_agents_at_target,
record_steps=True)
envModel = editor.EditorModel(env)
env.reset()
for lrcPath in llrcPaths:
envModel.mod_rail_cell_seq(envModel.interpolate_path(lrcPath))
return env, envModel
ddEnvSpecs = {
# opposing stations with single alternative path
"single_alternative": {
"llrcPaths": [
[(1, 0), (1, 15)], # across the top
[(1, 4), (1, 6), (3, 6), (3, 12), (1, 12), (1, 14)], # alternative loop below
],
"lrcStarts": [(1, 3), (1, 14)],
"lrcTargs": [(1, 14), (1, 3)],
"liDirs": [1, 3]
},
# single spur so one agent needs to wait
"single_spur": {
"llrcPaths": [
[(1, 0), (1, 15)],
[(4, 0), (4, 6), (1, 6), (1, 8)]],
"lrcStarts": [(1, 3), (1, 14)],
"lrcTargs": [(1, 14), (4, 2)],
"liDirs": [1, 3]
},
# single spur so one agent needs to wait
"merging_spurs": {
"llrcPaths": [
[(1, 0), (1, 15), (7, 15), (7, 0)],
[(4, 0), (4, 6), (1, 6), (1, 8)],
# [((1,14), (1,16), (7,16), )]
],
"lrcStarts": [(1, 2), (4, 2)],
"lrcTargs": [(7, 3)],
"liDirs": [1]
},
# Concentric Loops
"concentric_loops": {
"llrcPaths": [
[(1, 1), (1, 5), (8, 5), (8, 1), (1, 1), (1, 3)],
[(1, 3), (1, 10), (8, 10), (8, 3)]
],
"lrcStarts": [(1, 3)],
"lrcTargs": [(2, 1)],
"liDirs": [1]
},
# two loops
"loop_with_loops": {
"llrcPaths": [
# big outer loop Row 1, 8; Col 1, 15
[(1, 1), (1, 15), (8, 15), (8, 1), (1, 1), (1, 3)],
# alternative 1
[(1, 3), (1, 5), (3, 5), (3, 10), (1, 10), (1, 12)],
# alternative 2
[(8, 3), (8, 5), (6, 5), (6, 10), (8, 10), (8, 12)],
],
# list of row,col of agent start cells
"lrcStarts": [(1, 3), (8, 3)],
# list of row,col of targets
"lrcTargs": [(8, 2), (1, 2)],
# list of initial directions
"liDirs": [1, 1],
}
}
def makeTestEnv(sName="single_alternative", nAg=2, remove_agents_at_target=True):
global ddEnvSpecs
dSpec = ddEnvSpecs[sName]
return makeEnv2(nAg=nAg, remove_agents_at_target=remove_agents_at_target, **dSpec)
def getAgentState(env):
dAgState = {}
for iAg, ag in enumerate(env.agents):
dAgState[iAg] = (*ag.position, ag.direction)
return dAgState
......@@ -25,6 +25,13 @@ class GraphicsLayer(object):
pass
def pause(self, seconds=0.00001):
""" deprecated """
pass
def idle(self, seconds=0.00001):
""" process any display events eg redraw, resize.
Return only after the given number of seconds, ie idle / loop until that number.
"""
pass
def clf(self):
......@@ -65,7 +72,7 @@ class GraphicsLayer(object):
def get_cmap(self, *args, **kwargs):
return plt.get_cmap(*args, **kwargs)
def set_rail_at(self, row, col, binTrans, iTarget=None, isSelected=False, rail_grid=None):
def set_rail_at(self, row, col, binTrans, iTarget=None, isSelected=False, rail_grid=None, num_agents=None):
""" Set the rail at cell (row, col) to have transitions binTrans.
The target argument can contain the index of the agent to indicate
that agent's target is at that cell, so that a station can be
......@@ -73,7 +80,8 @@ class GraphicsLayer(object):
"""
pass
def set_agent_at(self, iAgent, row, col, iDirIn, iDirOut, isSelected=False):
def set_agent_at(self, iAgent, row, col, iDirIn, iDirOut, isSelected=False, rail_grid=None, show_debug=False,
clear_debug_text=True):
pass
def set_cell_occupied(self, iAgent, row, col):
......
import pyglet as pgl
import time
from PIL import Image
# from numpy import array
# from pkg_resources import resource_string as resource_bytes
# from flatland.utils.graphics_layer import GraphicsLayer
from flatland.utils.graphics_pil import PILSVG
class PGLGL(PILSVG):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.window_open = False # means the window has not yet been opened.
self.close_requested = False # user has clicked
self.closed = False # windows has been closed (currently, we leave the env still running)
def open_window(self):
print("open_window - pyglet")
assert self.window_open is False, "Window is already open!"
self.window = pgl.window.Window(resizable=True, vsync=False, width=1200, height=800)
#self.__class__.window.title("Flatland")
#self.__class__.window.configure(background='grey')
self.window_open = True
@self.window.event
def on_draw():
#print("pyglet draw event")
self.window.clear()
self.show(from_event=True)
#print("pyglet draw event done")
@self.window.event
def on_resize(width, height):
#print(f"The window was resized to {width}, {height}")
self.show(from_event=True)
self.window.dispatch_event("on_draw")
#print("pyglet resize event done")
@self.window.event
def on_close():
self.close_requested = True
def close_window(self):
self.window.close()
self.closed=True
def show(self, block=False, from_event=False):
if not self.window_open:
self.open_window()
if self.close_requested:
if not self.closed:
self.close_window()
return
#tStart = time.time()
self._processEvents()
pil_img = self.alpha_composite_layers()
pil_img_resized = pil_img.resize((self.window.width, self.window.height), resample=Image.NEAREST)
# convert our PIL image to pyglet:
bytes_image = pil_img_resized.tobytes()
pgl_image = pgl.image.ImageData(pil_img_resized.width, pil_img_resized.height,
#self.window.width, self.window.height,
'RGBA',
bytes_image, pitch=-pil_img_resized.width * 4)
pgl_image.blit(0,0)
#tEnd = time.time()
#print("show time: ", tEnd - tStart)
def _processEvents(self):
""" This is the replacement for a custom event loop for Pyglet.
The lines below are typical of Pyglet examples.
Manually resizing the window is still very clunky.
"""
#print("process events...", end="")
pgl.clock.tick()
#for window in pgl.app.windows:
if not self.closed:
self.window.switch_to()
self.window.dispatch_events()
self.window.flip()
#print(" events done")
def idle(self, seconds=0.00001):
tStart = time.time()
tEnd = tStart + seconds
while (time.time() < tEnd):
self._processEvents()
#self.show()
time.sleep(min(seconds, 0.1))
def test_pyglet():
oGL = PGLGL(400,300)
time.sleep(2)
def test_event_loop():
""" Shows how it should work with the standard event loop
Resizing is fairly smooth (ie runs at least 10-20x a second)
"""
window = pgl.window.Window(resizable=True)
pil_img = Image.open("notebooks/simple_example_3.png")
def show():
pil_img_resized = pil_img.resize((window.width, window.height), resample=Image.NEAREST)
bytes_image = pil_img_resized.tobytes()
pgl_image = pgl.image.ImageData(pil_img_resized.width, pil_img_resized.height,
#self.window.width, self.window.height,
'RGBA',
bytes_image, pitch=-pil_img_resized.width * 4)
pgl_image.blit(0,0)
@window.event
def on_draw():
print("pyglet draw event")
window.clear()
show()
print("pyglet draw event done")
@window.event
def on_resize(width, height):
print(f"The window was resized to {width}, {height}")
#show()
print("pyglet resize event done")
@window.event
def on_close():
#self.close_requested = True
print("close")
pgl.app.run()
if __name__=="__main__":
#test_pyglet()
test_event_loop()
\ No newline at end of file
import io
import os
import time
import tkinter as tk
#import tkinter as tk
import numpy as np
from PIL import Image, ImageDraw, ImageTk # , ImageFont
from PIL import Image, ImageDraw, ImageFont
from numpy import array
from pkg_resources import resource_string as resource_bytes
from flatland.utils.graphics_layer import GraphicsLayer
def enable_windows_cairo_support():
if os.name == 'nt':
import site
import ctypes.util
default_os_path = os.environ['PATH']
os.environ['PATH'] = ''
for s in site.getsitepackages():
os.environ['PATH'] = os.environ['PATH'] + ';' + s + '\\cairo'
os.environ['PATH'] = os.environ['PATH'] + ';' + default_os_path
if ctypes.util.find_library('cairo') is None:
print("Error: cairo not installed")
enable_windows_cairo_support()
from cairosvg import svg2png # noqa: E402
from flatland.core.grid.rail_env_grid import RailEnvTransitions # noqa: E402
class PILGL(GraphicsLayer):
# tk.Tk() must be a singleton!
# https://stackoverflow.com/questions/26097811/image-pyimage2-doesnt-exist
window = tk.Tk()
# window = tk.Tk()
RAIL_LAYER = 0
PREDICTION_PATH_LAYER = 1
......@@ -41,7 +25,7 @@ class PILGL(GraphicsLayer):
SELECTED_AGENT_LAYER = 4
SELECTED_TARGET_LAYER = 5
def __init__(self, width, height, jupyter=False):
def __init__(self, width, height, jupyter=False, screen_width=800, screen_height=600):
self.yxBase = (0, 0)
self.linewidth = 4
self.n_agent_colors = 1 # overridden in loadAgent
......@@ -52,13 +36,13 @@ class PILGL(GraphicsLayer):
self.background_grid = np.zeros(shape=(self.width, self.height))
if jupyter is False:
# NOTE: Currently removed the dependency on
# screeninfo. We have to find an alternate
# NOTE: Currently removed the dependency on
# screeninfo. We have to find an alternate
# way to compute the screen width and height
# In the meantime, we are harcoding the 800x600
# In the meantime, we are harcoding the 800x600
# assumption
self.screen_width = 800
self.screen_height = 600
self.screen_width = screen_width
self.screen_height = screen_height
w = (self.screen_width - self.width - 10) / (self.width + 1 + self.linewidth)
h = (self.screen_height - self.height - 10) / (self.height + 1 + self.linewidth)
self.nPixCell = int(max(1, np.ceil(min(w, h))))
......@@ -81,15 +65,15 @@ class PILGL(GraphicsLayer):
sColors = "d50000#c51162#aa00ff#6200ea#304ffe#2962ff#0091ea#00b8d4#00bfa5#00c853" + \
"#64dd17#aeea00#ffd600#ffab00#ff6d00#ff3d00#5d4037#455a64"
self.agent_colors = [self.rgb_s2i(sColor) for sColor in sColors.split("#")]
self.n_agent_colors = len(self.agent_colors)
self.window_open = False
self.firstFrame = True
self.old_background_image = (None, None, None)
self.create_layers()
self.font = ImageFont.load_default()
def build_background_map(self, dTargets):
x = self.old_background_image
rebuild = False
......@@ -107,14 +91,17 @@ class PILGL(GraphicsLayer):
rebuild = True
if rebuild:
# rebuild background_grid to control the visualisation of buildings, trees, mountains, lakes and river
self.background_grid = np.zeros(shape=(self.width, self.height))
# build base distance map (distance to targets)
for x in range(self.width):
for y in range(self.height):
distance = int(np.ceil(np.sqrt(self.width ** 2.0 + self.height ** 2.0)))
for rc in dTargets:
r = rc[1]
c = rc[0]
d = int(np.floor(np.sqrt((x - r) ** 2 + (y - c) ** 2)))
d = int(np.floor(np.sqrt((x - r) ** 2 + (y - c) ** 2)) / 0.5)
distance = min(d, distance)
self.background_grid[x][y] = distance
......@@ -128,6 +115,7 @@ class PILGL(GraphicsLayer):
return self.agent_colors[iAgent % self.n_agent_colors]
def plot(self, gX, gY, color=None, linewidth=3, layer=RAIL_LAYER, opacity=255, **kwargs):
""" Draw a line joining the points in gX, GY - each an"""
color = self.adapt_color(color)
if len(color) == 3:
color += (opacity,)
......@@ -135,7 +123,8 @@ class PILGL(GraphicsLayer):
color = color[:3] + (opacity,)
gPoints = np.stack([array(gX), -array(gY)]).T * self.nPixCell
gPoints = list(gPoints.ravel())
self.draws[layer].line(gPoints, fill=color, width=self.linewidth)
# the width here was self.linewidth - not really sure of the implications
self.draws[layer].line(gPoints, fill=color, width=linewidth)
def scatter(self, gX, gY, color=None, marker="o", s=50, layer=RAIL_LAYER, opacity=255, *args, **kwargs):
color = self.adapt_color(color)
......@@ -145,6 +134,19 @@ class PILGL(GraphicsLayer):
self.draws[layer].rectangle([(x - r, y - r), (x + r, y + r)], fill=color, outline=color)
def draw_image_xy(self, pil_img, xyPixLeftTop, layer=RAIL_LAYER, ):
# Resize all PIL images just before drawing them
# to ensure that resizing doesnt affect the
# recolorizing strategies in place
#
# That said : All the code in this file needs
# some serious refactoring -_- to ensure the
# code style and structure is consitent.
# - Mohanty
pil_img = pil_img.resize(
(self.nPixCell, self.nPixCell)
)
if (pil_img.mode == "RGBA"):
pil_mask = pil_img
else:
......@@ -157,19 +159,19 @@ class PILGL(GraphicsLayer):
self.draw_image_xy(pil_img, xyPixLeftTop, layer=layer)
def open_window(self):
assert self.window_open is False, "Window is already open!"
self.__class__.window.title("Flatland")
self.__class__.window.configure(background='grey')
self.window_open = True
pass
def close_window(self):
self.panel.destroy()
# quit but not destroy!
self.__class__.window.quit()
def text(self, *args, **kwargs):
pass
def text(self, xPx, yPx, strText, layer=RAIL_LAYER):
xyPixLeftTop = (xPx, yPx)
self.draws[layer].text(xyPixLeftTop, strText, font=self.font, fill=(0, 0, 0, 255))
def text_rowcol(self, rcTopLeft, strText, layer=AGENT_LAYER):
xyPixLeftTop = tuple((array(rcTopLeft) * self.nPixCell)[[1, 0]])
self.text(*xyPixLeftTop, strText, layer)
def prettify(self, *args, **kwargs):
pass
......@@ -182,28 +184,15 @@ class PILGL(GraphicsLayer):
self.create_layer(iLayer=PILGL.PREDICTION_PATH_LAYER, clear=True)
def show(self, block=False):
img = self.alpha_composite_layers()
if not self.window_open:
self.open_window()
tkimg = ImageTk.PhotoImage(img)
if self.firstFrame:
# Do TK actions for a new panel (not sure what they really do)
self.panel = tk.Label(self.window, image=tkimg)
self.panel.pack(side="bottom", fill="both", expand="yes")
else:
# update the image in situ
self.panel.configure(image=tkimg)
self.panel.image = tkimg
self.__class__.window.update()
self.firstFrame = False
#print("show() - ", self.__class__)
pass
def pause(self, seconds=0.00001):
pass
def idle(self, seconds=0.00001):
pass
def alpha_composite_layers(self):
img = self.layers[0]
for img2 in self.layers[1:]:
......@@ -263,9 +252,14 @@ class PILGL(GraphicsLayer):
class PILSVG(PILGL):
def __init__(self, width, height, jupyter=False):
"""
Note : This class should now ideally be called as PILPNG,
but for backward compatibility, and to not introduce any breaking changes at this point
we are sticking to the legacy name of PILSVG (when in practice we are not using SVG anymore)
"""
def __init__(self, width, height, jupyter=False, screen_width=800, screen_height=600):
oSuper = super()
oSuper.__init__(width, height, jupyter)
oSuper.__init__(width, height, jupyter, screen_width, screen_height)
self.lwAgents = []
self.agents_prev = []
......@@ -288,146 +282,151 @@ class PILSVG(PILGL):
self.lwAgents = []
self.agents_prev = []
def pil_from_svg_file(self, package, resource):
def pil_from_png_file(self, package, resource):
bytestring = resource_bytes(package, resource)
bytesPNG = svg2png(bytestring=bytestring, output_height=self.nPixCell, output_width=self.nPixCell)
with io.BytesIO(bytesPNG) as fIn:
with io.BytesIO(bytestring) as fIn:
pil_img = Image.open(fIn)
pil_img.load()
return pil_img
def pil_from_svg_bytes(self, bytesSVG):
bytesPNG = svg2png(bytesSVG, output_height=self.nPixCell, output_width=self.nPixCell)
with io.BytesIO(bytesPNG) as fIn:
pil_img = Image.open(fIn)
return pil_img
def load_buildings(self):
dBuildingFiles = [
"Buildings-Bank.svg",
"Buildings-Bar.svg",
"Buildings-Wohnhaus.svg",
"Buildings-Hochhaus.svg",
"Buildings-Hotel.svg",
"Buildings-Office.svg",
"Buildings-Polizei.svg",
"Buildings-Post.svg",
"Buildings-Supermarkt.svg",
"Buildings-Tankstelle.svg",
"Buildings-Fabrik_A.svg",
"Buildings-Fabrik_B.svg",
"Buildings-Fabrik_C.svg",
"Buildings-Fabrik_D.svg",
"Buildings-Fabrik_E.svg",
"Buildings-Fabrik_F.svg",
"Buildings-Fabrik_G.svg",
"Buildings-Fabrik_H.svg",
"Buildings-Fabrik_I.svg"
lBuildingFiles = [
"Buildings-Bank.png",
"Buildings-Bar.png",
"Buildings-Wohnhaus.png",
"Buildings-Hochhaus.png",
"Buildings-Hotel.png",
"Buildings-Office.png",
"Buildings-Polizei.png",
"Buildings-Post.png",
"Buildings-Supermarkt.png",
"Buildings-Tankstelle.png",
"Buildings-Fabrik_A.png",
"Buildings-Fabrik_B.png",
"Buildings-Fabrik_C.png",
"Buildings-Fabrik_D.png",
"Buildings-Fabrik_E.png",
"Buildings-Fabrik_F.png",
"Buildings-Fabrik_G.png",
"Buildings-Fabrik_H.png",
"Buildings-Fabrik_I.png"
]
imgBg = self.pil_from_svg_file('svg', "Background_city.svg")
imgBg = self.pil_from_png_file('flatland.png', "Background_city.png")
imgBg = imgBg.convert("RGBA")
self.dBuildings = []
for sFile in dBuildingFiles:
img = self.pil_from_svg_file('svg', sFile)
self.lBuildings = []
for sFile in lBuildingFiles:
img = self.pil_from_png_file('flatland.png', sFile)
img = Image.alpha_composite(imgBg, img)
self.dBuildings.append(img)
self.lBuildings.append(img)
def load_scenery(self):
scenery_files = [
"Scenery-Laubbaume_A.svg",
"Scenery-Laubbaume_B.svg",
"Scenery-Laubbaume_C.svg",
"Scenery-Nadelbaume_A.svg",
"Scenery-Nadelbaume_B.svg",
"Scenery-Bergwelt_B.svg"
"Scenery-Laubbaume_A.png",
"Scenery-Laubbaume_B.png",
"Scenery-Laubbaume_C.png",
"Scenery-Nadelbaume_A.png",
"Scenery-Nadelbaume_B.png",
"Scenery-Bergwelt_B.png"
]
scenery_files_d2 = [
"Scenery-Bergwelt_C_Teil_1_links.svg",
"Scenery-Bergwelt_C_Teil_2_rechts.svg"
"Scenery-Bergwelt_C_Teil_1_links.png",
"Scenery-Bergwelt_C_Teil_2_rechts.png"
]
scenery_files_d3 = [
"Scenery-Bergwelt_A_Teil_3_rechts.svg",
"Scenery-Bergwelt_A_Teil_2_mitte.svg",
"Scenery-Bergwelt_A_Teil_1_links.svg"
"Scenery-Bergwelt_A_Teil_1_links.png",
"Scenery-Bergwelt_A_Teil_2_mitte.png",
"Scenery-Bergwelt_A_Teil_3_rechts.png"
]
scenery_files_water = [
"Scenery_Water.png"
]
img_back_ground = self.pil_from_svg_file('svg', "Background_Light_green.svg")
img_back_ground = self.pil_from_png_file('flatland.png', "Background_Light_green.png").convert("RGBA")
self.scenery_background_white = self.pil_from_png_file('flatland.png', "Background_white.png").convert("RGBA")
self.scenery = []
for file in scenery_files:
img = self.pil_from_svg_file('svg', file)
img = self.pil_from_png_file('flatland.png', file)
img = Image.alpha_composite(img_back_ground, img)
self.scenery.append(img)
self.scenery_d2 = []
for file in scenery_files_d2:
img = self.pil_from_svg_file('svg', file)
img = self.pil_from_png_file('flatland.png', file)
img = Image.alpha_composite(img_back_ground, img)
self.scenery_d2.append(img)
self.scenery_d3 = []
for file in scenery_files_d3:
img = self.pil_from_svg_file('svg', file)
img = self.pil_from_png_file('flatland.png', file)
img = Image.alpha_composite(img_back_ground, img)
self.scenery_d3.append(img)
self.scenery_water = []
for file in scenery_files_water:
img = self.pil_from_png_file('flatland.png', file)
img = Image.alpha_composite(img_back_ground, img)
self.scenery_water.append(img)
def load_rail(self):
""" Load the rail SVG images, apply rotations, and store as PIL images.
"""
rail_files = {
"": "Background_Light_green.svg",
"WE": "Gleis_Deadend.svg",
"WW EE NN SS": "Gleis_Diamond_Crossing.svg",
"WW EE": "Gleis_horizontal.svg",
"EN SW": "Gleis_Kurve_oben_links.svg",
"WN SE": "Gleis_Kurve_oben_rechts.svg",
"ES NW": "Gleis_Kurve_unten_links.svg",
"NE WS": "Gleis_Kurve_unten_rechts.svg",
"NN SS": "Gleis_vertikal.svg",
"NN SS EE WW ES NW SE WN": "Weiche_Double_Slip.svg",
"EE WW EN SW": "Weiche_horizontal_oben_links.svg",
"EE WW SE WN": "Weiche_horizontal_oben_rechts.svg",
"EE WW ES NW": "Weiche_horizontal_unten_links.svg",
"EE WW NE WS": "Weiche_horizontal_unten_rechts.svg",
"NN SS EE WW NW ES": "Weiche_Single_Slip.svg",
"NE NW ES WS": "Weiche_Symetrical.svg",
"NN SS EN SW": "Weiche_vertikal_oben_links.svg",
"NN SS SE WN": "Weiche_vertikal_oben_rechts.svg",
"NN SS NW ES": "Weiche_vertikal_unten_links.svg",
"NN SS NE WS": "Weiche_vertikal_unten_rechts.svg",
"NE NW ES WS SS NN": "Weiche_Symetrical_gerade.svg",
"NE EN SW WS": "Gleis_Kurve_oben_links_unten_rechts.svg"
"": "Background_Light_green.png",
"WE": "Gleis_Deadend.png",
"WW EE NN SS": "Gleis_Diamond_Crossing.png",
"WW EE": "Gleis_horizontal.png",
"EN SW": "Gleis_Kurve_oben_links.png",
"WN SE": "Gleis_Kurve_oben_rechts.png",
"ES NW": "Gleis_Kurve_unten_links.png",
"NE WS": "Gleis_Kurve_unten_rechts.png",
"NN SS": "Gleis_vertikal.png",
"NN SS EE WW ES NW SE WN": "Weiche_Double_Slip.png",
"EE WW EN SW": "Weiche_horizontal_oben_links.png",
"EE WW SE WN": "Weiche_horizontal_oben_rechts.png",
"EE WW ES NW": "Weiche_horizontal_unten_links.png",
"EE WW NE WS": "Weiche_horizontal_unten_rechts.png",
"NN SS EE WW NW ES": "Weiche_Single_Slip.png",
"NE NW ES WS": "Weiche_Symetrical.png",
"NN SS EN SW": "Weiche_vertikal_oben_links.png",
"NN SS SE WN": "Weiche_vertikal_oben_rechts.png",
"NN SS NW ES": "Weiche_vertikal_unten_links.png",
"NN SS NE WS": "Weiche_vertikal_unten_rechts.png",
"NE NW ES WS SS NN": "Weiche_Symetrical_gerade.png",
"NE EN SW WS": "Gleis_Kurve_oben_links_unten_rechts.png"
}
target_files = {
"EW": "Bahnhof_#d50000_Deadend_links.svg",
"NS": "Bahnhof_#d50000_Deadend_oben.svg",
"WE": "Bahnhof_#d50000_Deadend_rechts.svg",
"SN": "Bahnhof_#d50000_Deadend_unten.svg",
"EE WW": "Bahnhof_#d50000_Gleis_horizontal.svg",
"NN SS": "Bahnhof_#d50000_Gleis_vertikal.svg"}
"EW": "Bahnhof_#d50000_Deadend_links.png",
"NS": "Bahnhof_#d50000_Deadend_oben.png",
"WE": "Bahnhof_#d50000_Deadend_rechts.png",
"SN": "Bahnhof_#d50000_Deadend_unten.png",
"EE WW": "Bahnhof_#d50000_Gleis_horizontal.png",
"NN SS": "Bahnhof_#d50000_Gleis_vertikal.png"}
# Dict of rail cell images indexed by binary transitions
pil_rail_files_org = self.load_svgs(rail_files, rotate=True)
pil_rail_files = self.load_svgs(rail_files, rotate=True, background_image="Background_rail.svg",
whitefilter="Background_white_filter.svg")
pil_rail_files_org = self.load_pngs(rail_files, rotate=True)
pil_rail_files = self.load_pngs(rail_files, rotate=True, background_image="Background_rail.png",
whitefilter="Background_white_filter.png")
# Load the target files (which have rails and transitions of their own)
# They are indexed by (binTrans, iAgent), ie a tuple of the binary transition and the agent index
pil_target_files_org = self.load_svgs(target_files, rotate=False, agent_colors=self.agent_colors)
pil_target_files = self.load_svgs(target_files, rotate=False, agent_colors=self.agent_colors,
background_image="Background_rail.svg",
whitefilter="Background_white_filter.svg")
pil_target_files_org = self.load_pngs(target_files, rotate=False, agent_colors=self.agent_colors)
pil_target_files = self.load_pngs(target_files, rotate=False, agent_colors=self.agent_colors,
background_image="Background_rail.png",
whitefilter="Background_white_filter.png")
# Load station and recolorize them
station = self.pil_from_svg_file("svg", "Bahnhof_#d50000_target.svg")
station = self.pil_from_png_file('flatland.png', "Bahnhof_#d50000_target.png")
self.station_colors = self.recolor_image(station, [0, 0, 0], self.agent_colors, False)
cell_occupied = self.pil_from_svg_file("svg", "Cell_occupied.svg")
cell_occupied = self.pil_from_png_file('flatland.png', "Cell_occupied.png")
self.cell_occupied = self.recolor_image(cell_occupied, [0, 0, 0], self.agent_colors, False)
# Merge them with the regular rails.
......@@ -435,7 +434,7 @@ class PILSVG(PILGL):
self.pil_rail = {**pil_rail_files, **pil_target_files}
self.pil_rail_org = {**pil_rail_files_org, **pil_target_files_org}
def load_svgs(self, file_directory, rotate=False, agent_colors=False, background_image=None, whitefilter=None):
def load_pngs(self, file_directory, rotate=False, agent_colors=False, background_image=None, whitefilter=None):
pil = {}
transitions = RailEnvTransitions()
......@@ -444,7 +443,7 @@ class PILSVG(PILGL):
for transition, file in file_directory.items():
# Translate the ascii transition description in the format "NE WS" to the
# Translate the ascii transition description in the format "NE WS" to the
# binary list of transitions as per RailEnv - NESW (in) x NESW (out)
transition_16_bit = ["0"] * 16
for sTran in transition.split(" "):
......@@ -456,14 +455,14 @@ class PILSVG(PILGL):
transition_16_bit_string = "".join(transition_16_bit)
binary_trans = int(transition_16_bit_string, 2)
pil_rail = self.pil_from_svg_file('svg', file)
pil_rail = self.pil_from_png_file('flatland.png', file).convert("RGBA")
if background_image is not None:
img_bg = self.pil_from_svg_file('svg', background_image)
img_bg = self.pil_from_png_file('flatland.png', background_image).convert("RGBA")
pil_rail = Image.alpha_composite(img_bg, pil_rail)
if whitefilter is not None:
img_bg = self.pil_from_svg_file('svg', whitefilter)
img_bg = self.pil_from_png_file('flatland.png', whitefilter).convert("RGBA")
pil_rail = Image.alpha_composite(pil_rail, img_bg)
if rotate:
......@@ -492,38 +491,71 @@ class PILSVG(PILGL):
False)[0]
self.draw_image_row_col(colored_rail, (row, col), layer=PILGL.PREDICTION_PATH_LAYER)
def set_rail_at(self, row, col, binary_trans, target=None, is_selected=False, rail_grid=None):
def set_rail_at(self, row, col, binary_trans, target=None, is_selected=False, rail_grid=None, num_agents=None,
show_debug=True):
if binary_trans in self.pil_rail:
pil_track = self.pil_rail[binary_trans]
if target is not None:
target_img = self.station_colors[target % len(self.station_colors)]
target_img = Image.alpha_composite(pil_track, target_img)
self.draw_image_row_col(target_img, (row, col), layer=PILGL.TARGET_LAYER)
if show_debug:
self.text_rowcol((row + 0.8, col + 0.0), strText=str(target), layer=PILGL.TARGET_LAYER)
city_size = 1
if num_agents is not None:
city_size = max(1, np.log(1 + num_agents) / 2.5)
if binary_trans == 0:
if self.background_grid[col][row] <= 4:
if self.background_grid[col][row] <= 4 + np.ceil(((col * row + col) % 10) / city_size):
a = int(self.background_grid[col][row])
a = a % len(self.dBuildings)
a = a % len(self.lBuildings)
if (col + row + col * row) % 13 > 11:
pil_track = self.scenery[a % len(self.scenery)]
else:
if (col + row + col * row) % 3 == 0:
a = (a + (col + row + col * row)) % len(self.dBuildings)
pil_track = self.dBuildings[a]
elif (self.background_grid[col][row] > 4) or ((col ** 3 + row ** 2 + col * row) % 10 == 0):
a = (a + (col + row + col * row)) % len(self.lBuildings)
pil_track = self.lBuildings[a]
elif ((self.background_grid[col][row] > 5 + ((col * row + col) % 3)) or
((col ** 3 + row ** 2 + col * row) % 10 == 0)):
a = int(self.background_grid[col][row]) - 4
a2 = (a + (col + row + col * row + col ** 3 + row ** 4))
if a2 % 17 > 11:
if a2 % 64 > 11:
a = a2
pil_track = self.scenery[a % len(self.scenery)]
a_l = a % len(self.scenery)
if a2 % 50 == 49:
pil_track = self.scenery_water[0]
else:
pil_track = self.scenery[a_l]
if rail_grid is not None:
if a2 % 11 > 3:
if a_l == len(self.scenery) - 1:
# mountain
if col > 1 and row % 7 == 1:
if rail_grid[row, col - 1] == 0:
self.draw_image_row_col(self.scenery_d2[0], (row, col - 1),
layer=PILGL.RAIL_LAYER)
pil_track = self.scenery_d2[1]
else:
if a_l == len(self.scenery) - 1:
# mountain
if col > 2 and not (row % 7 == 1):
if rail_grid[row, col - 2] == 0 and rail_grid[row, col - 1] == 0:
self.draw_image_row_col(self.scenery_d3[0], (row, col - 2),
layer=PILGL.RAIL_LAYER)
self.draw_image_row_col(self.scenery_d3[1], (row, col - 1),
layer=PILGL.RAIL_LAYER)
pil_track = self.scenery_d3[2]
self.draw_image_row_col(pil_track, (row, col), layer=PILGL.RAIL_LAYER)
else:
print("Illegal rail:", row, col, format(binary_trans, "#018b")[2:], binary_trans)
print("Can't render - illegal rail or SVG element is undefined:", row, col,
format(binary_trans, "#018b")[2:], binary_trans)
if target is not None:
if is_selected:
svgBG = self.pil_from_svg_file("svg", "Selected_Target.svg")
svgBG = self.pil_from_png_file('flatland.png', "Selected_Target.png")
self.clear_layer(PILGL.SELECTED_TARGET_LAYER, 0)
self.draw_image_row_col(svgBG, (row, col), layer=PILGL.SELECTED_TARGET_LAYER)
......@@ -536,6 +568,7 @@ class PILSVG(PILGL):
xy_color_mask = np.all(rgbaImg[:, :, 0:3] - a3BaseColor != 0, axis=2)
else:
xy_color_mask = np.all(rgbaImg[:, :, 0:3] - a3BaseColor == 0, axis=2)
rgbaImg2 = np.copy(rgbaImg)
# Repaint the base color with the new color
......@@ -548,9 +581,9 @@ class PILSVG(PILGL):
# Seed initial train/zug files indexed by tuple(iDirIn, iDirOut):
file_directory = {
(0, 0): "Zug_Gleis_#0091ea.svg",
(1, 2): "Zug_1_Weiche_#0091ea.svg",
(0, 3): "Zug_2_Weiche_#0091ea.svg"
(0, 0): "Zug_Gleis_#0091ea.png",
(1, 2): "Zug_1_Weiche_#0091ea.png",
(0, 3): "Zug_2_Weiche_#0091ea.png"
}
# "paint" color of the train images we load - this is the color we will change.
......@@ -563,7 +596,7 @@ class PILSVG(PILGL):
for directions, path_svg in file_directory.items():
in_direction, out_direction = directions
pil_zug = self.pil_from_svg_file("svg", path_svg)
pil_zug = self.pil_from_png_file('flatland.png', path_svg)
# Rotate both the directions and the image and save in the dict
for rot_direction in range(4):
......@@ -579,7 +612,8 @@ class PILSVG(PILGL):
for color_idx, pil_zug_3 in enumerate(pils):
self.pil_zug[(in_direction_2, out_direction_2, color_idx)] = pils[color_idx]
def set_agent_at(self, agent_idx, row, col, in_direction, out_direction, is_selected):
def set_agent_at(self, agent_idx, row, col, in_direction, out_direction, is_selected,
rail_grid=None, show_debug=False, clear_debug_text=True, malfunction=False):
delta_dir = (out_direction - in_direction) % 4
color_idx = agent_idx % self.n_agent_colors
# when flipping direction at a dead end, use the "out_direction" direction.
......@@ -587,16 +621,51 @@ class PILSVG(PILGL):
in_direction = out_direction
pil_zug = self.pil_zug[(in_direction % 4, out_direction % 4, color_idx)]
self.draw_image_row_col(pil_zug, (row, col), layer=PILGL.AGENT_LAYER)
if rail_grid is not None:
if rail_grid[row, col] == 0.0:
self.draw_image_row_col(self.scenery_background_white, (row, col), layer=PILGL.RAIL_LAYER)
if is_selected:
bg_svg = self.pil_from_svg_file("svg", "Selected_Agent.svg")
bg_svg = self.pil_from_png_file('flatland.png', "Selected_Agent.png")
self.clear_layer(PILGL.SELECTED_AGENT_LAYER, 0)
self.draw_image_row_col(bg_svg, (row, col), layer=PILGL.SELECTED_AGENT_LAYER)
if show_debug:
if not clear_debug_text:
dr = 0.2
dc = 0.2
if in_direction == 0:
dr = 0.8
dc = 0.0
if in_direction == 1:
dr = 0.0
dc = 0.8
if in_direction == 2:
dr = 0.4
dc = 0.8
if in_direction == 3:
dr = 0.8
dc = 0.4
self.text_rowcol((row + dr, col + dc,), str(agent_idx), layer=PILGL.SELECTED_AGENT_LAYER)
else:
self.text_rowcol((row + 0.2, col + 0.2,), str(agent_idx))
if malfunction:
self.draw_malfunction(agent_idx, (row, col))
def set_cell_occupied(self, agent_idx, row, col):
occupied_im = self.cell_occupied[agent_idx % len(self.cell_occupied)]
self.draw_image_row_col(occupied_im, (row, col), 1)
def draw_malfunction(self, agent_idx, rcTopLeft):
# Roughly an "X" shape to indicate malfunction
grcOffsets = np.array([[0, 0], [1, 1], [1, 0], [0, 1]])
grcPoints = np.array(rcTopLeft)[None] + grcOffsets
gxyPoints = grcPoints[:, [1, 0]]
gxPoints, gyPoints = gxyPoints.T
# print(agent_idx, rcTopLeft, gxyPoints, "X:", gxPoints, "Y:", gyPoints)
# plot(self, gX, gY, color=None, linewidth=3, layer=RAIL_LAYER, opacity=255, **kwargs):
self.plot(gxPoints, -gyPoints, color=(0, 0, 0, 255), layer=PILGL.AGENT_LAYER, linewidth=2)
def main2():
gl = PILSVG(10, 10)
......
from typing import List, NamedTuple
import numpy as np
from IPython import display
from ipycanvas import canvas
from flatland.envs.rail_env import RailEnvActions
from flatland.utils.rendertools import RenderTool
class Behaviour():
def __init__(self, env):
self.env = env
self.nAg = len(env.agents)
def getActions(self):
return {}
class AlwaysForward(Behaviour):
def getActions(self):
return {i: RailEnvActions.MOVE_FORWARD for i in range(self.nAg)}
class DelayedStartForward(AlwaysForward):
def __init__(self, env, nStartDelay=2):
self.nStartDelay = nStartDelay
super().__init__(env)
def getActions(self):
iStep = self.env._elapsed_steps + 1
nAgentsMoving = min(self.nAg, iStep // self.nStartDelay)
return {i: RailEnvActions.MOVE_FORWARD for i in range(nAgentsMoving)}
AgentPause = NamedTuple("AgentPause",
[
("iAg", int),
("iPauseAt", int),
("iPauseFor", int)
])
class ForwardWithPause(Behaviour):
def __init__(self, env, lPauses: List[AgentPause]):
self.env = env
self.nAg = len(env.agents)
self.lPauses = lPauses
self.dAgPaused = {}
def getActions(self):
iStep = self.env._elapsed_steps + 1 # add one because this is called before step()
# new pauses starting this step
lNewPauses = [tPause for tPause in self.lPauses if tPause.iPauseAt == iStep]
# copy across the agent index and pause length
for pause in lNewPauses:
self.dAgPaused[pause.iAg] = pause.iPauseFor
# default action is move forward
dAction = {i: RailEnvActions.MOVE_FORWARD for i in range(self.nAg)}
# overwrite paused agents with stop
for iAg in self.dAgPaused:
dAction[iAg] = RailEnvActions.STOP_MOVING
# decrement the counters for each pause, and remove any expired pauses.
lFinished = []
for iAg in self.dAgPaused:
self.dAgPaused[iAg] -= 1
if self.dAgPaused[iAg] <= 0:
lFinished.append(iAg)
for iAg in lFinished:
self.dAgPaused.pop(iAg, None)
return dAction
class Deterministic(Behaviour):
def __init__(self, env, dAg_lActions):
super().__init__(env)
self.dAg_lActions = dAg_lActions
def getActions(self):
iStep = self.env._elapsed_steps
dAg_Action = {}
for iAg, lActions in self.dAg_lActions.items():
if iStep < len(lActions):
iAct = lActions[iStep]
else:
iAct = RailEnvActions.DO_NOTHING
dAg_Action[iAg] = iAct
# print(iStep, dAg_Action[0])
return dAg_Action
class EnvCanvas():
def __init__(self, env, behaviour: Behaviour = None):
self.env = env
self.iStep = 0
if behaviour is None:
behaviour = AlwaysForward(env)
self.behaviour = behaviour
self.oRT = RenderTool(env, show_debug=True)
self.oCan = canvas.Canvas(size=(600, 300))
self.render()
def render(self):
self.oRT.render_env(show_rowcols=True, show_inactive_agents=False, show_observations=False)
gIm = self.oRT.get_image()
red_channel = gIm[:, :, 0]
blue_channel = gIm[:, :, 1]
green_channel = gIm[:, :, 2]
image_data = np.stack((red_channel, blue_channel, green_channel), axis=2)
self.oCan.put_image_data(image_data)
def step(self):
dAction = self.behaviour.getActions()
self.env.step(dAction)
def show(self):
self.render()
display.display(self.oCan)
# https://stackoverflow.com/questions/715417/converting-from-a-string-to-boolean-in-python
def str2bool(v):
return v.lower() in ("yes", "true", "t", "1")
# in order for enumeration to be deterministic for testing purposes
# https://stackoverflow.com/questions/1653970/does-python-have-an-ordered-set
from collections import OrderedDict
from collections.abc import MutableSet
class OrderedSet(OrderedDict, MutableSet):
def update(self, *args, **kwargs):
if kwargs:
raise TypeError("update() takes no keyword arguments")
for s in args:
for e in s:
self.add(e)
def add(self, elem):
self[elem] = None
def discard(self, elem):
self.pop(elem, None)
def __le__(self, other):
return all(e in other for e in self)
def __lt__(self, other):
return self <= other and self != other
def __ge__(self, other):
return all(e in self for e in other)
def __gt__(self, other):
return self >= other and self != other
def __repr__(self):
return 'OrderedSet([%s])' % (', '.join(map(repr, self.keys())))
def __str__(self):
return '{%s}' % (', '.join(map(repr, self.keys())))
difference = property(lambda self: self.__sub__)
difference_update = property(lambda self: self.__isub__)
intersection = property(lambda self: self.__and__)
intersection_update = property(lambda self: self.__iand__)
issubset = property(lambda self: self.__le__)
issuperset = property(lambda self: self.__ge__)
symmetric_difference = property(lambda self: self.__xor__)
symmetric_difference_update = property(lambda self: self.__ixor__)
union = property(lambda self: self.__or__)
......@@ -7,7 +7,10 @@ import numpy as np
from numpy import array
from recordtype import recordtype
from flatland.envs.step_utils.states import TrainState
from flatland.utils.graphics_pil import PILGL, PILSVG
from flatland.utils.graphics_pgl import PGLGL
# TODO: suggested renaming to RailEnvRenderTool, as it will only work with RailEnv!
......@@ -21,6 +24,105 @@ class AgentRenderVariant(IntEnum):
class RenderTool(object):
""" RenderTool is a facade to a renderer.
(This was introduced for the Browser / JS renderer which has now been removed.)
"""
def __init__(self, env, gl="PGL", jupyter=False,
agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
show_debug=False, clear_debug_text=True, screen_width=800, screen_height=600,
host="localhost", port=None):
self.env = env
self.frame_nr = 0
self.start_time = time.time()
self.times_list = deque()
self.agent_render_variant = agent_render_variant
if gl in ["PIL", "PILSVG", "PGL"]:
self.renderer = RenderLocal(env, gl, jupyter,
agent_render_variant,
show_debug, clear_debug_text, screen_width, screen_height)
self.gl = self.renderer.gl
else:
print("[", gl, "] not found, switch to PGL")
def render_env(self,
show=False, # whether to call matplotlib show() or equivalent after completion
show_agents=True, # whether to include agents
show_inactive_agents=False, # whether to show agents before they start
show_observations=True, # whether to include observations
show_predictions=False, # whether to include predictions
show_rowcols=False, # label the rows and columns
frames=False, # frame counter to show (intended since invocation)
episode=None, # int episode number to show
step=None, # int step number to show in image
selected_agent=None, # indicate which agent is "selected" in the editor):
return_image=False): # indicate if image is returned for use in monitor:
return self.renderer.render_env(show, show_agents, show_inactive_agents, show_observations,
show_predictions, show_rowcols, frames, episode, step, selected_agent, return_image)
def close_window(self):
self.renderer.close_window()
def reset(self):
self.renderer.reset()
def set_new_rail(self):
self.renderer.set_new_rail()
self.renderer.env = self.env # bit of a hack - copy our env to the delegate
def update_background(self):
self.renderer.update_background()
def get_endpoint_URL(self):
""" Returns a string URL for the root of the HTTP server
TODO: Need to update this work work on a remote server! May be tricky...
"""
#return "http://localhost:{}".format(self.renderer.get_port())
if hasattr(self.renderer, "get_endpoint_url"):
return self.renderer.get_endpoint_url()
else:
print("Attempt to get_endpoint_url from RenderTool - only supported with BROWSER")
return None
def get_image(self):
"""
"""
if hasattr(self.renderer, "gl"):
return self.renderer.gl.get_image()
else:
print("Attempt to retrieve image from RenderTool - not supported with BROWSER")
return None
class RenderBase(object):
def __init__(self, env):
pass
def render_env(self):
pass
def close_window(self):
pass
def reset(self):
pass
def set_new_rail(self):
""" Signal to the renderer that the env has changed and will need re-rendering.
"""
pass
def update_background(self):
""" A lesser version of set_new_rail?
TODO: can update_background be pruned for simplicity?
"""
pass
class RenderLocal(RenderBase):
""" Class to render the RailEnv and agents.
Uses two layers, layer 0 for rails (mostly static), layer 1 for agents etc (dynamic)
The lower / rail layer 0 is only redrawn after set_new_rail() has been called.
......@@ -39,7 +141,10 @@ class RenderTool(object):
theta = np.linspace(0, np.pi / 2, 5)
arc = array([np.cos(theta), np.sin(theta)]).T # from [1,0] to [0,1]
def __init__(self, env, gl="PILSVG", jupyter=False, agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND):
def __init__(self, env, gl="PILSVG", jupyter=False,
agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
show_debug=False, clear_debug_text=True, screen_width=800, screen_height=600):
self.env = env
self.frame_nr = 0
self.start_time = time.time()
......@@ -47,15 +152,21 @@ class RenderTool(object):
self.agent_render_variant = agent_render_variant
self.gl_str = gl
if gl == "PIL":
self.gl = PILGL(env.width, env.height, jupyter)
self.gl = PILGL(env.width, env.height, jupyter, screen_width=screen_width, screen_height=screen_height)
elif gl == "PILSVG":
self.gl = PILSVG(env.width, env.height, jupyter)
self.gl = PILSVG(env.width, env.height, jupyter, screen_width=screen_width, screen_height=screen_height)
else:
print("[", gl, "] not found, switch to PILSVG")
self.gl = PILSVG(env.width, env.height, jupyter)
if gl != "PGL":
print("[", gl, "] not found, switch to PGL, PILSVG")
print("Using PGL")
self.gl = PGLGL(env.width, env.height, jupyter, screen_width=screen_width, screen_height=screen_height)
self.new_rail = True
self.show_debug = show_debug
self.clear_debug_text = clear_debug_text
self.update_background()
def reset(self):
......@@ -72,9 +183,10 @@ class RenderTool(object):
def update_background(self):
# create background map
targets = {}
for agent_idx, agent in enumerate(self.env.agents_static):
for agent_idx, agent in enumerate(self.env.agents):
if agent is None:
continue
#print(f"updatebg: {agent_idx} {agent.target}")
targets[tuple(agent.target)] = agent_idx
self.gl.build_background_map(targets)
......@@ -88,10 +200,9 @@ class RenderTool(object):
self.new_rail = True
def plot_agents(self, targets=True, selected_agent=None):
color_map = self.gl.get_cmap('hsv',
lut=max(len(self.env.agents), len(self.env.agents_static) + 1))
color_map = self.gl.get_cmap('hsv', lut=(len(self.env.agents) + 1))
for agent_idx, agent in enumerate(self.env.agents_static):
for agent_idx, agent in enumerate(self.env.agents):
if agent is None:
continue
color = color_map(agent_idx)
......@@ -142,6 +253,9 @@ class RenderTool(object):
Plot a simple agent.
Assumes a working graphics layer context (cf a MPL figure).
"""
if position_row_col is None:
return
rt = self.__class__
direction_row_col = rt.transitions_row_col[direction] # agent direction in RC
......@@ -282,7 +396,7 @@ class RenderTool(object):
if len(observation_dict) < 1:
warnings.warn(
"Predictor did not provide any predicted cells to render. \
Observaiton builder needs to populate: env.dev_obs_dict")
Observation builder needs to populate: env.dev_obs_dict")
else:
for agent in agent_handles:
color = self.gl.get_agent_color(agent)
......@@ -393,31 +507,43 @@ class RenderTool(object):
def render_env(self,
show=False, # whether to call matplotlib show() or equivalent after completion
agents=True, # whether to include agents
show_agents=True, # whether to include agents
show_inactive_agents=False,
show_observations=True, # whether to include observations
show_predictions=False, # whether to include predictions
show_rowcols=False, # label the rows and columns
frames=False, # frame counter to show (intended since invocation)
episode=None, # int episode number to show
step=None, # int step number to show in image
selected_agent=None): # indicate which agent is "selected" in the editor
selected_agent=None, # indicate which agent is "selected" in the editor
return_image=False): # indicate if image is returned for use in monitor:
""" Draw the environment using the GraphicsLayer this RenderTool was created with.
(Use show=False from a Jupyter notebook with %matplotlib inline)
"""
if type(self.gl) is PILSVG:
self.render_env_svg(show=show,
# if type(self.gl) is PILSVG:
if self.gl_str in ["PILSVG", "PGL"]:
return self.render_env_svg(show=show,
show_observations=show_observations,
show_predictions=show_predictions,
selected_agent=selected_agent
selected_agent=selected_agent,
show_agents=show_agents,
show_inactive_agents=show_inactive_agents,
show_rowcols=show_rowcols,
return_image=return_image
)
else:
self.render_env_pil(show=show,
agents=agents,
return self.render_env_pil(show=show,
show_agents=show_agents,
show_inactive_agents=show_inactive_agents,
show_observations=show_observations,
show_predictions=show_predictions,
show_rowcols=show_rowcols,
frames=frames,
episode=episode,
step=step,
selected_agent=selected_agent
selected_agent=selected_agent,
return_image=return_image
)
def _draw_square(self, center, size, color, opacity=255, layer=0):
......@@ -433,13 +559,16 @@ class RenderTool(object):
def render_env_pil(self,
show=False, # whether to call matplotlib show() or equivalent after completion
# use false when calling from Jupyter. (and matplotlib no longer supported!)
agents=True, # whether to include agents
show_agents=True, # whether to include agents
show_inactive_agents=False,
show_observations=True, # whether to include observations
show_predictions=False, # whether to include predictions
show_rowcols=False, # label the rows and columns
frames=False, # frame counter to show (intended since invocation)
episode=None, # int episode number to show
step=None, # int step number to show in image
selected_agent=None # indicate which agent is "selected" in the editor
selected_agent=None, # indicate which agent is "selected" in the editor
return_image=False # indicate if image is returned for use in monitor:
):
if type(self.gl) is PILGL:
......@@ -450,7 +579,7 @@ class RenderTool(object):
self.render_rail()
# Draw each agent + its orientation + its target
if agents:
if show_agents:
self.plot_agents(targets=True, selected_agent=selected_agent)
if show_observations:
self.render_observation(range(env.get_num_agents()), env.dev_obs_dict)
......@@ -487,10 +616,13 @@ class RenderTool(object):
self.gl.pause(0.00001)
if return_image:
return self.get_image()
return
def render_env_svg(
self, show=False, show_observations=True, show_predictions=False, selected_agent=None
self, show=False, show_observations=True, show_predictions=False, selected_agent=None,
show_agents=True, show_inactive_agents=False, show_rowcols=False, return_image=False
):
"""
Renders the environment with SVG support (nice image)
......@@ -507,7 +639,7 @@ class RenderTool(object):
# store the targets
targets = {}
selected = {}
for agent_idx, agent in enumerate(self.env.agents_static):
for agent_idx, agent in enumerate(self.env.agents):
if agent is None:
continue
targets[tuple(agent.target)] = agent_idx
......@@ -525,60 +657,114 @@ class RenderTool(object):
is_selected = False
self.gl.set_rail_at(r, c, transitions, target=target, is_selected=is_selected,
rail_grid=env.rail.grid)
rail_grid=env.rail.grid, num_agents=env.get_num_agents(),
show_debug=self.show_debug)
self.gl.build_background_map(targets)
for agent_idx, agent in enumerate(self.env.agents):
if show_rowcols:
# label rows, cols
for iRow in range(env.height):
self.gl.text_rowcol((iRow, 0), str(iRow), layer=self.gl.RAIL_LAYER)
for iCol in range(env.width):
self.gl.text_rowcol((0, iCol), str(iCol), layer=self.gl.RAIL_LAYER)
if agent is None:
continue
if self.agent_render_variant == AgentRenderVariant.BOX_ONLY:
self.gl.set_cell_occupied(agent_idx, *(agent.position))
elif self.agent_render_variant == AgentRenderVariant.ONE_STEP_BEHIND or \
self.agent_render_variant == AgentRenderVariant.ONE_STEP_BEHIND_AND_BOX: # noqa: E125
if agent.old_position is not None:
position = agent.old_position
direction = agent.direction
old_direction = agent.old_direction
if show_agents:
for agent_idx, agent in enumerate(self.env.agents):
if agent is None:
continue
# Show an agent even if it hasn't already started
if agent.position is None:
if show_inactive_agents:
# print("agent ", agent_idx, agent.position, agent.old_position, agent.initial_position)
self.gl.set_agent_at(agent_idx, *(agent.initial_position),
agent.initial_direction, agent.initial_direction,
is_selected=(selected_agent == agent_idx),
rail_grid=env.rail.grid,
show_debug=self.show_debug, clear_debug_text=self.clear_debug_text,
malfunction=False)
continue
is_malfunction = agent.malfunction_handler.malfunction_down_counter > 0
if self.agent_render_variant == AgentRenderVariant.BOX_ONLY:
self.gl.set_cell_occupied(agent_idx, *(agent.position))
elif self.agent_render_variant == AgentRenderVariant.ONE_STEP_BEHIND or \
self.agent_render_variant == AgentRenderVariant.ONE_STEP_BEHIND_AND_BOX: # noqa: E125
# Most common case - the agent has been running for >1 steps
if agent.old_position is not None:
position = agent.old_position
direction = agent.direction
old_direction = agent.old_direction
# the agent's first step - it doesn't have an old position yet
elif agent.position is not None:
position = agent.position
direction = agent.direction
old_direction = agent.direction
# When the editor has just added an agent
elif agent.initial_position is not None:
position = agent.initial_position
direction = agent.initial_direction
old_direction = agent.initial_direction
# set_agent_at uses the agent index for the color
if self.agent_render_variant == AgentRenderVariant.ONE_STEP_BEHIND_AND_BOX:
self.gl.set_cell_occupied(agent_idx, *(agent.position))
self.gl.set_agent_at(agent_idx, *position, old_direction, direction,
selected_agent == agent_idx, rail_grid=env.rail.grid,
show_debug=self.show_debug, clear_debug_text=self.clear_debug_text,
malfunction=is_malfunction)
else:
position = agent.position
direction = agent.direction
old_direction = agent.direction
for possible_direction in range(4):
# Is a transition along movement `desired_movement_from_new_cell` to the current cell possible?
isValid = env.rail.get_transition((*agent.position, agent.direction), possible_direction)
if isValid:
direction = possible_direction
# set_agent_at uses the agent index for the color
self.gl.set_agent_at(agent_idx, *position, agent.direction, direction,
selected_agent == agent_idx, rail_grid=env.rail.grid,
show_debug=self.show_debug, clear_debug_text=self.clear_debug_text,
malfunction=is_malfunction)
# set_agent_at uses the agent index for the color
if self.agent_render_variant == AgentRenderVariant.AGENT_SHOWS_OPTIONS_AND_BOX:
self.gl.set_cell_occupied(agent_idx, *(agent.position))
if show_inactive_agents:
show_this_agent = True
else:
show_this_agent = agent.state.is_on_map_state()
# set_agent_at uses the agent index for the color
if self.agent_render_variant == AgentRenderVariant.ONE_STEP_BEHIND_AND_BOX:
self.gl.set_cell_occupied(agent_idx, *(agent.position))
self.gl.set_agent_at(agent_idx, *position, old_direction, direction, selected_agent == agent_idx)
else:
position = agent.position
direction = agent.direction
for possible_directions in range(4):
# Is a transition along movement `desired_movement_from_new_cell' to the current cell possible?
isValid = env.rail.get_transition((*agent.position, agent.direction), possible_directions)
if isValid:
direction = possible_directions
# set_agent_at uses the agent index for the color
self.gl.set_agent_at(agent_idx, *position, agent.direction, direction,
selected_agent == agent_idx)
# set_agent_at uses the agent index for the color
if self.agent_render_variant == AgentRenderVariant.AGENT_SHOWS_OPTIONS_AND_BOX:
self.gl.set_cell_occupied(agent_idx, *(agent.position))
self.gl.set_agent_at(agent_idx, *position, agent.direction, direction, selected_agent == agent_idx)
if show_this_agent:
self.gl.set_agent_at(agent_idx, *position, agent.direction, direction,
selected_agent == agent_idx,
rail_grid=env.rail.grid, malfunction=is_malfunction)
if show_observations:
self.render_observation(range(env.get_num_agents()), env.dev_obs_dict)
if show_predictions:
self.render_prediction(range(env.get_num_agents()), env.dev_pred_dict)
if show:
self.gl.show()
for i in range(3):
self.gl.process_events()
self.frame_nr += 1
if return_image:
return self.get_image()
return
def close_window(self):
......
from typing import Tuple
from typing import Tuple, Dict
import numpy as np
from flatland.core.grid.grid4 import Grid4Transitions
from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap
def make_simple_rail() -> Tuple[GridTransitionMap, np.array]:
# We instantiate a very simple rail network on a 7x10 grid:
# Note that that cells have invalid RailEnvTransitions!
# |
# |
# |
# _ _ _ _\ _ _ _ _ _ _
# /
# |
# |
# |
transitions = RailEnvTransitions()
cells = transitions.transition_list
empty = cells[0]
dead_end_from_south = cells[7]
dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180)
dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
vertical_straight = cells[1]
horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
simple_switch_north_left = cells[2]
simple_switch_north_right = cells[10]
simple_switch_east_west_north = transitions.rotate_transition(simple_switch_north_right, 270)
simple_switch_east_west_south = transitions.rotate_transition(simple_switch_north_left, 270)
rail_map = np.array(
[[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
[[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 +
[[dead_end_from_east] + [horizontal_straight] * 2 +
[simple_switch_east_west_north] +
[horizontal_straight] * 2 + [simple_switch_east_west_south] +
[horizontal_straight] * 2 + [dead_end_from_west]] +
[[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
[[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
rail = GridTransitionMap(width=rail_map.shape[1],
height=rail_map.shape[0], transitions=transitions)
rail.grid = rail_map
city_positions = [(0,3), (6, 6)]
train_stations = [
[( (0, 3), 0 ) ],
[( (6, 6), 0 ) ],
]
city_orientations = [0, 2]
agents_hints = {'city_positions': city_positions,
'train_stations': train_stations,
'city_orientations': city_orientations
}
optionals = {'agents_hints': agents_hints}
return rail, rail_map, optionals
def make_disconnected_simple_rail() -> Tuple[GridTransitionMap, np.array]:
# We instantiate a very simple rail network on a 7x10 grid:
# Note that that cells have invalid RailEnvTransitions!
# |
# |
# |
# _ _ _ _\ _ _ _ _ _
# /
# |
# |
# |
transitions = RailEnvTransitions()
cells = transitions.transition_list
empty = cells[0]
dead_end_from_south = cells[7]
dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180)
dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
vertical_straight = cells[1]
horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
simple_switch_north_left = cells[2]
simple_switch_north_right = cells[10]
simple_switch_east_west_north = transitions.rotate_transition(simple_switch_north_right, 270)
simple_switch_east_west_south = transitions.rotate_transition(simple_switch_north_left, 270)
rail_map = np.array(
[[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
[[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 +
[[dead_end_from_east] + [horizontal_straight] * 2 +
[simple_switch_east_west_north] +
[dead_end_from_west] + [dead_end_from_east] + [simple_switch_east_west_south] +
[horizontal_straight] * 2 + [dead_end_from_west]] +
[[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
[[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
rail = GridTransitionMap(width=rail_map.shape[1],
height=rail_map.shape[0], transitions=transitions)
rail.grid = rail_map
city_positions = [(0,3), (6, 6)]
train_stations = [
[( (0, 3), 0 ) ],
[( (6, 6), 0 ) ],
]
city_orientations = [0, 2]
agents_hints = {'city_positions': city_positions,
'train_stations': train_stations,
'city_orientations': city_orientations
}
optionals = {'agents_hints': agents_hints}
return rail, rail_map, optionals
def make_simple_rail2() -> Tuple[GridTransitionMap, np.array]:
# We instantiate a very simple rail network on a 7x10 grid:
# |
# |
# |
# _ _ _ _\ _ _ _ _ _ _
# \
# |
# |
# |
transitions = RailEnvTransitions()
cells = transitions.transition_list
empty = cells[0]
dead_end_from_south = cells[7]
dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180)
dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
vertical_straight = cells[1]
horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
simple_switch_north_right = cells[10]
simple_switch_east_west_north = transitions.rotate_transition(simple_switch_north_right, 270)
simple_switch_west_east_south = transitions.rotate_transition(simple_switch_north_right, 90)
rail_map = np.array(
[[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
[[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 +
[[dead_end_from_east] + [horizontal_straight] * 2 +
[simple_switch_east_west_north] +
[horizontal_straight] * 2 + [simple_switch_west_east_south] +
[horizontal_straight] * 2 + [dead_end_from_west]] +
[[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
[[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
rail = GridTransitionMap(width=rail_map.shape[1],
height=rail_map.shape[0], transitions=transitions)
rail.grid = rail_map
city_positions = [(0,3), (6, 6)]
train_stations = [
[( (0, 3), 0 ) ],
[( (6, 6), 0 ) ],
]
city_orientations = [0, 2]
agents_hints = {'city_positions': city_positions,
'train_stations': train_stations,
'city_orientations': city_orientations
}
optionals = {'agents_hints': agents_hints}
return rail, rail_map, optionals
def make_simple_rail_unconnected() -> Tuple[GridTransitionMap, np.array]:
# We instantiate a very simple rail network on a 7x10 grid:
# Note that that cells have invalid RailEnvTransitions!
# |
# |
# |
# _ _ _ _ _ _ _ _ _ _
# /
# |
# |
# |
transitions = RailEnvTransitions()
cells = transitions.transition_list
empty = cells[0]
dead_end_from_south = cells[7]
dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180)
dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
vertical_straight = cells[1]
horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
simple_switch_north_left = cells[2]
# simple_switch_north_right = cells[10]
# simple_switch_east_west_north = transitions.rotate_transition(simple_switch_north_right, 270)
simple_switch_east_west_south = transitions.rotate_transition(simple_switch_north_left, 270)
rail_map = np.array(
[[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
[[empty] * 3 + [vertical_straight] + [empty] * 6] +
[[empty] * 3 + [dead_end_from_north] + [empty] * 6] +
[[dead_end_from_east] + [horizontal_straight] * 5 + [simple_switch_east_west_south] +
[horizontal_straight] * 2 + [dead_end_from_west]] +
[[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
[[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
rail = GridTransitionMap(width=rail_map.shape[1],
height=rail_map.shape[0], transitions=transitions)
rail.grid = rail_map
city_positions = [(0,3), (6, 6)]
train_stations = [
[( (0, 3), 0 ) ],
[( (6, 6), 0 ) ],
]
city_orientations = [0, 2]
agents_hints = {'city_positions': city_positions,
'train_stations': train_stations,
'city_orientations': city_orientations
}
optionals = {'agents_hints': agents_hints}
return rail, rail_map, optionals
def make_simple_rail_with_alternatives() -> Tuple[GridTransitionMap, np.array]:
# We instantiate a very simple rail network on a 7x10 grid:
# 0 1 2 3 4 5 6 7 8 9 10
# 0 /-------------\
# 1 | |
# 2 | |
# 3 _ _ _ /_ _ _ |
# 4 \ ___ /
# 5 |/
# 6 |
# 7 |
transitions = RailEnvTransitions()
cells = transitions.transition_list
empty = cells[0]
dead_end_from_south = cells[7]
right_turn_from_south = cells[8]
right_turn_from_west = transitions.rotate_transition(right_turn_from_south, 90)
right_turn_from_north = transitions.rotate_transition(right_turn_from_south, 180)
dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180)
dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
vertical_straight = cells[1]
simple_switch_north_left = cells[2]
simple_switch_north_right = cells[10]
simple_switch_left_east = transitions.rotate_transition(simple_switch_north_left, 90)
horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
double_switch_south_horizontal_straight = horizontal_straight + cells[6]
double_switch_north_horizontal_straight = transitions.rotate_transition(
double_switch_south_horizontal_straight, 180)
rail_map = np.array(
[[empty] * 3 + [right_turn_from_south] + [horizontal_straight] * 5 + [right_turn_from_west]] +
[[empty] * 3 + [vertical_straight] + [empty] * 5 + [vertical_straight]] * 2 +
[[dead_end_from_east] + [horizontal_straight] * 2 + [simple_switch_left_east] + [horizontal_straight] * 2 + [
right_turn_from_west] + [empty] * 2 + [vertical_straight]] +
[[empty] * 6 + [simple_switch_north_right] + [horizontal_straight] * 2 + [right_turn_from_north]] +
[[empty] * 6 + [vertical_straight] + [empty] * 3] +
[[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
rail = GridTransitionMap(width=rail_map.shape[1],
height=rail_map.shape[0], transitions=transitions)
rail.grid = rail_map
city_positions = [(0,3), (6, 6)]
train_stations = [
[( (0, 3), 0 ) ],
[( (6, 6), 0 ) ],
]
city_orientations = [0, 2]
agents_hints = {'city_positions': city_positions,
'train_stations': train_stations,
'city_orientations': city_orientations
}
optionals = {'agents_hints': agents_hints}
return rail, rail_map, optionals
def make_invalid_simple_rail() -> Tuple[GridTransitionMap, np.array, Dict[str, str]]:
# We instantiate a very simple rail network on a 7x10 grid:
# |
# |
......@@ -16,15 +270,9 @@ def make_simple_rail() -> Tuple[GridTransitionMap, np.array]:
# |
# |
# |
cells = [int('0000000000000000', 2), # empty cell - Case 0
int('1000000000100000', 2), # Case 1 - straight
int('1001001000100000', 2), # Case 2 - simple switch
int('1000010000100001', 2), # Case 3 - diamond drossing
int('1001011000100001', 2), # Case 4 - single slip switch
int('1100110000110011', 2), # Case 5 - double slip switch
int('0101001000000010', 2), # Case 6 - symmetrical switch
int('0010000000000000', 2)] # Case 7 - dead end
transitions = Grid4Transitions([])
transitions = RailEnvTransitions()
cells = transitions.transition_list
empty = cells[0]
dead_end_from_south = cells[7]
dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
......@@ -47,4 +295,51 @@ def make_simple_rail() -> Tuple[GridTransitionMap, np.array]:
rail = GridTransitionMap(width=rail_map.shape[1],
height=rail_map.shape[0], transitions=transitions)
rail.grid = rail_map
return rail, rail_map
city_positions = [(0,3), (6, 6)]
train_stations = [
[( (0, 3), 0 ) ],
[( (6, 6), 0 ) ],
]
city_orientations = [0, 2]
agents_hints = {'city_positions': city_positions,
'train_stations': train_stations,
'city_orientations': city_orientations
}
optionals = {'agents_hints': agents_hints}
return rail, rail_map, optionals
def make_oval_rail() -> Tuple[GridTransitionMap, np.array]:
transitions = RailEnvTransitions()
cells = transitions.transition_list
empty = cells[0]
vertical_straight = cells[1]
horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
right_turn_from_south = cells[8]
right_turn_from_west = transitions.rotate_transition(right_turn_from_south, 90)
right_turn_from_north = transitions.rotate_transition(right_turn_from_south, 180)
right_turn_from_east = transitions.rotate_transition(right_turn_from_south, 270)
rail_map = np.array(
[[empty] * 9] +
[[empty] + [right_turn_from_south] + [horizontal_straight] * 5 + [right_turn_from_west] + [empty]] +
[[empty] + [vertical_straight] + [empty] * 5 + [vertical_straight] + [empty]]+
[[empty] + [vertical_straight] + [empty] * 5 + [vertical_straight] + [empty]] +
[[empty] + [right_turn_from_east] + [horizontal_straight] * 5 + [right_turn_from_north] + [empty]] +
[[empty] * 9], dtype=np.uint16)
rail = GridTransitionMap(width=rail_map.shape[1],
height=rail_map.shape[0], transitions=transitions)
rail.grid = rail_map
city_positions = [(1, 4), (4, 4)]
train_stations = [
[((1, 4), 0)],
[((4, 4), 0)],
]
city_orientations = [1, 3]
agents_hints = {'city_positions': city_positions,
'train_stations': train_stations,
'city_orientations': city_orientations
}
optionals = {'agents_hints': agents_hints}
return rail, rail_map, optionals
Flatland 2.0 Introduction
=========================
## What's new?
In this version of **Flat**land, we are moving closer to realistic and more complex railway problems.
Earlier versions of **Flat**land introduced you to the concept of restricted transitions, but they were still too simplistic to give us feasible solutions for daily operations.
Thus the following changes are coming in the next version to be closer to real railway network challenges:
- **New Level Generator** provide less connections between different nodes in the network and thus agent densities on rails are much higher.
- **Stochastic Events** cause agents to stop and get stuck for different numbers of time steps.
- **Different Speed Classes** allow agents to move at different speeds and thus enhance complexity in the search for optimal solutions.
We explain these changes in more detail and how you can play with their parametrization in Tutorials 3--5:
* [Tutorials](https://gitlab.aicrowd.com/flatland/flatland/tree/master/docs/tutorials)
We appreciate *your feedback* on the performance and the difficulty on these levels to help us shape the best possible **Flat**land 2.0 environment.
## Example code
To see all the changes in action you can just run the
* [examples/flatland_example_2_0.py](https://gitlab.aicrowd.com/flatland/flatland/blob/master/examples/flatland_2_0_example.py)
example.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
#!/usr/bin/env python
import glob
import os
import shutil
import subprocess
import webbrowser
from urllib.request import pathname2url
......@@ -18,16 +19,43 @@ def remove_exists(filename):
# clean docs config and html files, and rebuild everything
remove_exists('docs/flatland.rst')
# wildcards do not work under Windows
for image_file in glob.glob(r'./docs/flatland*.rst'):
remove_exists(image_file)
remove_exists('docs/modules.rst')
subprocess.call(['sphinx-apidoc', '-o', 'docs/', 'flatland'])
for md_file in glob.glob(r'./*.md') + glob.glob(r'./docs/specifications/*.md') + glob.glob(r'./docs/tutorials/*.md') + glob.glob(r'./docs/interface/*.md'):
from m2r import parse_from_file
rst_content = parse_from_file(md_file)
rst_file = md_file.replace(".md", ".rst")
remove_exists(rst_file)
with open(rst_file, 'w') as out:
print("m2r {}->{}".format(md_file, rst_file))
out.write(rst_content)
out.flush()
img_dest = 'docs/images/'
if not os.path.exists(img_dest):
os.makedirs(img_dest)
for image_file in glob.glob(r'./images/*.png'):
shutil.copy(image_file, img_dest)
subprocess.call(['sphinx-apidoc', '--force', '-a', '-e', '-o', 'docs/', 'flatland', '-H', 'API Reference', '--tocfile',
'05_apidoc'])
os.environ["SPHINXPROJ"] = "flatland"
os.environ["SPHINXPROJ"] = "Flatland"
os.chdir('docs')
subprocess.call(['python', '-msphinx', '-M', 'clean', '.', '_build'])
# TODO fix sphinx warnings instead of suppressing them...
subprocess.call(['python', '-msphinx', '-M', 'html', '.', '_build', '-Q'])
subprocess.call(['python', '-mpydeps', '../flatland', '-o', '_build/html/flatland.svg', '--no-config', '--noshow'])
img_dest = '_build/html/img'
if not os.path.exists(img_dest):
os.makedirs(img_dest)
for image_file in glob.glob(r'./specifications/img/*'):
shutil.copy(image_file, img_dest)
subprocess.call(['python', '-msphinx', '-M', 'html', '.', '_build'])
# we do not currrently use pydeps, commented out https://gitlab.aicrowd.com/flatland/flatland/issues/149
# subprocess.call(['python', '-mpydeps', '../flatland', '-o', '_build/html/flatland.svg', '--no-config', '--noshow'])
browser('_build/html/index.html')
%% Cell type:markdown id: tags:
### Example 1 - generate a rail from a manual specification
From a map of tuples (cell_type, rotation)
# Simple Animation Demo
%% Cell type:code id: tags:
``` python
from flatland.envs.generators import rail_from_manual_specifications_generator
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.utils.rendertools import RenderTool
from PIL import Image
%load_ext autoreload
%autoreload 2
```
%% Cell type:code id: tags:
``` python
specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0)],
[(0, 0), (0, 0), (0, 0), (0, 0), (7, 0), (0, 0)],
[(7, 270), (1, 90), (1, 90), (1, 90), (2, 90), (7, 90)],
[(0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0)]]
env = RailEnv(width=6,
height=4,
rail_generator=rail_from_manual_specifications_generator(specs),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2))
import numpy as np
import time
from IPython import display
from ipycanvas import canvas
from flatland.utils.rendertools import RenderTool
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_env import RailEnvActions as rea
from flatland.envs.persistence import RailEnvPersister
```
env.reset()
%% Cell type:code id: tags:
env_renderer = RenderTool(env, gl="PILSVG")
env_renderer.render_env(show=False)
``` python
env, env_dict = RailEnvPersister.load_new("complex_scene_2.pkl", load_from_package="env_data.railway")
_ = env.reset()
env._max_episode_steps = 100
```
%% Output
pickle failed to load file: complex_scene_2.pkl trying msgpack (deprecated)...
pickle failed to load file: complex_scene_2.pkl trying msgpack (deprecated)...
pickle failed to load file: complex_scene_2.pkl trying msgpack (deprecated)...
This env file has no max_episode_steps (deprecated) - setting to 100
%% Cell type:code id: tags:
``` python
Image.fromarray(env_renderer.gl.get_image())
oRT = RenderTool(env, gl="PILSVG", jupyter=False, show_debug=True)
image_arr = oRT.get_image()
oCanvas = canvas.Canvas()
oCanvas.put_image_data(image_arr[:,:,0:3])
display.display(oCanvas)
done={"__all__":False}
while not done["__all__"]:
actions = {}
for agent_handle, agents in enumerate(env.agents):
actions.update({agent_handle:rea.MOVE_FORWARD})
obs, rew, done, info = env.step(actions)
oRT.render_env(show_observations=False,show_predictions=False)
gIm = oRT.get_image()
oCanvas.put_image_data(gIm[:,:,0:3])
time.sleep(0.1)
```
%% Output
<PIL.Image.Image image mode=RGBA size=718x480 at 0x14DD8FD52E8>
......
# list of notebooks to include in run-all-notebooks.py test
simple_example_manual_control.ipynb
simple_rendering_demo.ipynb
flatland_animate.ipynb
render_episode.ipynb
scene_editor.ipynb
test_saved_envs.ipynb
test_service.ipynb
%% Cell type:markdown id: tags:
# Render Episode
Render a stored episode. Env file needs to have "episode" and "action" keys.
- creates a moving gif file of the episode
- displays the episode in a widget with a slider for the time steps.
%% Cell type:markdown id: tags:
# Setup
%% Cell type:code id: tags:
``` python
#!apt -qq install graphviz libgraphviz-dev pkg-config
#!pip install -qq git+https://gitlab.aicrowd.com/flatland/flatland.git
```
%% Cell type:code id: tags:
``` python
%load_ext autoreload
%autoreload 2
```
%% Cell type:code id: tags:
``` python
from IPython import display
```
%% Cell type:code id: tags:
``` python
import os
import pandas as pd
import PIL
import imageio
```
%% Cell type:code id: tags:
``` python
from flatland.utils.rendertools import RenderTool
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.line_generators import sparse_line_generator
from flatland.envs.malfunction_generators import malfunction_from_file, no_malfunction_generator
from flatland.envs.rail_generators import rail_from_file
from flatland.envs.rail_env import RailEnvActions
from flatland.envs.step_utils.states import TrainState
from flatland.envs.persistence import RailEnvPersister
```
%% Cell type:code id: tags:
``` python
def render_env(env_renderer):
ag0= env_renderer.env.agents[0]
#print("render_env ag0: ",ag0.position, ag0.direction)
aImage = env_renderer.render_env(show_rowcols=True, return_image=True)
pil_image = PIL.Image.fromarray(aImage)
return pil_image
```
%% Cell type:markdown id: tags:
# Experiments
This has been mostly changed to load envs using `importlib_resources`. It's getting them from the package "envdata.tests`
%% Cell type:code id: tags:
``` python
env, env_dict = RailEnvPersister.load_new("complex_scene_2.pkl", load_from_package="env_data.railway")
_ = env.reset()
env._max_episode_steps = 100
```
%% Output
pickle failed to load file: complex_scene_2.pkl trying msgpack (deprecated)...
pickle failed to load file: complex_scene_2.pkl trying msgpack (deprecated)...
pickle failed to load file: complex_scene_2.pkl trying msgpack (deprecated)...
This env file has no max_episode_steps (deprecated) - setting to 100
%% Cell type:code id: tags:
``` python
# the seed has to match that used to record the episode, in order for the malfunctions to match.
oRT = RenderTool(env, show_debug=True)
aImg = oRT.render_env(show_rowcols=True, return_image=True, show_inactive_agents=True)
print(env._max_episode_steps)
```
%% Cell type:code id: tags:
``` python
loAgs = env_dict["agents"]
lCols = "initial_direction,direction,initial_position,position".split(",")
pd.DataFrame([ [getattr(oAg, sCol) for sCol in lCols]
for oAg in loAgs], columns=lCols)
```
%% Cell type:code id: tags:
``` python
pd.DataFrame([ [getattr(oAg, sCol) for sCol in lCols]
for oAg in env.agents], columns=lCols)
```
%% Cell type:code id: tags:
``` python
pd.DataFrame([ vars(oAg) for oAg in env.agents])
```
%% Cell type:code id: tags:
``` python
# from persistence.py
def get_agent_state(env):
list_agents_state = []
for iAg, oAg in enumerate(env.agents):
# the int cast is to avoid numpy types which may cause problems with msgpack
# in env v2, agents may have position None, before starting
if oAg.position is None:
pos = (0, 0)
else:
pos = (int(oAg.position[0]), int(oAg.position[1]))
# print("pos:", pos, type(pos[0]))
list_agents_state.append(
[*pos, int(oAg.direction), oAg.malfunction_handler])
return list_agents_state
```
%% Cell type:code id: tags:
``` python
pd.DataFrame([ vars(oAg) for oAg in env.agents])
```
%% Cell type:code id: tags:
``` python
expert_actions = []
action = {}
```
%% Cell type:code id: tags:
``` python
env_renderer = RenderTool(env, gl="PGL", show_debug=True)
n_agents = env.get_num_agents()
x_dim, y_dim = env.width, env.height
max_steps = env._max_episode_steps
action_dict = {}
frames = []
# log everything in original state
statuses = []
for a in range(n_agents):
statuses.append(env.agents[a].state)
pilImg = render_env(env_renderer)
frames.append({
'image': pilImg,
'statuses': statuses
})
step = 0
all_done = False
failed_action_check = False
print("Processing episode steps:")
while not all_done:
print(step, end=", ")
for agent_handle, agent in enumerate(env.agents):
action_dict.update({agent_handle: RailEnvActions.MOVE_FORWARD})
next_obs, all_rewards, done, info = env.step(action_dict)
statuses = []
for a in range(n_agents):
statuses.append(env.agents[a].state)
#clear_output(wait=True)
pilImg = render_env(env_renderer)
frames.append({
'image': pilImg,
'statuses': statuses
})
#print("Replaying {}/{}".format(step, max_steps))
if done['__all__']:
all_done = True
max_steps = step + 1
print("done")
step += 1
```
%% Cell type:code id: tags:
``` python
assert failed_action_check == False, "Realised states did not match stored states."
```
%% Cell type:code id: tags:
``` python
from ipywidgets import interact, interactive, fixed, interact_manual, Play
import ipywidgets as widgets
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from IPython.display import HTML
display.display(HTML('<link rel="stylesheet" href="//stackpath.bootstrapcdn.com/font-awesome/4.7.0/css/font-awesome.min.css"/>'))
def plot_func(frame_idx):
frame = frames[int(frame_idx)]
display.display(frame['image'])
#print(frame['statuses'])
slider = widgets.FloatSlider(value=0, min=0, max=max_steps, step=1)
interact(plot_func, frame_idx = slider)
play = Play(
max=max_steps,
value=0,
step=1,
interval=250
)
widgets.link((play, 'value'), (slider, 'value'))
widgets.VBox([play])
```
%% Cell type:code id: tags:
``` python
```
......@@ -5,6 +5,7 @@ from subprocess import Popen, PIPE
import importlib_resources
import pkg_resources
from importlib_resources import path
import importlib_resources as ir
from ipython_genutils.py3compat import string_types, bytes_to_str
......@@ -38,17 +39,38 @@ def run_python(parameters, ignore_return_code=False, stdin=None):
return stdout.decode('utf8', 'replace'), stderr.decode('utf8', 'replace')
for entry in [entry for entry in importlib_resources.contents('notebooks') if
not pkg_resources.resource_isdir('notebooks', entry)
and entry.endswith(".ipynb")
]:
print("*****************************************************************")
print("Converting and running {}".format(entry))
print("*****************************************************************")
with path('notebooks', entry) as file_in:
out, err = run_python(" -m jupyter nbconvert --execute --to notebook --inplace " + str(file_in))
sys.stderr.write(err)
sys.stderr.flush()
sys.stdout.write(out)
sys.stdout.flush()
def main():
# If the file notebooks-list exists, use it as a definitive list of notebooks to run
# This in effect ignores any local notebooks you might be working on, so you can run tox
# without them causing the notebooks task / testenv to fail.
if importlib_resources.is_resource("notebooks", "notebook-list"):
print("Using the notebooks-list file to designate which notebooks to run")
lsNB = [
sLine for sLine in ir.read_text("notebooks", "notebook-list").split("\n")
if len(sLine) > 3 and not sLine.startswith("#")
]
else:
lsNB = [
entry for entry in importlib_resources.contents('notebooks') if
not pkg_resources.resource_isdir('notebooks', entry)
and entry.endswith(".ipynb")
]
print("Running notebooks:", " ".join(lsNB))
for entry in lsNB:
print("*****************************************************************")
print("Converting and running {}".format(entry))
print("*****************************************************************")
with path('notebooks', entry) as file_in:
out, err = run_python(" -m jupyter nbconvert --ExecutePreprocessor.timeout=120 " +
"--execute --to notebook --inplace " + str(file_in))
sys.stderr.write(err)
sys.stderr.flush()
sys.stdout.write(out)
sys.stdout.flush()
if __name__ == "__main__":
main()
\ No newline at end of file
%% Cell type:markdown id: tags:
# Railway Scene Editor
%% Cell type:code id: tags:
``` python
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:95% !important; }</style>"))
```
%% Output
%% Cell type:code id: tags:
``` python
from flatland.utils.editor import EditorMVC
mvc = EditorMVC(sGL="PILSVG" )
```
%% Cell type:markdown id: tags:
## Instructions
- Drag to draw (improved dead-ends)
- ctrl-click to add agent or select agent
- if agent is selected:
- ctrl-click to move agent position
- use rotate agent to rotate 90°
- ctrl-shift-click to set target for selected agent
- target can be moved by repeating
- to Resize the env (cannot preserve work):
- select "Regen" tab, set regen size slider, click regenerate.
- alt-click remove all rails from cell
Demo Scene: complex_scene.pkl
%% Cell type:code id: tags:
``` python
mvc.view.display()
```
%% Output
load file: temp.pkl
Regenerate size 5 5
load file: temp.pkl
load file: temp.pkl
Regenerate size 5 5
load file: temp.pkl
......
%% Cell type:markdown id: tags:
### Example 2 - Generate a random rail
%% Cell type:code id: tags:
``` python
import random
import numpy as np
from flatland.envs.generators import random_rail_generator
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.utils.rendertools import RenderTool
from PIL import Image
```
%% Cell type:code id: tags:
``` python
random.seed(100)
np.random.seed(100)
# Relative weights of each cell type to be used by the random rail generators.
transition_probability = [1.0, # empty cell - Case 0
1.0, # Case 1 - straight
1.0, # Case 2 - simple switch
0.3, # Case 3 - diamond drossing
0.5, # Case 4 - single slip
0.5, # Case 5 - double slip
0.2, # Case 6 - symmetrical
0.0, # Case 7 - dead end
0.2, # Case 8 - turn left
0.2, # Case 9 - turn right
1.0] # Case 10 - mirrored switch
# Example generate a random rail
env = RailEnv(width=10,
height=10,
rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
number_of_agents=3,
obs_builder_object=TreeObsForRailEnv(max_depth=2))
env.reset()
env_renderer = RenderTool(env, gl="PILSVG")
env_renderer.render_env(show=False)
Image.fromarray(env_renderer.gl.get_image())
```
%% Output
<PIL.Image.Image image mode=RGBA size=574x574 at 0x24ABDAF8E80>