From 65a41843602eec858a2bc5f8929f380c065231b1 Mon Sep 17 00:00:00 2001
From: flaurent <florian.laurent@gmail.com>
Date: Fri, 23 Oct 2020 19:04:07 +0200
Subject: [PATCH] FastTreeObs!

---
 reinforcement_learning/dddqn_policy.py        |   5 +-
 .../multi_agent_training.py                   | 153 ++++++---
 utils/fast_tree_obs.py                        | 301 ++++++++++++++++++
 3 files changed, 411 insertions(+), 48 deletions(-)
 mode change 100644 => 100755 reinforcement_learning/multi_agent_training.py
 create mode 100755 utils/fast_tree_obs.py

diff --git a/reinforcement_learning/dddqn_policy.py b/reinforcement_learning/dddqn_policy.py
index 3ff47b9..32d7110 100644
--- a/reinforcement_learning/dddqn_policy.py
+++ b/reinforcement_learning/dddqn_policy.py
@@ -119,10 +119,11 @@ class DDDQNPolicy(Policy):
         torch.save(self.qnetwork_target.state_dict(), filename + ".target")
 
     def load(self, filename):
-        if os.path.exists(filename + ".local"):
+        if os.path.exists(filename + ".local") and os.path.exists(filename + ".target"):
             self.qnetwork_local.load_state_dict(torch.load(filename + ".local"))
-        if os.path.exists(filename + ".target"):
             self.qnetwork_target.load_state_dict(torch.load(filename + ".target"))
+        else:
+            raise FileNotFoundError("Couldn't load policy from: '{}', '{}'".format(filename + ".local", filename + ".target"))
 
     def save_replay_buffer(self, filename):
         memory = self.memory.memory
diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py
old mode 100644
new mode 100755
index 7118a3a..80e2620
--- a/reinforcement_learning/multi_agent_training.py
+++ b/reinforcement_learning/multi_agent_training.py
@@ -5,18 +5,17 @@ import sys
 from argparse import ArgumentParser, Namespace
 from pathlib import Path
 from pprint import pprint
-
 import psutil
 from flatland.utils.rendertools import RenderTool
 from torch.utils.tensorboard import SummaryWriter
 import numpy as np
 import torch
+from collections import deque
 
 from flatland.envs.rail_env import RailEnv, RailEnvActions
 from flatland.envs.rail_generators import sparse_rail_generator
 from flatland.envs.schedule_generators import sparse_schedule_generator
 from flatland.envs.observations import TreeObsForRailEnv
-
 from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 
@@ -25,6 +24,7 @@ sys.path.append(str(base_dir))
 
 from utils.timer import Timer
 from utils.observation_utils import normalize_observation
+from utils.fast_tree_obs import FastTreeObs
 from reinforcement_learning.dddqn_policy import DDDQNPolicy
 
 try:
@@ -110,7 +110,30 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
 
     # Observation builder
     predictor = ShortestPathPredictorForRailEnv(observation_max_path_depth)
-    tree_observation = TreeObsForRailEnv(max_depth=observation_tree_depth, predictor=predictor)
+    if not train_params.use_extra_observation:
+        print("\nUsing standard TreeObs")
+
+        def check_is_observation_valid(observation):
+            return observation
+
+        def get_normalized_observation(observation, tree_depth: int, observation_radius=0):
+            return normalize_observation(observation, tree_depth, observation_radius)
+
+        tree_observation = TreeObsForRailEnv(max_depth=observation_tree_depth, predictor=predictor)
+        tree_observation.check_is_observation_valid = check_is_observation_valid
+        tree_observation.get_normalized_observation = get_normalized_observation
+    else:
+        print("\nUsing FastTreeObs")
+
+        def check_is_observation_valid(observation):
+            return True
+
+        def get_normalized_observation(observation, tree_depth: int, observation_radius=0):
+            return observation
+
+        tree_observation = FastTreeObs(max_depth=observation_tree_depth)
+        tree_observation.check_is_observation_valid = check_is_observation_valid
+        tree_observation.get_normalized_observation = get_normalized_observation
 
     # Setup the environments
     train_env = create_rail_env(train_env_params, tree_observation)
