Skip to content
Snippets Groups Projects
Commit b8846789 authored by Erik Nygren's avatar Erik Nygren
Browse files

new training results

parent 1e034d49
No related branches found
No related tags found
No related merge requests found
No preview for this file type
...@@ -17,7 +17,7 @@ from utils.observation_utils import norm_obs_clip, split_tree ...@@ -17,7 +17,7 @@ from utils.observation_utils import norm_obs_clip, split_tree
random.seed(3) random.seed(3)
np.random.seed(2) np.random.seed(2)
file_name = "./railway/navigate_and_avoid.pkl" file_name = "./railway/complex_scene.pkl"
env = RailEnv(width=10, env = RailEnv(width=10,
height=20, height=20,
rail_generator=rail_from_file(file_name), rail_generator=rail_from_file(file_name),
...@@ -27,9 +27,9 @@ y_dim = env.height ...@@ -27,9 +27,9 @@ y_dim = env.height
""" """
x_dim = np.random.randint(8, 20) x_dim = 20 #np.random.randint(8, 20)
y_dim = np.random.randint(8, 20) y_dim = 20 #np.random.randint(8, 20)
n_agents = np.random.randint(3, 8) n_agents = 10 #np.random.randint(3, 8)
n_goals = n_agents + np.random.randint(0, 3) n_goals = n_agents + np.random.randint(0, 3)
min_dist = int(0.75 * min(x_dim, y_dim)) min_dist = int(0.75 * min(x_dim, y_dim))
...@@ -53,7 +53,7 @@ for i in range(tree_depth + 1): ...@@ -53,7 +53,7 @@ for i in range(tree_depth + 1):
state_size = num_features_per_node * nr_nodes state_size = num_features_per_node * nr_nodes
action_size = 5 action_size = 5
n_trials = 5 n_trials = 1
observation_radius = 10 observation_radius = 10
max_steps = int(3 * (env.height + env.width)) max_steps = int(3 * (env.height + env.width))
eps = 1. eps = 1.
...@@ -70,10 +70,10 @@ action_prob = [0] * action_size ...@@ -70,10 +70,10 @@ action_prob = [0] * action_size
agent_obs = [None] * env.get_num_agents() agent_obs = [None] * env.get_num_agents()
agent_next_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents()
agent = Agent(state_size, action_size, "FC", 0) agent = Agent(state_size, action_size, "FC", 0)
with path(torch_training.Nets, "avoid_checkpoint53400.pth") as file_in: with path(torch_training.Nets, "avoid_checkpoint60000.pth") as file_in:
agent.qnetwork_local.load_state_dict(torch.load(file_in)) agent.qnetwork_local.load_state_dict(torch.load(file_in))
record_images = False record_images = True
frame_step = 0 frame_step = 0
for trials in range(1, n_trials + 1): for trials in range(1, n_trials + 1):
...@@ -93,7 +93,7 @@ for trials in range(1, n_trials + 1): ...@@ -93,7 +93,7 @@ for trials in range(1, n_trials + 1):
# Run episode # Run episode
for step in range(max_steps): for step in range(max_steps):
env_renderer.render_env(show=True, show_observations=False, show_predictions=True) env_renderer.render_env(show=True, show_observations=False, show_predictions=False)
if record_images: if record_images:
env_renderer.gl.save_image("./Images/Avoiding/flatland_frame_{:04d}.bmp".format(frame_step)) env_renderer.gl.save_image("./Images/Avoiding/flatland_frame_{:04d}.bmp".format(frame_step))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment