diff --git a/torch_training/Nets/avoid_checkpoint15000.pth b/torch_training/Nets/avoid_checkpoint15000.pth index 9d1936ab4a1d51530662b589423f78c0ccb57c44..bb681d11151a13c54c78c22ac7dd421eea45ed32 100644 Binary files a/torch_training/Nets/avoid_checkpoint15000.pth and b/torch_training/Nets/avoid_checkpoint15000.pth differ diff --git a/torch_training/Nets/avoid_checkpoint30000.pth b/torch_training/Nets/avoid_checkpoint30000.pth index 066b00180693a783ae134195e7cfdb1cd8975624..b6a0782cc1899a1e799011d19b3a9afb5906467c 100644 Binary files a/torch_training/Nets/avoid_checkpoint30000.pth and b/torch_training/Nets/avoid_checkpoint30000.pth differ diff --git a/torch_training/multi_agent_inference.py b/torch_training/multi_agent_inference.py index 84a0846cef45454f574c03db8c8a77d264fd798d..ef6ef4e1704cb1e1c53b622d27bb48ede5af4a54 100644 --- a/torch_training/multi_agent_inference.py +++ b/torch_training/multi_agent_inference.py @@ -17,7 +17,7 @@ from utils.observation_utils import normalize_observation random.seed(3) np.random.seed(2) -file_name = "./railway/complex_scene.pkl" +file_name = "./railway/simple_avoid.pkl" env = RailEnv(width=10, height=20, rail_generator=rail_from_file(file_name), @@ -27,9 +27,9 @@ y_dim = env.height """ -x_dim = 10 # np.random.randint(8, 20) -y_dim = 10 # np.random.randint(8, 20) -n_agents = 5 # np.random.randint(3, 8) +x_dim = 18 # np.random.randint(8, 20) +y_dim = 14 # np.random.randint(8, 20) +n_agents = 7 # np.random.randint(3, 8) n_goals = n_agents + np.random.randint(0, 3) min_dist = int(0.75 * min(x_dim, y_dim)) @@ -53,7 +53,7 @@ for i in range(tree_depth + 1): state_size = num_features_per_node * nr_nodes action_size = 5 -n_trials = 1 +n_trials = 10 observation_radius = 10 max_steps = int(3 * (env.height + env.width)) eps = 1. @@ -70,7 +70,7 @@ action_prob = [0] * action_size agent_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents() agent = Agent(state_size, action_size, "FC", 0) -with path(torch_training.Nets, "avoid_checkpoint52800.pth") as file_in: +with path(torch_training.Nets, "avoid_checkpoint46200.pth") as file_in: agent.qnetwork_local.load_state_dict(torch.load(file_in)) record_images = False @@ -102,7 +102,7 @@ for trials in range(1, n_trials + 1): next_obs, all_rewards, done, _ = env.step(action_dict) for a in range(env.get_num_agents()): - agent_obs[a] = agent_obs[a] = normalize_observation(next_obs[a], observation_radius=10) + agent_obs[a] = normalize_observation(next_obs[a], observation_radius=10) if done['__all__']: break