diff --git a/reinforcement_learning/dddqn_policy.py b/reinforcement_learning/dddqn_policy.py index f7cba9b6f9c94190e556f998b05c50fc5fa3c79b..7a3525d903d323487507f991e6bcb099a436b35e 100644 --- a/reinforcement_learning/dddqn_policy.py +++ b/reinforcement_learning/dddqn_policy.py @@ -55,11 +55,13 @@ class DDDQNPolicy(Policy): self.qnetwork_target = copy.deepcopy(self.qnetwork_local) self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=self.learning_rate) self.memory = ReplayBuffer(action_size, self.buffer_size, self.batch_size, self.device) - self.t_step = 0 self.loss = 0.0 + else: + self.memory = ReplayBuffer(action_size, 1, 1, self.device) + self.loss = 0.0 - def act(self, state, eps=0.): + def act(self, handle, state, eps=0.): state = torch.from_numpy(state).float().unsqueeze(0).to(self.device) self.qnetwork_local.eval() with torch.no_grad(): @@ -151,7 +153,7 @@ class DDDQNPolicy(Policy): self.memory.memory = pickle.load(f) def test(self): - self.act(np.array([[0] * self.state_size])) + self.act(0, np.array([[0] * self.state_size])) self._learn() def clone(self): diff --git a/reinforcement_learning/evaluate_agent.py b/reinforcement_learning/evaluate_agent.py index 64eb9433a9df00457d698e5873b31a16712de718..5488f81eae52753a071ef18142a5514579dd4c5c 100644 --- a/reinforcement_learning/evaluate_agent.py +++ b/reinforcement_learning/evaluate_agent.py @@ -26,7 +26,8 @@ from utils.observation_utils import normalize_observation from reinforcement_learning.dddqn_policy import DDDQNPolicy -def eval_policy(env_params, checkpoint, n_eval_episodes, max_steps, action_size, state_size, seed, render, allow_skipping, allow_caching): +def eval_policy(env_params, checkpoint, n_eval_episodes, max_steps, action_size, state_size, seed, render, + allow_skipping, allow_caching): # Evaluation is faster on CPU (except if you use a really huge policy) parameters = { 'use_gpu': False @@ -140,11 +141,12 @@ def eval_policy(env_params, checkpoint, n_eval_episodes, max_steps, action_size, else: preproc_timer.start() - norm_obs = normalize_observation(obs[agent], tree_depth=observation_tree_depth, observation_radius=observation_radius) + norm_obs = normalize_observation(obs[agent], tree_depth=observation_tree_depth, + observation_radius=observation_radius) preproc_timer.end() inference_timer.start() - action = policy.act(norm_obs, eps=0.0) + action = policy.act(agent, norm_obs, eps=0.0) inference_timer.end() action_dict.update({agent: action}) @@ -319,12 +321,15 @@ def evaluate_agents(file, n_evaluation_episodes, use_gpu, render, allow_skipping results = [] if render: - results.append(eval_policy(params, file, eval_per_thread, max_steps, action_size, state_size, 0, render, allow_skipping, allow_caching)) + results.append( + eval_policy(params, file, eval_per_thread, max_steps, action_size, state_size, 0, render, allow_skipping, + allow_caching)) else: with Pool() as p: results = p.starmap(eval_policy, - [(params, file, 1, max_steps, action_size, state_size, seed * nb_threads, render, allow_skipping, allow_caching) + [(params, file, 1, max_steps, action_size, state_size, seed * nb_threads, render, + allow_skipping, allow_caching) for seed in range(total_nb_eval)]) @@ -367,10 +372,12 @@ if __name__ == "__main__": parser.add_argument("--use_gpu", dest="use_gpu", help="use GPU if available", action='store_true') parser.add_argument("--render", help="render a single episode", action='store_true') - parser.add_argument("--allow_skipping", help="skips to the end of the episode if all agents are deadlocked", action='store_true') + parser.add_argument("--allow_skipping", help="skips to the end of the episode if all agents are deadlocked", + action='store_true') parser.add_argument("--allow_caching", help="caches the last observation-action pair", action='store_true') args = parser.parse_args() os.environ["OMP_NUM_THREADS"] = str(1) - evaluate_agents(file=args.file, n_evaluation_episodes=args.n_evaluation_episodes, use_gpu=args.use_gpu, render=args.render, + evaluate_agents(file=args.file, n_evaluation_episodes=args.n_evaluation_episodes, use_gpu=args.use_gpu, + render=args.render, allow_skipping=args.allow_skipping, allow_caching=args.allow_caching) diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py index dac599225c528dbd0a5f84c729cbb2568f942c51..e13584f491392974251b2134e29d9a5736dbac93 100755 --- a/reinforcement_learning/multi_agent_training.py +++ b/reinforcement_learning/multi_agent_training.py @@ -21,6 +21,8 @@ from torch.utils.tensorboard import SummaryWriter from reinforcement_learning.dddqn_policy import DDDQNPolicy from reinforcement_learning.ppo_agent import PPOAgent +from reinforcement_learning.ppo_deadlockavoidance_agent import MultiDecisionAgent +from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent from utils.deadlock_check import get_agent_positions, check_for_deadlock base_dir = Path(__file__).resolve().parent.parent @@ -174,6 +176,11 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): policy = DDDQNPolicy(state_size, action_size, train_params) if True: policy = PPOAgent(state_size, action_size) + if False: + policy = DeadLockAvoidanceAgent(train_env, action_size) + if True: + policy = MultiDecisionAgent(train_env, state_size, action_size, policy) + # Load existing policy if train_params.load_policy is not "": policy.load(train_params.load_policy) @@ -226,7 +233,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): train_env = create_rail_env(train_env_params, tree_observation) obs, info = train_env.reset(regenerate_rail=True, regenerate_schedule=True) - policy.reset() + policy.reset(train_env) reset_timer.end() if train_params.render: @@ -261,8 +268,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): agent = train_env.agents[agent_handle] if info['action_required'][agent_handle]: update_values[agent_handle] = True - action = policy.act(agent_obs[agent_handle], eps=eps_start) - + action = policy.act(agent_handle, agent_obs[agent_handle], eps=eps_start) action_count[action] += 1 actions_taken.append(action) else: @@ -288,7 +294,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): all_rewards[agent_handle] = 0.0 if done[agent_handle] == False: if check_for_deadlock(agent_handle, train_env, agent_positions): - all_rewards[agent_handle] = -1.0 + all_rewards[agent_handle] = -5.0 else: pos = agent.position possible_transitions = train_env.rail.get_transitions(*pos, agent.direction) @@ -471,6 +477,7 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params): score = 0.0 obs, info = env.reset(regenerate_rail=True, regenerate_schedule=True) + policy.reset(env) final_step = 0 policy.start_episode(train=False) @@ -484,7 +491,7 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params): action = 0 if info['action_required'][agent]: if tree_observation.check_is_observation_valid(agent_obs[agent]): - action = policy.act(agent_obs[agent], eps=0.0) + action = policy.act(agent, agent_obs[agent], eps=0.0) action_dict.update({agent: action}) policy.end_step(train=False) obs, all_rewards, done, info = env.step(action_dict) diff --git a/reinforcement_learning/multi_policy.py b/reinforcement_learning/multi_policy.py index 87763b1ab0e528627247f246ce734b7ddcbe55ab..0c2ae32144216bef5a95cc2ec43eb6ac027bfc7f 100644 --- a/reinforcement_learning/multi_policy.py +++ b/reinforcement_learning/multi_policy.py @@ -1,4 +1,5 @@ import numpy as np +from flatland.envs.rail_env import RailEnv from reinforcement_learning.policy import Policy from reinforcement_learning.ppo_agent import PPOAgent @@ -45,9 +46,9 @@ class MultiPolicy(Policy): self.loss = self.ppo_policy.loss return action_ppo - def reset(self): - self.ppo_policy.reset() - self.deadlock_avoidance_policy.reset() + def reset(self, env: RailEnv): + self.ppo_policy.reset(env) + self.deadlock_avoidance_policy.reset(env) def test(self): self.ppo_policy.test() diff --git a/reinforcement_learning/ordered_policy.py b/reinforcement_learning/ordered_policy.py index daf6639d33052eedc5b69481e84413edea552eee..2db171d2e1429a085488b02f9818ba75c57b2694 100644 --- a/reinforcement_learning/ordered_policy.py +++ b/reinforcement_learning/ordered_policy.py @@ -15,7 +15,7 @@ class OrderedPolicy(Policy): def __init__(self): self.action_size = 5 - def act(self, state, eps=0.): + def act(self, handle, state, eps=0.): _, distance, _ = split_tree_into_feature_groups(state, 1) distance = distance[1:] min_dist = min_gt(distance, 0) diff --git a/reinforcement_learning/policy.py b/reinforcement_learning/policy.py index 45889da1780a9188d85cedcd6c40c899fe088c51..5b118aee15253d7dfb86c04925ea8a058abdbf2d 100644 --- a/reinforcement_learning/policy.py +++ b/reinforcement_learning/policy.py @@ -1,10 +1,11 @@ -import torch.nn as nn +from flatland.envs.rail_env import RailEnv + class Policy: def step(self, handle, state, action, reward, next_state, done): raise NotImplementedError - def act(self, state, eps=0.): + def act(self, handle, state, eps=0.): raise NotImplementedError def save(self, filename): @@ -13,16 +14,16 @@ class Policy: def load(self, filename): raise NotImplementedError - def start_step(self,train): + def start_step(self, train): pass - def end_step(self,train): + def end_step(self, train): pass - def start_episode(self,train): + def start_episode(self, train): pass - def end_episode(self,train): + def end_episode(self, train): pass def load_replay_buffer(self, filename): @@ -31,8 +32,8 @@ class Policy: def test(self): pass - def reset(self): + def reset(self, env: RailEnv): pass def clone(self): - return self \ No newline at end of file + return self diff --git a/reinforcement_learning/ppo_agent.py b/reinforcement_learning/ppo_agent.py index ee179d19bf49fb7f3fe2d2ed47eb3b9123d257e4..a4e74ec884fd99288e45353a56d52cf31a27555f 100644 --- a/reinforcement_learning/ppo_agent.py +++ b/reinforcement_learning/ppo_agent.py @@ -9,7 +9,7 @@ from torch.distributions import Categorical # Hyperparameters from reinforcement_learning.policy import Policy -device = torch.device("cpu")#"cuda:0" if torch.cuda.is_available() else "cpu") +device = torch.device("cpu") # "cuda:0" if torch.cuda.is_available() else "cpu") print("device:", device) @@ -111,10 +111,10 @@ class PPOAgent(Policy): self.optimizer = optim.Adam(self.actor_critic_model.parameters(), lr=self.learning_rate) self.loss_function = nn.SmoothL1Loss() # nn.MSELoss() - def reset(self): + def reset(self, env): pass - def act(self, state, eps=None): + def act(self, handle, state, eps=None): # sample a action to take torch_state = torch.tensor(state, dtype=torch.float).to(device) dist = self.actor_critic_model.get_actor_dist(torch_state) @@ -148,10 +148,8 @@ class PPOAgent(Policy): reward_i = 1 else: done_list.insert(0, 0) - if reward_i < -1: - reward_i = -1 - else: - reward_i = 0 + reward_i = 0 + discounted_reward = reward_i + self.gamma * discounted_reward reward_list.insert(0, discounted_reward) state_next_list.insert(0, state_next_i) diff --git a/reinforcement_learning/ppo_deadlockavoidance_agent.py b/reinforcement_learning/ppo_deadlockavoidance_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..a3cf21638a4f04fba1b91e4cacbd668b62ce5996 --- /dev/null +++ b/reinforcement_learning/ppo_deadlockavoidance_agent.py @@ -0,0 +1,81 @@ +from flatland.envs.agent_utils import RailAgentStatus +from flatland.envs.rail_env import RailEnv, RailEnvActions + +from reinforcement_learning.policy import Policy +from utils.agent_can_choose_helper import AgentCanChooseHelper +from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent + + +class MultiDecisionAgent(Policy): + + def __init__(self, env: RailEnv, state_size, action_size, learning_agent): + self.env = env + self.state_size = state_size + self.action_size = action_size + self.learning_agent = learning_agent + self.dead_lock_avoidance_agent = DeadLockAvoidanceAgent(self.env, action_size, False) + self.agent_can_choose_helper = AgentCanChooseHelper() + self.memory = self.learning_agent.memory + self.loss = self.learning_agent.loss + + def step(self, handle, state, action, reward, next_state, done): + self.dead_lock_avoidance_agent.step(handle, state, action, reward, next_state, done) + self.learning_agent.step(handle, state, action, reward, next_state, done) + self.loss = self.learning_agent.loss + + def act(self, handle, state, eps=0.): + agent = self.env.agents[handle] + position = agent.position + if position is None: + position = agent.initial_position + direction = agent.direction + if agent.status < RailAgentStatus.DONE: + agents_on_switch, agents_near_to_switch, _, _ = \ + self.agent_can_choose_helper.check_agent_decision(position, direction) + if agents_on_switch or agents_near_to_switch: + return self.learning_agent.act(handle, state, eps) + else: + return self.dead_lock_avoidance_agent.act(handle, state, -1.0) + # Agent is still at target cell + return RailEnvActions.DO_NOTHING + + def save(self, filename): + self.dead_lock_avoidance_agent.save(filename) + self.learning_agent.save(filename) + + def load(self, filename): + self.dead_lock_avoidance_agent.load(filename) + self.learning_agent.load(filename) + + def start_step(self, train): + self.dead_lock_avoidance_agent.start_step(train) + self.learning_agent.start_step(train) + + def end_step(self, train): + self.dead_lock_avoidance_agent.end_step(train) + self.learning_agent.end_step(train) + + def start_episode(self, train): + self.dead_lock_avoidance_agent.start_episode(train) + self.learning_agent.start_episode(train) + + def end_episode(self, train): + self.dead_lock_avoidance_agent.end_episode(train) + self.learning_agent.end_episode(train) + + def load_replay_buffer(self, filename): + self.dead_lock_avoidance_agent.load_replay_buffer(filename) + self.learning_agent.load_replay_buffer(filename) + + def test(self): + self.dead_lock_avoidance_agent.test() + self.learning_agent.test() + + def reset(self, env: RailEnv): + self.env = env + self.agent_can_choose_helper.build_data(env) + self.dead_lock_avoidance_agent.reset(env) + self.learning_agent.reset(env) + + def clone(self): + return self diff --git a/reinforcement_learning/sequential_agent.py b/reinforcement_learning/sequential_agent.py index 3bb5a73cdc42f33a5e771eeaf530cf4af9742be8..e2055a69576454a0252a24e21408db0f04131da0 100644 --- a/reinforcement_learning/sequential_agent.py +++ b/reinforcement_learning/sequential_agent.py @@ -1,13 +1,13 @@ import sys -import numpy as np +from pathlib import Path +import numpy as np from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import complex_rail_generator from flatland.envs.schedule_generators import complex_schedule_generator from flatland.utils.rendertools import RenderTool -from pathlib import Path base_dir = Path(__file__).resolve().parent.parent sys.path.append(str(base_dir)) @@ -73,7 +73,7 @@ for trials in range(1, n_episodes + 1): if done[a]: acting_agent += 1 if a == acting_agent: - action = policy.act(obs[a]) + action = policy.act(a, obs[a]) else: action = 4 action_dict.update({a: action}) diff --git a/reinforcement_learning/sequential_agent_training.py b/reinforcement_learning/sequential_agent_training.py index ca19d1fcbbb4e3508a16b847d4b4cfcefc6aad98..d1ddd4348a462a9b7c17d6dae36c780acff1fd8b 100644 --- a/reinforcement_learning/sequential_agent_training.py +++ b/reinforcement_learning/sequential_agent_training.py @@ -1,13 +1,13 @@ import sys -import numpy as np +from pathlib import Path +import numpy as np from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import complex_rail_generator from flatland.envs.schedule_generators import complex_schedule_generator from flatland.utils.rendertools import RenderTool -from pathlib import Path base_dir = Path(__file__).resolve().parent.parent sys.path.append(str(base_dir)) @@ -66,7 +66,7 @@ for trials in range(1, n_episodes + 1): if done[a]: acting_agent += 1 if a == acting_agent: - action = policy.act(obs[a]) + action = policy.act(a, obs[a]) else: action = 4 action_dict.update({a: action}) diff --git a/reinforcement_learning/single_agent_training.py b/reinforcement_learning/single_agent_training.py index bfcc88656c8b37a8c09e72b51701d0750cf7f238..dda07a9db5b6da3c2185f65d259fa0a9cf549c50 100644 --- a/reinforcement_learning/single_agent_training.py +++ b/reinforcement_learning/single_agent_training.py @@ -123,7 +123,8 @@ def train_agent(n_episodes): # Build agent specific observations for agent in env.get_agent_handles(): if obs[agent]: - agent_obs[agent] = normalize_observation(obs[agent], observation_tree_depth, observation_radius=observation_radius) + agent_obs[agent] = normalize_observation(obs[agent], observation_tree_depth, + observation_radius=observation_radius) agent_prev_obs[agent] = agent_obs[agent].copy() # Run episode @@ -132,7 +133,7 @@ def train_agent(n_episodes): if info['action_required'][agent]: # If an action is required, we want to store the obs at that step as well as the action update_values = True - action = policy.act(agent_obs[agent], eps=eps_start) + action = policy.act(agent, agent_obs[agent], eps=eps_start) action_count[action] += 1 else: update_values = False @@ -154,7 +155,8 @@ def train_agent(n_episodes): agent_prev_action[agent] = action_dict[agent] if next_obs[agent]: - agent_obs[agent] = normalize_observation(next_obs[agent], observation_tree_depth, observation_radius=10) + agent_obs[agent] = normalize_observation(next_obs[agent], observation_tree_depth, + observation_radius=10) score += all_rewards[agent] @@ -179,15 +181,16 @@ def train_agent(n_episodes): else: end = " " - print('\rTraining {} agents on {}x{}\t Episode {}\t Average Score: {:.3f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format( - env.get_num_agents(), - x_dim, y_dim, - episode_idx, - np.mean(scores_window), - 100 * np.mean(completion_window), - eps_start, - action_probs - ), end=end) + print( + '\rTraining {} agents on {}x{}\t Episode {}\t Average Score: {:.3f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format( + env.get_num_agents(), + x_dim, y_dim, + episode_idx, + np.mean(scores_window), + 100 * np.mean(completion_window), + eps_start, + action_probs + ), end=end) # Plot overall training progress at the end plt.plot(scores) @@ -199,7 +202,8 @@ def train_agent(n_episodes): if __name__ == "__main__": parser = ArgumentParser() - parser.add_argument("-n", "--n_episodes", dest="n_episodes", help="number of episodes to run", default=500, type=int) + parser.add_argument("-n", "--n_episodes", dest="n_episodes", help="number of episodes to run", default=500, + type=int) args = parser.parse_args() train_agent(args.n_episodes) diff --git a/run.py b/run.py index 8eb8f8109c498a2cf0f9bd27a9174a577a41e240..1b1d11fd79aa3aef4a87e5e043f319e0c507edf1 100644 --- a/run.py +++ b/run.py @@ -31,6 +31,7 @@ from flatland.evaluators.client import FlatlandRemoteClient from flatland.evaluators.client import TimeoutException from reinforcement_learning.ppo_agent import PPOAgent +from reinforcement_learning.ppo_deadlockavoidance_agent import MultiDecisionAgent from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent from utils.deadlock_check import check_if_all_blocked from utils.fast_tree_obs import FastTreeObs @@ -50,20 +51,23 @@ USE_FAST_TREEOBS = True USE_PPO_AGENT = True # Checkpoint to use (remember to push it!) -checkpoint = "./checkpoints/201124171810-7800.pth" # DDDQN: 18.249244799876152 DEPTH=2 AGENTS=10 +checkpoint = "./checkpoints/201124171810-7800.pth" # DDDQN: 18.249244799876152 DEPTH=2 AGENTS=10 # checkpoint = "./checkpoints/201126150143-5200.pth" # DDDQN: 18.249244799876152 DEPTH=2 AGENTS=10 # checkpoint = "./checkpoints/201126160144-2000.pth" # DDDQN: 18.249244799876152 DEPTH=2 AGENTS=10 -checkpoint = "./checkpoints/201207144650-20000.pth" # PPO: 14.45790721540786 -checkpoint = "./checkpoints/201211063511-6300.pth" # DDDQN: 16.948349308440857 -checkpoint = "./checkpoints/201211095604-12000.pth" # DDDQN: 17.3862941316504 -checkpoint = "./checkpoints/201211164554-9400.pth" # DDDQN: 16.09241366013537 -checkpoint = "./checkpoints/201213181400-6800.pth" # PPO: 13.944402986414723 +checkpoint = "./checkpoints/201207144650-20000.pth" # PPO: 14.45790721540786 +checkpoint = "./checkpoints/201211063511-6300.pth" # DDDQN: 16.948349308440857 +checkpoint = "./checkpoints/201211095604-12000.pth" # DDDQN: 17.3862941316504 +checkpoint = "./checkpoints/201211164554-9400.pth" # DDDQN: 16.09241366013537 +checkpoint = "./checkpoints/201213181400-6800.pth" # PPO: 13.944402986414723 +checkpoint = "./checkpoints/201214140158-5000.pth" # USE_MULTI_DECISION_AGENT with DDDQN: 13.944402986414723 +checkpoint = "./checkpoints/201214160604-3000.pth" # USE_MULTI_DECISION_AGENT with DDDQN: 13.944402986414723 EPSILON = 0.0 # Use last action cache USE_ACTION_CACHE = False USE_DEAD_LOCK_AVOIDANCE_AGENT = False # 21.54485505223213 +USE_MULTI_DECISION_AGENT = True # Observation parameters (must match training parameters!) observation_tree_depth = 2 @@ -106,10 +110,10 @@ action_size = 5 # Creates the policy. No GPU on evaluation server. if not USE_PPO_AGENT: - policy = DDDQNPolicy(state_size, action_size, Namespace(**{'use_gpu': False}), evaluation_mode=True) + trained_policy = DDDQNPolicy(state_size, action_size, Namespace(**{'use_gpu': False}), evaluation_mode=True) else: - policy = PPOAgent(state_size, action_size) -policy.load(checkpoint) + trained_policy = PPOAgent(state_size, action_size) +trained_policy.load(checkpoint) ##################################################################### # Main evaluation loop @@ -144,6 +148,11 @@ while True: tree_observation.set_env(local_env) tree_observation.reset() + + policy = trained_policy + if USE_MULTI_DECISION_AGENT: + policy = MultiDecisionAgent(local_env, state_size, action_size, trained_policy) + policy.reset(local_env) observation = tree_observation.get_many(list(range(nb_agents))) print("Evaluation {}: {} agents in {}x{}".format(evaluation_number, nb_agents, local_env.width, local_env.height)) @@ -199,7 +208,7 @@ while True: observation_tree_depth, observation_radius=observation_radius) - action = policy.act(normalized_observation, eps=EPSILON) + action = policy.act(agent_handle, normalized_observation, eps=EPSILON) action_dict[agent_handle] = action diff --git a/utils/agent_can_choose_helper.py b/utils/agent_can_choose_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..95636c6a13dc847e798ff283bfb4ccf8fc871a64 --- /dev/null +++ b/utils/agent_can_choose_helper.py @@ -0,0 +1,107 @@ +from flatland.core.grid.grid4_utils import get_new_position +from flatland.envs.agent_utils import RailAgentStatus +from flatland.envs.rail_env import fast_count_nonzero + + +class AgentCanChooseHelper: + def __init__(self): + pass + + def build_data(self, env): + self.env = env + if self.env is not None: + self.env.dev_obs_dict = {} + self.switches = {} + self.switches_neighbours = {} + if self.env is not None: + self.find_all_cell_where_agent_can_choose() + + def find_all_switches(self): + # Search the environment (rail grid) for all switch cells. A switch is a cell where more than one tranisation + # exists and collect all direction where the switch is a switch. + self.switches = {} + for h in range(self.env.height): + for w in range(self.env.width): + pos = (h, w) + for dir in range(4): + possible_transitions = self.env.rail.get_transitions(*pos, dir) + num_transitions = fast_count_nonzero(possible_transitions) + if num_transitions > 1: + if pos not in self.switches.keys(): + self.switches.update({pos: [dir]}) + else: + self.switches[pos].append(dir) + + def find_all_switch_neighbours(self): + # Collect all cells where is a neighbour to a switch cell. All cells are neighbour where the agent can make + # just one step and he stands on a switch. A switch is a cell where the agents has more than one transition. + self.switches_neighbours = {} + for h in range(self.env.height): + for w in range(self.env.width): + # look one step forward + for dir in range(4): + pos = (h, w) + possible_transitions = self.env.rail.get_transitions(*pos, dir) + for d in range(4): + if possible_transitions[d] == 1: + new_cell = get_new_position(pos, d) + if new_cell in self.switches.keys() and pos not in self.switches.keys(): + if pos not in self.switches_neighbours.keys(): + self.switches_neighbours.update({pos: [dir]}) + else: + self.switches_neighbours[pos].append(dir) + + def find_all_cell_where_agent_can_choose(self): + # prepare the memory - collect all cells where the agent can choose more than FORWARD/STOP. + self.find_all_switches() + self.find_all_switch_neighbours() + + def check_agent_decision(self, position, direction): + # Decide whether the agent is + # - on a switch + # - at a switch neighbour (near to switch). The switch must be a switch where the agent has more option than + # FORWARD/STOP + # - all switch : doesn't matter whether the agent has more options than FORWARD/STOP + # - all switch neightbors : doesn't matter the agent has more then one options (transistion) when he reach the + # switch + agents_on_switch = False + agents_on_switch_all = False + agents_near_to_switch = False + agents_near_to_switch_all = False + if position in self.switches.keys(): + agents_on_switch = direction in self.switches[position] + agents_on_switch_all = True + + if position in self.switches_neighbours.keys(): + new_cell = get_new_position(position, direction) + if new_cell in self.switches.keys(): + if not direction in self.switches[new_cell]: + agents_near_to_switch = direction in self.switches_neighbours[position] + else: + agents_near_to_switch = direction in self.switches_neighbours[position] + + agents_near_to_switch_all = direction in self.switches_neighbours[position] + + return agents_on_switch, agents_near_to_switch, agents_near_to_switch_all, agents_on_switch_all + + def required_agent_decision(self): + agents_can_choose = {} + agents_on_switch = {} + agents_on_switch_all = {} + agents_near_to_switch = {} + agents_near_to_switch_all = {} + for a in range(self.env.get_num_agents()): + ret_agents_on_switch, ret_agents_near_to_switch, ret_agents_near_to_switch_all, ret_agents_on_switch_all = \ + self.check_agent_decision( + self.env.agents[a].position, + self.env.agents[a].direction) + agents_on_switch.update({a: ret_agents_on_switch}) + agents_on_switch_all.update({a: ret_agents_on_switch_all}) + ready_to_depart = self.env.agents[a].status == RailAgentStatus.READY_TO_DEPART + agents_near_to_switch.update({a: (ret_agents_near_to_switch and not ready_to_depart)}) + + agents_can_choose.update({a: agents_on_switch[a] or agents_near_to_switch[a]}) + + agents_near_to_switch_all.update({a: (ret_agents_near_to_switch_all and not ready_to_depart)}) + + return agents_can_choose, agents_on_switch, agents_near_to_switch, agents_near_to_switch_all, agents_on_switch_all diff --git a/utils/dead_lock_avoidance_agent.py b/utils/dead_lock_avoidance_agent.py index 07840db7b28505ea228db35e4e10f961c4015313..286718ea4a86ae75f9d32e72476f5e48f14a558f 100644 --- a/utils/dead_lock_avoidance_agent.py +++ b/utils/dead_lock_avoidance_agent.py @@ -66,10 +66,18 @@ class DeadlockAvoidanceShortestDistanceWalker(ShortestDistanceWalker): self.full_shortest_distance_agent_map[(handle, position[0], position[1])] = 1 +class DummyMemory: + def __init__(self): + self.memory = [] + + def __len__(self): + return 0 + + class DeadLockAvoidanceAgent(Policy): def __init__(self, env: RailEnv, action_size, show_debug_plot=False): self.env = env - self.memory = None + self.memory = DummyMemory() self.loss = 0 self.action_size = action_size self.agent_can_move = {} @@ -77,16 +85,16 @@ class DeadLockAvoidanceAgent(Policy): self.switches = {} self.show_debug_plot = show_debug_plot - def step(self, state, action, reward, next_state, done): + def step(self, handle, state, action, reward, next_state, done): pass - def act(self, state, eps=0.): + def act(self, handle, state, eps=0.): # Epsilon-greedy action selection if np.random.random() < eps: return np.random.choice(np.arange(self.action_size)) # agent = self.env.agents[state[0]] - check = self.agent_can_move.get(state[0], None) + check = self.agent_can_move.get(handle, None) if check is None: return RailEnvActions.STOP_MOVING return check[3] @@ -94,7 +102,8 @@ class DeadLockAvoidanceAgent(Policy): def get_agent_can_move_value(self, handle): return self.agent_can_move_value.get(handle, np.inf) - def reset(self): + def reset(self, env): + self.env = env self.agent_positions = None self.shortest_distance_walker = None self.switches = {} diff --git a/utils/fast_tree_obs.py b/utils/fast_tree_obs.py index c45477db4dcb91ebbea36b729ebc6ff142bec000..8172703b51cf754981c03e487c08f51ea400aec0 100755 --- a/utils/fast_tree_obs.py +++ b/utils/fast_tree_obs.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import List, Optional, Any import numpy as np from flatland.core.env_observation_builder import ObservationBuilder @@ -6,6 +6,7 @@ from flatland.core.grid.grid4_utils import get_new_position from flatland.envs.agent_utils import RailAgentStatus from flatland.envs.rail_env import fast_count_nonzero, fast_argmax, RailEnvActions +from utils.agent_can_choose_helper import AgentCanChooseHelper from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent from utils.deadlock_check import check_for_deadlock, get_agent_positions @@ -24,116 +25,14 @@ Author: Adrian Egli (adrian.egli@gmail.com) class FastTreeObs(ObservationBuilder): - def __init__(self, max_depth): + def __init__(self, max_depth: Any): self.max_depth = max_depth self.observation_dim = 41 - - def build_data(self): - if self.env is not None: - self.env.dev_obs_dict = {} - self.switches = {} - self.switches_neighbours = {} - self.debug_render_list = [] - self.debug_render_path_list = [] - if self.env is not None: - self.find_all_cell_where_agent_can_choose() - self.dead_lock_avoidance_agent = DeadLockAvoidanceAgent(self.env, 5, False) - else: - self.dead_lock_avoidance_agent = None - - def find_all_switches(self): - # Search the environment (rail grid) for all switch cells. A switch is a cell where more than one tranisation - # exists and collect all direction where the switch is a switch. - self.switches = {} - for h in range(self.env.height): - for w in range(self.env.width): - pos = (h, w) - for dir in range(4): - possible_transitions = self.env.rail.get_transitions(*pos, dir) - num_transitions = fast_count_nonzero(possible_transitions) - if num_transitions > 1: - if pos not in self.switches.keys(): - self.switches.update({pos: [dir]}) - else: - self.switches[pos].append(dir) - - def find_all_switch_neighbours(self): - # Collect all cells where is a neighbour to a switch cell. All cells are neighbour where the agent can make - # just one step and he stands on a switch. A switch is a cell where the agents has more than one transition. - self.switches_neighbours = {} - for h in range(self.env.height): - for w in range(self.env.width): - # look one step forward - for dir in range(4): - pos = (h, w) - possible_transitions = self.env.rail.get_transitions(*pos, dir) - for d in range(4): - if possible_transitions[d] == 1: - new_cell = get_new_position(pos, d) - if new_cell in self.switches.keys() and pos not in self.switches.keys(): - if pos not in self.switches_neighbours.keys(): - self.switches_neighbours.update({pos: [dir]}) - else: - self.switches_neighbours[pos].append(dir) - - def find_all_cell_where_agent_can_choose(self): - # prepare the memory - collect all cells where the agent can choose more than FORWARD/STOP. - self.find_all_switches() - self.find_all_switch_neighbours() - - def check_agent_decision(self, position, direction): - # Decide whether the agent is - # - on a switch - # - at a switch neighbour (near to switch). The switch must be a switch where the agent has more option than - # FORWARD/STOP - # - all switch : doesn't matter whether the agent has more options than FORWARD/STOP - # - all switch neightbors : doesn't matter the agent has more then one options (transistion) when he reach the - # switch - agents_on_switch = False - agents_on_switch_all = False - agents_near_to_switch = False - agents_near_to_switch_all = False - if position in self.switches.keys(): - agents_on_switch = direction in self.switches[position] - agents_on_switch_all = True - - if position in self.switches_neighbours.keys(): - new_cell = get_new_position(position, direction) - if new_cell in self.switches.keys(): - if not direction in self.switches[new_cell]: - agents_near_to_switch = direction in self.switches_neighbours[position] - else: - agents_near_to_switch = direction in self.switches_neighbours[position] - - agents_near_to_switch_all = direction in self.switches_neighbours[position] - - return agents_on_switch, agents_near_to_switch, agents_near_to_switch_all, agents_on_switch_all - - def required_agent_decision(self): - agents_can_choose = {} - agents_on_switch = {} - agents_on_switch_all = {} - agents_near_to_switch = {} - agents_near_to_switch_all = {} - for a in range(self.env.get_num_agents()): - ret_agents_on_switch, ret_agents_near_to_switch, ret_agents_near_to_switch_all, ret_agents_on_switch_all = \ - self.check_agent_decision( - self.env.agents[a].position, - self.env.agents[a].direction) - agents_on_switch.update({a: ret_agents_on_switch}) - agents_on_switch_all.update({a: ret_agents_on_switch_all}) - ready_to_depart = self.env.agents[a].status == RailAgentStatus.READY_TO_DEPART - agents_near_to_switch.update({a: (ret_agents_near_to_switch and not ready_to_depart)}) - - agents_can_choose.update({a: agents_on_switch[a] or agents_near_to_switch[a]}) - - agents_near_to_switch_all.update({a: (ret_agents_near_to_switch_all and not ready_to_depart)}) - - return agents_can_choose, agents_on_switch, agents_near_to_switch, agents_near_to_switch_all, agents_on_switch_all + self.agent_can_choose_helper = None def debug_render(self, env_renderer): agents_can_choose, agents_on_switch, agents_near_to_switch, agents_near_to_switch_all = \ - self.required_agent_decision() + self.agent_can_choose_helper.required_agent_decision() self.env.dev_obs_dict = {} for a in range(max(3, self.env.get_num_agents())): self.env.dev_obs_dict.update({a: []}) @@ -156,13 +55,20 @@ class FastTreeObs(ObservationBuilder): env_renderer.gl.agent_colors[3] = env_renderer.gl.rgb_s2i("550000") self.env.dev_obs_dict[0] = self.debug_render_list - self.env.dev_obs_dict[1] = self.switches.keys() - self.env.dev_obs_dict[2] = self.switches_neighbours.keys() + self.env.dev_obs_dict[1] = self.agent_can_choose_helper.switches.keys() + self.env.dev_obs_dict[2] = self.agent_can_choose_helper.switches_neighbours.keys() self.env.dev_obs_dict[3] = self.debug_render_path_list def reset(self): - self.build_data() - return + if self.agent_can_choose_helper is None: + self.agent_can_choose_helper = AgentCanChooseHelper() + self.agent_can_choose_helper.build_data(self.env) + self.debug_render_list = [] + self.debug_render_path_list = [] + if self.env is not None: + self.dead_lock_avoidance_agent = DeadLockAvoidanceAgent(self.env, 5, False) + else: + self.dead_lock_avoidance_agent = None def _explore(self, handle, new_position, new_direction, distance_map, depth=0): has_opp_agent = 0 @@ -201,7 +107,7 @@ class FastTreeObs(ObservationBuilder): # agent_near_to_switch == TRUE -> One cell before the switch, where the agent can decide # agents_on_switch, agents_near_to_switch, _, _ = \ - self.check_agent_decision(new_position, new_direction) + self.agent_can_choose_helper.check_agent_decision(new_position, new_direction) if agents_near_to_switch: # The exploration was walking on a path where the agent can not decide @@ -250,7 +156,7 @@ class FastTreeObs(ObservationBuilder): self.dead_lock_avoidance_agent.end_step(train=False) return observations - def get(self, handle): + def get(self, handle: int = 0): # all values are [0,1] # observation[0] : 1 path towards target (direction 0) / otherwise 0 -> path is longer or there is no path # observation[1] : 1 path towards target (direction 1) / otherwise 0 -> path is longer or there is no path @@ -340,14 +246,14 @@ class FastTreeObs(ObservationBuilder): agents_near_to_switch, \ agents_near_to_switch_all, \ agents_on_switch_all = \ - self.check_agent_decision(agent_virtual_position, agent.direction) + self.agent_can_choose_helper.check_agent_decision(agent_virtual_position, agent.direction) observation[7] = int(agents_on_switch) observation[8] = int(agents_on_switch_all) observation[9] = int(agents_near_to_switch) observation[10] = int(agents_near_to_switch_all) - action = self.dead_lock_avoidance_agent.act([handle], 0.0) + action = self.dead_lock_avoidance_agent.act(handle, None, 0.0) observation[35] = int(action == RailEnvActions.STOP_MOVING) observation[40] = int(check_for_deadlock(handle, self.env, self.agent_positions)) diff --git a/utils/shortest_path_walker_heuristic_agent.py b/utils/shortest_path_walker_heuristic_agent.py index eaa71e91a416b0c899519e690c4e29ad8147a48d..d2cbab04f407edeae3fba5030a0b7b3309560cfc 100644 --- a/utils/shortest_path_walker_heuristic_agent.py +++ b/utils/shortest_path_walker_heuristic_agent.py @@ -8,7 +8,7 @@ class ShortestPathWalkerHeuristicPolicy(Policy): def step(self, state, action, reward, next_state, done): pass - def act(self, node, eps=0.): + def act(self, handle, node, eps=0.): left_node = node.childs.get('L') forward_node = node.childs.get('F')