Skip to content
Snippets Groups Projects
Commit faca100b authored by Erik Nygren's avatar Erik Nygren
Browse files

new training results added

parent 9fd05a66
No related branches found
No related tags found
No related merge requests found
No preview for this file type
No preview for this file type
...@@ -17,7 +17,7 @@ from utils.observation_utils import normalize_observation ...@@ -17,7 +17,7 @@ from utils.observation_utils import normalize_observation
random.seed(3) random.seed(3)
np.random.seed(2) np.random.seed(2)
file_name = "./railway/complex_scene.pkl" file_name = "./railway/simple_avoid.pkl"
env = RailEnv(width=10, env = RailEnv(width=10,
height=20, height=20,
rail_generator=rail_from_file(file_name), rail_generator=rail_from_file(file_name),
...@@ -27,9 +27,9 @@ y_dim = env.height ...@@ -27,9 +27,9 @@ y_dim = env.height
""" """
x_dim = 10 # np.random.randint(8, 20) x_dim = 18 # np.random.randint(8, 20)
y_dim = 10 # np.random.randint(8, 20) y_dim = 14 # np.random.randint(8, 20)
n_agents = 5 # np.random.randint(3, 8) n_agents = 7 # np.random.randint(3, 8)
n_goals = n_agents + np.random.randint(0, 3) n_goals = n_agents + np.random.randint(0, 3)
min_dist = int(0.75 * min(x_dim, y_dim)) min_dist = int(0.75 * min(x_dim, y_dim))
...@@ -53,7 +53,7 @@ for i in range(tree_depth + 1): ...@@ -53,7 +53,7 @@ for i in range(tree_depth + 1):
state_size = num_features_per_node * nr_nodes state_size = num_features_per_node * nr_nodes
action_size = 5 action_size = 5
n_trials = 1 n_trials = 10
observation_radius = 10 observation_radius = 10
max_steps = int(3 * (env.height + env.width)) max_steps = int(3 * (env.height + env.width))
eps = 1. eps = 1.
...@@ -70,7 +70,7 @@ action_prob = [0] * action_size ...@@ -70,7 +70,7 @@ action_prob = [0] * action_size
agent_obs = [None] * env.get_num_agents() agent_obs = [None] * env.get_num_agents()
agent_next_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents()
agent = Agent(state_size, action_size, "FC", 0) agent = Agent(state_size, action_size, "FC", 0)
with path(torch_training.Nets, "avoid_checkpoint52800.pth") as file_in: with path(torch_training.Nets, "avoid_checkpoint46200.pth") as file_in:
agent.qnetwork_local.load_state_dict(torch.load(file_in)) agent.qnetwork_local.load_state_dict(torch.load(file_in))
record_images = False record_images = False
...@@ -102,7 +102,7 @@ for trials in range(1, n_trials + 1): ...@@ -102,7 +102,7 @@ for trials in range(1, n_trials + 1):
next_obs, all_rewards, done, _ = env.step(action_dict) next_obs, all_rewards, done, _ = env.step(action_dict)
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
agent_obs[a] = agent_obs[a] = normalize_observation(next_obs[a], observation_radius=10) agent_obs[a] = normalize_observation(next_obs[a], observation_radius=10)
if done['__all__']: if done['__all__']:
break break
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment