......@@ -60,7 +60,7 @@ For training purposes the tree is flattend into a single array.
## Training
### Setting up the environment
Before you get started with the training make sure that you have [pytorch]( installed.
Let us now train a simle double dueling DQN agent to navigate to its target on flatland. We start by importing flatland
Let us now train a simPle double dueling DQN agent to navigate to its target on flatland. We start by importing flatland
from flatland.envs.generators import complex_rail_generator
......@@ -111,7 +111,7 @@ env_renderer = RenderTool(env, gl="PILSVG", )
To set up a appropriate agent we need the state and action space sizes. From the discussion above about the tree observation we end up with:
[**Adrian**: I just wonder, why this is not done in seperate method in the the observation: get_state_size, then we don't have to write down much more. And the user don't need to
understand anything about the oberservation. I suggest moving this into the obersvation, base ObservationBuilder declare it as an abstract method. ... ]
understand anything about the observation. I suggest moving this into the observation, base ObservationBuilder declare it as an abstract method. ... ]
# Given the depth of the tree observation and the number of features per node we get the following state_size
......@@ -150,7 +150,7 @@ We now use the normalized `agent_obs` for our training loop:
for trials in range(1, n_trials + 1):
# Reset environment
obs = env.reset(True, True)
obs, info = env.reset(True, True)
if not Training:
......@@ -218,7 +218,7 @@ for trials in range(1, n_trials + 1):
eps = max(eps_end, eps_decay * eps) # decrease epsilon
Running the `` file trains a simple agent to navigate to any random target within the railway network. After running you should see a learning curve similiar to this one:
Running the `` file trains a simple agent to navigate to any random target within the railway network. After running you should see a learning curve similiar to this one:
......@@ -174,7 +174,7 @@ We now use the normalized `agent_obs` for our training loop:
agent_next_obs = [None] * env.get_num_agents()
# Reset environment
obs = env.reset(True, True)
obs, info = env.reset(True, True)
# Setup placeholder for finals observation of a single agent. This is necessary because agents terminate at
# different times during an episode
......@@ -8,7 +8,7 @@ import torch
import torch.nn.functional as F
import torch.optim as optim
from torch_training.model import QNetwork, QNetwork2
from torch_training.model import QNetwork
BUFFER_SIZE = int(1e5) # replay buffer size
BATCH_SIZE = 512 # minibatch size
......@@ -16,43 +16,33 @@ GAMMA = 0.99 # discount factor 0.99
TAU = 1e-3 # for soft update of target parameters
LR = 0.5e-4 # learning rate 0.5e-4 works
UPDATE_EVERY = 10 # how often to update the network
double_dqn = True # If using double dqn algorithm
input_channels = 5 # Number of Input channels
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
class Agent:
"""Interacts with and learns from the environment."""
def __init__(self, state_size, action_size, net_type, seed, double_dqn=True, input_channels=5):
def __init__(self, state_size, action_size, double_dqn=True):
"""Initialize an Agent object.
state_size (int): dimension of each state
action_size (int): dimension of each action
seed (int): random seed
self.state_size = state_size
self.action_size = action_size
self.seed = random.seed(seed)
self.version = net_type
self.double_dqn = double_dqn
# Q-Network
if self.version == "Conv":
self.qnetwork_local = QNetwork2(state_size, action_size, seed, input_channels).to(device)
self.qnetwork_target = copy.deepcopy(self.qnetwork_local)
self.qnetwork_local = QNetwork(state_size, action_size, seed).to(device)
self.qnetwork_target = copy.deepcopy(self.qnetwork_local)
self.qnetwork_local = QNetwork(state_size, action_size).to(device)
self.qnetwork_target = copy.deepcopy(self.qnetwork_local)
self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=LR)
# Replay memory
self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed)
self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE)
# Initialize time step (for updating every UPDATE_EVERY steps)
self.t_step = 0
......@@ -152,7 +142,7 @@ class Agent:
class ReplayBuffer:
"""Fixed-size buffer to store experience tuples."""
def __init__(self, action_size, buffer_size, batch_size, seed):
def __init__(self, action_size, buffer_size, batch_size):
"""Initialize a ReplayBuffer object.
......@@ -160,13 +150,11 @@ class ReplayBuffer:
action_size (int): dimension of each action
buffer_size (int): maximum size of buffer
batch_size (int): size of each training batch
seed (int): random seed
self.action_size = action_size
self.memory = deque(maxlen=buffer_size)
self.batch_size = batch_size
self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])
self.seed = random.seed(seed)
def add(self, state, action, reward, next_state, done):
"""Add a new experience to memory."""
......@@ -188,7 +176,7 @@ class ReplayBuffer:
dones = torch.from_numpy(self.__v_stack_impr([e.done for e in experiences if e is not None]).astype(np.uint8)) \
return (states, actions, rewards, next_states, dones)
return states, actions, rewards, next_states, dones
def __len__(self):
"""Return the current size of internal memory."""
......@@ -3,7 +3,7 @@ import torch.nn.functional as F
class QNetwork(nn.Module):
def __init__(self, state_size, action_size, seed, hidsize1=128, hidsize2=128):
def __init__(self, state_size, action_size, hidsize1=128, hidsize2=128):
super(QNetwork, self).__init__()
self.fc1_val = nn.Linear(state_size, hidsize1)
......@@ -24,38 +24,3 @@ class QNetwork(nn.Module):
adv = F.relu(self.fc2_adv(adv))
adv = self.fc3_adv(adv)
return val + adv - adv.mean()
class QNetwork2(nn.Module):
def __init__(self, state_size, action_size, seed, input_channels, hidsize1=128, hidsize2=64):
super(QNetwork2, self).__init__()
self.conv1 = nn.Conv2d(input_channels, 16, kernel_size=3, stride=1)
self.bn1 = nn.BatchNorm2d(16)
self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=3)
self.bn2 = nn.BatchNorm2d(32)
self.conv3 = nn.Conv2d(32, 64, kernel_size=5, stride=3)
self.bn3 = nn.BatchNorm2d(64)
self.fc1_val = nn.Linear(6400, hidsize1)
self.fc2_val = nn.Linear(hidsize1, hidsize2)
self.fc3_val = nn.Linear(hidsize2, 1)
self.fc1_adv = nn.Linear(6400, hidsize1)
self.fc2_adv = nn.Linear(hidsize1, hidsize2)
self.fc3_adv = nn.Linear(hidsize2, action_size)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
# value function approximation
val = F.relu(self.fc1_val(x.view(x.size(0), -1)))
val = F.relu(self.fc2_val(val))
val = self.fc3_val(val)
# advantage calculation
adv = F.relu(self.fc1_adv(x.view(x.size(0), -1)))
adv = F.relu(self.fc2_adv(adv))
adv = self.fc3_adv(adv)
return val + adv - adv.mean()
......@@ -3,37 +3,49 @@ from collections import deque
import numpy as np
import torch
from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.schedule_generators import sparse_schedule_generator
from flatland.utils.rendertools import RenderTool
from importlib_resources import path
from observation_builders.observations import TreeObsForRailEnv
from predictors.predictions import ShortestPathPredictorForRailEnv
import torch_training.Nets
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_file, sparse_rail_generator
from flatland.envs.schedule_generators import schedule_from_file, sparse_schedule_generator
from flatland.utils.rendertools import RenderTool
from torch_training.dueling_double_dqn import Agent
from utils.observation_utils import normalize_observation
file_name = "./railway/complex_scene.pkl"
env = RailEnv(width=10,
obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()))
x_dim = env.width
y_dim = env.height
# Parameters for the Environment
x_dim = 20
y_dim = 20
n_agents = 5
tree_depth = 2
x_dim = 25
y_dim = 25
n_agents = 10
# We are training an Agent using the Tree Observation with depth 2
observation_builder = TreeObsForRailEnv(max_depth=2)
# Use a the malfunction generator to break agents from time to time
stochastic_data = {'prop_malfunction': 0.1, # Percentage of defective agents
'malfunction_rate': 30, # Rate of malfunction occurence
'min_duration': 3, # Minimal duration of malfunction
'max_duration': 20 # Max duration of malfunction
stochastic_data = MalfunctionParameters(malfunction_rate=1./10000, # Rate of malfunction occurence
min_duration=15, # Minimal duration of malfunction
max_duration=50 # Max duration of malfunction
# Custom observation builder
predictor = ShortestPathPredictorForRailEnv()
observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=predictor)
TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(30))
# Different agent types (trains) with different speeds.
speed_ration_map = {1.: 0.25, # Fast passenger train
......@@ -43,37 +55,33 @@ speed_ration_map = {1.: 0.25, # Fast passenger train
env = RailEnv(width=x_dim,
# Number of cities in map (where train stations are)
# Number of intersections (no start / target)
num_trainstations=10, # Number of possible start/targets on map
min_node_dist=3, # Minimal distance of nodes
node_radius=2, # Proximity of stations to city center
# Number of connections to other cities/intersections
seed=15, # Random seed
seed=1, # Random seed
stochastic_data=stochastic_data, # Malfunction data generator
env.reset(True, True)
observation_helper = TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv())
env_renderer = RenderTool(env, gl="PILSVG", )
handle = env.get_agent_handles()
num_features_per_node = env.obs_builder.observation_dim
tree_depth = 2
nr_nodes = 0
for i in range(tree_depth + 1):
nr_nodes += np.power(4, i)
state_size = num_features_per_node * nr_nodes
action_size = 5
n_trials = 10
observation_radius = 10
max_steps = int(3 * (env.height + env.width))
# We set the number of episodes we would like to train on
if 'n_trials' not in locals():
n_trials = 60000
max_steps = int(4 * 2 * (20 + env.height + env.width))
eps = 1.
eps_end = 0.005
eps_decay = 0.9995
......@@ -81,14 +89,13 @@ action_dict = dict()
final_action_dict = dict()
scores_window = deque(maxlen=100)
done_window = deque(maxlen=100)
time_obs = deque(maxlen=2)
scores = []
dones_list = []
action_prob = [0] * action_size
agent_obs = [None] * env.get_num_agents()
agent_next_obs = [None] * env.get_num_agents()
agent = Agent(state_size, action_size, "FC", 0)
with path(torch_training.Nets, "avoid_checkpoint100.pth") as file_in:
agent = Agent(state_size, action_size)
with path(torch_training.Nets, "navigator_checkpoint1200.pth") as file_in:
record_images = False
......@@ -97,30 +104,36 @@ frame_step = 0
for trials in range(1, n_trials + 1):
# Reset environment
obs = env.reset(True, True)
obs, info = env.reset(True, True)
# Build agent specific observations
for a in range(env.get_num_agents()):
agent_obs[a] = normalize_observation(obs[a], observation_radius=10)
agent_obs[a] = agent_obs[a] = normalize_observation(obs[a], tree_depth, observation_radius=10)
# Reset score and done
score = 0
env_done = 0
# Run episode
for step in range(max_steps):
env_renderer.render_env(show=True, show_observations=False, show_predictions=True)
if record_images:"./Images/Avoiding/flatland_frame_{:04d}.bmp".format(frame_step))
frame_step += 1
# time.sleep(1.5)
# Action
for a in range(env.get_num_agents()):
action = agent.act(agent_obs[a], eps=0)
if info['action_required'][a]:
action = agent.act(agent_obs[a], eps=0.)
action = 0
action_prob[action] += 1
action_dict.update({a: action})
# Environment step
next_obs, all_rewards, done, _ = env.step(action_dict)
obs, all_rewards, done, _ = env.step(action_dict)
env_renderer.render_env(show=True, show_predictions=True, show_observations=False)
# Build agent specific observations and normalize
for a in range(env.get_num_agents()):
agent_obs[a] = normalize_observation(next_obs[a], observation_radius=10)
if obs[a]:
agent_obs[a] = normalize_observation(obs[a], tree_depth, observation_radius=10)
if done['__all__']:
