Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • jack_bruck/baselines
  • rivesunder/baselines
  • xzhaoma/baselines
  • giulia_cantini/baselines
  • sfwatergit/baselines
  • jiaodaxiaozi/baselines
  • flatland/baselines
7 results
Show changes
Showing
with 37727 additions and 231 deletions
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -59,7 +59,8 @@ For training purposes the tree is flattend into a single array. ...@@ -59,7 +59,8 @@ For training purposes the tree is flattend into a single array.
## Training ## Training
### Setting up the environment ### Setting up the environment
Let us now train a simle double dueling DQN agent to navigate to its target on flatland. We start by importing flatland Before you get started with the training make sure that you have [pytorch](https://pytorch.org/get-started/locally/) installed.
Let us now train a simPle double dueling DQN agent to navigate to its target on flatland. We start by importing flatland
``` ```
from flatland.envs.generators import complex_rail_generator from flatland.envs.generators import complex_rail_generator
...@@ -105,12 +106,12 @@ We have no successfully set up the environment for training. To visualize it in ...@@ -105,12 +106,12 @@ We have no successfully set up the environment for training. To visualize it in
env_renderer = RenderTool(env, gl="PILSVG", ) env_renderer = RenderTool(env, gl="PILSVG", )
``` ```
###Setting up the agent ### Setting up the agent
To set up a appropriate agent we need the state and action space sizes. From the discussion above about the tree observation we end up with: To set up a appropriate agent we need the state and action space sizes. From the discussion above about the tree observation we end up with:
[**Adrian**: I just wonder, why this is not done in seperate method in the the observation: get_state_size, then we don't have to write down much more. And the user don't need to [**Adrian**: I just wonder, why this is not done in seperate method in the the observation: get_state_size, then we don't have to write down much more. And the user don't need to
understand anything about the oberservation. I suggest moving this into the obersvation, base ObservationBuilder declare it as an abstract method. ... ] understand anything about the observation. I suggest moving this into the observation, base ObservationBuilder declare it as an abstract method. ... ]
``` ```
# Given the depth of the tree observation and the number of features per node we get the following state_size # Given the depth of the tree observation and the number of features per node we get the following state_size
...@@ -149,7 +150,7 @@ We now use the normalized `agent_obs` for our training loop: ...@@ -149,7 +150,7 @@ We now use the normalized `agent_obs` for our training loop:
for trials in range(1, n_trials + 1): for trials in range(1, n_trials + 1):
# Reset environment # Reset environment
obs = env.reset(True, True) obs, info = env.reset(True, True)
if not Training: if not Training:
env_renderer.set_new_rail() env_renderer.set_new_rail()
...@@ -217,7 +218,7 @@ for trials in range(1, n_trials + 1): ...@@ -217,7 +218,7 @@ for trials in range(1, n_trials + 1):
eps = max(eps_end, eps_decay * eps) # decrease epsilon eps = max(eps_end, eps_decay * eps) # decrease epsilon
``` ```
Running the `navigation_training.py` file trains a simple agent to navigate to any random target within the railway network. After running you should see a learning curve similiar to this one: Running the `training_navigation.py` file trains a simple agent to navigate to any random target within the railway network. After running you should see a learning curve similiar to this one:
![Learning_curve](https://i.imgur.com/yVGXpUy.png) ![Learning_curve](https://i.imgur.com/yVGXpUy.png)
......
...@@ -174,7 +174,7 @@ We now use the normalized `agent_obs` for our training loop: ...@@ -174,7 +174,7 @@ We now use the normalized `agent_obs` for our training loop:
agent_next_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents()
# Reset environment # Reset environment
obs = env.reset(True, True) obs, info = env.reset(True, True)
# Setup placeholder for finals observation of a single agent. This is necessary because agents terminate at # Setup placeholder for finals observation of a single agent. This is necessary because agents terminate at
# different times during an episode # different times during an episode
......
File deleted
File deleted
File deleted
...@@ -8,51 +8,41 @@ import torch ...@@ -8,51 +8,41 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
from torch_training.model import QNetwork, QNetwork2 from torch_training.model import QNetwork
BUFFER_SIZE = int(1e5) # replay buffer size BUFFER_SIZE = int(1e5) # replay buffer size
BATCH_SIZE = 512 # minibatch size BATCH_SIZE = 512 # minibatch size
GAMMA = 0.99 # discount factor 0.99 GAMMA = 0.99 # discount factor 0.99
TAU = 1e-3 # for soft update of target parameters TAU = 1e-3 # for soft update of target parameters
LR = 0.5e-4 # learning rate 5 LR = 0.5e-4 # learning rate 0.5e-4 works
UPDATE_EVERY = 10 # how often to update the network UPDATE_EVERY = 10 # how often to update the network
double_dqn = True # If using double dqn algorithm
input_channels = 5 # Number of Input channels
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
print(device) print(device)
class Agent: class Agent:
"""Interacts with and learns from the environment.""" """Interacts with and learns from the environment."""
def __init__(self, state_size, action_size, net_type, seed, double_dqn=True, input_channels=5): def __init__(self, state_size, action_size, double_dqn=True):
"""Initialize an Agent object. """Initialize an Agent object.
Params Params
====== ======
state_size (int): dimension of each state state_size (int): dimension of each state
action_size (int): dimension of each action action_size (int): dimension of each action
seed (int): random seed
""" """
self.state_size = state_size self.state_size = state_size
self.action_size = action_size self.action_size = action_size
self.seed = random.seed(seed)
self.version = net_type
self.double_dqn = double_dqn self.double_dqn = double_dqn
# Q-Network # Q-Network
if self.version == "Conv": self.qnetwork_local = QNetwork(state_size, action_size).to(device)
self.qnetwork_local = QNetwork2(state_size, action_size, seed, input_channels).to(device) self.qnetwork_target = copy.deepcopy(self.qnetwork_local)
self.qnetwork_target = copy.deepcopy(self.qnetwork_local)
else:
self.qnetwork_local = QNetwork(state_size, action_size, seed).to(device)
self.qnetwork_target = copy.deepcopy(self.qnetwork_local)
self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=LR) self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=LR)
# Replay memory # Replay memory
self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed) self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE)
# Initialize time step (for updating every UPDATE_EVERY steps) # Initialize time step (for updating every UPDATE_EVERY steps)
self.t_step = 0 self.t_step = 0
...@@ -152,7 +142,7 @@ class Agent: ...@@ -152,7 +142,7 @@ class Agent:
class ReplayBuffer: class ReplayBuffer:
"""Fixed-size buffer to store experience tuples.""" """Fixed-size buffer to store experience tuples."""
def __init__(self, action_size, buffer_size, batch_size, seed): def __init__(self, action_size, buffer_size, batch_size):
"""Initialize a ReplayBuffer object. """Initialize a ReplayBuffer object.
Params Params
...@@ -160,13 +150,11 @@ class ReplayBuffer: ...@@ -160,13 +150,11 @@ class ReplayBuffer:
action_size (int): dimension of each action action_size (int): dimension of each action
buffer_size (int): maximum size of buffer buffer_size (int): maximum size of buffer
batch_size (int): size of each training batch batch_size (int): size of each training batch
seed (int): random seed
""" """
self.action_size = action_size self.action_size = action_size
self.memory = deque(maxlen=buffer_size) self.memory = deque(maxlen=buffer_size)
self.batch_size = batch_size self.batch_size = batch_size
self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"]) self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])
self.seed = random.seed(seed)
def add(self, state, action, reward, next_state, done): def add(self, state, action, reward, next_state, done):
"""Add a new experience to memory.""" """Add a new experience to memory."""
...@@ -188,7 +176,7 @@ class ReplayBuffer: ...@@ -188,7 +176,7 @@ class ReplayBuffer:
dones = torch.from_numpy(self.__v_stack_impr([e.done for e in experiences if e is not None]).astype(np.uint8)) \ dones = torch.from_numpy(self.__v_stack_impr([e.done for e in experiences if e is not None]).astype(np.uint8)) \
.float().to(device) .float().to(device)
return (states, actions, rewards, next_states, dones) return states, actions, rewards, next_states, dones
def __len__(self): def __len__(self):
"""Return the current size of internal memory.""" """Return the current size of internal memory."""
......
...@@ -3,7 +3,7 @@ import torch.nn.functional as F ...@@ -3,7 +3,7 @@ import torch.nn.functional as F
class QNetwork(nn.Module): class QNetwork(nn.Module):
def __init__(self, state_size, action_size, seed, hidsize1=128, hidsize2=128): def __init__(self, state_size, action_size, hidsize1=128, hidsize2=128):
super(QNetwork, self).__init__() super(QNetwork, self).__init__()
self.fc1_val = nn.Linear(state_size, hidsize1) self.fc1_val = nn.Linear(state_size, hidsize1)
...@@ -24,38 +24,3 @@ class QNetwork(nn.Module): ...@@ -24,38 +24,3 @@ class QNetwork(nn.Module):
adv = F.relu(self.fc2_adv(adv)) adv = F.relu(self.fc2_adv(adv))
adv = self.fc3_adv(adv) adv = self.fc3_adv(adv)
return val + adv - adv.mean() return val + adv - adv.mean()
class QNetwork2(nn.Module):
def __init__(self, state_size, action_size, seed, input_channels, hidsize1=128, hidsize2=64):
super(QNetwork2, self).__init__()
self.conv1 = nn.Conv2d(input_channels, 16, kernel_size=3, stride=1)
self.bn1 = nn.BatchNorm2d(16)
self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=3)
self.bn2 = nn.BatchNorm2d(32)
self.conv3 = nn.Conv2d(32, 64, kernel_size=5, stride=3)
self.bn3 = nn.BatchNorm2d(64)
self.fc1_val = nn.Linear(6400, hidsize1)
self.fc2_val = nn.Linear(hidsize1, hidsize2)
self.fc3_val = nn.Linear(hidsize2, 1)
self.fc1_adv = nn.Linear(6400, hidsize1)
self.fc2_adv = nn.Linear(hidsize1, hidsize2)
self.fc3_adv = nn.Linear(hidsize2, action_size)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
# value function approximation
val = F.relu(self.fc1_val(x.view(x.size(0), -1)))
val = F.relu(self.fc2_val(val))
val = self.fc3_val(val)
# advantage calculation
adv = F.relu(self.fc1_adv(x.view(x.size(0), -1)))
adv = F.relu(self.fc2_adv(adv))
adv = self.fc3_adv(adv)
return val + adv - adv.mean()
This diff is collapsed.
This diff is collapsed.
...@@ -7,17 +7,18 @@ from collections import deque ...@@ -7,17 +7,18 @@ from collections import deque
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import torch import torch
# Import Flatland/ Observations and Predictors
from flatland.envs.generators import complex_rail_generator
from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator
# Import Flatland/ Observations and Predictors
from flatland.envs.schedule_generators import complex_schedule_generator
from importlib_resources import path from importlib_resources import path
# Import Torch and utility functions to normalize observation # Import Torch and utility functions to normalize observation
import torch_training.Nets import torch_training.Nets
from torch_training.dueling_double_dqn import Agent from torch_training.dueling_double_dqn import Agent
from utils.observation_utils import norm_obs_clip, split_tree from utils.observation_utils import norm_obs_clip, split_tree_into_feature_groups
def main(argv): def main(argv):
...@@ -40,25 +41,25 @@ def main(argv): ...@@ -40,25 +41,25 @@ def main(argv):
n_agents = np.random.randint(3, 8) n_agents = np.random.randint(3, 8)
n_goals = n_agents + np.random.randint(0, 3) n_goals = n_agents + np.random.randint(0, 3)
min_dist = int(0.75 * min(x_dim, y_dim)) min_dist = int(0.75 * min(x_dim, y_dim))
tree_depth = 3 tree_depth = 2
print("main2") print("main2")
demo = False
# Get an observation builder and predictor # Get an observation builder and predictor
predictor = ShortestPathPredictorForRailEnv() observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=ShortestPathPredictorForRailEnv())
observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=predictor())
env = RailEnv(width=x_dim, env = RailEnv(width=x_dim,
height=y_dim, height=y_dim,
rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist, rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist,
max_dist=99999, max_dist=99999,
seed=0), seed=0),
schedule_generator=complex_schedule_generator(),
obs_builder_object=observation_helper, obs_builder_object=observation_helper,
number_of_agents=n_agents) number_of_agents=n_agents)
env.reset(True, True) env.reset(True, True)
handle = env.get_agent_handles() handle = env.get_agent_handles()
features_per_node = env.obs_builder.observation_dim features_per_node = env.obs_builder.observation_dim
tree_depth = 2
nr_nodes = 0 nr_nodes = 0
for i in range(tree_depth + 1): for i in range(tree_depth + 1):
nr_nodes += np.power(4, i) nr_nodes += np.power(4, i)
...@@ -85,11 +86,11 @@ def main(argv): ...@@ -85,11 +86,11 @@ def main(argv):
agent_obs = [None] * env.get_num_agents() agent_obs = [None] * env.get_num_agents()
agent_next_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents()
# Initialize the agent # Initialize the agent
agent = Agent(state_size, action_size, "FC", 0) agent = Agent(state_size, action_size)
# Here you can pre-load an agent # Here you can pre-load an agent
if False: if False:
with path(torch_training.Nets, "avoid_checkpoint30000.pth") as file_in: with path(torch_training.Nets, "avoid_checkpoint500.pth") as file_in:
agent.qnetwork_local.load_state_dict(torch.load(file_in)) agent.qnetwork_local.load_state_dict(torch.load(file_in))
# Do training over n_episodes # Do training over n_episodes
...@@ -109,6 +110,7 @@ def main(argv): ...@@ -109,6 +110,7 @@ def main(argv):
rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist, rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist,
max_dist=99999, max_dist=99999,
seed=0), seed=0),
schedule_generator=complex_schedule_generator(),
obs_builder_object=TreeObsForRailEnv(max_depth=3, obs_builder_object=TreeObsForRailEnv(max_depth=3,
predictor=ShortestPathPredictorForRailEnv()), predictor=ShortestPathPredictorForRailEnv()),
number_of_agents=n_agents) number_of_agents=n_agents)
...@@ -119,7 +121,7 @@ def main(argv): ...@@ -119,7 +121,7 @@ def main(argv):
agent_next_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents()
# Reset environment # Reset environment
obs = env.reset(True, True) obs, info = env.reset(True, True)
# Setup placeholder for finals observation of a single agent. This is necessary because agents terminate at # Setup placeholder for finals observation of a single agent. This is necessary because agents terminate at
# different times during an episode # different times during an episode
...@@ -128,8 +130,7 @@ def main(argv): ...@@ -128,8 +130,7 @@ def main(argv):
# Build agent specific observations # Build agent specific observations
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
data, distance, agent_data = split_tree(tree=np.array(obs[a]), data, distance, agent_data = split_tree_into_feature_groups(obs[a], tree_depth)
current_depth=0)
data = norm_obs_clip(data) data = norm_obs_clip(data)
distance = norm_obs_clip(distance) distance = norm_obs_clip(distance)
agent_data = np.clip(agent_data, -1, 1) agent_data = np.clip(agent_data, -1, 1)
...@@ -160,8 +161,7 @@ def main(argv): ...@@ -160,8 +161,7 @@ def main(argv):
next_obs, all_rewards, done, _ = env.step(action_dict) next_obs, all_rewards, done, _ = env.step(action_dict)
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
data, distance, agent_data = split_tree(tree=np.array(next_obs[a]), data, distance, agent_data = split_tree_into_feature_groups(next_obs[a], tree_depth)
current_depth=0)
data = norm_obs_clip(data) data = norm_obs_clip(data)
distance = norm_obs_clip(distance) distance = norm_obs_clip(distance)
agent_data = np.clip(agent_data, -1, 1) agent_data = np.clip(agent_data, -1, 1)
......