Skip to content
Snippets Groups Projects
Commit 7bb7aebd authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files

DQN & PPO

parent 71c6e090
No related branches found
No related tags found
No related merge requests found
......@@ -67,7 +67,7 @@ class DDDQNPolicy(Policy):
else:
return random.choice(np.arange(self.action_size))
def step(self, state, action, reward, next_state, done):
def step(self, handle, state, action, reward, next_state, done):
assert not self.evaluation_mode, "Policy has been initialized for evaluation only."
# Save experience in replay memory
......
......@@ -18,13 +18,14 @@ from flatland.envs.schedule_generators import sparse_schedule_generator
from flatland.utils.rendertools import RenderTool
from torch.utils.tensorboard import SummaryWriter
from reinforcement_learning.ppo.ppo_agent import PPOAgent
base_dir = Path(__file__).resolve().parent.parent
sys.path.append(str(base_dir))
from utils.timer import Timer
from utils.observation_utils import normalize_observation
from utils.fast_tree_obs import FastTreeObs, fast_tree_obs_check_agent_deadlock
from reinforcement_learning.dddqn_policy import DDDQNPolicy
from utils.fast_tree_obs import FastTreeObs
try:
import wandb
......@@ -171,8 +172,8 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
completion_window = deque(maxlen=checkpoint_interval)
# Double Dueling DQN policy
policy = DDDQNPolicy(state_size, action_size, train_params)
# policy = DDDQNPolicy(state_size, action_size, train_params)
policy = PPOAgent(state_size, action_size, n_agents)
# Load existing policy
if train_params.load_policy is not "":
policy.load(train_params.load_policy)
......@@ -226,6 +227,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
train_env_params.n_agents = episode_idx % n_agents + 1
train_env = create_rail_env(train_env_params, tree_observation)
obs, info = train_env.reset(regenerate_rail=True, regenerate_schedule=True)
policy.reset()
reset_timer.end()
if train_params.render:
......@@ -288,18 +290,10 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
if update_values[agent] or done['__all__']:
# Only learn from timesteps where somethings happened
learn_timer.start()
call_step = True
if not (agent_obs[agent][7] == 1 or agent_obs[agent][8] == 1 or agent_obs[agent][4] == 1):
if action_dict.get(agent) == RailEnvActions.MOVE_FORWARD:
call_step = np.random.random() < 0.1
if fast_tree_obs_check_agent_deadlock(agent_obs[agent]):
all_rewards[agent] -= 10
call_step = True
if call_step:
policy.step(agent_prev_obs[agent], agent_prev_action[agent], all_rewards[agent],
agent_obs[agent],
done[agent])
policy.step(agent,
agent_prev_obs[agent], agent_prev_action[agent], all_rewards[agent],
agent_obs[agent],
done[agent])
learn_timer.end()
agent_prev_obs[agent] = agent_obs[agent].copy()
......@@ -444,7 +438,6 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params):
score = 0.0
obs, info = env.reset(regenerate_rail=True, regenerate_schedule=True)
final_step = 0
for step in range(max_steps - 1):
......
class Policy:
def step(self, state, action, reward, next_state, done):
def step(self, handle, state, action, reward, next_state, done):
raise NotImplementedError
def act(self, state, eps=0.):
......@@ -16,3 +16,12 @@ class Policy:
def end_step(self):
pass
def load_replay_buffer(self, filename):
pass
def test(self):
pass
def reset(self):
pass
\ No newline at end of file
......@@ -20,7 +20,7 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class PPOAgent(Policy):
def __init__(self, state_size, action_size, num_agents, env):
def __init__(self, state_size, action_size, num_agents):
self.action_size = action_size
self.state_size = state_size
self.num_agents = num_agents
......@@ -31,7 +31,7 @@ class PPOAgent(Policy):
self.memory = ReplayBuffer(BUFFER_SIZE)
self.t_step = 0
self.loss = 0
self.env = env
self.num_agents = num_agents
def reset(self):
self.finished = [False] * len(self.episodes)
......@@ -39,21 +39,11 @@ class PPOAgent(Policy):
# Decide on an action to take in the environment
def act(self, handle, state, eps=None):
if True:
self.policy.eval()
with torch.no_grad():
output = self.policy(torch.from_numpy(state).float().unsqueeze(0).to(device))
return Categorical(output).sample().item()
# Epsilon-greedy action selection
if random.random() > eps:
self.policy.eval()
with torch.no_grad():
output = self.policy(torch.from_numpy(state).float().unsqueeze(0).to(device))
return Categorical(output).sample().item()
else:
return random.choice(np.arange(self.action_size))
def act(self, state, eps=None):
self.policy.eval()
with torch.no_grad():
output = self.policy(torch.from_numpy(state).float().unsqueeze(0).to(device))
return Categorical(output).sample().item()
# Record the results of the agent's action and update the model
def step(self, handle, state, action, reward, next_state, done):
......
......@@ -146,7 +146,9 @@ def train_agent(n_episodes):
for agent in range(env.get_num_agents()):
# Only update the values when we are done or when an action was taken and thus relevant information is present
if update_values or done[agent]:
policy.step(agent_prev_obs[agent], agent_prev_action[agent], all_rewards[agent], agent_obs[agent], done[agent])
policy.step(agent,
agent_prev_obs[agent], agent_prev_action[agent], all_rewards[agent],
agent_obs[agent], done[agent])
agent_prev_obs[agent] = agent_obs[agent].copy()
agent_prev_action[agent] = action_dict[agent]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment