Skip to content
Snippets Groups Projects
Commit 3df98d5c authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files

single agent learning in multi agent environment

parent 95dbb5be
No related branches found
No related tags found
No related merge requests found
Showing
with 21 additions and 13 deletions
File deleted
File deleted
File deleted
File deleted
File deleted
File deleted
File deleted
File deleted
File deleted
File deleted
File deleted
File deleted
......@@ -20,6 +20,7 @@ from torch.utils.tensorboard import SummaryWriter
from reinforcement_learning.dddqn_policy import DDDQNPolicy
from reinforcement_learning.ppo.ppo_agent import PPOAgent
from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent
base_dir = Path(__file__).resolve().parent.parent
sys.path.append(str(base_dir))
......@@ -229,6 +230,10 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
train_env = create_rail_env(train_env_params, tree_observation)
obs, info = train_env.reset(regenerate_rail=True, regenerate_schedule=True)
policy.reset()
policy2 = DeadLockAvoidanceAgent(train_env)
policy2.reset()
reset_timer.end()
if train_params.render:
......@@ -252,15 +257,21 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
max_steps = train_env._max_episode_steps
# Run episode
agent_to_learn = 0
if train_env.get_num_agents() > 1:
agent_to_learn = np.random.choice(train_env.get_num_agents())
for step in range(max_steps - 1):
inference_timer.start()
policy.start_step()
policy2.start_step()
for agent in train_env.get_agent_handles():
if info['action_required'][agent]:
update_values[agent] = True
action = policy.act(agent_obs[agent], eps=eps_start)
if agent == agent_to_learn:
action = policy.act(agent_obs[agent], eps=eps_start)
else:
action = policy2.act([agent], eps=eps_start)
action_count[action] += 1
actions_taken.append(action)
else:
......@@ -270,6 +281,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
action = 0
action_dict.update({agent: action})
policy.end_step()
policy2.end_step()
inference_timer.end()
# Environment step
......@@ -291,10 +303,11 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
if update_values[agent] or done['__all__']:
# Only learn from timesteps where somethings happened
learn_timer.start()
policy.step(agent,
agent_prev_obs[agent], agent_prev_action[agent], all_rewards[agent],
agent_obs[agent],
done[agent])
if agent == agent_to_learn:
policy.step(agent,
agent_prev_obs[agent], agent_prev_action[agent], all_rewards[agent],
agent_obs[agent],
done[agent])
learn_timer.end()
agent_prev_obs[agent] = agent_obs[agent].copy()
......@@ -481,7 +494,7 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params):
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=5400, type=int)
parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=2, type=int)
parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=1, type=int)
parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=0,
type=int)
parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=5, type=int)
......@@ -506,7 +519,7 @@ if __name__ == "__main__":
parser.add_argument("--load_policy", help="policy filename (reference) to load", default="", type=str)
parser.add_argument("--use_fast_tree_observation", help="use FastTreeObs instead of stock TreeObs",
action='store_true')
parser.add_argument("--max_depth", help="max depth", default=1, type=int)
parser.add_argument("--max_depth", help="max depth", default=2, type=int)
training_params = parser.parse_args()
env_params = [
......
......@@ -39,11 +39,6 @@ class PPOAgent(Policy):
# Decide on an action to take in the environment
def act(self, state, eps=None):
# if eps is not None:
# # Epsilon-greedy action selection
# if np.random.random() < eps:
# return np.random.choice(np.arange(self.action_size))
self.policy.eval()
with torch.no_grad():
output = self.policy(torch.from_numpy(state).float().unsqueeze(0).to(device))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment