Skip to content
Snippets Groups Projects
Commit 5c8f9fe6 authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files
parents db3a7c31 9cdfef4c
No related branches found
No related tags found
No related merge requests found
......@@ -12,7 +12,6 @@ env = RailEnv(width=15,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=10, min_dist=10, max_dist=99999, seed=0),
number_of_agents=5)
# Import your own Agent or use RLlib to train agents on Flatland
# As an example we use a random agent here
......@@ -39,14 +38,17 @@ class RandomAgent:
"""
return
def save(self):
def save(self, filename):
# Store the current policy
return
def load(self,filename):
# Load a policy
return
# Initialize the agent with the parameters corresponding to the environment and observation_builder
agent = RandomAgent(218, 4)
n_trials = 1000
n_trials = 5
# Empty dictionary for all agent action
action_dict = dict()
......@@ -71,11 +73,12 @@ for trials in range(1, n_trials + 1):
next_obs, all_rewards, done, _ = env.step(action_dict)
# Update replay buffer and train agent
agent.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a]))
score += all_rewards[a]
for a in range(env.get_num_agents()):
agent.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a]))
score += all_rewards[a]
obs = next_obs.copy()
if done['__all__']:
break
print('Episode Nr. {}'.format(trials))
print('Episode Nr. {}\t Score = {}'.format(trials,score))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment