diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py index b219ace5133d911600edca85da6dc13fc9590753..94c6208ddde96a1bd9aea3728bb2a73f3f4ab667 100755 --- a/reinforcement_learning/multi_agent_training.py +++ b/reinforcement_learning/multi_agent_training.py @@ -176,7 +176,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): policy = PPOPolicy(state_size, get_action_size()) if False: policy = DeadLockAvoidanceAgent(train_env, get_action_size()) - if True: + if False: policy = MultiDecisionAgent(train_env, state_size, get_action_size(), policy) # Load existing policy @@ -283,11 +283,11 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): next_obs, all_rewards, done, info = train_env.step(map_actions(action_dict)) # Reward shaping .Dead-lock .NotMoving .NotStarted - if False: + if True: agent_positions = get_agent_positions(train_env) for agent_handle in train_env.get_agent_handles(): agent = train_env.agents[agent_handle] - act = map_action(action_dict.get(agent_handle, RailEnvActions.DO_NOTHING)) + act = map_action(action_dict.get(agent_handle, map_rail_env_action(RailEnvActions.DO_NOTHING))) if agent.status == RailAgentStatus.ACTIVE: if done[agent_handle] == False: if check_for_deadlock(agent_handle, train_env, agent_positions): @@ -298,9 +298,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): num_transitions = fast_count_nonzero(possible_transitions) if num_transitions < 2 and ((act != RailEnvActions.MOVE_FORWARD) or (act != RailEnvActions.STOP_MOVING)): - all_rewards[agent_handle] -= 0.5 - else: - all_rewards[agent_handle] -= 0.01 + all_rewards[agent_handle] -= 1.0 else: all_rewards[agent_handle] *= 9.0 all_rewards[agent_handle] += 1.0 diff --git a/reinforcement_learning/ppo_agent.py b/reinforcement_learning/ppo_agent.py index 51f0f7187f7d0894c97d535e68987b79e644779a..4072ecc190361f3a08aa150c80db17e66675017c 100644 --- a/reinforcement_learning/ppo_agent.py +++ b/reinforcement_learning/ppo_agent.py @@ -96,21 +96,21 @@ class ActorCriticModel(nn.Module): class PPOPolicy(LearningPolicy): - def __init__(self, state_size, action_size): + def __init__(self, state_size, action_size, use_replay_buffer=False): print(">> PPOPolicy") super(PPOPolicy, self).__init__() # parameters self.learning_rate = 1.0e-3 self.gamma = 0.95 - self.surrogate_eps_clip = 0.01 - self.K_epoch = 5 - self.weight_loss = 0.25 + self.surrogate_eps_clip = 0.1 + self.K_epoch = 10 + self.weight_loss = 0.5 self.weight_entropy = 0.01 self.buffer_size = 32_000 self.batch_size = 1024 self.buffer_min_size = 0 - self.use_replay_buffer = True + self.use_replay_buffer = use_replay_buffer self.device = device self.current_episode_memory = EpisodeBuffers() @@ -178,10 +178,6 @@ class PPOPolicy(LearningPolicy): state_next_list.insert(0, state_next_i) prob_a_list.insert(0, prob_action_i) - # standard-normalize rewards - reward_list = np.array(reward_list) - reward_list = (reward_list - reward_list.mean()) / (reward_list.std() + 1.e-5) - if self.use_replay_buffer: self._push_transitions_to_replay_buffer(state_list, action_list, reward_list, state_next_list, diff --git a/reinforcement_learning/rl_agent_test.py b/reinforcement_learning/rl_agent_test.py index d764b8c016f70d4abd1fe3ed1f2d5f99a001411d..529597171cbe6898aa70f6e163df458dbb4257fc 100644 --- a/reinforcement_learning/rl_agent_test.py +++ b/reinforcement_learning/rl_agent_test.py @@ -30,12 +30,12 @@ def cartpole(use_dddqn=False): observation_space = env.observation_space.shape[0] action_space = env.action_space.n if not use_dddqn: - policy = PPOPolicy(observation_space, action_space) + policy = PPOPolicy(observation_space, action_space, False) else: policy = DDDQNPolicy(observation_space, action_space, dddqn_param) episode = 0 checkpoint_interval = 20 - scores_window = deque(maxlen=checkpoint_interval) + scores_window = deque(maxlen=100) while True: episode += 1 state = env.reset() @@ -45,7 +45,7 @@ def cartpole(use_dddqn=False): policy.start_episode(train=training_mode) while True: - env.render() + # env.render() policy.start_step(train=training_mode) action = policy.act(handle, state, eps) state_next, reward, terminal, info = env.step(action) diff --git a/run.py b/run.py index a3e116c449a7572aecb90eb85e2f165ae70d1e11..c04de8538db9910b6c1dcbe1d3dd42b3487edf6c 100644 --- a/run.py +++ b/run.py @@ -52,7 +52,7 @@ USE_FAST_TREEOBS = True USE_PPO_AGENT = True # Checkpoint to use (remember to push it!) -checkpoint = "./checkpoints/201217163219-6500.pth" # +checkpoint = "./checkpoints/201219090514-8600.pth" # # checkpoint = "./checkpoints/201215212134-12000.pth" # EPSILON = 0.0