diff --git a/torch_training/Nets/avoid_checkpoint15000.pth b/torch_training/Nets/avoid_checkpoint15000.pth index 9cc03d3c2ac88d946ddd33fa1009cb3ab56a7b59..7f629e2e980777167d964e12dfcf8f9f4b86fcb9 100644 Binary files a/torch_training/Nets/avoid_checkpoint15000.pth and b/torch_training/Nets/avoid_checkpoint15000.pth differ diff --git a/torch_training/Nets/avoid_checkpoint30000.pth b/torch_training/Nets/avoid_checkpoint30000.pth index 2625b7648ec3ff8e3efba2ed33eebe516654c252..b146ffa5265aad9b05c112d86abbd2119ceea775 100644 Binary files a/torch_training/Nets/avoid_checkpoint30000.pth and b/torch_training/Nets/avoid_checkpoint30000.pth differ diff --git a/torch_training/multi_agent_inference.py b/torch_training/multi_agent_inference.py index e88bf3132eb54714e966025b317d50bf1edbb576..94c4ee035217828b3a31eef800be7153a9202f56 100644 --- a/torch_training/multi_agent_inference.py +++ b/torch_training/multi_agent_inference.py @@ -53,6 +53,7 @@ state_size = num_features_per_node * nr_nodes action_size = 5 n_trials = 100 +observation_radius = 10 max_steps = int(3 * (env.height + env.width)) eps = 1. eps_end = 0.005 @@ -68,7 +69,7 @@ action_prob = [0] * action_size agent_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents() agent = Agent(state_size, action_size, "FC", 0) -with path(torch_training.Nets, "avoid_checkpoint2900.pth") as file_in: +with path(torch_training.Nets, "avoid_checkpoint49700.pth") as file_in: agent.qnetwork_local.load_state_dict(torch.load(file_in)) record_images = False @@ -84,7 +85,7 @@ for trials in range(1, n_trials + 1): for a in range(env.get_num_agents()): data, distance, agent_data = split_tree(tree=np.array(obs[a]), num_features_per_node=num_features_per_node, current_depth=0) - data = norm_obs_clip(data) + data = norm_obs_clip(data, fixed_radius=observation_radius) distance = norm_obs_clip(distance) agent_data = np.clip(agent_data, -1, 1) agent_obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data)) @@ -106,9 +107,10 @@ for trials in range(1, n_trials + 1): next_obs, all_rewards, done, _ = env.step(action_dict) for a in range(env.get_num_agents()): - data, distance, agent_data = split_tree(tree=np.array(obs[a]), num_features_per_node=num_features_per_node, + data, distance, agent_data = split_tree(tree=np.array(next_obs[a]), + num_features_per_node=num_features_per_node, current_depth=0) - data = norm_obs_clip(data) + data = norm_obs_clip(data, fixed_radius=observation_radius) distance = norm_obs_clip(distance) agent_data = np.clip(agent_data, -1, 1) agent_next_obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data)) diff --git a/torch_training/render_agent_behavior.py b/torch_training/render_agent_behavior.py index bc377e393e51ac17dba9d198ef44565e15c1c062..489501ae8ca43df9ca94a86a837e274181b207ee 100644 --- a/torch_training/render_agent_behavior.py +++ b/torch_training/render_agent_behavior.py @@ -3,7 +3,7 @@ from collections import deque import numpy as np import torch -from flatland.envs.generators import rail_from_file +from flatland.envs.generators import complex_rail_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv @@ -16,7 +16,7 @@ from utils.observation_utils import norm_obs_clip, split_tree random.seed(1) np.random.seed(1) - +""" file_name = "./railway/complex_scene.pkl" env = RailEnv(width=10, height=20, @@ -40,7 +40,7 @@ env = RailEnv(width=x_dim, obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()), number_of_agents=n_agents) env.reset(True, True) -""" + observation_helper = TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()) env_renderer = RenderTool(env, gl="PILSVG", ) num_features_per_node = env.obs_builder.observation_dim @@ -67,7 +67,7 @@ action_prob = [0] * action_size agent_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents() agent = Agent(state_size, action_size, "FC", 0) -with path(torch_training.Nets, "avoid_checkpoint60000.pth") as file_in: +with path(torch_training.Nets, "avoid_checkpoint49700.pth") as file_in: agent.qnetwork_local.load_state_dict(torch.load(file_in)) record_images = False @@ -101,7 +101,7 @@ for trials in range(1, n_trials + 1): # Run episode for step in range(max_steps): - env_renderer.render_env(show=True, show_observations=False, show_predictions=False) + env_renderer.render_env(show=True, show_observations=False, show_predictions=True) if record_images: env_renderer.gl.saveImage("./Images/flatland_frame_{:04d}.bmp".format(frame_step))