Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • jack_bruck/baselines
  • rivesunder/baselines
  • xzhaoma/baselines
  • giulia_cantini/baselines
  • sfwatergit/baselines
  • jiaodaxiaozi/baselines
  • flatland/baselines
7 results
Show changes
File added
import random
from collections import deque
import numpy as np
import torch
from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.schedule_generators import sparse_schedule_generator
from flatland.utils.rendertools import RenderTool
from importlib_resources import path
import torch_training.Nets
from torch_training.dueling_double_dqn import Agent
from utils.observation_utils import normalize_observation
random.seed(1)
np.random.seed(1)
"""
file_name = "./railway/complex_scene.pkl"
env = RailEnv(width=10,
height=20,
rail_generator=rail_from_file(file_name),
obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()))
x_dim = env.width
y_dim = env.height
"""
# Parameters for the Environment
x_dim = 25
y_dim = 25
n_agents = 1
n_goals = 5
min_dist = 5
# We are training an Agent using the Tree Observation with depth 2
observation_builder = TreeObsForRailEnv(max_depth=2)
# Use a the malfunction generator to break agents from time to time
stochastic_data = MalfunctionParameters(malfunction_rate=1./10000, # Rate of malfunction occurence
min_duration=15, # Minimal duration of malfunction
max_duration=50 # Max duration of malfunction
)
# Custom observation builder
TreeObservation = TreeObsForRailEnv(max_depth=2)
# Different agent types (trains) with different speeds.
speed_ration_map = {1.: 1., # Fast passenger train
1. / 2.: 0.0, # Fast freight train
1. / 3.: 0.0, # Slow commuter train
1. / 4.: 0.0} # Slow freight train
env = RailEnv(width=x_dim,
height=y_dim,
rail_generator=sparse_rail_generator(max_num_cities=3,
# Number of cities in map (where train stations are)
seed=1, # Random seed
grid_mode=False,
max_rails_between_cities=2,
max_rails_in_city=4),
schedule_generator=sparse_schedule_generator(speed_ration_map),
number_of_agents=n_agents,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
obs_builder_object=TreeObservation)
env.reset(True,True)
env_renderer = RenderTool(env, gl="PILSVG", )
num_features_per_node = env.obs_builder.observation_dim
tree_depth = 2
nr_nodes = 0
for i in range(tree_depth + 1):
nr_nodes += np.power(4, i)
state_size = num_features_per_node * nr_nodes
action_size = 5
# We set the number of episodes we would like to train on
if 'n_trials' not in locals():
n_trials = 60000
max_steps = int(3 * (env.height + env.width))
eps = 1.
eps_end = 0.005
eps_decay = 0.9995
action_dict = dict()
final_action_dict = dict()
scores_window = deque(maxlen=100)
done_window = deque(maxlen=100)
scores = []
dones_list = []
action_prob = [0] * action_size
agent_obs = [None] * env.get_num_agents()
agent_next_obs = [None] * env.get_num_agents()
agent = Agent(state_size, action_size)
with path(torch_training.Nets, "navigator_checkpoint1000.pth") as file_in:
agent.qnetwork_local.load_state_dict(torch.load(file_in))
record_images = False
frame_step = 0
for trials in range(1, n_trials + 1):
# Reset environment
obs, info = env.reset(True, True)
env_renderer.reset()
# Build agent specific observations
for a in range(env.get_num_agents()):
agent_obs[a] = agent_obs[a] = normalize_observation(obs[a], tree_depth, observation_radius=10)
# Reset score and done
score = 0
env_done = 0
# Run episode
for step in range(max_steps):
# Action
for a in range(env.get_num_agents()):
if info['action_required'][a]:
action = agent.act(agent_obs[a], eps=0.)
else:
action = 0
action_prob[action] += 1
action_dict.update({a: action})
# Environment step
obs, all_rewards, done, _ = env.step(action_dict)
env_renderer.render_env(show=True, show_predictions=True, show_observations=False)
# Build agent specific observations and normalize
for a in range(env.get_num_agents()):
if obs[a]:
agent_obs[a] = normalize_observation(obs[a], tree_depth, observation_radius=10)
if done['__all__']:
break
import getopt
import random import random
import sys
from collections import deque from collections import deque
# make sure the root path is in system path
from pathlib import Path
from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters
base_dir = Path(__file__).resolve().parent.parent
sys.path.append(str(base_dir))
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import torch import torch
from dueling_double_dqn import Agent from torch_training.dueling_double_dqn import Agent
from flatland.envs.generators import complex_rail_generator
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.schedule_generators import sparse_schedule_generator
from flatland.utils.rendertools import RenderTool from flatland.utils.rendertools import RenderTool
from utils.observation_utils import normalize_observation
from flatland.envs.observations import TreeObsForRailEnv
from utils.observation_utils import norm_obs_clip, split_tree def main(argv):
try:
random.seed(1) opts, args = getopt.getopt(argv, "n:", ["n_trials="])
np.random.seed(1) except getopt.GetoptError:
print('training_navigation.py -n <n_trials>')
# Example generate a rail given a manual specification, sys.exit(2)
# a map of tuples (cell_type, rotation) for opt, arg in opts:
transition_probability = [15, # empty cell - Case 0 if opt in ('-n', '--n_trials'):
5, # Case 1 - straight n_trials = int(arg)
5, # Case 2 - simple switch
1, # Case 3 - diamond crossing random.seed(1)
1, # Case 4 - single slip np.random.seed(1)
1, # Case 5 - double slip
1, # Case 6 - symmetrical # Parameters for the Environment
0, # Case 7 - dead end x_dim = 35
1, # Case 1b (8) - simple turn right y_dim = 35
1, # Case 1c (9) - simple turn left n_agents = 1
1] # Case 2b (10) - simple switch mirrored
# Example generate a random rail # Use a the malfunction generator to break agents from time to time
""" stochastic_data = MalfunctionParameters(malfunction_rate=1./10000, # Rate of malfunction occurence
env = RailEnv(width=20, min_duration=15, # Minimal duration of malfunction
height=20, max_duration=50 # Max duration of malfunction
rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), )
number_of_agents=1)
env = RailEnv(width=15, # Custom observation builder
height=15, TreeObservation = TreeObsForRailEnv(max_depth=2)
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=10, min_dist=10, max_dist=99999, seed=0),
number_of_agents=1) # Different agent types (trains) with different speeds.
speed_ration_map = {1.: 0., # Fast passenger train
1. / 2.: 1.0, # Fast freight train
env = RailEnv(width=10, 1. / 3.: 0.0, # Slow commuter train
height=20, obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv())) 1. / 4.: 0.0} # Slow freight train
env.load("./railway/complex_scene.pkl")
file_load = True env = RailEnv(width=x_dim,
""" height=y_dim,
x_dim = np.random.randint(8, 20) rail_generator=sparse_rail_generator(max_num_cities=3,
y_dim = np.random.randint(8, 20) # Number of cities in map (where train stations are)
n_agents = np.random.randint(3, 8) seed=1, # Random seed
n_goals = n_agents + np.random.randint(0, 3) grid_mode=False,
min_dist = int(0.75 * min(x_dim, y_dim)) max_rails_between_cities=2,
env = RailEnv(width=x_dim, max_rails_in_city=3),
height=y_dim, schedule_generator=sparse_schedule_generator(speed_ration_map),
rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist, number_of_agents=n_agents,
max_dist=99999, malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
seed=0), # Malfunction data generator
obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()), obs_builder_object=TreeObservation)
number_of_agents=n_agents) # Reset env
env.reset(True, True) env.reset(True,True)
file_load = False # After training we want to render the results so we also load a renderer
""" env_renderer = RenderTool(env, gl="PILSVG", )
# Given the depth of the tree observation and the number of features per node we get the following state_size
""" num_features_per_node = env.obs_builder.observation_dim
observation_helper = TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()) tree_depth = 2
env_renderer = RenderTool(env, gl="PILSVG",) nr_nodes = 0
handle = env.get_agent_handles() for i in range(tree_depth + 1):
features_per_node = 9 nr_nodes += np.power(4, i)
state_size = features_per_node * 85 * 2 state_size = num_features_per_node * nr_nodes
action_size = 5
n_trials = 30000 # The action space of flatland is 5 discrete actions
max_steps = int(3 * (env.height + env.width)) action_size = 5
eps = 1.
eps_end = 0.005 # We set the number of episodes we would like to train on
eps_decay = 0.9995 if 'n_trials' not in locals():
action_dict = dict() n_trials = 15000
final_action_dict = dict()
scores_window = deque(maxlen=100) # And the max number of steps we want to take per episode
done_window = deque(maxlen=100) max_steps = int(3 * (env.height + env.width))
time_obs = deque(maxlen=2)
scores = [] # Define training parameters
dones_list = [] eps = 1.
action_prob = [0] * action_size eps_end = 0.005
agent_obs = [None] * env.get_num_agents() eps_decay = 0.998
agent_next_obs = [None] * env.get_num_agents()
agent = Agent(state_size, action_size, "FC", 0) # And some variables to keep track of the progress
agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint30000.pth')) action_dict = dict()
final_action_dict = dict()
demo = True scores_window = deque(maxlen=100)
record_images = False done_window = deque(maxlen=100)
scores = []
dones_list = []
for trials in range(1, n_trials + 1): action_prob = [0] * action_size
agent_obs = [None] * env.get_num_agents()
if trials % 50 == 0 and not demo: agent_next_obs = [None] * env.get_num_agents()
x_dim = np.random.randint(8, 20) agent_obs_buffer = [None] * env.get_num_agents()
y_dim = np.random.randint(8, 20) agent_action_buffer = [2] * env.get_num_agents()
n_agents = np.random.randint(3, 8) cummulated_reward = np.zeros(env.get_num_agents())
n_goals = n_agents + np.random.randint(0, 3) update_values = False
min_dist = int(0.75 * min(x_dim, y_dim)) # Now we load a Double dueling DQN agent
env = RailEnv(width=x_dim, agent = Agent(state_size, action_size)
height=y_dim,
rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist, for trials in range(1, n_trials + 1):
max_dist=99999,
seed=0), # Reset environment
obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()), obs, info = env.reset(True, True)
number_of_agents=n_agents) env_renderer.reset()
env.reset(True, True) # Build agent specific observations
max_steps = int(3 * (env.height + env.width))
agent_obs = [None] * env.get_num_agents()
agent_next_obs = [None] * env.get_num_agents()
# Reset environment
if file_load:
obs = env.reset(False, False)
else:
obs = env.reset(True, True)
if demo:
env_renderer.set_new_rail()
obs_original = obs.copy()
final_obs = obs.copy()
final_obs_next = obs.copy()
for a in range(env.get_num_agents()):
data, distance, agent_data = split_tree(tree=np.array(obs[a]), num_features_per_node=features_per_node,
current_depth=0)
data = norm_obs_clip(data)
distance = norm_obs_clip(distance)
agent_data = np.clip(agent_data, -1, 1)
obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data))
agent_data = env.agents[a]
speed = 1 #np.random.randint(1,5)
agent_data.speed_data['speed'] = 1. / speed
for i in range(2):
time_obs.append(obs)
# env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5)
for a in range(env.get_num_agents()):
agent_obs[a] = np.concatenate((time_obs[0][a], time_obs[1][a]))
score = 0
env_done = 0
# Run episode
for step in range(max_steps):
if demo:
env_renderer.renderEnv(show=True, show_observations=True)
# observation_helper.util_print_obs_subtree(obs_original[0])
if record_images:
env_renderer.gl.saveImage("./Images/flatland_frame_{:04d}.bmp".format(step))
# print(step)
# Action
for a in range(env.get_num_agents()):
if demo:
eps = 0
# action = agent.act(np.array(obs[a]), eps=eps)
action = agent.act(agent_obs[a], eps=eps)
action_prob[action] += 1
action_dict.update({a: action})
# Environment step
next_obs, all_rewards, done, _ = env.step(action_dict)
# print(all_rewards,action)
obs_original = next_obs.copy()
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
data, distance, agent_data = split_tree(tree=np.array(next_obs[a]), num_features_per_node=features_per_node, if obs[a]:
current_depth=0) agent_obs[a] = normalize_observation(obs[a], tree_depth, observation_radius=10)
data = norm_obs_clip(data) agent_obs_buffer[a] = agent_obs[a].copy()
distance = norm_obs_clip(distance)
agent_data = np.clip(agent_data, -1, 1) # Reset score and done
next_obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data)) score = 0
time_obs.append(next_obs) env_done = 0
# Update replay buffer and train agent # Run episode
for a in range(env.get_num_agents()): for step in range(max_steps):
agent_next_obs[a] = np.concatenate((time_obs[0][a], time_obs[1][a])) # Action
if done[a]:
final_obs[a] = agent_obs[a].copy()
final_obs_next[a] = agent_next_obs[a].copy()
final_action_dict.update({a: action_dict[a]})
if not demo and not done[a]:
agent.step(agent_obs[a], action_dict[a], all_rewards[a], agent_next_obs[a], done[a])
score += all_rewards[a] / env.get_num_agents()
agent_obs = agent_next_obs.copy()
if done['__all__']:
env_done = 1
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
agent.step(final_obs[a], final_action_dict[a], all_rewards[a], final_obs_next[a], done[a]) if info['action_required'][a]:
break # If an action is require, we want to store the obs a that step as well as the action
# Epsilon decay update_values = True
eps = max(eps_end, eps_decay * eps) # decrease epsilon action = agent.act(agent_obs[a], eps=eps)
action_prob[action] += 1
done_window.append(env_done) else:
scores_window.append(score / max_steps) # save most recent score update_values = False
scores.append(np.mean(scores_window)) action = 0
dones_list.append((np.mean(done_window))) action_dict.update({a: action})
print( # Environment step
'\rTraining {} Agents on ({},{}).\t Episode {}\t Average Score: {:.3f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format( next_obs, all_rewards, done, info = env.step(action_dict)
env.get_num_agents(), x_dim, y_dim, # Update replay buffer and train agent
trials, for a in range(env.get_num_agents()):
np.mean(scores_window), # Only update the values when we are done or when an action was taken and thus relevant information is present
100 * np.mean(done_window), if update_values or done[a]:
eps, action_prob / np.sum(action_prob)), end=" ") agent.step(agent_obs_buffer[a], agent_action_buffer[a], all_rewards[a],
agent_obs[a], done[a])
if trials % 100 == 0: cummulated_reward[a] = 0.
agent_obs_buffer[a] = agent_obs[a].copy()
agent_action_buffer[a] = action_dict[a]
if next_obs[a]:
agent_obs[a] = normalize_observation(next_obs[a], tree_depth, observation_radius=10)
score += all_rewards[a] / env.get_num_agents()
# Copy observation
if done['__all__']:
env_done = 1
break
# Epsilon decay
eps = max(eps_end, eps_decay * eps) # decrease epsilon
# Collection information about training
tasks_finished = 0
for _idx in range(env.get_num_agents()):
if done[_idx] == 1:
tasks_finished += 1
done_window.append(tasks_finished / max(1, env.get_num_agents()))
scores_window.append(score / max_steps) # save most recent score
scores.append(np.mean(scores_window))
dones_list.append((np.mean(done_window)))
print( print(
'\rTraining {} Agents.\t Episode {}\t Average Score: {:.3f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format( '\rTraining {} Agents on ({},{}).\t Episode {}\t Average Score: {:.3f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format(
env.get_num_agents(), env.get_num_agents(), x_dim, y_dim,
trials, trials,
np.mean(scores_window), np.mean(scores_window),
100 * np.mean(done_window), 100 * np.mean(done_window),
eps, eps, action_prob / np.sum(action_prob)), end=" ")
action_prob / np.sum(action_prob)))
torch.save(agent.qnetwork_local.state_dict(), if trials % 100 == 0:
'./Nets/avoid_checkpoint' + str(trials) + '.pth') print(
action_prob = [1] * action_size '\rTraining {} Agents on ({},{}).\t Episode {}\t Average Score: {:.3f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format(
plt.plot(scores) env.get_num_agents(), x_dim, y_dim,
plt.show() trials,
np.mean(scores_window),
100 * np.mean(done_window),
eps, action_prob / np.sum(action_prob)))
torch.save(agent.qnetwork_local.state_dict(),
'./Nets/navigator_checkpoint' + str(trials) + '.pth')
action_prob = [1] * action_size
# Plot overall training progress at the end
plt.plot(scores)
plt.show()
if __name__ == '__main__':
main(sys.argv[1:])
...@@ -15,13 +15,14 @@ setenv = ...@@ -15,13 +15,14 @@ setenv =
PYTHONPATH = {toxinidir} PYTHONPATH = {toxinidir}
passenv = passenv =
DISPLAY DISPLAY
XAUTHORITY
; HTTP_PROXY+HTTPS_PROXY required behind corporate proxies ; HTTP_PROXY+HTTPS_PROXY required behind corporate proxies
HTTP_PROXY HTTP_PROXY
HTTPS_PROXY HTTPS_PROXY
deps = deps =
-r{toxinidir}/requirements_torch_training.txt -r{toxinidir}/requirements_torch_training.txt
commands = commands =
python torch_training/multi_agent_training.py python torch_training/multi_agent_training.py --n_trials=10
[flake8] [flake8]
max-line-length = 120 max-line-length = 120
...@@ -29,7 +30,12 @@ ignore = E121 E126 E123 E128 E133 E226 E241 E242 E704 W291 W293 W391 W503 W504 W ...@@ -29,7 +30,12 @@ ignore = E121 E126 E123 E128 E133 E226 E241 E242 E704 W291 W293 W391 W503 W504 W
[testenv:flake8] [testenv:flake8]
basepython = python basepython = python
passenv = DISPLAY passenv =
DISPLAY
XAUTHORITY
; HTTP_PROXY+HTTPS_PROXY required behind corporate proxies
HTTP_PROXY
HTTPS_PROXY
deps = deps =
-r{toxinidir}/requirements_torch_training.txt -r{toxinidir}/requirements_torch_training.txt
commands = commands =
......
...@@ -3,15 +3,16 @@ import time ...@@ -3,15 +3,16 @@ import time
from collections import deque from collections import deque
import numpy as np import numpy as np
from flatland.envs.generators import complex_rail_generator
from flatland.envs.observations import GlobalObsForRailEnv from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator
from flatland.envs.schedule_generators import complex_schedule_generator
from line_profiler import LineProfiler from line_profiler import LineProfiler
from utils.observation_utils import norm_obs_clip, split_tree from utils.observation_utils import norm_obs_clip, split_tree_into_feature_groups
def printProgressBar (iteration, total, prefix = '', suffix = '', decimals = 1, length = 100, fill = '*'): def printProgressBar(iteration, total, prefix='', suffix='', decimals=1, length=100, fill='*'):
""" """
Call in a loop to create terminal progress bar Call in a loop to create terminal progress bar
@params: @params:
...@@ -31,13 +32,14 @@ def printProgressBar (iteration, total, prefix = '', suffix = '', decimals = 1, ...@@ -31,13 +32,14 @@ def printProgressBar (iteration, total, prefix = '', suffix = '', decimals = 1,
if iteration == total: if iteration == total:
print('') print('')
class RandomAgent: class RandomAgent:
def __init__(self, state_size, action_size): def __init__(self, state_size, action_size):
self.state_size = state_size self.state_size = state_size
self.action_size = action_size self.action_size = action_size
def act(self, state, eps = 0): def act(self, state, eps=0):
""" """
:param state: input is the observation of the agent :param state: input is the observation of the agent
:return: returns an action :return: returns an action
...@@ -87,6 +89,7 @@ def run_test(parameters, agent, test_nr=0, tree_depth=3): ...@@ -87,6 +89,7 @@ def run_test(parameters, agent, test_nr=0, tree_depth=3):
rail_generator=complex_rail_generator(nr_start_goal=nr_paths, nr_extra=5, min_dist=min_dist, rail_generator=complex_rail_generator(nr_start_goal=nr_paths, nr_extra=5, min_dist=min_dist,
max_dist=99999, max_dist=99999,
seed=parameters[3]), seed=parameters[3]),
schedule_generator=complex_schedule_generator(),
obs_builder_object=GlobalObsForRailEnv(), obs_builder_object=GlobalObsForRailEnv(),
number_of_agents=parameters[2]) number_of_agents=parameters[2])
max_steps = int(3 * (env.height + env.width)) max_steps = int(3 * (env.height + env.width))
...@@ -99,10 +102,9 @@ def run_test(parameters, agent, test_nr=0, tree_depth=3): ...@@ -99,10 +102,9 @@ def run_test(parameters, agent, test_nr=0, tree_depth=3):
# Reset the env # Reset the env
lp_reset(True, True) lp_reset(True, True)
obs = env.reset(True, True) obs, info = env.reset(True, True)
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
data, distance, agent_data = split_tree(tree=np.array(obs[a]), num_features_per_node=9, data, distance, agent_data = split_tree_into_feature_groups(obs[a], tree_depth)
current_depth=0)
data = norm_obs_clip(data) data = norm_obs_clip(data)
distance = norm_obs_clip(distance) distance = norm_obs_clip(distance)
agent_data = np.clip(agent_data, -1, 1) agent_data = np.clip(agent_data, -1, 1)
...@@ -126,9 +128,7 @@ def run_test(parameters, agent, test_nr=0, tree_depth=3): ...@@ -126,9 +128,7 @@ def run_test(parameters, agent, test_nr=0, tree_depth=3):
next_obs, all_rewards, done, _ = lp_step(action_dict) next_obs, all_rewards, done, _ = lp_step(action_dict)
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
data, distance, agent_data = split_tree(tree=np.array(next_obs[a]), data, distance, agent_data = split_tree_into_feature_groups(next_obs[a], tree_depth)
num_features_per_node=features_per_node,
current_depth=0)
data = norm_obs_clip(data) data = norm_obs_clip(data)
distance = norm_obs_clip(distance) distance = norm_obs_clip(distance)
agent_data = np.clip(agent_data, -1, 1) agent_data = np.clip(agent_data, -1, 1)
......
import numpy as np import numpy as np
from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.observations import TreeObsForRailEnv
...@@ -17,7 +16,7 @@ def max_lt(seq, val): ...@@ -17,7 +16,7 @@ def max_lt(seq, val):
return max return max
def min_lt(seq, val): def min_gt(seq, val):
""" """
Return smallest item in seq for which item > val applies. Return smallest item in seq for which item > val applies.
None is returned if seq was empty or all items in seq were >= val. None is returned if seq was empty or all items in seq were >= val.
...@@ -31,7 +30,7 @@ def min_lt(seq, val): ...@@ -31,7 +30,7 @@ def min_lt(seq, val):
return min return min
def norm_obs_clip(obs, clip_min=-1, clip_max=1): def norm_obs_clip(obs, clip_min=-1, clip_max=1, fixed_radius=0, normalize_to_range=False):
""" """
This function returns the difference between min and max value of an observation This function returns the difference between min and max value of an observation
:param obs: Observation that should be normalized :param obs: Observation that should be normalized
...@@ -39,61 +38,89 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1): ...@@ -39,61 +38,89 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1):
:param clip_max: max value where observation will be clipped :param clip_max: max value where observation will be clipped
:return: returnes normalized and clipped observatoin :return: returnes normalized and clipped observatoin
""" """
max_obs = max(1, max_lt(obs, 1000)) if fixed_radius > 0:
min_obs = min(max_obs, min_lt(obs, 0)) max_obs = fixed_radius
else:
max_obs = max(1, max_lt(obs, 1000)) + 1
min_obs = 0 # min(max_obs, min_gt(obs, 0))
if normalize_to_range:
min_obs = min_gt(obs, 0)
if min_obs > max_obs:
min_obs = max_obs
if max_obs == min_obs: if max_obs == min_obs:
return np.clip(np.array(obs) / max_obs, clip_min, clip_max) return np.clip(np.array(obs) / max_obs, clip_min, clip_max)
norm = np.abs(max_obs - min_obs) norm = np.abs(max_obs - min_obs)
if norm == 0:
norm = 1.
return np.clip((np.array(obs) - min_obs) / norm, clip_min, clip_max) return np.clip((np.array(obs) - min_obs) / norm, clip_min, clip_max)
def split_tree(tree, current_depth=0): def _split_node_into_feature_groups(node: TreeObsForRailEnv.Node) -> (np.ndarray, np.ndarray, np.ndarray):
data = np.zeros(6)
distance = np.zeros(1)
agent_data = np.zeros(4)
data[0] = node.dist_own_target_encountered
data[1] = node.dist_other_target_encountered
data[2] = node.dist_other_agent_encountered
data[3] = node.dist_potential_conflict
data[4] = node.dist_unusable_switch
data[5] = node.dist_to_next_branch
distance[0] = node.dist_min_to_target
agent_data[0] = node.num_agents_same_direction
agent_data[1] = node.num_agents_opposite_direction
agent_data[2] = node.num_agents_malfunctioning
agent_data[3] = node.speed_min_fractional
return data, distance, agent_data
def _split_subtree_into_feature_groups(node: TreeObsForRailEnv.Node, current_tree_depth: int, max_tree_depth: int) -> (np.ndarray, np.ndarray, np.ndarray):
if node == -np.inf:
remaining_depth = max_tree_depth - current_tree_depth
# reference: https://stackoverflow.com/questions/515214/total-number-of-nodes-in-a-tree-data-structure
num_remaining_nodes = int((4**(remaining_depth+1) - 1) / (4 - 1))
return [-np.inf] * num_remaining_nodes*6, [-np.inf] * num_remaining_nodes, [-np.inf] * num_remaining_nodes*4
data, distance, agent_data = _split_node_into_feature_groups(node)
if not node.childs:
return data, distance, agent_data
for direction in TreeObsForRailEnv.tree_explored_actions_char:
sub_data, sub_distance, sub_agent_data = _split_subtree_into_feature_groups(node.childs[direction], current_tree_depth + 1, max_tree_depth)
data = np.concatenate((data, sub_data))
distance = np.concatenate((distance, sub_distance))
agent_data = np.concatenate((agent_data, sub_agent_data))
return data, distance, agent_data
def split_tree_into_feature_groups(tree: TreeObsForRailEnv.Node, max_tree_depth: int) -> (np.ndarray, np.ndarray, np.ndarray):
""" """
Splits the tree observation into different sub groups that need the same normalization. This function splits the tree into three difference arrays of values
This is necessary because the tree observation includes two different distance:
1. Distance from the agent --> This is measured in cells from current agent location
2. Distance to targer --> This is measured as distance from cell to agent target
3. Binary data --> Contains information about presence of object --> No normalization necessary
Number 1. will depend on the depth and size of the tree search
Number 2. will depend on the size of the map and thus the max distance on the map
Number 3. Is independent of tree depth and map size and thus must be handled differently
Therefore we split the tree into these two classes for better normalization.
:param tree: Tree that needs to be split
:param num_features_per_node: Features per node ATTENTION! this parameter is vital to correct splitting of the tree.
:param current_depth: Keeping track of the current depth in the tree
:return: Returns the three different groups of distance and binary values.
""" """
num_features_per_node = TreeObsForRailEnv.observation_dim data, distance, agent_data = _split_node_into_feature_groups(tree)
if len(tree) < num_features_per_node: for direction in TreeObsForRailEnv.tree_explored_actions_char:
return [], [], [] sub_data, sub_distance, sub_agent_data = _split_subtree_into_feature_groups(tree.childs[direction], 1, max_tree_depth)
data = np.concatenate((data, sub_data))
depth = 0 distance = np.concatenate((distance, sub_distance))
tmp = len(tree) / num_features_per_node - 1 agent_data = np.concatenate((agent_data, sub_agent_data))
pow4 = 4
while tmp > 0: return data, distance, agent_data
tmp -= pow4
depth += 1
pow4 *= 4 def normalize_observation(observation: TreeObsForRailEnv.Node, tree_depth: int, observation_radius=0):
child_size = (len(tree) - num_features_per_node) // 4
""" """
Here we split the node features into the different classes of distances and binary values. This function normalizes the observation used by the RL algorithm
Pay close attention to this part if you modify any of the features in the tree observation.
""" """
tree_data = tree[:6].tolist() data, distance, agent_data = split_tree_into_feature_groups(observation, tree_depth)
distance_data = [tree[6]]
agent_data = tree[7:num_features_per_node].tolist() data = norm_obs_clip(data, fixed_radius=observation_radius)
# Split each child of the current node and continue to next depth level distance = norm_obs_clip(distance, normalize_to_range=True)
for children in range(4): agent_data = np.clip(agent_data, -1, 1)
child_tree = tree[(num_features_per_node + children * child_size): normalized_obs = np.concatenate((np.concatenate((data, distance)), agent_data))
(num_features_per_node + (children + 1) * child_size)] return normalized_obs
tmp_tree_data, tmp_distance_data, tmp_agent_data = split_tree(child_tree,
current_depth=current_depth + 1)
if len(tmp_tree_data) > 0:
tree_data.extend(tmp_tree_data)
distance_data.extend(tmp_distance_data)
agent_data.extend(tmp_agent_data)
return tree_data, distance_data, agent_data