diff --git a/torch_training/bla.py b/torch_training/bla.py index a103f9f7a91d565d449236231da0ac1ed034fc39..cd1aa13a824a58815778f0779e64739093a5a095 100644 --- a/torch_training/bla.py +++ b/torch_training/bla.py @@ -55,6 +55,33 @@ def main(argv): number_of_agents=n_agents) env.reset(True, True) file_load = False + observation_helper = TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()) + env_renderer = RenderTool(env, gl="PILSVG", ) + handle = env.get_agent_handles() + features_per_node = 9 + state_size = features_per_node * 85 * 2 + action_size = 5 + + print("main3") + + # We set the number of episodes we would like to train on + if 'n_trials' not in locals(): + n_trials = 30000 + max_steps = int(3 * (env.height + env.width)) + eps = 1. + eps_end = 0.005 + eps_decay = 0.9995 + action_dict = dict() + final_action_dict = dict() + scores_window = deque(maxlen=100) + done_window = deque(maxlen=100) + time_obs = deque(maxlen=2) + scores = [] + dones_list = [] + action_prob = [0] * action_size + agent_obs = [None] * env.get_num_agents() + agent_next_obs = [None] * env.get_num_agents() + agent = Agent(state_size, action_size, "FC", 0) print("multi_agent_trainging.py (2)")