diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py index 1a1f5121fd6f24f38bf82b0bec162131d84c7d67..487d44d70b3387fb03df6af5c7ae030862ba8d0a 100755 --- a/reinforcement_learning/multi_agent_training.py +++ b/reinforcement_learning/multi_agent_training.py @@ -285,21 +285,13 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): step_timer.start() next_obs, all_rewards, done, info = train_env.step(action_dict) - if False: + if True: for agent in train_env.get_agent_handles(): act = action_dict.get(agent, RailEnvActions.DO_NOTHING) - if agent_obs[agent][26] == 1: - if act == RailEnvActions.STOP_MOVING: - all_rewards[agent] *= 0.01 - else: - if act == RailEnvActions.MOVE_LEFT: - all_rewards[agent] *= 0.9 - else: - if agent_obs[agent][7] == 0 and agent_obs[agent][8] == 0: - if act == RailEnvActions.MOVE_FORWARD: - all_rewards[agent] *= 0.01 - if done[agent]: - all_rewards[agent] += 100.0 + if agent_obs[agent][5] == 1: + if agent_obs[agent][26] == 1: + if act != RailEnvActions.STOP_MOVING: + all_rewards[agent] -= 10.0 step_timer.end() @@ -508,11 +500,11 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params): if __name__ == "__main__": parser = ArgumentParser() parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=5400, type=int) - parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=1, + parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=2, type=int) - parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=1, + parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=2, type=int) - parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=10, type=int) + parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=25, type=int) parser.add_argument("--checkpoint_interval", help="checkpoint interval", default=100, type=int) parser.add_argument("--eps_start", help="max exploration", default=1.0, type=float) parser.add_argument("--eps_end", help="min exploration", default=0.01, type=float) diff --git a/run.py b/run.py index 8cb2630449d3b6442cb0aed3c0e3b1b31fb14484..f4eb48e4aef8bce112a71e68355b7c7e5b85adb2 100644 --- a/run.py +++ b/run.py @@ -26,7 +26,7 @@ from reinforcement_learning.dddqn_policy import DDDQNPolicy VERBOSE = True # Checkpoint to use (remember to push it!) -checkpoint = "./checkpoints/201111175340-5400.pth" +checkpoint = "./checkpoints/201112143850-4100.pth" # 21.543589381053096 DEPTH=2 # Use last action cache USE_ACTION_CACHE = False @@ -137,14 +137,13 @@ while True: nb_hit += 1 else: action = policy.act(observation[agent], eps=0.01) - #if observation[agent][26] == 1: - # action = RailEnvActions.STOP_MOVING action_dict[agent] = action if USE_ACTION_CACHE: agent_last_obs[agent] = observation[agent] agent_last_action[agent] = action + policy.end_step() agent_time = time.time() - time_start time_taken_by_controller.append(agent_time)