diff --git a/examples/training_navigation.py b/examples/training_navigation.py index ddf10b10619387b32fafe4780e97e80fe3550ac9..47487086a3858dc1fd71006b934dc353300498bc 100644 --- a/examples/training_navigation.py +++ b/examples/training_navigation.py @@ -12,19 +12,23 @@ np.random.seed(1) # Example generate a rail given a manual specification, # a map of tuples (cell_type, rotation) transition_probability = [5, # empty cell - Case 0 - 15, # Case 1 - straight + 1, # Case 1 - straight 5, # Case 2 - simple switch 1, # Case 3 - diamond crossing 1, # Case 4 - single slip 1, # Case 5 - double slip 1, # Case 6 - symmetrical - 0] # Case 7 - dead end + 0, # Case 7 - dead end + 15, # Case 1b (8) - simple turn right + 15, # Case 1c (9) - simple turn left + 15] # Case 2b (10) - simple switch mirrored + # Example generate a random rail env = RailEnv(width=10, height=10, rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), - number_of_agents=1) + number_of_agents=3) """ env = RailEnv(width=20, height=20, @@ -57,7 +61,7 @@ action_prob = [0] * 4 agent = Agent(state_size, action_size, "FC", 0) agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint14900.pth')) -demo = True +demo = False def max_lt(seq, val): diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index fed790173941d7afc5ea5e3956bd131443f3ef57..baba6baac4a4efe76844147d7c49435a88ad79af 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -250,7 +250,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8): transitions_templates_ = [] transition_probabilities = [] - for i in range(len(t_utils.transitions) - 4): # don't include dead-ends + for i in range(len(t_utils.transitions)-4): # don't include dead-ends all_transitions = 0 for dir_ in range(4): trans = t_utils.get_transitions(t_utils.transitions[i], dir_)