Commit 18dcace9 authored by MasterScrat's avatar MasterScrat
Browse files

Hyperparam tweaks, cleanup

parent 9b771c14
......@@ -197,7 +197,7 @@ def eval_policy(env_params, checkpoint, n_eval_episodes, max_steps, action_size,
hit_text = ""
if nb_hit > 0:
hit_text = "\t⚡ Hit {} ({:.1f}%)".format(nb_hit, (100*nb_hit)/(n_agents*final_step))
hit_text = "\t⚡ Hit {} ({:.1f}%)".format(nb_hit, (100 * nb_hit) / (n_agents * final_step))
print(
"☑️ Score: {:.3f} \tDone: {:.1f}% \tNb steps: {:.3f} "
......@@ -302,8 +302,8 @@ def evaluate_agents(file, n_evaluation_episodes, use_gpu, render, allow_skipping
"observation_max_path_depth": 20
}
params = test5_params
env_params = Namespace(**test0_params)
params = small_v0_params
env_params = Namespace(**params)
print("Environment parameters:")
pprint(params)
......@@ -361,6 +361,10 @@ if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("-f", "--file", help="checkpoint to load", required=True, type=str)
parser.add_argument("-n", "--n_evaluation_episodes", help="number of evaluation episodes", default=25, type=int)
# TODO
# parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=0, type=int)
parser.add_argument("--use_gpu", dest="use_gpu", help="use GPU if available", action='store_true')
parser.add_argument("--render", help="render a single episode", action='store_true')
parser.add_argument("--allow_skipping", help="skips to the end of the episode if all agents are deadlocked", action='store_true')
......
......@@ -429,18 +429,18 @@ if __name__ == "__main__":
parser.add_argument("--eps_start", help="max exploration", default=1.0, type=float)
parser.add_argument("--eps_end", help="min exploration", default=0.01, type=float)
parser.add_argument("--eps_decay", help="exploration decay", default=0.99, type=float)
parser.add_argument("--buffer_size", help="replay buffer size", default=int(5e5), type=int)
parser.add_argument("--buffer_size", help="replay buffer size", default=int(1e5), type=int)
parser.add_argument("--buffer_min_size", help="min buffer size to start training", default=0, type=int)
parser.add_argument("--restore_replay_buffer", help="replay buffer to restore", default="", type=str)
parser.add_argument("--save_replay_buffer", help="save replay buffer at each evaluation interval", default=False, type=bool)
parser.add_argument("--batch_size", help="minibatch size", default=32, type=int)
parser.add_argument("--batch_size", help="minibatch size", default=128, type=int)
parser.add_argument("--gamma", help="discount factor", default=0.99, type=float)
parser.add_argument("--tau", help="soft update of target parameters", default=1e-3, type=float)
parser.add_argument("--learning_rate", help="learning rate", default=0.5e-4, type=float)
parser.add_argument("--hidden_size", help="hidden size (2 fc layers)", default=256, type=int)
parser.add_argument("--hidden_size", help="hidden size (2 fc layers)", default=128, type=int)
parser.add_argument("--update_every", help="how often to update the network", default=8, type=int)
parser.add_argument("--use_gpu", help="use GPU if available", default=False, type=bool)
parser.add_argument("--num_threads", help="number of threads to use", default=1, type=int)
parser.add_argument("--num_threads", help="number of threads PyTorch can use", default=1, type=int)
parser.add_argument("--render", help="render 1 episode in 100", default=False, type=bool)
training_params = parser.parse_args()
......
......@@ -25,8 +25,13 @@ from utils.observation_utils import normalize_observation
# Checkpoint to use
checkpoint = "checkpoints/sample-checkpoint.pth"
# Beyond this number of agents, skip the episode
max_env_width = 35
# Beyond this env width, skip the episode
MAX_ENV_WIDTH = 35
# Use action cache:
# Saves the last observation-action mapping for each
# Only works for deterministic agents!
USE_ACTION_CACHE = False
# Observation parameters
# These need to match your training parameters!
......@@ -34,10 +39,6 @@ observation_tree_depth = 2
observation_radius = 10
observation_max_path_depth = 30
# Use action cache:
# Saves the last observation-action mapping for each
# Only works for deterministic agents!
use_action_cache = True
####################################################
remote_client = FlatlandRemoteClient()
......@@ -108,7 +109,7 @@ while True:
obs_time, agent_time, step_time = 0.0, 0.0, 0.0
no_ops_mode = False
if local_env.width <= max_env_width and not check_if_all_blocked(env=local_env):
if local_env.width <= MAX_ENV_WIDTH and not check_if_all_blocked(env=local_env):
time_start = time.time()
action_dict = {}
for agent in range(nb_agents):
......@@ -122,7 +123,7 @@ while True:
action_dict[agent] = action
if use_action_cache:
if USE_ACTION_CACHE:
agent_last_obs[agent] = observation[agent]
agent_last_action[agent] = action
agent_time = time.time() - time_start
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment