diff --git a/reinforcement_learning/dddqn_policy.py b/reinforcement_learning/dddqn_policy.py index 2cf7ad25e4f581c86cd53bc49671669b8820ab8f..1c323c366c7ff6b08baf1ff2d7d9ceee389bbd57 100644 --- a/reinforcement_learning/dddqn_policy.py +++ b/reinforcement_learning/dddqn_policy.py @@ -67,7 +67,7 @@ class DDDQNPolicy(Policy): else: return random.choice(np.arange(self.action_size)) - def step(self, state, action, reward, next_state, done): + def step(self, handle, state, action, reward, next_state, done): assert not self.evaluation_mode, "Policy has been initialized for evaluation only." # Save experience in replay memory diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py index 1d99a9b80d51d8ff34991efbeb802af2a373a292..6f250a83fb830afef842037c312629b99cdb78c1 100755 --- a/reinforcement_learning/multi_agent_training.py +++ b/reinforcement_learning/multi_agent_training.py @@ -18,13 +18,14 @@ from flatland.envs.schedule_generators import sparse_schedule_generator from flatland.utils.rendertools import RenderTool from torch.utils.tensorboard import SummaryWriter +from reinforcement_learning.ppo.ppo_agent import PPOAgent + base_dir = Path(__file__).resolve().parent.parent sys.path.append(str(base_dir)) from utils.timer import Timer from utils.observation_utils import normalize_observation -from utils.fast_tree_obs import FastTreeObs, fast_tree_obs_check_agent_deadlock -from reinforcement_learning.dddqn_policy import DDDQNPolicy +from utils.fast_tree_obs import FastTreeObs try: import wandb @@ -171,8 +172,8 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): completion_window = deque(maxlen=checkpoint_interval) # Double Dueling DQN policy - policy = DDDQNPolicy(state_size, action_size, train_params) - + # policy = DDDQNPolicy(state_size, action_size, train_params) + policy = PPOAgent(state_size, action_size, n_agents) # Load existing policy if train_params.load_policy is not "": policy.load(train_params.load_policy) @@ -226,6 +227,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): train_env_params.n_agents = episode_idx % n_agents + 1 train_env = create_rail_env(train_env_params, tree_observation) obs, info = train_env.reset(regenerate_rail=True, regenerate_schedule=True) + policy.reset() reset_timer.end() if train_params.render: @@ -288,18 +290,10 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): if update_values[agent] or done['__all__']: # Only learn from timesteps where somethings happened learn_timer.start() - call_step = True - - if not (agent_obs[agent][7] == 1 or agent_obs[agent][8] == 1 or agent_obs[agent][4] == 1): - if action_dict.get(agent) == RailEnvActions.MOVE_FORWARD: - call_step = np.random.random() < 0.1 - if fast_tree_obs_check_agent_deadlock(agent_obs[agent]): - all_rewards[agent] -= 10 - call_step = True - if call_step: - policy.step(agent_prev_obs[agent], agent_prev_action[agent], all_rewards[agent], - agent_obs[agent], - done[agent]) + policy.step(agent, + agent_prev_obs[agent], agent_prev_action[agent], all_rewards[agent], + agent_obs[agent], + done[agent]) learn_timer.end() agent_prev_obs[agent] = agent_obs[agent].copy() @@ -444,7 +438,6 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params): score = 0.0 obs, info = env.reset(regenerate_rail=True, regenerate_schedule=True) - final_step = 0 for step in range(max_steps - 1): diff --git a/reinforcement_learning/policy.py b/reinforcement_learning/policy.py index b8714d1a2fd8e085d4e9e00c48c7362846e8ed87..b605aa3ddaf43ad1a496e44a3fac367be4bd8234 100644 --- a/reinforcement_learning/policy.py +++ b/reinforcement_learning/policy.py @@ -1,5 +1,5 @@ class Policy: - def step(self, state, action, reward, next_state, done): + def step(self, handle, state, action, reward, next_state, done): raise NotImplementedError def act(self, state, eps=0.): @@ -16,3 +16,12 @@ class Policy: def end_step(self): pass + + def load_replay_buffer(self, filename): + pass + + def test(self): + pass + + def reset(self): + pass \ No newline at end of file diff --git a/reinforcement_learning/ppo/ppo_agent.py b/reinforcement_learning/ppo/ppo_agent.py index e43cb3080bf5187e148f0b06e69749f46e840e9e..663a05acb42fd5a919f4eff0c3c45146d9bbd471 100644 --- a/reinforcement_learning/ppo/ppo_agent.py +++ b/reinforcement_learning/ppo/ppo_agent.py @@ -20,7 +20,7 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") class PPOAgent(Policy): - def __init__(self, state_size, action_size, num_agents, env): + def __init__(self, state_size, action_size, num_agents): self.action_size = action_size self.state_size = state_size self.num_agents = num_agents @@ -31,7 +31,7 @@ class PPOAgent(Policy): self.memory = ReplayBuffer(BUFFER_SIZE) self.t_step = 0 self.loss = 0 - self.env = env + self.num_agents = num_agents def reset(self): self.finished = [False] * len(self.episodes) @@ -39,21 +39,11 @@ class PPOAgent(Policy): # Decide on an action to take in the environment - def act(self, handle, state, eps=None): - if True: - self.policy.eval() - with torch.no_grad(): - output = self.policy(torch.from_numpy(state).float().unsqueeze(0).to(device)) - return Categorical(output).sample().item() - - # Epsilon-greedy action selection - if random.random() > eps: - self.policy.eval() - with torch.no_grad(): - output = self.policy(torch.from_numpy(state).float().unsqueeze(0).to(device)) - return Categorical(output).sample().item() - else: - return random.choice(np.arange(self.action_size)) + def act(self, state, eps=None): + self.policy.eval() + with torch.no_grad(): + output = self.policy(torch.from_numpy(state).float().unsqueeze(0).to(device)) + return Categorical(output).sample().item() # Record the results of the agent's action and update the model def step(self, handle, state, action, reward, next_state, done): diff --git a/reinforcement_learning/single_agent_training.py b/reinforcement_learning/single_agent_training.py index 236d1a76cbbff9be26612e81bf24265886acbab8..bfcc88656c8b37a8c09e72b51701d0750cf7f238 100644 --- a/reinforcement_learning/single_agent_training.py +++ b/reinforcement_learning/single_agent_training.py @@ -146,7 +146,9 @@ def train_agent(n_episodes): for agent in range(env.get_num_agents()): # Only update the values when we are done or when an action was taken and thus relevant information is present if update_values or done[agent]: - policy.step(agent_prev_obs[agent], agent_prev_action[agent], all_rewards[agent], agent_obs[agent], done[agent]) + policy.step(agent, + agent_prev_obs[agent], agent_prev_action[agent], all_rewards[agent], + agent_obs[agent], done[agent]) agent_prev_obs[agent] = agent_obs[agent].copy() agent_prev_action[agent] = action_dict[agent]