diff --git a/examples/training_navigation.py b/examples/training_navigation.py index 0e0f43964c8e6fbb9bba9325cae3b29f6c32f0e1..51d51b156e627fda9d122a8247cf71058bfd3c38 100644 --- a/examples/training_navigation.py +++ b/examples/training_navigation.py @@ -10,21 +10,21 @@ np.random.seed(1) # Example generate a rail given a manual specification, # a map of tuples (cell_type, rotation) -transition_probability = [0.5, # empty cell - Case 0 - 1.0, # Case 1 - straight - 1.0, # Case 2 - simple switch - 0.3, # Case 3 - diamond crossing - 0.5, # Case 4 - single slip - 0.5, # Case 5 - double slip - 0.2, # Case 6 - symmetrical - 0.0] # Case 7 - dead end +transition_probability = [5, # empty cell - Case 0 + 15, # Case 1 - straight + 5, # Case 2 - simple switch + 1, # Case 3 - diamond crossing + 1, # Case 4 - single slip + 1, # Case 5 - double slip + 1, # Case 6 - symmetrical + 0] # Case 7 - dead end # Example generate a random rail -env = RailEnv(width=20, - height=20, +env = RailEnv(width=10, + height=10, rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), - number_of_agents=1) -env_renderer = RenderTool(env) + number_of_agents=3) +env_renderer = RenderTool(env, gl="QT") handle = env.get_agent_handles() state_size = 105 @@ -42,7 +42,7 @@ action_prob = [0]*4 agent = Agent(state_size, action_size, "FC", 0) agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint15000.pth')) -demo = True +demo = False def max_lt(seq, val): """ Return greatest item in seq for which item < val applies. @@ -74,7 +74,7 @@ for trials in range(1, n_trials + 1): # Reset environment obs = env.reset() for a in range(env.number_of_agents): - norm = max(1, max_lt(obs[a],np.inf)) + norm = max(1, max_lt(obs[a], np.inf)) obs[a] = np.clip(np.array(obs[a]) / norm, -1, 1) # env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5) @@ -83,7 +83,7 @@ for trials in range(1, n_trials + 1): env_done = 0 # Run episode - for step in range(50): + for step in range(100): if demo: env_renderer.renderEnv(show=True) #print(step)