Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • jack_bruck/baselines
  • rivesunder/baselines
  • xzhaoma/baselines
  • giulia_cantini/baselines
  • sfwatergit/baselines
  • jiaodaxiaozi/baselines
  • flatland/baselines
7 results
Show changes
Showing
with 33950 additions and 1147 deletions
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -60,7 +60,7 @@ For training purposes the tree is flattend into a single array.
## Training
### Setting up the environment
Before you get started with the training make sure that you have [pytorch](https://pytorch.org/get-started/locally/) installed.
Let us now train a simle double dueling DQN agent to navigate to its target on flatland. We start by importing flatland
Let us now train a simPle double dueling DQN agent to navigate to its target on flatland. We start by importing flatland
```
from flatland.envs.generators import complex_rail_generator
......@@ -111,7 +111,7 @@ env_renderer = RenderTool(env, gl="PILSVG", )
To set up a appropriate agent we need the state and action space sizes. From the discussion above about the tree observation we end up with:
[**Adrian**: I just wonder, why this is not done in seperate method in the the observation: get_state_size, then we don't have to write down much more. And the user don't need to
understand anything about the oberservation. I suggest moving this into the obersvation, base ObservationBuilder declare it as an abstract method. ... ]
understand anything about the observation. I suggest moving this into the observation, base ObservationBuilder declare it as an abstract method. ... ]
```
# Given the depth of the tree observation and the number of features per node we get the following state_size
......@@ -150,7 +150,7 @@ We now use the normalized `agent_obs` for our training loop:
for trials in range(1, n_trials + 1):
# Reset environment
obs = env.reset(True, True)
obs, info = env.reset(True, True)
if not Training:
env_renderer.set_new_rail()
......@@ -218,7 +218,7 @@ for trials in range(1, n_trials + 1):
eps = max(eps_end, eps_decay * eps) # decrease epsilon
```
Running the `navigation_training.py` file trains a simple agent to navigate to any random target within the railway network. After running you should see a learning curve similiar to this one:
Running the `training_navigation.py` file trains a simple agent to navigate to any random target within the railway network. After running you should see a learning curve similiar to this one:
![Learning_curve](https://i.imgur.com/yVGXpUy.png)
......
......@@ -174,7 +174,7 @@ We now use the normalized `agent_obs` for our training loop:
agent_next_obs = [None] * env.get_num_agents()
# Reset environment
obs = env.reset(True, True)
obs, info = env.reset(True, True)
# Setup placeholder for finals observation of a single agent. This is necessary because agents terminate at
# different times during an episode
......
File deleted
File deleted
File deleted
......@@ -8,7 +8,7 @@ import torch
import torch.nn.functional as F
import torch.optim as optim
from torch_training.model import QNetwork, QNetwork2
from torch_training.model import QNetwork
BUFFER_SIZE = int(1e5) # replay buffer size
BATCH_SIZE = 512 # minibatch size
......@@ -16,43 +16,33 @@ GAMMA = 0.99 # discount factor 0.99
TAU = 1e-3 # for soft update of target parameters
LR = 0.5e-4 # learning rate 0.5e-4 works
UPDATE_EVERY = 10 # how often to update the network
double_dqn = True # If using double dqn algorithm
input_channels = 5 # Number of Input channels
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
print(device)
class Agent:
"""Interacts with and learns from the environment."""
def __init__(self, state_size, action_size, net_type, seed, double_dqn=True, input_channels=5):
def __init__(self, state_size, action_size, double_dqn=True):
"""Initialize an Agent object.
Params
======
state_size (int): dimension of each state
action_size (int): dimension of each action
seed (int): random seed
"""
self.state_size = state_size
self.action_size = action_size
self.seed = random.seed(seed)
self.version = net_type
self.double_dqn = double_dqn
# Q-Network
if self.version == "Conv":
self.qnetwork_local = QNetwork2(state_size, action_size, seed, input_channels).to(device)
self.qnetwork_target = copy.deepcopy(self.qnetwork_local)
else:
self.qnetwork_local = QNetwork(state_size, action_size, seed).to(device)
self.qnetwork_target = copy.deepcopy(self.qnetwork_local)
self.qnetwork_local = QNetwork(state_size, action_size).to(device)
self.qnetwork_target = copy.deepcopy(self.qnetwork_local)
self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=LR)
# Replay memory
self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed)
self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE)
# Initialize time step (for updating every UPDATE_EVERY steps)
self.t_step = 0
......@@ -152,7 +142,7 @@ class Agent:
class ReplayBuffer:
"""Fixed-size buffer to store experience tuples."""
def __init__(self, action_size, buffer_size, batch_size, seed):
def __init__(self, action_size, buffer_size, batch_size):
"""Initialize a ReplayBuffer object.
Params
......@@ -160,13 +150,11 @@ class ReplayBuffer:
action_size (int): dimension of each action
buffer_size (int): maximum size of buffer
batch_size (int): size of each training batch
seed (int): random seed
"""
self.action_size = action_size
self.memory = deque(maxlen=buffer_size)
self.batch_size = batch_size
self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])
self.seed = random.seed(seed)
def add(self, state, action, reward, next_state, done):
"""Add a new experience to memory."""
......@@ -188,7 +176,7 @@ class ReplayBuffer:
dones = torch.from_numpy(self.__v_stack_impr([e.done for e in experiences if e is not None]).astype(np.uint8)) \
.float().to(device)
return (states, actions, rewards, next_states, dones)
return states, actions, rewards, next_states, dones
def __len__(self):
"""Return the current size of internal memory."""
......
......@@ -3,7 +3,7 @@ import torch.nn.functional as F
class QNetwork(nn.Module):
def __init__(self, state_size, action_size, seed, hidsize1=128, hidsize2=128):
def __init__(self, state_size, action_size, hidsize1=128, hidsize2=128):
super(QNetwork, self).__init__()
self.fc1_val = nn.Linear(state_size, hidsize1)
......@@ -24,38 +24,3 @@ class QNetwork(nn.Module):
adv = F.relu(self.fc2_adv(adv))
adv = self.fc3_adv(adv)
return val + adv - adv.mean()
class QNetwork2(nn.Module):
def __init__(self, state_size, action_size, seed, input_channels, hidsize1=128, hidsize2=64):
super(QNetwork2, self).__init__()
self.conv1 = nn.Conv2d(input_channels, 16, kernel_size=3, stride=1)
self.bn1 = nn.BatchNorm2d(16)
self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=3)
self.bn2 = nn.BatchNorm2d(32)
self.conv3 = nn.Conv2d(32, 64, kernel_size=5, stride=3)
self.bn3 = nn.BatchNorm2d(64)
self.fc1_val = nn.Linear(6400, hidsize1)
self.fc2_val = nn.Linear(hidsize1, hidsize2)
self.fc3_val = nn.Linear(hidsize2, 1)
self.fc1_adv = nn.Linear(6400, hidsize1)
self.fc2_adv = nn.Linear(hidsize1, hidsize2)
self.fc3_adv = nn.Linear(hidsize2, action_size)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
# value function approximation
val = F.relu(self.fc1_val(x.view(x.size(0), -1)))
val = F.relu(self.fc2_val(val))
val = self.fc3_val(val)
# advantage calculation
adv = F.relu(self.fc1_adv(x.view(x.size(0), -1)))
adv = F.relu(self.fc2_adv(adv))
adv = self.fc3_adv(adv)
return val + adv - adv.mean()
......@@ -3,37 +3,49 @@ from collections import deque
import numpy as np
import torch
from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.schedule_generators import sparse_schedule_generator
from flatland.utils.rendertools import RenderTool
from importlib_resources import path
from observation_builders.observations import TreeObsForRailEnv
from predictors.predictions import ShortestPathPredictorForRailEnv
import torch_training.Nets
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_file, sparse_rail_generator
from flatland.envs.schedule_generators import schedule_from_file, sparse_schedule_generator
from flatland.utils.rendertools import RenderTool
from torch_training.dueling_double_dqn import Agent
from utils.observation_utils import normalize_observation
random.seed(3)
np.random.seed(2)
random.seed(1)
np.random.seed(1)
"""
file_name = "./railway/complex_scene.pkl"
env = RailEnv(width=10,
height=20,
rail_generator=rail_from_file(file_name),
obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()))
x_dim = env.width
y_dim = env.height
"""
# Parameters for the Environment
x_dim = 20
y_dim = 20
n_agents = 5
tree_depth = 2
x_dim = 25
y_dim = 25
n_agents = 10
# We are training an Agent using the Tree Observation with depth 2
observation_builder = TreeObsForRailEnv(max_depth=2)
# Use a the malfunction generator to break agents from time to time
stochastic_data = {'prop_malfunction': 0.1, # Percentage of defective agents
'malfunction_rate': 30, # Rate of malfunction occurence
'min_duration': 3, # Minimal duration of malfunction
'max_duration': 20 # Max duration of malfunction
}
stochastic_data = MalfunctionParameters(malfunction_rate=1./10000, # Rate of malfunction occurence
min_duration=15, # Minimal duration of malfunction
max_duration=50 # Max duration of malfunction
)
# Custom observation builder
predictor = ShortestPathPredictorForRailEnv()
observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=predictor)
TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(30))
# Different agent types (trains) with different speeds.
speed_ration_map = {1.: 0.25, # Fast passenger train
......@@ -43,37 +55,33 @@ speed_ration_map = {1.: 0.25, # Fast passenger train
env = RailEnv(width=x_dim,
height=y_dim,
rail_generator=sparse_rail_generator(num_cities=5,
rail_generator=sparse_rail_generator(max_num_cities=3,
# Number of cities in map (where train stations are)
num_intersections=4,
# Number of intersections (no start / target)
num_trainstations=10, # Number of possible start/targets on map
min_node_dist=3, # Minimal distance of nodes
node_radius=2, # Proximity of stations to city center
num_neighb=3,
# Number of connections to other cities/intersections
seed=15, # Random seed
grid_mode=True,
enhance_intersection=False
),
seed=1, # Random seed
grid_mode=False,
max_rails_between_cities=2,
max_rails_in_city=2),
schedule_generator=sparse_schedule_generator(speed_ration_map),
number_of_agents=n_agents,
stochastic_data=stochastic_data, # Malfunction data generator
obs_builder_object=observation_helper)
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
obs_builder_object=TreeObservation)
env.reset(True, True)
observation_helper = TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv())
env_renderer = RenderTool(env, gl="PILSVG", )
handle = env.get_agent_handles()
num_features_per_node = env.obs_builder.observation_dim
tree_depth = 2
nr_nodes = 0
for i in range(tree_depth + 1):
nr_nodes += np.power(4, i)
state_size = num_features_per_node * nr_nodes
action_size = 5
n_trials = 10
observation_radius = 10
max_steps = int(3 * (env.height + env.width))
# We set the number of episodes we would like to train on
if 'n_trials' not in locals():
n_trials = 60000
max_steps = int(4 * 2 * (20 + env.height + env.width))
eps = 1.
eps_end = 0.005
eps_decay = 0.9995
......@@ -81,14 +89,13 @@ action_dict = dict()
final_action_dict = dict()
scores_window = deque(maxlen=100)
done_window = deque(maxlen=100)
time_obs = deque(maxlen=2)
scores = []
dones_list = []
action_prob = [0] * action_size
agent_obs = [None] * env.get_num_agents()
agent_next_obs = [None] * env.get_num_agents()
agent = Agent(state_size, action_size, "FC", 0)
with path(torch_training.Nets, "avoid_checkpoint100.pth") as file_in:
agent = Agent(state_size, action_size)
with path(torch_training.Nets, "navigator_checkpoint1200.pth") as file_in:
agent.qnetwork_local.load_state_dict(torch.load(file_in))
record_images = False
......@@ -97,30 +104,36 @@ frame_step = 0
for trials in range(1, n_trials + 1):
# Reset environment
obs = env.reset(True, True)
obs, info = env.reset(True, True)
env_renderer.reset()
# Build agent specific observations
for a in range(env.get_num_agents()):
agent_obs[a] = normalize_observation(obs[a], observation_radius=10)
agent_obs[a] = agent_obs[a] = normalize_observation(obs[a], tree_depth, observation_radius=10)
# Reset score and done
score = 0
env_done = 0
# Run episode
for step in range(max_steps):
env_renderer.render_env(show=True, show_observations=False, show_predictions=True)
if record_images:
env_renderer.gl.save_image("./Images/Avoiding/flatland_frame_{:04d}.bmp".format(frame_step))
frame_step += 1
# time.sleep(1.5)
# Action
for a in range(env.get_num_agents()):
action = agent.act(agent_obs[a], eps=0)
if info['action_required'][a]:
action = agent.act(agent_obs[a], eps=0.)
else:
action = 0
action_prob[action] += 1
action_dict.update({a: action})
# Environment step
next_obs, all_rewards, done, _ = env.step(action_dict)
obs, all_rewards, done, _ = env.step(action_dict)
env_renderer.render_env(show=True, show_predictions=True, show_observations=False)
# Build agent specific observations and normalize
for a in range(env.get_num_agents()):
agent_obs[a] = normalize_observation(next_obs[a], observation_radius=10)
if obs[a]:
agent_obs[a] = normalize_observation(obs[a], tree_depth, observation_radius=10)
if done['__all__']:
break
This diff is collapsed.
This diff is collapsed.