Skip to content
Snippets Groups Projects
Commit 81d103a5 authored by Erik Nygren's avatar Erik Nygren
Browse files

updated file for conflict detection training. Used for introduction to the problem

parent 06065008
No related branches found
No related tags found
No related merge requests found
# Import packages for plotting and system
import getopt
import random
import sys import sys
from collections import deque from collections import deque
import getopt
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import random
import torch import torch
# Import Flatland/ Observations and Predictors
from flatland.envs.generators import complex_rail_generator from flatland.envs.generators import complex_rail_generator
from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnv
from flatland.utils.rendertools import RenderTool
from importlib_resources import path from importlib_resources import path
# Import Torch and utility functions to normalize observation
import torch_training.Nets import torch_training.Nets
from torch_training.dueling_double_dqn import Agent from torch_training.dueling_double_dqn import Agent
from utils.observation_utils import norm_obs_clip, split_tree from utils.observation_utils import norm_obs_clip, split_tree
...@@ -20,52 +22,53 @@ from utils.observation_utils import norm_obs_clip, split_tree ...@@ -20,52 +22,53 @@ from utils.observation_utils import norm_obs_clip, split_tree
def main(argv): def main(argv):
try: try:
opts, args = getopt.getopt(argv, "n:", ["n_trials="]) opts, args = getopt.getopt(argv, "n:", ["n_episodes="])
except getopt.GetoptError: except getopt.GetoptError:
print('training_navigation.py -n <n_trials>') print('training_navigation.py -n <n_episodes>')
sys.exit(2) sys.exit(2)
for opt, arg in opts: for opt, arg in opts:
if opt in ('-n', '--n_trials'): if opt in ('-n', '--n_episodes'):
n_trials = int(arg) n_episodes = int(arg)
## Initialize the random
random.seed(1) random.seed(1)
np.random.seed(1) np.random.seed(1)
"""
file_name = "./railway/complex_scene.pkl"
env = RailEnv(width=10,
height=20,
rail_generator=rail_from_data(file_name),
obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()))
x_dim = env.width
y_dim = env.height
"""
# Initialize a random map with a random number of agents
x_dim = np.random.randint(8, 20) x_dim = np.random.randint(8, 20)
y_dim = np.random.randint(8, 20) y_dim = np.random.randint(8, 20)
n_agents = np.random.randint(3, 8) n_agents = np.random.randint(3, 8)
n_goals = n_agents + np.random.randint(0, 3) n_goals = n_agents + np.random.randint(0, 3)
min_dist = int(0.75 * min(x_dim, y_dim)) min_dist = int(0.75 * min(x_dim, y_dim))
tree_depth = 3
print("main2") print("main2")
# Get an observation builder and predictor
predictor = ShortestPathPredictorForRailEnv()
observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=predictor)
env = RailEnv(width=x_dim, env = RailEnv(width=x_dim,
height=y_dim, height=y_dim,
rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist, rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist,
max_dist=99999, max_dist=99999,
seed=0), seed=0),
obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()), obs_builder_object=observation_helper,
number_of_agents=n_agents) number_of_agents=n_agents)
env.reset(True, True) env.reset(True, True)
observation_helper = TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv())
env_renderer = RenderTool(env, gl="PILSVG", )
handle = env.get_agent_handles() handle = env.get_agent_handles()
features_per_node = 9 num_features_per_node = env.obs_builder.observation_dim
state_size = features_per_node * 85 * 2 nr_nodes = 0
for i in range(tree_depth + 1):
nr_nodes += np.power(4, i)
state_size = num_features_per_node * nr_nodes
action_size = 5 action_size = 5
# We set the number of episodes we would like to train on # We set the number of episodes we would like to train on
if 'n_trials' not in locals(): if 'n_episodes' not in locals():
n_trials = 60000 n_episodes = 60000
# Set max number of steps per episode as well as other training relevant parameter
max_steps = int(3 * (env.height + env.width)) max_steps = int(3 * (env.height + env.width))
eps = 1. eps = 1.
eps_end = 0.005 eps_end = 0.005
...@@ -74,23 +77,28 @@ def main(argv): ...@@ -74,23 +77,28 @@ def main(argv):
final_action_dict = dict() final_action_dict = dict()
scores_window = deque(maxlen=100) scores_window = deque(maxlen=100)
done_window = deque(maxlen=100) done_window = deque(maxlen=100)
time_obs = deque(maxlen=2)
scores = [] scores = []
dones_list = [] dones_list = []
action_prob = [0] * action_size action_prob = [0] * action_size
agent_obs = [None] * env.get_num_agents() agent_obs = [None] * env.get_num_agents()
agent_next_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents()
agent = Agent(state_size, action_size, "FC", 0) observation_radius = 10
with path(torch_training.Nets, "avoid_checkpoint30000.pth") as file_in:
agent.qnetwork_local.load_state_dict(torch.load(file_in))
demo = False # Initialize the agent
record_images = False agent = Agent(state_size, action_size, "FC", 0)
frame_step = 0
for trials in range(1, n_trials + 1):
if trials % 50 == 0 and not demo: # Here you can pre-load an agent
if False:
with path(torch_training.Nets, "avoid_checkpoint30000.pth") as file_in:
agent.qnetwork_local.load_state_dict(torch.load(file_in))
# Do training over n_episodes
for episodes in range(1, n_episodes + 1):
"""
Training Curriculum: In order to get good generalization we change the number of agents
and the size of the levels every 50 episodes.
"""
if episodes % 50 == 0:
x_dim = np.random.randint(8, 20) x_dim = np.random.randint(8, 20)
y_dim = np.random.randint(8, 20) y_dim = np.random.randint(8, 20)
n_agents = np.random.randint(3, 8) n_agents = np.random.randint(3, 8)
...@@ -101,90 +109,78 @@ def main(argv): ...@@ -101,90 +109,78 @@ def main(argv):
rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist, rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist,
max_dist=99999, max_dist=99999,
seed=0), seed=0),
obs_builder_object=TreeObsForRailEnv(max_depth=3, obs_builder_object=observation_helper,
predictor=ShortestPathPredictorForRailEnv()),
number_of_agents=n_agents) number_of_agents=n_agents)
env.reset(True, True)
# Adjust the parameters according to the new env.
max_steps = int(3 * (env.height + env.width)) max_steps = int(3 * (env.height + env.width))
agent_obs = [None] * env.get_num_agents() agent_obs = [None] * env.get_num_agents()
agent_next_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents()
# Reset environment # Reset environment
obs = env.reset(True, True) obs = env.reset(True, True)
if demo:
env_renderer.set_new_rail() # Setup placeholder for finals observation of a single agent. This is necessary because agents terminate at
obs_original = obs.copy() # different times during an episode
final_obs = obs.copy() final_obs = agent_obs.copy()
final_obs_next = obs.copy() final_obs_next = agent_next_obs.copy()
# Build agent specific observations
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
data, distance, agent_data = split_tree(tree=np.array(obs[a]), data, distance, agent_data = split_tree(tree=np.array(obs[a]), num_features_per_node=num_features_per_node,
current_depth=0) current_depth=0)
data = norm_obs_clip(data) data = norm_obs_clip(data, fixed_radius=observation_radius)
distance = norm_obs_clip(distance) distance = norm_obs_clip(distance)
agent_data = np.clip(agent_data, -1, 1) agent_data = np.clip(agent_data, -1, 1)
obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data)) agent_obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data))
agent_data = env.agents[a]
speed = 1 # np.random.randint(1,5)
agent_data.speed_data['speed'] = 1. / speed
for i in range(2):
time_obs.append(obs)
# env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5)
for a in range(env.get_num_agents()):
agent_obs[a] = np.concatenate((time_obs[0][a], time_obs[1][a]))
score = 0 score = 0
env_done = 0 env_done = 0
# Run episode # Run episode
for step in range(max_steps): for step in range(max_steps):
if demo:
env_renderer.renderEnv(show=True, show_observations=False)
# observation_helper.util_print_obs_subtree(obs_original[0])
if record_images:
env_renderer.gl.saveImage("./Images/flatland_frame_{:04d}.bmp".format(frame_step))
frame_step += 1
# print(step)
# Action # Action
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
if demo:
eps = 0
# action = agent.act(np.array(obs[a]), eps=eps)
action = agent.act(agent_obs[a], eps=eps) action = agent.act(agent_obs[a], eps=eps)
action_prob[action] += 1 action_prob[action] += 1
action_dict.update({a: action}) action_dict.update({a: action})
# Environment step
# Environment step
next_obs, all_rewards, done, _ = env.step(action_dict) next_obs, all_rewards, done, _ = env.step(action_dict)
# print(all_rewards,action)
obs_original = next_obs.copy() # Build agent specific observations and normalize
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
data, distance, agent_data = split_tree(tree=np.array(next_obs[a]), data, distance, agent_data = split_tree(tree=np.array(next_obs[a]),
current_depth=0) num_features_per_node=num_features_per_node, current_depth=0)
data = norm_obs_clip(data) data = norm_obs_clip(data, fixed_radius=observation_radius)
distance = norm_obs_clip(distance) distance = norm_obs_clip(distance)
agent_data = np.clip(agent_data, -1, 1) agent_data = np.clip(agent_data, -1, 1)
next_obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data)) agent_next_obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data))
time_obs.append(next_obs)
# Update replay buffer and train agent # Update replay buffer and train agent
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
agent_next_obs[a] = np.concatenate((time_obs[0][a], time_obs[1][a]))
if done[a]: if done[a]:
final_obs[a] = agent_obs[a].copy() final_obs[a] = agent_obs[a].copy()
final_obs_next[a] = agent_next_obs[a].copy() final_obs_next[a] = agent_next_obs[a].copy()
final_action_dict.update({a: action_dict[a]}) final_action_dict.update({a: action_dict[a]})
if not demo and not done[a]: if not done[a]:
agent.step(agent_obs[a], action_dict[a], all_rewards[a], agent_next_obs[a], done[a]) agent.step(agent_obs[a], action_dict[a], all_rewards[a], agent_next_obs[a], done[a])
score += all_rewards[a] / env.get_num_agents() score += all_rewards[a] / env.get_num_agents()
# Copy observation
agent_obs = agent_next_obs.copy() agent_obs = agent_next_obs.copy()
if done['__all__']: if done['__all__']:
env_done = 1 env_done = 1
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
agent.step(final_obs[a], final_action_dict[a], all_rewards[a], final_obs_next[a], done[a]) agent.step(final_obs[a], final_action_dict[a], all_rewards[a], final_obs_next[a], done[a])
break break
# Epsilon decay # Epsilon decay
eps = max(eps_end, eps_decay * eps) # decrease epsilon eps = max(eps_end, eps_decay * eps) # decrease epsilon
# Collection information about training
done_window.append(env_done) done_window.append(env_done)
scores_window.append(score / max_steps) # save most recent score scores_window.append(score / max_steps) # save most recent score
scores.append(np.mean(scores_window)) scores.append(np.mean(scores_window))
...@@ -193,22 +189,22 @@ def main(argv): ...@@ -193,22 +189,22 @@ def main(argv):
print( print(
'\rTraining {} Agents on ({},{}).\t Episode {}\t Average Score: {:.3f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format( '\rTraining {} Agents on ({},{}).\t Episode {}\t Average Score: {:.3f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format(
env.get_num_agents(), x_dim, y_dim, env.get_num_agents(), x_dim, y_dim,
trials, episodes,
np.mean(scores_window), np.mean(scores_window),
100 * np.mean(done_window), 100 * np.mean(done_window),
eps, action_prob / np.sum(action_prob)), end=" ") eps, action_prob / np.sum(action_prob)), end=" ")
if trials % 100 == 0: if episodes % 100 == 0:
print( print(
'\rTraining {} Agents.\t Episode {}\t Average Score: {:.3f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format( '\rTraining {} Agents.\t Episode {}\t Average Score: {:.3f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format(
env.get_num_agents(), env.get_num_agents(),
trials, episodes,
np.mean(scores_window), np.mean(scores_window),
100 * np.mean(done_window), 100 * np.mean(done_window),
eps, eps,
action_prob / np.sum(action_prob))) action_prob / np.sum(action_prob)))
torch.save(agent.qnetwork_local.state_dict(), torch.save(agent.qnetwork_local.state_dict(),
'./Nets/avoid_checkpoint' + str(trials) + '.pth') './Nets/avoid_checkpoint' + str(episodes) + '.pth')
action_prob = [1] * action_size action_prob = [1] * action_size
plt.plot(scores) plt.plot(scores)
plt.show() plt.show()
......
# Import packages for plotting and system
import getopt
import random
import sys
from collections import deque
import matplotlib.pyplot as plt
import numpy as np
import torch
# Import Flatland/ Observations and Predictors
from flatland.envs.generators import complex_rail_generator
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from importlib_resources import path
# Import Torch and utility functions to normalize observation
import torch_training.Nets
from torch_training.dueling_double_dqn import Agent
from utils.observation_utils import norm_obs_clip, split_tree
def main(argv):
try:
opts, args = getopt.getopt(argv, "n:", ["n_episodes="])
except getopt.GetoptError:
print('training_navigation.py -n <n_episodes>')
sys.exit(2)
for opt, arg in opts:
if opt in ('-n', '--n_episodes'):
n_episodes = int(arg)
## Initialize the random
random.seed(1)
np.random.seed(1)
# Initialize a random map with a random number of agents
x_dim = np.random.randint(8, 20)
y_dim = np.random.randint(8, 20)
n_agents = np.random.randint(3, 8)
n_goals = n_agents + np.random.randint(0, 3)
min_dist = int(0.75 * min(x_dim, y_dim))
tree_depth = 3
print("main2")
# Get an observation builder and predictor
predictor = ShortestPathPredictorForRailEnv()
observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=predictor())
env = RailEnv(width=x_dim,
height=y_dim,
rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist,
max_dist=99999,
seed=0),
obs_builder_object=observation_helper,
number_of_agents=n_agents)
env.reset(True, True)
handle = env.get_agent_handles()
features_per_node = env.obs_builder.observation_dim
tree_depth = 2
nr_nodes = 0
for i in range(tree_depth + 1):
nr_nodes += np.power(4, i)
state_size = 2 * features_per_node * nr_nodes # We will use two time steps per observation --> 2x state_size
action_size = 5
# We set the number of episodes we would like to train on
if 'n_episodes' not in locals():
n_episodes = 60000
# Set max number of steps per episode as well as other training relevant parameter
max_steps = int(3 * (env.height + env.width))
eps = 1.
eps_end = 0.005
eps_decay = 0.9995
action_dict = dict()
final_action_dict = dict()
scores_window = deque(maxlen=100)
done_window = deque(maxlen=100)
time_obs = deque(maxlen=2)
scores = []
dones_list = []
action_prob = [0] * action_size
agent_obs = [None] * env.get_num_agents()
agent_next_obs = [None] * env.get_num_agents()
# Initialize the agent
agent = Agent(state_size, action_size, "FC", 0)
# Here you can pre-load an agent
if False:
with path(torch_training.Nets, "avoid_checkpoint30000.pth") as file_in:
agent.qnetwork_local.load_state_dict(torch.load(file_in))
# Do training over n_episodes
for episodes in range(1, n_episodes + 1):
"""
Training Curriculum: In order to get good generalization we change the number of agents
and the size of the levels every 50 episodes.
"""
if episodes % 50 == 0:
x_dim = np.random.randint(8, 20)
y_dim = np.random.randint(8, 20)
n_agents = np.random.randint(3, 8)
n_goals = n_agents + np.random.randint(0, 3)
min_dist = int(0.75 * min(x_dim, y_dim))
env = RailEnv(width=x_dim,
height=y_dim,
rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist,
max_dist=99999,
seed=0),
obs_builder_object=TreeObsForRailEnv(max_depth=3,
predictor=ShortestPathPredictorForRailEnv()),
number_of_agents=n_agents)
# Adjust the parameters according to the new env.
max_steps = int(3 * (env.height + env.width))
agent_obs = [None] * env.get_num_agents()
agent_next_obs = [None] * env.get_num_agents()
# Reset environment
obs = env.reset(True, True)
# Setup placeholder for finals observation of a single agent. This is necessary because agents terminate at
# different times during an episode
final_obs = agent_obs.copy()
final_obs_next = agent_next_obs.copy()
# Build agent specific observations
for a in range(env.get_num_agents()):
data, distance, agent_data = split_tree(tree=np.array(obs[a]),
current_depth=0)
data = norm_obs_clip(data)
distance = norm_obs_clip(distance)
agent_data = np.clip(agent_data, -1, 1)
obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data))
# Accumulate two time steps of observation (Here just twice the first state)
for i in range(2):
time_obs.append(obs)
# Build the agent specific double ti
for a in range(env.get_num_agents()):
agent_obs[a] = np.concatenate((time_obs[0][a], time_obs[1][a]))
score = 0
env_done = 0
# Run episode
for step in range(max_steps):
# Action
for a in range(env.get_num_agents()):
if demo:
eps = 0
# action = agent.act(np.array(obs[a]), eps=eps)
action = agent.act(agent_obs[a], eps=eps)
action_prob[action] += 1
action_dict.update({a: action})
# Environment step
next_obs, all_rewards, done, _ = env.step(action_dict)
for a in range(env.get_num_agents()):
data, distance, agent_data = split_tree(tree=np.array(next_obs[a]),
current_depth=0)
data = norm_obs_clip(data)
distance = norm_obs_clip(distance)
agent_data = np.clip(agent_data, -1, 1)
next_obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data))
time_obs.append(next_obs)
# Update replay buffer and train agent
for a in range(env.get_num_agents()):
agent_next_obs[a] = np.concatenate((time_obs[0][a], time_obs[1][a]))
if done[a]:
final_obs[a] = agent_obs[a].copy()
final_obs_next[a] = agent_next_obs[a].copy()
final_action_dict.update({a: action_dict[a]})
if not demo and not done[a]:
agent.step(agent_obs[a], action_dict[a], all_rewards[a], agent_next_obs[a], done[a])
score += all_rewards[a] / env.get_num_agents()
agent_obs = agent_next_obs.copy()
if done['__all__']:
env_done = 1
for a in range(env.get_num_agents()):
agent.step(final_obs[a], final_action_dict[a], all_rewards[a], final_obs_next[a], done[a])
break
# Epsilon decay
eps = max(eps_end, eps_decay * eps) # decrease epsilon
done_window.append(env_done)
scores_window.append(score / max_steps) # save most recent score
scores.append(np.mean(scores_window))
dones_list.append((np.mean(done_window)))
print(
'\rTraining {} Agents on ({},{}).\t Episode {}\t Average Score: {:.3f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format(
env.get_num_agents(), x_dim, y_dim,
episodes,
np.mean(scores_window),
100 * np.mean(done_window),
eps, action_prob / np.sum(action_prob)), end=" ")
if episodes % 100 == 0:
print(
'\rTraining {} Agents.\t Episode {}\t Average Score: {:.3f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format(
env.get_num_agents(),
episodes,
np.mean(scores_window),
100 * np.mean(done_window),
eps,
action_prob / np.sum(action_prob)))
torch.save(agent.qnetwork_local.state_dict(),
'./Nets/avoid_checkpoint' + str(episodes) + '.pth')
action_prob = [1] * action_size
plt.plot(scores)
plt.show()
if __name__ == '__main__':
main(sys.argv[1:])
...@@ -52,12 +52,12 @@ def main(argv): ...@@ -52,12 +52,12 @@ def main(argv):
env_renderer = RenderTool(env, gl="PILSVG", ) env_renderer = RenderTool(env, gl="PILSVG", )
# Given the depth of the tree observation and the number of features per node we get the following state_size # Given the depth of the tree observation and the number of features per node we get the following state_size
features_per_node = env.obs_builder.observation_dim num_features_per_node = env.obs_builder.observation_dim
tree_depth = 2 tree_depth = 2
nr_nodes = 0 nr_nodes = 0
for i in range(tree_depth + 1): for i in range(tree_depth + 1):
nr_nodes += np.power(4, i) nr_nodes += np.power(4, i)
state_size = features_per_node * nr_nodes state_size = num_features_per_node * nr_nodes
# The action space of flatland is 5 discrete actions # The action space of flatland is 5 discrete actions
action_size = 5 action_size = 5
...@@ -102,6 +102,7 @@ def main(argv): ...@@ -102,6 +102,7 @@ def main(argv):
# Build agent specific local observation # Build agent specific local observation
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
rail_data, distance_data, agent_data = split_tree(tree=np.array(obs[a]), rail_data, distance_data, agent_data = split_tree(tree=np.array(obs[a]),
num_features_per_node=num_features_per_node,
current_depth=0) current_depth=0)
rail_data = norm_obs_clip(rail_data) rail_data = norm_obs_clip(rail_data)
distance_data = norm_obs_clip(distance_data) distance_data = norm_obs_clip(distance_data)
...@@ -135,6 +136,7 @@ def main(argv): ...@@ -135,6 +136,7 @@ def main(argv):
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
rail_data, distance_data, agent_data = split_tree(tree=np.array(next_obs[a]), rail_data, distance_data, agent_data = split_tree(tree=np.array(next_obs[a]),
num_features_per_node=num_features_per_node,
current_depth=0) current_depth=0)
rail_data = norm_obs_clip(rail_data) rail_data = norm_obs_clip(rail_data)
distance_data = norm_obs_clip(distance_data) distance_data = norm_obs_clip(distance_data)
...@@ -195,6 +197,7 @@ def main(argv): ...@@ -195,6 +197,7 @@ def main(argv):
# Build agent specific local observation # Build agent specific local observation
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
rail_data, distance_data, agent_data = split_tree(tree=np.array(obs[a]), rail_data, distance_data, agent_data = split_tree(tree=np.array(obs[a]),
num_features_per_node=num_features_per_node,
current_depth=0) current_depth=0)
rail_data = norm_obs_clip(rail_data) rail_data = norm_obs_clip(rail_data)
distance_data = norm_obs_clip(distance_data) distance_data = norm_obs_clip(distance_data)
...@@ -220,6 +223,7 @@ def main(argv): ...@@ -220,6 +223,7 @@ def main(argv):
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
rail_data, distance_data, agent_data = split_tree(tree=np.array(next_obs[a]), rail_data, distance_data, agent_data = split_tree(tree=np.array(next_obs[a]),
num_features_per_node=num_features_per_node,
current_depth=0) current_depth=0)
rail_data = norm_obs_clip(rail_data) rail_data = norm_obs_clip(rail_data)
distance_data = norm_obs_clip(distance_data) distance_data = norm_obs_clip(distance_data)
......
import numpy as np import numpy as np
from flatland.envs.observations import TreeObsForRailEnv
def max_lt(seq, val): def max_lt(seq, val):
""" """
...@@ -54,7 +52,7 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1, fixed_radius=0): ...@@ -54,7 +52,7 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1, fixed_radius=0):
return np.clip((np.array(obs) - min_obs) / norm, clip_min, clip_max) return np.clip((np.array(obs) - min_obs) / norm, clip_min, clip_max)
def split_tree(tree, current_depth=0): def split_tree(tree, num_features_per_node, current_depth=0):
""" """
Splits the tree observation into different sub groups that need the same normalization. Splits the tree observation into different sub groups that need the same normalization.
This is necessary because the tree observation includes two different distance: This is necessary because the tree observation includes two different distance:
...@@ -70,7 +68,6 @@ def split_tree(tree, current_depth=0): ...@@ -70,7 +68,6 @@ def split_tree(tree, current_depth=0):
:param current_depth: Keeping track of the current depth in the tree :param current_depth: Keeping track of the current depth in the tree
:return: Returns the three different groups of distance and binary values. :return: Returns the three different groups of distance and binary values.
""" """
num_features_per_node = TreeObsForRailEnv.observation_dim
if len(tree) < num_features_per_node: if len(tree) < num_features_per_node:
return [], [], [] return [], [], []
...@@ -93,7 +90,7 @@ def split_tree(tree, current_depth=0): ...@@ -93,7 +90,7 @@ def split_tree(tree, current_depth=0):
for children in range(4): for children in range(4):
child_tree = tree[(num_features_per_node + children * child_size): child_tree = tree[(num_features_per_node + children * child_size):
(num_features_per_node + (children + 1) * child_size)] (num_features_per_node + (children + 1) * child_size)]
tmp_tree_data, tmp_distance_data, tmp_agent_data = split_tree(child_tree, tmp_tree_data, tmp_distance_data, tmp_agent_data = split_tree(child_tree, num_features_per_node,
current_depth=current_depth + 1) current_depth=current_depth + 1)
if len(tmp_tree_data) > 0: if len(tmp_tree_data) > 0:
tree_data.extend(tmp_tree_data) tree_data.extend(tmp_tree_data)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment