From a8fc6b59ef1327143e970082ca7b1892febfcb08 Mon Sep 17 00:00:00 2001 From: hagrid67 <jdhwatson@gmail.com> Date: Fri, 26 Apr 2019 23:30:49 +0100 Subject: [PATCH] added basic agents to editor --- examples/play_model.py | 76 +++++++++++++++++++++++++++++++++++ flatland/utils/editor.py | 61 +++++++++++++++++++++++++--- flatland/utils/rendertools.py | 9 +---- 3 files changed, 132 insertions(+), 14 deletions(-) diff --git a/examples/play_model.py b/examples/play_model.py index 6a67397..68530e6 100644 --- a/examples/play_model.py +++ b/examples/play_model.py @@ -10,6 +10,82 @@ import matplotlib.pyplot as plt import time + +class Player(object): + def __init__(self, env): + self.env = env + self.handle = env.get_agent_handles() + + self.state_size = 105 + self.action_size = 4 + self.n_trials = 9999 + self.eps = 1. + self.eps_end = 0.005 + self.eps_decay = 0.998 + self.action_dict = dict() + self.scores_window = deque(maxlen=100) + self.done_window = deque(maxlen=100) + self.scores = [] + self.dones_list = [] + self.action_prob = [0]*4 + self.agent = Agent(self.state_size, self.action_size, "FC", 0) + self.agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint9900.pth')) + + self.iFrame = 0 + self.tStart = time.time() + + # Reset environment + self.obs = self.env.reset() + for a in range(self.env.number_of_agents): + norm = max(1, max_lt(self.obs[a], np.inf)) + self.obs[a] = np.clip(np.array(self.obs[a]) / norm, -1, 1) + + # env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5) + + self.score = 0 + self.env_done = 0 + + def step(self): + env = self.env + for a in range(env.number_of_agents): + action = self.agent.act(np.array(self.obs[a]), eps=self.eps) + self.action_prob[action] += 1 + self.action_dict.update({a: action}) + + # Environment step + next_obs, all_rewards, done, _ = self.env.step(self.action_dict) + + for a in range(env.number_of_agents): + norm = max(1, max_lt(next_obs[a], np.inf)) + next_obs[a] = np.clip(np.array(next_obs[a]) / norm, -1, 1) + + # Update replay buffer and train agent + for a in range(self.env.number_of_agents): + self.agent.step(self.obs[a], self.action_dict[a], all_rewards[a], next_obs[a], done[a]) + self.score += all_rewards[a] + + self.iFrame += 1 + + self.obs = next_obs.copy() + if done['__all__']: + self.env_done = 1 + + +def max_lt(seq, val): + """ + Return greatest item in seq for which item < val applies. + None is returned if seq was empty or all items in seq were >= val. + """ + + idx = len(seq)-1 + while idx >= 0: + if seq[idx] < val and seq[idx] >= 0: + return seq[idx] + idx -= 1 + return None + + + def main(render=True, delay=0.0): random.seed(1) diff --git a/flatland/utils/editor.py b/flatland/utils/editor.py index 0a0bbc4..935e40d 100644 --- a/flatland/utils/editor.py +++ b/flatland/utils/editor.py @@ -15,6 +15,19 @@ from flatland.envs.rail_env import RailEnv, random_rail_generator # from flatland.core.transitions import RailEnvTransitions from flatland.core.env_observation_builder import TreeObsForRailEnv import flatland.utils.rendertools as rt +from examples.play_model import Player + + +class View(object): + def __init__(self, editor): + self.editor = editor + self.oRT = rt.RenderTool(self.editor.env) + plt.figure(figsize=(10,10)) + self.oRT.renderEnv(spacing=False, arrows=False, sRailColor="gray", show=False) + img = self.oRT.getImage() + plt.clf() + import jpy_canvas + self.wid_img = jpy_canvas.Canvas(img) class JupEditor(object): @@ -39,6 +52,8 @@ class JupEditor(object): self.drawMode = "Draw" self.env_filename = "temp.npy" self.set_env(env) + self.iAgent = None + self.player = None def set_env(self, env): self.env = env @@ -56,6 +71,28 @@ class JupEditor(object): def setDrawMode(self, dEvent): self.drawMode = dEvent["new"] + def on_click(self, wid, event): + x = event['canvasX'] + y = event['canvasY'] + rcCell = ((array([y, x]) - self.yxBase) / self.nPixCell).astype(int) + + if self.drawMode == "Origin": + self.iAgent = len(self.env.agents_position) + self.env.agents_position.append(rcCell) + self.env.agents_handles.append(max(self.env.agents_handles + [-1]) + 1) + self.env.agents_direction.append(0) + self.env.agents_target.append(rcCell) # set the target to the origin initially + self.env.number_of_agents = self.iAgent + 1 + self.drawMode = "Destination" + + elif self.drawMode == "Destination" and self.iAgent is not None: + self.env.agents_target[self.iAgent] = rcCell + self.drawMode = "Origin" + + self.log("agent", self.drawMode, self.iAgent, rcCell) + + self.redraw() + def event_handler(self, wid, event): """Mouse motion event handler """ @@ -150,9 +187,6 @@ class JupEditor(object): # This updates the image in the browser to be the new edited version self.wid_img.data = writableData - def on_click(self, event): - pass - def redraw(self, hide_stdout=True, update=True): if hide_stdout: @@ -161,7 +195,8 @@ class JupEditor(object): stdout_dest = sys.stdout # TODO: bit of a hack - can we suppress the console messages from MPL at source? - with redirect_stdout(stdout_dest): + #with redirect_stdout(stdout_dest): + with self.wid_output: plt.figure(figsize=(10, 10)) self.oRT.renderEnv(spacing=False, arrows=False, sRailColor="gray", show=False) img = self.oRT.getImage() @@ -178,6 +213,13 @@ class JupEditor(object): def clear(self, event): self.env.rail.grid[:, :] = 0 + self.env.number_of_agents = 0 + self.env.agents_position = [] + self.env.agents_direction = [] + self.env.agents_handles = [] + self.env.agents_target = [] + self.player = None + self.redraw_event(event) def setFilename(self, filename): @@ -201,15 +243,22 @@ class JupEditor(object): self.env = RailEnv(width=self.regen_size, height=self.regen_size, rail_generator=random_rail_generator(cell_type_relative_proportion=[1, 1] + [0.5] * 6), - number_of_agents=0, + number_of_agents=self.env.number_of_agents, obs_builder_object=TreeObsForRailEnv(max_depth=2)) self.env.reset() self.set_env(self.env) + self.player = Player(self.env) self.redraw() def setRegenSize_event(self, event): self.regen_size = event["new"] - + + def step_event(self, event=None): + if self.player is None: + self.player = Player(self.env) + self.player.step() + self.redraw() + def fix_env(self): self.env.width = self.env.rail.width self.env.height = self.env.rail.height diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index 1444645..6c5a175 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -21,7 +21,6 @@ class MPLGL(GraphicsLayer): plt.plot(*args, **kwargs) def scatter(self, *args, **kwargs): - print(args, kwargs) plt.scatter(*args, **kwargs) def text(self, *args, **kwargs): @@ -209,7 +208,7 @@ class RenderTool(object): xyDir = np.matmul(rcDir, rt.grc2xy) # agent direction in xy xyPos = np.matmul(rcPos - rcDir / 2, rt.grc2xy) + rt.xyHalf - self.gl.scatter(*xyPos, color=color, s=40) # agent location + self.gl.scatter(*xyPos, color=color, marker="o", s=100) # agent location xyDirLine = array([xyPos, xyPos + xyDir/2]).T # line for agent orient. self.gl.plot(*xyDirLine, color=color, lw=5, ms=0, alpha=0.6) @@ -219,12 +218,6 @@ class RenderTool(object): xyTarget = np.matmul(rcTarget, rt.grc2xy) + rt.xyHalf self._draw_square(xyTarget, 1/3, color) - if False: - # mark the next cell we're heading into - rcNext = rcPos + rcDir - xyNext = np.matmul(rcNext, rt.grc2xy) + rt.xyHalf - self.gl.scatter(*xyNext, color=color) - def plotTrans(self, rcPos, gTransRCAg, color="r", depth=None): """ plot the transitions in gTransRCAg at position rcPos. -- GitLab