Skip to content
Snippets Groups Projects
Commit c12f806e authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files

clean up code - simplified

parent 04a942e5
No related branches found
No related tags found
No related merge requests found
......@@ -206,14 +206,6 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
scores_window = deque(maxlen=checkpoint_interval) # todo smooth when rendering instead
completion_window = deque(maxlen=checkpoint_interval)
# IF USE_SINGLE_AGENT_TRAINING is set and the episode_idx <= MAX_SINGLE_TRAINING_ITERATION then
# the training gets done with single use. Each UPDATE_POLICY2_N_EPISODE the second policy get replaced
# with the policy (the one which get trained).
USE_SINGLE_AGENT_TRAINING = False
MAX_SINGLE_TRAINING_ITERATION = 100000
UPDATE_POLICY2_N_EPISODE = 200
USE_DEADLOCK_AVOIDANCE_AS_POLICY2 = False
# Double Dueling DQN policy
policy = DDDQNPolicy(state_size, action_size, train_params)
if False:
......@@ -263,9 +255,6 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
preproc_timer = Timer()
inference_timer = Timer()
if episode_idx > MAX_SINGLE_TRAINING_ITERATION:
USE_SINGLE_AGENT_TRAINING = False
# Reset environment
reset_timer.start()
number_of_agents = int(min(n_agents, 1 + np.floor(episode_idx / 200)))
......@@ -274,13 +263,6 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
train_env = create_rail_env(train_env_params, tree_observation)
obs, info = train_env.reset(regenerate_rail=True, regenerate_schedule=True)
policy.reset()
if USE_DEADLOCK_AVOIDANCE_AS_POLICY2:
policy2 = DeadLockAvoidanceAgent(train_env, action_size)
else:
if episode_idx % UPDATE_POLICY2_N_EPISODE == 0:
policy2 = policy.clone()
reset_timer.end()
if train_params.render:
......@@ -307,26 +289,14 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
max_steps = train_env._max_episode_steps
# Run episode
agent_to_learn = [0]
if train_env.get_num_agents() > 1:
agent_to_learn = np.unique(np.random.choice(train_env.get_num_agents(), train_env.get_num_agents()))
# agent_to_learn = np.arange(train_env.get_num_agents())
for step in range(max_steps - 1):
inference_timer.start()
policy.start_step()
policy2.start_step()
for agent_handle in train_env.get_agent_handles():
agent = train_env.agents[agent_handle]
if info['action_required'][agent_handle]:
update_values[agent_handle] = True
if (agent_handle in agent_to_learn) or (not USE_SINGLE_AGENT_TRAINING):
action = policy.act(agent_obs[agent_handle], eps=eps_start)
else:
if USE_DEADLOCK_AVOIDANCE_AS_POLICY2:
action = policy2.act([agent_handle], eps=0.0)
else:
action = policy2.act(agent_obs[agent_handle], eps=0.0)
action = policy.act(agent_obs[agent_handle], eps=eps_start)
action_count[action] += 1
actions_taken.append(action)
......@@ -337,7 +307,6 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
action = 0
action_dict.update({agent_handle: action})
policy.end_step()
policy2.end_step()
inference_timer.end()
# Environment step
......@@ -383,13 +352,12 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
if update_values[agent_handle] or done['__all__']:
# Only learn from timesteps where somethings happened
learn_timer.start()
if (agent_handle in agent_to_learn) or (not USE_SINGLE_AGENT_TRAINING):
policy.step(agent_handle,
agent_prev_obs[agent_handle],
agent_prev_action[agent_handle],
all_rewards[agent_handle],
agent_obs[agent_handle],
done[agent_handle])
policy.step(agent_handle,
agent_prev_obs[agent_handle],
agent_prev_action[agent_handle],
all_rewards[agent_handle],
agent_obs[agent_handle],
done[agent_handle])
learn_timer.end()
agent_prev_obs[agent_handle] = agent_obs[agent_handle].copy()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment