diff --git a/env-data/railway/example_network_000.pkl b/env-data/railway/example_network_000.pkl index ab7764b1ea250633e5a0cd2ec9b93a4e07479c9f..e102e21735416747cb8bd9f231ce6e20fdf514c0 100644 Binary files a/env-data/railway/example_network_000.pkl and b/env-data/railway/example_network_000.pkl differ diff --git a/env-data/railway/example_network_001.pkl b/env-data/railway/example_network_001.pkl index af9dc43246449562065be64c750a06c75c3e7571..a9c5cc97c9c4bf4159db2134756f17fa0c4fce87 100644 Binary files a/env-data/railway/example_network_001.pkl and b/env-data/railway/example_network_001.pkl differ diff --git a/env-data/railway/example_network_002.pkl b/env-data/railway/example_network_002.pkl index d39e44798066e0a0b753006775c5727326fc58da..37647ac2871801d2d08fd65276889e2b232c1170 100644 Binary files a/env-data/railway/example_network_002.pkl and b/env-data/railway/example_network_002.pkl differ diff --git a/examples/demo.py b/examples/demo.py index a53370aac5b3668cb12f7f794b1d91d563ef96e4..c2485de46902304d375f80b0a3d689b8a18b6d0f 100644 --- a/examples/demo.py +++ b/examples/demo.py @@ -132,39 +132,16 @@ class Demo: handle = self.env.get_agent_handles() return handle - def run_demo(self, max_nbr_of_steps=100): + def run_demo(self, max_nbr_of_steps=30): action_dict = dict() - time_obs = deque(maxlen=2) - action_prob = [0] * 4 - agent_obs = [None] * self.env.get_num_agents() - agent_next_obs = [None] * self.env.get_num_agents() # Reset environment obs = self.env.reset(False, False) - for a in range(self.env.get_num_agents()): - data, distance = self.env.obs_builder.split_tree(tree=np.array(obs[a]), num_features_per_node=5, current_depth=0) - - data = norm_obs_clip(data) - distance = norm_obs_clip(distance) - obs[a] = np.concatenate((data, distance)) - - for i in range(2): - time_obs.append(obs) - - # env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5) - for a in range(self.env.get_num_agents()): - agent_obs[a] = np.concatenate((time_obs[0][a], time_obs[1][a])) - for step in range(max_nbr_of_steps): - - time.sleep(.2) - - # print(step) # Action for a in range(self.env.get_num_agents()): - action = np.random.choice(self.action_size) #self.agent.act(agent_obs[a]) - action_prob[action] += 1 + action = 2 #np.random.choice(self.action_size) #self.agent.act(agent_obs[a]) action_dict.update({a: action}) print(action_dict) @@ -173,20 +150,7 @@ class Demo: # Environment step next_obs, all_rewards, done, _ = self.env.step(action_dict) - for a in range(self.env.get_num_agents()): - data, distance = self.env.obs_builder.split_tree(tree=np.array(next_obs[a]), num_features_per_node=5, - current_depth=0) - data = norm_obs_clip(data) - distance = norm_obs_clip(distance) - next_obs[a] = np.concatenate((data, distance)) - - # Update replay buffer and train agent - for a in range(self.env.get_num_agents()): - agent_next_obs[a] = np.concatenate((time_obs[0][a], time_obs[1][a])) - - time_obs.append(next_obs) - agent_obs = agent_next_obs.copy() if done['__all__']: break diff --git a/notebooks/Editor2.ipynb b/notebooks/Editor2.ipynb index 20286e886c5d73f0ce427e8feeb209a8ad99c000..71b74b793e0ae3ba41dcb1f18ff501eb10595667 100644 --- a/notebooks/Editor2.ipynb +++ b/notebooks/Editor2.ipynb @@ -9,9 +9,18 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 9, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" + ] + } + ], "source": [ "%load_ext autoreload\n", "%autoreload 2" @@ -19,7 +28,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -32,7 +41,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -54,7 +63,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -63,7 +72,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -97,7 +106,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 14, "metadata": { "scrolled": false }, @@ -105,7 +114,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "7b66ea9348c9477f881ff27456987363", + "model_id": "31e3248d9a0e4b5da8f2439abd13558d", "version_major": 2, "version_minor": 0 }, @@ -123,7 +132,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 15, "metadata": { "scrolled": false }, @@ -131,7 +140,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "ffa0f869fe8a4921a7415384b75c1ded", + "model_id": "c22754b330ce490383eb05972bc96afe", "version_major": 2, "version_minor": 0 }, @@ -150,7 +159,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -159,7 +168,7 @@ "(0, 0)" ] }, - "execution_count": 8, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" }