diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py index e13584f491392974251b2134e29d9a5736dbac93..3e461fd9fe3e13aa341c8e4aab1cde6af2854bf7 100755 --- a/reinforcement_learning/multi_agent_training.py +++ b/reinforcement_learning/multi_agent_training.py @@ -24,6 +24,7 @@ from reinforcement_learning.ppo_agent import PPOAgent from reinforcement_learning.ppo_deadlockavoidance_agent import MultiDecisionAgent from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent from utils.deadlock_check import get_agent_positions, check_for_deadlock +from utils.agent_action_config import get_flatland_full_action_size, get_action_size, map_actions, map_action base_dir = Path(__file__).resolve().parent.parent sys.path.append(str(base_dir)) @@ -155,10 +156,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): # Calculate the state size given the depth of the tree observation and the number of features state_size = tree_observation.observation_dim - # The action space of flatland is 5 discrete actions - action_size = 5 - - action_count = [0] * action_size + action_count = [0] * get_flatland_full_action_size() action_dict = dict() agent_obs = [None] * n_agents agent_prev_obs = [None] * n_agents @@ -173,13 +171,13 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): completion_window = deque(maxlen=checkpoint_interval) # Double Dueling DQN policy - policy = DDDQNPolicy(state_size, action_size, train_params) - if True: - policy = PPOAgent(state_size, action_size) + policy = DDDQNPolicy(state_size, get_action_size(), train_params) if False: - policy = DeadLockAvoidanceAgent(train_env, action_size) + policy = PPOAgent(state_size, get_action_size()) + if True: + policy = DeadLockAvoidanceAgent(train_env, get_action_size()) if True: - policy = MultiDecisionAgent(train_env, state_size, action_size, policy) + policy = MultiDecisionAgent(train_env, state_size, get_action_size(), policy) # Load existing policy if train_params.load_policy is not "": @@ -269,8 +267,8 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): if info['action_required'][agent_handle]: update_values[agent_handle] = True action = policy.act(agent_handle, agent_obs[agent_handle], eps=eps_start) - action_count[action] += 1 - actions_taken.append(action) + action_count[map_action(action, get_action_size())] += 1 + actions_taken.append(map_action(action, get_action_size())) else: # An action is not required if the train hasn't joined the railway network, # if it already reached its target, or if is currently malfunctioning. @@ -282,7 +280,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): # Environment step step_timer.start() - next_obs, all_rewards, done, info = train_env.step(action_dict) + next_obs, all_rewards, done, info = train_env.step(map_actions(action_dict, get_action_size())) # Reward shaping .Dead-lock .NotMoving .NotStarted if False: @@ -290,6 +288,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): for agent_handle in train_env.get_agent_handles(): agent = train_env.agents[agent_handle] act = action_dict.get(agent_handle, RailEnvActions.DO_NOTHING) + act = map_action(act, get_action_size()) if agent.status == RailAgentStatus.ACTIVE: all_rewards[agent_handle] = 0.0 if done[agent_handle] == False: @@ -305,7 +304,8 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): else: all_rewards[agent_handle] = -0.01 else: - all_rewards[agent_handle] = 1.0 + all_rewards[agent_handle] *= 10.0 + all_rewards[agent_handle] += 1.0 step_timer.end() @@ -325,7 +325,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): learn_timer.start() policy.step(agent_handle, agent_prev_obs[agent_handle], - agent_prev_action[agent_handle], + agent_prev_action[agent_handle] - 1, all_rewards[agent_handle], agent_obs[agent_handle], done[agent_handle]) @@ -375,19 +375,19 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): policy.save_replay_buffer('./replay_buffers/' + training_id + '-' + str(episode_idx) + '.pkl') # reset action count - action_count = [0] * action_size + action_count = [0] * get_flatland_full_action_size() print( '\r🚂 Episode {}' - '\t 🚉 nAgents {}' - '\t 🆠Score: {:7.3f}' + '\t 🚉 nAgents {:2}/{:2}' + ' 🆠Score: {:7.3f}' ' Avg: {:7.3f}' '\t 💯 Done: {:6.2f}%' ' Avg: {:6.2f}%' '\t 🎲 Epsilon: {:.3f} ' '\t 🔀 Action Probs: {}'.format( episode_idx, - train_env_params.n_agents, + train_env_params.n_agents, train_env.get_num_agents(), normalized_score, smoothed_normalized_score, 100 * completion, @@ -494,7 +494,7 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params): action = policy.act(agent, agent_obs[agent], eps=0.0) action_dict.update({agent: action}) policy.end_step(train=False) - obs, all_rewards, done, info = env.step(action_dict) + obs, all_rewards, done, info = env.step(map_actions(action_dict, get_action_size())) for agent in env.get_agent_handles(): score += all_rewards[agent] diff --git a/reinforcement_learning/ppo_deadlockavoidance_agent.py b/reinforcement_learning/ppo_deadlockavoidance_agent.py index a3cf21638a4f04fba1b91e4cacbd668b62ce5996..e344d8ebcd908d2508829571d6cde59d413390df 100644 --- a/reinforcement_learning/ppo_deadlockavoidance_agent.py +++ b/reinforcement_learning/ppo_deadlockavoidance_agent.py @@ -35,7 +35,10 @@ class MultiDecisionAgent(Policy): if agents_on_switch or agents_near_to_switch: return self.learning_agent.act(handle, state, eps) else: - return self.dead_lock_avoidance_agent.act(handle, state, -1.0) + act = self.dead_lock_avoidance_agent.act(handle, state, -1.0) + if self.action_size == 4: + act = max(act - 1, 0) + return act # Agent is still at target cell return RailEnvActions.DO_NOTHING diff --git a/run.py b/run.py index 1b1d11fd79aa3aef4a87e5e043f319e0c507edf1..0ba9acc5a0f976ea7373e6925ef67411978f1a42 100644 --- a/run.py +++ b/run.py @@ -32,6 +32,7 @@ from flatland.evaluators.client import TimeoutException from reinforcement_learning.ppo_agent import PPOAgent from reinforcement_learning.ppo_deadlockavoidance_agent import MultiDecisionAgent +from utils.agent_action_config import get_action_size, map_actions from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent from utils.deadlock_check import check_if_all_blocked from utils.fast_tree_obs import FastTreeObs @@ -60,14 +61,14 @@ checkpoint = "./checkpoints/201211095604-12000.pth" # DDDQN: 17.3862941316504 checkpoint = "./checkpoints/201211164554-9400.pth" # DDDQN: 16.09241366013537 checkpoint = "./checkpoints/201213181400-6800.pth" # PPO: 13.944402986414723 checkpoint = "./checkpoints/201214140158-5000.pth" # USE_MULTI_DECISION_AGENT with DDDQN: 13.944402986414723 -checkpoint = "./checkpoints/201214160604-3000.pth" # USE_MULTI_DECISION_AGENT with DDDQN: 13.944402986414723 +checkpoint = "./checkpoints/201215120518-3700.pth" # USE_MULTI_DECISION_AGENT with PPO: 13.944402986414723 EPSILON = 0.0 # Use last action cache USE_ACTION_CACHE = False USE_DEAD_LOCK_AVOIDANCE_AGENT = False # 21.54485505223213 -USE_MULTI_DECISION_AGENT = True +USE_MULTI_DECISION_AGENT = False # Observation parameters (must match training parameters!) observation_tree_depth = 2 @@ -106,7 +107,7 @@ else: n_nodes = sum([np.power(4, i) for i in range(observation_tree_depth + 1)]) state_size = n_features_per_node * n_nodes -action_size = 5 +action_size = get_action_size() # Creates the policy. No GPU on evaluation server. if not USE_PPO_AGENT: @@ -221,7 +222,7 @@ while True: time_taken_by_controller.append(agent_time) time_start = time.time() - _, all_rewards, done, info = remote_client.env_step(action_dict) + _, all_rewards, done, info = remote_client.env_step(map_actions(action_dict, get_action_size)) step_time = time.time() - time_start time_taken_per_step.append(step_time) diff --git a/utils/agent_action_config.py b/utils/agent_action_config.py new file mode 100644 index 0000000000000000000000000000000000000000..dceba553310e2d4e47a5554da267c11dfa338ee1 --- /dev/null +++ b/utils/agent_action_config.py @@ -0,0 +1,25 @@ + +def get_flatland_full_action_size(): + # The action space of flatland is 5 discrete actions + return 5 + + +def get_action_size(): + # The agents (DDDQN, PPO, ... ) have this actions space + return 4 + + +def map_actions(actions, action_size): + # Map the + if action_size == get_flatland_full_action_size(): + return actions + for key in actions: + value = actions.get(key, 0) + actions.update({key: (value + 1)}) + return actions + + +def map_action(action, action_size): + if action_size == get_flatland_full_action_size(): + return action + return action + 1 diff --git a/utils/dead_lock_avoidance_agent.py b/utils/dead_lock_avoidance_agent.py index 286718ea4a86ae75f9d32e72476f5e48f14a558f..87f7e28a6ba20baad1ed78ae256f44b39c47cfea 100644 --- a/utils/dead_lock_avoidance_agent.py +++ b/utils/dead_lock_avoidance_agent.py @@ -95,9 +95,12 @@ class DeadLockAvoidanceAgent(Policy): # agent = self.env.agents[state[0]] check = self.agent_can_move.get(handle, None) - if check is None: - return RailEnvActions.STOP_MOVING - return check[3] + act = RailEnvActions.STOP_MOVING + if check is not None: + act = check[3] + if self.action_size == 4: + act = max(act - 1, 0) + return act def get_agent_can_move_value(self, handle): return self.agent_can_move_value.get(handle, np.inf)