@@ -118,15 +141,19 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
     eval_env = create_rail_env(eval_env_params, tree_observation)
     eval_env.reset(regenerate_schedule=True, regenerate_rail=True)
 
+    if not train_params.use_extra_observation:
+        # Calculate the state size given the depth of the tree observation and the number of features
+        n_features_per_node = train_env.obs_builder.observation_dim
+        n_nodes = sum([np.power(4, i) for i in range(observation_tree_depth + 1)])
+        state_size = n_features_per_node * n_nodes
+    else:
+        # Calculate the state size given the depth of the tree observation and the number of features
+        state_size = tree_observation.observation_dim
+
     # Setup renderer
     if train_params.render:
         env_renderer = RenderTool(train_env, gl="PGL")
 
-    # Calculate the state size given the depth of the tree observation and the number of features
-    n_features_per_node = train_env.obs_builder.observation_dim
-    n_nodes = sum([np.power(4, i) for i in range(observation_tree_depth + 1)])
-    state_size = n_features_per_node * n_nodes
-
     # The action space of flatland is 5 discrete actions
     action_size = 5
 
@@ -144,14 +171,19 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
     update_values = [False] * n_agents
 
     # Smoothed values used as target for hyperparameter tuning
-    smoothed_normalized_score = -1.0
     smoothed_eval_normalized_score = -1.0
-    smoothed_completion = 0.0
     smoothed_eval_completion = 0.0
 
+    scores_window = deque(maxlen=checkpoint_interval)  # todo smooth when rendering instead
+    completion_window = deque(maxlen=checkpoint_interval)
+
     # Double Dueling DQN policy
     policy = DDDQNPolicy(state_size, action_size, train_params)
 
+    # Load existing policy
+    if train_params.load_policy is not "":
+        policy.load(train_params.load_policy)
+
     # Loads existing replay buffer
     if restore_replay_buffer:
         try:
@@ -166,7 +198,9 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
 
     hdd = psutil.disk_usage('/')
     if save_replay_buffer and (hdd.free / (2 ** 30)) < 500.0:
-        print("⚠️  Careful! Saving replay buffers will quickly consume a lot of disk space. You have {:.2f}gb left.".format(hdd.free / (2 ** 30)))
+        print(
+            "⚠️  Careful! Saving replay buffers will quickly consume a lot of disk space. You have {:.2f}gb left.".format(
+                hdd.free / (2 ** 30)))
 
     # TensorBoard writer
     writer = SummaryWriter()
@@ -177,14 +211,15 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
     training_timer = Timer()
     training_timer.start()
 
-    print("\n🚉 Training {} trains on {}x{} grid for {} episodes, evaluating on {} episodes every {} episodes. Training id '{}'.\n".format(
-        train_env.get_num_agents(),
-        x_dim, y_dim,
-        n_episodes,
-        n_eval_episodes,
-        checkpoint_interval,
-        training_id
-    ))
+    print(
+        "\n🚉 Training {} trains on {}x{} grid for {} episodes, evaluating on {} episodes every {} episodes. Training id '{}'.\n".format(
+            train_env.get_num_agents(),
+            x_dim, y_dim,
+            n_episodes,
+            n_eval_episodes,
+            checkpoint_interval,
+            training_id
+        ))
 
     for episode_idx in range(n_episodes + 1):
         step_timer = Timer()
@@ -207,8 +242,9 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
 
         # Build initial agent-specific observations
         for agent in train_env.get_agent_handles():
