From 2e05cbe1dd86794a3f4bb4e669371f6388749904 Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Wed, 3 Jul 2019 14:57:41 -0400 Subject: [PATCH] using new utility functions --- torch_training/training_navigation.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py index a6ee613..0e5ad18 100644 --- a/torch_training/training_navigation.py +++ b/torch_training/training_navigation.py @@ -43,7 +43,7 @@ env = RailEnv(width=15, """ env = RailEnv(width=10, height=20, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) -env.load("./railway/flatland.pkl") +env.load("./railway/complex_scene.pkl") file_load = True """ @@ -80,7 +80,7 @@ agent = Agent(state_size, action_size, "FC", 0) agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint15000.pth')) demo = True -record_images = True +record_images = False @@ -129,15 +129,15 @@ for trials in range(1, n_trials + 1): if demo: eps = 0 # action = agent.act(np.array(obs[a]), eps=eps) - action = 2 #agent.act(agent_obs[a], eps=eps) + action = agent.act(agent_obs[a], eps=eps) action_prob[action] += 1 action_dict.update({a: action}) # Environment step next_obs, all_rewards, done, _ = env.step(action_dict) for a in range(env.get_num_agents()): - data, distance, agent_data = env.obs_builder.split_tree(tree=np.array(next_obs[a]), num_features_per_node=8, - current_depth=0) + data, distance, agent_data = split_tree(tree=np.array(next_obs[a]), num_features_per_node=8, + current_depth=0) data = norm_obs_clip(data) distance = norm_obs_clip(distance) agent_data = np.clip(agent_data, -1, 1) -- GitLab