diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py index 1720d7bc67b55f8e4a8a5a88501cdad2368b603f..8c30d72010d95389dd22af06971170aa9c4b4480 100644 --- a/torch_training/training_navigation.py +++ b/torch_training/training_navigation.py @@ -3,7 +3,6 @@ from collections import deque import numpy as np import torch - from dueling_double_dqn import Agent from flatland.envs.generators import complex_rail_generator from flatland.envs.rail_env import RailEnv @@ -38,16 +37,16 @@ env = RailEnv(width=15, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=10, min_dist=10, max_dist=99999, seed=0), number_of_agents=1) -""" + env = RailEnv(width=10, height=20) env.load("./railway/complex_scene.pkl") - +""" env = RailEnv(width=8, height=8, rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=5, min_dist=5, max_dist=99999, seed=0), - number_of_agents=2) + number_of_agents=1) env.reset(True, True) env_renderer = RenderTool(env, gl="PILSVG") @@ -133,7 +132,7 @@ for trials in range(1, n_trials + 1): data, distance, agent_data = env.obs_builder.split_tree(tree=np.array(obs[a]), num_features_per_node=7, current_depth=0) data = norm_obs_clip(data) distance = norm_obs_clip(distance) - + agent_data = np.clip(agent_data, -1, 1) obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data)) for i in range(2): time_obs.append(obs) @@ -146,13 +145,12 @@ for trials in range(1, n_trials + 1): # Run episode for step in range(100): if demo: - - env_renderer.renderEnv(show=True, show_observations=True) + env_renderer.renderEnv(show=True, show_observations=False) # print(step) # Action for a in range(env.get_num_agents()): if demo: - eps = 0 + eps = 1 # action = agent.act(np.array(obs[a]), eps=eps) action = agent.act(agent_obs[a], eps=eps) action_prob[action] += 1 @@ -165,6 +163,7 @@ for trials in range(1, n_trials + 1): current_depth=0) data = norm_obs_clip(data) distance = norm_obs_clip(distance) + agent_data = np.clip(agent_data, -1, 1) next_obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data)) time_obs.append(next_obs)