Skip to content
Snippets Groups Projects
Commit 1c60b970 authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files

refactored and added new agent

parent 66929e4a
No related branches found
No related tags found
No related merge requests found
......@@ -22,9 +22,9 @@ from torch.utils.tensorboard import SummaryWriter
from reinforcement_learning.dddqn_policy import DDDQNPolicy
from reinforcement_learning.ppo_agent import PPOAgent
from reinforcement_learning.ppo_deadlockavoidance_agent import MultiDecisionAgent
from utils.agent_action_config import get_flatland_full_action_size, get_action_size, map_actions, map_action
from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent
from utils.deadlock_check import get_agent_positions, check_for_deadlock
from utils.agent_action_config import get_flatland_full_action_size, get_action_size, map_actions, map_action
base_dir = Path(__file__).resolve().parent.parent
sys.path.append(str(base_dir))
......@@ -174,9 +174,9 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
policy = DDDQNPolicy(state_size, get_action_size(), train_params)
if False:
policy = PPOAgent(state_size, get_action_size())
if True:
if False:
policy = DeadLockAvoidanceAgent(train_env, get_action_size())
if True:
if False:
policy = MultiDecisionAgent(train_env, state_size, get_action_size(), policy)
# Load existing policy
......@@ -387,7 +387,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
'\t 🎲 Epsilon: {:.3f} '
'\t 🔀 Action Probs: {}'.format(
episode_idx,
train_env_params.n_agents, train_env.get_num_agents(),
train_env_params.n_agents, number_of_agents,
normalized_score,
smoothed_normalized_score,
100 * completion,
......@@ -521,11 +521,11 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params):
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=12000, type=int)
parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=2,
parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=3,
type=int)
parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=0,
parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=2,
type=int)
parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=5, type=int)
parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=10, type=int)
parser.add_argument("--checkpoint_interval", help="checkpoint interval", default=100, type=int)
parser.add_argument("--eps_start", help="max exploration", default=0.1, type=float)
parser.add_argument("--eps_end", help="min exploration", default=0.005, type=float)
......
......@@ -49,26 +49,18 @@ from reinforcement_learning.dddqn_policy import DDDQNPolicy
# Print per-step logs
VERBOSE = True
USE_FAST_TREEOBS = True
USE_PPO_AGENT = True
USE_PPO_AGENT = False
# Checkpoint to use (remember to push it!)
checkpoint = "./checkpoints/201124171810-7800.pth" # DDDQN: 18.249244799876152 DEPTH=2 AGENTS=10
# checkpoint = "./checkpoints/201126150143-5200.pth" # DDDQN: 18.249244799876152 DEPTH=2 AGENTS=10
# checkpoint = "./checkpoints/201126160144-2000.pth" # DDDQN: 18.249244799876152 DEPTH=2 AGENTS=10
checkpoint = "./checkpoints/201207144650-20000.pth" # PPO: 14.45790721540786
checkpoint = "./checkpoints/201211063511-6300.pth" # DDDQN: 16.948349308440857
checkpoint = "./checkpoints/201211095604-12000.pth" # DDDQN: 17.3862941316504
checkpoint = "./checkpoints/201211164554-9400.pth" # DDDQN: 16.09241366013537
checkpoint = "./checkpoints/201213181400-6800.pth" # PPO: 13.944402986414723
checkpoint = "./checkpoints/201214140158-5000.pth" # USE_MULTI_DECISION_AGENT with DDDQN: 13.944402986414723
checkpoint = "./checkpoints/201215120518-3700.pth" # USE_MULTI_DECISION_AGENT with PPO: 13.944402986414723
checkpoint = "./checkpoints/201215160226-12000.pth" #
# checkpoint = "./checkpoints/201215212134-12000.pth" #
EPSILON = 0.0
# Use last action cache
USE_ACTION_CACHE = False
USE_DEAD_LOCK_AVOIDANCE_AGENT = False # 21.54485505223213
USE_MULTI_DECISION_AGENT = False
USE_MULTI_DECISION_AGENT = True
# Observation parameters (must match training parameters!)
observation_tree_depth = 2
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment