diff --git a/checkpoints/201106170544-5400.pth.local b/checkpoints/201106170544-5400.pth.local deleted file mode 100644 index ea068d906ebad8ed514f0c0b4fa70b54f17cedbf..0000000000000000000000000000000000000000 Binary files a/checkpoints/201106170544-5400.pth.local and /dev/null differ diff --git a/checkpoints/201106170544-5400.pth.target b/checkpoints/201106170544-5400.pth.target deleted file mode 100644 index f789a5afa01268708e0009ae5d03e98166db36e2..0000000000000000000000000000000000000000 Binary files a/checkpoints/201106170544-5400.pth.target and /dev/null differ diff --git a/checkpoints/201106234900-100.pth.local b/checkpoints/201106234900-100.pth.local new file mode 100644 index 0000000000000000000000000000000000000000..1b4b1cb2c430282098d599f76b71123fc84c9ba4 Binary files /dev/null and b/checkpoints/201106234900-100.pth.local differ diff --git a/checkpoints/201106234900-100.pth.target b/checkpoints/201106234900-100.pth.target new file mode 100644 index 0000000000000000000000000000000000000000..391a36c37d22af1fdc97de554f96d7ecdc0d4874 Binary files /dev/null and b/checkpoints/201106234900-100.pth.target differ diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py index 98a62c180dc076ff99da7ab594f3e1c3c7978a70..b9d103961da0b2474eb0fb3cb1dd65058e2bc51c 100755 --- a/reinforcement_learning/multi_agent_training.py +++ b/reinforcement_learning/multi_agent_training.py @@ -19,7 +19,6 @@ from flatland.utils.rendertools import RenderTool from torch.utils.tensorboard import SummaryWriter from reinforcement_learning.dddqn_policy import DDDQNPolicy -from reinforcement_learning.ppo.ppo_agent import PPOAgent from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent base_dir = Path(__file__).resolve().parent.parent @@ -200,9 +199,6 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): # TensorBoard writer writer = SummaryWriter() - writer.add_hparams(vars(train_params), {}) - writer.add_hparams(vars(train_env_params), {}) - writer.add_hparams(vars(obs_params), {}) training_timer = Timer() training_timer.start() @@ -287,6 +283,22 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): # Environment step step_timer.start() next_obs, all_rewards, done, info = train_env.step(action_dict) + + for agent in train_env.get_agent_handles(): + act = action_dict.get(agent, RailEnvActions.DO_NOTHING) + if agent_obs[agent][26] == 1: + if act == RailEnvActions.STOP_MOVING: + all_rewards[agent] *= 0.01 + else: + if act == RailEnvActions.MOVE_LEFT: + all_rewards[agent] *= 0.9 + else: + if agent_obs[agent][7] == 0 and agent_obs[agent][8] == 0: + if act == RailEnvActions.MOVE_FORWARD: + all_rewards[agent] *= 0.01 + if done[agent]: + all_rewards[agent] += 100.0 + step_timer.end() # Render an episode at some interval @@ -495,11 +507,11 @@ if __name__ == "__main__": parser = ArgumentParser() parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=5400, type=int) parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=1, type=int) - parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=1, + parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=0, type=int) - parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=25, type=int) + parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=1, type=int) parser.add_argument("--checkpoint_interval", help="checkpoint interval", default=100, type=int) - parser.add_argument("--eps_start", help="max exploration", default=1.0, type=float) + parser.add_argument("--eps_start", help="max exploration", default=0.1, type=float) parser.add_argument("--eps_end", help="min exploration", default=0.01, type=float) parser.add_argument("--eps_decay", help="exploration decay", default=0.9998, type=float) parser.add_argument("--buffer_size", help="replay buffer size", default=int(1e7), type=int) @@ -519,7 +531,7 @@ if __name__ == "__main__": parser.add_argument("--load_policy", help="policy filename (reference) to load", default="", type=str) parser.add_argument("--use_fast_tree_observation", help="use FastTreeObs instead of stock TreeObs", action='store_true') - parser.add_argument("--max_depth", help="max depth", default=2, type=int) + parser.add_argument("--max_depth", help="max depth", default=1, type=int) training_params = parser.parse_args() env_params = [ diff --git a/run.py b/run.py index b780e21e2287c0f3b87472c8b207c9851c0cd218..7f5f0d0449ba4e2e663051898d64c845eac7e112 100644 --- a/run.py +++ b/run.py @@ -10,7 +10,6 @@ from flatland.envs.rail_env import RailEnvActions from flatland.evaluators.client import FlatlandRemoteClient from flatland.evaluators.client import TimeoutException -from reinforcement_learning.ppo.ppo_agent import PPOAgent from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent from utils.deadlock_check import check_if_all_blocked from utils.fast_tree_obs import FastTreeObs @@ -27,9 +26,8 @@ from reinforcement_learning.dddqn_policy import DDDQNPolicy VERBOSE = True # Checkpoint to use (remember to push it!) -checkpoint = "./checkpoints/201105222046-5400.pth" # 17.66104361971127 Depth 1 -checkpoint = "./checkpoints/201106073658-4400.pth" # 15.64082361736683 Depth 1 -checkpoint = "./checkpoints/201106170544-5400.pth" # 15.64082361736683 Depth 1 +checkpoint = "./checkpoints/201106234244-400.pth" # 15.64082361736683 Depth 1 +checkpoint = "./checkpoints/201106234900-100.pth" # 15.64082361736683 Depth 1 # Use last action cache USE_ACTION_CACHE = False @@ -140,9 +138,8 @@ while True: nb_hit += 1 else: action = policy.act(observation[agent], eps=0.01) - - if observation[agent][26] == 1: - action = RailEnvActions.STOP_MOVING + if observation[agent][26] == 1: + action = RailEnvActions.STOP_MOVING action_dict[agent] = action diff --git a/utils/fast_tree_obs.py b/utils/fast_tree_obs.py index 919be23b0eaafa1be76869a0c90ff18dd647e773..b6d673d997b6e43ebdba22578f8a0a7e21831000 100755 --- a/utils/fast_tree_obs.py +++ b/utils/fast_tree_obs.py @@ -23,7 +23,7 @@ class FastTreeObs(ObservationBuilder): def __init__(self, max_depth): self.max_depth = max_depth - self.observation_dim = 32 + self.observation_dim = 27 def build_data(self): if self.env is not None: @@ -303,23 +303,6 @@ class FastTreeObs(ObservationBuilder): action = self.dead_lock_avoidance_agent.act([handle], 0.0) observation[26] = int(action == RailEnvActions.STOP_MOVING) - observation[27] = int(action == RailEnvActions.MOVE_LEFT) - observation[28] = int(action == RailEnvActions.MOVE_FORWARD) - observation[29] = int(action == RailEnvActions.MOVE_RIGHT) - observation[30] = int(self.full_action_required(observation)) - observation[31] = int(fast_tree_obs_check_agent_deadlock(observation)) self.env.dev_obs_dict.update({handle: visited}) return observation - - def full_action_required(self, observation): - return observation[7] == 1 or observation[8] == 1 or observation[4] == 1 - - -def fast_tree_obs_check_agent_deadlock(observation): - nbr_of_path = 0 - nbr_of_blocked_path = 0 - for dir_loop in range(4): - nbr_of_path += observation[10 + dir_loop] - nbr_of_blocked_path += int(observation[14 + dir_loop] > 0) - return nbr_of_path <= nbr_of_blocked_path