Skip to content
Snippets Groups Projects
Commit 876f8c74 authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files

DQN & PPO

parent 7bb7aebd
No related branches found
No related tags found
No related merge requests found
File added
File added
...@@ -18,6 +18,7 @@ from flatland.envs.schedule_generators import sparse_schedule_generator ...@@ -18,6 +18,7 @@ from flatland.envs.schedule_generators import sparse_schedule_generator
from flatland.utils.rendertools import RenderTool from flatland.utils.rendertools import RenderTool
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from reinforcement_learning.dddqn_policy import DDDQNPolicy
from reinforcement_learning.ppo.ppo_agent import PPOAgent from reinforcement_learning.ppo.ppo_agent import PPOAgent
base_dir = Path(__file__).resolve().parent.parent base_dir = Path(__file__).resolve().parent.parent
...@@ -172,8 +173,8 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): ...@@ -172,8 +173,8 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
completion_window = deque(maxlen=checkpoint_interval) completion_window = deque(maxlen=checkpoint_interval)
# Double Dueling DQN policy # Double Dueling DQN policy
# policy = DDDQNPolicy(state_size, action_size, train_params) policy = DDDQNPolicy(state_size, action_size, train_params)
policy = PPOAgent(state_size, action_size, n_agents) # policy = PPOAgent(state_size, action_size, n_agents)
# Load existing policy # Load existing policy
if train_params.load_policy is not "": if train_params.load_policy is not "":
policy.load(train_params.load_policy) policy.load(train_params.load_policy)
...@@ -480,7 +481,7 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params): ...@@ -480,7 +481,7 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params):
if __name__ == "__main__": if __name__ == "__main__":
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=5400, type=int) parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=5400, type=int)
parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=1, type=int) parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=2, type=int)
parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=0, parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=0,
type=int) type=int)
parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=5, type=int) parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=5, type=int)
......
import os import os
import random
import numpy as np import numpy as np
import torch import torch
...@@ -17,6 +16,7 @@ CLIP_FACTOR = .005 ...@@ -17,6 +16,7 @@ CLIP_FACTOR = .005
UPDATE_EVERY = 30 UPDATE_EVERY = 30
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("device:", device)
class PPOAgent(Policy): class PPOAgent(Policy):
...@@ -31,7 +31,6 @@ class PPOAgent(Policy): ...@@ -31,7 +31,6 @@ class PPOAgent(Policy):
self.memory = ReplayBuffer(BUFFER_SIZE) self.memory = ReplayBuffer(BUFFER_SIZE)
self.t_step = 0 self.t_step = 0
self.loss = 0 self.loss = 0
self.num_agents = num_agents
def reset(self): def reset(self):
self.finished = [False] * len(self.episodes) self.finished = [False] * len(self.episodes)
...@@ -43,7 +42,8 @@ class PPOAgent(Policy): ...@@ -43,7 +42,8 @@ class PPOAgent(Policy):
self.policy.eval() self.policy.eval()
with torch.no_grad(): with torch.no_grad():
output = self.policy(torch.from_numpy(state).float().unsqueeze(0).to(device)) output = self.policy(torch.from_numpy(state).float().unsqueeze(0).to(device))
return Categorical(output).sample().item() ret = Categorical(output).sample().item()
return ret
# Record the results of the agent's action and update the model # Record the results of the agent's action and update the model
def step(self, handle, state, action, reward, next_state, done): def step(self, handle, state, action, reward, next_state, done):
...@@ -118,14 +118,14 @@ class PPOAgent(Policy): ...@@ -118,14 +118,14 @@ class PPOAgent(Policy):
if os.path.exists(filename + ".policy"): if os.path.exists(filename + ".policy"):
print(' >> ', filename + ".policy") print(' >> ', filename + ".policy")
try: try:
self.policy.load_state_dict(torch.load(filename + ".policy")) self.policy.load_state_dict(torch.load(filename + ".policy", map_location=device))
except: except:
print(" >> failed!") print(" >> failed!")
pass pass
if os.path.exists(filename + ".optimizer"): if os.path.exists(filename + ".optimizer"):
print(' >> ', filename + ".optimizer") print(' >> ', filename + ".optimizer")
try: try:
self.optimizer.load_state_dict(torch.load(filename + ".optimizer")) self.optimizer.load_state_dict(torch.load(filename + ".optimizer", map_location=device))
except: except:
print(" >> failed!") print(" >> failed!")
pass pass
...@@ -6,6 +6,7 @@ from pathlib import Path ...@@ -6,6 +6,7 @@ from pathlib import Path
import numpy as np import numpy as np
from flatland.core.env_observation_builder import DummyObservationBuilder from flatland.core.env_observation_builder import DummyObservationBuilder
from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnvActions
from flatland.evaluators.client import FlatlandRemoteClient from flatland.evaluators.client import FlatlandRemoteClient
from flatland.evaluators.client import TimeoutException from flatland.evaluators.client import TimeoutException
...@@ -25,10 +26,12 @@ from reinforcement_learning.dddqn_policy import DDDQNPolicy ...@@ -25,10 +26,12 @@ from reinforcement_learning.dddqn_policy import DDDQNPolicy
VERBOSE = True VERBOSE = True
# Checkpoint to use (remember to push it!) # Checkpoint to use (remember to push it!)
checkpoint = "./checkpoints/201105173637-4700.pth" # 18.50097663335293 : Depth = 1 checkpoint = "./checkpoints/201105222046-5400.pth" # 17.66104361971127 Depth 1
checkpoint = "./checkpoints/201106073658-4300.pth" # 15.64082361736683 Depth 1
checkpoint = "./checkpoints/201106090621-3300.pth" # 15.64082361736683 Depth 1
# Use last action cache # Use last action cache
USE_ACTION_CACHE = True USE_ACTION_CACHE = False
USE_DEAD_LOCK_AVOIDANCE_AGENT = False USE_DEAD_LOCK_AVOIDANCE_AGENT = False
# Observation parameters (must match training parameters!) # Observation parameters (must match training parameters!)
...@@ -50,6 +53,7 @@ action_size = 5 ...@@ -50,6 +53,7 @@ action_size = 5
# Creates the policy. No GPU on evaluation server. # Creates the policy. No GPU on evaluation server.
policy = DDDQNPolicy(state_size, action_size, Namespace(**{'use_gpu': False}), evaluation_mode=True) policy = DDDQNPolicy(state_size, action_size, Namespace(**{'use_gpu': False}), evaluation_mode=True)
# policy = PPOAgent(state_size, action_size, 10)
policy.load(checkpoint) policy.load(checkpoint)
##################################################################### #####################################################################
...@@ -134,7 +138,10 @@ while True: ...@@ -134,7 +138,10 @@ while True:
action = agent_last_action[agent] action = agent_last_action[agent]
nb_hit += 1 nb_hit += 1
else: else:
action = policy.act(observation[agent], eps=0.0) action = policy.act(observation[agent], eps=0.01)
if observation[agent][26] == 1:
action = RailEnvActions.STOP_MOVING
action_dict[agent] = action action_dict[agent] = action
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment