diff --git a/torch_training/multi_agent_inference.py b/torch_training/multi_agent_inference.py index 94c4ee035217828b3a31eef800be7153a9202f56..4e12353eaeaf353f70af79987ca9cc2e87a302ae 100644 --- a/torch_training/multi_agent_inference.py +++ b/torch_training/multi_agent_inference.py @@ -3,7 +3,7 @@ from collections import deque import numpy as np import torch -from flatland.envs.generators import complex_rail_generator +from flatland.envs.generators import complex_rail_generator, rail_from_file from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv @@ -14,16 +14,17 @@ import torch_training.Nets from torch_training.dueling_double_dqn import Agent from utils.observation_utils import norm_obs_clip, split_tree -random.seed(1) -np.random.seed(1) -""" -file_name = "./railway/complex_scene.pkl" +random.seed(3) +np.random.seed(2) + +file_name = "./railway/navigate_and_avoid.pkl" env = RailEnv(width=10, height=20, rail_generator=rail_from_file(file_name), obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv())) x_dim = env.width y_dim = env.height + """ x_dim = np.random.randint(8, 20) @@ -40,7 +41,7 @@ env = RailEnv(width=x_dim, obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()), number_of_agents=n_agents) env.reset(True, True) - +""" tree_depth = 3 observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=ShortestPathPredictorForRailEnv()) env_renderer = RenderTool(env, gl="PILSVG", ) @@ -52,7 +53,7 @@ for i in range(tree_depth + 1): state_size = num_features_per_node * nr_nodes action_size = 5 -n_trials = 100 +n_trials = 5 observation_radius = 10 max_steps = int(3 * (env.height + env.width)) eps = 1. @@ -69,7 +70,7 @@ action_prob = [0] * action_size agent_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents() agent = Agent(state_size, action_size, "FC", 0) -with path(torch_training.Nets, "avoid_checkpoint49700.pth") as file_in: +with path(torch_training.Nets, "avoid_checkpoint53400.pth") as file_in: agent.qnetwork_local.load_state_dict(torch.load(file_in)) record_images = False @@ -95,7 +96,7 @@ for trials in range(1, n_trials + 1): env_renderer.render_env(show=True, show_observations=False, show_predictions=True) if record_images: - env_renderer.gl.saveImage("./Images/flatland_frame_{:04d}.bmp".format(frame_step)) + env_renderer.gl.save_image("./Images/Avoiding/flatland_frame_{:04d}.bmp".format(frame_step)) frame_step += 1 # Action