Skip to content
Snippets Groups Projects
Commit 58f7caf0 authored by Erik Nygren's avatar Erik Nygren
Browse files

minor updates to training files

parent 417678af
No related branches found
No related tags found
No related merge requests found
......@@ -11,38 +11,39 @@ np.random.seed(1)
# Example generate a rail given a manual specification,
# a map of tuples (cell_type, rotation)
transition_probability = [5, # empty cell - Case 0
1, # Case 1 - straight
transition_probability = [15, # empty cell - Case 0
5, # Case 1 - straight
5, # Case 2 - simple switch
1, # Case 3 - diamond crossing
1, # Case 4 - single slip
1, # Case 5 - double slip
1, # Case 6 - symmetrical
0, # Case 7 - dead end
15, # Case 1b (8) - simple turn right
15, # Case 1c (9) - simple turn left
15] # Case 2b (10) - simple switch mirrored
1, # Case 1b (8) - simple turn right
1, # Case 1c (9) - simple turn left
1] # Case 2b (10) - simple switch mirrored
# Example generate a random rail
"""
env = RailEnv(width=10,
height=10,
rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
number_of_agents=3)
number_of_agents=1)
"""
env = RailEnv(width=20,
height=20,
rail_generator=complex_rail_generator(nr_start_goal=20, min_dist=10, max_dist=99999, seed=0),
number_of_agents=5)
env = RailEnv(width=15,
height=15,
rail_generator=complex_rail_generator(nr_start_goal=15, min_dist=5, max_dist=99999, seed=0),
number_of_agents=10)
"""
env = RailEnv(width=20,
height=20,
rail_generator=rail_from_list_of_saved_GridTransitionMap_generator(
['../notebooks/testing_11.npy']),
number_of_agents=1)
['../notebooks/temp.npy']),
number_of_agents=3)
"""
env_renderer = RenderTool(env, gl="QT")
handle = env.get_agent_handles()
......@@ -125,7 +126,8 @@ for trials in range(1, n_trials + 1):
next_obs[a] = np.clip(np.array(next_obs[a]) / norm, -1, 1)
# Update replay buffer and train agent
for a in range(env.number_of_agents):
agent.step(obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a])
if not demo:
agent.step(obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a])
score += all_rewards[a]
obs = next_obs.copy()
......
No preview for this file type
......@@ -123,7 +123,7 @@ def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0):
# print("failed...")
created_sanity += 1
print("\n> Complex Rail Gen: Created #", len(start_goal), "pairs")
#print("\n> Complex Rail Gen: Created #", len(start_goal), "pairs")
# print(start_goal)
agents_position = [sg[0] for sg in start_goal]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment