From f9d47bba3379d60b0564f5c81421a4af53a5ba9f Mon Sep 17 00:00:00 2001 From: Adrian Egli <adrian.egli@sbb.ch> Date: Fri, 6 Nov 2020 23:58:19 +0100 Subject: [PATCH] rl --- apt.txt | 10 +- reinforcement_learning/dddqn_policy.py | 402 +++--- reinforcement_learning/evaluate_agent.py | 752 +++++----- .../multi_agent_training.py | 1216 ++++++++--------- reinforcement_learning/policy.py | 52 +- reinforcement_learning/ppo/model.py | 40 +- reinforcement_learning/ppo/ppo_agent.py | 262 ++-- reinforcement_learning/ppo/replay_memory.py | 106 +- run.py | 430 +++--- utils/dead_lock_avoidance_agent.py | 350 ++--- utils/fast_tree_obs.py | 616 ++++----- utils/shortest_distance_walker.py | 174 +-- 12 files changed, 2205 insertions(+), 2205 deletions(-) diff --git a/apt.txt b/apt.txt index c0e0ffb..d593bcc 100644 --- a/apt.txt +++ b/apt.txt @@ -1,6 +1,6 @@ -curl -git -vim -ssh -gcc +curl +git +vim +ssh +gcc build-essential \ No newline at end of file diff --git a/reinforcement_learning/dddqn_policy.py b/reinforcement_learning/dddqn_policy.py index 1c323c3..6218ab8 100644 --- a/reinforcement_learning/dddqn_policy.py +++ b/reinforcement_learning/dddqn_policy.py @@ -1,201 +1,201 @@ -import copy -import os -import pickle -import random -from collections import namedtuple, deque, Iterable - -import numpy as np -import torch -import torch.nn.functional as F -import torch.optim as optim - -from reinforcement_learning.model import DuelingQNetwork -from reinforcement_learning.policy import Policy - - -class DDDQNPolicy(Policy): - """Dueling Double DQN policy""" - - def __init__(self, state_size, action_size, parameters, evaluation_mode=False): - self.evaluation_mode = evaluation_mode - - self.state_size = state_size - self.action_size = action_size - self.double_dqn = True - self.hidsize = 128 - - if not evaluation_mode: - self.hidsize = parameters.hidden_size - self.buffer_size = parameters.buffer_size - self.batch_size = parameters.batch_size - self.update_every = parameters.update_every - self.learning_rate = parameters.learning_rate - self.tau = parameters.tau - self.gamma = parameters.gamma - self.buffer_min_size = parameters.buffer_min_size - - # Device - if parameters.use_gpu and torch.cuda.is_available(): - self.device = torch.device("cuda:0") - # print("🇠Using GPU") - else: - self.device = torch.device("cpu") - # print("🢠Using CPU") - - # Q-Network - self.qnetwork_local = DuelingQNetwork(state_size, action_size, hidsize1=self.hidsize, hidsize2=self.hidsize).to( - self.device) - - if not evaluation_mode: - self.qnetwork_target = copy.deepcopy(self.qnetwork_local) - self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=self.learning_rate) - self.memory = ReplayBuffer(action_size, self.buffer_size, self.batch_size, self.device) - - self.t_step = 0 - self.loss = 0.0 - - def act(self, state, eps=0.): - state = torch.from_numpy(state).float().unsqueeze(0).to(self.device) - self.qnetwork_local.eval() - with torch.no_grad(): - action_values = self.qnetwork_local(state) - self.qnetwork_local.train() - - # Epsilon-greedy action selection - if random.random() > eps: - return np.argmax(action_values.cpu().data.numpy()) - else: - return random.choice(np.arange(self.action_size)) - - def step(self, handle, state, action, reward, next_state, done): - assert not self.evaluation_mode, "Policy has been initialized for evaluation only." - - # Save experience in replay memory - self.memory.add(state, action, reward, next_state, done) - - # Learn every UPDATE_EVERY time steps. - self.t_step = (self.t_step + 1) % self.update_every - if self.t_step == 0: - # If enough samples are available in memory, get random subset and learn - if len(self.memory) > self.buffer_min_size and len(self.memory) > self.batch_size: - self._learn() - - def _learn(self): - experiences = self.memory.sample() - states, actions, rewards, next_states, dones = experiences - - # Get expected Q values from local model - q_expected = self.qnetwork_local(states).gather(1, actions) - - if self.double_dqn: - # Double DQN - q_best_action = self.qnetwork_local(next_states).max(1)[1] - q_targets_next = self.qnetwork_target(next_states).gather(1, q_best_action.unsqueeze(-1)) - else: - # DQN - q_targets_next = self.qnetwork_target(next_states).detach().max(1)[0].unsqueeze(-1) - - # Compute Q targets for current states - q_targets = rewards + (self.gamma * q_targets_next * (1 - dones)) - - # Compute loss - self.loss = F.mse_loss(q_expected, q_targets) - - # Minimize the loss - self.optimizer.zero_grad() - self.loss.backward() - self.optimizer.step() - - # Update target network - self._soft_update(self.qnetwork_local, self.qnetwork_target, self.tau) - - def _soft_update(self, local_model, target_model, tau): - # Soft update model parameters. - # θ_target = Ï„*θ_local + (1 - Ï„)*θ_target - for target_param, local_param in zip(target_model.parameters(), local_model.parameters()): - target_param.data.copy_(tau * local_param.data + (1.0 - tau) * target_param.data) - - def save(self, filename): - torch.save(self.qnetwork_local.state_dict(), filename + ".local") - torch.save(self.qnetwork_target.state_dict(), filename + ".target") - - def load(self, filename): - try: - if os.path.exists(filename + ".local") and os.path.exists(filename + ".target"): - self.qnetwork_local.load_state_dict(torch.load(filename + ".local")) - print("qnetwork_local loaded ('{}')".format(filename + ".local")) - if not self.evaluation_mode: - self.qnetwork_target.load_state_dict(torch.load(filename + ".target")) - print("qnetwork_target loaded ('{}' )".format(filename + ".target")) - else: - print(">> Checkpoint not found, using untrained policy! ('{}', '{}')".format(filename + ".local", - filename + ".target")) - except Exception as exc: - print(exc) - print("Couldn't load policy from, using untrained policy! ('{}', '{}')".format(filename + ".local", - filename + ".target")) - - def save_replay_buffer(self, filename): - memory = self.memory.memory - with open(filename, 'wb') as f: - pickle.dump(list(memory)[-500000:], f) - - def load_replay_buffer(self, filename): - with open(filename, 'rb') as f: - self.memory.memory = pickle.load(f) - - def test(self): - self.act(np.array([[0] * self.state_size])) - self._learn() - - -Experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"]) - - -class ReplayBuffer: - """Fixed-size buffer to store experience tuples.""" - - def __init__(self, action_size, buffer_size, batch_size, device): - """Initialize a ReplayBuffer object. - - Params - ====== - action_size (int): dimension of each action - buffer_size (int): maximum size of buffer - batch_size (int): size of each training batch - """ - self.action_size = action_size - self.memory = deque(maxlen=buffer_size) - self.batch_size = batch_size - self.device = device - - def add(self, state, action, reward, next_state, done): - """Add a new experience to memory.""" - e = Experience(np.expand_dims(state, 0), action, reward, np.expand_dims(next_state, 0), done) - self.memory.append(e) - - def sample(self): - """Randomly sample a batch of experiences from memory.""" - experiences = random.sample(self.memory, k=self.batch_size) - - states = torch.from_numpy(self.__v_stack_impr([e.state for e in experiences if e is not None])) \ - .float().to(self.device) - actions = torch.from_numpy(self.__v_stack_impr([e.action for e in experiences if e is not None])) \ - .long().to(self.device) - rewards = torch.from_numpy(self.__v_stack_impr([e.reward for e in experiences if e is not None])) \ - .float().to(self.device) - next_states = torch.from_numpy(self.__v_stack_impr([e.next_state for e in experiences if e is not None])) \ - .float().to(self.device) - dones = torch.from_numpy(self.__v_stack_impr([e.done for e in experiences if e is not None]).astype(np.uint8)) \ - .float().to(self.device) - - return states, actions, rewards, next_states, dones - - def __len__(self): - """Return the current size of internal memory.""" - return len(self.memory) - - def __v_stack_impr(self, states): - sub_dim = len(states[0][0]) if isinstance(states[0], Iterable) else 1 - np_states = np.reshape(np.array(states), (len(states), sub_dim)) - return np_states +import copy +import os +import pickle +import random +from collections import namedtuple, deque, Iterable + +import numpy as np +import torch +import torch.nn.functional as F +import torch.optim as optim + +from reinforcement_learning.model import DuelingQNetwork +from reinforcement_learning.policy import Policy + + +class DDDQNPolicy(Policy): + """Dueling Double DQN policy""" + + def __init__(self, state_size, action_size, parameters, evaluation_mode=False): + self.evaluation_mode = evaluation_mode + + self.state_size = state_size + self.action_size = action_size + self.double_dqn = True + self.hidsize = 128 + + if not evaluation_mode: + self.hidsize = parameters.hidden_size + self.buffer_size = parameters.buffer_size + self.batch_size = parameters.batch_size + self.update_every = parameters.update_every + self.learning_rate = parameters.learning_rate + self.tau = parameters.tau + self.gamma = parameters.gamma + self.buffer_min_size = parameters.buffer_min_size + + # Device + if parameters.use_gpu and torch.cuda.is_available(): + self.device = torch.device("cuda:0") + # print("🇠Using GPU") + else: + self.device = torch.device("cpu") + # print("🢠Using CPU") + + # Q-Network + self.qnetwork_local = DuelingQNetwork(state_size, action_size, hidsize1=self.hidsize, hidsize2=self.hidsize).to( + self.device) + + if not evaluation_mode: + self.qnetwork_target = copy.deepcopy(self.qnetwork_local) + self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=self.learning_rate) + self.memory = ReplayBuffer(action_size, self.buffer_size, self.batch_size, self.device) + + self.t_step = 0 + self.loss = 0.0 + + def act(self, state, eps=0.): + state = torch.from_numpy(state).float().unsqueeze(0).to(self.device) + self.qnetwork_local.eval() + with torch.no_grad(): + action_values = self.qnetwork_local(state) + self.qnetwork_local.train() + + # Epsilon-greedy action selection + if random.random() > eps: + return np.argmax(action_values.cpu().data.numpy()) + else: + return random.choice(np.arange(self.action_size)) + + def step(self, handle, state, action, reward, next_state, done): + assert not self.evaluation_mode, "Policy has been initialized for evaluation only." + + # Save experience in replay memory + self.memory.add(state, action, reward, next_state, done) + + # Learn every UPDATE_EVERY time steps. + self.t_step = (self.t_step + 1) % self.update_every + if self.t_step == 0: + # If enough samples are available in memory, get random subset and learn + if len(self.memory) > self.buffer_min_size and len(self.memory) > self.batch_size: + self._learn() + + def _learn(self): + experiences = self.memory.sample() + states, actions, rewards, next_states, dones = experiences + + # Get expected Q values from local model + q_expected = self.qnetwork_local(states).gather(1, actions) + + if self.double_dqn: + # Double DQN + q_best_action = self.qnetwork_local(next_states).max(1)[1] + q_targets_next = self.qnetwork_target(next_states).gather(1, q_best_action.unsqueeze(-1)) + else: + # DQN + q_targets_next = self.qnetwork_target(next_states).detach().max(1)[0].unsqueeze(-1) + + # Compute Q targets for current states + q_targets = rewards + (self.gamma * q_targets_next * (1 - dones)) + + # Compute loss + self.loss = F.mse_loss(q_expected, q_targets) + + # Minimize the loss + self.optimizer.zero_grad() + self.loss.backward() + self.optimizer.step() + + # Update target network + self._soft_update(self.qnetwork_local, self.qnetwork_target, self.tau) + + def _soft_update(self, local_model, target_model, tau): + # Soft update model parameters. + # θ_target = Ï„*θ_local + (1 - Ï„)*θ_target + for target_param, local_param in zip(target_model.parameters(), local_model.parameters()): + target_param.data.copy_(tau * local_param.data + (1.0 - tau) * target_param.data) + + def save(self, filename): + torch.save(self.qnetwork_local.state_dict(), filename + ".local") + torch.save(self.qnetwork_target.state_dict(), filename + ".target") + + def load(self, filename): + try: + if os.path.exists(filename + ".local") and os.path.exists(filename + ".target"): + self.qnetwork_local.load_state_dict(torch.load(filename + ".local")) + print("qnetwork_local loaded ('{}')".format(filename + ".local")) + if not self.evaluation_mode: + self.qnetwork_target.load_state_dict(torch.load(filename + ".target")) + print("qnetwork_target loaded ('{}' )".format(filename + ".target")) + else: + print(">> Checkpoint not found, using untrained policy! ('{}', '{}')".format(filename + ".local", + filename + ".target")) + except Exception as exc: + print(exc) + print("Couldn't load policy from, using untrained policy! ('{}', '{}')".format(filename + ".local", + filename + ".target")) + + def save_replay_buffer(self, filename): + memory = self.memory.memory + with open(filename, 'wb') as f: + pickle.dump(list(memory)[-500000:], f) + + def load_replay_buffer(self, filename): + with open(filename, 'rb') as f: + self.memory.memory = pickle.load(f) + + def test(self): + self.act(np.array([[0] * self.state_size])) + self._learn() + + +Experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"]) + + +class ReplayBuffer: + """Fixed-size buffer to store experience tuples.""" + + def __init__(self, action_size, buffer_size, batch_size, device): + """Initialize a ReplayBuffer object. + + Params + ====== + action_size (int): dimension of each action + buffer_size (int): maximum size of buffer + batch_size (int): size of each training batch + """ + self.action_size = action_size + self.memory = deque(maxlen=buffer_size) + self.batch_size = batch_size + self.device = device + + def add(self, state, action, reward, next_state, done): + """Add a new experience to memory.""" + e = Experience(np.expand_dims(state, 0), action, reward, np.expand_dims(next_state, 0), done) + self.memory.append(e) + + def sample(self): + """Randomly sample a batch of experiences from memory.""" + experiences = random.sample(self.memory, k=self.batch_size) + + states = torch.from_numpy(self.__v_stack_impr([e.state for e in experiences if e is not None])) \ + .float().to(self.device) + actions = torch.from_numpy(self.__v_stack_impr([e.action for e in experiences if e is not None])) \ + .long().to(self.device) + rewards = torch.from_numpy(self.__v_stack_impr([e.reward for e in experiences if e is not None])) \ + .float().to(self.device) + next_states = torch.from_numpy(self.__v_stack_impr([e.next_state for e in experiences if e is not None])) \ + .float().to(self.device) + dones = torch.from_numpy(self.__v_stack_impr([e.done for e in experiences if e is not None]).astype(np.uint8)) \ + .float().to(self.device) + + return states, actions, rewards, next_states, dones + + def __len__(self): + """Return the current size of internal memory.""" + return len(self.memory) + + def __v_stack_impr(self, states): + sub_dim = len(states[0][0]) if isinstance(states[0], Iterable) else 1 + np_states = np.reshape(np.array(states), (len(states), sub_dim)) + return np_states diff --git a/reinforcement_learning/evaluate_agent.py b/reinforcement_learning/evaluate_agent.py index 80d4298..64eb943 100644 --- a/reinforcement_learning/evaluate_agent.py +++ b/reinforcement_learning/evaluate_agent.py @@ -1,376 +1,376 @@ -import math -import multiprocessing -import os -import sys -from argparse import ArgumentParser, Namespace -from multiprocessing import Pool -from pathlib import Path -from pprint import pprint - -import numpy as np -import torch -from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters -from flatland.envs.observations import TreeObsForRailEnv -from flatland.envs.predictions import ShortestPathPredictorForRailEnv -from flatland.envs.rail_env import RailEnv -from flatland.envs.rail_generators import sparse_rail_generator -from flatland.envs.schedule_generators import sparse_schedule_generator -from flatland.utils.rendertools import RenderTool - -base_dir = Path(__file__).resolve().parent.parent -sys.path.append(str(base_dir)) - -from utils.deadlock_check import check_if_all_blocked -from utils.timer import Timer -from utils.observation_utils import normalize_observation -from reinforcement_learning.dddqn_policy import DDDQNPolicy - - -def eval_policy(env_params, checkpoint, n_eval_episodes, max_steps, action_size, state_size, seed, render, allow_skipping, allow_caching): - # Evaluation is faster on CPU (except if you use a really huge policy) - parameters = { - 'use_gpu': False - } - - policy = DDDQNPolicy(state_size, action_size, Namespace(**parameters), evaluation_mode=True) - policy.qnetwork_local = torch.load(checkpoint) - - env_params = Namespace(**env_params) - - # Environment parameters - n_agents = env_params.n_agents - x_dim = env_params.x_dim - y_dim = env_params.y_dim - n_cities = env_params.n_cities - max_rails_between_cities = env_params.max_rails_between_cities - max_rails_in_city = env_params.max_rails_in_city - - # Malfunction and speed profiles - # TODO pass these parameters properly from main! - malfunction_parameters = MalfunctionParameters( - malfunction_rate=1. / 2000, # Rate of malfunctions - min_duration=20, # Minimal duration - max_duration=50 # Max duration - ) - - # Only fast trains in Round 1 - speed_profiles = { - 1.: 1.0, # Fast passenger train - 1. / 2.: 0.0, # Fast freight train - 1. / 3.: 0.0, # Slow commuter train - 1. / 4.: 0.0 # Slow freight train - } - - # Observation parameters - observation_tree_depth = env_params.observation_tree_depth - observation_radius = env_params.observation_radius - observation_max_path_depth = env_params.observation_max_path_depth - - # Observation builder - predictor = ShortestPathPredictorForRailEnv(observation_max_path_depth) - tree_observation = TreeObsForRailEnv(max_depth=observation_tree_depth, predictor=predictor) - - # Setup the environment - env = RailEnv( - width=x_dim, height=y_dim, - rail_generator=sparse_rail_generator( - max_num_cities=n_cities, - grid_mode=False, - max_rails_between_cities=max_rails_between_cities, - max_rails_in_city=max_rails_in_city, - ), - schedule_generator=sparse_schedule_generator(speed_profiles), - number_of_agents=n_agents, - malfunction_generator_and_process_data=malfunction_from_params(malfunction_parameters), - obs_builder_object=tree_observation - ) - - if render: - env_renderer = RenderTool(env, gl="PGL") - - action_dict = dict() - scores = [] - completions = [] - nb_steps = [] - inference_times = [] - preproc_times = [] - agent_times = [] - step_times = [] - - for episode_idx in range(n_eval_episodes): - seed += 1 - - inference_timer = Timer() - preproc_timer = Timer() - agent_timer = Timer() - step_timer = Timer() - - step_timer.start() - obs, info = env.reset(regenerate_rail=True, regenerate_schedule=True, random_seed=seed) - step_timer.end() - - agent_obs = [None] * env.get_num_agents() - score = 0.0 - - if render: - env_renderer.set_new_rail() - - final_step = 0 - skipped = 0 - - nb_hit = 0 - agent_last_obs = {} - agent_last_action = {} - - for step in range(max_steps - 1): - if allow_skipping and check_if_all_blocked(env): - # FIXME why -1? bug where all agents are "done" after max_steps! - skipped = max_steps - step - 1 - final_step = max_steps - 2 - n_unfinished_agents = sum(not done[idx] for idx in env.get_agent_handles()) - score -= skipped * n_unfinished_agents - break - - agent_timer.start() - for agent in env.get_agent_handles(): - if obs[agent] and info['action_required'][agent]: - if agent in agent_last_obs and np.all(agent_last_obs[agent] == obs[agent]): - nb_hit += 1 - action = agent_last_action[agent] - - else: - preproc_timer.start() - norm_obs = normalize_observation(obs[agent], tree_depth=observation_tree_depth, observation_radius=observation_radius) - preproc_timer.end() - - inference_timer.start() - action = policy.act(norm_obs, eps=0.0) - inference_timer.end() - - action_dict.update({agent: action}) - - if allow_caching: - agent_last_obs[agent] = obs[agent] - agent_last_action[agent] = action - agent_timer.end() - - step_timer.start() - obs, all_rewards, done, info = env.step(action_dict) - step_timer.end() - - if render: - env_renderer.render_env( - show=True, - frames=False, - show_observations=False, - show_predictions=False - ) - - if step % 100 == 0: - print("{}/{}".format(step, max_steps - 1)) - - for agent in env.get_agent_handles(): - score += all_rewards[agent] - - final_step = step - - if done['__all__']: - break - - normalized_score = score / (max_steps * env.get_num_agents()) - scores.append(normalized_score) - - tasks_finished = sum(done[idx] for idx in env.get_agent_handles()) - completion = tasks_finished / max(1, env.get_num_agents()) - completions.append(completion) - - nb_steps.append(final_step) - - inference_times.append(inference_timer.get()) - preproc_times.append(preproc_timer.get()) - agent_times.append(agent_timer.get()) - step_times.append(step_timer.get()) - - skipped_text = "" - if skipped > 0: - skipped_text = "\tâš¡ Skipped {}".format(skipped) - - hit_text = "" - if nb_hit > 0: - hit_text = "\tâš¡ Hit {} ({:.1f}%)".format(nb_hit, (100 * nb_hit) / (n_agents * final_step)) - - print( - "â˜‘ï¸ Score: {:.3f} \tDone: {:.1f}% \tNb steps: {:.3f} " - "\tðŸ Seed: {}" - "\t🚉 Env: {:.3f}s " - "\t🤖 Agent: {:.3f}s (per step: {:.3f}s) \t[preproc: {:.3f}s \tinfer: {:.3f}s]" - "{}{}".format( - normalized_score, - completion * 100.0, - final_step, - seed, - step_timer.get(), - agent_timer.get(), - agent_timer.get() / final_step, - preproc_timer.get(), - inference_timer.get(), - skipped_text, - hit_text - ) - ) - - return scores, completions, nb_steps, agent_times, step_times - - -def evaluate_agents(file, n_evaluation_episodes, use_gpu, render, allow_skipping, allow_caching): - nb_threads = 1 - eval_per_thread = n_evaluation_episodes - - if not render: - nb_threads = multiprocessing.cpu_count() - eval_per_thread = max(1, math.ceil(n_evaluation_episodes / nb_threads)) - - total_nb_eval = eval_per_thread * nb_threads - print("Will evaluate policy {} over {} episodes on {} threads.".format(file, total_nb_eval, nb_threads)) - - if total_nb_eval != n_evaluation_episodes: - print("(Rounding up from {} to fill all cores)".format(n_evaluation_episodes)) - - # Observation parameters need to match the ones used during training! - - # small_v0 - small_v0_params = { - # sample configuration - "n_agents": 5, - "x_dim": 25, - "y_dim": 25, - "n_cities": 4, - "max_rails_between_cities": 2, - "max_rails_in_city": 3, - - # observations - "observation_tree_depth": 2, - "observation_radius": 10, - "observation_max_path_depth": 20 - } - - # Test_0 - test0_params = { - # sample configuration - "n_agents": 5, - "x_dim": 25, - "y_dim": 25, - "n_cities": 2, - "max_rails_between_cities": 2, - "max_rails_in_city": 3, - - # observations - "observation_tree_depth": 2, - "observation_radius": 10, - "observation_max_path_depth": 20 - } - - # Test_1 - test1_params = { - # environment - "n_agents": 10, - "x_dim": 30, - "y_dim": 30, - "n_cities": 2, - "max_rails_between_cities": 2, - "max_rails_in_city": 3, - - # observations - "observation_tree_depth": 2, - "observation_radius": 10, - "observation_max_path_depth": 10 - } - - # Test_5 - test5_params = { - # environment - "n_agents": 80, - "x_dim": 35, - "y_dim": 35, - "n_cities": 5, - "max_rails_between_cities": 2, - "max_rails_in_city": 4, - - # observations - "observation_tree_depth": 2, - "observation_radius": 10, - "observation_max_path_depth": 20 - } - - params = small_v0_params - env_params = Namespace(**params) - - print("Environment parameters:") - pprint(params) - - # Calculate space dimensions and max steps - max_steps = int(4 * 2 * (env_params.x_dim + env_params.y_dim + (env_params.n_agents / env_params.n_cities))) - action_size = 5 - tree_observation = TreeObsForRailEnv(max_depth=env_params.observation_tree_depth) - tree_depth = env_params.observation_tree_depth - num_features_per_node = tree_observation.observation_dim - n_nodes = sum([np.power(4, i) for i in range(tree_depth + 1)]) - state_size = num_features_per_node * n_nodes - - results = [] - if render: - results.append(eval_policy(params, file, eval_per_thread, max_steps, action_size, state_size, 0, render, allow_skipping, allow_caching)) - - else: - with Pool() as p: - results = p.starmap(eval_policy, - [(params, file, 1, max_steps, action_size, state_size, seed * nb_threads, render, allow_skipping, allow_caching) - for seed in - range(total_nb_eval)]) - - scores = [] - completions = [] - nb_steps = [] - times = [] - step_times = [] - for s, c, n, t, st in results: - scores.append(s) - completions.append(c) - nb_steps.append(n) - times.append(t) - step_times.append(st) - - print("-" * 200) - - print("✅ Score: {:.3f} \tDone: {:.1f}% \tNb steps: {:.3f} \tAgent total: {:.3f}s (per step: {:.3f}s)".format( - np.mean(scores), - np.mean(completions) * 100.0, - np.mean(nb_steps), - np.mean(times), - np.mean(times) / np.mean(nb_steps) - )) - - print("â²ï¸ Agent sum: {:.3f}s \tEnv sum: {:.3f}s \tTotal sum: {:.3f}s".format( - np.sum(times), - np.sum(step_times), - np.sum(times) + np.sum(step_times) - )) - - -if __name__ == "__main__": - parser = ArgumentParser() - parser.add_argument("-f", "--file", help="checkpoint to load", required=True, type=str) - parser.add_argument("-n", "--n_evaluation_episodes", help="number of evaluation episodes", default=25, type=int) - - # TODO - # parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=0, type=int) - - parser.add_argument("--use_gpu", dest="use_gpu", help="use GPU if available", action='store_true') - parser.add_argument("--render", help="render a single episode", action='store_true') - parser.add_argument("--allow_skipping", help="skips to the end of the episode if all agents are deadlocked", action='store_true') - parser.add_argument("--allow_caching", help="caches the last observation-action pair", action='store_true') - args = parser.parse_args() - - os.environ["OMP_NUM_THREADS"] = str(1) - evaluate_agents(file=args.file, n_evaluation_episodes=args.n_evaluation_episodes, use_gpu=args.use_gpu, render=args.render, - allow_skipping=args.allow_skipping, allow_caching=args.allow_caching) +import math +import multiprocessing +import os +import sys +from argparse import ArgumentParser, Namespace +from multiprocessing import Pool +from pathlib import Path +from pprint import pprint + +import numpy as np +import torch +from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters +from flatland.envs.observations import TreeObsForRailEnv +from flatland.envs.predictions import ShortestPathPredictorForRailEnv +from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import sparse_rail_generator +from flatland.envs.schedule_generators import sparse_schedule_generator +from flatland.utils.rendertools import RenderTool + +base_dir = Path(__file__).resolve().parent.parent +sys.path.append(str(base_dir)) + +from utils.deadlock_check import check_if_all_blocked +from utils.timer import Timer +from utils.observation_utils import normalize_observation +from reinforcement_learning.dddqn_policy import DDDQNPolicy + + +def eval_policy(env_params, checkpoint, n_eval_episodes, max_steps, action_size, state_size, seed, render, allow_skipping, allow_caching): + # Evaluation is faster on CPU (except if you use a really huge policy) + parameters = { + 'use_gpu': False + } + + policy = DDDQNPolicy(state_size, action_size, Namespace(**parameters), evaluation_mode=True) + policy.qnetwork_local = torch.load(checkpoint) + + env_params = Namespace(**env_params) + + # Environment parameters + n_agents = env_params.n_agents + x_dim = env_params.x_dim + y_dim = env_params.y_dim + n_cities = env_params.n_cities + max_rails_between_cities = env_params.max_rails_between_cities + max_rails_in_city = env_params.max_rails_in_city + + # Malfunction and speed profiles + # TODO pass these parameters properly from main! + malfunction_parameters = MalfunctionParameters( + malfunction_rate=1. / 2000, # Rate of malfunctions + min_duration=20, # Minimal duration + max_duration=50 # Max duration + ) + + # Only fast trains in Round 1 + speed_profiles = { + 1.: 1.0, # Fast passenger train + 1. / 2.: 0.0, # Fast freight train + 1. / 3.: 0.0, # Slow commuter train + 1. / 4.: 0.0 # Slow freight train + } + + # Observation parameters + observation_tree_depth = env_params.observation_tree_depth + observation_radius = env_params.observation_radius + observation_max_path_depth = env_params.observation_max_path_depth + + # Observation builder + predictor = ShortestPathPredictorForRailEnv(observation_max_path_depth) + tree_observation = TreeObsForRailEnv(max_depth=observation_tree_depth, predictor=predictor) + + # Setup the environment + env = RailEnv( + width=x_dim, height=y_dim, + rail_generator=sparse_rail_generator( + max_num_cities=n_cities, + grid_mode=False, + max_rails_between_cities=max_rails_between_cities, + max_rails_in_city=max_rails_in_city, + ), + schedule_generator=sparse_schedule_generator(speed_profiles), + number_of_agents=n_agents, + malfunction_generator_and_process_data=malfunction_from_params(malfunction_parameters), + obs_builder_object=tree_observation + ) + + if render: + env_renderer = RenderTool(env, gl="PGL") + + action_dict = dict() + scores = [] + completions = [] + nb_steps = [] + inference_times = [] + preproc_times = [] + agent_times = [] + step_times = [] + + for episode_idx in range(n_eval_episodes): + seed += 1 + + inference_timer = Timer() + preproc_timer = Timer() + agent_timer = Timer() + step_timer = Timer() + + step_timer.start() + obs, info = env.reset(regenerate_rail=True, regenerate_schedule=True, random_seed=seed) + step_timer.end() + + agent_obs = [None] * env.get_num_agents() + score = 0.0 + + if render: + env_renderer.set_new_rail() + + final_step = 0 + skipped = 0 + + nb_hit = 0 + agent_last_obs = {} + agent_last_action = {} + + for step in range(max_steps - 1): + if allow_skipping and check_if_all_blocked(env): + # FIXME why -1? bug where all agents are "done" after max_steps! + skipped = max_steps - step - 1 + final_step = max_steps - 2 + n_unfinished_agents = sum(not done[idx] for idx in env.get_agent_handles()) + score -= skipped * n_unfinished_agents + break + + agent_timer.start() + for agent in env.get_agent_handles(): + if obs[agent] and info['action_required'][agent]: + if agent in agent_last_obs and np.all(agent_last_obs[agent] == obs[agent]): + nb_hit += 1 + action = agent_last_action[agent] + + else: + preproc_timer.start() + norm_obs = normalize_observation(obs[agent], tree_depth=observation_tree_depth, observation_radius=observation_radius) + preproc_timer.end() + + inference_timer.start() + action = policy.act(norm_obs, eps=0.0) + inference_timer.end() + + action_dict.update({agent: action}) + + if allow_caching: + agent_last_obs[agent] = obs[agent] + agent_last_action[agent] = action + agent_timer.end() + + step_timer.start() + obs, all_rewards, done, info = env.step(action_dict) + step_timer.end() + + if render: + env_renderer.render_env( + show=True, + frames=False, + show_observations=False, + show_predictions=False + ) + + if step % 100 == 0: + print("{}/{}".format(step, max_steps - 1)) + + for agent in env.get_agent_handles(): + score += all_rewards[agent] + + final_step = step + + if done['__all__']: + break + + normalized_score = score / (max_steps * env.get_num_agents()) + scores.append(normalized_score) + + tasks_finished = sum(done[idx] for idx in env.get_agent_handles()) + completion = tasks_finished / max(1, env.get_num_agents()) + completions.append(completion) + + nb_steps.append(final_step) + + inference_times.append(inference_timer.get()) + preproc_times.append(preproc_timer.get()) + agent_times.append(agent_timer.get()) + step_times.append(step_timer.get()) + + skipped_text = "" + if skipped > 0: + skipped_text = "\tâš¡ Skipped {}".format(skipped) + + hit_text = "" + if nb_hit > 0: + hit_text = "\tâš¡ Hit {} ({:.1f}%)".format(nb_hit, (100 * nb_hit) / (n_agents * final_step)) + + print( + "â˜‘ï¸ Score: {:.3f} \tDone: {:.1f}% \tNb steps: {:.3f} " + "\tðŸ Seed: {}" + "\t🚉 Env: {:.3f}s " + "\t🤖 Agent: {:.3f}s (per step: {:.3f}s) \t[preproc: {:.3f}s \tinfer: {:.3f}s]" + "{}{}".format( + normalized_score, + completion * 100.0, + final_step, + seed, + step_timer.get(), + agent_timer.get(), + agent_timer.get() / final_step, + preproc_timer.get(), + inference_timer.get(), + skipped_text, + hit_text + ) + ) + + return scores, completions, nb_steps, agent_times, step_times + + +def evaluate_agents(file, n_evaluation_episodes, use_gpu, render, allow_skipping, allow_caching): + nb_threads = 1 + eval_per_thread = n_evaluation_episodes + + if not render: + nb_threads = multiprocessing.cpu_count() + eval_per_thread = max(1, math.ceil(n_evaluation_episodes / nb_threads)) + + total_nb_eval = eval_per_thread * nb_threads + print("Will evaluate policy {} over {} episodes on {} threads.".format(file, total_nb_eval, nb_threads)) + + if total_nb_eval != n_evaluation_episodes: + print("(Rounding up from {} to fill all cores)".format(n_evaluation_episodes)) + + # Observation parameters need to match the ones used during training! + + # small_v0 + small_v0_params = { + # sample configuration + "n_agents": 5, + "x_dim": 25, + "y_dim": 25, + "n_cities": 4, + "max_rails_between_cities": 2, + "max_rails_in_city": 3, + + # observations + "observation_tree_depth": 2, + "observation_radius": 10, + "observation_max_path_depth": 20 + } + + # Test_0 + test0_params = { + # sample configuration + "n_agents": 5, + "x_dim": 25, + "y_dim": 25, + "n_cities": 2, + "max_rails_between_cities": 2, + "max_rails_in_city": 3, + + # observations + "observation_tree_depth": 2, + "observation_radius": 10, + "observation_max_path_depth": 20 + } + + # Test_1 + test1_params = { + # environment + "n_agents": 10, + "x_dim": 30, + "y_dim": 30, + "n_cities": 2, + "max_rails_between_cities": 2, + "max_rails_in_city": 3, + + # observations + "observation_tree_depth": 2, + "observation_radius": 10, + "observation_max_path_depth": 10 + } + + # Test_5 + test5_params = { + # environment + "n_agents": 80, + "x_dim": 35, + "y_dim": 35, + "n_cities": 5, + "max_rails_between_cities": 2, + "max_rails_in_city": 4, + + # observations + "observation_tree_depth": 2, + "observation_radius": 10, + "observation_max_path_depth": 20 + } + + params = small_v0_params + env_params = Namespace(**params) + + print("Environment parameters:") + pprint(params) + + # Calculate space dimensions and max steps + max_steps = int(4 * 2 * (env_params.x_dim + env_params.y_dim + (env_params.n_agents / env_params.n_cities))) + action_size = 5 + tree_observation = TreeObsForRailEnv(max_depth=env_params.observation_tree_depth) + tree_depth = env_params.observation_tree_depth + num_features_per_node = tree_observation.observation_dim + n_nodes = sum([np.power(4, i) for i in range(tree_depth + 1)]) + state_size = num_features_per_node * n_nodes + + results = [] + if render: + results.append(eval_policy(params, file, eval_per_thread, max_steps, action_size, state_size, 0, render, allow_skipping, allow_caching)) + + else: + with Pool() as p: + results = p.starmap(eval_policy, + [(params, file, 1, max_steps, action_size, state_size, seed * nb_threads, render, allow_skipping, allow_caching) + for seed in + range(total_nb_eval)]) + + scores = [] + completions = [] + nb_steps = [] + times = [] + step_times = [] + for s, c, n, t, st in results: + scores.append(s) + completions.append(c) + nb_steps.append(n) + times.append(t) + step_times.append(st) + + print("-" * 200) + + print("✅ Score: {:.3f} \tDone: {:.1f}% \tNb steps: {:.3f} \tAgent total: {:.3f}s (per step: {:.3f}s)".format( + np.mean(scores), + np.mean(completions) * 100.0, + np.mean(nb_steps), + np.mean(times), + np.mean(times) / np.mean(nb_steps) + )) + + print("â²ï¸ Agent sum: {:.3f}s \tEnv sum: {:.3f}s \tTotal sum: {:.3f}s".format( + np.sum(times), + np.sum(step_times), + np.sum(times) + np.sum(step_times) + )) + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("-f", "--file", help="checkpoint to load", required=True, type=str) + parser.add_argument("-n", "--n_evaluation_episodes", help="number of evaluation episodes", default=25, type=int) + + # TODO + # parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=0, type=int) + + parser.add_argument("--use_gpu", dest="use_gpu", help="use GPU if available", action='store_true') + parser.add_argument("--render", help="render a single episode", action='store_true') + parser.add_argument("--allow_skipping", help="skips to the end of the episode if all agents are deadlocked", action='store_true') + parser.add_argument("--allow_caching", help="caches the last observation-action pair", action='store_true') + args = parser.parse_args() + + os.environ["OMP_NUM_THREADS"] = str(1) + evaluate_agents(file=args.file, n_evaluation_episodes=args.n_evaluation_episodes, use_gpu=args.use_gpu, render=args.render, + allow_skipping=args.allow_skipping, allow_caching=args.allow_caching) diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py index b9d1039..74288e7 100755 --- a/reinforcement_learning/multi_agent_training.py +++ b/reinforcement_learning/multi_agent_training.py @@ -1,608 +1,608 @@ -import os -import random -import sys -from argparse import ArgumentParser, Namespace -from collections import deque -from datetime import datetime -from pathlib import Path -from pprint import pprint - -import numpy as np -import psutil -from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters -from flatland.envs.observations import TreeObsForRailEnv -from flatland.envs.predictions import ShortestPathPredictorForRailEnv -from flatland.envs.rail_env import RailEnv, RailEnvActions -from flatland.envs.rail_generators import sparse_rail_generator -from flatland.envs.schedule_generators import sparse_schedule_generator -from flatland.utils.rendertools import RenderTool -from torch.utils.tensorboard import SummaryWriter - -from reinforcement_learning.dddqn_policy import DDDQNPolicy -from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent - -base_dir = Path(__file__).resolve().parent.parent -sys.path.append(str(base_dir)) - -from utils.timer import Timer -from utils.observation_utils import normalize_observation -from utils.fast_tree_obs import FastTreeObs - -try: - import wandb - - wandb.init(sync_tensorboard=True) -except ImportError: - print("Install wandb to log to Weights & Biases") - -""" -This file shows how to train multiple agents using a reinforcement learning approach. -After training an agent, you can submit it straight away to the NeurIPS 2020 Flatland challenge! - -Agent documentation: https://flatland.aicrowd.com/getting-started/rl/multi-agent.html -Submission documentation: https://flatland.aicrowd.com/getting-started/first-submission.html -""" - - -def create_rail_env(env_params, tree_observation): - n_agents = env_params.n_agents - x_dim = env_params.x_dim - y_dim = env_params.y_dim - n_cities = env_params.n_cities - max_rails_between_cities = env_params.max_rails_between_cities - max_rails_in_city = env_params.max_rails_in_city - seed = env_params.seed - - # Break agents from time to time - malfunction_parameters = MalfunctionParameters( - malfunction_rate=env_params.malfunction_rate, - min_duration=20, - max_duration=50 - ) - - return RailEnv( - width=x_dim, height=y_dim, - rail_generator=sparse_rail_generator( - max_num_cities=n_cities, - grid_mode=False, - max_rails_between_cities=max_rails_between_cities, - max_rails_in_city=max_rails_in_city - ), - schedule_generator=sparse_schedule_generator(), - number_of_agents=n_agents, - malfunction_generator_and_process_data=malfunction_from_params(malfunction_parameters), - obs_builder_object=tree_observation, - random_seed=seed - ) - - -def train_agent(train_params, train_env_params, eval_env_params, obs_params): - # Environment parameters - n_agents = train_env_params.n_agents - x_dim = train_env_params.x_dim - y_dim = train_env_params.y_dim - n_cities = train_env_params.n_cities - max_rails_between_cities = train_env_params.max_rails_between_cities - max_rails_in_city = train_env_params.max_rails_in_city - seed = train_env_params.seed - - # Unique ID for this training - now = datetime.now() - training_id = now.strftime('%y%m%d%H%M%S') - - # Observation parameters - observation_tree_depth = obs_params.observation_tree_depth - observation_radius = obs_params.observation_radius - observation_max_path_depth = obs_params.observation_max_path_depth - - # Training parameters - eps_start = train_params.eps_start - eps_end = train_params.eps_end - eps_decay = train_params.eps_decay - n_episodes = train_params.n_episodes - checkpoint_interval = train_params.checkpoint_interval - n_eval_episodes = train_params.n_evaluation_episodes - restore_replay_buffer = train_params.restore_replay_buffer - save_replay_buffer = train_params.save_replay_buffer - - # Set the seeds - random.seed(seed) - np.random.seed(seed) - - # Observation builder - predictor = ShortestPathPredictorForRailEnv(observation_max_path_depth) - if not train_params.use_fast_tree_observation: - print("\nUsing standard TreeObs") - - def check_is_observation_valid(observation): - return observation - - def get_normalized_observation(observation, tree_depth: int, observation_radius=0): - return normalize_observation(observation, tree_depth, observation_radius) - - tree_observation = TreeObsForRailEnv(max_depth=observation_tree_depth, predictor=predictor) - tree_observation.check_is_observation_valid = check_is_observation_valid - tree_observation.get_normalized_observation = get_normalized_observation - else: - print("\nUsing FastTreeObs") - - def check_is_observation_valid(observation): - return True - - def get_normalized_observation(observation, tree_depth: int, observation_radius=0): - return observation - - tree_observation = FastTreeObs(max_depth=observation_tree_depth) - tree_observation.check_is_observation_valid = check_is_observation_valid - tree_observation.get_normalized_observation = get_normalized_observation - - # Setup the environments - train_env = create_rail_env(train_env_params, tree_observation) - train_env.reset(regenerate_schedule=True, regenerate_rail=True) - eval_env = create_rail_env(eval_env_params, tree_observation) - eval_env.reset(regenerate_schedule=True, regenerate_rail=True) - - if not train_params.use_fast_tree_observation: - # Calculate the state size given the depth of the tree observation and the number of features - n_features_per_node = train_env.obs_builder.observation_dim - n_nodes = sum([np.power(4, i) for i in range(observation_tree_depth + 1)]) - state_size = n_features_per_node * n_nodes - else: - # Calculate the state size given the depth of the tree observation and the number of features - state_size = tree_observation.observation_dim - - # Setup renderer - if train_params.render: - env_renderer = RenderTool(train_env, gl="PGL") - - # The action space of flatland is 5 discrete actions - action_size = 5 - - action_count = [0] * action_size - action_dict = dict() - agent_obs = [None] * n_agents - agent_prev_obs = [None] * n_agents - agent_prev_action = [2] * n_agents - update_values = [False] * n_agents - - # Smoothed values used as target for hyperparameter tuning - smoothed_eval_normalized_score = -1.0 - smoothed_eval_completion = 0.0 - - scores_window = deque(maxlen=checkpoint_interval) # todo smooth when rendering instead - completion_window = deque(maxlen=checkpoint_interval) - - # Double Dueling DQN policy - policy = DDDQNPolicy(state_size, action_size, train_params) - # policy = PPOAgent(state_size, action_size, n_agents) - # Load existing policy - if train_params.load_policy is not "": - policy.load(train_params.load_policy) - - # Loads existing replay buffer - if restore_replay_buffer: - try: - policy.load_replay_buffer(restore_replay_buffer) - policy.test() - except RuntimeError as e: - print("\n🛑 Could't load replay buffer, were the experiences generated using the same tree depth?") - print(e) - exit(1) - - print("\n💾 Replay buffer status: {}/{} experiences".format(len(policy.memory.memory), train_params.buffer_size)) - - hdd = psutil.disk_usage('/') - if save_replay_buffer and (hdd.free / (2 ** 30)) < 500.0: - print( - "âš ï¸ Careful! Saving replay buffers will quickly consume a lot of disk space. You have {:.2f}gb left.".format( - hdd.free / (2 ** 30))) - - # TensorBoard writer - writer = SummaryWriter() - - training_timer = Timer() - training_timer.start() - - print( - "\n🚉 Training {} trains on {}x{} grid for {} episodes, evaluating on {} episodes every {} episodes. Training id '{}'.\n".format( - train_env.get_num_agents(), - x_dim, y_dim, - n_episodes, - n_eval_episodes, - checkpoint_interval, - training_id - )) - - for episode_idx in range(n_episodes + 1): - step_timer = Timer() - reset_timer = Timer() - learn_timer = Timer() - preproc_timer = Timer() - inference_timer = Timer() - - # Reset environment - reset_timer.start() - train_env_params.n_agents = episode_idx % n_agents + 1 - train_env = create_rail_env(train_env_params, tree_observation) - obs, info = train_env.reset(regenerate_rail=True, regenerate_schedule=True) - policy.reset() - - policy2 = DeadLockAvoidanceAgent(train_env) - policy2.reset() - - reset_timer.end() - - if train_params.render: - env_renderer.set_new_rail() - - score = 0 - nb_steps = 0 - actions_taken = [] - - # Build initial agent-specific observations - for agent in train_env.get_agent_handles(): - if tree_observation.check_is_observation_valid(obs[agent]): - agent_obs[agent] = tree_observation.get_normalized_observation(obs[agent], observation_tree_depth, - observation_radius=observation_radius) - agent_prev_obs[agent] = agent_obs[agent].copy() - - # Max number of steps per episode - # This is the official formula used during evaluations - # See details in flatland.envs.schedule_generators.sparse_schedule_generator - # max_steps = int(4 * 2 * (env.height + env.width + (n_agents / n_cities))) - max_steps = train_env._max_episode_steps - - # Run episode - agent_to_learn = 0 - if train_env.get_num_agents() > 1: - agent_to_learn = np.random.choice(train_env.get_num_agents()) - for step in range(max_steps - 1): - inference_timer.start() - policy.start_step() - policy2.start_step() - for agent in train_env.get_agent_handles(): - if info['action_required'][agent]: - update_values[agent] = True - - if agent == agent_to_learn: - action = policy.act(agent_obs[agent], eps=eps_start) - else: - action = policy2.act([agent], eps=eps_start) - action_count[action] += 1 - actions_taken.append(action) - else: - # An action is not required if the train hasn't joined the railway network, - # if it already reached its target, or if is currently malfunctioning. - update_values[agent] = False - action = 0 - action_dict.update({agent: action}) - policy.end_step() - policy2.end_step() - inference_timer.end() - - # Environment step - step_timer.start() - next_obs, all_rewards, done, info = train_env.step(action_dict) - - for agent in train_env.get_agent_handles(): - act = action_dict.get(agent, RailEnvActions.DO_NOTHING) - if agent_obs[agent][26] == 1: - if act == RailEnvActions.STOP_MOVING: - all_rewards[agent] *= 0.01 - else: - if act == RailEnvActions.MOVE_LEFT: - all_rewards[agent] *= 0.9 - else: - if agent_obs[agent][7] == 0 and agent_obs[agent][8] == 0: - if act == RailEnvActions.MOVE_FORWARD: - all_rewards[agent] *= 0.01 - if done[agent]: - all_rewards[agent] += 100.0 - - step_timer.end() - - # Render an episode at some interval - if train_params.render and episode_idx % checkpoint_interval == 0: - env_renderer.render_env( - show=True, - frames=False, - show_observations=False, - show_predictions=False - ) - - # Update replay buffer and train agent - for agent in train_env.get_agent_handles(): - if update_values[agent] or done['__all__']: - # Only learn from timesteps where somethings happened - learn_timer.start() - if agent == agent_to_learn: - policy.step(agent, - agent_prev_obs[agent], agent_prev_action[agent], all_rewards[agent], - agent_obs[agent], - done[agent]) - learn_timer.end() - - agent_prev_obs[agent] = agent_obs[agent].copy() - agent_prev_action[agent] = action_dict[agent] - - # Preprocess the new observations - if tree_observation.check_is_observation_valid(next_obs[agent]): - preproc_timer.start() - agent_obs[agent] = tree_observation.get_normalized_observation(next_obs[agent], - observation_tree_depth, - observation_radius=observation_radius) - preproc_timer.end() - - score += all_rewards[agent] - - nb_steps = step - - if done['__all__']: - break - - # Epsilon decay - eps_start = max(eps_end, eps_decay * eps_start) - - # Collect information about training - tasks_finished = sum(done[idx] for idx in train_env.get_agent_handles()) - completion = tasks_finished / max(1, train_env.get_num_agents()) - normalized_score = score / (max_steps * train_env.get_num_agents()) - action_probs = action_count / max(1, np.sum(action_count)) - - scores_window.append(normalized_score) - completion_window.append(completion) - smoothed_normalized_score = np.mean(scores_window) - smoothed_completion = np.mean(completion_window) - - # Print logs - if episode_idx % checkpoint_interval == 0: - policy.save('./checkpoints/' + training_id + '-' + str(episode_idx) + '.pth') - - if save_replay_buffer: - policy.save_replay_buffer('./replay_buffers/' + training_id + '-' + str(episode_idx) + '.pkl') - - if train_params.render: - env_renderer.close_window() - - # reset action count - action_count = [0] * action_size - - print( - '\r🚂 Episode {}' - '\t 🆠Score: {:7.3f}' - ' Avg: {:7.3f}' - '\t 💯 Done: {:6.2f}%' - ' Avg: {:6.2f}%' - '\t 🎲 Epsilon: {:.3f} ' - '\t 🔀 Action Probs: {}'.format( - episode_idx, - normalized_score, - smoothed_normalized_score, - 100 * completion, - 100 * smoothed_completion, - eps_start, - format_action_prob(action_probs) - ), end=" ") - - # Evaluate policy and log results at some interval - if episode_idx % checkpoint_interval == 0 and n_eval_episodes > 0: - scores, completions, nb_steps_eval = eval_policy(eval_env, - tree_observation, - policy, - train_params, - obs_params) - - writer.add_scalar("evaluation/scores_min", np.min(scores), episode_idx) - writer.add_scalar("evaluation/scores_max", np.max(scores), episode_idx) - writer.add_scalar("evaluation/scores_mean", np.mean(scores), episode_idx) - writer.add_scalar("evaluation/scores_std", np.std(scores), episode_idx) - writer.add_histogram("evaluation/scores", np.array(scores), episode_idx) - writer.add_scalar("evaluation/completions_min", np.min(completions), episode_idx) - writer.add_scalar("evaluation/completions_max", np.max(completions), episode_idx) - writer.add_scalar("evaluation/completions_mean", np.mean(completions), episode_idx) - writer.add_scalar("evaluation/completions_std", np.std(completions), episode_idx) - writer.add_histogram("evaluation/completions", np.array(completions), episode_idx) - writer.add_scalar("evaluation/nb_steps_min", np.min(nb_steps_eval), episode_idx) - writer.add_scalar("evaluation/nb_steps_max", np.max(nb_steps_eval), episode_idx) - writer.add_scalar("evaluation/nb_steps_mean", np.mean(nb_steps_eval), episode_idx) - writer.add_scalar("evaluation/nb_steps_std", np.std(nb_steps_eval), episode_idx) - writer.add_histogram("evaluation/nb_steps", np.array(nb_steps_eval), episode_idx) - - smoothing = 0.9 - smoothed_eval_normalized_score = smoothed_eval_normalized_score * smoothing + np.mean(scores) * ( - 1.0 - smoothing) - smoothed_eval_completion = smoothed_eval_completion * smoothing + np.mean(completions) * (1.0 - smoothing) - writer.add_scalar("evaluation/smoothed_score", smoothed_eval_normalized_score, episode_idx) - writer.add_scalar("evaluation/smoothed_completion", smoothed_eval_completion, episode_idx) - - # Save logs to tensorboard - writer.add_scalar("training/score", normalized_score, episode_idx) - writer.add_scalar("training/smoothed_score", smoothed_normalized_score, episode_idx) - writer.add_scalar("training/completion", np.mean(completion), episode_idx) - writer.add_scalar("training/smoothed_completion", np.mean(smoothed_completion), episode_idx) - writer.add_scalar("training/nb_steps", nb_steps, episode_idx) - writer.add_histogram("actions/distribution", np.array(actions_taken), episode_idx) - writer.add_scalar("actions/nothing", action_probs[RailEnvActions.DO_NOTHING], episode_idx) - writer.add_scalar("actions/left", action_probs[RailEnvActions.MOVE_LEFT], episode_idx) - writer.add_scalar("actions/forward", action_probs[RailEnvActions.MOVE_FORWARD], episode_idx) - writer.add_scalar("actions/right", action_probs[RailEnvActions.MOVE_RIGHT], episode_idx) - writer.add_scalar("actions/stop", action_probs[RailEnvActions.STOP_MOVING], episode_idx) - writer.add_scalar("training/epsilon", eps_start, episode_idx) - writer.add_scalar("training/buffer_size", len(policy.memory), episode_idx) - writer.add_scalar("training/loss", policy.loss, episode_idx) - writer.add_scalar("timer/reset", reset_timer.get(), episode_idx) - writer.add_scalar("timer/step", step_timer.get(), episode_idx) - writer.add_scalar("timer/learn", learn_timer.get(), episode_idx) - writer.add_scalar("timer/preproc", preproc_timer.get(), episode_idx) - writer.add_scalar("timer/total", training_timer.get_current(), episode_idx) - - -def format_action_prob(action_probs): - action_probs = np.round(action_probs, 3) - actions = ["↻", "â†", "↑", "→", "â—¼"] - - buffer = "" - for action, action_prob in zip(actions, action_probs): - buffer += action + " " + "{:.3f}".format(action_prob) + " " - - return buffer - - -def eval_policy(env, tree_observation, policy, train_params, obs_params): - n_eval_episodes = train_params.n_evaluation_episodes - max_steps = env._max_episode_steps - tree_depth = obs_params.observation_tree_depth - observation_radius = obs_params.observation_radius - - action_dict = dict() - scores = [] - completions = [] - nb_steps = [] - - for episode_idx in range(n_eval_episodes): - agent_obs = [None] * env.get_num_agents() - score = 0.0 - - obs, info = env.reset(regenerate_rail=True, regenerate_schedule=True) - final_step = 0 - - for step in range(max_steps - 1): - policy.start_step() - for agent in env.get_agent_handles(): - if tree_observation.check_is_observation_valid(agent_obs[agent]): - agent_obs[agent] = tree_observation.get_normalized_observation(obs[agent], tree_depth=tree_depth, - observation_radius=observation_radius) - - action = 0 - if info['action_required'][agent]: - if tree_observation.check_is_observation_valid(agent_obs[agent]): - action = policy.act(agent_obs[agent], eps=0.0) - action_dict.update({agent: action}) - policy.end_step() - obs, all_rewards, done, info = env.step(action_dict) - - for agent in env.get_agent_handles(): - score += all_rewards[agent] - - final_step = step - - if done['__all__']: - break - - normalized_score = score / (max_steps * env.get_num_agents()) - scores.append(normalized_score) - - tasks_finished = sum(done[idx] for idx in env.get_agent_handles()) - completion = tasks_finished / max(1, env.get_num_agents()) - completions.append(completion) - - nb_steps.append(final_step) - - print("\t✅ Eval: score {:.3f} done {:.1f}%".format(np.mean(scores), np.mean(completions) * 100.0)) - - return scores, completions, nb_steps - - -if __name__ == "__main__": - parser = ArgumentParser() - parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=5400, type=int) - parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=1, type=int) - parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=0, - type=int) - parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=1, type=int) - parser.add_argument("--checkpoint_interval", help="checkpoint interval", default=100, type=int) - parser.add_argument("--eps_start", help="max exploration", default=0.1, type=float) - parser.add_argument("--eps_end", help="min exploration", default=0.01, type=float) - parser.add_argument("--eps_decay", help="exploration decay", default=0.9998, type=float) - parser.add_argument("--buffer_size", help="replay buffer size", default=int(1e7), type=int) - parser.add_argument("--buffer_min_size", help="min buffer size to start training", default=0, type=int) - parser.add_argument("--restore_replay_buffer", help="replay buffer to restore", default="", type=str) - parser.add_argument("--save_replay_buffer", help="save replay buffer at each evaluation interval", default=False, - type=bool) - parser.add_argument("--batch_size", help="minibatch size", default=128, type=int) - parser.add_argument("--gamma", help="discount factor", default=0.99, type=float) - parser.add_argument("--tau", help="soft update of target parameters", default=1e-3, type=float) - parser.add_argument("--learning_rate", help="learning rate", default=0.5e-4, type=float) - parser.add_argument("--hidden_size", help="hidden size (2 fc layers)", default=128, type=int) - parser.add_argument("--update_every", help="how often to update the network", default=8, type=int) - parser.add_argument("--use_gpu", help="use GPU if available", default=False, type=bool) - parser.add_argument("--num_threads", help="number of threads PyTorch can use", default=1, type=int) - parser.add_argument("--render", help="render 1 episode in 100", action='store_true') - parser.add_argument("--load_policy", help="policy filename (reference) to load", default="", type=str) - parser.add_argument("--use_fast_tree_observation", help="use FastTreeObs instead of stock TreeObs", - action='store_true') - parser.add_argument("--max_depth", help="max depth", default=1, type=int) - - training_params = parser.parse_args() - env_params = [ - { - # Test_0 - "n_agents": 5, - "x_dim": 25, - "y_dim": 25, - "n_cities": 2, - "max_rails_between_cities": 2, - "max_rails_in_city": 3, - "malfunction_rate": 1 / 50, - "seed": 0 - }, - { - # Test_1 - "n_agents": 10, - "x_dim": 30, - "y_dim": 30, - "n_cities": 2, - "max_rails_between_cities": 2, - "max_rails_in_city": 3, - "malfunction_rate": 1 / 100, - "seed": 0 - }, - { - # Test_2 - "n_agents": 20, - "x_dim": 30, - "y_dim": 30, - "n_cities": 3, - "max_rails_between_cities": 2, - "max_rails_in_city": 3, - "malfunction_rate": 1 / 200, - "seed": 0 - }, - ] - - obs_params = { - "observation_tree_depth": training_params.max_depth, - "observation_radius": 10, - "observation_max_path_depth": 30 - } - - - def check_env_config(id): - if id >= len(env_params) or id < 0: - print("\n🛑 Invalid environment configuration, only Test_0 to Test_{} are supported.".format( - len(env_params) - 1)) - exit(1) - - - check_env_config(training_params.training_env_config) - check_env_config(training_params.evaluation_env_config) - - training_env_params = env_params[training_params.training_env_config] - evaluation_env_params = env_params[training_params.evaluation_env_config] - - # FIXME hard-coded for sweep search - # see https://wb-forum.slack.com/archives/CL4V2QE59/p1602931982236600 to implement properly - # training_params.use_fast_tree_observation = True - - print("\nTraining parameters:") - pprint(vars(training_params)) - print("\nTraining environment parameters (Test_{}):".format(training_params.training_env_config)) - pprint(training_env_params) - print("\nEvaluation environment parameters (Test_{}):".format(training_params.evaluation_env_config)) - pprint(evaluation_env_params) - print("\nObservation parameters:") - pprint(obs_params) - - os.environ["OMP_NUM_THREADS"] = str(training_params.num_threads) - train_agent(training_params, Namespace(**training_env_params), Namespace(**evaluation_env_params), - Namespace(**obs_params)) +import os +import random +import sys +from argparse import ArgumentParser, Namespace +from collections import deque +from datetime import datetime +from pathlib import Path +from pprint import pprint + +import numpy as np +import psutil +from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters +from flatland.envs.observations import TreeObsForRailEnv +from flatland.envs.predictions import ShortestPathPredictorForRailEnv +from flatland.envs.rail_env import RailEnv, RailEnvActions +from flatland.envs.rail_generators import sparse_rail_generator +from flatland.envs.schedule_generators import sparse_schedule_generator +from flatland.utils.rendertools import RenderTool +from torch.utils.tensorboard import SummaryWriter + +from reinforcement_learning.dddqn_policy import DDDQNPolicy +from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent + +base_dir = Path(__file__).resolve().parent.parent +sys.path.append(str(base_dir)) + +from utils.timer import Timer +from utils.observation_utils import normalize_observation +from utils.fast_tree_obs import FastTreeObs + +try: + import wandb + + wandb.init(sync_tensorboard=True) +except ImportError: + print("Install wandb to log to Weights & Biases") + +""" +This file shows how to train multiple agents using a reinforcement learning approach. +After training an agent, you can submit it straight away to the NeurIPS 2020 Flatland challenge! + +Agent documentation: https://flatland.aicrowd.com/getting-started/rl/multi-agent.html +Submission documentation: https://flatland.aicrowd.com/getting-started/first-submission.html +""" + + +def create_rail_env(env_params, tree_observation): + n_agents = env_params.n_agents + x_dim = env_params.x_dim + y_dim = env_params.y_dim + n_cities = env_params.n_cities + max_rails_between_cities = env_params.max_rails_between_cities + max_rails_in_city = env_params.max_rails_in_city + seed = env_params.seed + + # Break agents from time to time + malfunction_parameters = MalfunctionParameters( + malfunction_rate=env_params.malfunction_rate, + min_duration=20, + max_duration=50 + ) + + return RailEnv( + width=x_dim, height=y_dim, + rail_generator=sparse_rail_generator( + max_num_cities=n_cities, + grid_mode=False, + max_rails_between_cities=max_rails_between_cities, + max_rails_in_city=max_rails_in_city + ), + schedule_generator=sparse_schedule_generator(), + number_of_agents=n_agents, + malfunction_generator_and_process_data=malfunction_from_params(malfunction_parameters), + obs_builder_object=tree_observation, + random_seed=seed + ) + + +def train_agent(train_params, train_env_params, eval_env_params, obs_params): + # Environment parameters + n_agents = train_env_params.n_agents + x_dim = train_env_params.x_dim + y_dim = train_env_params.y_dim + n_cities = train_env_params.n_cities + max_rails_between_cities = train_env_params.max_rails_between_cities + max_rails_in_city = train_env_params.max_rails_in_city + seed = train_env_params.seed + + # Unique ID for this training + now = datetime.now() + training_id = now.strftime('%y%m%d%H%M%S') + + # Observation parameters + observation_tree_depth = obs_params.observation_tree_depth + observation_radius = obs_params.observation_radius + observation_max_path_depth = obs_params.observation_max_path_depth + + # Training parameters + eps_start = train_params.eps_start + eps_end = train_params.eps_end + eps_decay = train_params.eps_decay + n_episodes = train_params.n_episodes + checkpoint_interval = train_params.checkpoint_interval + n_eval_episodes = train_params.n_evaluation_episodes + restore_replay_buffer = train_params.restore_replay_buffer + save_replay_buffer = train_params.save_replay_buffer + + # Set the seeds + random.seed(seed) + np.random.seed(seed) + + # Observation builder + predictor = ShortestPathPredictorForRailEnv(observation_max_path_depth) + if not train_params.use_fast_tree_observation: + print("\nUsing standard TreeObs") + + def check_is_observation_valid(observation): + return observation + + def get_normalized_observation(observation, tree_depth: int, observation_radius=0): + return normalize_observation(observation, tree_depth, observation_radius) + + tree_observation = TreeObsForRailEnv(max_depth=observation_tree_depth, predictor=predictor) + tree_observation.check_is_observation_valid = check_is_observation_valid + tree_observation.get_normalized_observation = get_normalized_observation + else: + print("\nUsing FastTreeObs") + + def check_is_observation_valid(observation): + return True + + def get_normalized_observation(observation, tree_depth: int, observation_radius=0): + return observation + + tree_observation = FastTreeObs(max_depth=observation_tree_depth) + tree_observation.check_is_observation_valid = check_is_observation_valid + tree_observation.get_normalized_observation = get_normalized_observation + + # Setup the environments + train_env = create_rail_env(train_env_params, tree_observation) + train_env.reset(regenerate_schedule=True, regenerate_rail=True) + eval_env = create_rail_env(eval_env_params, tree_observation) + eval_env.reset(regenerate_schedule=True, regenerate_rail=True) + + if not train_params.use_fast_tree_observation: + # Calculate the state size given the depth of the tree observation and the number of features + n_features_per_node = train_env.obs_builder.observation_dim + n_nodes = sum([np.power(4, i) for i in range(observation_tree_depth + 1)]) + state_size = n_features_per_node * n_nodes + else: + # Calculate the state size given the depth of the tree observation and the number of features + state_size = tree_observation.observation_dim + + # Setup renderer + if train_params.render: + env_renderer = RenderTool(train_env, gl="PGL") + + # The action space of flatland is 5 discrete actions + action_size = 5 + + action_count = [0] * action_size + action_dict = dict() + agent_obs = [None] * n_agents + agent_prev_obs = [None] * n_agents + agent_prev_action = [2] * n_agents + update_values = [False] * n_agents + + # Smoothed values used as target for hyperparameter tuning + smoothed_eval_normalized_score = -1.0 + smoothed_eval_completion = 0.0 + + scores_window = deque(maxlen=checkpoint_interval) # todo smooth when rendering instead + completion_window = deque(maxlen=checkpoint_interval) + + # Double Dueling DQN policy + policy = DDDQNPolicy(state_size, action_size, train_params) + # policy = PPOAgent(state_size, action_size, n_agents) + # Load existing policy + if train_params.load_policy is not "": + policy.load(train_params.load_policy) + + # Loads existing replay buffer + if restore_replay_buffer: + try: + policy.load_replay_buffer(restore_replay_buffer) + policy.test() + except RuntimeError as e: + print("\n🛑 Could't load replay buffer, were the experiences generated using the same tree depth?") + print(e) + exit(1) + + print("\n💾 Replay buffer status: {}/{} experiences".format(len(policy.memory.memory), train_params.buffer_size)) + + hdd = psutil.disk_usage('/') + if save_replay_buffer and (hdd.free / (2 ** 30)) < 500.0: + print( + "âš ï¸ Careful! Saving replay buffers will quickly consume a lot of disk space. You have {:.2f}gb left.".format( + hdd.free / (2 ** 30))) + + # TensorBoard writer + writer = SummaryWriter() + + training_timer = Timer() + training_timer.start() + + print( + "\n🚉 Training {} trains on {}x{} grid for {} episodes, evaluating on {} episodes every {} episodes. Training id '{}'.\n".format( + train_env.get_num_agents(), + x_dim, y_dim, + n_episodes, + n_eval_episodes, + checkpoint_interval, + training_id + )) + + for episode_idx in range(n_episodes + 1): + step_timer = Timer() + reset_timer = Timer() + learn_timer = Timer() + preproc_timer = Timer() + inference_timer = Timer() + + # Reset environment + reset_timer.start() + train_env_params.n_agents = episode_idx % n_agents + 1 + train_env = create_rail_env(train_env_params, tree_observation) + obs, info = train_env.reset(regenerate_rail=True, regenerate_schedule=True) + policy.reset() + + policy2 = DeadLockAvoidanceAgent(train_env) + policy2.reset() + + reset_timer.end() + + if train_params.render: + env_renderer.set_new_rail() + + score = 0 + nb_steps = 0 + actions_taken = [] + + # Build initial agent-specific observations + for agent in train_env.get_agent_handles(): + if tree_observation.check_is_observation_valid(obs[agent]): + agent_obs[agent] = tree_observation.get_normalized_observation(obs[agent], observation_tree_depth, + observation_radius=observation_radius) + agent_prev_obs[agent] = agent_obs[agent].copy() + + # Max number of steps per episode + # This is the official formula used during evaluations + # See details in flatland.envs.schedule_generators.sparse_schedule_generator + # max_steps = int(4 * 2 * (env.height + env.width + (n_agents / n_cities))) + max_steps = train_env._max_episode_steps + + # Run episode + agent_to_learn = 0 + if train_env.get_num_agents() > 1: + agent_to_learn = np.random.choice(train_env.get_num_agents()) + for step in range(max_steps - 1): + inference_timer.start() + policy.start_step() + policy2.start_step() + for agent in train_env.get_agent_handles(): + if info['action_required'][agent]: + update_values[agent] = True + + if agent == agent_to_learn: + action = policy.act(agent_obs[agent], eps=eps_start) + else: + action = policy2.act([agent], eps=eps_start) + action_count[action] += 1 + actions_taken.append(action) + else: + # An action is not required if the train hasn't joined the railway network, + # if it already reached its target, or if is currently malfunctioning. + update_values[agent] = False + action = 0 + action_dict.update({agent: action}) + policy.end_step() + policy2.end_step() + inference_timer.end() + + # Environment step + step_timer.start() + next_obs, all_rewards, done, info = train_env.step(action_dict) + + for agent in train_env.get_agent_handles(): + act = action_dict.get(agent, RailEnvActions.DO_NOTHING) + if agent_obs[agent][26] == 1: + if act == RailEnvActions.STOP_MOVING: + all_rewards[agent] *= 0.01 + else: + if act == RailEnvActions.MOVE_LEFT: + all_rewards[agent] *= 0.9 + else: + if agent_obs[agent][7] == 0 and agent_obs[agent][8] == 0: + if act == RailEnvActions.MOVE_FORWARD: + all_rewards[agent] *= 0.01 + if done[agent]: + all_rewards[agent] += 100.0 + + step_timer.end() + + # Render an episode at some interval + if train_params.render and episode_idx % checkpoint_interval == 0: + env_renderer.render_env( + show=True, + frames=False, + show_observations=False, + show_predictions=False + ) + + # Update replay buffer and train agent + for agent in train_env.get_agent_handles(): + if update_values[agent] or done['__all__']: + # Only learn from timesteps where somethings happened + learn_timer.start() + if agent == agent_to_learn: + policy.step(agent, + agent_prev_obs[agent], agent_prev_action[agent], all_rewards[agent], + agent_obs[agent], + done[agent]) + learn_timer.end() + + agent_prev_obs[agent] = agent_obs[agent].copy() + agent_prev_action[agent] = action_dict[agent] + + # Preprocess the new observations + if tree_observation.check_is_observation_valid(next_obs[agent]): + preproc_timer.start() + agent_obs[agent] = tree_observation.get_normalized_observation(next_obs[agent], + observation_tree_depth, + observation_radius=observation_radius) + preproc_timer.end() + + score += all_rewards[agent] + + nb_steps = step + + if done['__all__']: + break + + # Epsilon decay + eps_start = max(eps_end, eps_decay * eps_start) + + # Collect information about training + tasks_finished = sum(done[idx] for idx in train_env.get_agent_handles()) + completion = tasks_finished / max(1, train_env.get_num_agents()) + normalized_score = score / (max_steps * train_env.get_num_agents()) + action_probs = action_count / max(1, np.sum(action_count)) + + scores_window.append(normalized_score) + completion_window.append(completion) + smoothed_normalized_score = np.mean(scores_window) + smoothed_completion = np.mean(completion_window) + + # Print logs + if episode_idx % checkpoint_interval == 0: + policy.save('./checkpoints/' + training_id + '-' + str(episode_idx) + '.pth') + + if save_replay_buffer: + policy.save_replay_buffer('./replay_buffers/' + training_id + '-' + str(episode_idx) + '.pkl') + + if train_params.render: + env_renderer.close_window() + + # reset action count + action_count = [0] * action_size + + print( + '\r🚂 Episode {}' + '\t 🆠Score: {:7.3f}' + ' Avg: {:7.3f}' + '\t 💯 Done: {:6.2f}%' + ' Avg: {:6.2f}%' + '\t 🎲 Epsilon: {:.3f} ' + '\t 🔀 Action Probs: {}'.format( + episode_idx, + normalized_score, + smoothed_normalized_score, + 100 * completion, + 100 * smoothed_completion, + eps_start, + format_action_prob(action_probs) + ), end=" ") + + # Evaluate policy and log results at some interval + if episode_idx % checkpoint_interval == 0 and n_eval_episodes > 0: + scores, completions, nb_steps_eval = eval_policy(eval_env, + tree_observation, + policy, + train_params, + obs_params) + + writer.add_scalar("evaluation/scores_min", np.min(scores), episode_idx) + writer.add_scalar("evaluation/scores_max", np.max(scores), episode_idx) + writer.add_scalar("evaluation/scores_mean", np.mean(scores), episode_idx) + writer.add_scalar("evaluation/scores_std", np.std(scores), episode_idx) + writer.add_histogram("evaluation/scores", np.array(scores), episode_idx) + writer.add_scalar("evaluation/completions_min", np.min(completions), episode_idx) + writer.add_scalar("evaluation/completions_max", np.max(completions), episode_idx) + writer.add_scalar("evaluation/completions_mean", np.mean(completions), episode_idx) + writer.add_scalar("evaluation/completions_std", np.std(completions), episode_idx) + writer.add_histogram("evaluation/completions", np.array(completions), episode_idx) + writer.add_scalar("evaluation/nb_steps_min", np.min(nb_steps_eval), episode_idx) + writer.add_scalar("evaluation/nb_steps_max", np.max(nb_steps_eval), episode_idx) + writer.add_scalar("evaluation/nb_steps_mean", np.mean(nb_steps_eval), episode_idx) + writer.add_scalar("evaluation/nb_steps_std", np.std(nb_steps_eval), episode_idx) + writer.add_histogram("evaluation/nb_steps", np.array(nb_steps_eval), episode_idx) + + smoothing = 0.9 + smoothed_eval_normalized_score = smoothed_eval_normalized_score * smoothing + np.mean(scores) * ( + 1.0 - smoothing) + smoothed_eval_completion = smoothed_eval_completion * smoothing + np.mean(completions) * (1.0 - smoothing) + writer.add_scalar("evaluation/smoothed_score", smoothed_eval_normalized_score, episode_idx) + writer.add_scalar("evaluation/smoothed_completion", smoothed_eval_completion, episode_idx) + + # Save logs to tensorboard + writer.add_scalar("training/score", normalized_score, episode_idx) + writer.add_scalar("training/smoothed_score", smoothed_normalized_score, episode_idx) + writer.add_scalar("training/completion", np.mean(completion), episode_idx) + writer.add_scalar("training/smoothed_completion", np.mean(smoothed_completion), episode_idx) + writer.add_scalar("training/nb_steps", nb_steps, episode_idx) + writer.add_histogram("actions/distribution", np.array(actions_taken), episode_idx) + writer.add_scalar("actions/nothing", action_probs[RailEnvActions.DO_NOTHING], episode_idx) + writer.add_scalar("actions/left", action_probs[RailEnvActions.MOVE_LEFT], episode_idx) + writer.add_scalar("actions/forward", action_probs[RailEnvActions.MOVE_FORWARD], episode_idx) + writer.add_scalar("actions/right", action_probs[RailEnvActions.MOVE_RIGHT], episode_idx) + writer.add_scalar("actions/stop", action_probs[RailEnvActions.STOP_MOVING], episode_idx) + writer.add_scalar("training/epsilon", eps_start, episode_idx) + writer.add_scalar("training/buffer_size", len(policy.memory), episode_idx) + writer.add_scalar("training/loss", policy.loss, episode_idx) + writer.add_scalar("timer/reset", reset_timer.get(), episode_idx) + writer.add_scalar("timer/step", step_timer.get(), episode_idx) + writer.add_scalar("timer/learn", learn_timer.get(), episode_idx) + writer.add_scalar("timer/preproc", preproc_timer.get(), episode_idx) + writer.add_scalar("timer/total", training_timer.get_current(), episode_idx) + + +def format_action_prob(action_probs): + action_probs = np.round(action_probs, 3) + actions = ["↻", "â†", "↑", "→", "â—¼"] + + buffer = "" + for action, action_prob in zip(actions, action_probs): + buffer += action + " " + "{:.3f}".format(action_prob) + " " + + return buffer + + +def eval_policy(env, tree_observation, policy, train_params, obs_params): + n_eval_episodes = train_params.n_evaluation_episodes + max_steps = env._max_episode_steps + tree_depth = obs_params.observation_tree_depth + observation_radius = obs_params.observation_radius + + action_dict = dict() + scores = [] + completions = [] + nb_steps = [] + + for episode_idx in range(n_eval_episodes): + agent_obs = [None] * env.get_num_agents() + score = 0.0 + + obs, info = env.reset(regenerate_rail=True, regenerate_schedule=True) + final_step = 0 + + for step in range(max_steps - 1): + policy.start_step() + for agent in env.get_agent_handles(): + if tree_observation.check_is_observation_valid(agent_obs[agent]): + agent_obs[agent] = tree_observation.get_normalized_observation(obs[agent], tree_depth=tree_depth, + observation_radius=observation_radius) + + action = 0 + if info['action_required'][agent]: + if tree_observation.check_is_observation_valid(agent_obs[agent]): + action = policy.act(agent_obs[agent], eps=0.0) + action_dict.update({agent: action}) + policy.end_step() + obs, all_rewards, done, info = env.step(action_dict) + + for agent in env.get_agent_handles(): + score += all_rewards[agent] + + final_step = step + + if done['__all__']: + break + + normalized_score = score / (max_steps * env.get_num_agents()) + scores.append(normalized_score) + + tasks_finished = sum(done[idx] for idx in env.get_agent_handles()) + completion = tasks_finished / max(1, env.get_num_agents()) + completions.append(completion) + + nb_steps.append(final_step) + + print("\t✅ Eval: score {:.3f} done {:.1f}%".format(np.mean(scores), np.mean(completions) * 100.0)) + + return scores, completions, nb_steps + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=5400, type=int) + parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=1, type=int) + parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=0, + type=int) + parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=1, type=int) + parser.add_argument("--checkpoint_interval", help="checkpoint interval", default=100, type=int) + parser.add_argument("--eps_start", help="max exploration", default=0.1, type=float) + parser.add_argument("--eps_end", help="min exploration", default=0.01, type=float) + parser.add_argument("--eps_decay", help="exploration decay", default=0.9998, type=float) + parser.add_argument("--buffer_size", help="replay buffer size", default=int(1e7), type=int) + parser.add_argument("--buffer_min_size", help="min buffer size to start training", default=0, type=int) + parser.add_argument("--restore_replay_buffer", help="replay buffer to restore", default="", type=str) + parser.add_argument("--save_replay_buffer", help="save replay buffer at each evaluation interval", default=False, + type=bool) + parser.add_argument("--batch_size", help="minibatch size", default=128, type=int) + parser.add_argument("--gamma", help="discount factor", default=0.99, type=float) + parser.add_argument("--tau", help="soft update of target parameters", default=1e-3, type=float) + parser.add_argument("--learning_rate", help="learning rate", default=0.5e-4, type=float) + parser.add_argument("--hidden_size", help="hidden size (2 fc layers)", default=128, type=int) + parser.add_argument("--update_every", help="how often to update the network", default=8, type=int) + parser.add_argument("--use_gpu", help="use GPU if available", default=False, type=bool) + parser.add_argument("--num_threads", help="number of threads PyTorch can use", default=1, type=int) + parser.add_argument("--render", help="render 1 episode in 100", action='store_true') + parser.add_argument("--load_policy", help="policy filename (reference) to load", default="", type=str) + parser.add_argument("--use_fast_tree_observation", help="use FastTreeObs instead of stock TreeObs", + action='store_true') + parser.add_argument("--max_depth", help="max depth", default=1, type=int) + + training_params = parser.parse_args() + env_params = [ + { + # Test_0 + "n_agents": 5, + "x_dim": 25, + "y_dim": 25, + "n_cities": 2, + "max_rails_between_cities": 2, + "max_rails_in_city": 3, + "malfunction_rate": 1 / 50, + "seed": 0 + }, + { + # Test_1 + "n_agents": 10, + "x_dim": 30, + "y_dim": 30, + "n_cities": 2, + "max_rails_between_cities": 2, + "max_rails_in_city": 3, + "malfunction_rate": 1 / 100, + "seed": 0 + }, + { + # Test_2 + "n_agents": 20, + "x_dim": 30, + "y_dim": 30, + "n_cities": 3, + "max_rails_between_cities": 2, + "max_rails_in_city": 3, + "malfunction_rate": 1 / 200, + "seed": 0 + }, + ] + + obs_params = { + "observation_tree_depth": training_params.max_depth, + "observation_radius": 10, + "observation_max_path_depth": 30 + } + + + def check_env_config(id): + if id >= len(env_params) or id < 0: + print("\n🛑 Invalid environment configuration, only Test_0 to Test_{} are supported.".format( + len(env_params) - 1)) + exit(1) + + + check_env_config(training_params.training_env_config) + check_env_config(training_params.evaluation_env_config) + + training_env_params = env_params[training_params.training_env_config] + evaluation_env_params = env_params[training_params.evaluation_env_config] + + # FIXME hard-coded for sweep search + # see https://wb-forum.slack.com/archives/CL4V2QE59/p1602931982236600 to implement properly + # training_params.use_fast_tree_observation = True + + print("\nTraining parameters:") + pprint(vars(training_params)) + print("\nTraining environment parameters (Test_{}):".format(training_params.training_env_config)) + pprint(training_env_params) + print("\nEvaluation environment parameters (Test_{}):".format(training_params.evaluation_env_config)) + pprint(evaluation_env_params) + print("\nObservation parameters:") + pprint(obs_params) + + os.environ["OMP_NUM_THREADS"] = str(training_params.num_threads) + train_agent(training_params, Namespace(**training_env_params), Namespace(**evaluation_env_params), + Namespace(**obs_params)) diff --git a/reinforcement_learning/policy.py b/reinforcement_learning/policy.py index b605aa3..c7621a6 100644 --- a/reinforcement_learning/policy.py +++ b/reinforcement_learning/policy.py @@ -1,27 +1,27 @@ -class Policy: - def step(self, handle, state, action, reward, next_state, done): - raise NotImplementedError - - def act(self, state, eps=0.): - raise NotImplementedError - - def save(self, filename): - raise NotImplementedError - - def load(self, filename): - raise NotImplementedError - - def start_step(self): - pass - - def end_step(self): - pass - - def load_replay_buffer(self, filename): - pass - - def test(self): - pass - - def reset(self): +class Policy: + def step(self, handle, state, action, reward, next_state, done): + raise NotImplementedError + + def act(self, state, eps=0.): + raise NotImplementedError + + def save(self, filename): + raise NotImplementedError + + def load(self, filename): + raise NotImplementedError + + def start_step(self): + pass + + def end_step(self): + pass + + def load_replay_buffer(self, filename): + pass + + def test(self): + pass + + def reset(self): pass \ No newline at end of file diff --git a/reinforcement_learning/ppo/model.py b/reinforcement_learning/ppo/model.py index 03d72c9..51b86ff 100644 --- a/reinforcement_learning/ppo/model.py +++ b/reinforcement_learning/ppo/model.py @@ -1,20 +1,20 @@ -import torch.nn as nn -import torch.nn.functional as F - - -class PolicyNetwork(nn.Module): - def __init__(self, state_size, action_size, hidsize1=128, hidsize2=128, hidsize3=32): - super().__init__() - self.fc1 = nn.Linear(state_size, hidsize1) - self.fc2 = nn.Linear(hidsize1, hidsize2) - # self.fc3 = nn.Linear(hidsize2, hidsize3) - self.output = nn.Linear(hidsize2, action_size) - self.softmax = nn.Softmax(dim=1) - self.bn0 = nn.BatchNorm1d(state_size, affine=False) - - def forward(self, inputs): - x = self.bn0(inputs.float()) - x = F.relu(self.fc1(x)) - x = F.relu(self.fc2(x)) - # x = F.relu(self.fc3(x)) - return self.softmax(self.output(x)) +import torch.nn as nn +import torch.nn.functional as F + + +class PolicyNetwork(nn.Module): + def __init__(self, state_size, action_size, hidsize1=128, hidsize2=128, hidsize3=32): + super().__init__() + self.fc1 = nn.Linear(state_size, hidsize1) + self.fc2 = nn.Linear(hidsize1, hidsize2) + # self.fc3 = nn.Linear(hidsize2, hidsize3) + self.output = nn.Linear(hidsize2, action_size) + self.softmax = nn.Softmax(dim=1) + self.bn0 = nn.BatchNorm1d(state_size, affine=False) + + def forward(self, inputs): + x = self.bn0(inputs.float()) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + # x = F.relu(self.fc3(x)) + return self.softmax(self.output(x)) diff --git a/reinforcement_learning/ppo/ppo_agent.py b/reinforcement_learning/ppo/ppo_agent.py index a7431f8..49fe7e6 100644 --- a/reinforcement_learning/ppo/ppo_agent.py +++ b/reinforcement_learning/ppo/ppo_agent.py @@ -1,131 +1,131 @@ -import os - -import numpy as np -import torch -from torch.distributions.categorical import Categorical - -from reinforcement_learning.policy import Policy -from reinforcement_learning.ppo.model import PolicyNetwork -from reinforcement_learning.ppo.replay_memory import Episode, ReplayBuffer - -BUFFER_SIZE = 128_000 -BATCH_SIZE = 8192 -GAMMA = 0.95 -LR = 0.5e-4 -CLIP_FACTOR = .005 -UPDATE_EVERY = 30 - -device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") -print("device:", device) - - -class PPOAgent(Policy): - def __init__(self, state_size, action_size, num_agents): - self.action_size = action_size - self.state_size = state_size - self.num_agents = num_agents - self.policy = PolicyNetwork(state_size, action_size).to(device) - self.old_policy = PolicyNetwork(state_size, action_size).to(device) - self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=LR) - self.episodes = [Episode() for _ in range(num_agents)] - self.memory = ReplayBuffer(BUFFER_SIZE) - self.t_step = 0 - self.loss = 0 - - def reset(self): - self.finished = [False] * len(self.episodes) - self.tot_reward = [0] * self.num_agents - - # Decide on an action to take in the environment - - def act(self, state, eps=None): - self.policy.eval() - with torch.no_grad(): - output = self.policy(torch.from_numpy(state).float().unsqueeze(0).to(device)) - ret = Categorical(output).sample().item() - return ret - - # Record the results of the agent's action and update the model - def step(self, handle, state, action, reward, next_state, done): - if not self.finished[handle]: - # Push experience into Episode memory - self.tot_reward[handle] += reward - if done == 1: - reward = 1 # self.tot_reward[handle] - else: - reward = 0 - - self.episodes[handle].push(state, action, reward, next_state, done) - - # When we finish the episode, discount rewards and push the experience into replay memory - if done: - self.episodes[handle].discount_rewards(GAMMA) - self.memory.push_episode(self.episodes[handle]) - self.episodes[handle].reset() - self.finished[handle] = True - - # Perform a gradient update every UPDATE_EVERY time steps - self.t_step = (self.t_step + 1) % UPDATE_EVERY - if self.t_step == 0 and len(self.memory) > BATCH_SIZE * 4: - self._learn(*self.memory.sample(BATCH_SIZE, device)) - - def _clip_gradient(self, model, clip): - - for p in model.parameters(): - p.grad.data.clamp_(-clip, clip) - return - - """Computes a gradient clipping coefficient based on gradient norm.""" - totalnorm = 0 - for p in model.parameters(): - if p.grad is not None: - modulenorm = p.grad.data.norm() - totalnorm += modulenorm ** 2 - totalnorm = np.sqrt(totalnorm) - coeff = min(1, clip / (totalnorm + 1e-6)) - - for p in model.parameters(): - if p.grad is not None: - p.grad.mul_(coeff) - - def _learn(self, states, actions, rewards, next_state, done): - self.policy.train() - - responsible_outputs = torch.gather(self.policy(states), 1, actions) - old_responsible_outputs = torch.gather(self.old_policy(states), 1, actions).detach() - - # rewards = rewards - rewards.mean() - ratio = responsible_outputs / (old_responsible_outputs + 1e-5) - clamped_ratio = torch.clamp(ratio, 1. - CLIP_FACTOR, 1. + CLIP_FACTOR) - loss = -torch.min(ratio * rewards, clamped_ratio * rewards).mean() - self.loss = loss - - # Compute loss and perform a gradient step - self.old_policy.load_state_dict(self.policy.state_dict()) - self.optimizer.zero_grad() - loss.backward() - # self._clip_gradient(self.policy, 1.0) - self.optimizer.step() - - # Checkpointing methods - def save(self, filename): - # print("Saving model from checkpoint:", filename) - torch.save(self.policy.state_dict(), filename + ".policy") - torch.save(self.optimizer.state_dict(), filename + ".optimizer") - - def load(self, filename): - print("load policy from file", filename) - if os.path.exists(filename + ".policy"): - print(' >> ', filename + ".policy") - try: - self.policy.load_state_dict(torch.load(filename + ".policy", map_location=device)) - except: - print(" >> failed!") - pass - if os.path.exists(filename + ".optimizer"): - print(' >> ', filename + ".optimizer") - try: - self.optimizer.load_state_dict(torch.load(filename + ".optimizer", map_location=device)) - except: - print(" >> failed!") - pass +import os + +import numpy as np +import torch +from torch.distributions.categorical import Categorical + +from reinforcement_learning.policy import Policy +from reinforcement_learning.ppo.model import PolicyNetwork +from reinforcement_learning.ppo.replay_memory import Episode, ReplayBuffer + +BUFFER_SIZE = 128_000 +BATCH_SIZE = 8192 +GAMMA = 0.95 +LR = 0.5e-4 +CLIP_FACTOR = .005 +UPDATE_EVERY = 30 + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +print("device:", device) + + +class PPOAgent(Policy): + def __init__(self, state_size, action_size, num_agents): + self.action_size = action_size + self.state_size = state_size + self.num_agents = num_agents + self.policy = PolicyNetwork(state_size, action_size).to(device) + self.old_policy = PolicyNetwork(state_size, action_size).to(device) + self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=LR) + self.episodes = [Episode() for _ in range(num_agents)] + self.memory = ReplayBuffer(BUFFER_SIZE) + self.t_step = 0 + self.loss = 0 + + def reset(self): + self.finished = [False] * len(self.episodes) + self.tot_reward = [0] * self.num_agents + + # Decide on an action to take in the environment + + def act(self, state, eps=None): + self.policy.eval() + with torch.no_grad(): + output = self.policy(torch.from_numpy(state).float().unsqueeze(0).to(device)) + ret = Categorical(output).sample().item() + return ret + + # Record the results of the agent's action and update the model + def step(self, handle, state, action, reward, next_state, done): + if not self.finished[handle]: + # Push experience into Episode memory + self.tot_reward[handle] += reward + if done == 1: + reward = 1 # self.tot_reward[handle] + else: + reward = 0 + + self.episodes[handle].push(state, action, reward, next_state, done) + + # When we finish the episode, discount rewards and push the experience into replay memory + if done: + self.episodes[handle].discount_rewards(GAMMA) + self.memory.push_episode(self.episodes[handle]) + self.episodes[handle].reset() + self.finished[handle] = True + + # Perform a gradient update every UPDATE_EVERY time steps + self.t_step = (self.t_step + 1) % UPDATE_EVERY + if self.t_step == 0 and len(self.memory) > BATCH_SIZE * 4: + self._learn(*self.memory.sample(BATCH_SIZE, device)) + + def _clip_gradient(self, model, clip): + + for p in model.parameters(): + p.grad.data.clamp_(-clip, clip) + return + + """Computes a gradient clipping coefficient based on gradient norm.""" + totalnorm = 0 + for p in model.parameters(): + if p.grad is not None: + modulenorm = p.grad.data.norm() + totalnorm += modulenorm ** 2 + totalnorm = np.sqrt(totalnorm) + coeff = min(1, clip / (totalnorm + 1e-6)) + + for p in model.parameters(): + if p.grad is not None: + p.grad.mul_(coeff) + + def _learn(self, states, actions, rewards, next_state, done): + self.policy.train() + + responsible_outputs = torch.gather(self.policy(states), 1, actions) + old_responsible_outputs = torch.gather(self.old_policy(states), 1, actions).detach() + + # rewards = rewards - rewards.mean() + ratio = responsible_outputs / (old_responsible_outputs + 1e-5) + clamped_ratio = torch.clamp(ratio, 1. - CLIP_FACTOR, 1. + CLIP_FACTOR) + loss = -torch.min(ratio * rewards, clamped_ratio * rewards).mean() + self.loss = loss + + # Compute loss and perform a gradient step + self.old_policy.load_state_dict(self.policy.state_dict()) + self.optimizer.zero_grad() + loss.backward() + # self._clip_gradient(self.policy, 1.0) + self.optimizer.step() + + # Checkpointing methods + def save(self, filename): + # print("Saving model from checkpoint:", filename) + torch.save(self.policy.state_dict(), filename + ".policy") + torch.save(self.optimizer.state_dict(), filename + ".optimizer") + + def load(self, filename): + print("load policy from file", filename) + if os.path.exists(filename + ".policy"): + print(' >> ', filename + ".policy") + try: + self.policy.load_state_dict(torch.load(filename + ".policy", map_location=device)) + except: + print(" >> failed!") + pass + if os.path.exists(filename + ".optimizer"): + print(' >> ', filename + ".optimizer") + try: + self.optimizer.load_state_dict(torch.load(filename + ".optimizer", map_location=device)) + except: + print(" >> failed!") + pass diff --git a/reinforcement_learning/ppo/replay_memory.py b/reinforcement_learning/ppo/replay_memory.py index 61a1b81..3e6619b 100644 --- a/reinforcement_learning/ppo/replay_memory.py +++ b/reinforcement_learning/ppo/replay_memory.py @@ -1,53 +1,53 @@ -import torch -import random -import numpy as np -from collections import namedtuple, deque, Iterable - - -Transition = namedtuple("Experience", ("state", "action", "reward", "next_state", "done")) - - -class Episode: - memory = [] - - def reset(self): - self.memory = [] - - def push(self, *args): - self.memory.append(tuple(args)) - - def discount_rewards(self, gamma): - running_add = 0. - for i, (state, action, reward, *rest) in list(enumerate(self.memory))[::-1]: - running_add = running_add * gamma + reward - self.memory[i] = (state, action, running_add, *rest) - - -class ReplayBuffer: - def __init__(self, buffer_size): - self.memory = deque(maxlen=buffer_size) - - def push(self, state, action, reward, next_state, done): - self.memory.append(Transition(np.expand_dims(state, 0), action, reward, np.expand_dims(next_state, 0), done)) - - def push_episode(self, episode): - for step in episode.memory: - self.push(*step) - - def sample(self, batch_size, device): - experiences = random.sample(self.memory, k=batch_size) - - states = torch.from_numpy(self.stack([e.state for e in experiences])).float().to(device) - actions = torch.from_numpy(self.stack([e.action for e in experiences])).long().to(device) - rewards = torch.from_numpy(self.stack([e.reward for e in experiences])).float().to(device) - next_states = torch.from_numpy(self.stack([e.next_state for e in experiences])).float().to(device) - dones = torch.from_numpy(self.stack([e.done for e in experiences]).astype(np.uint8)).float().to(device) - - return states, actions, rewards, next_states, dones - - def stack(self, states): - sub_dims = states[0].shape[1:] if isinstance(states[0], Iterable) else [1] - return np.reshape(np.array(states), (len(states), *sub_dims)) - - def __len__(self): - return len(self.memory) +import torch +import random +import numpy as np +from collections import namedtuple, deque, Iterable + + +Transition = namedtuple("Experience", ("state", "action", "reward", "next_state", "done")) + + +class Episode: + memory = [] + + def reset(self): + self.memory = [] + + def push(self, *args): + self.memory.append(tuple(args)) + + def discount_rewards(self, gamma): + running_add = 0. + for i, (state, action, reward, *rest) in list(enumerate(self.memory))[::-1]: + running_add = running_add * gamma + reward + self.memory[i] = (state, action, running_add, *rest) + + +class ReplayBuffer: + def __init__(self, buffer_size): + self.memory = deque(maxlen=buffer_size) + + def push(self, state, action, reward, next_state, done): + self.memory.append(Transition(np.expand_dims(state, 0), action, reward, np.expand_dims(next_state, 0), done)) + + def push_episode(self, episode): + for step in episode.memory: + self.push(*step) + + def sample(self, batch_size, device): + experiences = random.sample(self.memory, k=batch_size) + + states = torch.from_numpy(self.stack([e.state for e in experiences])).float().to(device) + actions = torch.from_numpy(self.stack([e.action for e in experiences])).long().to(device) + rewards = torch.from_numpy(self.stack([e.reward for e in experiences])).float().to(device) + next_states = torch.from_numpy(self.stack([e.next_state for e in experiences])).float().to(device) + dones = torch.from_numpy(self.stack([e.done for e in experiences]).astype(np.uint8)).float().to(device) + + return states, actions, rewards, next_states, dones + + def stack(self, states): + sub_dims = states[0].shape[1:] if isinstance(states[0], Iterable) else [1] + return np.reshape(np.array(states), (len(states), *sub_dims)) + + def __len__(self): + return len(self.memory) diff --git a/run.py b/run.py index c5e879e..1094d1b 100644 --- a/run.py +++ b/run.py @@ -1,215 +1,215 @@ -import sys -import time -from argparse import Namespace -from pathlib import Path - -import numpy as np -from flatland.core.env_observation_builder import DummyObservationBuilder -from flatland.envs.predictions import ShortestPathPredictorForRailEnv -from flatland.envs.rail_env import RailEnvActions -from flatland.evaluators.client import FlatlandRemoteClient -from flatland.evaluators.client import TimeoutException - -from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent -from utils.deadlock_check import check_if_all_blocked -from utils.fast_tree_obs import FastTreeObs - -base_dir = Path(__file__).resolve().parent.parent -sys.path.append(str(base_dir)) - -from reinforcement_learning.dddqn_policy import DDDQNPolicy - -#################################################### -# EVALUATION PARAMETERS - -# Print per-step logs -VERBOSE = True - -# Checkpoint to use (remember to push it!) -checkpoint = "./checkpoints/201106234244-400.pth" # 15.64082361736683 Depth 1 -checkpoint = "./checkpoints/201106234900-200.pth" # 15.64082361736683 Depth 1 - -# Use last action cache -USE_ACTION_CACHE = False -USE_DEAD_LOCK_AVOIDANCE_AGENT = False - -# Observation parameters (must match training parameters!) -observation_tree_depth = 1 -observation_radius = 10 -observation_max_path_depth = 30 - -#################################################### - -remote_client = FlatlandRemoteClient() - -# Observation builder -predictor = ShortestPathPredictorForRailEnv(observation_max_path_depth) -tree_observation = FastTreeObs(max_depth=observation_tree_depth) - -# Calculates state and action sizes -state_size = tree_observation.observation_dim -action_size = 5 - -# Creates the policy. No GPU on evaluation server. -policy = DDDQNPolicy(state_size, action_size, Namespace(**{'use_gpu': False}), evaluation_mode=True) -# policy = PPOAgent(state_size, action_size, 10) -policy.load(checkpoint) - -##################################################################### -# Main evaluation loop -##################################################################### -evaluation_number = 0 - -while True: - evaluation_number += 1 - - # We use a dummy observation and call TreeObsForRailEnv ourselves when needed. - # This way we decide if we want to calculate the observations or not instead - # of having them calculated every time we perform an env step. - time_start = time.time() - observation, info = remote_client.env_create( - obs_builder_object=DummyObservationBuilder() - ) - env_creation_time = time.time() - time_start - - if not observation: - # If the remote_client returns False on a `env_create` call, - # then it basically means that your agent has already been - # evaluated on all the required evaluation environments, - # and hence it's safe to break out of the main evaluation loop. - break - - print("Env Path : ", remote_client.current_env_path) - print("Env Creation Time : ", env_creation_time) - - local_env = remote_client.env - nb_agents = len(local_env.agents) - max_nb_steps = local_env._max_episode_steps - - tree_observation.set_env(local_env) - tree_observation.reset() - observation = tree_observation.get_many(list(range(nb_agents))) - - print("Evaluation {}: {} agents in {}x{}".format(evaluation_number, nb_agents, local_env.width, local_env.height)) - - # Now we enter into another infinite loop where we - # compute the actions for all the individual steps in this episode - # until the episode is `done` - steps = 0 - - # Bookkeeping - time_taken_by_controller = [] - time_taken_per_step = [] - - # Action cache: keep track of last observation to avoid running the same inferrence multiple times. - # This only makes sense for deterministic policies. - agent_last_obs = {} - agent_last_action = {} - nb_hit = 0 - - if USE_DEAD_LOCK_AVOIDANCE_AGENT: - policy = DeadLockAvoidanceAgent(local_env) - - while True: - try: - ##################################################################### - # Evaluation of a single episode - ##################################################################### - steps += 1 - obs_time, agent_time, step_time = 0.0, 0.0, 0.0 - no_ops_mode = False - - if not check_if_all_blocked(env=local_env): - time_start = time.time() - action_dict = {} - policy.start_step() - if USE_DEAD_LOCK_AVOIDANCE_AGENT: - observation = np.zeros((local_env.get_num_agents(), 2)) - for agent in range(nb_agents): - - if USE_DEAD_LOCK_AVOIDANCE_AGENT: - observation[agent][0] = agent - observation[agent][1] = steps - - if info['action_required'][agent]: - if agent in agent_last_obs and np.all(agent_last_obs[agent] == observation[agent]): - # cache hit - action = agent_last_action[agent] - nb_hit += 1 - else: - action = policy.act(observation[agent], eps=0.01) - if observation[agent][26] == 1: - action = RailEnvActions.STOP_MOVING - - action_dict[agent] = action - - if USE_ACTION_CACHE: - agent_last_obs[agent] = observation[agent] - agent_last_action[agent] = action - policy.end_step() - agent_time = time.time() - time_start - time_taken_by_controller.append(agent_time) - - time_start = time.time() - _, all_rewards, done, info = remote_client.env_step(action_dict) - step_time = time.time() - time_start - time_taken_per_step.append(step_time) - - time_start = time.time() - observation = tree_observation.get_many(list(range(nb_agents))) - obs_time = time.time() - time_start - - else: - # Fully deadlocked: perform no-ops - no_ops_mode = True - - time_start = time.time() - _, all_rewards, done, info = remote_client.env_step({}) - step_time = time.time() - time_start - time_taken_per_step.append(step_time) - - nb_agents_done = sum(done[idx] for idx in local_env.get_agent_handles()) - - if VERBOSE or done['__all__']: - print( - "Step {}/{}\tAgents done: {}\t Obs time {:.3f}s\t Inference time {:.5f}s\t Step time {:.3f}s\t Cache hits {}\t No-ops? {}".format( - str(steps).zfill(4), - max_nb_steps, - nb_agents_done, - obs_time, - agent_time, - step_time, - nb_hit, - no_ops_mode - ), end="\r") - - if done['__all__']: - # When done['__all__'] == True, then the evaluation of this - # particular Env instantiation is complete, and we can break out - # of this loop, and move onto the next Env evaluation - print() - break - - except TimeoutException as err: - # A timeout occurs, won't get any reward for this episode :-( - # Skip to next episode as further actions in this one will be ignored. - # The whole evaluation will be stopped if there are 10 consecutive timeouts. - print("Timeout! Will skip this episode and go to the next.", err) - break - - np_time_taken_by_controller = np.array(time_taken_by_controller) - np_time_taken_per_step = np.array(time_taken_per_step) - print("Mean/Std of Time taken by Controller : ", np_time_taken_by_controller.mean(), - np_time_taken_by_controller.std()) - print("Mean/Std of Time per Step : ", np_time_taken_per_step.mean(), np_time_taken_per_step.std()) - print("=" * 100) - -print("Evaluation of all environments complete!") -######################################################################## -# Submit your Results -# -# Please do not forget to include this call, as this triggers the -# final computation of the score statistics, video generation, etc -# and is necessary to have your submission marked as successfully evaluated -######################################################################## -print(remote_client.submit()) +import sys +import time +from argparse import Namespace +from pathlib import Path + +import numpy as np +from flatland.core.env_observation_builder import DummyObservationBuilder +from flatland.envs.predictions import ShortestPathPredictorForRailEnv +from flatland.envs.rail_env import RailEnvActions +from flatland.evaluators.client import FlatlandRemoteClient +from flatland.evaluators.client import TimeoutException + +from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent +from utils.deadlock_check import check_if_all_blocked +from utils.fast_tree_obs import FastTreeObs + +base_dir = Path(__file__).resolve().parent.parent +sys.path.append(str(base_dir)) + +from reinforcement_learning.dddqn_policy import DDDQNPolicy + +#################################################### +# EVALUATION PARAMETERS + +# Print per-step logs +VERBOSE = True + +# Checkpoint to use (remember to push it!) +checkpoint = "./checkpoints/201106234244-400.pth" # 15.64082361736683 Depth 1 +checkpoint = "./checkpoints/201106234900-200.pth" # 15.64082361736683 Depth 1 + +# Use last action cache +USE_ACTION_CACHE = False +USE_DEAD_LOCK_AVOIDANCE_AGENT = False + +# Observation parameters (must match training parameters!) +observation_tree_depth = 1 +observation_radius = 10 +observation_max_path_depth = 30 + +#################################################### + +remote_client = FlatlandRemoteClient() + +# Observation builder +predictor = ShortestPathPredictorForRailEnv(observation_max_path_depth) +tree_observation = FastTreeObs(max_depth=observation_tree_depth) + +# Calculates state and action sizes +state_size = tree_observation.observation_dim +action_size = 5 + +# Creates the policy. No GPU on evaluation server. +policy = DDDQNPolicy(state_size, action_size, Namespace(**{'use_gpu': False}), evaluation_mode=True) +# policy = PPOAgent(state_size, action_size, 10) +policy.load(checkpoint) + +##################################################################### +# Main evaluation loop +##################################################################### +evaluation_number = 0 + +while True: + evaluation_number += 1 + + # We use a dummy observation and call TreeObsForRailEnv ourselves when needed. + # This way we decide if we want to calculate the observations or not instead + # of having them calculated every time we perform an env step. + time_start = time.time() + observation, info = remote_client.env_create( + obs_builder_object=DummyObservationBuilder() + ) + env_creation_time = time.time() - time_start + + if not observation: + # If the remote_client returns False on a `env_create` call, + # then it basically means that your agent has already been + # evaluated on all the required evaluation environments, + # and hence it's safe to break out of the main evaluation loop. + break + + print("Env Path : ", remote_client.current_env_path) + print("Env Creation Time : ", env_creation_time) + + local_env = remote_client.env + nb_agents = len(local_env.agents) + max_nb_steps = local_env._max_episode_steps + + tree_observation.set_env(local_env) + tree_observation.reset() + observation = tree_observation.get_many(list(range(nb_agents))) + + print("Evaluation {}: {} agents in {}x{}".format(evaluation_number, nb_agents, local_env.width, local_env.height)) + + # Now we enter into another infinite loop where we + # compute the actions for all the individual steps in this episode + # until the episode is `done` + steps = 0 + + # Bookkeeping + time_taken_by_controller = [] + time_taken_per_step = [] + + # Action cache: keep track of last observation to avoid running the same inferrence multiple times. + # This only makes sense for deterministic policies. + agent_last_obs = {} + agent_last_action = {} + nb_hit = 0 + + if USE_DEAD_LOCK_AVOIDANCE_AGENT: + policy = DeadLockAvoidanceAgent(local_env) + + while True: + try: + ##################################################################### + # Evaluation of a single episode + ##################################################################### + steps += 1 + obs_time, agent_time, step_time = 0.0, 0.0, 0.0 + no_ops_mode = False + + if not check_if_all_blocked(env=local_env): + time_start = time.time() + action_dict = {} + policy.start_step() + if USE_DEAD_LOCK_AVOIDANCE_AGENT: + observation = np.zeros((local_env.get_num_agents(), 2)) + for agent in range(nb_agents): + + if USE_DEAD_LOCK_AVOIDANCE_AGENT: + observation[agent][0] = agent + observation[agent][1] = steps + + if info['action_required'][agent]: + if agent in agent_last_obs and np.all(agent_last_obs[agent] == observation[agent]): + # cache hit + action = agent_last_action[agent] + nb_hit += 1 + else: + action = policy.act(observation[agent], eps=0.0) + #if observation[agent][26] == 1: + # action = RailEnvActions.STOP_MOVING + + action_dict[agent] = action + + if USE_ACTION_CACHE: + agent_last_obs[agent] = observation[agent] + agent_last_action[agent] = action + policy.end_step() + agent_time = time.time() - time_start + time_taken_by_controller.append(agent_time) + + time_start = time.time() + _, all_rewards, done, info = remote_client.env_step(action_dict) + step_time = time.time() - time_start + time_taken_per_step.append(step_time) + + time_start = time.time() + observation = tree_observation.get_many(list(range(nb_agents))) + obs_time = time.time() - time_start + + else: + # Fully deadlocked: perform no-ops + no_ops_mode = True + + time_start = time.time() + _, all_rewards, done, info = remote_client.env_step({}) + step_time = time.time() - time_start + time_taken_per_step.append(step_time) + + nb_agents_done = sum(done[idx] for idx in local_env.get_agent_handles()) + + if VERBOSE or done['__all__']: + print( + "Step {}/{}\tAgents done: {}\t Obs time {:.3f}s\t Inference time {:.5f}s\t Step time {:.3f}s\t Cache hits {}\t No-ops? {}".format( + str(steps).zfill(4), + max_nb_steps, + nb_agents_done, + obs_time, + agent_time, + step_time, + nb_hit, + no_ops_mode + ), end="\r") + + if done['__all__']: + # When done['__all__'] == True, then the evaluation of this + # particular Env instantiation is complete, and we can break out + # of this loop, and move onto the next Env evaluation + print() + break + + except TimeoutException as err: + # A timeout occurs, won't get any reward for this episode :-( + # Skip to next episode as further actions in this one will be ignored. + # The whole evaluation will be stopped if there are 10 consecutive timeouts. + print("Timeout! Will skip this episode and go to the next.", err) + break + + np_time_taken_by_controller = np.array(time_taken_by_controller) + np_time_taken_per_step = np.array(time_taken_per_step) + print("Mean/Std of Time taken by Controller : ", np_time_taken_by_controller.mean(), + np_time_taken_by_controller.std()) + print("Mean/Std of Time per Step : ", np_time_taken_per_step.mean(), np_time_taken_per_step.std()) + print("=" * 100) + +print("Evaluation of all environments complete!") +######################################################################## +# Submit your Results +# +# Please do not forget to include this call, as this triggers the +# final computation of the score statistics, video generation, etc +# and is necessary to have your submission marked as successfully evaluated +######################################################################## +print(remote_client.submit()) diff --git a/utils/dead_lock_avoidance_agent.py b/utils/dead_lock_avoidance_agent.py index 700600c..1d0b52c 100644 --- a/utils/dead_lock_avoidance_agent.py +++ b/utils/dead_lock_avoidance_agent.py @@ -1,175 +1,175 @@ -from typing import Optional, List - -import matplotlib.pyplot as plt -import numpy as np -from flatland.core.env_observation_builder import DummyObservationBuilder -from flatland.envs.agent_utils import RailAgentStatus -from flatland.envs.rail_env import RailEnv, RailEnvActions, fast_count_nonzero - -from reinforcement_learning.policy import Policy -from utils.shortest_distance_walker import ShortestDistanceWalker - - -class DeadlockAvoidanceObservation(DummyObservationBuilder): - def __init__(self): - self.counter = 0 - - def get_many(self, handles: Optional[List[int]] = None) -> bool: - self.counter += 1 - obs = np.ones(len(handles), 2) - for handle in handles: - obs[handle][0] = handle - obs[handle][1] = self.counter - return obs - - -class DeadlockAvoidanceShortestDistanceWalker(ShortestDistanceWalker): - def __init__(self, env: RailEnv, agent_positions, switches): - super().__init__(env) - self.shortest_distance_agent_map = np.zeros((self.env.get_num_agents(), - self.env.height, - self.env.width), - dtype=int) - 1 - - self.full_shortest_distance_agent_map = np.zeros((self.env.get_num_agents(), - self.env.height, - self.env.width), - dtype=int) - 1 - - self.agent_positions = agent_positions - - self.opp_agent_map = {} - self.same_agent_map = {} - self.switches = switches - - def getData(self): - return self.shortest_distance_agent_map, self.full_shortest_distance_agent_map - - def callback(self, handle, agent, position, direction, action, possible_transitions): - opp_a = self.agent_positions[position] - if opp_a != -1 and opp_a != handle: - if self.env.agents[opp_a].direction != direction: - d = self.opp_agent_map.get(handle, []) - if opp_a not in d: - d.append(opp_a) - self.opp_agent_map.update({handle: d}) - else: - if len(self.opp_agent_map.get(handle, [])) == 0: - d = self.same_agent_map.get(handle, []) - if opp_a not in d: - d.append(opp_a) - self.same_agent_map.update({handle: d}) - - if len(self.opp_agent_map.get(handle, [])) == 0: - if self.switches.get(position, None) is None: - self.shortest_distance_agent_map[(handle, position[0], position[1])] = 1 - self.full_shortest_distance_agent_map[(handle, position[0], position[1])] = 1 - - -class DeadLockAvoidanceAgent(Policy): - def __init__(self, env: RailEnv, show_debug_plot=False): - self.env = env - self.memory = None - self.loss = 0 - self.agent_can_move = {} - self.switches = {} - self.show_debug_plot = show_debug_plot - - def step(self, state, action, reward, next_state, done): - pass - - def act(self, state, eps=0.): - # agent = self.env.agents[state[0]] - check = self.agent_can_move.get(state[0], None) - if check is None: - return RailEnvActions.STOP_MOVING - return check[3] - - def reset(self): - self.agent_positions = None - self.shortest_distance_walker = None - self.switches = {} - for h in range(self.env.height): - for w in range(self.env.width): - pos = (h, w) - for dir in range(4): - possible_transitions = self.env.rail.get_transitions(*pos, dir) - num_transitions = fast_count_nonzero(possible_transitions) - if num_transitions > 1: - if pos not in self.switches.keys(): - self.switches.update({pos: [dir]}) - else: - self.switches[pos].append(dir) - - def start_step(self): - self.build_agent_position_map() - self.shortest_distance_mapper() - self.extract_agent_can_move() - - def end_step(self): - pass - - def get_actions(self): - pass - - def build_agent_position_map(self): - # build map with agent positions (only active agents) - self.agent_positions = np.zeros((self.env.height, self.env.width), dtype=int) - 1 - for handle in range(self.env.get_num_agents()): - agent = self.env.agents[handle] - if agent.status == RailAgentStatus.ACTIVE: - if agent.position is not None: - self.agent_positions[agent.position] = handle - - def shortest_distance_mapper(self): - self.shortest_distance_walker = DeadlockAvoidanceShortestDistanceWalker(self.env, - self.agent_positions, - self.switches) - for handle in range(self.env.get_num_agents()): - agent = self.env.agents[handle] - if agent.status <= RailAgentStatus.ACTIVE: - self.shortest_distance_walker.walk_to_target(handle) - - def extract_agent_can_move(self): - self.agent_can_move = {} - shortest_distance_agent_map, full_shortest_distance_agent_map = self.shortest_distance_walker.getData() - for handle in range(self.env.get_num_agents()): - agent = self.env.agents[handle] - if agent.status < RailAgentStatus.DONE: - next_step_ok = self.check_agent_can_move(shortest_distance_agent_map[handle], - self.shortest_distance_walker.same_agent_map.get(handle, []), - self.shortest_distance_walker.opp_agent_map.get(handle, []), - full_shortest_distance_agent_map) - if next_step_ok: - next_position, next_direction, action, _ = self.shortest_distance_walker.walk_one_step(handle) - self.agent_can_move.update({handle: [next_position[0], next_position[1], next_direction, action]}) - - if self.show_debug_plot: - a = np.floor(np.sqrt(self.env.get_num_agents())) - b = np.ceil(self.env.get_num_agents() / a) - for handle in range(self.env.get_num_agents()): - plt.subplot(a, b, handle + 1) - plt.imshow(full_shortest_distance_agent_map[handle] + shortest_distance_agent_map[handle]) - plt.show(block=False) - plt.pause(0.01) - - def check_agent_can_move(self, - my_shortest_walking_path, - same_agents, - opp_agents, - full_shortest_distance_agent_map): - agent_positions_map = (self.agent_positions > -1).astype(int) - delta = my_shortest_walking_path - next_step_ok = True - for opp_a in opp_agents: - opp = full_shortest_distance_agent_map[opp_a] - delta = ((my_shortest_walking_path - opp - agent_positions_map) > 0).astype(int) - if np.sum(delta) < (3 + len(opp_agents)): - next_step_ok = False - return next_step_ok - - def save(self, filename): - pass - - def load(self, filename): - pass +from typing import Optional, List + +import matplotlib.pyplot as plt +import numpy as np +from flatland.core.env_observation_builder import DummyObservationBuilder +from flatland.envs.agent_utils import RailAgentStatus +from flatland.envs.rail_env import RailEnv, RailEnvActions, fast_count_nonzero + +from reinforcement_learning.policy import Policy +from utils.shortest_distance_walker import ShortestDistanceWalker + + +class DeadlockAvoidanceObservation(DummyObservationBuilder): + def __init__(self): + self.counter = 0 + + def get_many(self, handles: Optional[List[int]] = None) -> bool: + self.counter += 1 + obs = np.ones(len(handles), 2) + for handle in handles: + obs[handle][0] = handle + obs[handle][1] = self.counter + return obs + + +class DeadlockAvoidanceShortestDistanceWalker(ShortestDistanceWalker): + def __init__(self, env: RailEnv, agent_positions, switches): + super().__init__(env) + self.shortest_distance_agent_map = np.zeros((self.env.get_num_agents(), + self.env.height, + self.env.width), + dtype=int) - 1 + + self.full_shortest_distance_agent_map = np.zeros((self.env.get_num_agents(), + self.env.height, + self.env.width), + dtype=int) - 1 + + self.agent_positions = agent_positions + + self.opp_agent_map = {} + self.same_agent_map = {} + self.switches = switches + + def getData(self): + return self.shortest_distance_agent_map, self.full_shortest_distance_agent_map + + def callback(self, handle, agent, position, direction, action, possible_transitions): + opp_a = self.agent_positions[position] + if opp_a != -1 and opp_a != handle: + if self.env.agents[opp_a].direction != direction: + d = self.opp_agent_map.get(handle, []) + if opp_a not in d: + d.append(opp_a) + self.opp_agent_map.update({handle: d}) + else: + if len(self.opp_agent_map.get(handle, [])) == 0: + d = self.same_agent_map.get(handle, []) + if opp_a not in d: + d.append(opp_a) + self.same_agent_map.update({handle: d}) + + if len(self.opp_agent_map.get(handle, [])) == 0: + if self.switches.get(position, None) is None: + self.shortest_distance_agent_map[(handle, position[0], position[1])] = 1 + self.full_shortest_distance_agent_map[(handle, position[0], position[1])] = 1 + + +class DeadLockAvoidanceAgent(Policy): + def __init__(self, env: RailEnv, show_debug_plot=False): + self.env = env + self.memory = None + self.loss = 0 + self.agent_can_move = {} + self.switches = {} + self.show_debug_plot = show_debug_plot + + def step(self, state, action, reward, next_state, done): + pass + + def act(self, state, eps=0.): + # agent = self.env.agents[state[0]] + check = self.agent_can_move.get(state[0], None) + if check is None: + return RailEnvActions.STOP_MOVING + return check[3] + + def reset(self): + self.agent_positions = None + self.shortest_distance_walker = None + self.switches = {} + for h in range(self.env.height): + for w in range(self.env.width): + pos = (h, w) + for dir in range(4): + possible_transitions = self.env.rail.get_transitions(*pos, dir) + num_transitions = fast_count_nonzero(possible_transitions) + if num_transitions > 1: + if pos not in self.switches.keys(): + self.switches.update({pos: [dir]}) + else: + self.switches[pos].append(dir) + + def start_step(self): + self.build_agent_position_map() + self.shortest_distance_mapper() + self.extract_agent_can_move() + + def end_step(self): + pass + + def get_actions(self): + pass + + def build_agent_position_map(self): + # build map with agent positions (only active agents) + self.agent_positions = np.zeros((self.env.height, self.env.width), dtype=int) - 1 + for handle in range(self.env.get_num_agents()): + agent = self.env.agents[handle] + if agent.status == RailAgentStatus.ACTIVE: + if agent.position is not None: + self.agent_positions[agent.position] = handle + + def shortest_distance_mapper(self): + self.shortest_distance_walker = DeadlockAvoidanceShortestDistanceWalker(self.env, + self.agent_positions, + self.switches) + for handle in range(self.env.get_num_agents()): + agent = self.env.agents[handle] + if agent.status <= RailAgentStatus.ACTIVE: + self.shortest_distance_walker.walk_to_target(handle) + + def extract_agent_can_move(self): + self.agent_can_move = {} + shortest_distance_agent_map, full_shortest_distance_agent_map = self.shortest_distance_walker.getData() + for handle in range(self.env.get_num_agents()): + agent = self.env.agents[handle] + if agent.status < RailAgentStatus.DONE: + next_step_ok = self.check_agent_can_move(shortest_distance_agent_map[handle], + self.shortest_distance_walker.same_agent_map.get(handle, []), + self.shortest_distance_walker.opp_agent_map.get(handle, []), + full_shortest_distance_agent_map) + if next_step_ok: + next_position, next_direction, action, _ = self.shortest_distance_walker.walk_one_step(handle) + self.agent_can_move.update({handle: [next_position[0], next_position[1], next_direction, action]}) + + if self.show_debug_plot: + a = np.floor(np.sqrt(self.env.get_num_agents())) + b = np.ceil(self.env.get_num_agents() / a) + for handle in range(self.env.get_num_agents()): + plt.subplot(a, b, handle + 1) + plt.imshow(full_shortest_distance_agent_map[handle] + shortest_distance_agent_map[handle]) + plt.show(block=False) + plt.pause(0.01) + + def check_agent_can_move(self, + my_shortest_walking_path, + same_agents, + opp_agents, + full_shortest_distance_agent_map): + agent_positions_map = (self.agent_positions > -1).astype(int) + delta = my_shortest_walking_path + next_step_ok = True + for opp_a in opp_agents: + opp = full_shortest_distance_agent_map[opp_a] + delta = ((my_shortest_walking_path - opp - agent_positions_map) > 0).astype(int) + if np.sum(delta) < (3 + len(opp_agents)): + next_step_ok = False + return next_step_ok + + def save(self, filename): + pass + + def load(self, filename): + pass diff --git a/utils/fast_tree_obs.py b/utils/fast_tree_obs.py index b6d673d..3b14b0f 100755 --- a/utils/fast_tree_obs.py +++ b/utils/fast_tree_obs.py @@ -1,308 +1,308 @@ -import numpy as np -from flatland.core.env_observation_builder import ObservationBuilder -from flatland.core.grid.grid4_utils import get_new_position -from flatland.envs.agent_utils import RailAgentStatus -from flatland.envs.rail_env import fast_count_nonzero, fast_argmax, RailEnvActions - -from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent - -""" -LICENCE for the FastTreeObs Observation Builder - -The observation can be used freely and reused for further submissions. Only the author needs to be referred to -/mentioned in any submissions - if the entire observation or parts, or the main idea is used. - -Author: Adrian Egli (adrian.egli@gmail.com) - -[Linkedin](https://www.researchgate.net/profile/Adrian_Egli2) -[Researchgate](https://www.linkedin.com/in/adrian-egli-733a9544/) -""" - - -class FastTreeObs(ObservationBuilder): - - def __init__(self, max_depth): - self.max_depth = max_depth - self.observation_dim = 27 - - def build_data(self): - if self.env is not None: - self.env.dev_obs_dict = {} - self.switches = {} - self.switches_neighbours = {} - self.debug_render_list = [] - self.debug_render_path_list = [] - if self.env is not None: - self.find_all_cell_where_agent_can_choose() - self.dead_lock_avoidance_agent = DeadLockAvoidanceAgent(self.env) - else: - self.dead_lock_avoidance_agent = None - - def find_all_cell_where_agent_can_choose(self): - switches = {} - for h in range(self.env.height): - for w in range(self.env.width): - pos = (h, w) - for dir in range(4): - possible_transitions = self.env.rail.get_transitions(*pos, dir) - num_transitions = fast_count_nonzero(possible_transitions) - if num_transitions > 1: - if pos not in switches.keys(): - switches.update({pos: [dir]}) - else: - switches[pos].append(dir) - - switches_neighbours = {} - for h in range(self.env.height): - for w in range(self.env.width): - # look one step forward - for dir in range(4): - pos = (h, w) - possible_transitions = self.env.rail.get_transitions(*pos, dir) - for d in range(4): - if possible_transitions[d] == 1: - new_cell = get_new_position(pos, d) - if new_cell in switches.keys() and pos not in switches.keys(): - if pos not in switches_neighbours.keys(): - switches_neighbours.update({pos: [dir]}) - else: - switches_neighbours[pos].append(dir) - - self.switches = switches - self.switches_neighbours = switches_neighbours - - def check_agent_decision(self, position, direction): - switches = self.switches - switches_neighbours = self.switches_neighbours - agents_on_switch = False - agents_on_switch_all = False - agents_near_to_switch = False - agents_near_to_switch_all = False - if position in switches.keys(): - agents_on_switch = direction in switches[position] - agents_on_switch_all = True - - if position in switches_neighbours.keys(): - new_cell = get_new_position(position, direction) - if new_cell in switches.keys(): - if not direction in switches[new_cell]: - agents_near_to_switch = direction in switches_neighbours[position] - else: - agents_near_to_switch = direction in switches_neighbours[position] - - agents_near_to_switch_all = direction in switches_neighbours[position] - - return agents_on_switch, agents_near_to_switch, agents_near_to_switch_all, agents_on_switch_all - - def required_agent_decision(self): - agents_can_choose = {} - agents_on_switch = {} - agents_on_switch_all = {} - agents_near_to_switch = {} - agents_near_to_switch_all = {} - for a in range(self.env.get_num_agents()): - ret_agents_on_switch, ret_agents_near_to_switch, ret_agents_near_to_switch_all, ret_agents_on_switch_all = \ - self.check_agent_decision( - self.env.agents[a].position, - self.env.agents[a].direction) - agents_on_switch.update({a: ret_agents_on_switch}) - agents_on_switch_all.update({a: ret_agents_on_switch_all}) - ready_to_depart = self.env.agents[a].status == RailAgentStatus.READY_TO_DEPART - agents_near_to_switch.update({a: (ret_agents_near_to_switch and not ready_to_depart)}) - - agents_can_choose.update({a: agents_on_switch[a] or agents_near_to_switch[a]}) - - agents_near_to_switch_all.update({a: (ret_agents_near_to_switch_all and not ready_to_depart)}) - - return agents_can_choose, agents_on_switch, agents_near_to_switch, agents_near_to_switch_all, agents_on_switch_all - - def debug_render(self, env_renderer): - agents_can_choose, agents_on_switch, agents_near_to_switch, agents_near_to_switch_all = \ - self.required_agent_decision() - self.env.dev_obs_dict = {} - for a in range(max(3, self.env.get_num_agents())): - self.env.dev_obs_dict.update({a: []}) - - selected_agent = None - if agents_can_choose[0]: - if self.env.agents[0].position is not None: - self.debug_render_list.append(self.env.agents[0].position) - else: - self.debug_render_list.append(self.env.agents[0].initial_position) - - if self.env.agents[0].position is not None: - self.debug_render_path_list.append(self.env.agents[0].position) - else: - self.debug_render_path_list.append(self.env.agents[0].initial_position) - - env_renderer.gl.agent_colors[0] = env_renderer.gl.rgb_s2i("FF0000") - env_renderer.gl.agent_colors[1] = env_renderer.gl.rgb_s2i("666600") - env_renderer.gl.agent_colors[2] = env_renderer.gl.rgb_s2i("006666") - env_renderer.gl.agent_colors[3] = env_renderer.gl.rgb_s2i("550000") - - self.env.dev_obs_dict[0] = self.debug_render_list - self.env.dev_obs_dict[1] = self.switches.keys() - self.env.dev_obs_dict[2] = self.switches_neighbours.keys() - self.env.dev_obs_dict[3] = self.debug_render_path_list - - def reset(self): - self.build_data() - return - - def fast_argmax(self, array): - if array[0] == 1: - return 0 - if array[1] == 1: - return 1 - if array[2] == 1: - return 2 - return 3 - - def _explore(self, handle, new_position, new_direction, depth=0): - has_opp_agent = 0 - has_same_agent = 0 - has_switch = 0 - visited = [] - - # stop exploring (max_depth reached) - if depth >= self.max_depth: - return has_opp_agent, has_same_agent, has_switch, visited - - # max_explore_steps = 100 - cnt = 0 - while cnt < 100: - cnt += 1 - - visited.append(new_position) - opp_a = self.env.agent_positions[new_position] - if opp_a != -1 and opp_a != handle: - if self.env.agents[opp_a].direction != new_direction: - # opp agent found - has_opp_agent = 1 - return has_opp_agent, has_same_agent, has_switch, visited - else: - has_same_agent = 1 - return has_opp_agent, has_same_agent, has_switch, visited - - # convert one-hot encoding to 0,1,2,3 - agents_on_switch, \ - agents_near_to_switch, \ - agents_near_to_switch_all, \ - agents_on_switch_all = \ - self.check_agent_decision(new_position, new_direction) - if agents_near_to_switch: - return has_opp_agent, has_same_agent, has_switch, visited - - possible_transitions = self.env.rail.get_transitions(*new_position, new_direction) - if agents_on_switch: - f = 0 - for dir_loop in range(4): - if possible_transitions[dir_loop] == 1: - f += 1 - hoa, hsa, hs, v = self._explore(handle, - get_new_position(new_position, dir_loop), - dir_loop, - depth + 1) - visited.append(v) - has_opp_agent += hoa - has_same_agent += hsa - has_switch += hs - f = max(f, 1.0) - return has_opp_agent / f, has_same_agent / f, has_switch / f, visited - else: - new_direction = fast_argmax(possible_transitions) - new_position = get_new_position(new_position, new_direction) - - return has_opp_agent, has_same_agent, has_switch, visited - - def get(self, handle): - # all values are [0,1] - # observation[0] : 1 path towards target (direction 0) / otherwise 0 -> path is longer or there is no path - # observation[1] : 1 path towards target (direction 1) / otherwise 0 -> path is longer or there is no path - # observation[2] : 1 path towards target (direction 2) / otherwise 0 -> path is longer or there is no path - # observation[3] : 1 path towards target (direction 3) / otherwise 0 -> path is longer or there is no path - # observation[4] : int(agent.status == RailAgentStatus.READY_TO_DEPART) - # observation[5] : int(agent.status == RailAgentStatus.ACTIVE) - # observation[6] : int(agent.status == RailAgentStatus.DONE or agent.status == RailAgentStatus.DONE_REMOVED) - # observation[7] : current agent is located at a switch, where it can take a routing decision - # observation[8] : current agent is located at a cell, where it has to take a stop-or-go decision - # observation[9] : current agent is located one step before/after a switch - # observation[10] : 1 if there is a path (track/branch) otherwise 0 (direction 0) - # observation[11] : 1 if there is a path (track/branch) otherwise 0 (direction 1) - # observation[12] : 1 if there is a path (track/branch) otherwise 0 (direction 2) - # observation[13] : 1 if there is a path (track/branch) otherwise 0 (direction 3) - # observation[14] : If there is a path with step (direction 0) and there is a agent with opposite direction -> 1 - # observation[15] : If there is a path with step (direction 1) and there is a agent with opposite direction -> 1 - # observation[16] : If there is a path with step (direction 2) and there is a agent with opposite direction -> 1 - # observation[17] : If there is a path with step (direction 3) and there is a agent with opposite direction -> 1 - # observation[18] : If there is a path with step (direction 0) and there is a agent with same direction -> 1 - # observation[19] : If there is a path with step (direction 1) and there is a agent with same direction -> 1 - # observation[20] : If there is a path with step (direction 2) and there is a agent with same direction -> 1 - # observation[21] : If there is a path with step (direction 3) and there is a agent with same direction -> 1 - # observation[22] : If there is a switch on the path which agent can not use -> 1 - # observation[23] : If there is a switch on the path which agent can not use -> 1 - # observation[24] : If there is a switch on the path which agent can not use -> 1 - # observation[25] : If there is a switch on the path which agent can not use -> 1 - # observation[26] : If there the dead-lock avoidance agent predicts a deadlock -> 1 - - if handle == 0: - self.dead_lock_avoidance_agent.start_step() - - observation = np.zeros(self.observation_dim) - visited = [] - agent = self.env.agents[handle] - - agent_done = False - if agent.status == RailAgentStatus.READY_TO_DEPART: - agent_virtual_position = agent.initial_position - observation[4] = 1 - elif agent.status == RailAgentStatus.ACTIVE: - agent_virtual_position = agent.position - observation[5] = 1 - else: - observation[6] = 1 - agent_virtual_position = (-1, -1) - agent_done = True - - if not agent_done: - visited.append(agent_virtual_position) - distance_map = self.env.distance_map.get() - current_cell_dist = distance_map[handle, - agent_virtual_position[0], agent_virtual_position[1], - agent.direction] - possible_transitions = self.env.rail.get_transitions(*agent_virtual_position, agent.direction) - orientation = agent.direction - if fast_count_nonzero(possible_transitions) == 1: - orientation = fast_argmax(possible_transitions) - - for dir_loop, branch_direction in enumerate([(orientation + dir_loop) % 4 for dir_loop in range(-1, 3)]): - if possible_transitions[branch_direction]: - new_position = get_new_position(agent_virtual_position, branch_direction) - new_cell_dist = distance_map[handle, - new_position[0], new_position[1], - branch_direction] - if not (np.math.isinf(new_cell_dist) and np.math.isinf(current_cell_dist)): - observation[dir_loop] = int(new_cell_dist < current_cell_dist) - - has_opp_agent, has_same_agent, has_switch, v = self._explore(handle, new_position, branch_direction) - visited.append(v) - - observation[10 + dir_loop] = int(not np.math.isinf(new_cell_dist)) - observation[14 + dir_loop] = has_opp_agent - observation[18 + dir_loop] = has_same_agent - observation[22 + dir_loop] = has_switch - - agents_on_switch, \ - agents_near_to_switch, \ - agents_near_to_switch_all, \ - agents_on_switch_all = \ - self.check_agent_decision(agent_virtual_position, agent.direction) - observation[7] = int(agents_on_switch) - observation[8] = int(agents_near_to_switch) - observation[9] = int(agents_near_to_switch_all) - - action = self.dead_lock_avoidance_agent.act([handle], 0.0) - observation[26] = int(action == RailEnvActions.STOP_MOVING) - self.env.dev_obs_dict.update({handle: visited}) - - return observation +import numpy as np +from flatland.core.env_observation_builder import ObservationBuilder +from flatland.core.grid.grid4_utils import get_new_position +from flatland.envs.agent_utils import RailAgentStatus +from flatland.envs.rail_env import fast_count_nonzero, fast_argmax, RailEnvActions + +from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent + +""" +LICENCE for the FastTreeObs Observation Builder + +The observation can be used freely and reused for further submissions. Only the author needs to be referred to +/mentioned in any submissions - if the entire observation or parts, or the main idea is used. + +Author: Adrian Egli (adrian.egli@gmail.com) + +[Linkedin](https://www.researchgate.net/profile/Adrian_Egli2) +[Researchgate](https://www.linkedin.com/in/adrian-egli-733a9544/) +""" + + +class FastTreeObs(ObservationBuilder): + + def __init__(self, max_depth): + self.max_depth = max_depth + self.observation_dim = 27 + + def build_data(self): + if self.env is not None: + self.env.dev_obs_dict = {} + self.switches = {} + self.switches_neighbours = {} + self.debug_render_list = [] + self.debug_render_path_list = [] + if self.env is not None: + self.find_all_cell_where_agent_can_choose() + self.dead_lock_avoidance_agent = DeadLockAvoidanceAgent(self.env) + else: + self.dead_lock_avoidance_agent = None + + def find_all_cell_where_agent_can_choose(self): + switches = {} + for h in range(self.env.height): + for w in range(self.env.width): + pos = (h, w) + for dir in range(4): + possible_transitions = self.env.rail.get_transitions(*pos, dir) + num_transitions = fast_count_nonzero(possible_transitions) + if num_transitions > 1: + if pos not in switches.keys(): + switches.update({pos: [dir]}) + else: + switches[pos].append(dir) + + switches_neighbours = {} + for h in range(self.env.height): + for w in range(self.env.width): + # look one step forward + for dir in range(4): + pos = (h, w) + possible_transitions = self.env.rail.get_transitions(*pos, dir) + for d in range(4): + if possible_transitions[d] == 1: + new_cell = get_new_position(pos, d) + if new_cell in switches.keys() and pos not in switches.keys(): + if pos not in switches_neighbours.keys(): + switches_neighbours.update({pos: [dir]}) + else: + switches_neighbours[pos].append(dir) + + self.switches = switches + self.switches_neighbours = switches_neighbours + + def check_agent_decision(self, position, direction): + switches = self.switches + switches_neighbours = self.switches_neighbours + agents_on_switch = False + agents_on_switch_all = False + agents_near_to_switch = False + agents_near_to_switch_all = False + if position in switches.keys(): + agents_on_switch = direction in switches[position] + agents_on_switch_all = True + + if position in switches_neighbours.keys(): + new_cell = get_new_position(position, direction) + if new_cell in switches.keys(): + if not direction in switches[new_cell]: + agents_near_to_switch = direction in switches_neighbours[position] + else: + agents_near_to_switch = direction in switches_neighbours[position] + + agents_near_to_switch_all = direction in switches_neighbours[position] + + return agents_on_switch, agents_near_to_switch, agents_near_to_switch_all, agents_on_switch_all + + def required_agent_decision(self): + agents_can_choose = {} + agents_on_switch = {} + agents_on_switch_all = {} + agents_near_to_switch = {} + agents_near_to_switch_all = {} + for a in range(self.env.get_num_agents()): + ret_agents_on_switch, ret_agents_near_to_switch, ret_agents_near_to_switch_all, ret_agents_on_switch_all = \ + self.check_agent_decision( + self.env.agents[a].position, + self.env.agents[a].direction) + agents_on_switch.update({a: ret_agents_on_switch}) + agents_on_switch_all.update({a: ret_agents_on_switch_all}) + ready_to_depart = self.env.agents[a].status == RailAgentStatus.READY_TO_DEPART + agents_near_to_switch.update({a: (ret_agents_near_to_switch and not ready_to_depart)}) + + agents_can_choose.update({a: agents_on_switch[a] or agents_near_to_switch[a]}) + + agents_near_to_switch_all.update({a: (ret_agents_near_to_switch_all and not ready_to_depart)}) + + return agents_can_choose, agents_on_switch, agents_near_to_switch, agents_near_to_switch_all, agents_on_switch_all + + def debug_render(self, env_renderer): + agents_can_choose, agents_on_switch, agents_near_to_switch, agents_near_to_switch_all = \ + self.required_agent_decision() + self.env.dev_obs_dict = {} + for a in range(max(3, self.env.get_num_agents())): + self.env.dev_obs_dict.update({a: []}) + + selected_agent = None + if agents_can_choose[0]: + if self.env.agents[0].position is not None: + self.debug_render_list.append(self.env.agents[0].position) + else: + self.debug_render_list.append(self.env.agents[0].initial_position) + + if self.env.agents[0].position is not None: + self.debug_render_path_list.append(self.env.agents[0].position) + else: + self.debug_render_path_list.append(self.env.agents[0].initial_position) + + env_renderer.gl.agent_colors[0] = env_renderer.gl.rgb_s2i("FF0000") + env_renderer.gl.agent_colors[1] = env_renderer.gl.rgb_s2i("666600") + env_renderer.gl.agent_colors[2] = env_renderer.gl.rgb_s2i("006666") + env_renderer.gl.agent_colors[3] = env_renderer.gl.rgb_s2i("550000") + + self.env.dev_obs_dict[0] = self.debug_render_list + self.env.dev_obs_dict[1] = self.switches.keys() + self.env.dev_obs_dict[2] = self.switches_neighbours.keys() + self.env.dev_obs_dict[3] = self.debug_render_path_list + + def reset(self): + self.build_data() + return + + def fast_argmax(self, array): + if array[0] == 1: + return 0 + if array[1] == 1: + return 1 + if array[2] == 1: + return 2 + return 3 + + def _explore(self, handle, new_position, new_direction, depth=0): + has_opp_agent = 0 + has_same_agent = 0 + has_switch = 0 + visited = [] + + # stop exploring (max_depth reached) + if depth >= self.max_depth: + return has_opp_agent, has_same_agent, has_switch, visited + + # max_explore_steps = 100 + cnt = 0 + while cnt < 100: + cnt += 1 + + visited.append(new_position) + opp_a = self.env.agent_positions[new_position] + if opp_a != -1 and opp_a != handle: + if self.env.agents[opp_a].direction != new_direction: + # opp agent found + has_opp_agent = 1 + return has_opp_agent, has_same_agent, has_switch, visited + else: + has_same_agent = 1 + return has_opp_agent, has_same_agent, has_switch, visited + + # convert one-hot encoding to 0,1,2,3 + agents_on_switch, \ + agents_near_to_switch, \ + agents_near_to_switch_all, \ + agents_on_switch_all = \ + self.check_agent_decision(new_position, new_direction) + if agents_near_to_switch: + return has_opp_agent, has_same_agent, has_switch, visited + + possible_transitions = self.env.rail.get_transitions(*new_position, new_direction) + if agents_on_switch: + f = 0 + for dir_loop in range(4): + if possible_transitions[dir_loop] == 1: + f += 1 + hoa, hsa, hs, v = self._explore(handle, + get_new_position(new_position, dir_loop), + dir_loop, + depth + 1) + visited.append(v) + has_opp_agent += hoa + has_same_agent += hsa + has_switch += hs + f = max(f, 1.0) + return has_opp_agent / f, has_same_agent / f, has_switch / f, visited + else: + new_direction = fast_argmax(possible_transitions) + new_position = get_new_position(new_position, new_direction) + + return has_opp_agent, has_same_agent, has_switch, visited + + def get(self, handle): + # all values are [0,1] + # observation[0] : 1 path towards target (direction 0) / otherwise 0 -> path is longer or there is no path + # observation[1] : 1 path towards target (direction 1) / otherwise 0 -> path is longer or there is no path + # observation[2] : 1 path towards target (direction 2) / otherwise 0 -> path is longer or there is no path + # observation[3] : 1 path towards target (direction 3) / otherwise 0 -> path is longer or there is no path + # observation[4] : int(agent.status == RailAgentStatus.READY_TO_DEPART) + # observation[5] : int(agent.status == RailAgentStatus.ACTIVE) + # observation[6] : int(agent.status == RailAgentStatus.DONE or agent.status == RailAgentStatus.DONE_REMOVED) + # observation[7] : current agent is located at a switch, where it can take a routing decision + # observation[8] : current agent is located at a cell, where it has to take a stop-or-go decision + # observation[9] : current agent is located one step before/after a switch + # observation[10] : 1 if there is a path (track/branch) otherwise 0 (direction 0) + # observation[11] : 1 if there is a path (track/branch) otherwise 0 (direction 1) + # observation[12] : 1 if there is a path (track/branch) otherwise 0 (direction 2) + # observation[13] : 1 if there is a path (track/branch) otherwise 0 (direction 3) + # observation[14] : If there is a path with step (direction 0) and there is a agent with opposite direction -> 1 + # observation[15] : If there is a path with step (direction 1) and there is a agent with opposite direction -> 1 + # observation[16] : If there is a path with step (direction 2) and there is a agent with opposite direction -> 1 + # observation[17] : If there is a path with step (direction 3) and there is a agent with opposite direction -> 1 + # observation[18] : If there is a path with step (direction 0) and there is a agent with same direction -> 1 + # observation[19] : If there is a path with step (direction 1) and there is a agent with same direction -> 1 + # observation[20] : If there is a path with step (direction 2) and there is a agent with same direction -> 1 + # observation[21] : If there is a path with step (direction 3) and there is a agent with same direction -> 1 + # observation[22] : If there is a switch on the path which agent can not use -> 1 + # observation[23] : If there is a switch on the path which agent can not use -> 1 + # observation[24] : If there is a switch on the path which agent can not use -> 1 + # observation[25] : If there is a switch on the path which agent can not use -> 1 + # observation[26] : If there the dead-lock avoidance agent predicts a deadlock -> 1 + + if handle == 0: + self.dead_lock_avoidance_agent.start_step() + + observation = np.zeros(self.observation_dim) + visited = [] + agent = self.env.agents[handle] + + agent_done = False + if agent.status == RailAgentStatus.READY_TO_DEPART: + agent_virtual_position = agent.initial_position + observation[4] = 1 + elif agent.status == RailAgentStatus.ACTIVE: + agent_virtual_position = agent.position + observation[5] = 1 + else: + observation[6] = 1 + agent_virtual_position = (-1, -1) + agent_done = True + + if not agent_done: + visited.append(agent_virtual_position) + distance_map = self.env.distance_map.get() + current_cell_dist = distance_map[handle, + agent_virtual_position[0], agent_virtual_position[1], + agent.direction] + possible_transitions = self.env.rail.get_transitions(*agent_virtual_position, agent.direction) + orientation = agent.direction + if fast_count_nonzero(possible_transitions) == 1: + orientation = fast_argmax(possible_transitions) + + for dir_loop, branch_direction in enumerate([(orientation + dir_loop) % 4 for dir_loop in range(-1, 3)]): + if possible_transitions[branch_direction]: + new_position = get_new_position(agent_virtual_position, branch_direction) + new_cell_dist = distance_map[handle, + new_position[0], new_position[1], + branch_direction] + if not (np.math.isinf(new_cell_dist) and np.math.isinf(current_cell_dist)): + observation[dir_loop] = int(new_cell_dist < current_cell_dist) + + has_opp_agent, has_same_agent, has_switch, v = self._explore(handle, new_position, branch_direction) + visited.append(v) + + observation[10 + dir_loop] = int(not np.math.isinf(new_cell_dist)) + observation[14 + dir_loop] = has_opp_agent + observation[18 + dir_loop] = has_same_agent + observation[22 + dir_loop] = has_switch + + agents_on_switch, \ + agents_near_to_switch, \ + agents_near_to_switch_all, \ + agents_on_switch_all = \ + self.check_agent_decision(agent_virtual_position, agent.direction) + observation[7] = int(agents_on_switch) + observation[8] = int(agents_near_to_switch) + observation[9] = int(agents_near_to_switch_all) + + action = self.dead_lock_avoidance_agent.act([handle], 0.0) + observation[26] = int(action == RailEnvActions.STOP_MOVING) + self.env.dev_obs_dict.update({handle: visited}) + + return observation diff --git a/utils/shortest_distance_walker.py b/utils/shortest_distance_walker.py index 0b8121f..62b686f 100644 --- a/utils/shortest_distance_walker.py +++ b/utils/shortest_distance_walker.py @@ -1,87 +1,87 @@ -import numpy as np -from flatland.core.grid.grid4_utils import get_new_position -from flatland.envs.rail_env import RailEnv, RailEnvActions -from flatland.envs.rail_env import fast_count_nonzero, fast_argmax - - -class ShortestDistanceWalker: - def __init__(self, env: RailEnv): - self.env = env - - def walk(self, handle, position, direction): - possible_transitions = self.env.rail.get_transitions(*position, direction) - num_transitions = fast_count_nonzero(possible_transitions) - if num_transitions == 1: - new_direction = fast_argmax(possible_transitions) - new_position = get_new_position(position, new_direction) - - dist = self.env.distance_map.get()[handle, new_position[0], new_position[1], new_direction] - return new_position, new_direction, dist, RailEnvActions.MOVE_FORWARD, possible_transitions - else: - min_distances = [] - positions = [] - directions = [] - for new_direction in [(direction + i) % 4 for i in range(-1, 2)]: - if possible_transitions[new_direction]: - new_position = get_new_position(position, new_direction) - min_distances.append( - self.env.distance_map.get()[handle, new_position[0], new_position[1], new_direction]) - positions.append(new_position) - directions.append(new_direction) - else: - min_distances.append(np.inf) - positions.append(None) - directions.append(None) - - a = self.get_action(handle, min_distances) - return positions[a], directions[a], min_distances[a], a + 1, possible_transitions - - def get_action(self, handle, min_distances): - return np.argmin(min_distances) - - def callback(self, handle, agent, position, direction, action, possible_transitions): - pass - - def get_agent_position_and_direction(self, handle): - agent = self.env.agents[handle] - if agent.position is not None: - position = agent.position - else: - position = agent.initial_position - direction = agent.direction - return position, direction - - def walk_to_target(self, handle, position=None, direction=None, max_step=500): - if position is None and direction is None: - position, direction = self.get_agent_position_and_direction(handle) - elif position is None: - position, _ = self.get_agent_position_and_direction(handle) - elif direction is None: - _, direction = self.get_agent_position_and_direction(handle) - - agent = self.env.agents[handle] - step = 0 - while (position != agent.target) and (step < max_step): - position, direction, dist, action, possible_transitions = self.walk(handle, position, direction) - if position is None: - break - self.callback(handle, agent, position, direction, action, possible_transitions) - step += 1 - - def callback_one_step(self, handle, agent, position, direction, action, possible_transitions): - pass - - def walk_one_step(self, handle): - agent = self.env.agents[handle] - if agent.position is not None: - position = agent.position - else: - position = agent.initial_position - direction = agent.direction - possible_transitions = (0, 1, 0, 0) - if (position != agent.target): - new_position, new_direction, dist, action, possible_transitions = self.walk(handle, position, direction) - if new_position is None: - return position, direction, RailEnvActions.STOP_MOVING, possible_transitions - self.callback_one_step(handle, agent, new_position, new_direction, action, possible_transitions) - return new_position, new_direction, action, possible_transitions +import numpy as np +from flatland.core.grid.grid4_utils import get_new_position +from flatland.envs.rail_env import RailEnv, RailEnvActions +from flatland.envs.rail_env import fast_count_nonzero, fast_argmax + + +class ShortestDistanceWalker: + def __init__(self, env: RailEnv): + self.env = env + + def walk(self, handle, position, direction): + possible_transitions = self.env.rail.get_transitions(*position, direction) + num_transitions = fast_count_nonzero(possible_transitions) + if num_transitions == 1: + new_direction = fast_argmax(possible_transitions) + new_position = get_new_position(position, new_direction) + + dist = self.env.distance_map.get()[handle, new_position[0], new_position[1], new_direction] + return new_position, new_direction, dist, RailEnvActions.MOVE_FORWARD, possible_transitions + else: + min_distances = [] + positions = [] + directions = [] + for new_direction in [(direction + i) % 4 for i in range(-1, 2)]: + if possible_transitions[new_direction]: + new_position = get_new_position(position, new_direction) + min_distances.append( + self.env.distance_map.get()[handle, new_position[0], new_position[1], new_direction]) + positions.append(new_position) + directions.append(new_direction) + else: + min_distances.append(np.inf) + positions.append(None) + directions.append(None) + + a = self.get_action(handle, min_distances) + return positions[a], directions[a], min_distances[a], a + 1, possible_transitions + + def get_action(self, handle, min_distances): + return np.argmin(min_distances) + + def callback(self, handle, agent, position, direction, action, possible_transitions): + pass + + def get_agent_position_and_direction(self, handle): + agent = self.env.agents[handle] + if agent.position is not None: + position = agent.position + else: + position = agent.initial_position + direction = agent.direction + return position, direction + + def walk_to_target(self, handle, position=None, direction=None, max_step=500): + if position is None and direction is None: + position, direction = self.get_agent_position_and_direction(handle) + elif position is None: + position, _ = self.get_agent_position_and_direction(handle) + elif direction is None: + _, direction = self.get_agent_position_and_direction(handle) + + agent = self.env.agents[handle] + step = 0 + while (position != agent.target) and (step < max_step): + position, direction, dist, action, possible_transitions = self.walk(handle, position, direction) + if position is None: + break + self.callback(handle, agent, position, direction, action, possible_transitions) + step += 1 + + def callback_one_step(self, handle, agent, position, direction, action, possible_transitions): + pass + + def walk_one_step(self, handle): + agent = self.env.agents[handle] + if agent.position is not None: + position = agent.position + else: + position = agent.initial_position + direction = agent.direction + possible_transitions = (0, 1, 0, 0) + if (position != agent.target): + new_position, new_direction, dist, action, possible_transitions = self.walk(handle, position, direction) + if new_position is None: + return position, direction, RailEnvActions.STOP_MOVING, possible_transitions + self.callback_one_step(handle, agent, new_position, new_direction, action, possible_transitions) + return new_position, new_direction, action, possible_transitions -- GitLab