From b8c300f929f02ba869ba2b255913a8eab0a34e91 Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Fri, 19 Apr 2019 16:07:10 +0200
Subject: [PATCH] setting up training environment

---
 examples/temporary_example.py         |  2 +-
 examples/training_navigation.py       | 16 ++++++++--------
 flatland/agents/dueling_double_dqn.py |  2 +-
 3 files changed, 10 insertions(+), 10 deletions(-)

diff --git a/examples/temporary_example.py b/examples/temporary_example.py
index c015f614..67fa4616 100644
--- a/examples/temporary_example.py
+++ b/examples/temporary_example.py
@@ -26,7 +26,7 @@ transition_probability = [1.0,  # empty cell - Case 0
                           0.5,  # Case 4 - single slip
                           0.1,  # Case 5 - double slip
                           0.2,  # Case 6 - symmetrical
-                          0.01]  # Case 7 - dead end
+                          1.0]  # Case 7 - dead end
 
 # Example generate a random rail
 env = RailEnv(width=20,
diff --git a/examples/training_navigation.py b/examples/training_navigation.py
index f81a50dd..33fe287d 100644
--- a/examples/training_navigation.py
+++ b/examples/training_navigation.py
@@ -32,9 +32,9 @@ env = RailEnv(width=20,
 env.reset()
 
 env_renderer = RenderTool(env)
-env_renderer.renderEnv(show=True)
+handle = env.get_agent_handles()
 
-state_size = 5
+state_size = 105
 action_size = 4
 agent = Agent(state_size, action_size, "FC", 0)
 
@@ -49,7 +49,6 @@ env = RailEnv(width=6,
               number_of_agents=1,
               obs_builder_object=TreeObsForRailEnv(max_depth=2))
 
-handle = env.get_agent_handles()
 
 env.agents_position[0] = [1, 4]
 env.agents_target[0] = [1, 1]
@@ -62,14 +61,15 @@ env.obs_builder.reset()
 #    print(env.obs_builder.distance_map[0, :, :, i])
 
 obs, all_rewards, done, _ = env.step({0:0})
-print(len(obs[0]))
-env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5)
+#env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5)
 
 env_renderer = RenderTool(env)
-env_renderer.renderEnv(show=True)
+action_dict = {0: 0}
 
 for step in range(100):
     obs, all_rewards, done, _ = env.step(action_dict)
-    action_dict = {}
+    action = agent.act(np.array(obs[0]),eps=1)
+
+    action_dict = {0 :action}
     print("Rewards: ", all_rewards, "  [done=", done, "]")
-    env_renderer.renderEnv(show=True)
+
diff --git a/flatland/agents/dueling_double_dqn.py b/flatland/agents/dueling_double_dqn.py
index 63a1badb..3eacf4c9 100644
--- a/flatland/agents/dueling_double_dqn.py
+++ b/flatland/agents/dueling_double_dqn.py
@@ -2,7 +2,7 @@ import numpy as np
 import random
 from collections import namedtuple, deque
 import os
-from agent.model import QNetwork, QNetwork2
+from flatland.agents.model import QNetwork, QNetwork2
 import torch
 import torch.nn.functional as F
 import torch.optim as optim
-- 
GitLab