-            if obs[agent]:
-                agent_obs[agent] = normalize_observation(obs[agent], observation_tree_depth, observation_radius=observation_radius)
+            if tree_observation.check_is_observation_valid(obs[agent]):
+                agent_obs[agent] = tree_observation.get_normalized_observation(obs[agent], observation_tree_depth,
+                                                                               observation_radius=observation_radius)
                 agent_prev_obs[agent] = agent_obs[agent].copy()
 
         # Run episode
@@ -217,6 +253,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
             for agent in train_env.get_agent_handles():
                 if info['action_required'][agent]:
                     update_values[agent] = True
+
                     action = policy.act(agent_obs[agent], eps=eps_start)
 
                     action_count[action] += 1
@@ -255,9 +292,11 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
                     agent_prev_action[agent] = action_dict[agent]
 
                 # Preprocess the new observations
-                if next_obs[agent]:
+                if tree_observation.check_is_observation_valid(next_obs[agent]):
                     preproc_timer.start()
-                    agent_obs[agent] = normalize_observation(next_obs[agent], observation_tree_depth, observation_radius=observation_radius)
+                    agent_obs[agent] = tree_observation.get_normalized_observation(next_obs[agent],
+                                                                                   observation_tree_depth,
+                                                                                   observation_radius=observation_radius)
                     preproc_timer.end()
 
                 score += all_rewards[agent]
@@ -274,12 +313,12 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
         tasks_finished = sum(done[idx] for idx in train_env.get_agent_handles())
         completion = tasks_finished / max(1, train_env.get_num_agents())
         normalized_score = score / (max_steps * train_env.get_num_agents())
-        action_probs = action_count / np.sum(action_count)
-        action_count = [1] * action_size
+        action_probs = action_count / max(1, np.sum(action_count))
 
-        smoothing = 0.99
-        smoothed_normalized_score = smoothed_normalized_score * smoothing + normalized_score * (1.0 - smoothing)
-        smoothed_completion = smoothed_completion * smoothing + completion * (1.0 - smoothing)
+        scores_window.append(normalized_score)
+        completion_window.append(completion)
+        smoothed_normalized_score = np.mean(scores_window)
+        smoothed_completion = np.mean(completion_window)
 
         # Print logs
         if episode_idx % checkpoint_interval == 0:
@@ -291,12 +330,15 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
             if train_params.render:
                 env_renderer.close_window()
 
