Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • jack_bruck/baselines
  • rivesunder/baselines
  • xzhaoma/baselines
  • giulia_cantini/baselines
  • sfwatergit/baselines
  • jiaodaxiaozi/baselines
  • flatland/baselines
7 results
Show changes
"""
Collection of environment-specific PredictionBuilder.
"""
import numpy as np
from flatland.core.env_prediction_builder import PredictionBuilder
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.rail_env import RailEnvActions
class DummyPredictorForRailEnv(PredictionBuilder):
"""
DummyPredictorForRailEnv object.
This object returns predictions for agents in the RailEnv environment.
The prediction acts as if no other agent is in the environment and always takes the forward action.
"""
def get(self, custom_args=None, handle=None):
"""
Called whenever get_many in the observation build is called.
Parameters
-------
custom_args: dict
Not used in this dummy implementation.
handle : int (optional)
Handle of the agent for which to compute the observation vector.
Returns
-------
np.array
Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1)x5 elements:
- time_offset
- position axis 0
- position axis 1
- direction
- action taken to come here
The prediction at 0 is the current position, direction etc.
"""
agents = self.env.agents
if handle:
agents = [self.env.agents[handle]]
prediction_dict = {}
for agent in agents:
action_priorities = [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT]
_agent_initial_position = agent.position
_agent_initial_direction = agent.direction
prediction = np.zeros(shape=(self.max_depth + 1, 5))
prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0]
for index in range(1, self.max_depth + 1):
action_done = False
# if we're at the target, stop moving...
if agent.position == agent.target:
prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING]
continue
for action in action_priorities:
cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \
self.env._check_action_on_agent(action, agent)
if all([new_cell_isValid, transition_isValid]):
# move and change direction to face the new_direction that was
# performed
agent.position = new_position
agent.direction = new_direction
prediction[index] = [index, *new_position, new_direction, action]
action_done = True
break
if not action_done:
raise Exception("Cannot move further. Something is wrong")
prediction_dict[agent.handle] = prediction
agent.position = _agent_initial_position
agent.direction = _agent_initial_direction
return prediction_dict
class ShortestPathPredictorForRailEnv(PredictionBuilder):
"""
ShortestPathPredictorForRailEnv object.
This object returns shortest-path predictions for agents in the RailEnv environment.
The prediction acts as if no other agent is in the environment and always takes the forward action.
"""
def __init__(self, max_depth=20):
# Initialize with depth 20
self.max_depth = max_depth
def get(self, custom_args=None, handle=None):
"""
Called whenever get_many in the observation build is called.
Requires distance_map to extract the shortest path.
Parameters
-------
custom_args: dict
- distance_map : dict
handle : int (optional)
Handle of the agent for which to compute the observation vector.
Returns
-------
np.array
Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1)x5 elements:
- time_offset
- position axis 0
- position axis 1
- direction
- action taken to come here
The prediction at 0 is the current position, direction etc.
"""
agents = self.env.agents
if handle:
agents = [self.env.agents[handle]]
assert custom_args is not None
distance_map = custom_args.get('distance_map')
assert distance_map is not None
prediction_dict = {}
for agent in agents:
_agent_initial_position = agent.position
_agent_initial_direction = agent.direction
agent_speed = agent.speed_data["speed"]
times_per_cell = int(np.reciprocal(agent_speed))
prediction = np.zeros(shape=(self.max_depth + 1, 5))
prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0]
new_direction = _agent_initial_direction
new_position = _agent_initial_position
visited = set()
for index in range(1, self.max_depth + 1):
# if we're at the target, stop moving...
if agent.position == agent.target:
prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING]
visited.add((agent.position[0], agent.position[1], agent.direction))
continue
if not agent.moving:
prediction[index] = [index, *agent.position, agent.direction, RailEnvActions.STOP_MOVING]
visited.add((agent.position[0], agent.position[1], agent.direction))
continue
# Take shortest possible path
cell_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
if np.sum(cell_transitions) == 1 and index % times_per_cell == 0:
new_direction = np.argmax(cell_transitions)
new_position = get_new_position(agent.position, new_direction)
elif np.sum(cell_transitions) > 1 and index % times_per_cell == 0:
min_dist = np.inf
no_dist_found = True
for direction in range(4):
if cell_transitions[direction] == 1:
neighbour_cell = get_new_position(agent.position, direction)
target_dist = distance_map[agent.handle, neighbour_cell[0], neighbour_cell[1], direction]
if target_dist < min_dist or no_dist_found:
min_dist = target_dist
new_direction = direction
no_dist_found = False
new_position = get_new_position(agent.position, new_direction)
elif index % times_per_cell == 0:
raise Exception("No transition possible {}".format(cell_transitions))
# update the agent's position and direction
agent.position = new_position
agent.direction = new_direction
# prediction is ready
prediction[index] = [index, *new_position, new_direction, 0]
visited.add((new_position[0], new_position[1], new_direction))
self.env.dev_pred_dict[agent.handle] = visited
prediction_dict[agent.handle] = prediction
# cleanup: reset initial position
agent.position = _agent_initial_position
agent.direction = _agent_initial_direction
return prediction_dict
...@@ -3,15 +3,15 @@ from collections import deque ...@@ -3,15 +3,15 @@ from collections import deque
import numpy as np import numpy as np
import torch import torch
from importlib_resources import path from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters
import torch_training.Nets
from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.schedule_generators import sparse_schedule_generator from flatland.envs.schedule_generators import sparse_schedule_generator
from flatland.utils.rendertools import RenderTool from flatland.utils.rendertools import RenderTool
from importlib_resources import path
import torch_training.Nets
from torch_training.dueling_double_dqn import Agent from torch_training.dueling_double_dqn import Agent
from utils.observation_utils import normalize_observation from utils.observation_utils import normalize_observation
...@@ -28,8 +28,8 @@ y_dim = env.height ...@@ -28,8 +28,8 @@ y_dim = env.height
""" """
# Parameters for the Environment # Parameters for the Environment
x_dim = 20 x_dim = 25
y_dim = 20 y_dim = 25
n_agents = 1 n_agents = 1
n_goals = 5 n_goals = 5
min_dist = 5 min_dist = 5
...@@ -38,43 +38,34 @@ min_dist = 5 ...@@ -38,43 +38,34 @@ min_dist = 5
observation_builder = TreeObsForRailEnv(max_depth=2) observation_builder = TreeObsForRailEnv(max_depth=2)
# Use a the malfunction generator to break agents from time to time # Use a the malfunction generator to break agents from time to time
stochastic_data = {'prop_malfunction': 0.0, # Percentage of defective agents stochastic_data = MalfunctionParameters(malfunction_rate=1./10000, # Rate of malfunction occurence
'malfunction_rate': 30, # Rate of malfunction occurence min_duration=15, # Minimal duration of malfunction
'min_duration': 3, # Minimal duration of malfunction max_duration=50 # Max duration of malfunction
'max_duration': 20 # Max duration of malfunction )
}
# Custom observation builder # Custom observation builder
TreeObservation = TreeObsForRailEnv(max_depth=2) TreeObservation = TreeObsForRailEnv(max_depth=2)
# Different agent types (trains) with different speeds. # Different agent types (trains) with different speeds.
speed_ration_map = {1.: 0., # Fast passenger train speed_ration_map = {1.: 1., # Fast passenger train
1. / 2.: 1.0, # Fast freight train 1. / 2.: 0.0, # Fast freight train
1. / 3.: 0.0, # Slow commuter train 1. / 3.: 0.0, # Slow commuter train
1. / 4.: 0.0} # Slow freight train 1. / 4.: 0.0} # Slow freight train
env = RailEnv(width=x_dim, env = RailEnv(width=x_dim,
height=y_dim, height=y_dim,
rail_generator=sparse_rail_generator(num_cities=5, rail_generator=sparse_rail_generator(max_num_cities=3,
# Number of cities in map (where train stations are) # Number of cities in map (where train stations are)
num_intersections=4, seed=1, # Random seed
# Number of intersections (no start / target) grid_mode=False,
num_trainstations=10, # Number of possible start/targets on map max_rails_between_cities=2,
min_node_dist=3, # Minimal distance of nodes max_rails_in_city=4),
node_radius=2, # Proximity of stations to city center
num_neighb=3,
# Number of connections to other cities/intersections
seed=15, # Random seed
grid_mode=True,
enhance_intersection=False
),
schedule_generator=sparse_schedule_generator(speed_ration_map), schedule_generator=sparse_schedule_generator(speed_ration_map),
number_of_agents=n_agents, number_of_agents=n_agents,
stochastic_data=stochastic_data, # Malfunction data generator malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
obs_builder_object=TreeObservation) obs_builder_object=TreeObservation)
env.reset(True, True) env.reset(True,True)
observation_helper = TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv())
env_renderer = RenderTool(env, gl="PILSVG", ) env_renderer = RenderTool(env, gl="PILSVG", )
num_features_per_node = env.obs_builder.observation_dim num_features_per_node = env.obs_builder.observation_dim
...@@ -101,8 +92,8 @@ dones_list = [] ...@@ -101,8 +92,8 @@ dones_list = []
action_prob = [0] * action_size action_prob = [0] * action_size
agent_obs = [None] * env.get_num_agents() agent_obs = [None] * env.get_num_agents()
agent_next_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents()
agent = Agent(state_size, action_size, "FC", 0) agent = Agent(state_size, action_size)
with path(torch_training.Nets, "navigator_checkpoint10700.pth") as file_in: with path(torch_training.Nets, "navigator_checkpoint1000.pth") as file_in:
agent.qnetwork_local.load_state_dict(torch.load(file_in)) agent.qnetwork_local.load_state_dict(torch.load(file_in))
record_images = False record_images = False
...@@ -111,11 +102,11 @@ frame_step = 0 ...@@ -111,11 +102,11 @@ frame_step = 0
for trials in range(1, n_trials + 1): for trials in range(1, n_trials + 1):
# Reset environment # Reset environment
obs = env.reset(True, True) obs, info = env.reset(True, True)
env_renderer.reset() env_renderer.reset()
# Build agent specific observations # Build agent specific observations
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
agent_obs[a] = agent_obs[a] = normalize_observation(obs[a], observation_radius=10) agent_obs[a] = agent_obs[a] = normalize_observation(obs[a], tree_depth, observation_radius=10)
# Reset score and done # Reset score and done
score = 0 score = 0
env_done = 0 env_done = 0
...@@ -125,15 +116,22 @@ for trials in range(1, n_trials + 1): ...@@ -125,15 +116,22 @@ for trials in range(1, n_trials + 1):
# Action # Action
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
action = agent.act(agent_obs[a], eps=0.) if info['action_required'][a]:
action = agent.act(agent_obs[a], eps=0.)
else:
action = 0
action_prob[action] += 1 action_prob[action] += 1
action_dict.update({a: action}) action_dict.update({a: action})
# Environment step # Environment step
obs, all_rewards, done, _ = env.step(action_dict) obs, all_rewards, done, _ = env.step(action_dict)
env_renderer.render_env(show=True, show_predictions=True, show_observations=False) env_renderer.render_env(show=True, show_predictions=True, show_observations=False)
# Build agent specific observations and normalize # Build agent specific observations and normalize
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
agent_obs[a] = normalize_observation(obs[a], observation_radius=10) if obs[a]:
agent_obs[a] = normalize_observation(obs[a], tree_depth, observation_radius=10)
if done['__all__']: if done['__all__']:
......
...@@ -2,19 +2,25 @@ import getopt ...@@ -2,19 +2,25 @@ import getopt
import random import random
import sys import sys
from collections import deque from collections import deque
# make sure the root path is in system path
from pathlib import Path
from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters
base_dir = Path(__file__).resolve().parent.parent
sys.path.append(str(base_dir))
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import torch import torch
from dueling_double_dqn import Agent from torch_training.dueling_double_dqn import Agent
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.schedule_generators import sparse_schedule_generator from flatland.envs.schedule_generators import sparse_schedule_generator
from flatland.utils.rendertools import RenderTool from flatland.utils.rendertools import RenderTool
from utils.observation_utils import normalize_observation from utils.observation_utils import normalize_observation
from flatland.envs.observations import TreeObsForRailEnv
def main(argv): def main(argv):
try: try:
...@@ -30,18 +36,17 @@ def main(argv): ...@@ -30,18 +36,17 @@ def main(argv):
np.random.seed(1) np.random.seed(1)
# Parameters for the Environment # Parameters for the Environment
x_dim = 20 x_dim = 35
y_dim = 20 y_dim = 35
n_agents = 1 n_agents = 1
n_goals = 5
min_dist = 5
# Use a the malfunction generator to break agents from time to time # Use a the malfunction generator to break agents from time to time
stochastic_data = {'prop_malfunction': 0.0, # Percentage of defective agents stochastic_data = MalfunctionParameters(malfunction_rate=1./10000, # Rate of malfunction occurence
'malfunction_rate': 30, # Rate of malfunction occurence min_duration=15, # Minimal duration of malfunction
'min_duration': 3, # Minimal duration of malfunction max_duration=50 # Max duration of malfunction
'max_duration': 20 # Max duration of malfunction )
}
# Custom observation builder # Custom observation builder
TreeObservation = TreeObsForRailEnv(max_depth=2) TreeObservation = TreeObsForRailEnv(max_depth=2)
...@@ -54,24 +59,19 @@ def main(argv): ...@@ -54,24 +59,19 @@ def main(argv):
env = RailEnv(width=x_dim, env = RailEnv(width=x_dim,
height=y_dim, height=y_dim,
rail_generator=sparse_rail_generator(num_cities=5, rail_generator=sparse_rail_generator(max_num_cities=3,
# Number of cities in map (where train stations are) # Number of cities in map (where train stations are)
num_intersections=4, seed=1, # Random seed
# Number of intersections (no start / target) grid_mode=False,
num_trainstations=10, # Number of possible start/targets on map max_rails_between_cities=2,
min_node_dist=3, # Minimal distance of nodes max_rails_in_city=3),
node_radius=2, # Proximity of stations to city center
num_neighb=3,
# Number of connections to other cities/intersections
seed=15, # Random seed
grid_mode=True,
enhance_intersection=False
),
schedule_generator=sparse_schedule_generator(speed_ration_map), schedule_generator=sparse_schedule_generator(speed_ration_map),
number_of_agents=n_agents, number_of_agents=n_agents,
stochastic_data=stochastic_data, # Malfunction data generator malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
# Malfunction data generator
obs_builder_object=TreeObservation) obs_builder_object=TreeObservation)
# Reset env
env.reset(True,True)
# After training we want to render the results so we also load a renderer # After training we want to render the results so we also load a renderer
env_renderer = RenderTool(env, gl="PILSVG", ) env_renderer = RenderTool(env, gl="PILSVG", )
# Given the depth of the tree observation and the number of features per node we get the following state_size # Given the depth of the tree observation and the number of features per node we get the following state_size
...@@ -108,24 +108,22 @@ def main(argv): ...@@ -108,24 +108,22 @@ def main(argv):
agent_obs = [None] * env.get_num_agents() agent_obs = [None] * env.get_num_agents()
agent_next_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents()
agent_obs_buffer = [None] * env.get_num_agents() agent_obs_buffer = [None] * env.get_num_agents()
agent_action_buffer = [None] * env.get_num_agents() agent_action_buffer = [2] * env.get_num_agents()
cummulated_reward = np.zeros(env.get_num_agents()) cummulated_reward = np.zeros(env.get_num_agents())
update_values = False
# Now we load a Double dueling DQN agent # Now we load a Double dueling DQN agent
agent = Agent(state_size, action_size, "FC", 0) agent = Agent(state_size, action_size)
for trials in range(1, n_trials + 1): for trials in range(1, n_trials + 1):
# Reset environment # Reset environment
obs = env.reset(True, True) obs, info = env.reset(True, True)
register_action_state = np.zeros(env.get_num_agents(), dtype=bool) env_renderer.reset()
final_obs = agent_obs.copy()
final_obs_next = agent_next_obs.copy()
# Build agent specific observations # Build agent specific observations
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
agent_obs[a] = normalize_observation(obs[a], observation_radius=10) if obs[a]:
agent_obs_buffer[a] = agent_obs[a].copy() agent_obs[a] = normalize_observation(obs[a], tree_depth, observation_radius=10)
agent_obs_buffer[a] = agent_obs[a].copy()
# Reset score and done # Reset score and done
score = 0 score = 0
...@@ -135,49 +133,36 @@ def main(argv): ...@@ -135,49 +133,36 @@ def main(argv):
for step in range(max_steps): for step in range(max_steps):
# Action # Action
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
if env.agents[a].speed_data['position_fraction'] < 0.001: if info['action_required'][a]:
register_action_state[a] = True # If an action is require, we want to store the obs a that step as well as the action
update_values = True
action = agent.act(agent_obs[a], eps=eps) action = agent.act(agent_obs[a], eps=eps)
action_prob[action] += 1 action_prob[action] += 1
if step == 0:
agent_action_buffer[a] = action
else: else:
register_action_state[a] = False update_values = False
action = 0 action = 0
action_dict.update({a: action}) action_dict.update({a: action})
# Environment step # Environment step
next_obs, all_rewards, done, _ = env.step(action_dict) next_obs, all_rewards, done, info = env.step(action_dict)
# Build agent specific observations and normalize
for a in range(env.get_num_agents()):
agent_next_obs[a] = normalize_observation(next_obs[a], observation_radius=10)
cummulated_reward[a] += all_rewards[a]
# Update replay buffer and train agent # Update replay buffer and train agent
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
if done[a]: # Only update the values when we are done or when an action was taken and thus relevant information is present
final_obs[a] = agent_obs_buffer[a] if update_values or done[a]:
final_obs_next[a] = agent_next_obs[a].copy() agent.step(agent_obs_buffer[a], agent_action_buffer[a], all_rewards[a],
final_action_dict.update({a: agent_action_buffer[a]}) agent_obs[a], done[a])
if not done[a]: cummulated_reward[a] = 0.
if agent_obs_buffer[a] is not None and register_action_state[a]:
agent_delayed_next = agent_obs[a].copy() agent_obs_buffer[a] = agent_obs[a].copy()
agent.step(agent_obs_buffer[a], agent_action_buffer[a], all_rewards[a], agent_action_buffer[a] = action_dict[a]
agent_delayed_next, done[a]) if next_obs[a]:
cummulated_reward[a] = 0. agent_obs[a] = normalize_observation(next_obs[a], tree_depth, observation_radius=10)
if register_action_state[a]:
agent_obs_buffer[a] = agent_obs[a].copy()
agent_action_buffer[a] = action_dict[a]
score += all_rewards[a] / env.get_num_agents() score += all_rewards[a] / env.get_num_agents()
# Copy observation # Copy observation
agent_obs = agent_next_obs.copy()
if done['__all__']: if done['__all__']:
env_done = 1 env_done = 1
for a in range(env.get_num_agents()):
agent.step(final_obs[a], final_action_dict[a], all_rewards[a], final_obs_next[a], done[a])
break break
# Epsilon decay # Epsilon decay
...@@ -188,7 +173,7 @@ def main(argv): ...@@ -188,7 +173,7 @@ def main(argv):
for _idx in range(env.get_num_agents()): for _idx in range(env.get_num_agents()):
if done[_idx] == 1: if done[_idx] == 1:
tasks_finished += 1 tasks_finished += 1
done_window.append(tasks_finished / env.get_num_agents()) done_window.append(tasks_finished / max(1, env.get_num_agents()))
scores_window.append(score / max_steps) # save most recent score scores_window.append(score / max_steps) # save most recent score
scores.append(np.mean(scores_window)) scores.append(np.mean(scores_window))
dones_list.append((np.mean(done_window))) dones_list.append((np.mean(done_window)))
......
...@@ -3,13 +3,13 @@ import time ...@@ -3,13 +3,13 @@ import time
from collections import deque from collections import deque
import numpy as np import numpy as np
from line_profiler import LineProfiler
from flatland.envs.observations import GlobalObsForRailEnv from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator from flatland.envs.rail_generators import complex_rail_generator
from flatland.envs.schedule_generators import complex_schedule_generator from flatland.envs.schedule_generators import complex_schedule_generator
from utils.observation_utils import norm_obs_clip, split_tree from line_profiler import LineProfiler
from utils.observation_utils import norm_obs_clip, split_tree_into_feature_groups
def printProgressBar(iteration, total, prefix='', suffix='', decimals=1, length=100, fill='*'): def printProgressBar(iteration, total, prefix='', suffix='', decimals=1, length=100, fill='*'):
...@@ -102,10 +102,9 @@ def run_test(parameters, agent, test_nr=0, tree_depth=3): ...@@ -102,10 +102,9 @@ def run_test(parameters, agent, test_nr=0, tree_depth=3):
# Reset the env # Reset the env
lp_reset(True, True) lp_reset(True, True)
obs = env.reset(True, True) obs, info = env.reset(True, True)
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
data, distance, agent_data = split_tree(tree=np.array(obs[a]), data, distance, agent_data = split_tree_into_feature_groups(obs[a], tree_depth)
current_depth=0)
data = norm_obs_clip(data) data = norm_obs_clip(data)
distance = norm_obs_clip(distance) distance = norm_obs_clip(distance)
agent_data = np.clip(agent_data, -1, 1) agent_data = np.clip(agent_data, -1, 1)
...@@ -129,8 +128,7 @@ def run_test(parameters, agent, test_nr=0, tree_depth=3): ...@@ -129,8 +128,7 @@ def run_test(parameters, agent, test_nr=0, tree_depth=3):
next_obs, all_rewards, done, _ = lp_step(action_dict) next_obs, all_rewards, done, _ = lp_step(action_dict)
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
data, distance, agent_data = split_tree(tree=np.array(next_obs[a]), data, distance, agent_data = split_tree_into_feature_groups(next_obs[a], tree_depth)
current_depth=0)
data = norm_obs_clip(data) data = norm_obs_clip(data)
distance = norm_obs_clip(distance) distance = norm_obs_clip(distance)
agent_data = np.clip(agent_data, -1, 1) agent_data = np.clip(agent_data, -1, 1)
......
import numpy as np import numpy as np
from flatland.envs.observations import TreeObsForRailEnv
def max_lt(seq, val): def max_lt(seq, val):
...@@ -53,57 +54,71 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1, fixed_radius=0, normalize_to_ran ...@@ -53,57 +54,71 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1, fixed_radius=0, normalize_to_ran
return np.clip((np.array(obs) - min_obs) / norm, clip_min, clip_max) return np.clip((np.array(obs) - min_obs) / norm, clip_min, clip_max)
def split_tree(tree, num_features_per_node, current_depth=0): def _split_node_into_feature_groups(node: TreeObsForRailEnv.Node) -> (np.ndarray, np.ndarray, np.ndarray):
data = np.zeros(6)
distance = np.zeros(1)
agent_data = np.zeros(4)
data[0] = node.dist_own_target_encountered
data[1] = node.dist_other_target_encountered
data[2] = node.dist_other_agent_encountered
data[3] = node.dist_potential_conflict
data[4] = node.dist_unusable_switch
data[5] = node.dist_to_next_branch
distance[0] = node.dist_min_to_target
agent_data[0] = node.num_agents_same_direction
agent_data[1] = node.num_agents_opposite_direction
agent_data[2] = node.num_agents_malfunctioning
agent_data[3] = node.speed_min_fractional
return data, distance, agent_data
def _split_subtree_into_feature_groups(node: TreeObsForRailEnv.Node, current_tree_depth: int, max_tree_depth: int) -> (np.ndarray, np.ndarray, np.ndarray):
if node == -np.inf:
remaining_depth = max_tree_depth - current_tree_depth
# reference: https://stackoverflow.com/questions/515214/total-number-of-nodes-in-a-tree-data-structure
num_remaining_nodes = int((4**(remaining_depth+1) - 1) / (4 - 1))
return [-np.inf] * num_remaining_nodes*6, [-np.inf] * num_remaining_nodes, [-np.inf] * num_remaining_nodes*4
data, distance, agent_data = _split_node_into_feature_groups(node)
if not node.childs:
return data, distance, agent_data
for direction in TreeObsForRailEnv.tree_explored_actions_char:
sub_data, sub_distance, sub_agent_data = _split_subtree_into_feature_groups(node.childs[direction], current_tree_depth + 1, max_tree_depth)
data = np.concatenate((data, sub_data))
distance = np.concatenate((distance, sub_distance))
agent_data = np.concatenate((agent_data, sub_agent_data))
return data, distance, agent_data
def split_tree_into_feature_groups(tree: TreeObsForRailEnv.Node, max_tree_depth: int) -> (np.ndarray, np.ndarray, np.ndarray):
""" """
Splits the tree observation into different sub groups that need the same normalization. This function splits the tree into three difference arrays of values
This is necessary because the tree observation includes two different distance:
1. Distance from the agent --> This is measured in cells from current agent location
2. Distance to targer --> This is measured as distance from cell to agent target
3. Binary data --> Contains information about presence of object --> No normalization necessary
Number 1. will depend on the depth and size of the tree search
Number 2. will depend on the size of the map and thus the max distance on the map
Number 3. Is independent of tree depth and map size and thus must be handled differently
Therefore we split the tree into these two classes for better normalization.
:param tree: Tree that needs to be split
:param num_features_per_node: Features per node ATTENTION! this parameter is vital to correct splitting of the tree.
:param current_depth: Keeping track of the current depth in the tree
:return: Returns the three different groups of distance and binary values.
""" """
if len(tree) < num_features_per_node: data, distance, agent_data = _split_node_into_feature_groups(tree)
return [], [], []
for direction in TreeObsForRailEnv.tree_explored_actions_char:
depth = 0 sub_data, sub_distance, sub_agent_data = _split_subtree_into_feature_groups(tree.childs[direction], 1, max_tree_depth)
tmp = len(tree) / num_features_per_node - 1 data = np.concatenate((data, sub_data))
pow4 = 4 distance = np.concatenate((distance, sub_distance))
while tmp > 0: agent_data = np.concatenate((agent_data, sub_agent_data))
tmp -= pow4
depth += 1 return data, distance, agent_data
pow4 *= 4
child_size = (len(tree) - num_features_per_node) // 4
def normalize_observation(observation: TreeObsForRailEnv.Node, tree_depth: int, observation_radius=0):
""" """
Here we split the node features into the different classes of distances and binary values. This function normalizes the observation used by the RL algorithm
Pay close attention to this part if you modify any of the features in the tree observation.
""" """
tree_data = tree[:6].tolist() data, distance, agent_data = split_tree_into_feature_groups(observation, tree_depth)
distance_data = [tree[6]]
agent_data = tree[7:num_features_per_node].tolist()
# Split each child of the current node and continue to next depth level
for children in range(4):
child_tree = tree[(num_features_per_node + children * child_size):
(num_features_per_node + (children + 1) * child_size)]
tmp_tree_data, tmp_distance_data, tmp_agent_data = split_tree(child_tree, num_features_per_node,
current_depth=current_depth + 1)
if len(tmp_tree_data) > 0:
tree_data.extend(tmp_tree_data)
distance_data.extend(tmp_distance_data)
agent_data.extend(tmp_agent_data)
return tree_data, distance_data, agent_data
def normalize_observation(observation, num_features_per_node=11, observation_radius=0):
data, distance, agent_data = split_tree(tree=np.array(observation), num_features_per_node=num_features_per_node,
current_depth=0)
data = norm_obs_clip(data, fixed_radius=observation_radius) data = norm_obs_clip(data, fixed_radius=observation_radius)
distance = norm_obs_clip(distance, normalize_to_range=True) distance = norm_obs_clip(distance, normalize_to_range=True)
agent_data = np.clip(agent_data, -1, 1) agent_data = np.clip(agent_data, -1, 1)
......