diff --git a/checkpoints/201124171810-7800.pth.local b/checkpoints/201124171810-7800.pth.local new file mode 100644 index 0000000000000000000000000000000000000000..66312bf9d8ad449b0c236227b6533b00c8236f46 Binary files /dev/null and b/checkpoints/201124171810-7800.pth.local differ diff --git a/checkpoints/201124171810-7800.pth.target b/checkpoints/201124171810-7800.pth.target new file mode 100644 index 0000000000000000000000000000000000000000..4f8ea90c7422bb215605028b925e41a4e1a8c61d Binary files /dev/null and b/checkpoints/201124171810-7800.pth.target differ diff --git a/reinforcement_learning/dddqn_policy.py b/reinforcement_learning/dddqn_policy.py index 3eb54a3dd6b0dba4f35316fd96e0f5a704bb2440..134a2c2e1ddb6c45d4ab806a3906b64c1298b529 100644 --- a/reinforcement_learning/dddqn_policy.py +++ b/reinforcement_learning/dddqn_policy.py @@ -44,8 +44,10 @@ class DDDQNPolicy(Policy): # print("🢠Using CPU") # Q-Network - self.qnetwork_local = DuelingQNetwork(state_size, action_size, hidsize1=self.hidsize, hidsize2=self.hidsize).to( - self.device) + self.qnetwork_local = DuelingQNetwork(state_size, + action_size, + hidsize1=self.hidsize, + hidsize2=self.hidsize).to(self.device) if not evaluation_mode: self.qnetwork_target = copy.deepcopy(self.qnetwork_local) diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py index 3c2fd9fdb88cabb1277b8779e8b0cf25ed9feae0..ea4d6e2126049c8be58f5da65ab435c5d5bde33d 100755 --- a/reinforcement_learning/multi_agent_training.py +++ b/reinforcement_learning/multi_agent_training.py @@ -9,16 +9,20 @@ from pprint import pprint import numpy as np import psutil +from flatland.core.grid.grid4_utils import get_new_position +from flatland.envs.agent_utils import RailAgentStatus from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv -from flatland.envs.rail_env import RailEnv, RailEnvActions +from flatland.envs.rail_env import RailEnv, RailEnvActions, fast_count_nonzero from flatland.envs.rail_generators import sparse_rail_generator from flatland.envs.schedule_generators import sparse_schedule_generator from flatland.utils.rendertools import RenderTool from torch.utils.tensorboard import SummaryWriter from reinforcement_learning.dddqn_policy import DDDQNPolicy +from reinforcement_learning.ppo.ppo_agent import PPOAgent +from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent base_dir = Path(__file__).resolve().parent.parent sys.path.append(str(base_dir)) @@ -75,6 +79,41 @@ def create_rail_env(env_params, tree_observation): ) +def get_agent_positions(env): + agent_positions: np.ndarray = np.full((env.height, env.width), -1) + for agent_handle in env.get_agent_handles(): + agent = env.agents[agent_handle] + if agent.status == RailAgentStatus.ACTIVE: + position = agent.position + if position is None: + position = agent.initial_position + agent_positions[position] = agent_handle + return agent_positions + + +def check_for_dealock(handle, env, agent_positions): + agent = env.agents[handle] + if agent.status == RailAgentStatus.DONE or agent.status == RailAgentStatus.DONE_REMOVED: + return False + + position = agent.position + if position is None: + position = agent.initial_position + possible_transitions = env.rail.get_transitions(*position, agent.direction) + num_transitions = fast_count_nonzero(possible_transitions) + for dir_loop in range(4): + if possible_transitions[dir_loop] == 1: + new_position = get_new_position(position, dir_loop) + opposite_agent = agent_positions[new_position] + if opposite_agent != handle and opposite_agent != -1: + num_transitions -= 1 + else: + return False + + is_deadlock = num_transitions <= 0 + return is_deadlock + + def train_agent(train_params, train_env_params, eval_env_params, obs_params): # Environment parameters n_agents = train_env_params.n_agents @@ -150,10 +189,6 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): # Calculate the state size given the depth of the tree observation and the number of features state_size = tree_observation.observation_dim - # Setup renderer - if train_params.render: - env_renderer = RenderTool(train_env, gl="PGL") - # The action space of flatland is 5 discrete actions action_size = 5 @@ -174,13 +209,15 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): # IF USE_SINGLE_AGENT_TRAINING is set and the episode_idx <= MAX_SINGLE_TRAINING_ITERATION then # the training gets done with single use. Each UPDATE_POLICY2_N_EPISODE the second policy get replaced # with the policy (the one which get trained). - USE_SINGLE_AGENT_TRAINING = True - MAX_SINGLE_TRAINING_ITERATION = 1000 + USE_SINGLE_AGENT_TRAINING = False + MAX_SINGLE_TRAINING_ITERATION = 100000 UPDATE_POLICY2_N_EPISODE = 200 + USE_DEADLOCK_AVOIDANCE_AS_POLICY2 = False # Double Dueling DQN policy policy = DDDQNPolicy(state_size, action_size, train_params) - # policy = PPOAgent(state_size, action_size, n_agents) + if False: + policy = PPOAgent(state_size, action_size, n_agents) # Load existing policy if train_params.load_policy is not "": policy.load(train_params.load_policy) @@ -231,17 +268,23 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): # Reset environment reset_timer.start() - train_env_params.n_agents = episode_idx % n_agents + 1 + number_of_agents = min(1 + round(n_agents * (1.0 - 0.9985 ** episode_idx)), n_agents) + train_env_params.n_agents = episode_idx % number_of_agents + 1 train_env = create_rail_env(train_env_params, tree_observation) obs, info = train_env.reset(regenerate_rail=True, regenerate_schedule=True) policy.reset() - if episode_idx % UPDATE_POLICY2_N_EPISODE == 0: - policy2 = policy.clone() + if USE_DEADLOCK_AVOIDANCE_AS_POLICY2: + policy2 = DeadLockAvoidanceAgent(train_env, action_size) + else: + if episode_idx % UPDATE_POLICY2_N_EPISODE == 0: + policy2 = policy.clone() reset_timer.end() if train_params.render: + # Setup renderer + env_renderer = RenderTool(train_env, gl="PGL") env_renderer.set_new_rail() score = 0 @@ -249,11 +292,12 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): actions_taken = [] # Build initial agent-specific observations - for agent in train_env.get_agent_handles(): - if tree_observation.check_is_observation_valid(obs[agent]): - agent_obs[agent] = tree_observation.get_normalized_observation(obs[agent], observation_tree_depth, - observation_radius=observation_radius) - agent_prev_obs[agent] = agent_obs[agent].copy() + for agent_handle in train_env.get_agent_handles(): + if tree_observation.check_is_observation_valid(obs[agent_handle]): + agent_obs[agent_handle] = tree_observation.get_normalized_observation(obs[agent_handle], + observation_tree_depth, + observation_radius=observation_radius) + agent_prev_obs[agent_handle] = agent_obs[agent_handle].copy() # Max number of steps per episode # This is the official formula used during evaluations @@ -271,21 +315,26 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): inference_timer.start() policy.start_step() policy2.start_step() - for agent in train_env.get_agent_handles(): - if info['action_required'][agent]: - update_values[agent] = True - if agent in agent_to_learn or not USE_SINGLE_AGENT_TRAINING: - action = policy.act(agent_obs[agent], eps=eps_start) + for agent_handle in train_env.get_agent_handles(): + agent = train_env.agents[agent_handle] + if info['action_required'][agent_handle]: + update_values[agent_handle] = True + if (agent_handle in agent_to_learn) or (not USE_SINGLE_AGENT_TRAINING): + action = policy.act(agent_obs[agent_handle], eps=eps_start) else: - action = policy2.act(agent_obs[agent], eps=eps_start) + if USE_DEADLOCK_AVOIDANCE_AS_POLICY2: + action = policy2.act([agent_handle], eps=0.0) + else: + action = policy2.act(agent_obs[agent_handle], eps=0.0) + action_count[action] += 1 actions_taken.append(action) else: # An action is not required if the train hasn't joined the railway network, # if it already reached its target, or if is currently malfunctioning. - update_values[agent] = False + update_values[agent_handle] = False action = 0 - action_dict.update({agent: action}) + action_dict.update({agent_handle: action}) policy.end_step() policy2.end_step() inference_timer.end() @@ -294,23 +343,18 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): step_timer.start() next_obs, all_rewards, done, info = train_env.step(action_dict) - if True: - for agent in train_env.get_agent_handles(): - act = action_dict.get(agent, RailEnvActions.DO_NOTHING) - if agent_obs[agent][5] == 1: - if agent_obs[agent][26] == 1: - if act != RailEnvActions.STOP_MOVING: - all_rewards[agent] -= 10.0 - if agent_obs[agent][27] == 1: - if act == RailEnvActions.MOVE_LEFT or \ - act == RailEnvActions.MOVE_RIGHT or \ - act == RailEnvActions.DO_NOTHING: - all_rewards[agent] -= 1.0 - + # Dead-lock found -> rewards shaping + agent_positions = get_agent_positions(train_env) + for agent_handle in train_env.get_agent_handles(): + agent = train_env.agents[agent_handle] + act = action_dict.get(agent_handle, RailEnvActions.MOVE_FORWARD) + if agent.status == RailAgentStatus.ACTIVE: + if check_for_dealock(agent_handle, train_env, agent_positions): + all_rewards[agent_handle] -= 5.0 step_timer.end() # Render an episode at some interval - if train_params.render and episode_idx % checkpoint_interval == 0: + if train_params.render: env_renderer.render_env( show=True, frames=False, @@ -319,29 +363,31 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): ) # Update replay buffer and train agent - for agent in train_env.get_agent_handles(): - if update_values[agent] or done['__all__']: + for agent_handle in train_env.get_agent_handles(): + if update_values[agent_handle] or done['__all__']: # Only learn from timesteps where somethings happened learn_timer.start() - if agent in agent_to_learn or not USE_SINGLE_AGENT_TRAINING: - policy.step(agent, - agent_prev_obs[agent], agent_prev_action[agent], all_rewards[agent], - agent_obs[agent], - done[agent]) + if (agent_handle in agent_to_learn) or (not USE_SINGLE_AGENT_TRAINING): + policy.step(agent_handle, + agent_prev_obs[agent_handle], + agent_prev_action[agent_handle], + all_rewards[agent_handle], + agent_obs[agent_handle], + done[agent_handle]) learn_timer.end() - agent_prev_obs[agent] = agent_obs[agent].copy() - agent_prev_action[agent] = action_dict[agent] + agent_prev_obs[agent_handle] = agent_obs[agent_handle].copy() + agent_prev_action[agent_handle] = action_dict[agent_handle] # Preprocess the new observations - if tree_observation.check_is_observation_valid(next_obs[agent]): + if tree_observation.check_is_observation_valid(next_obs[agent_handle]): preproc_timer.start() - agent_obs[agent] = tree_observation.get_normalized_observation(next_obs[agent], - observation_tree_depth, - observation_radius=observation_radius) + agent_obs[agent_handle] = tree_observation.get_normalized_observation(next_obs[agent_handle], + observation_tree_depth, + observation_radius=observation_radius) preproc_timer.end() - score += all_rewards[agent] + score += all_rewards[agent_handle] nb_steps = step @@ -362,6 +408,9 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): smoothed_normalized_score = np.mean(scores_window) smoothed_completion = np.mean(completion_window) + if train_params.render: + env_renderer.close_window() + # Print logs if episode_idx % checkpoint_interval == 0: policy.save('./checkpoints/' + training_id + '-' + str(episode_idx) + '.pth') @@ -369,14 +418,12 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): if save_replay_buffer: policy.save_replay_buffer('./replay_buffers/' + training_id + '-' + str(episode_idx) + '.pkl') - if train_params.render: - env_renderer.close_window() - # reset action count action_count = [0] * action_size print( '\r🚂 Episode {}' + '\t 🚉 nAgents {}' '\t 🆠Score: {:7.3f}' ' Avg: {:7.3f}' '\t 💯 Done: {:6.2f}%' @@ -384,6 +431,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): '\t 🎲 Epsilon: {:.3f} ' '\t 🔀 Action Probs: {}'.format( episode_idx, + train_env_params.n_agents, normalized_score, smoothed_normalized_score, 100 * completion, @@ -507,7 +555,7 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params): nb_steps.append(final_step) - print("\t✅ Eval: score {:.3f} done {:.1f}%".format(np.mean(scores), np.mean(completions) * 100.0)) + print(" ✅ Eval: score {:.3f} done {:.1f}%".format(np.mean(scores), np.mean(completions) * 100.0)) return scores, completions, nb_steps diff --git a/reinforcement_learning/ppo/ppo_agent.py b/reinforcement_learning/ppo/ppo_agent.py index 49fe7e6f6c02a20e4ea3d7c6e7a9d3e33cff9742..09bafd79558e19f0266adba4742f42f6f3e373b1 100644 --- a/reinforcement_learning/ppo/ppo_agent.py +++ b/reinforcement_learning/ppo/ppo_agent.py @@ -1,3 +1,4 @@ +import copy import os import numpy as np @@ -129,3 +130,10 @@ class PPOAgent(Policy): except: print(" >> failed!") pass + + def clone(self): + policy = PPOAgent(self.state_size, self.action_size, self.num_agents) + policy.policy = copy.deepcopy(self.policy) + policy.old_policy = copy.deepcopy(self.old_policy) + policy.optimizer = copy.deepcopy(self.optimizer) + return self diff --git a/run.py b/run.py index a4fa62faf053c687bd7bd0e1eaf50dff5cbb702c..b63719528c5acd1fb4098572c33f283f610aae56 100644 --- a/run.py +++ b/run.py @@ -1,3 +1,24 @@ +''' +DDDQNPolicy experiments - EPSILON impact analysis +---------------------------------------------------------------------------------------- +checkpoint = "./checkpoints/201124171810-7800.pth" # Training on AGENTS=10 with Depth=2 +EPSILON = 0.000 # Sum Normalized Reward : 0.000000000000000 (primary score) +EPSILON = 0.002 # Sum Normalized Reward : 18.445875081269286 (primary score) +EPSILON = 0.005 # Sum Normalized Reward : 18.371733625865854 (primary score) +EPSILON = 0.010 # Sum Normalized Reward : 18.249244799876152 (primary score) +EPSILON = 0.020 # Sum Normalized Reward : 17.526987022691376 (primary score) +EPSILON = 0.030 # Sum Normalized Reward : 16.796885571003942 (primary score) +EPSILON = 0.040 # Sum Normalized Reward : 17.280787151431426 (primary score) +EPSILON = 0.050 # Sum Normalized Reward : 16.256945636647025 (primary score) +EPSILON = 0.100 # Sum Normalized Reward : 14.828347241759966 (primary score) +EPSILON = 0.200 # Sum Normalized Reward : 11.192330074898457 (primary score) +EPSILON = 0.300 # Sum Normalized Reward : 14.523067754608782 (primary score) +EPSILON = 0.400 # Sum Normalized Reward : 12.901508220410834 (primary score) +EPSILON = 0.500 # Sum Normalized Reward : 3.754660231871272 (primary score) +EPSILON = 1.000 # Sum Normalized Reward : 1.397180159192391 (primary score) +''' + + import sys import time from argparse import Namespace @@ -6,7 +27,6 @@ from pathlib import Path import numpy as np from flatland.core.env_observation_builder import DummyObservationBuilder from flatland.envs.predictions import ShortestPathPredictorForRailEnv -from flatland.envs.rail_env import RailEnvActions from flatland.evaluators.client import FlatlandRemoteClient from flatland.evaluators.client import TimeoutException @@ -26,8 +46,9 @@ from reinforcement_learning.dddqn_policy import DDDQNPolicy VERBOSE = True # Checkpoint to use (remember to push it!) -# checkpoint = "./checkpoints/201112143850-5400.pth" # 21.220418678677177 DEPTH=2 AGENTS=10 -checkpoint = "./checkpoints/201117082153-1500.pth" # 21.570149424415636 DEPTH=2 AGENTS=10 +checkpoint = "./checkpoints/201124171810-7800.pth" # 18.249244799876152 DEPTH=2 AGENTS=10 + +EPSILON = 0.01 # Use last action cache USE_ACTION_CACHE = False @@ -108,7 +129,7 @@ while True: nb_hit = 0 if USE_DEAD_LOCK_AVOIDANCE_AGENT: - policy = DeadLockAvoidanceAgent(local_env) + policy = DeadLockAvoidanceAgent(local_env, action_size) while True: try: @@ -125,25 +146,26 @@ while True: policy.start_step() if USE_DEAD_LOCK_AVOIDANCE_AGENT: observation = np.zeros((local_env.get_num_agents(), 2)) - for agent in range(nb_agents): + for agent_handle in range(nb_agents): if USE_DEAD_LOCK_AVOIDANCE_AGENT: - observation[agent][0] = agent - observation[agent][1] = steps + observation[agent_handle][0] = agent_handle + observation[agent_handle][1] = steps - if info['action_required'][agent]: - if agent in agent_last_obs and np.all(agent_last_obs[agent] == observation[agent]): + if info['action_required'][agent_handle]: + if agent_handle in agent_last_obs and np.all( + agent_last_obs[agent_handle] == observation[agent_handle]): # cache hit - action = agent_last_action[agent] + action = agent_last_action[agent_handle] nb_hit += 1 else: - action = policy.act(observation[agent], eps=0.01) + action = policy.act(observation[agent_handle], eps=EPSILON) - action_dict[agent] = action + action_dict[agent_handle] = action - if USE_ACTION_CACHE: - agent_last_obs[agent] = observation[agent] - agent_last_action[agent] = action + if USE_ACTION_CACHE: + agent_last_obs[agent_handle] = observation[agent_handle] + agent_last_action[agent_handle] = action policy.end_step() agent_time = time.time() - time_start diff --git a/utils/dead_lock_avoidance_agent.py b/utils/dead_lock_avoidance_agent.py index 1d0b52ccd915f730fa98fe79db9336f66cb70116..4a371350333bbe6e8295331ada53e8f8ada83b3b 100644 --- a/utils/dead_lock_avoidance_agent.py +++ b/utils/dead_lock_avoidance_agent.py @@ -67,11 +67,13 @@ class DeadlockAvoidanceShortestDistanceWalker(ShortestDistanceWalker): class DeadLockAvoidanceAgent(Policy): - def __init__(self, env: RailEnv, show_debug_plot=False): + def __init__(self, env: RailEnv, action_size, show_debug_plot=False): self.env = env self.memory = None self.loss = 0 + self.action_size = action_size self.agent_can_move = {} + self.agent_can_move_value = {} self.switches = {} self.show_debug_plot = show_debug_plot @@ -79,12 +81,19 @@ class DeadLockAvoidanceAgent(Policy): pass def act(self, state, eps=0.): + # Epsilon-greedy action selection + if np.random.random() < eps: + return np.random.choice(np.arange(self.action_size)) + # agent = self.env.agents[state[0]] check = self.agent_can_move.get(state[0], None) if check is None: return RailEnvActions.STOP_MOVING return check[3] + def get_agent_can_move_value(self, handle): + return self.agent_can_move_value.get(handle, np.inf) + def reset(self): self.agent_positions = None self.shortest_distance_walker = None @@ -136,7 +145,8 @@ class DeadLockAvoidanceAgent(Policy): for handle in range(self.env.get_num_agents()): agent = self.env.agents[handle] if agent.status < RailAgentStatus.DONE: - next_step_ok = self.check_agent_can_move(shortest_distance_agent_map[handle], + next_step_ok = self.check_agent_can_move(handle, + shortest_distance_agent_map[handle], self.shortest_distance_walker.same_agent_map.get(handle, []), self.shortest_distance_walker.opp_agent_map.get(handle, []), full_shortest_distance_agent_map) @@ -154,6 +164,7 @@ class DeadLockAvoidanceAgent(Policy): plt.pause(0.01) def check_agent_can_move(self, + handle, my_shortest_walking_path, same_agents, opp_agents, @@ -166,6 +177,9 @@ class DeadLockAvoidanceAgent(Policy): delta = ((my_shortest_walking_path - opp - agent_positions_map) > 0).astype(int) if np.sum(delta) < (3 + len(opp_agents)): next_step_ok = False + v = self.agent_can_move_value.get(handle, np.inf) + v = min(v, np.sum(delta)) + self.agent_can_move_value.update({handle: v}) return next_step_ok def save(self, filename): diff --git a/utils/extra.py b/utils/extra.py index c4df6a8f1290f165d05616229e0bef55668d81e1..03cd4f902efc91ca0adf7569067c95f80eeb24b1 100644 --- a/utils/extra.py +++ b/utils/extra.py @@ -62,7 +62,7 @@ class Extra(ObservationBuilder): self.dead_lock_avoidance_agent = None if self.env is not None: self.env.dev_obs_dict = {} - self.dead_lock_avoidance_agent = DeadLockAvoidanceAgent(self.env, None, None) + self.dead_lock_avoidance_agent = DeadLockAvoidanceAgent(self.env, 5, False) self.switches = {} self.switches_neighbours = {} @@ -195,7 +195,8 @@ class Extra(ObservationBuilder): local_walker.walk_to_target(handle, new_position, branch_direction) shortest_distance_agent_map, _ = self.dead_lock_avoidance_agent.shortest_distance_walker.getData() my_shortest_path_to_check = shortest_distance_agent_map[handle] - next_step_ok = self.dead_lock_avoidance_agent.check_agent_can_move(my_shortest_path_to_check, + next_step_ok = self.dead_lock_avoidance_agent.check_agent_can_move(handle, + my_shortest_path_to_check, opp_agents, same_agents, full_shortest_distance_agent_map) diff --git a/utils/fast_tree_obs.py b/utils/fast_tree_obs.py index 1e3b507cbe9128073ee4c0ba79657e1acca48fb1..fa72bcf23b2a1a887c20ea7aec7e60248f4a932b 100755 --- a/utils/fast_tree_obs.py +++ b/utils/fast_tree_obs.py @@ -25,7 +25,7 @@ class FastTreeObs(ObservationBuilder): def __init__(self, max_depth): self.max_depth = max_depth - self.observation_dim = 33 + self.observation_dim = 36 def build_data(self): if self.env is not None: @@ -36,7 +36,7 @@ class FastTreeObs(ObservationBuilder): self.debug_render_path_list = [] if self.env is not None: self.find_all_cell_where_agent_can_choose() - self.dead_lock_avoidance_agent = DeadLockAvoidanceAgent(self.env) + self.dead_lock_avoidance_agent = DeadLockAvoidanceAgent(self.env, 5) else: self.dead_lock_avoidance_agent = None @@ -163,15 +163,16 @@ class FastTreeObs(ObservationBuilder): self.build_data() return - def _explore(self, handle, new_position, new_direction, depth=0): + def _explore(self, handle, new_position, new_direction, distance_map, depth=0): has_opp_agent = 0 has_same_agent = 0 has_target = 0 visited = [] + min_dist = distance_map[handle, new_position[0], new_position[1], new_direction] # stop exploring (max_depth reached) if depth >= self.max_depth: - return has_opp_agent, has_same_agent, has_target, visited + return has_opp_agent, has_same_agent, has_target, visited, min_dist # max_explore_steps = 100 -> just to ensure that the exploration ends cnt = 0 @@ -184,7 +185,7 @@ class FastTreeObs(ObservationBuilder): if self.env.agents[opp_a].direction != new_direction: # opp agent found -> stop exploring. This would be a strong signal. has_opp_agent = 1 - return has_opp_agent, has_same_agent, has_target, visited + return has_opp_agent, has_same_agent, has_target, visited, min_dist else: # same agent found # the agent can follow the agent, because this agent is still moving ahead and there shouldn't @@ -193,7 +194,7 @@ class FastTreeObs(ObservationBuilder): # target on this branch -> thus the agents should scan further whether there will be an opposite # agent walking on same track has_same_agent = 1 - # !NOT stop exploring! return has_opp_agent, has_same_agent, has_switch, visited + # !NOT stop exploring! return has_opp_agent, has_same_agent, has_switch, visited,min_dist # agents_on_switch == TRUE -> Current cell is a switch where the agent can decide (branch) in exploration # agent_near_to_switch == TRUE -> One cell before the switch, where the agent can decide @@ -204,7 +205,7 @@ class FastTreeObs(ObservationBuilder): if agents_near_to_switch: # The exploration was walking on a path where the agent can not decide # Best option would be MOVE_FORWARD -> Skip exploring - just walking - return has_opp_agent, has_same_agent, has_target, visited + return has_opp_agent, has_same_agent, has_target, visited, min_dist if self.env.agents[handle].target == new_position: has_target = 1 @@ -222,20 +223,24 @@ class FastTreeObs(ObservationBuilder): # --- OPEN RESEARCH QUESTION ---> is this good or shall we use full detailed information as # we did in the TreeObservation (FLATLAND) ? if possible_transitions[dir_loop] == 1: - hoa, hsa, ht, v = self._explore(handle, - get_new_position(new_position, dir_loop), - dir_loop, - depth + 1) + hoa, hsa, ht, v, m_dist = self._explore(handle, + get_new_position(new_position, dir_loop), + dir_loop, + distance_map, + depth + 1) visited.append(v) - has_opp_agent += hoa * 2 ** (-1 - depth) - has_same_agent += hsa * 2 ** (-1 - depth) + has_opp_agent += max(hoa, has_opp_agent) + has_same_agent += max(hsa, has_same_agent) has_target = max(has_target, ht) - return has_opp_agent, has_same_agent, has_target, visited + min_dist = min(min_dist, m_dist) + return has_opp_agent, has_same_agent, has_target, visited, min_dist else: new_direction = fast_argmax(possible_transitions) new_position = get_new_position(new_position, new_direction) - return has_opp_agent, has_same_agent, has_target, visited + min_dist = min(min_dist, distance_map[handle, new_position[0], new_position[1], new_direction]) + + return has_opp_agent, has_same_agent, has_target, visited, min_dist def get_many(self, handles: Optional[List[int]] = None): self.dead_lock_avoidance_agent.start_step() @@ -310,9 +315,14 @@ class FastTreeObs(ObservationBuilder): if not (np.math.isinf(new_cell_dist) and np.math.isinf(current_cell_dist)): observation[dir_loop] = int(new_cell_dist < current_cell_dist) - has_opp_agent, has_same_agent, has_target, v = self._explore(handle, new_position, branch_direction) + has_opp_agent, has_same_agent, has_target, v, min_dist = self._explore(handle, + new_position, + branch_direction, + distance_map) visited.append(v) + if not (np.math.isinf(min_dist) and np.math.isinf(current_cell_dist)): + observation[31 + dir_loop] = int(min_dist < current_cell_dist) observation[11 + dir_loop] = int(not np.math.isinf(new_cell_dist)) observation[15 + dir_loop] = has_opp_agent observation[19 + dir_loop] = has_same_agent @@ -332,8 +342,10 @@ class FastTreeObs(ObservationBuilder): action = self.dead_lock_avoidance_agent.act([handle], 0.0) observation[31] = int(action == RailEnvActions.STOP_MOVING) - observation[32] = int(fast_count_nonzero(possible_transitions) == 1) self.env.dev_obs_dict.update({handle: visited}) + observation[np.isinf(observation)] = -1 + observation[np.isnan(observation)] = -1 + return observation