From 3ba970a03be9c806cf5e5af768fd4924e5f3a971 Mon Sep 17 00:00:00 2001 From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch> Date: Thu, 23 May 2019 21:22:35 +0200 Subject: [PATCH] demo.py is no simply working --- env-data/railway/example_network_000.pkl | Bin 180 -> 180 bytes env-data/railway/example_network_001.pkl | Bin 218 -> 218 bytes env-data/railway/example_network_002.pkl | Bin 282 -> 290 bytes examples/demo.py | 40 ++--------------------- notebooks/Editor2.ipynb | 33 ++++++++++++------- 5 files changed, 23 insertions(+), 50 deletions(-) diff --git a/env-data/railway/example_network_000.pkl b/env-data/railway/example_network_000.pkl index ab7764b1ea250633e5a0cd2ec9b93a4e07479c9f..e102e21735416747cb8bd9f231ce6e20fdf514c0 100644 GIT binary patch delta 20 ccmdnOxP@`Ta$ZKpNvxd964O)jN{S}{07exD2><{9 delta 20 ccmdnOxP@`Ta$W|;Nvxd964O)jN{S}{07eK02mk;8 diff --git a/env-data/railway/example_network_001.pkl b/env-data/railway/example_network_001.pkl index af9dc43246449562065be64c750a06c75c3e7571..a9c5cc97c9c4bf4159db2134756f17fa0c4fce87 100644 GIT binary patch delta 20 ccmcb`c#CnuSzboQNgNEz64O)jN{S}{08PLLbpQYW delta 20 ccmcb`c#CnuSzZRlNgNEz64O)jN{S}{08O(8bN~PV diff --git a/env-data/railway/example_network_002.pkl b/env-data/railway/example_network_002.pkl index d39e44798066e0a0b753006775c5727326fc58da..37647ac2871801d2d08fd65276889e2b232c1170 100644 GIT binary patch delta 33 pcmbQmw1{Z~C!^$~$&)xam?tqXOrFHS!7zz~eOY39YF<h41OT6X3I6~9 delta 25 hcmZ3)G>d5iC!^rR$&)xanI|zYEK5vJ%_}LM003eJ2q6Ff diff --git a/examples/demo.py b/examples/demo.py index a53370a..c2485de 100644 --- a/examples/demo.py +++ b/examples/demo.py @@ -132,39 +132,16 @@ class Demo: handle = self.env.get_agent_handles() return handle - def run_demo(self, max_nbr_of_steps=100): + def run_demo(self, max_nbr_of_steps=30): action_dict = dict() - time_obs = deque(maxlen=2) - action_prob = [0] * 4 - agent_obs = [None] * self.env.get_num_agents() - agent_next_obs = [None] * self.env.get_num_agents() # Reset environment obs = self.env.reset(False, False) - for a in range(self.env.get_num_agents()): - data, distance = self.env.obs_builder.split_tree(tree=np.array(obs[a]), num_features_per_node=5, current_depth=0) - - data = norm_obs_clip(data) - distance = norm_obs_clip(distance) - obs[a] = np.concatenate((data, distance)) - - for i in range(2): - time_obs.append(obs) - - # env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5) - for a in range(self.env.get_num_agents()): - agent_obs[a] = np.concatenate((time_obs[0][a], time_obs[1][a])) - for step in range(max_nbr_of_steps): - - time.sleep(.2) - - # print(step) # Action for a in range(self.env.get_num_agents()): - action = np.random.choice(self.action_size) #self.agent.act(agent_obs[a]) - action_prob[action] += 1 + action = 2 #np.random.choice(self.action_size) #self.agent.act(agent_obs[a]) action_dict.update({a: action}) print(action_dict) @@ -173,20 +150,7 @@ class Demo: # Environment step next_obs, all_rewards, done, _ = self.env.step(action_dict) - for a in range(self.env.get_num_agents()): - data, distance = self.env.obs_builder.split_tree(tree=np.array(next_obs[a]), num_features_per_node=5, - current_depth=0) - data = norm_obs_clip(data) - distance = norm_obs_clip(distance) - next_obs[a] = np.concatenate((data, distance)) - - # Update replay buffer and train agent - for a in range(self.env.get_num_agents()): - agent_next_obs[a] = np.concatenate((time_obs[0][a], time_obs[1][a])) - - time_obs.append(next_obs) - agent_obs = agent_next_obs.copy() if done['__all__']: break diff --git a/notebooks/Editor2.ipynb b/notebooks/Editor2.ipynb index 20286e8..71b74b7 100644 --- a/notebooks/Editor2.ipynb +++ b/notebooks/Editor2.ipynb @@ -9,9 +9,18 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 9, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" + ] + } + ], "source": [ "%load_ext autoreload\n", "%autoreload 2" @@ -19,7 +28,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -32,7 +41,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -54,7 +63,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -63,7 +72,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -97,7 +106,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 14, "metadata": { "scrolled": false }, @@ -105,7 +114,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "7b66ea9348c9477f881ff27456987363", + "model_id": "31e3248d9a0e4b5da8f2439abd13558d", "version_major": 2, "version_minor": 0 }, @@ -123,7 +132,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 15, "metadata": { "scrolled": false }, @@ -131,7 +140,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "ffa0f869fe8a4921a7415384b75c1ded", + "model_id": "c22754b330ce490383eb05972bc96afe", "version_major": 2, "version_minor": 0 }, @@ -150,7 +159,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -159,7 +168,7 @@ "(0, 0)" ] }, - "execution_count": 8, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } -- GitLab