diff --git a/examples/training_navigation.py b/examples/training_navigation.py index 3990587743c20aa214348ccdc134f01f7acd8be6..52e76450bd1a97487471b9d531d72021b7b4fcef 100644 --- a/examples/training_navigation.py +++ b/examples/training_navigation.py @@ -11,38 +11,39 @@ np.random.seed(1) # Example generate a rail given a manual specification, # a map of tuples (cell_type, rotation) -transition_probability = [5, # empty cell - Case 0 - 1, # Case 1 - straight +transition_probability = [15, # empty cell - Case 0 + 5, # Case 1 - straight 5, # Case 2 - simple switch 1, # Case 3 - diamond crossing 1, # Case 4 - single slip 1, # Case 5 - double slip 1, # Case 6 - symmetrical 0, # Case 7 - dead end - 15, # Case 1b (8) - simple turn right - 15, # Case 1c (9) - simple turn left - 15] # Case 2b (10) - simple switch mirrored + 1, # Case 1b (8) - simple turn right + 1, # Case 1c (9) - simple turn left + 1] # Case 2b (10) - simple switch mirrored # Example generate a random rail +""" env = RailEnv(width=10, height=10, rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), - number_of_agents=3) + number_of_agents=1) """ -env = RailEnv(width=20, - height=20, - rail_generator=complex_rail_generator(nr_start_goal=20, min_dist=10, max_dist=99999, seed=0), - number_of_agents=5) +env = RailEnv(width=15, + height=15, + rail_generator=complex_rail_generator(nr_start_goal=15, min_dist=5, max_dist=99999, seed=0), + number_of_agents=10) """ env = RailEnv(width=20, height=20, rail_generator=rail_from_list_of_saved_GridTransitionMap_generator( - ['../notebooks/testing_11.npy']), - number_of_agents=1) - + ['../notebooks/temp.npy']), + number_of_agents=3) +""" env_renderer = RenderTool(env, gl="QT") handle = env.get_agent_handles() @@ -125,7 +126,8 @@ for trials in range(1, n_trials + 1): next_obs[a] = np.clip(np.array(next_obs[a]) / norm, -1, 1) # Update replay buffer and train agent for a in range(env.number_of_agents): - agent.step(obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a]) + if not demo: + agent.step(obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a]) score += all_rewards[a] obs = next_obs.copy() diff --git a/flatland/baselines/Nets/avoid_checkpoint15000.pth b/flatland/baselines/Nets/avoid_checkpoint15000.pth index 9a63ce495867f6bf5464d0a0856187a6dba736b4..ca019b7b5d221577bcdb65e3979ba9795e5fd65b 100644 Binary files a/flatland/baselines/Nets/avoid_checkpoint15000.pth and b/flatland/baselines/Nets/avoid_checkpoint15000.pth differ diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index 4f356e1c673d37af40d59a1b4c297bec59f5ce6c..fe971e6b24b90e31dadd797359247537078ad5f6 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -123,7 +123,7 @@ def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0): # print("failed...") created_sanity += 1 - print("\n> Complex Rail Gen: Created #", len(start_goal), "pairs") + #print("\n> Complex Rail Gen: Created #", len(start_goal), "pairs") # print(start_goal) agents_position = [sg[0] for sg in start_goal]