diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 4735f9ab44ec19fc544daa2c91cc9ba0533317c9..31c91b571353d1b0f05826985e49ec26151d59c7 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -257,8 +257,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8): rail[replace_row][replace_col] = None possible_transitions, possible_probabilities = zip(*besttrans) - possible_probabilities = \ - np.exp(possible_probabilities) / sum(np.exp(possible_probabilities)) + possible_probabilities = [p/sum(possible_probabilities) for p in possible_probabilities] rail[row][col] = np.random.choice(possible_transitions, p=possible_probabilities) @@ -272,7 +271,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8): else: possible_transitions, possible_probabilities = zip(*possible_cell_transitions) - possible_probabilities = np.exp(possible_probabilities) / sum(np.exp(possible_probabilities)) + possible_probabilities = [p/sum(possible_probabilities) for p in possible_probabilities] rail[row][col] = np.random.choice(possible_transitions, p=possible_probabilities)