diff --git a/examples/training_navigation.py b/examples/training_navigation.py index 231d4e923a3bf2e7a485b804b19d54a48b67fb2d..18bb63563180c90227e5df12554e3d51faddbc91 100644 --- a/examples/training_navigation.py +++ b/examples/training_navigation.py @@ -10,8 +10,8 @@ np.random.seed(1) # Example generate a rail given a manual specification, # a map of tuples (cell_type, rotation) -transition_probability = [10.0, # empty cell - Case 0 - 50.0, # Case 1 - straight +transition_probability = [0.5, # empty cell - Case 0 + 1.0, # Case 1 - straight 1.0, # Case 2 - simple switch 0.3, # Case 3 - diamond drossing 0.5, # Case 4 - single slip @@ -20,8 +20,8 @@ transition_probability = [10.0, # empty cell - Case 0 0.0] # Case 7 - dead end # Example generate a random rail -env = RailEnv(width=5, - height=5, +env = RailEnv(width=7, + height=7, rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), number_of_agents=1) env_renderer = RenderTool(env) @@ -49,7 +49,7 @@ def max_lt(seq, val): idx = len(seq)-1 while idx >= 0: - if seq[idx] < val and seq[idx] > 0: + if seq[idx] < val and seq[idx] >= 0: return seq[idx] idx -= 1 return None @@ -110,6 +110,7 @@ for trials in range(1, n_trials + 1): eps, action_prob/np.sum(action_prob)), end=" ") if trials % 100 == 0: + action_prob = [1]*4 print( '\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format( env.number_of_agents,