+            # reset action count
+            action_count = [0] * action_size
+
         print(
             '\r🚂 Episode {}'
-            '\t 🏆 Score: {:.3f}'
-            ' Avg: {:.3f}'
-            '\t 💯 Done: {:.2f}%'
-            ' Avg: {:.2f}%'
+            '\t 🏆 Score: {:7.3f}'
+            ' Avg: {:7.3f}'
+            '\t 💯 Done: {:6.2f}%'
+            ' Avg: {:6.2f}%'
             '\t 🎲 Epsilon: {:.3f} '
             '\t 🔀 Action Probs: {}'.format(
                 episode_idx,
@@ -310,7 +352,11 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
 
         # Evaluate policy and log results at some interval
         if episode_idx % checkpoint_interval == 0 and n_eval_episodes > 0:
-            scores, completions, nb_steps_eval = eval_policy(eval_env, policy, train_params, obs_params)
+            scores, completions, nb_steps_eval = eval_policy(eval_env,
+                                                             tree_observation,
+                                                             policy,
+                                                             train_params,
+                                                             obs_params)
 
             writer.add_scalar("evaluation/scores_min", np.min(scores), episode_idx)
             writer.add_scalar("evaluation/scores_max", np.max(scores), episode_idx)
@@ -329,7 +375,8 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
             writer.add_histogram("evaluation/nb_steps", np.array(nb_steps_eval), episode_idx)
 
             smoothing = 0.9
-            smoothed_eval_normalized_score = smoothed_eval_normalized_score * smoothing + np.mean(scores) * (1.0 - smoothing)
+            smoothed_eval_normalized_score = smoothed_eval_normalized_score * smoothing + np.mean(scores) * (
+                    1.0 - smoothing)
             smoothed_eval_completion = smoothed_eval_completion * smoothing + np.mean(completions) * (1.0 - smoothing)
             writer.add_scalar("evaluation/smoothed_score", smoothed_eval_normalized_score, episode_idx)
             writer.add_scalar("evaluation/smoothed_completion", smoothed_eval_completion, episode_idx)
@@ -367,7 +414,7 @@ def format_action_prob(action_probs):
     return buffer
 
 
-def eval_policy(env, policy, train_params, obs_params):
+def eval_policy(env, tree_observation, policy, train_params, obs_params):
     n_eval_episodes = train_params.n_evaluation_episodes
     max_steps = env._max_episode_steps
     tree_depth = obs_params.observation_tree_depth
@@ -388,12 +435,14 @@ def eval_policy(env, policy, train_params, obs_params):
 
         for step in range(max_steps - 1):
             for agent in env.get_agent_handles():
-                if obs[agent]:
-                    agent_obs[agent] = normalize_observation(obs[agent], tree_depth=tree_depth, observation_radius=observation_radius)
+                if tree_observation.check_is_observation_valid(agent_obs[agent]):
+                    agent_obs[agent] = tree_observation.get_normalized_observation(obs[agent], tree_depth=tree_depth,
+                                                                                   observation_radius=observation_radius)
 
                 action = 0
                 if info['action_required'][agent]:
-                    action = policy.act(agent_obs[agent], eps=0.0)
+                    if tree_observation.check_is_observation_valid(agent_obs[agent]):
+                        action = policy.act(agent_obs[agent], eps=0.0)
                 action_dict.update({agent: action})
 
             obs, all_rewards, done, info = env.step(action_dict)
@@ -424,8 +473,9 @@ if __name__ == "__main__":
     parser = ArgumentParser()
     parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=2500, type=int)
     parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=0, type=int)
-    parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=0, type=int)
-    parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=25, type=int)
+    parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=0,
+                        type=int)
+    parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=50, type=int)
     parser.add_argument("--checkpoint_interval", help="checkpoint interval", default=100, type=int)
     parser.add_argument("--eps_start", help="max exploration", default=1.0, type=float)
     parser.add_argument("--eps_end", help="min exploration", default=0.01, type=float)
@@ -433,7 +483,8 @@ if __name__ == "__main__":
     parser.add_argument("--buffer_size", help="replay buffer size", default=int(1e5), type=int)
     parser.add_argument("--buffer_min_size", help="min buffer size to start training", default=0, type=int)
     parser.add_argument("--restore_replay_buffer", help="replay buffer to restore", default="", type=str)
-    parser.add_argument("--save_replay_buffer", help="save replay buffer at each evaluation interval", default=False, type=bool)
+    parser.add_argument("--save_replay_buffer", help="save replay buffer at each evaluation interval", default=False,
+                        type=bool)
     parser.add_argument("--batch_size", help="minibatch size", default=128, type=int)
     parser.add_argument("--gamma", help="discount factor", default=0.99, type=float)
     parser.add_argument("--tau", help="soft update of target parameters", default=1e-3, type=float)
@@ -442,9 +493,12 @@ if __name__ == "__main__":
     parser.add_argument("--update_every", help="how often to update the network", default=8, type=int)
     parser.add_argument("--use_gpu", help="use GPU if available", default=False, type=bool)
     parser.add_argument("--num_threads", help="number of threads PyTorch can use", default=1, type=int)
-    parser.add_argument("--render", help="render 1 episode in 100", default=False, type=bool)
-    training_params = parser.parse_args()
+    parser.add_argument("--render", help="render 1 episode in 100", action='store_true')
+    parser.add_argument("--load_policy", help="policy filename (reference) to load", default="", type=str)
+    parser.add_argument("--use_extra_observation", help="extra observation", action='store_true')
+    parser.add_argument("--max_depth", help="max depth", default=2, type=int)
 
