diff --git a/examples/training_example.py b/examples/training_example.py index d6f2c0268a9d7aece260b8e0f97ae1ff68d28bb6..ee97c7e4c90dc15018d6d87c5b2293455075c5f0 100644 --- a/examples/training_example.py +++ b/examples/training_example.py @@ -12,7 +12,6 @@ env = RailEnv(width=15, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=10, min_dist=10, max_dist=99999, seed=0), number_of_agents=5) - # Import your own Agent or use RLlib to train agents on Flatland # As an example we use a random agent here @@ -39,14 +38,17 @@ class RandomAgent: """ return - def save(self): + def save(self, filename): # Store the current policy return + def load(self,filename): + # Load a policy + return # Initialize the agent with the parameters corresponding to the environment and observation_builder agent = RandomAgent(218, 4) -n_trials = 1000 +n_trials = 5 # Empty dictionary for all agent action action_dict = dict() @@ -71,11 +73,12 @@ for trials in range(1, n_trials + 1): next_obs, all_rewards, done, _ = env.step(action_dict) # Update replay buffer and train agent - agent.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a])) - score += all_rewards[a] + for a in range(env.get_num_agents()): + agent.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a])) + score += all_rewards[a] obs = next_obs.copy() if done['__all__']: break - print('Episode Nr. {}'.format(trials)) + print('Episode Nr. {}\t Score = {}'.format(trials,score))