From ee0ed5c3bd7c33e96fcc31bb2b4b2aa08d844156 Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Wed, 10 Jul 2019 17:01:45 +0200 Subject: [PATCH] #92 reward function test --- torch_training/bla.py | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/torch_training/bla.py b/torch_training/bla.py index 37d9a09..80ec308 100644 --- a/torch_training/bla.py +++ b/torch_training/bla.py @@ -92,24 +92,24 @@ def main(argv): print("Going to run training for {} trials...".format(n_trials)) for trials in range(1, n_trials + 1): - # if trials % 50 == 0 and not demo: - # x_dim = np.random.randint(8, 20) - # y_dim = np.random.randint(8, 20) - # n_agents = np.random.randint(3, 8) - # n_goals = n_agents + np.random.randint(0, 3) - # min_dist = int(0.75 * min(x_dim, y_dim)) - # env = RailEnv(width=x_dim, - # height=y_dim, - # rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist, - # max_dist=99999, - # seed=0), - # obs_builder_object=TreeObsForRailEnv(max_depth=3, - # predictor=ShortestPathPredictorForRailEnv()), - # number_of_agents=n_agents) - # env.reset(True, True) - # max_steps = int(3 * (env.height + env.width)) - # agent_obs = [None] * env.get_num_agents() - # agent_next_obs = [None] * env.get_num_agents() + if trials % 50 == 0 and not demo: + x_dim = np.random.randint(8, 20) + y_dim = np.random.randint(8, 20) + n_agents = np.random.randint(3, 8) + n_goals = n_agents + np.random.randint(0, 3) + min_dist = int(0.75 * min(x_dim, y_dim)) + env = RailEnv(width=x_dim, + height=y_dim, + rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist, + max_dist=99999, + seed=0), + obs_builder_object=TreeObsForRailEnv(max_depth=3, + predictor=ShortestPathPredictorForRailEnv()), + number_of_agents=n_agents) + env.reset(True, True) + max_steps = int(3 * (env.height + env.width)) + agent_obs = [None] * env.get_num_agents() + agent_next_obs = [None] * env.get_num_agents() # # Reset environment # if file_load: # obs = env.reset(False, False) -- GitLab