diff --git a/examples/play_model.py b/examples/play_model.py index 5e2d898db02969138bfa2112055f531967320306..95c37349396ca28677701621bacc8b940f547fae 100644 --- a/examples/play_model.py +++ b/examples/play_model.py @@ -136,144 +136,12 @@ def main(render=True, delay=0.0, n_trials=3, n_steps=50, sGL="PILSVG"): for step in range(n_steps): oPlayer.step() if render: - env_renderer.renderEnv(show=True, frames=True, iEpisode=trials, iStep=step, - action_dict=oPlayer.action_dict) + env_renderer.renderEnv(show=True, frames=True, iEpisode=trials, iStep=step) # time.sleep(10) if delay > 0: time.sleep(delay) - -def main_old(render=True, delay=0.0): - ''' DEPRECATED main which drives agent directly - Please use the new main() which creates a Player object which is also used by the Editor. - Please fix any bugs in main() and Player rather than here. - Will delete this one shortly. - ''' - - random.seed(1) - np.random.seed(1) - - # Example generate a random rail - env = RailEnv(width=15, height=15, - rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=20, min_dist=12), - number_of_agents=5) - - if render: - env_renderer = RenderTool(env, gl="PIL") - # env_renderer = RenderTool(env, gl="QT") - - n_trials = 9999 - eps = 1. - eps_end = 0.005 - eps_decay = 0.998 - action_dict = dict() - scores_window = deque(maxlen=100) - done_window = deque(maxlen=100) - scores = [] - dones_list = [] - action_prob = [0] * 4 - - # Real Agent - # state_size = 105 - # action_size = 4 - # agent = Agent(state_size, action_size, "FC", 0) - # agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint9900.pth')) - - def max_lt(seq, val): - """ - Return greatest item in seq for which item < val applies. - None is returned if seq was empty or all items in seq were >= val. - """ - - idx = len(seq) - 1 - while idx >= 0: - if seq[idx] < val and seq[idx] >= 0: - return seq[idx] - idx -= 1 - return None - - iFrame = 0 - tStart = time.time() - for trials in range(1, n_trials + 1): - - # Reset environment - obs = env.reset() - if render: - env_renderer.set_new_rail() - - for a in range(env.get_num_agents()): - norm = max(1, max_lt(obs[a], np.inf)) - obs[a] = np.clip(np.array(obs[a]) / norm, -1, 1) - - # env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5) - - score = 0 - env_done = 0 - - # Run episode - for step in range(100): - # if trials > 114: - # env_renderer.renderEnv(show=True) - # print(step) - # Action - for a in range(env.get_num_agents()): - action = random.randint(0, 3) # agent.act(np.array(obs[a]), eps=eps) - action_prob[action] += 1 - action_dict.update({a: action}) - - if render: - env_renderer.renderEnv(show=True, frames=True, iEpisode=trials, iStep=step, action_dict=action_dict) - if delay > 0: - time.sleep(delay) - - iFrame += 1 - - # Environment step - next_obs, all_rewards, done, _ = env.step(action_dict) - - for a in range(env.get_num_agents()): - norm = max(1, max_lt(next_obs[a], np.inf)) - next_obs[a] = np.clip(np.array(next_obs[a]) / norm, -1, 1) - - # Update replay buffer and train agent - # only needed for "real" agent - # for a in range(env.get_num_agents()): - # agent.step(obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a]) - # score += all_rewards[a] - - obs = next_obs.copy() - if done['__all__']: - env_done = 1 - break - # Epsilon decay - eps = max(eps_end, eps_decay * eps) # decrease epsilon - - done_window.append(env_done) - scores_window.append(score) # save most recent score - scores.append(np.mean(scores_window)) - dones_list.append((np.mean(done_window))) - - print(('\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%' + - '\tEpsilon: {:.2f} \t Action Probabilities: \t {}').format( - env.get_num_agents(), - trials, - np.mean(scores_window), - 100 * np.mean(done_window), - eps, action_prob / np.sum(action_prob)), - end=" ") - if trials % 100 == 0: - tNow = time.time() - rFps = iFrame / (tNow - tStart) - print(('\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%' + - '\tEpsilon: {:.2f} fps: {:.2f} \t Action Probabilities: \t {}').format( - env.get_num_agents(), - trials, - np.mean(scores_window), - 100 * np.mean(done_window), - eps, rFps, action_prob / np.sum(action_prob))) - # torch.save(agent.qnetwork_local.state_dict(), - # '../flatland/baselines/Nets/avoid_checkpoint' + str(trials) + '.pth') - action_prob = [1] * 4 + env_renderer.gl.close_window() if __name__ == "__main__": diff --git a/examples/tkplay.py b/examples/tkplay.py index 05078fadc225602dccc30bd5159b50a1ac7a8713..a46dcbc6f3aca2b84a9b35f33c339c3f03291c32 100644 --- a/examples/tkplay.py +++ b/examples/tkplay.py @@ -1,7 +1,3 @@ -import time -import tkinter as tk - -from PIL import ImageTk, Image from examples.play_model import Player from flatland.envs.generators import complex_rail_generator @@ -9,51 +5,30 @@ from flatland.envs.rail_env import RailEnv from flatland.utils.rendertools import RenderTool -def tkmain(n_trials=2, n_steps=50): - # This creates the main window of an application - window = tk.Tk() - window.title("Join") - window.configure(background='grey') - +def tkmain(n_trials=2, n_steps=50, sGL="PIL"): # Example generate a random rail env = RailEnv(width=15, height=15, rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=20, min_dist=12), number_of_agents=5) - env_renderer = RenderTool(env, gl="PIL") + env_renderer = RenderTool(env, gl=sGL, show=True) oPlayer = Player(env) n_trials = 1 - delay = 0 for trials in range(1, n_trials + 1): # Reset environment8 oPlayer.reset() env_renderer.set_new_rail() - first = True - for step in range(n_steps): oPlayer.step() env_renderer.renderEnv(show=True, frames=True, iEpisode=trials, iStep=step, action_dict=oPlayer.action_dict) - img = env_renderer.getImage() - img = Image.fromarray(img) - tkimg = ImageTk.PhotoImage(img) - - if first: - panel = tk.Label(window, image=tkimg) - panel.pack(side="bottom", fill="both", expand="yes") - else: - # update the image in situ - panel.configure(image=tkimg) - panel.image = tkimg - window.update() - if delay > 0: - time.sleep(delay) - first = False + env_renderer.gl.close_window() if __name__ == "__main__": - tkmain() + tkmain(sGL="PIL") + tkmain(sGL="PILSVG") diff --git a/flatland/utils/graphics_pil.py b/flatland/utils/graphics_pil.py index 84718813faf284fe90455aa02d6d1039a29d11f7..fe5e74b541b3d958df7058acc5b127c45d1a8ebd 100644 --- a/flatland/utils/graphics_pil.py +++ b/flatland/utils/graphics_pil.py @@ -77,6 +77,11 @@ class PILGL(GraphicsLayer): self.window.configure(background='grey') self.window_open = True + def close_window(self): + self.panel.destroy() + self.window.quit() + self.window.destroy() + def text(self, *args, **kwargs): pass @@ -211,7 +216,7 @@ class PILSVG(PILGL): def loadRailSVGs(self): """ Load the rail SVG images, apply rotations, and store as PIL images. """ - dFiles = { + dRailFiles = { "": "Background_#91D1DD.svg", "WE": "Gleis_Deadend.svg", "WW EE NN SS": "Gleis_Diamond_Crossing.svg", @@ -233,7 +238,19 @@ class PILSVG(PILGL): "NN SS NW ES": "Weiche_vertikal_unten_links.svg", "NN SS NE WS": "Weiche_vertikal_unten_rechts.svg"} - self.dPil = {} + dTargetFiles = { + "EW": "Bahnhof_#d50000_Deadend_links.svg", + "NS": "Bahnhof_#d50000_Deadend_oben.svg", + "WE": "Bahnhof_#d50000_Deadend_rechts.svg", + "SN": "Bahnhof_#d50000_Deadend_unten.svg", + "EE WW": "Bahnhof_#d50000_Gleis_horizontal.svg", + "NN SS": "Bahnhof_#d50000_Gleis_vertikal.svg"} + + self.dPilRail = self.loadSVGs(dRailFiles, rotate=True) + self.dPilTarget = self.loadSVGs(dTargetFiles, rotate=False) + + def loadSVGs(self, dDirFile, rotate=False): + dPil = {} transitions = RailEnvTransitions() @@ -241,10 +258,10 @@ class PILSVG(PILGL): # svgBG = SVG("./svg/Background_#91D1DD.svg") - for sTrans, sFile in dFiles.items(): + for sTrans, sFile in dDirFile.items(): sPathSvg = "./svg/" + sFile - # Translate the ascii transition descption in the format "NE WS" to the + # Translate the ascii transition description in the format "NE WS" to the # binary list of transitions as per RailEnv - NESW (in) x NESW (out) lTrans16 = ["0"] * 16 for sTran in sTrans.split(" "): @@ -263,22 +280,31 @@ class PILSVG(PILGL): # svg = svg.merge(svgBG) pilRail = self.pilFromSvgFile(sPathSvg) - self.dPil[binTrans] = pilRail - - # Rotate both the transition binary and the image and save in the dict - for nRot in [90, 180, 270]: - binTrans2 = transitions.rotate_transition(binTrans, nRot) - - # PIL rotates anticlockwise for positive theta - pilRail2 = pilRail.rotate(-nRot) - self.dPil[binTrans2] = pilRail2 - - def setRailAt(self, row, col, binTrans): - if binTrans in self.dPil: - pilTrack = self.dPil[binTrans] - self.drawImageRC(pilTrack, (row, col)) + dPil[binTrans] = pilRail + + if rotate: + # Rotate both the transition binary and the image and save in the dict + for nRot in [90, 180, 270]: + binTrans2 = transitions.rotate_transition(binTrans, nRot) + + # PIL rotates anticlockwise for positive theta + pilRail2 = pilRail.rotate(-nRot) + dPil[binTrans2] = pilRail2 + return dPil + + def setRailAt(self, row, col, binTrans, target=None): + if target is None: + if binTrans in self.dPilRail: + pilTrack = self.dPilRail[binTrans] + self.drawImageRC(pilTrack, (row, col)) + else: + print("Illegal rail:", row, col, format(binTrans, "#018b")[2:]) else: - print("Illegal rail:", row, col, format(binTrans, "#018b")[2:]) + if binTrans in self.dPilTarget: + pilTrack = self.dPilTarget[binTrans] + self.drawImageRC(pilTrack, (row, col)) + else: + print("Illegal target rail:", row, col, format(binTrans, "#018b")[2:]) def rgb_s2i(self, sRGB): """ convert a hex RGB string like 0091ea to 3-tuple of ints """ diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index a8f9f1cbc5158275ae45539d41a1607b775b22ef..fbe8e1ecf7624fd170b2e4f23b5bdd5d40a975c1 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -751,11 +751,23 @@ class RenderTool(object): if self.new_rail: self.new_rail = False self.gl.clear_rails() + + # store the targets + dTargets = {} + for iAgent, agent in enumerate(self.env.agents_static): + if agent is None: + continue + dTargets[agent.target] = iAgent + # Draw each cell independently for r in range(env.height): for c in range(env.width): binTrans = env.rail.grid[r, c] - self.gl.setRailAt(r, c, binTrans) + if (r, c) in dTargets: + target = dTargets[(r, c)] + else: + target = None + self.gl.setRailAt(r, c, binTrans, target) for iAgent, agent in enumerate(self.env.agents): if agent is None: diff --git a/tests/test_player.py b/tests/test_player.py index 668afb96a502ce5c25fa1e6d6fe9383a1facbf87..876db6b3d9d7a258451385f377b45dc2e4c2a4fa 100644 --- a/tests/test_player.py +++ b/tests/test_player.py @@ -4,8 +4,9 @@ from examples.play_model import main def test_main(): + main(render=True, n_steps=20, n_trials=2, sGL="PIL") main(render=True, n_steps=20, n_trials=2, sGL="PILSVG") - # main(render=True, n_steps=20, n_trials=2, sGL="PIL") + if __name__ == "__main__":