diff --git a/examples/play_model.py b/examples/play_model.py new file mode 100644 index 0000000000000000000000000000000000000000..713d831e4936b0803f7d69f4eae8253db0685ef3 --- /dev/null +++ b/examples/play_model.py @@ -0,0 +1,141 @@ +from flatland.envs.rail_env import RailEnv, random_rail_generator +# from flatland.core.env_observation_builder import TreeObsForRailEnv +from flatland.utils.rendertools import RenderTool +from flatland.baselines.dueling_double_dqn import Agent +from collections import deque +import torch +import random +import numpy as np +import matplotlib.pyplot as plt + + +def main(): + + random.seed(1) + np.random.seed(1) + + # Example generate a rail given a manual specification, + # a map of tuples (cell_type, rotation) + transition_probability = [0.5, # empty cell - Case 0 + 1.0, # Case 1 - straight + 1.0, # Case 2 - simple switch + 0.3, # Case 3 - diamond drossing + 0.5, # Case 4 - single slip + 0.5, # Case 5 - double slip + 0.2, # Case 6 - symmetrical + 0.0] # Case 7 - dead end + + # Example generate a random rail + env = RailEnv(width=7, + height=7, + rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), + number_of_agents=1) + env_renderer = RenderTool(env) + plt.figure(figsize=(5,5)) + handle = env.get_agent_handles() + + state_size = 105 + action_size = 4 + n_trials = 9999 + eps = 1. + eps_end = 0.005 + eps_decay = 0.998 + action_dict = dict() + scores_window = deque(maxlen=100) + done_window = deque(maxlen=100) + scores = [] + dones_list = [] + action_prob = [0]*4 + agent = Agent(state_size, action_size, "FC", 0) + agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint9900.pth')) + + def max_lt(seq, val): + """ + Return greatest item in seq for which item < val applies. + None is returned if seq was empty or all items in seq were >= val. + """ + + idx = len(seq)-1 + while idx >= 0: + if seq[idx] < val and seq[idx] >= 0: + return seq[idx] + idx -= 1 + return None + + for trials in range(1, n_trials + 1): + + # Reset environment + obs = env.reset() + + for a in range(env.number_of_agents): + norm = max(1, max_lt(obs[a],np.inf)) + obs[a] = np.clip(np.array(obs[a]) / norm, -1, 1) + + # env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5) + + score = 0 + env_done = 0 + + # Run episode + for step in range(50): + #if trials > 114: + #env_renderer.renderEnv(show=True) + #print(step) + # Action + for a in range(env.number_of_agents): + action = agent.act(np.array(obs[a]), eps=eps) + action_prob[action] += 1 + action_dict.update({a: action}) + + # Environment step + next_obs, all_rewards, done, _ = env.step(action_dict) + for a in range(env.number_of_agents): + norm = max(1, max_lt(next_obs[a], np.inf)) + next_obs[a] = np.clip(np.array(next_obs[a]) / norm, -1, 1) + # Update replay buffer and train agent + for a in range(env.number_of_agents): + agent.step(obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a]) + score += all_rewards[a] + + env_renderer.renderEnv(show=True, frames=True, iEpisode=trials, iStep=step) + + obs = next_obs.copy() + if done['__all__']: + env_done = 1 + break + # Epsilon decay + eps = max(eps_end, eps_decay * eps) # decrease epsilon + + done_window.append(env_done) + scores_window.append(score) # save most recent score + scores.append(np.mean(scores_window)) + dones_list.append((np.mean(done_window))) + + print('\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%' + + '\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format( + env.number_of_agents, + trials, + np.mean( + scores_window), + 100 * np.mean( + done_window), + eps, action_prob/np.sum(action_prob)), + end=" ") + if trials % 100 == 0: + + print( + '\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format( + env.number_of_agents, + trials, + np.mean( + scores_window), + 100 * np.mean( + done_window), + eps, action_prob / np.sum(action_prob))) + torch.save(agent.qnetwork_local.state_dict(), + '../flatland/baselines/Nets/avoid_checkpoint' + str(trials) + '.pth') + action_prob = [1]*4 + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index 9f008f00c261f76e71771732aa02c3c3071f9542..0ce87ff91dd8e19faee1554a89fb80ec7eecbac4 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -4,7 +4,8 @@ import numpy as np from numpy import array import xarray as xr import matplotlib.pyplot as plt - +import time +from collections import deque # TODO: suggested renaming to RailEnvRenderTool, as it will only work with RailEnv! @@ -31,6 +32,9 @@ class RenderTool(object): def __init__(self, env): self.env = env + self.iFrame = 0 + self.time1 = time.time() + self.lTimes = deque() def plotTreeOnRail(self, lVisits, color="r"): """ @@ -391,7 +395,8 @@ class RenderTool(object): def renderEnv( self, show=False, curves=True, spacing=False, - arrows=False, agents=True, sRailColor="gray"): + arrows=False, agents=True, sRailColor="gray", + frames=False, iEpisode=None, iStep=None): """ Draw the environment using matplotlib. Draw into the figure if provided. @@ -537,6 +542,27 @@ class RenderTool(object): color=cmap(i), linewidth=2.0) + # Draw some textual information like fps + yText = [0.1, 0.4, 0.7] + if frames: + plt.text(0.1, yText[2], "Frame:{:}".format(self.iFrame)) + self.iFrame += 1 + + if iEpisode is not None: + plt.text(0.1, yText[1], "Ep:{}".format(iEpisode)) + + if iStep is not None: + plt.text(0.1, yText[0], "Step:{}".format(iStep)) + + tNow = time.time() + plt.text(2, yText[2], "elapsed:{:.2f}s".format(tNow - self.time1)) + self.lTimes.append(tNow) + if len(self.lTimes) > 20: + self.lTimes.popleft() + if len(self.lTimes) > 1: + rFps = (len(self.lTimes) - 1) / (self.lTimes[-1] - self.lTimes[0]) + plt.text(2, yText[1], "fps:{:.2f}".format(rFps)) + plt.xlim([0, env.width * cell_size]) plt.ylim([-env.height * cell_size, 0]) diff --git a/tox.ini b/tox.ini index c8e28dd38c145803acfa4a5272ac7f7da9fbbf89..54bc00406686c1ba45a1ded15b10e147c2919394 100644 --- a/tox.ini +++ b/tox.ini @@ -8,6 +8,7 @@ python = [flake8] max-line-length = 120 +ignore = E128 E121 E126 E123 E133 E226 E241 E242 W504 W [testenv:flake8] basepython = python