From 72c69df1274bd2b2e24160b8cf7c2f0e25441683 Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Sun, 5 May 2019 12:33:15 +0200 Subject: [PATCH] minor test in navigation training --- examples/training_navigation.py | 12 ++++++++---- flatland/envs/generators.py | 2 +- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/examples/training_navigation.py b/examples/training_navigation.py index ddf10b10..47487086 100644 --- a/examples/training_navigation.py +++ b/examples/training_navigation.py @@ -12,19 +12,23 @@ np.random.seed(1) # Example generate a rail given a manual specification, # a map of tuples (cell_type, rotation) transition_probability = [5, # empty cell - Case 0 - 15, # Case 1 - straight + 1, # Case 1 - straight 5, # Case 2 - simple switch 1, # Case 3 - diamond crossing 1, # Case 4 - single slip 1, # Case 5 - double slip 1, # Case 6 - symmetrical - 0] # Case 7 - dead end + 0, # Case 7 - dead end + 15, # Case 1b (8) - simple turn right + 15, # Case 1c (9) - simple turn left + 15] # Case 2b (10) - simple switch mirrored + # Example generate a random rail env = RailEnv(width=10, height=10, rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), - number_of_agents=1) + number_of_agents=3) """ env = RailEnv(width=20, height=20, @@ -57,7 +61,7 @@ action_prob = [0] * 4 agent = Agent(state_size, action_size, "FC", 0) agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint14900.pth')) -demo = True +demo = False def max_lt(seq, val): diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index fed79017..baba6baa 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -250,7 +250,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8): transitions_templates_ = [] transition_probabilities = [] - for i in range(len(t_utils.transitions) - 4): # don't include dead-ends + for i in range(len(t_utils.transitions)-4): # don't include dead-ends all_transitions = 0 for dir_ in range(4): trans = t_utils.get_transitions(t_utils.transitions[i], dir_) -- GitLab