diff --git a/checkpoints/201112143850-5400.pth.local b/checkpoints/201112143850-5400.pth.local deleted file mode 100644 index 13ff4c91a142fa7712f07fc251650f1fe4bf83f0..0000000000000000000000000000000000000000 Binary files a/checkpoints/201112143850-5400.pth.local and /dev/null differ diff --git a/checkpoints/201112143850-5400.pth.target b/checkpoints/201112143850-5400.pth.target deleted file mode 100644 index 5c302cdd1fd9c0f480444b49d266313af07a0c23..0000000000000000000000000000000000000000 Binary files a/checkpoints/201112143850-5400.pth.target and /dev/null differ diff --git a/checkpoints/201113211844-6100.pth.local b/checkpoints/201113211844-6100.pth.local deleted file mode 100644 index 1d069b2b5ccdabe52732a6ebb34a2e4bdcd08247..0000000000000000000000000000000000000000 Binary files a/checkpoints/201113211844-6100.pth.local and /dev/null differ diff --git a/checkpoints/201113211844-6100.pth.target b/checkpoints/201113211844-6100.pth.target deleted file mode 100644 index adc2a88d9992133f95162f4596cbdda98fd1da66..0000000000000000000000000000000000000000 Binary files a/checkpoints/201113211844-6100.pth.target and /dev/null differ diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py index b3271b5867f618154668bd12d7c9b61dc1636b40..3c2fd9fdb88cabb1277b8779e8b0cf25ed9feae0 100755 --- a/reinforcement_learning/multi_agent_training.py +++ b/reinforcement_learning/multi_agent_training.py @@ -541,7 +541,7 @@ if __name__ == "__main__": parser.add_argument("--load_policy", help="policy filename (reference) to load", default="", type=str) parser.add_argument("--use_fast_tree_observation", help="use FastTreeObs instead of stock TreeObs", action='store_true') - parser.add_argument("--max_depth", help="max depth", default=2, type=int) + parser.add_argument("--max_depth", help="max depth", default=1, type=int) training_params = parser.parse_args() env_params = [ diff --git a/utils/fast_tree_obs.py b/utils/fast_tree_obs.py index 625a21e2b3a1d05951cd67fad9b30aa0401fde56..1e3b507cbe9128073ee4c0ba79657e1acca48fb1 100755 --- a/utils/fast_tree_obs.py +++ b/utils/fast_tree_obs.py @@ -25,7 +25,7 @@ class FastTreeObs(ObservationBuilder): def __init__(self, max_depth): self.max_depth = max_depth - self.observation_dim = 32 + self.observation_dim = 33 def build_data(self): if self.env is not None: @@ -41,6 +41,8 @@ class FastTreeObs(ObservationBuilder): self.dead_lock_avoidance_agent = None def find_all_switches(self): + # Search the environment (rail grid) for all switch cells. A switch is a cell where more than one tranisation + # exists and collect all direction where the switch is a switch. self.switches = {} for h in range(self.env.height): for w in range(self.env.width): @@ -55,6 +57,8 @@ class FastTreeObs(ObservationBuilder): self.switches[pos].append(dir) def find_all_switch_neighbours(self): + # Collect all cells where is a neighbour to a switch cell. All cells are neighbour where the agent can make + # just one step and he stands on a switch. A switch is a cell where the agents has more than one transition. self.switches_neighbours = {} for h in range(self.env.height): for w in range(self.env.width): @@ -72,10 +76,18 @@ class FastTreeObs(ObservationBuilder): self.switches_neighbours[pos].append(dir) def find_all_cell_where_agent_can_choose(self): + # prepare the data - collect all cells where the agent can choose more than FORWARD/STOP. self.find_all_switches() self.find_all_switch_neighbours() def check_agent_decision(self, position, direction): + # Decide whether the agent is + # - on a switch + # - at a switch neighbour (near to switch). The switch must be a switch where the agent has more option than + # FORWARD/STOP + # - all switch : doesn't matter whether the agent has more options than FORWARD/STOP + # - all switch neightbors : doesn't matter the agent has more then one options (transistion) when he reach the + # switch agents_on_switch = False agents_on_switch_all = False agents_near_to_switch = False @@ -301,24 +313,26 @@ class FastTreeObs(ObservationBuilder): has_opp_agent, has_same_agent, has_target, v = self._explore(handle, new_position, branch_direction) visited.append(v) - observation[10 + dir_loop] = int(not np.math.isinf(new_cell_dist)) - observation[14 + dir_loop] = has_opp_agent - observation[18 + dir_loop] = has_same_agent - observation[22 + dir_loop] = has_target - observation[26 + dir_loop] = int(np.math.isinf(new_cell_dist)) + observation[11 + dir_loop] = int(not np.math.isinf(new_cell_dist)) + observation[15 + dir_loop] = has_opp_agent + observation[19 + dir_loop] = has_same_agent + observation[23 + dir_loop] = has_target + observation[27 + dir_loop] = int(np.math.isinf(new_cell_dist)) agents_on_switch, \ agents_near_to_switch, \ agents_near_to_switch_all, \ agents_on_switch_all = \ self.check_agent_decision(agent_virtual_position, agent.direction) + observation[7] = int(agents_on_switch) - observation[8] = int(agents_near_to_switch) - observation[9] = int(agents_near_to_switch_all) + observation[8] = int(agents_on_switch_all) + observation[9] = int(agents_near_to_switch) + observation[10] = int(agents_near_to_switch_all) action = self.dead_lock_avoidance_agent.act([handle], 0.0) - observation[30] = int(action == RailEnvActions.STOP_MOVING) - observation[31] = int(fast_count_nonzero(possible_transitions) == 1) + observation[31] = int(action == RailEnvActions.STOP_MOVING) + observation[32] = int(fast_count_nonzero(possible_transitions) == 1) self.env.dev_obs_dict.update({handle: visited})