diff --git a/torch_training/bla.py b/torch_training/bla.py index 37d9a09db78f6989c9559d59b849a58ce3feaabd..80ec308c2b6bc5498d9198e2a03e562b02e7c96d 100644 --- a/torch_training/bla.py +++ b/torch_training/bla.py @@ -92,24 +92,24 @@ def main(argv): print("Going to run training for {} trials...".format(n_trials)) for trials in range(1, n_trials + 1): - # if trials % 50 == 0 and not demo: - # x_dim = np.random.randint(8, 20) - # y_dim = np.random.randint(8, 20) - # n_agents = np.random.randint(3, 8) - # n_goals = n_agents + np.random.randint(0, 3) - # min_dist = int(0.75 * min(x_dim, y_dim)) - # env = RailEnv(width=x_dim, - # height=y_dim, - # rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist, - # max_dist=99999, - # seed=0), - # obs_builder_object=TreeObsForRailEnv(max_depth=3, - # predictor=ShortestPathPredictorForRailEnv()), - # number_of_agents=n_agents) - # env.reset(True, True) - # max_steps = int(3 * (env.height + env.width)) - # agent_obs = [None] * env.get_num_agents() - # agent_next_obs = [None] * env.get_num_agents() + if trials % 50 == 0 and not demo: + x_dim = np.random.randint(8, 20) + y_dim = np.random.randint(8, 20) + n_agents = np.random.randint(3, 8) + n_goals = n_agents + np.random.randint(0, 3) + min_dist = int(0.75 * min(x_dim, y_dim)) + env = RailEnv(width=x_dim, + height=y_dim, + rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist, + max_dist=99999, + seed=0), + obs_builder_object=TreeObsForRailEnv(max_depth=3, + predictor=ShortestPathPredictorForRailEnv()), + number_of_agents=n_agents) + env.reset(True, True) + max_steps = int(3 * (env.height + env.width)) + agent_obs = [None] * env.get_num_agents() + agent_next_obs = [None] * env.get_num_agents() # # Reset environment # if file_load: # obs = env.reset(False, False)