+    training_params = parser.parse_args()
     env_params = [
         {
             # Test_0
@@ -482,14 +536,16 @@ if __name__ == "__main__":
     ]
 
     obs_params = {
-        "observation_tree_depth": 2,
+        "observation_tree_depth": training_params.max_depth,
         "observation_radius": 10,
         "observation_max_path_depth": 30
     }
 
+
     def check_env_config(id):
         if id >= len(env_params) or id < 0:
-            print("\n🛑 Invalid environment configuration, only Test_0 to Test_{} are supported.".format(len(env_params) - 1))
+            print("\n🛑 Invalid environment configuration, only Test_0 to Test_{} are supported.".format(
+                len(env_params) - 1))
             exit(1)
 
 
@@ -499,6 +555,10 @@ if __name__ == "__main__":
     training_env_params = env_params[training_params.training_env_config]
     evaluation_env_params = env_params[training_params.evaluation_env_config]
 
+    # FIXME hard-coded for sweep search
+    # see https://wb-forum.slack.com/archives/CL4V2QE59/p1602931982236600 to implement properly
+    # training_params.use_extra_observation = True
+
     print("\nTraining parameters:")
     pprint(vars(training_params))
     print("\nTraining environment parameters (Test_{}):".format(training_params.training_env_config))
@@ -509,4 +569,5 @@ if __name__ == "__main__":
     pprint(obs_params)
 
     os.environ["OMP_NUM_THREADS"] = str(training_params.num_threads)
-    train_agent(training_params, Namespace(**training_env_params), Namespace(**evaluation_env_params), Namespace(**obs_params))
+    train_agent(training_params, Namespace(**training_env_params), Namespace(**evaluation_env_params),
+                Namespace(**obs_params))
diff --git a/utils/fast_tree_obs.py b/utils/fast_tree_obs.py
new file mode 100755
index 0000000..12c91ca
--- /dev/null
+++ b/utils/fast_tree_obs.py
@@ -0,0 +1,301 @@
+import numpy as np
+from flatland.core.env_observation_builder import ObservationBuilder
+from flatland.core.grid.grid4_utils import get_new_position
+from flatland.envs.agent_utils import RailAgentStatus
+from flatland.envs.rail_env import fast_count_nonzero, fast_argmax
+
+
+"""
+LICENCE for the FastTreeObs Observation Builder  
+
+The observation can be used freely and reused for further submissions. Only the author needs to be referred to
+/mentioned in any submissions - if the entire observation or parts, or the main idea is used.
+
+Author: Adrian Egli (adrian.egli@gmail.com)
+
+[Linkedin](https://www.researchgate.net/profile/Adrian_Egli2)
+[Researchgate](https://www.linkedin.com/in/adrian-egli-733a9544/)
+"""
+
+class FastTreeObs(ObservationBuilder):
+
+    def __init__(self, max_depth):
+        self.max_depth = max_depth
+        self.observation_dim = 26
+
+    def build_data(self):
+        if self.env is not None:
+            self.env.dev_obs_dict = {}
+        self.switches = {}
+        self.switches_neighbours = {}
+        self.debug_render_list = []
+        self.debug_render_path_list = []
+        if self.env is not None:
+            self.find_all_cell_where_agent_can_choose()
+
+    def find_all_cell_where_agent_can_choose(self):
+        switches = {}
+        for h in range(self.env.height):
+            for w in range(self.env.width):
+                pos = (h, w)
+                for dir in range(4):
+                    possible_transitions = self.env.rail.get_transitions(*pos, dir)
+                    num_transitions = fast_count_nonzero(possible_transitions)
+                    if num_transitions > 1:
+                        if pos not in switches.keys():
+                            switches.update({pos: [dir]})
+                        else:
+                            switches[pos].append(dir)
+
+        switches_neighbours = {}
+        for h in range(self.env.height):
+            for w in range(self.env.width):
+                # look one step forward
+                for dir in range(4):
+                    pos = (h, w)
+                    possible_transitions = self.env.rail.get_transitions(*pos, dir)
+                    for d in range(4):
+                        if possible_transitions[d] == 1:
+                            new_cell = get_new_position(pos, d)
+                            if new_cell in switches.keys() and pos not in switches.keys():
+                                if pos not in switches_neighbours.keys():
+                                    switches_neighbours.update({pos: [dir]})
+                                else:
+                                    switches_neighbours[pos].append(dir)
+
+        self.switches = switches
+        self.switches_neighbours = switches_neighbours
+
+    def check_agent_decision(self, position, direction):
+        switches = self.switches
+        switches_neighbours = self.switches_neighbours
+        agents_on_switch = False
+        agents_on_switch_all = False
+        agents_near_to_switch = False
+        agents_near_to_switch_all = False
+        if position in switches.keys():
+            agents_on_switch = direction in switches[position]
+            agents_on_switch_all = True
+
+        if position in switches_neighbours.keys():
+            new_cell = get_new_position(position, direction)
+            if new_cell in switches.keys():
+                if not direction in switches[new_cell]:
+                    agents_near_to_switch = direction in switches_neighbours[position]
+            else:
+                agents_near_to_switch = direction in switches_neighbours[position]
+
+            agents_near_to_switch_all = direction in switches_neighbours[position]
+
+        return agents_on_switch, agents_near_to_switch, agents_near_to_switch_all, agents_on_switch_all
+
+    def required_agent_decision(self):
+        agents_can_choose = {}
+        agents_on_switch = {}
+        agents_on_switch_all = {}
+        agents_near_to_switch = {}
+        agents_near_to_switch_all = {}
+        for a in range(self.env.get_num_agents()):
+            ret_agents_on_switch, ret_agents_near_to_switch, ret_agents_near_to_switch_all, ret_agents_on_switch_all = \
+                self.check_agent_decision(
+                    self.env.agents[a].position,
+                    self.env.agents[a].direction)
+            agents_on_switch.update({a: ret_agents_on_switch})
+            agents_on_switch_all.update({a: ret_agents_on_switch_all})
+            ready_to_depart = self.env.agents[a].status == RailAgentStatus.READY_TO_DEPART
+            agents_near_to_switch.update({a: (ret_agents_near_to_switch and not ready_to_depart)})
+
+            agents_can_choose.update({a: agents_on_switch[a] or agents_near_to_switch[a]})
+
+            agents_near_to_switch_all.update({a: (ret_agents_near_to_switch_all and not ready_to_depart)})
+
+        return agents_can_choose, agents_on_switch, agents_near_to_switch, agents_near_to_switch_all, agents_on_switch_all
+
+    def debug_render(self, env_renderer):
+        agents_can_choose, agents_on_switch, agents_near_to_switch, agents_near_to_switch_all = \
+            self.required_agent_decision()
+        self.env.dev_obs_dict = {}
+        for a in range(max(3, self.env.get_num_agents())):
+            self.env.dev_obs_dict.update({a: []})
+
+        selected_agent = None
+        if agents_can_choose[0]:
+            if self.env.agents[0].position is not None:
+                self.debug_render_list.append(self.env.agents[0].position)
+            else:
+                self.debug_render_list.append(self.env.agents[0].initial_position)
+
+        if self.env.agents[0].position is not None:
+            self.debug_render_path_list.append(self.env.agents[0].position)
+        else:
+            self.debug_render_path_list.append(self.env.agents[0].initial_position)
+
+        env_renderer.gl.agent_colors[0] = env_renderer.gl.rgb_s2i("FF0000")
+        env_renderer.gl.agent_colors[1] = env_renderer.gl.rgb_s2i("666600")
+        env_renderer.gl.agent_colors[2] = env_renderer.gl.rgb_s2i("006666")
+        env_renderer.gl.agent_colors[3] = env_renderer.gl.rgb_s2i("550000")
+
+        self.env.dev_obs_dict[0] = self.debug_render_list
+        self.env.dev_obs_dict[1] = self.switches.keys()
+        self.env.dev_obs_dict[2] = self.switches_neighbours.keys()
+        self.env.dev_obs_dict[3] = self.debug_render_path_list
+
+    def reset(self):
+        self.build_data()
+        return
+
+    def fast_argmax(self, array):
+        if array[0] == 1:
+            return 0
+        if array[1] == 1:
+            return 1
+        if array[2] == 1:
+            return 2
+        return 3
+
+    def _explore(self, handle, new_position, new_direction, depth=0):
+        has_opp_agent = 0
+        has_same_agent = 0
+        has_switch = 0
+        visited = []
+
+        # stop exploring (max_depth reached)
+        if depth >= self.max_depth:
+            return has_opp_agent, has_same_agent, has_switch, visited
+
+        # max_explore_steps = 100
+        cnt = 0
+        while cnt < 100:
+            cnt += 1
+
+            visited.append(new_position)
+            opp_a = self.env.agent_positions[new_position]
+            if opp_a != -1 and opp_a != handle:
+                if self.env.agents[opp_a].direction != new_direction:
+                    # opp agent found
+                    has_opp_agent = 1
+                    return has_opp_agent, has_same_agent, has_switch, visited
+                else:
+                    has_same_agent = 1
+                    return has_opp_agent, has_same_agent, has_switch, visited
+
+            # convert one-hot encoding to 0,1,2,3
+            agents_on_switch, \
+            agents_near_to_switch, \
+            agents_near_to_switch_all, \
+            agents_on_switch_all = \
+                self.check_agent_decision(new_position, new_direction)
+            if agents_near_to_switch:
+                return has_opp_agent, has_same_agent, has_switch, visited
+
+            possible_transitions = self.env.rail.get_transitions(*new_position, new_direction)
+            if agents_on_switch:
+                f = 0
+                for dir_loop in range(4):
+                    if possible_transitions[dir_loop] == 1:
+                        f += 1
+                        hoa, hsa, hs, v = self._explore(handle,
+                                                        get_new_position(new_position, dir_loop),
+                                                        dir_loop,
+                                                        depth + 1)
+                        visited.append(v)
+                        has_opp_agent += hoa
+                        has_same_agent += hsa
+                        has_switch += hs
+                f = max(f, 1.0)
+                return has_opp_agent / f, has_same_agent / f, has_switch / f, visited
+            else:
+                new_direction = fast_argmax(possible_transitions)
+                new_position = get_new_position(new_position, new_direction)
+
+        return has_opp_agent, has_same_agent, has_switch, visited
+
+    def get(self, handle):
+        # all values are [0,1]
+        # observation[0]  : 1 path towards target (direction 0) / otherwise 0 -> path is longer or there is no path
+        # observation[1]  : 1 path towards target (direction 1) / otherwise 0 -> path is longer or there is no path
+        # observation[2]  : 1 path towards target (direction 2) / otherwise 0 -> path is longer or there is no path
+        # observation[3]  : 1 path towards target (direction 3) / otherwise 0 -> path is longer or there is no path
+        # observation[4]  : int(agent.status == RailAgentStatus.READY_TO_DEPART)
+        # observation[5]  : int(agent.status == RailAgentStatus.ACTIVE)
+        # observation[6]  : int(agent.status == RailAgentStatus.DONE or agent.status == RailAgentStatus.DONE_REMOVED)
+        # observation[7]  : current agent is located at a switch, where it can take a routing decision
+        # observation[8]  : current agent is located at a cell, where it has to take a stop-or-go decision
+        # observation[9]  : current agent is located one step before/after a switch
+        # observation[10] : 1 if there is a path (track/branch) otherwise 0 (direction 0)
+        # observation[11] : 1 if there is a path (track/branch) otherwise 0 (direction 1)
+        # observation[12] : 1 if there is a path (track/branch) otherwise 0 (direction 2)
+        # observation[13] : 1 if there is a path (track/branch) otherwise 0 (direction 3)
+        # observation[14] : If there is a path with step (direction 0) and there is a agent with opposite direction -> 1
+        # observation[15] : If there is a path with step (direction 1) and there is a agent with opposite direction -> 1
+        # observation[16] : If there is a path with step (direction 2) and there is a agent with opposite direction -> 1
+        # observation[17] : If there is a path with step (direction 3) and there is a agent with opposite direction -> 1
+        # observation[18] : If there is a path with step (direction 0) and there is a agent with same direction -> 1
+        # observation[19] : If there is a path with step (direction 1) and there is a agent with same direction -> 1
+        # observation[20] : If there is a path with step (direction 2) and there is a agent with same direction -> 1
+        # observation[21] : If there is a path with step (direction 3) and there is a agent with same direction -> 1
+        # observation[22] : If there is a switch on the path which agent can not use -> 1
+        # observation[23] : If there is a switch on the path which agent can not use -> 1
+        # observation[24] : If there is a switch on the path which agent can not use -> 1
+        # observation[25] : If there is a switch on the path which agent can not use -> 1
+
+        observation = np.zeros(self.observation_dim)
+        visited = []
+        agent = self.env.agents[handle]
+
+        agent_done = False
+        if agent.status == RailAgentStatus.READY_TO_DEPART:
+            agent_virtual_position = agent.initial_position
+            observation[4] = 1
+        elif agent.status == RailAgentStatus.ACTIVE:
+            agent_virtual_position = agent.position
+            observation[5] = 1
+        else:
+            observation[6] = 1
+            agent_virtual_position = (-1, -1)
+            agent_done = True
+
+        if not agent_done:
+            visited.append(agent_virtual_position)
+            distance_map = self.env.distance_map.get()
+            current_cell_dist = distance_map[handle,
+                                             agent_virtual_position[0], agent_virtual_position[1],
+                                             agent.direction]
+            possible_transitions = self.env.rail.get_transitions(*agent_virtual_position, agent.direction)
+            orientation = agent.direction
+            if fast_count_nonzero(possible_transitions) == 1:
+                orientation = fast_argmax(possible_transitions)
+
+            for dir_loop, branch_direction in enumerate([(orientation + dir_loop) % 4 for dir_loop in range(-1, 3)]):
+                if possible_transitions[branch_direction]:
+                    new_position = get_new_position(agent_virtual_position, branch_direction)
+                    new_cell_dist = distance_map[handle,
+                                                 new_position[0], new_position[1],
+                                                 branch_direction]
+                    if not (np.math.isinf(new_cell_dist) and np.math.isinf(current_cell_dist)):
+                        observation[dir_loop] = int(new_cell_dist < current_cell_dist)
+
+                    has_opp_agent, has_same_agent, has_switch, v = self._explore(handle, new_position, branch_direction)
+                    visited.append(v)
+
+                    observation[10 + dir_loop] = 1
+                    observation[14 + dir_loop] = has_opp_agent
+                    observation[18 + dir_loop] = has_same_agent
+                    observation[22 + dir_loop] = has_switch
+
+        agents_on_switch, \
+        agents_near_to_switch, \
+        agents_near_to_switch_all, \
+        agents_on_switch_all = \
+            self.check_agent_decision(agent_virtual_position, agent.direction)
+        observation[7] = int(agents_on_switch)
+        observation[8] = int(agents_near_to_switch)
+        observation[9] = int(agents_near_to_switch_all)
+
+        self.env.dev_obs_dict.update({handle: visited})
+
+        return observation
+
+    @staticmethod
+    def agent_can_choose(observation):
+        return observation[7] == 1 or observation[8] == 1
-- 
GitLab