diff --git a/examples/training_navigation.py b/examples/training_navigation.py index 51d51b156e627fda9d122a8247cf71058bfd3c38..456a4a03b3a1716f4c69516b8132aa85293bdf91 100644 --- a/examples/training_navigation.py +++ b/examples/training_navigation.py @@ -20,8 +20,8 @@ transition_probability = [5, # empty cell - Case 0 0] # Case 7 - dead end # Example generate a random rail -env = RailEnv(width=10, - height=10, +env = RailEnv(width=15, + height=15, rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), number_of_agents=3) env_renderer = RenderTool(env, gl="QT") @@ -38,37 +38,41 @@ scores_window = deque(maxlen=100) done_window = deque(maxlen=100) scores = [] dones_list = [] -action_prob = [0]*4 +action_prob = [0] * 4 agent = Agent(state_size, action_size, "FC", 0) -agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint15000.pth')) +agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint13900.pth')) + +demo = True + -demo = False def max_lt(seq, val): """ Return greatest item in seq for which item < val applies. None is returned if seq was empty or all items in seq were >= val. """ max = 0 - idx = len(seq)-1 + idx = len(seq) - 1 while idx >= 0: if seq[idx] < val and seq[idx] >= 0 and seq[idx] > max: max = seq[idx] idx -= 1 return max + def min_lt(seq, val): """ Return smallest item in seq for which item > val applies. None is returned if seq was empty or all items in seq were >= val. """ min = np.inf - idx = len(seq)-1 + idx = len(seq) - 1 while idx >= 0: if seq[idx] > val and seq[idx] < min: min = seq[idx] idx -= 1 return min + for trials in range(1, n_trials + 1): # Reset environment @@ -86,7 +90,7 @@ for trials in range(1, n_trials + 1): for step in range(100): if demo: env_renderer.renderEnv(show=True) - #print(step) + # print(step) # Action for a in range(env.number_of_agents): if demo: @@ -117,17 +121,17 @@ for trials in range(1, n_trials + 1): scores.append(np.mean(scores_window)) dones_list.append((np.mean(done_window))) - print('\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format( - env.number_of_agents, - trials, - np.mean( - scores_window), - 100 * np.mean( - done_window), - eps, action_prob/np.sum(action_prob)), - end=" ") + print( + '\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format( + env.number_of_agents, + trials, + np.mean( + scores_window), + 100 * np.mean( + done_window), + eps, action_prob / np.sum(action_prob)), + end=" ") if trials % 100 == 0: - print( '\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format( env.number_of_agents, @@ -139,4 +143,4 @@ for trials in range(1, n_trials + 1): eps, action_prob / np.sum(action_prob))) torch.save(agent.qnetwork_local.state_dict(), '../flatland/baselines/Nets/avoid_checkpoint' + str(trials) + '.pth') - action_prob = [1]*4 + action_prob = [1] * 4