From e99a59af87524ac0c8cfeb0038bdf5c734bbe66f Mon Sep 17 00:00:00 2001 From: Giacomo Spigler <spiglerg@gmail.com> Date: Sun, 21 Apr 2019 01:40:22 +0200 Subject: [PATCH] railenv generator probabilities: softmax -> normalize by sum --- flatland/envs/rail_env.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 4735f9ab..31c91b57 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -257,8 +257,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8): rail[replace_row][replace_col] = None possible_transitions, possible_probabilities = zip(*besttrans) - possible_probabilities = \ - np.exp(possible_probabilities) / sum(np.exp(possible_probabilities)) + possible_probabilities = [p/sum(possible_probabilities) for p in possible_probabilities] rail[row][col] = np.random.choice(possible_transitions, p=possible_probabilities) @@ -272,7 +271,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8): else: possible_transitions, possible_probabilities = zip(*possible_cell_transitions) - possible_probabilities = np.exp(possible_probabilities) / sum(np.exp(possible_probabilities)) + possible_probabilities = [p/sum(possible_probabilities) for p in possible_probabilities] rail[row][col] = np.random.choice(possible_transitions, p=possible_probabilities) -- GitLab