diff --git a/examples/training_example.py b/examples/training_example.py
index 6910461327c778ff52824165032641ece019cf7a..70986c55ca5443ec9dd330fbc1dbfa6def768187 100644
--- a/examples/training_example.py
+++ b/examples/training_example.py
@@ -14,11 +14,12 @@ np.random.seed(1)
 
 TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
 LocalGridObs = LocalObsForRailEnv(view_height=10, view_width=2, center=2)
+
 env = RailEnv(width=50,
               height=50,
               rail_generator=complex_rail_generator(nr_start_goal=20, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
               obs_builder_object=TreeObservation,
-              number_of_agents=20)
+              number_of_agents=10)
 
 env_renderer = RenderTool(env, gl="PILSVG", )
 
@@ -58,7 +59,7 @@ class RandomAgent:
 
 
 # Initialize the agent with the parameters corresponding to the environment and observation_builder
-agent = RandomAgent(218, 4)
+agent = RandomAgent(218, 5)
 n_trials = 5
 
 # Empty dictionary for all agent action
@@ -75,7 +76,7 @@ for trials in range(1, n_trials + 1):
 
     score = 0
     # Run episode
-    for step in range(100):
+    for step in range(500):
         # Chose an action for each agent in the environment
         for a in range(env.get_num_agents()):
             action = agent.act(obs[a])
@@ -89,7 +90,6 @@ for trials in range(1, n_trials + 1):
         for a in range(env.get_num_agents()):
             agent.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a]))
             score += all_rewards[a]
-
         obs = next_obs.copy()
         if done['__all__']:
             break