Skip to content
Snippets Groups Projects
Commit 2754bf15 authored by hagrid67's avatar hagrid67
Browse files

added play_model.py - supposed to be a main prog to run renderer

loads a pre-saved model from Nets (not sure if working)
added some frame, time, episode text info to RenderEnv (still in matplotlib)
added some flake8 ignores to tox.ini - mostly hanging index-related.
parent 7778c247
No related branches found
No related tags found
No related merge requests found
from flatland.envs.rail_env import RailEnv, random_rail_generator
# from flatland.core.env_observation_builder import TreeObsForRailEnv
from flatland.utils.rendertools import RenderTool
from flatland.baselines.dueling_double_dqn import Agent
from collections import deque
import torch
import random
import numpy as np
import matplotlib.pyplot as plt
def main():
random.seed(1)
np.random.seed(1)
# Example generate a rail given a manual specification,
# a map of tuples (cell_type, rotation)
transition_probability = [0.5, # empty cell - Case 0
1.0, # Case 1 - straight
1.0, # Case 2 - simple switch
0.3, # Case 3 - diamond drossing
0.5, # Case 4 - single slip
0.5, # Case 5 - double slip
0.2, # Case 6 - symmetrical
0.0] # Case 7 - dead end
# Example generate a random rail
env = RailEnv(width=7,
height=7,
rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
number_of_agents=1)
env_renderer = RenderTool(env)
plt.figure(figsize=(5,5))
handle = env.get_agent_handles()
state_size = 105
action_size = 4
n_trials = 9999
eps = 1.
eps_end = 0.005
eps_decay = 0.998
action_dict = dict()
scores_window = deque(maxlen=100)
done_window = deque(maxlen=100)
scores = []
dones_list = []
action_prob = [0]*4
agent = Agent(state_size, action_size, "FC", 0)
agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint9900.pth'))
def max_lt(seq, val):
"""
Return greatest item in seq for which item < val applies.
None is returned if seq was empty or all items in seq were >= val.
"""
idx = len(seq)-1
while idx >= 0:
if seq[idx] < val and seq[idx] >= 0:
return seq[idx]
idx -= 1
return None
for trials in range(1, n_trials + 1):
# Reset environment
obs = env.reset()
for a in range(env.number_of_agents):
norm = max(1, max_lt(obs[a],np.inf))
obs[a] = np.clip(np.array(obs[a]) / norm, -1, 1)
# env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5)
score = 0
env_done = 0
# Run episode
for step in range(50):
#if trials > 114:
#env_renderer.renderEnv(show=True)
#print(step)
# Action
for a in range(env.number_of_agents):
action = agent.act(np.array(obs[a]), eps=eps)
action_prob[action] += 1
action_dict.update({a: action})
# Environment step
next_obs, all_rewards, done, _ = env.step(action_dict)
for a in range(env.number_of_agents):
norm = max(1, max_lt(next_obs[a], np.inf))
next_obs[a] = np.clip(np.array(next_obs[a]) / norm, -1, 1)
# Update replay buffer and train agent
for a in range(env.number_of_agents):
agent.step(obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a])
score += all_rewards[a]
env_renderer.renderEnv(show=True, frames=True, iEpisode=trials, iStep=step)
obs = next_obs.copy()
if done['__all__']:
env_done = 1
break
# Epsilon decay
eps = max(eps_end, eps_decay * eps) # decrease epsilon
done_window.append(env_done)
scores_window.append(score) # save most recent score
scores.append(np.mean(scores_window))
dones_list.append((np.mean(done_window)))
print('\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%' +
'\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format(
env.number_of_agents,
trials,
np.mean(
scores_window),
100 * np.mean(
done_window),
eps, action_prob/np.sum(action_prob)),
end=" ")
if trials % 100 == 0:
print(
'\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format(
env.number_of_agents,
trials,
np.mean(
scores_window),
100 * np.mean(
done_window),
eps, action_prob / np.sum(action_prob)))
torch.save(agent.qnetwork_local.state_dict(),
'../flatland/baselines/Nets/avoid_checkpoint' + str(trials) + '.pth')
action_prob = [1]*4
if __name__ == "__main__":
main()
\ No newline at end of file
......@@ -4,7 +4,8 @@ import numpy as np
from numpy import array
import xarray as xr
import matplotlib.pyplot as plt
import time
from collections import deque
# TODO: suggested renaming to RailEnvRenderTool, as it will only work with RailEnv!
......@@ -31,6 +32,9 @@ class RenderTool(object):
def __init__(self, env):
self.env = env
self.iFrame = 0
self.time1 = time.time()
self.lTimes = deque()
def plotTreeOnRail(self, lVisits, color="r"):
"""
......@@ -391,7 +395,8 @@ class RenderTool(object):
def renderEnv(
self, show=False, curves=True, spacing=False,
arrows=False, agents=True, sRailColor="gray"):
arrows=False, agents=True, sRailColor="gray",
frames=False, iEpisode=None, iStep=None):
"""
Draw the environment using matplotlib.
Draw into the figure if provided.
......@@ -537,6 +542,27 @@ class RenderTool(object):
color=cmap(i),
linewidth=2.0)
# Draw some textual information like fps
yText = [0.1, 0.4, 0.7]
if frames:
plt.text(0.1, yText[2], "Frame:{:}".format(self.iFrame))
self.iFrame += 1
if iEpisode is not None:
plt.text(0.1, yText[1], "Ep:{}".format(iEpisode))
if iStep is not None:
plt.text(0.1, yText[0], "Step:{}".format(iStep))
tNow = time.time()
plt.text(2, yText[2], "elapsed:{:.2f}s".format(tNow - self.time1))
self.lTimes.append(tNow)
if len(self.lTimes) > 20:
self.lTimes.popleft()
if len(self.lTimes) > 1:
rFps = (len(self.lTimes) - 1) / (self.lTimes[-1] - self.lTimes[0])
plt.text(2, yText[1], "fps:{:.2f}".format(rFps))
plt.xlim([0, env.width * cell_size])
plt.ylim([-env.height * cell_size, 0])
......
......@@ -8,6 +8,7 @@ python =
[flake8]
max-line-length = 120
ignore = E128 E121 E126 E123 E133 E226 E241 E242 W504 W
[testenv:flake8]
basepython = python
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment