From 70c2f8694d2ec70b1042a174aa1973a1d9ef89b9 Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Fri, 19 Apr 2019 15:56:35 +0200
Subject: [PATCH] moved agent files and initial implementation of training

---
 examples/training_navigation.py               | 75 +++++++++++++++++++
 flatland/agents/__init__.py                   |  0
 .../agents}/dueling_double_dqn.py             |  0
 {agents => flatland/agents}/model.py          |  0
 4 files changed, 75 insertions(+)
 create mode 100644 examples/training_navigation.py
 create mode 100644 flatland/agents/__init__.py
 rename {agents => flatland/agents}/dueling_double_dqn.py (100%)
 rename {agents => flatland/agents}/model.py (100%)

diff --git a/examples/training_navigation.py b/examples/training_navigation.py
new file mode 100644
index 0000000..f81a50d
--- /dev/null
+++ b/examples/training_navigation.py
@@ -0,0 +1,75 @@
+from flatland.envs.rail_env import *
+from flatland.core.env_observation_builder import TreeObsForRailEnv
+from flatland.utils.rendertools import *
+from flatland.agents.dueling_double_dqn import Agent
+random.seed(1)
+np.random.seed(1)
+
+"""
+transition_probability = [1.0,  # empty cell - Case 0
+                          3.0,  # Case 1 - straight
+                          1.0,  # Case 2 - simple switch
+                          3.0,  # Case 3 - diamond drossing
+                          2.0,  # Case 4 - single slip
+                          1.0,  # Case 5 - double slip
+                          1.0,  # Case 6 - symmetrical
+                          1.0]  # Case 7 - dead end
+"""
+transition_probability = [1.0,  # empty cell - Case 0
+                          1.0,  # Case 1 - straight
+                          0.5,  # Case 2 - simple switch
+                          0.2,  # Case 3 - diamond drossing
+                          0.5,  # Case 4 - single slip
+                          0.1,  # Case 5 - double slip
+                          0.2,  # Case 6 - symmetrical
+                          0.01]  # Case 7 - dead end
+
+# Example generate a random rail
+env = RailEnv(width=20,
+              height=20,
+              rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
+              number_of_agents=10)
+env.reset()
+
+env_renderer = RenderTool(env)
+env_renderer.renderEnv(show=True)
+
+state_size = 5
+action_size = 4
+agent = Agent(state_size, action_size, "FC", 0)
+
+# Example generate a rail given a manual specification,
+# a map of tuples (cell_type, rotation)
+specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (7, 0), (0, 0)],
+         [(7, 270), (1, 90), (1, 90), (1, 90), (2, 90), (7, 90)]]
+
+env = RailEnv(width=6,
+              height=2,
+              rail_generator=rail_from_manual_specifications_generator(specs),
+              number_of_agents=1,
+              obs_builder_object=TreeObsForRailEnv(max_depth=2))
+
+handle = env.get_agent_handles()
+
+env.agents_position[0] = [1, 4]
+env.agents_target[0] = [1, 1]
+env.agents_direction[0] = 1
+# TODO: watch out: if these variables are overridden, the obs_builder object has to be reset, too!
+env.obs_builder.reset()
+
+# TODO: delete next line
+#for i in range(4):
+#    print(env.obs_builder.distance_map[0, :, :, i])
+
+obs, all_rewards, done, _ = env.step({0:0})
+print(len(obs[0]))
+env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5)
+
+env_renderer = RenderTool(env)
+env_renderer.renderEnv(show=True)
+
+for step in range(100):
+    obs, all_rewards, done, _ = env.step(action_dict)
+    action_dict = {}
+    print("Rewards: ", all_rewards, "  [done=", done, "]")
+    env_renderer.renderEnv(show=True)
diff --git a/flatland/agents/__init__.py b/flatland/agents/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/agents/dueling_double_dqn.py b/flatland/agents/dueling_double_dqn.py
similarity index 100%
rename from agents/dueling_double_dqn.py
rename to flatland/agents/dueling_double_dqn.py
diff --git a/agents/model.py b/flatland/agents/model.py
similarity index 100%
rename from agents/model.py
rename to flatland/agents/model.py
-- 
GitLab