From 7bb7aebd8893d291415505441e9245e36b8de914 Mon Sep 17 00:00:00 2001 From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch> Date: Thu, 5 Nov 2020 21:12:41 +0100 Subject: [PATCH] DQN & PPO --- reinforcement_learning/dddqn_policy.py | 2 +- .../multi_agent_training.py | 27 +++++++------------ reinforcement_learning/policy.py | 11 +++++++- reinforcement_learning/ppo/ppo_agent.py | 24 +++++------------ .../single_agent_training.py | 4 ++- 5 files changed, 31 insertions(+), 37 deletions(-) diff --git a/reinforcement_learning/dddqn_policy.py b/reinforcement_learning/dddqn_policy.py index 2cf7ad2..1c323c3 100644 --- a/reinforcement_learning/dddqn_policy.py +++ b/reinforcement_learning/dddqn_policy.py @@ -67,7 +67,7 @@ class DDDQNPolicy(Policy): else: return random.choice(np.arange(self.action_size)) - def step(self, state, action, reward, next_state, done): + def step(self, handle, state, action, reward, next_state, done): assert not self.evaluation_mode, "Policy has been initialized for evaluation only." # Save experience in replay memory diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py index 1d99a9b..6f250a8 100755 --- a/reinforcement_learning/multi_agent_training.py +++ b/reinforcement_learning/multi_agent_training.py @@ -18,13 +18,14 @@ from flatland.envs.schedule_generators import sparse_schedule_generator from flatland.utils.rendertools import RenderTool from torch.utils.tensorboard import SummaryWriter +from reinforcement_learning.ppo.ppo_agent import PPOAgent + base_dir = Path(__file__).resolve().parent.parent sys.path.append(str(base_dir)) from utils.timer import Timer from utils.observation_utils import normalize_observation -from utils.fast_tree_obs import FastTreeObs, fast_tree_obs_check_agent_deadlock -from reinforcement_learning.dddqn_policy import DDDQNPolicy +from utils.fast_tree_obs import FastTreeObs try: import wandb @@ -171,8 +172,8 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): completion_window = deque(maxlen=checkpoint_interval) # Double Dueling DQN policy - policy = DDDQNPolicy(state_size, action_size, train_params) - + # policy = DDDQNPolicy(state_size, action_size, train_params) + policy = PPOAgent(state_size, action_size, n_agents) # Load existing policy if train_params.load_policy is not "": policy.load(train_params.load_policy) @@ -226,6 +227,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): train_env_params.n_agents = episode_idx % n_agents + 1 train_env = create_rail_env(train_env_params, tree_observation) obs, info = train_env.reset(regenerate_rail=True, regenerate_schedule=True) + policy.reset() reset_timer.end() if train_params.render: @@ -288,18 +290,10 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): if update_values[agent] or done['__all__']: # Only learn from timesteps where somethings happened learn_timer.start() - call_step = True - - if not (agent_obs[agent][7] == 1 or agent_obs[agent][8] == 1 or agent_obs[agent][4] == 1): - if action_dict.get(agent) == RailEnvActions.MOVE_FORWARD: - call_step = np.random.random() < 0.1 - if fast_tree_obs_check_agent_deadlock(agent_obs[agent]): - all_rewards[agent] -= 10 - call_step = True - if call_step: - policy.step(agent_prev_obs[agent], agent_prev_action[agent], all_rewards[agent], - agent_obs[agent], - done[agent]) + policy.step(agent, + agent_prev_obs[agent], agent_prev_action[agent], all_rewards[agent], + agent_obs[agent], + done[agent]) learn_timer.end() agent_prev_obs[agent] = agent_obs[agent].copy() @@ -444,7 +438,6 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params): score = 0.0 obs, info = env.reset(regenerate_rail=True, regenerate_schedule=True) - final_step = 0 for step in range(max_steps - 1): diff --git a/reinforcement_learning/policy.py b/reinforcement_learning/policy.py index b8714d1..b605aa3 100644 --- a/reinforcement_learning/policy.py +++ b/reinforcement_learning/policy.py @@ -1,5 +1,5 @@ class Policy: - def step(self, state, action, reward, next_state, done): + def step(self, handle, state, action, reward, next_state, done): raise NotImplementedError def act(self, state, eps=0.): @@ -16,3 +16,12 @@ class Policy: def end_step(self): pass + + def load_replay_buffer(self, filename): + pass + + def test(self): + pass + + def reset(self): + pass \ No newline at end of file diff --git a/reinforcement_learning/ppo/ppo_agent.py b/reinforcement_learning/ppo/ppo_agent.py index e43cb30..663a05a 100644 --- a/reinforcement_learning/ppo/ppo_agent.py +++ b/reinforcement_learning/ppo/ppo_agent.py @@ -20,7 +20,7 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") class PPOAgent(Policy): - def __init__(self, state_size, action_size, num_agents, env): + def __init__(self, state_size, action_size, num_agents): self.action_size = action_size self.state_size = state_size self.num_agents = num_agents @@ -31,7 +31,7 @@ class PPOAgent(Policy): self.memory = ReplayBuffer(BUFFER_SIZE) self.t_step = 0 self.loss = 0 - self.env = env + self.num_agents = num_agents def reset(self): self.finished = [False] * len(self.episodes) @@ -39,21 +39,11 @@ class PPOAgent(Policy): # Decide on an action to take in the environment - def act(self, handle, state, eps=None): - if True: - self.policy.eval() - with torch.no_grad(): - output = self.policy(torch.from_numpy(state).float().unsqueeze(0).to(device)) - return Categorical(output).sample().item() - - # Epsilon-greedy action selection - if random.random() > eps: - self.policy.eval() - with torch.no_grad(): - output = self.policy(torch.from_numpy(state).float().unsqueeze(0).to(device)) - return Categorical(output).sample().item() - else: - return random.choice(np.arange(self.action_size)) + def act(self, state, eps=None): + self.policy.eval() + with torch.no_grad(): + output = self.policy(torch.from_numpy(state).float().unsqueeze(0).to(device)) + return Categorical(output).sample().item() # Record the results of the agent's action and update the model def step(self, handle, state, action, reward, next_state, done): diff --git a/reinforcement_learning/single_agent_training.py b/reinforcement_learning/single_agent_training.py index 236d1a7..bfcc886 100644 --- a/reinforcement_learning/single_agent_training.py +++ b/reinforcement_learning/single_agent_training.py @@ -146,7 +146,9 @@ def train_agent(n_episodes): for agent in range(env.get_num_agents()): # Only update the values when we are done or when an action was taken and thus relevant information is present if update_values or done[agent]: - policy.step(agent_prev_obs[agent], agent_prev_action[agent], all_rewards[agent], agent_obs[agent], done[agent]) + policy.step(agent, + agent_prev_obs[agent], agent_prev_action[agent], all_rewards[agent], + agent_obs[agent], done[agent]) agent_prev_obs[agent] = agent_obs[agent].copy() agent_prev_action[agent] = action_dict[agent] -- GitLab