diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..e7e9d11d4bf243bffe4bb60b4ac1f9392297f4bf --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,2 @@ +# Default ignored files +/workspace.xml diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000000000000000000000000000000000000..105ce2da2d6447d11dfe32bfb846c3d5b199fc99 --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ +<component name="InspectionProjectProfileManager"> + <settings> + <option name="USE_PROJECT_PROFILE" value="false" /> + <version value="1.0" /> + </settings> +</component> \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000000000000000000000000000000000000..65531ca992813bbfedbe43dfae5a5f4337168ed8 --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ +<?xml version="1.0" encoding="UTF-8"?> +<project version="4"> + <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.6" project-jdk-type="Python SDK" /> +</project> \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000000000000000000000000000000000000..925278cb163e3cc6c725cda433b8df8b625c3f0b --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ +<?xml version="1.0" encoding="UTF-8"?> +<project version="4"> + <component name="ProjectModuleManager"> + <modules> + <module fileurl="file://$PROJECT_DIR$/.idea/neurips2020-flatland-starter-kit.iml" filepath="$PROJECT_DIR$/.idea/neurips2020-flatland-starter-kit.iml" /> + </modules> + </component> +</project> \ No newline at end of file diff --git a/.idea/neurips2020-flatland-starter-kit.iml b/.idea/neurips2020-flatland-starter-kit.iml new file mode 100644 index 0000000000000000000000000000000000000000..8dc09e5476bcb840206461450ae44f23421d964a --- /dev/null +++ b/.idea/neurips2020-flatland-starter-kit.iml @@ -0,0 +1,11 @@ +<?xml version="1.0" encoding="UTF-8"?> +<module type="PYTHON_MODULE" version="4"> + <component name="NewModuleRootManager"> + <content url="file://$MODULE_DIR$" /> + <orderEntry type="inheritedJdk" /> + <orderEntry type="sourceFolder" forTests="false" /> + </component> + <component name="TestRunnerService"> + <option name="PROJECT_TEST_RUNNER" value="pytest" /> + </component> +</module> \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000000000000000000000000000000000000..94a25f7f4cb416c083d265558da75d457237d671 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ +<?xml version="1.0" encoding="UTF-8"?> +<project version="4"> + <component name="VcsDirectoryMappings"> + <mapping directory="$PROJECT_DIR$" vcs="Git" /> + </component> +</project> \ No newline at end of file diff --git a/checkpoints/No_col_20/best_0.6672/ppo/model_checkpoint.meta b/checkpoints/No_col_20/best_0.6672/ppo/model_checkpoint.meta deleted file mode 100644 index 4f703391c1f977f6a63e4d8320ad8cdfb10e9a97..0000000000000000000000000000000000000000 Binary files a/checkpoints/No_col_20/best_0.6672/ppo/model_checkpoint.meta and /dev/null differ diff --git a/checkpoints/No_col_20/best_0.6672/ppo/model_checkpoint.optimizer b/checkpoints/No_col_20/best_0.6672/ppo/model_checkpoint.optimizer deleted file mode 100644 index d92c6b5ad82ef53eca4332ccc23056a80a3699a2..0000000000000000000000000000000000000000 Binary files a/checkpoints/No_col_20/best_0.6672/ppo/model_checkpoint.optimizer and /dev/null differ diff --git a/checkpoints/No_col_20/best_0.6672/ppo/model_checkpoint.policy b/checkpoints/No_col_20/best_0.6672/ppo/model_checkpoint.policy deleted file mode 100644 index 1db257665e9b57e80c24c1e1bcc165fbcdb80d71..0000000000000000000000000000000000000000 Binary files a/checkpoints/No_col_20/best_0.6672/ppo/model_checkpoint.policy and /dev/null differ diff --git a/checkpoints/No_col_20/best_0.6719/ppo/model_checkpoint.meta b/checkpoints/No_col_20/best_0.6719/ppo/model_checkpoint.meta deleted file mode 100644 index 84ffd934079e5114f881aad914112c35e5b0f777..0000000000000000000000000000000000000000 Binary files a/checkpoints/No_col_20/best_0.6719/ppo/model_checkpoint.meta and /dev/null differ diff --git a/checkpoints/No_col_20/best_0.6719/ppo/model_checkpoint.optimizer b/checkpoints/No_col_20/best_0.6719/ppo/model_checkpoint.optimizer deleted file mode 100644 index 63b4fc065ee61195c655c5aef7b41b6e354bf75d..0000000000000000000000000000000000000000 Binary files a/checkpoints/No_col_20/best_0.6719/ppo/model_checkpoint.optimizer and /dev/null differ diff --git a/checkpoints/No_col_20/best_0.6719/ppo/model_checkpoint.policy b/checkpoints/No_col_20/best_0.6719/ppo/model_checkpoint.policy deleted file mode 100644 index 7430a2d9b643bf30f9522c8ef3390bac8009f4c1..0000000000000000000000000000000000000000 Binary files a/checkpoints/No_col_20/best_0.6719/ppo/model_checkpoint.policy and /dev/null differ diff --git a/checkpoints/No_col_20/best_0.7089/ppo/model_checkpoint.meta b/checkpoints/No_col_20/best_0.7089/ppo/model_checkpoint.meta deleted file mode 100644 index 8f5843e3dc95213ff24db6375e9a4ba65dfb29ef..0000000000000000000000000000000000000000 Binary files a/checkpoints/No_col_20/best_0.7089/ppo/model_checkpoint.meta and /dev/null differ diff --git a/checkpoints/No_col_20/best_0.7089/ppo/model_checkpoint.optimizer b/checkpoints/No_col_20/best_0.7089/ppo/model_checkpoint.optimizer deleted file mode 100644 index 7bc95369b9607e0cd3528f8be5db73eac34c0e17..0000000000000000000000000000000000000000 Binary files a/checkpoints/No_col_20/best_0.7089/ppo/model_checkpoint.optimizer and /dev/null differ diff --git a/checkpoints/No_col_20/best_0.7089/ppo/model_checkpoint.policy b/checkpoints/No_col_20/best_0.7089/ppo/model_checkpoint.policy deleted file mode 100644 index 339210ff779ad9a89efbf3620151939facd12c77..0000000000000000000000000000000000000000 Binary files a/checkpoints/No_col_20/best_0.7089/ppo/model_checkpoint.policy and /dev/null differ diff --git a/checkpoints/No_col_20/best_0.7526/ppo/model_checkpoint.meta b/checkpoints/No_col_20/best_0.7526/ppo/model_checkpoint.meta deleted file mode 100644 index 7617876cf3d7031f066a779fde687404b0a1cc6f..0000000000000000000000000000000000000000 Binary files a/checkpoints/No_col_20/best_0.7526/ppo/model_checkpoint.meta and /dev/null differ diff --git a/checkpoints/No_col_20/best_0.7526/ppo/model_checkpoint.optimizer b/checkpoints/No_col_20/best_0.7526/ppo/model_checkpoint.optimizer deleted file mode 100644 index b93d28155360433cbb1574b2e797bf1e293c2f6c..0000000000000000000000000000000000000000 Binary files a/checkpoints/No_col_20/best_0.7526/ppo/model_checkpoint.optimizer and /dev/null differ diff --git a/checkpoints/No_col_20/best_0.7526/ppo/model_checkpoint.policy b/checkpoints/No_col_20/best_0.7526/ppo/model_checkpoint.policy deleted file mode 100644 index bc21bc40897b530a65966b1cbbbaeb41835f7b69..0000000000000000000000000000000000000000 Binary files a/checkpoints/No_col_20/best_0.7526/ppo/model_checkpoint.policy and /dev/null differ diff --git a/checkpoints/No_col_20/model_checkpoint.meta b/checkpoints/No_col_20/model_checkpoint.meta deleted file mode 100644 index fb226078f5ccd715028992499f09f6edfcc4857e..0000000000000000000000000000000000000000 Binary files a/checkpoints/No_col_20/model_checkpoint.meta and /dev/null differ diff --git a/checkpoints/No_col_20/model_checkpoint.optimizer b/checkpoints/No_col_20/model_checkpoint.optimizer deleted file mode 100644 index c725a9ccd7791016f89cf8426746dd549dc23410..0000000000000000000000000000000000000000 Binary files a/checkpoints/No_col_20/model_checkpoint.optimizer and /dev/null differ diff --git a/checkpoints/No_col_20/model_checkpoint.policy b/checkpoints/No_col_20/model_checkpoint.policy deleted file mode 100644 index 91e0a1abe40721649a98ccc76271f34295f7932c..0000000000000000000000000000000000000000 Binary files a/checkpoints/No_col_20/model_checkpoint.policy and /dev/null differ diff --git a/checkpoints/best_0.4757/ppo/model_checkpoint.meta b/checkpoints/best_0.4757/ppo/model_checkpoint.meta deleted file mode 100644 index 8c645aa2e8794ff4edf9974b297ccca3eed5b013..0000000000000000000000000000000000000000 Binary files a/checkpoints/best_0.4757/ppo/model_checkpoint.meta and /dev/null differ diff --git a/checkpoints/best_0.4757/ppo/model_checkpoint.optimizer b/checkpoints/best_0.4757/ppo/model_checkpoint.optimizer deleted file mode 100644 index bd1ca4d0a9df1ccd8eff82001b542b0edadc3796..0000000000000000000000000000000000000000 Binary files a/checkpoints/best_0.4757/ppo/model_checkpoint.optimizer and /dev/null differ diff --git a/checkpoints/best_0.4757/ppo/model_checkpoint.policy b/checkpoints/best_0.4757/ppo/model_checkpoint.policy deleted file mode 100644 index d0bcce75eb2b058ef3a63abdfb48d1d78e5f1ef9..0000000000000000000000000000000000000000 Binary files a/checkpoints/best_0.4757/ppo/model_checkpoint.policy and /dev/null differ diff --git a/checkpoints/best_0.4893/ppo/model_checkpoint.meta b/checkpoints/best_0.4893/ppo/model_checkpoint.meta deleted file mode 100644 index 17b21c3612387374893f2a5ce283dc671e0bb66c..0000000000000000000000000000000000000000 Binary files a/checkpoints/best_0.4893/ppo/model_checkpoint.meta and /dev/null differ diff --git a/checkpoints/best_0.4893/ppo/model_checkpoint.optimizer b/checkpoints/best_0.4893/ppo/model_checkpoint.optimizer deleted file mode 100644 index bf6ef14481c4832e7fd8289f775cd1da08d27d25..0000000000000000000000000000000000000000 Binary files a/checkpoints/best_0.4893/ppo/model_checkpoint.optimizer and /dev/null differ diff --git a/checkpoints/best_0.4893/ppo/model_checkpoint.policy b/checkpoints/best_0.4893/ppo/model_checkpoint.policy deleted file mode 100644 index 91722141aa59d4376ad3a7695a08d016908af6ab..0000000000000000000000000000000000000000 Binary files a/checkpoints/best_0.4893/ppo/model_checkpoint.policy and /dev/null differ diff --git a/checkpoints/best_0.5003/ppo/model_checkpoint.meta b/checkpoints/best_0.5003/ppo/model_checkpoint.meta deleted file mode 100644 index 0cbaf636a4b6be40fe1bca6e522239cc733171b4..0000000000000000000000000000000000000000 Binary files a/checkpoints/best_0.5003/ppo/model_checkpoint.meta and /dev/null differ diff --git a/checkpoints/best_0.5003/ppo/model_checkpoint.optimizer b/checkpoints/best_0.5003/ppo/model_checkpoint.optimizer deleted file mode 100644 index 87a0386cbdbedd9cdaa4d8de81e7131caa71b815..0000000000000000000000000000000000000000 Binary files a/checkpoints/best_0.5003/ppo/model_checkpoint.optimizer and /dev/null differ diff --git a/checkpoints/best_0.5003/ppo/model_checkpoint.policy b/checkpoints/best_0.5003/ppo/model_checkpoint.policy deleted file mode 100644 index dfde8badb9c05dc8599bcccf81f27580b4e0ff08..0000000000000000000000000000000000000000 Binary files a/checkpoints/best_0.5003/ppo/model_checkpoint.policy and /dev/null differ diff --git a/checkpoints/best_0.5109/ppo/model_checkpoint.meta b/checkpoints/best_0.5109/ppo/model_checkpoint.meta deleted file mode 100644 index 4079d5979b350916b55315bc305fdaf842b7be27..0000000000000000000000000000000000000000 Binary files a/checkpoints/best_0.5109/ppo/model_checkpoint.meta and /dev/null differ diff --git a/checkpoints/best_0.5109/ppo/model_checkpoint.optimizer b/checkpoints/best_0.5109/ppo/model_checkpoint.optimizer deleted file mode 100644 index a7f7fb013513f9a3dfe716acd7ef49c2fa0d57d9..0000000000000000000000000000000000000000 Binary files a/checkpoints/best_0.5109/ppo/model_checkpoint.optimizer and /dev/null differ diff --git a/checkpoints/best_0.5109/ppo/model_checkpoint.policy b/checkpoints/best_0.5109/ppo/model_checkpoint.policy deleted file mode 100644 index a613c4b42171de1169cbfa42d68a932457399cba..0000000000000000000000000000000000000000 Binary files a/checkpoints/best_0.5109/ppo/model_checkpoint.policy and /dev/null differ diff --git a/checkpoints/best_0.5172/ppo/model_checkpoint.meta b/checkpoints/best_0.5172/ppo/model_checkpoint.meta deleted file mode 100644 index 40125f2717372d6f7b2d9d3c86cb11c592ba99b5..0000000000000000000000000000000000000000 Binary files a/checkpoints/best_0.5172/ppo/model_checkpoint.meta and /dev/null differ diff --git a/checkpoints/best_0.5172/ppo/model_checkpoint.optimizer b/checkpoints/best_0.5172/ppo/model_checkpoint.optimizer deleted file mode 100644 index 7feedff1803645d8b131523fc21f990f96bc681a..0000000000000000000000000000000000000000 Binary files a/checkpoints/best_0.5172/ppo/model_checkpoint.optimizer and /dev/null differ diff --git a/checkpoints/best_0.5172/ppo/model_checkpoint.policy b/checkpoints/best_0.5172/ppo/model_checkpoint.policy deleted file mode 100644 index fe9741cdd7a7abd2d171c84aa95b1f1d0a17841c..0000000000000000000000000000000000000000 Binary files a/checkpoints/best_0.5172/ppo/model_checkpoint.policy and /dev/null differ diff --git a/checkpoints/best_0.5355/ppo/model_checkpoint.meta b/checkpoints/best_0.5355/ppo/model_checkpoint.meta deleted file mode 100644 index b45d9b8937c572df6febfa2f0ac5a9d4cda4eb0e..0000000000000000000000000000000000000000 Binary files a/checkpoints/best_0.5355/ppo/model_checkpoint.meta and /dev/null differ diff --git a/checkpoints/best_0.5355/ppo/model_checkpoint.optimizer b/checkpoints/best_0.5355/ppo/model_checkpoint.optimizer deleted file mode 100644 index 915d64ab8e782a47c0de68a8ccdb81842a074c85..0000000000000000000000000000000000000000 Binary files a/checkpoints/best_0.5355/ppo/model_checkpoint.optimizer and /dev/null differ diff --git a/checkpoints/best_0.5355/ppo/model_checkpoint.policy b/checkpoints/best_0.5355/ppo/model_checkpoint.policy deleted file mode 100644 index 3300d40d010b85f3e4395c9aed630f6beea486af..0000000000000000000000000000000000000000 Binary files a/checkpoints/best_0.5355/ppo/model_checkpoint.policy and /dev/null differ diff --git a/checkpoints/best_0.5435/ppo/model_checkpoint.meta b/checkpoints/best_0.5435/ppo/model_checkpoint.meta deleted file mode 100644 index 56a27a0763598ba9748c4b337fcb59e95ccdf612..0000000000000000000000000000000000000000 Binary files a/checkpoints/best_0.5435/ppo/model_checkpoint.meta and /dev/null differ diff --git a/checkpoints/best_0.5435/ppo/model_checkpoint.optimizer b/checkpoints/best_0.5435/ppo/model_checkpoint.optimizer deleted file mode 100644 index 1cec17631653a3834677441f541024d582d406b4..0000000000000000000000000000000000000000 Binary files a/checkpoints/best_0.5435/ppo/model_checkpoint.optimizer and /dev/null differ diff --git a/checkpoints/best_0.5435/ppo/model_checkpoint.policy b/checkpoints/best_0.5435/ppo/model_checkpoint.policy deleted file mode 100644 index 8a3f4e113a4fd5ca7fe4ef27b5dd81d682939df7..0000000000000000000000000000000000000000 Binary files a/checkpoints/best_0.5435/ppo/model_checkpoint.policy and /dev/null differ diff --git a/checkpoints/best_0.8620/ppo/model_checkpoint.meta b/checkpoints/best_0.8620/ppo/model_checkpoint.meta deleted file mode 100644 index 8f5843e3dc95213ff24db6375e9a4ba65dfb29ef..0000000000000000000000000000000000000000 Binary files a/checkpoints/best_0.8620/ppo/model_checkpoint.meta and /dev/null differ diff --git a/checkpoints/best_0.8620/ppo/model_checkpoint.optimizer b/checkpoints/best_0.8620/ppo/model_checkpoint.optimizer deleted file mode 100644 index 13b15ba48dde84af9e70f18cb0e4395737351f00..0000000000000000000000000000000000000000 Binary files a/checkpoints/best_0.8620/ppo/model_checkpoint.optimizer and /dev/null differ diff --git a/checkpoints/best_0.8620/ppo/model_checkpoint.policy b/checkpoints/best_0.8620/ppo/model_checkpoint.policy deleted file mode 100644 index 7049ae616c44eb2c20edc1eb6aaf26f16e839851..0000000000000000000000000000000000000000 Binary files a/checkpoints/best_0.8620/ppo/model_checkpoint.policy and /dev/null differ diff --git a/checkpoints/dqn/README.md b/checkpoints/dqn/README.md deleted file mode 100644 index 7877436bc8fc766ce1409c3810a48b13da31ac39..0000000000000000000000000000000000000000 --- a/checkpoints/dqn/README.md +++ /dev/null @@ -1 +0,0 @@ -DQN checkpoints will be saved here diff --git a/checkpoints/dqn/model_checkpoint.local b/checkpoints/dqn/model_checkpoint.local deleted file mode 100644 index cca8687b9d333d2a050d8f6910960ef80f0680b7..0000000000000000000000000000000000000000 Binary files a/checkpoints/dqn/model_checkpoint.local and /dev/null differ diff --git a/checkpoints/dqn/model_checkpoint.meta b/checkpoints/dqn/model_checkpoint.meta deleted file mode 100644 index 502812ecee5de1aa38caa05cc9766f7d2f04ba7e..0000000000000000000000000000000000000000 Binary files a/checkpoints/dqn/model_checkpoint.meta and /dev/null differ diff --git a/checkpoints/dqn/model_checkpoint.optimizer b/checkpoints/dqn/model_checkpoint.optimizer deleted file mode 100644 index 5badbab158a8d003b069c5a2bb0b8628f25dd89b..0000000000000000000000000000000000000000 Binary files a/checkpoints/dqn/model_checkpoint.optimizer and /dev/null differ diff --git a/checkpoints/dqn/model_checkpoint.target b/checkpoints/dqn/model_checkpoint.target deleted file mode 100644 index 6a853ac88d2554eef6d0b6a414d87d9993c22998..0000000000000000000000000000000000000000 Binary files a/checkpoints/dqn/model_checkpoint.target and /dev/null differ diff --git a/checkpoints/ppo/model_checkpoint.meta b/checkpoints/ppo/model_checkpoint.meta index b45d9b8937c572df6febfa2f0ac5a9d4cda4eb0e..e998bcda155cbd11e6f1b70e77966ed92812930c 100644 Binary files a/checkpoints/ppo/model_checkpoint.meta and b/checkpoints/ppo/model_checkpoint.meta differ diff --git a/checkpoints/ppo/model_checkpoint.optimizer b/checkpoints/ppo/model_checkpoint.optimizer index 190ef25976343f4c1cca9b751f78fc8fdcadfa28..0fbd49e2e1fa62c34dff663c9d47bd53f61128f3 100644 Binary files a/checkpoints/ppo/model_checkpoint.optimizer and b/checkpoints/ppo/model_checkpoint.optimizer differ diff --git a/checkpoints/ppo/model_checkpoint.policy b/checkpoints/ppo/model_checkpoint.policy index c4492df60aaec91709c87ae729bf71480866b31e..4868691cf2f12669df395b845b2a903c1d917336 100644 Binary files a/checkpoints/ppo/model_checkpoint.policy and b/checkpoints/ppo/model_checkpoint.policy differ diff --git a/dump.rdb b/dump.rdb new file mode 100644 index 0000000000000000000000000000000000000000..ba7e97e6ff27c20f75b47463e01777453f724a57 Binary files /dev/null and b/dump.rdb differ diff --git a/src/extra.py b/src/extra.py index 025c8f1355f5c4c1122a3b5272049234e53c71da..e5a75c8d6eb82997aff52f7b14263e78edc39693 100644 --- a/src/extra.py +++ b/src/extra.py @@ -66,8 +66,10 @@ def fast_count_nonzero(possible_transitions: (int, int, int, int)): class Extra(ObservationBuilder): def __init__(self, max_depth): - self.max_depth = max_depth - self.observation_dim = 30 + self.max_depth = max_depthmodel_checkpoint.meta +model_checkpoint.optimizer +model_checkpoint.policy + self.observation_dim = 22 self.agent = None def build_data(self): @@ -188,6 +190,10 @@ class Extra(ObservationBuilder): return obsData def is_collision(self, obsData): + if np.sum(obsData[10:14]) == 0: + return False + if np.sum(obsData[10:14]) == np.sum(obsData[14:18]): + return True return False def reset(self): @@ -203,16 +209,14 @@ class Extra(ObservationBuilder): return 2 return 3 - def _explore(self, handle, distance_map, new_position, new_direction, depth=0): + def _explore(self, handle, new_position, new_direction, depth=0): has_opp_agent = 0 has_same_agent = 0 visited = [] - visited_direction = [] - visited_min_distance = np.inf # stop exploring (max_depth reached) if depth >= self.max_depth: - return has_opp_agent, has_same_agent, visited, visited_direction, visited_min_distance + return has_opp_agent, has_same_agent, visited # max_explore_steps = 100 cnt = 0 @@ -220,22 +224,15 @@ class Extra(ObservationBuilder): cnt += 1 visited.append(new_position) - visited_direction.append(new_direction) - - new_cell_dist = distance_map[handle, - new_position[0], new_position[1], - new_direction] - visited_min_distance = min(visited_min_distance, new_cell_dist) - opp_a = self.env.agent_positions[new_position] if opp_a != -1 and opp_a != handle: if self.env.agents[opp_a].direction != new_direction: # opp agent found has_opp_agent = 1 - return has_opp_agent, has_same_agent, visited, visited_direction, visited_min_distance + return has_opp_agent, has_same_agent, visited else: has_same_agent = 1 - return has_opp_agent, has_same_agent, visited, visited_direction, visited_min_distance + return has_opp_agent, has_same_agent, visited # convert one-hot encoding to 0,1,2,3 possible_transitions = self.env.rail.get_transitions(*new_position, new_direction) @@ -244,28 +241,23 @@ class Extra(ObservationBuilder): agents_near_to_switch_all = \ self.check_agent_descision(new_position, new_direction) if agents_near_to_switch: - return has_opp_agent, has_same_agent, visited, visited_direction, visited_min_distance + return has_opp_agent, has_same_agent, visited if agents_on_switch: for dir_loop in range(4): if possible_transitions[dir_loop] == 1: - hoa, hsa, v, d, min_dist = self._explore(handle, - distance_map, - get_new_position(new_position, dir_loop), - dir_loop, - depth + 1) - if np.math.isinf(min_dist) == False: - visited_min_distance = min(visited_min_distance, min_dist) - - visited = visited + v - visited_direction = visited_direction + d + hoa, hsa, v = self._explore(handle, + get_new_position(new_position, dir_loop), + dir_loop, + depth + 1) + visited.append(v) has_opp_agent = 0.5 * (has_opp_agent + hoa) has_same_agent = 0.5 * (has_same_agent + hsa) - return has_opp_agent, has_same_agent, visited, visited_direction, visited_min_distance + return has_opp_agent, has_same_agent, visited else: new_direction = fast_argmax(possible_transitions) new_position = get_new_position(new_position, new_direction) - return has_opp_agent, has_same_agent, visited, visited_direction, visited_min_distance + return has_opp_agent, has_same_agent, visited def get(self, handle): # all values are [0,1] @@ -294,7 +286,6 @@ class Extra(ObservationBuilder): observation = np.zeros(self.observation_dim) visited = [] - visited_direction = [] agent = self.env.agents[handle] agent_done = False @@ -311,7 +302,6 @@ class Extra(ObservationBuilder): if not agent_done: visited.append(agent_virtual_position) - visited_direction.append(agent.direction) distance_map = self.env.distance_map.get() current_cell_dist = distance_map[handle, agent_virtual_position[0], agent_virtual_position[1], @@ -330,12 +320,8 @@ class Extra(ObservationBuilder): if not (np.math.isinf(new_cell_dist) and np.math.isinf(current_cell_dist)): observation[dir_loop] = int(new_cell_dist < current_cell_dist) - has_opp_agent, has_same_agent, vis, dir, min_dist = self._explore(handle, - distance_map, - new_position, - branch_direction) - visited = visited + vis - visited_direction = visited_direction + dir + has_opp_agent, has_same_agent, v = self._explore(handle, new_position, branch_direction) + visited.append(v) observation[10 + dir_loop] = 1 observation[14 + dir_loop] = has_opp_agent @@ -349,16 +335,6 @@ class Extra(ObservationBuilder): observation[8] = int(agents_near_to_switch) observation[9] = int(agents_near_to_switch_all) - observation[22] = int(self.env._elapsed_steps % 4 == 0) - observation[23] = int(self.env._elapsed_steps % 4 == 1) - observation[24] = int(self.env._elapsed_steps % 4 == 2) - observation[25] = int(self.env._elapsed_steps % 4 == 3) - - observation[26] = int(agent.direction % 4 == 0) - observation[27] = int(agent.direction % 4 == 1) - observation[28] = int(agent.direction % 4 == 2) - observation[29] = int(agent.direction % 4 == 3) - self.env.dev_obs_dict.update({handle: visited}) return observation diff --git a/src/ppo/model.py b/src/ppo/model.py index 421423df6739bbc4b4ed94487de7e3dfa9d973a8..51b86ff16691c03f6a754405352bb4cf48e4b914 100644 --- a/src/ppo/model.py +++ b/src/ppo/model.py @@ -3,7 +3,7 @@ import torch.nn.functional as F class PolicyNetwork(nn.Module): - def __init__(self, state_size, action_size, hidsize1=128, hidsize2=256, hidsize3=32): + def __init__(self, state_size, action_size, hidsize1=128, hidsize2=128, hidsize3=32): super().__init__() self.fc1 = nn.Linear(state_size, hidsize1) self.fc2 = nn.Linear(hidsize1, hidsize2)