diff --git a/reinforcement_learning/dddqn_policy.py b/reinforcement_learning/dddqn_policy.py index 3ff47b920f770a61369b46d6bbf0690af1cbc81d..32d71101ee54f10d481dba3362594998db52d716 100644 --- a/reinforcement_learning/dddqn_policy.py +++ b/reinforcement_learning/dddqn_policy.py @@ -119,10 +119,11 @@ class DDDQNPolicy(Policy): torch.save(self.qnetwork_target.state_dict(), filename + ".target") def load(self, filename): - if os.path.exists(filename + ".local"): + if os.path.exists(filename + ".local") and os.path.exists(filename + ".target"): self.qnetwork_local.load_state_dict(torch.load(filename + ".local")) - if os.path.exists(filename + ".target"): self.qnetwork_target.load_state_dict(torch.load(filename + ".target")) + else: + raise FileNotFoundError("Couldn't load policy from: '{}', '{}'".format(filename + ".local", filename + ".target")) def save_replay_buffer(self, filename): memory = self.memory.memory diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py old mode 100644 new mode 100755 index 7118a3a47855a5f77814e294f9067fd6feb16ce1..80e262018db4a3bdc38790fb5f1fd588c2e85516 --- a/reinforcement_learning/multi_agent_training.py +++ b/reinforcement_learning/multi_agent_training.py @@ -5,18 +5,17 @@ import sys from argparse import ArgumentParser, Namespace from pathlib import Path from pprint import pprint - import psutil from flatland.utils.rendertools import RenderTool from torch.utils.tensorboard import SummaryWriter import numpy as np import torch +from collections import deque from flatland.envs.rail_env import RailEnv, RailEnvActions from flatland.envs.rail_generators import sparse_rail_generator from flatland.envs.schedule_generators import sparse_schedule_generator from flatland.envs.observations import TreeObsForRailEnv - from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters from flatland.envs.predictions import ShortestPathPredictorForRailEnv @@ -25,6 +24,7 @@ sys.path.append(str(base_dir)) from utils.timer import Timer from utils.observation_utils import normalize_observation +from utils.fast_tree_obs import FastTreeObs from reinforcement_learning.dddqn_policy import DDDQNPolicy try: @@ -110,7 +110,30 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): # Observation builder predictor = ShortestPathPredictorForRailEnv(observation_max_path_depth) - tree_observation = TreeObsForRailEnv(max_depth=observation_tree_depth, predictor=predictor) + if not train_params.use_extra_observation: + print("\nUsing standard TreeObs") + + def check_is_observation_valid(observation): + return observation + + def get_normalized_observation(observation, tree_depth: int, observation_radius=0): + return normalize_observation(observation, tree_depth, observation_radius) + + tree_observation = TreeObsForRailEnv(max_depth=observation_tree_depth, predictor=predictor) + tree_observation.check_is_observation_valid = check_is_observation_valid + tree_observation.get_normalized_observation = get_normalized_observation + else: + print("\nUsing FastTreeObs") + + def check_is_observation_valid(observation): + return True + + def get_normalized_observation(observation, tree_depth: int, observation_radius=0): + return observation + + tree_observation = FastTreeObs(max_depth=observation_tree_depth) + tree_observation.check_is_observation_valid = check_is_observation_valid + tree_observation.get_normalized_observation = get_normalized_observation # Setup the environments train_env = create_rail_env(train_env_params, tree_observation) @@ -118,15 +141,19 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): eval_env = create_rail_env(eval_env_params, tree_observation) eval_env.reset(regenerate_schedule=True, regenerate_rail=True) + if not train_params.use_extra_observation: + # Calculate the state size given the depth of the tree observation and the number of features + n_features_per_node = train_env.obs_builder.observation_dim + n_nodes = sum([np.power(4, i) for i in range(observation_tree_depth + 1)]) + state_size = n_features_per_node * n_nodes + else: + # Calculate the state size given the depth of the tree observation and the number of features + state_size = tree_observation.observation_dim + # Setup renderer if train_params.render: env_renderer = RenderTool(train_env, gl="PGL") - # Calculate the state size given the depth of the tree observation and the number of features - n_features_per_node = train_env.obs_builder.observation_dim - n_nodes = sum([np.power(4, i) for i in range(observation_tree_depth + 1)]) - state_size = n_features_per_node * n_nodes - # The action space of flatland is 5 discrete actions action_size = 5 @@ -144,14 +171,19 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): update_values = [False] * n_agents # Smoothed values used as target for hyperparameter tuning - smoothed_normalized_score = -1.0 smoothed_eval_normalized_score = -1.0 - smoothed_completion = 0.0 smoothed_eval_completion = 0.0 + scores_window = deque(maxlen=checkpoint_interval) # todo smooth when rendering instead + completion_window = deque(maxlen=checkpoint_interval) + # Double Dueling DQN policy policy = DDDQNPolicy(state_size, action_size, train_params) + # Load existing policy + if train_params.load_policy is not "": + policy.load(train_params.load_policy) + # Loads existing replay buffer if restore_replay_buffer: try: @@ -166,7 +198,9 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): hdd = psutil.disk_usage('/') if save_replay_buffer and (hdd.free / (2 ** 30)) < 500.0: - print("âš ï¸ Careful! Saving replay buffers will quickly consume a lot of disk space. You have {:.2f}gb left.".format(hdd.free / (2 ** 30))) + print( + "âš ï¸ Careful! Saving replay buffers will quickly consume a lot of disk space. You have {:.2f}gb left.".format( + hdd.free / (2 ** 30))) # TensorBoard writer writer = SummaryWriter() @@ -177,14 +211,15 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): training_timer = Timer() training_timer.start() - print("\n🚉 Training {} trains on {}x{} grid for {} episodes, evaluating on {} episodes every {} episodes. Training id '{}'.\n".format( - train_env.get_num_agents(), - x_dim, y_dim, - n_episodes, - n_eval_episodes, - checkpoint_interval, - training_id - )) + print( + "\n🚉 Training {} trains on {}x{} grid for {} episodes, evaluating on {} episodes every {} episodes. Training id '{}'.\n".format( + train_env.get_num_agents(), + x_dim, y_dim, + n_episodes, + n_eval_episodes, + checkpoint_interval, + training_id + )) for episode_idx in range(n_episodes + 1): step_timer = Timer() @@ -207,8 +242,9 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): # Build initial agent-specific observations for agent in train_env.get_agent_handles(): - if obs[agent]: - agent_obs[agent] = normalize_observation(obs[agent], observation_tree_depth, observation_radius=observation_radius) + if tree_observation.check_is_observation_valid(obs[agent]): + agent_obs[agent] = tree_observation.get_normalized_observation(obs[agent], observation_tree_depth, + observation_radius=observation_radius) agent_prev_obs[agent] = agent_obs[agent].copy() # Run episode @@ -217,6 +253,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): for agent in train_env.get_agent_handles(): if info['action_required'][agent]: update_values[agent] = True + action = policy.act(agent_obs[agent], eps=eps_start) action_count[action] += 1 @@ -255,9 +292,11 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): agent_prev_action[agent] = action_dict[agent] # Preprocess the new observations - if next_obs[agent]: + if tree_observation.check_is_observation_valid(next_obs[agent]): preproc_timer.start() - agent_obs[agent] = normalize_observation(next_obs[agent], observation_tree_depth, observation_radius=observation_radius) + agent_obs[agent] = tree_observation.get_normalized_observation(next_obs[agent], + observation_tree_depth, + observation_radius=observation_radius) preproc_timer.end() score += all_rewards[agent] @@ -274,12 +313,12 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): tasks_finished = sum(done[idx] for idx in train_env.get_agent_handles()) completion = tasks_finished / max(1, train_env.get_num_agents()) normalized_score = score / (max_steps * train_env.get_num_agents()) - action_probs = action_count / np.sum(action_count) - action_count = [1] * action_size + action_probs = action_count / max(1, np.sum(action_count)) - smoothing = 0.99 - smoothed_normalized_score = smoothed_normalized_score * smoothing + normalized_score * (1.0 - smoothing) - smoothed_completion = smoothed_completion * smoothing + completion * (1.0 - smoothing) + scores_window.append(normalized_score) + completion_window.append(completion) + smoothed_normalized_score = np.mean(scores_window) + smoothed_completion = np.mean(completion_window) # Print logs if episode_idx % checkpoint_interval == 0: @@ -291,12 +330,15 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): if train_params.render: env_renderer.close_window() + # reset action count + action_count = [0] * action_size + print( '\r🚂 Episode {}' - '\t 🆠Score: {:.3f}' - ' Avg: {:.3f}' - '\t 💯 Done: {:.2f}%' - ' Avg: {:.2f}%' + '\t 🆠Score: {:7.3f}' + ' Avg: {:7.3f}' + '\t 💯 Done: {:6.2f}%' + ' Avg: {:6.2f}%' '\t 🎲 Epsilon: {:.3f} ' '\t 🔀 Action Probs: {}'.format( episode_idx, @@ -310,7 +352,11 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): # Evaluate policy and log results at some interval if episode_idx % checkpoint_interval == 0 and n_eval_episodes > 0: - scores, completions, nb_steps_eval = eval_policy(eval_env, policy, train_params, obs_params) + scores, completions, nb_steps_eval = eval_policy(eval_env, + tree_observation, + policy, + train_params, + obs_params) writer.add_scalar("evaluation/scores_min", np.min(scores), episode_idx) writer.add_scalar("evaluation/scores_max", np.max(scores), episode_idx) @@ -329,7 +375,8 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): writer.add_histogram("evaluation/nb_steps", np.array(nb_steps_eval), episode_idx) smoothing = 0.9 - smoothed_eval_normalized_score = smoothed_eval_normalized_score * smoothing + np.mean(scores) * (1.0 - smoothing) + smoothed_eval_normalized_score = smoothed_eval_normalized_score * smoothing + np.mean(scores) * ( + 1.0 - smoothing) smoothed_eval_completion = smoothed_eval_completion * smoothing + np.mean(completions) * (1.0 - smoothing) writer.add_scalar("evaluation/smoothed_score", smoothed_eval_normalized_score, episode_idx) writer.add_scalar("evaluation/smoothed_completion", smoothed_eval_completion, episode_idx) @@ -367,7 +414,7 @@ def format_action_prob(action_probs): return buffer -def eval_policy(env, policy, train_params, obs_params): +def eval_policy(env, tree_observation, policy, train_params, obs_params): n_eval_episodes = train_params.n_evaluation_episodes max_steps = env._max_episode_steps tree_depth = obs_params.observation_tree_depth @@ -388,12 +435,14 @@ def eval_policy(env, policy, train_params, obs_params): for step in range(max_steps - 1): for agent in env.get_agent_handles(): - if obs[agent]: - agent_obs[agent] = normalize_observation(obs[agent], tree_depth=tree_depth, observation_radius=observation_radius) + if tree_observation.check_is_observation_valid(agent_obs[agent]): + agent_obs[agent] = tree_observation.get_normalized_observation(obs[agent], tree_depth=tree_depth, + observation_radius=observation_radius) action = 0 if info['action_required'][agent]: - action = policy.act(agent_obs[agent], eps=0.0) + if tree_observation.check_is_observation_valid(agent_obs[agent]): + action = policy.act(agent_obs[agent], eps=0.0) action_dict.update({agent: action}) obs, all_rewards, done, info = env.step(action_dict) @@ -424,8 +473,9 @@ if __name__ == "__main__": parser = ArgumentParser() parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=2500, type=int) parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=0, type=int) - parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=0, type=int) - parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=25, type=int) + parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=0, + type=int) + parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=50, type=int) parser.add_argument("--checkpoint_interval", help="checkpoint interval", default=100, type=int) parser.add_argument("--eps_start", help="max exploration", default=1.0, type=float) parser.add_argument("--eps_end", help="min exploration", default=0.01, type=float) @@ -433,7 +483,8 @@ if __name__ == "__main__": parser.add_argument("--buffer_size", help="replay buffer size", default=int(1e5), type=int) parser.add_argument("--buffer_min_size", help="min buffer size to start training", default=0, type=int) parser.add_argument("--restore_replay_buffer", help="replay buffer to restore", default="", type=str) - parser.add_argument("--save_replay_buffer", help="save replay buffer at each evaluation interval", default=False, type=bool) + parser.add_argument("--save_replay_buffer", help="save replay buffer at each evaluation interval", default=False, + type=bool) parser.add_argument("--batch_size", help="minibatch size", default=128, type=int) parser.add_argument("--gamma", help="discount factor", default=0.99, type=float) parser.add_argument("--tau", help="soft update of target parameters", default=1e-3, type=float) @@ -442,9 +493,12 @@ if __name__ == "__main__": parser.add_argument("--update_every", help="how often to update the network", default=8, type=int) parser.add_argument("--use_gpu", help="use GPU if available", default=False, type=bool) parser.add_argument("--num_threads", help="number of threads PyTorch can use", default=1, type=int) - parser.add_argument("--render", help="render 1 episode in 100", default=False, type=bool) - training_params = parser.parse_args() + parser.add_argument("--render", help="render 1 episode in 100", action='store_true') + parser.add_argument("--load_policy", help="policy filename (reference) to load", default="", type=str) + parser.add_argument("--use_extra_observation", help="extra observation", action='store_true') + parser.add_argument("--max_depth", help="max depth", default=2, type=int) + training_params = parser.parse_args() env_params = [ { # Test_0 @@ -482,14 +536,16 @@ if __name__ == "__main__": ] obs_params = { - "observation_tree_depth": 2, + "observation_tree_depth": training_params.max_depth, "observation_radius": 10, "observation_max_path_depth": 30 } + def check_env_config(id): if id >= len(env_params) or id < 0: - print("\n🛑 Invalid environment configuration, only Test_0 to Test_{} are supported.".format(len(env_params) - 1)) + print("\n🛑 Invalid environment configuration, only Test_0 to Test_{} are supported.".format( + len(env_params) - 1)) exit(1) @@ -499,6 +555,10 @@ if __name__ == "__main__": training_env_params = env_params[training_params.training_env_config] evaluation_env_params = env_params[training_params.evaluation_env_config] + # FIXME hard-coded for sweep search + # see https://wb-forum.slack.com/archives/CL4V2QE59/p1602931982236600 to implement properly + # training_params.use_extra_observation = True + print("\nTraining parameters:") pprint(vars(training_params)) print("\nTraining environment parameters (Test_{}):".format(training_params.training_env_config)) @@ -509,4 +569,5 @@ if __name__ == "__main__": pprint(obs_params) os.environ["OMP_NUM_THREADS"] = str(training_params.num_threads) - train_agent(training_params, Namespace(**training_env_params), Namespace(**evaluation_env_params), Namespace(**obs_params)) + train_agent(training_params, Namespace(**training_env_params), Namespace(**evaluation_env_params), + Namespace(**obs_params)) diff --git a/utils/fast_tree_obs.py b/utils/fast_tree_obs.py new file mode 100755 index 0000000000000000000000000000000000000000..12c91cabb481ed77bc69704ab1c548e6b447b24f --- /dev/null +++ b/utils/fast_tree_obs.py @@ -0,0 +1,301 @@ +import numpy as np +from flatland.core.env_observation_builder import ObservationBuilder +from flatland.core.grid.grid4_utils import get_new_position +from flatland.envs.agent_utils import RailAgentStatus +from flatland.envs.rail_env import fast_count_nonzero, fast_argmax + + +""" +LICENCE for the FastTreeObs Observation Builder + +The observation can be used freely and reused for further submissions. Only the author needs to be referred to +/mentioned in any submissions - if the entire observation or parts, or the main idea is used. + +Author: Adrian Egli (adrian.egli@gmail.com) + +[Linkedin](https://www.researchgate.net/profile/Adrian_Egli2) +[Researchgate](https://www.linkedin.com/in/adrian-egli-733a9544/) +""" + +class FastTreeObs(ObservationBuilder): + + def __init__(self, max_depth): + self.max_depth = max_depth + self.observation_dim = 26 + + def build_data(self): + if self.env is not None: + self.env.dev_obs_dict = {} + self.switches = {} + self.switches_neighbours = {} + self.debug_render_list = [] + self.debug_render_path_list = [] + if self.env is not None: + self.find_all_cell_where_agent_can_choose() + + def find_all_cell_where_agent_can_choose(self): + switches = {} + for h in range(self.env.height): + for w in range(self.env.width): + pos = (h, w) + for dir in range(4): + possible_transitions = self.env.rail.get_transitions(*pos, dir) + num_transitions = fast_count_nonzero(possible_transitions) + if num_transitions > 1: + if pos not in switches.keys(): + switches.update({pos: [dir]}) + else: + switches[pos].append(dir) + + switches_neighbours = {} + for h in range(self.env.height): + for w in range(self.env.width): + # look one step forward + for dir in range(4): + pos = (h, w) + possible_transitions = self.env.rail.get_transitions(*pos, dir) + for d in range(4): + if possible_transitions[d] == 1: + new_cell = get_new_position(pos, d) + if new_cell in switches.keys() and pos not in switches.keys(): + if pos not in switches_neighbours.keys(): + switches_neighbours.update({pos: [dir]}) + else: + switches_neighbours[pos].append(dir) + + self.switches = switches + self.switches_neighbours = switches_neighbours + + def check_agent_decision(self, position, direction): + switches = self.switches + switches_neighbours = self.switches_neighbours + agents_on_switch = False + agents_on_switch_all = False + agents_near_to_switch = False + agents_near_to_switch_all = False + if position in switches.keys(): + agents_on_switch = direction in switches[position] + agents_on_switch_all = True + + if position in switches_neighbours.keys(): + new_cell = get_new_position(position, direction) + if new_cell in switches.keys(): + if not direction in switches[new_cell]: + agents_near_to_switch = direction in switches_neighbours[position] + else: + agents_near_to_switch = direction in switches_neighbours[position] + + agents_near_to_switch_all = direction in switches_neighbours[position] + + return agents_on_switch, agents_near_to_switch, agents_near_to_switch_all, agents_on_switch_all + + def required_agent_decision(self): + agents_can_choose = {} + agents_on_switch = {} + agents_on_switch_all = {} + agents_near_to_switch = {} + agents_near_to_switch_all = {} + for a in range(self.env.get_num_agents()): + ret_agents_on_switch, ret_agents_near_to_switch, ret_agents_near_to_switch_all, ret_agents_on_switch_all = \ + self.check_agent_decision( + self.env.agents[a].position, + self.env.agents[a].direction) + agents_on_switch.update({a: ret_agents_on_switch}) + agents_on_switch_all.update({a: ret_agents_on_switch_all}) + ready_to_depart = self.env.agents[a].status == RailAgentStatus.READY_TO_DEPART + agents_near_to_switch.update({a: (ret_agents_near_to_switch and not ready_to_depart)}) + + agents_can_choose.update({a: agents_on_switch[a] or agents_near_to_switch[a]}) + + agents_near_to_switch_all.update({a: (ret_agents_near_to_switch_all and not ready_to_depart)}) + + return agents_can_choose, agents_on_switch, agents_near_to_switch, agents_near_to_switch_all, agents_on_switch_all + + def debug_render(self, env_renderer): + agents_can_choose, agents_on_switch, agents_near_to_switch, agents_near_to_switch_all = \ + self.required_agent_decision() + self.env.dev_obs_dict = {} + for a in range(max(3, self.env.get_num_agents())): + self.env.dev_obs_dict.update({a: []}) + + selected_agent = None + if agents_can_choose[0]: + if self.env.agents[0].position is not None: + self.debug_render_list.append(self.env.agents[0].position) + else: + self.debug_render_list.append(self.env.agents[0].initial_position) + + if self.env.agents[0].position is not None: + self.debug_render_path_list.append(self.env.agents[0].position) + else: + self.debug_render_path_list.append(self.env.agents[0].initial_position) + + env_renderer.gl.agent_colors[0] = env_renderer.gl.rgb_s2i("FF0000") + env_renderer.gl.agent_colors[1] = env_renderer.gl.rgb_s2i("666600") + env_renderer.gl.agent_colors[2] = env_renderer.gl.rgb_s2i("006666") + env_renderer.gl.agent_colors[3] = env_renderer.gl.rgb_s2i("550000") + + self.env.dev_obs_dict[0] = self.debug_render_list + self.env.dev_obs_dict[1] = self.switches.keys() + self.env.dev_obs_dict[2] = self.switches_neighbours.keys() + self.env.dev_obs_dict[3] = self.debug_render_path_list + + def reset(self): + self.build_data() + return + + def fast_argmax(self, array): + if array[0] == 1: + return 0 + if array[1] == 1: + return 1 + if array[2] == 1: + return 2 + return 3 + + def _explore(self, handle, new_position, new_direction, depth=0): + has_opp_agent = 0 + has_same_agent = 0 + has_switch = 0 + visited = [] + + # stop exploring (max_depth reached) + if depth >= self.max_depth: + return has_opp_agent, has_same_agent, has_switch, visited + + # max_explore_steps = 100 + cnt = 0 + while cnt < 100: + cnt += 1 + + visited.append(new_position) + opp_a = self.env.agent_positions[new_position] + if opp_a != -1 and opp_a != handle: + if self.env.agents[opp_a].direction != new_direction: + # opp agent found + has_opp_agent = 1 + return has_opp_agent, has_same_agent, has_switch, visited + else: + has_same_agent = 1 + return has_opp_agent, has_same_agent, has_switch, visited + + # convert one-hot encoding to 0,1,2,3 + agents_on_switch, \ + agents_near_to_switch, \ + agents_near_to_switch_all, \ + agents_on_switch_all = \ + self.check_agent_decision(new_position, new_direction) + if agents_near_to_switch: + return has_opp_agent, has_same_agent, has_switch, visited + + possible_transitions = self.env.rail.get_transitions(*new_position, new_direction) + if agents_on_switch: + f = 0 + for dir_loop in range(4): + if possible_transitions[dir_loop] == 1: + f += 1 + hoa, hsa, hs, v = self._explore(handle, + get_new_position(new_position, dir_loop), + dir_loop, + depth + 1) + visited.append(v) + has_opp_agent += hoa + has_same_agent += hsa + has_switch += hs + f = max(f, 1.0) + return has_opp_agent / f, has_same_agent / f, has_switch / f, visited + else: + new_direction = fast_argmax(possible_transitions) + new_position = get_new_position(new_position, new_direction) + + return has_opp_agent, has_same_agent, has_switch, visited + + def get(self, handle): + # all values are [0,1] + # observation[0] : 1 path towards target (direction 0) / otherwise 0 -> path is longer or there is no path + # observation[1] : 1 path towards target (direction 1) / otherwise 0 -> path is longer or there is no path + # observation[2] : 1 path towards target (direction 2) / otherwise 0 -> path is longer or there is no path + # observation[3] : 1 path towards target (direction 3) / otherwise 0 -> path is longer or there is no path + # observation[4] : int(agent.status == RailAgentStatus.READY_TO_DEPART) + # observation[5] : int(agent.status == RailAgentStatus.ACTIVE) + # observation[6] : int(agent.status == RailAgentStatus.DONE or agent.status == RailAgentStatus.DONE_REMOVED) + # observation[7] : current agent is located at a switch, where it can take a routing decision + # observation[8] : current agent is located at a cell, where it has to take a stop-or-go decision + # observation[9] : current agent is located one step before/after a switch + # observation[10] : 1 if there is a path (track/branch) otherwise 0 (direction 0) + # observation[11] : 1 if there is a path (track/branch) otherwise 0 (direction 1) + # observation[12] : 1 if there is a path (track/branch) otherwise 0 (direction 2) + # observation[13] : 1 if there is a path (track/branch) otherwise 0 (direction 3) + # observation[14] : If there is a path with step (direction 0) and there is a agent with opposite direction -> 1 + # observation[15] : If there is a path with step (direction 1) and there is a agent with opposite direction -> 1 + # observation[16] : If there is a path with step (direction 2) and there is a agent with opposite direction -> 1 + # observation[17] : If there is a path with step (direction 3) and there is a agent with opposite direction -> 1 + # observation[18] : If there is a path with step (direction 0) and there is a agent with same direction -> 1 + # observation[19] : If there is a path with step (direction 1) and there is a agent with same direction -> 1 + # observation[20] : If there is a path with step (direction 2) and there is a agent with same direction -> 1 + # observation[21] : If there is a path with step (direction 3) and there is a agent with same direction -> 1 + # observation[22] : If there is a switch on the path which agent can not use -> 1 + # observation[23] : If there is a switch on the path which agent can not use -> 1 + # observation[24] : If there is a switch on the path which agent can not use -> 1 + # observation[25] : If there is a switch on the path which agent can not use -> 1 + + observation = np.zeros(self.observation_dim) + visited = [] + agent = self.env.agents[handle] + + agent_done = False + if agent.status == RailAgentStatus.READY_TO_DEPART: + agent_virtual_position = agent.initial_position + observation[4] = 1 + elif agent.status == RailAgentStatus.ACTIVE: + agent_virtual_position = agent.position + observation[5] = 1 + else: + observation[6] = 1 + agent_virtual_position = (-1, -1) + agent_done = True + + if not agent_done: + visited.append(agent_virtual_position) + distance_map = self.env.distance_map.get() + current_cell_dist = distance_map[handle, + agent_virtual_position[0], agent_virtual_position[1], + agent.direction] + possible_transitions = self.env.rail.get_transitions(*agent_virtual_position, agent.direction) + orientation = agent.direction + if fast_count_nonzero(possible_transitions) == 1: + orientation = fast_argmax(possible_transitions) + + for dir_loop, branch_direction in enumerate([(orientation + dir_loop) % 4 for dir_loop in range(-1, 3)]): + if possible_transitions[branch_direction]: + new_position = get_new_position(agent_virtual_position, branch_direction) + new_cell_dist = distance_map[handle, + new_position[0], new_position[1], + branch_direction] + if not (np.math.isinf(new_cell_dist) and np.math.isinf(current_cell_dist)): + observation[dir_loop] = int(new_cell_dist < current_cell_dist) + + has_opp_agent, has_same_agent, has_switch, v = self._explore(handle, new_position, branch_direction) + visited.append(v) + + observation[10 + dir_loop] = 1 + observation[14 + dir_loop] = has_opp_agent + observation[18 + dir_loop] = has_same_agent + observation[22 + dir_loop] = has_switch + + agents_on_switch, \ + agents_near_to_switch, \ + agents_near_to_switch_all, \ + agents_on_switch_all = \ + self.check_agent_decision(agent_virtual_position, agent.direction) + observation[7] = int(agents_on_switch) + observation[8] = int(agents_near_to_switch) + observation[9] = int(agents_near_to_switch_all) + + self.env.dev_obs_dict.update({handle: visited}) + + return observation + + @staticmethod + def agent_can_choose(observation): + return observation[7] == 1 or observation[8] == 1