diff --git a/torch_training/multi_agent_inference.py b/torch_training/multi_agent_inference.py index 3d6d27c09ad8920720fa2312a555296a57d39924..30f5a91288f71330ed6c9bab00b79090aa6bf818 100644 --- a/torch_training/multi_agent_inference.py +++ b/torch_training/multi_agent_inference.py @@ -3,7 +3,7 @@ from collections import deque import numpy as np import torch -from flatland.envs.generators import complex_rail_generator, rail_from_file +from flatland.envs.generators import complex_rail_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv @@ -16,8 +16,8 @@ from utils.observation_utils import norm_obs_clip, split_tree random.seed(3) np.random.seed(2) - -file_name = "./railway/complex_scene.pkl" +""" +file_name = "./railway/simple_avoid.pkl" env = RailEnv(width=10, height=20, rail_generator=rail_from_file(file_name), @@ -29,8 +29,8 @@ y_dim = env.height x_dim = 20 #np.random.randint(8, 20) y_dim = 20 #np.random.randint(8, 20) -n_agents = 10 #np.random.randint(3, 8) -n_goals = n_agents + np.random.randint(0, 3) +n_agents = 1 # np.random.randint(3, 8) +n_goals = 10 + n_agents + np.random.randint(0, 3) min_dist = int(0.75 * min(x_dim, y_dim)) env = RailEnv(width=x_dim, @@ -41,7 +41,7 @@ env = RailEnv(width=x_dim, obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()), number_of_agents=n_agents) env.reset(True, True) -""" + tree_depth = 3 observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=ShortestPathPredictorForRailEnv()) env_renderer = RenderTool(env, gl="PILSVG", ) @@ -53,8 +53,8 @@ for i in range(tree_depth + 1): state_size = num_features_per_node * nr_nodes action_size = 5 -n_trials = 1 -observation_radius = 10 +n_trials = 10 +observation_radius = 20 max_steps = int(3 * (env.height + env.width)) eps = 1. eps_end = 0.005 @@ -73,7 +73,7 @@ agent = Agent(state_size, action_size, "FC", 0) with path(torch_training.Nets, "avoid_checkpoint60000.pth") as file_in: agent.qnetwork_local.load_state_dict(torch.load(file_in)) -record_images = True +record_images = False frame_step = 0 for trials in range(1, n_trials + 1): @@ -93,7 +93,7 @@ for trials in range(1, n_trials + 1): # Run episode for step in range(max_steps): - env_renderer.render_env(show=True, show_observations=False, show_predictions=False) + env_renderer.render_env(show=True, show_observations=True, show_predictions=False) if record_images: env_renderer.gl.save_image("./Images/Avoiding/flatland_frame_{:04d}.bmp".format(frame_step)) @@ -114,8 +114,7 @@ for trials in range(1, n_trials + 1): data = norm_obs_clip(data, fixed_radius=observation_radius) distance = norm_obs_clip(distance) agent_data = np.clip(agent_data, -1, 1) - agent_next_obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data)) + agent_obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data)) - agent_obs = agent_next_obs.copy() if done['__all__']: break diff --git a/torch_training/railway/complex_scene.pkl b/torch_training/railway/complex_scene.pkl index b5c272477f53794d78a896c33d7c91e5b8cb0ea3..9bad5f9674b7a7b7e792c4e4805bce51ced35f7c 100644 Binary files a/torch_training/railway/complex_scene.pkl and b/torch_training/railway/complex_scene.pkl differ