diff --git a/examples/training_navigation.py b/examples/training_navigation.py index 9d45cd175c7d324804a2a34e278428aebad69e28..12902aa5798ae2a0e23cd5d07a5592327b0c7d9f 100644 --- a/examples/training_navigation.py +++ b/examples/training_navigation.py @@ -32,8 +32,8 @@ env = RailEnv(width=10, """ env = RailEnv(width=15, height=15, - rail_generator=complex_rail_generator(nr_start_goal=5, min_dist=5, max_dist=99999, seed=0), - number_of_agents=5) + rail_generator=complex_rail_generator(nr_start_goal=10, min_dist=5, max_dist=99999, seed=0), + number_of_agents=3) """ env = RailEnv(width=20, height=20, @@ -50,7 +50,7 @@ action_size = 4 n_trials = 15000 eps = 1. eps_end = 0.005 -eps_decay = 0.998 +eps_decay = 0.9995 action_dict = dict() final_action_dict = dict() scores_window = deque(maxlen=100) @@ -62,9 +62,9 @@ action_prob = [0] * 4 agent_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents() agent = Agent(state_size, action_size, "FC", 0) -# agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint15000.pth')) +agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint15000.pth')) -demo = False +demo = True def max_lt(seq, val): diff --git a/flatland/baselines/Nets/avoid_checkpoint15000.pth b/flatland/baselines/Nets/avoid_checkpoint15000.pth index adcfe61576553bbf0e2b4ba00d9fffafbfd9d7da..14882a37a86085b137f4422b6bba75f387a2d3b5 100644 Binary files a/flatland/baselines/Nets/avoid_checkpoint15000.pth and b/flatland/baselines/Nets/avoid_checkpoint15000.pth differ