From d3479be58d63a004f4254a30fd95efb715ba26d9 Mon Sep 17 00:00:00 2001
From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch>
Date: Thu, 26 Nov 2020 09:00:34 +0100
Subject: [PATCH] removed extra

---
 reinforcement_learning/multi_policy.py |  23 +-
 utils/dead_lock_avoidance_agent.py     |   2 +-
 utils/extra.py                         | 366 -------------------------
 3 files changed, 12 insertions(+), 379 deletions(-)
 delete mode 100644 utils/extra.py

diff --git a/reinforcement_learning/multi_policy.py b/reinforcement_learning/multi_policy.py
index 765bcf5..52183c7 100644
--- a/reinforcement_learning/multi_policy.py
+++ b/reinforcement_learning/multi_policy.py
@@ -4,7 +4,6 @@ from flatland.envs.rail_env import RailEnvActions
 from reinforcement_learning.policy import Policy
 from reinforcement_learning.ppo.ppo_agent import PPOAgent
 from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent
-from utils.extra import ExtraPolicy
 
 
 class MultiPolicy(Policy):
@@ -13,20 +12,20 @@ class MultiPolicy(Policy):
         self.action_size = action_size
         self.memory = []
         self.loss = 0
-        self.extra_policy = ExtraPolicy(state_size, action_size)
+        self.deadlock_avoidance_policy = DeadLockAvoidanceAgent(env, action_size, False)
         self.ppo_policy = PPOAgent(state_size + action_size, action_size, n_agents, env)
 
     def load(self, filename):
         self.ppo_policy.load(filename)
-        self.extra_policy.load(filename)
+        self.deadlock_avoidance_policy.load(filename)
 
     def save(self, filename):
         self.ppo_policy.save(filename)
-        self.extra_policy.save(filename)
+        self.deadlock_avoidance_policy.save(filename)
 
     def step(self, handle, state, action, reward, next_state, done):
-        action_extra_state = self.extra_policy.act(handle, state, 0.0)
-        action_extra_next_state = self.extra_policy.act(handle, next_state, 0.0)
+        action_extra_state = self.deadlock_avoidance_policy.act(handle, state, 0.0)
+        action_extra_next_state = self.deadlock_avoidance_policy.act(handle, next_state, 0.0)
 
         extended_state = np.copy(state)
         for action_itr in np.arange(self.action_size):
@@ -35,11 +34,11 @@ class MultiPolicy(Policy):
         for action_itr in np.arange(self.action_size):
             extended_next_state = np.append(extended_next_state, [int(action_extra_next_state == action_itr)])
 
-        self.extra_policy.step(handle, state, action, reward, next_state, done)
+        self.deadlock_avoidance_policy.step(handle, state, action, reward, next_state, done)
         self.ppo_policy.step(handle, extended_state, action, reward, extended_next_state, done)
 
     def act(self, handle, state, eps=0.):
-        action_extra_state = self.extra_policy.act(handle, state, 0.0)
+        action_extra_state = self.deadlock_avoidance_policy.act(handle, state, 0.0)
         extended_state = np.copy(state)
         for action_itr in np.arange(self.action_size):
             extended_state = np.append(extended_state, [int(action_extra_state == action_itr)])
@@ -49,16 +48,16 @@ class MultiPolicy(Policy):
 
     def reset(self):
         self.ppo_policy.reset()
-        self.extra_policy.reset()
+        self.deadlock_avoidance_policy.reset()
 
     def test(self):
         self.ppo_policy.test()
-        self.extra_policy.test()
+        self.deadlock_avoidance_policy.test()
 
     def start_step(self):
-        self.extra_policy.start_step()
+        self.deadlock_avoidance_policy.start_step()
         self.ppo_policy.start_step()
 
     def end_step(self):
-        self.extra_policy.end_step()
+        self.deadlock_avoidance_policy.end_step()
         self.ppo_policy.end_step()
diff --git a/utils/dead_lock_avoidance_agent.py b/utils/dead_lock_avoidance_agent.py
index 4a37135..37dcd0d 100644
--- a/utils/dead_lock_avoidance_agent.py
+++ b/utils/dead_lock_avoidance_agent.py
@@ -67,7 +67,7 @@ class DeadlockAvoidanceShortestDistanceWalker(ShortestDistanceWalker):
 
 
 class DeadLockAvoidanceAgent(Policy):
-    def __init__(self, env: RailEnv, action_size, show_debug_plot=False):
+    def _init__(self, env: RailEnv, action_size, show_debug_plot=False):
         self.env = env
         self.memory = None
         self.loss = 0
diff --git a/utils/extra.py b/utils/extra.py
deleted file mode 100644
index 03cd4f9..0000000
--- a/utils/extra.py
+++ /dev/null
@@ -1,366 +0,0 @@
-import numpy as np
-from flatland.core.env_observation_builder import ObservationBuilder
-from flatland.core.grid.grid4_utils import get_new_position
-from flatland.envs.agent_utils import RailAgentStatus
-from flatland.envs.rail_env import RailEnvActions, fast_argmax, fast_count_nonzero
-
-from reinforcement_learning.policy import Policy
-from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent, DeadlockAvoidanceShortestDistanceWalker
-
-
-class ExtraPolicy(Policy):
-    def __init__(self, state_size, action_size):
-        self.state_size = state_size
-        self.action_size = action_size
-        self.memory = []
-        self.loss = 0
-
-    def load(self, filename):
-        pass
-
-    def save(self, filename):
-        pass
-
-    def step(self, handle, state, action, reward, next_state, done):
-        pass
-
-    def act(self, handle, state, eps=0.):
-        a = 0
-        b = 4
-        action = RailEnvActions.STOP_MOVING
-        if state[2] == 1 and state[10 + a] == 0:
-            action = RailEnvActions.MOVE_LEFT
-        elif state[3] == 1 and state[11 + a] == 0:
-            action = RailEnvActions.MOVE_FORWARD
-        elif state[4] == 1 and state[12 + a] == 0:
-            action = RailEnvActions.MOVE_RIGHT
-        elif state[5] == 1 and state[13 + a] == 0:
-            action = RailEnvActions.MOVE_FORWARD
-
-        elif state[6] == 1 and state[10 + b] == 0:
-            action = RailEnvActions.MOVE_LEFT
-        elif state[7] == 1 and state[11 + b] == 0:
-            action = RailEnvActions.MOVE_FORWARD
-        elif state[8] == 1 and state[12 + b] == 0:
-            action = RailEnvActions.MOVE_RIGHT
-        elif state[9] == 1 and state[13 + b] == 0:
-            action = RailEnvActions.MOVE_FORWARD
-
-        return action
-
-    def test(self):
-        pass
-
-
-class Extra(ObservationBuilder):
-
-    def __init__(self, max_depth):
-        self.max_depth = max_depth
-        self.observation_dim = 31
-
-    def build_data(self):
-        self.dead_lock_avoidance_agent = None
-        if self.env is not None:
-            self.env.dev_obs_dict = {}
-            self.dead_lock_avoidance_agent = DeadLockAvoidanceAgent(self.env, 5, False)
-
-        self.switches = {}
-        self.switches_neighbours = {}
-        self.debug_render_list = []
-        self.debug_render_path_list = []
-        if self.env is not None:
-            self.find_all_cell_where_agent_can_choose()
-
-    def find_all_cell_where_agent_can_choose(self):
-
-        switches = {}
-        for h in range(self.env.height):
-            for w in range(self.env.width):
-                pos = (h, w)
-                for dir in range(4):
-                    possible_transitions = self.env.rail.get_transitions(*pos, dir)
-                    num_transitions = fast_count_nonzero(possible_transitions)
-                    if num_transitions > 1:
-                        if pos not in switches.keys():
-                            switches.update({pos: [dir]})
-                        else:
-                            switches[pos].append(dir)
-
-        switches_neighbours = {}
-        for h in range(self.env.height):
-            for w in range(self.env.width):
-                # look one step forward
-                for dir in range(4):
-                    pos = (h, w)
-                    possible_transitions = self.env.rail.get_transitions(*pos, dir)
-                    for d in range(4):
-                        if possible_transitions[d] == 1:
-                            new_cell = get_new_position(pos, d)
-                            if new_cell in switches.keys() and pos not in switches.keys():
-                                if pos not in switches_neighbours.keys():
-                                    switches_neighbours.update({pos: [dir]})
-                                else:
-                                    switches_neighbours[pos].append(dir)
-
-        self.switches = switches
-        self.switches_neighbours = switches_neighbours
-
-    def check_agent_descision(self, position, direction):
-        switches = self.switches
-        switches_neighbours = self.switches_neighbours
-        agents_on_switch = False
-        agents_on_switch_all = False
-        agents_near_to_switch = False
-        agents_near_to_switch_all = False
-        if position in switches.keys():
-            agents_on_switch = direction in switches[position]
-            agents_on_switch_all = True
-
-        if position in switches_neighbours.keys():
-            new_cell = get_new_position(position, direction)
-            if new_cell in switches.keys():
-                if not direction in switches[new_cell]:
-                    agents_near_to_switch = direction in switches_neighbours[position]
-            else:
-                agents_near_to_switch = direction in switches_neighbours[position]
-
-            agents_near_to_switch_all = direction in switches_neighbours[position]
-
-        return agents_on_switch, agents_near_to_switch, agents_near_to_switch_all, agents_on_switch_all
-
-    def required_agent_descision(self):
-        agents_can_choose = {}
-        agents_on_switch = {}
-        agents_on_switch_all = {}
-        agents_near_to_switch = {}
-        agents_near_to_switch_all = {}
-        for a in range(self.env.get_num_agents()):
-            ret_agents_on_switch, ret_agents_near_to_switch, ret_agents_near_to_switch_all, ret_agents_on_switch_all = \
-                self.check_agent_descision(
-                    self.env.agents[a].position,
-                    self.env.agents[a].direction)
-            agents_on_switch.update({a: ret_agents_on_switch})
-            agents_on_switch_all.update({a: ret_agents_on_switch_all})
-            ready_to_depart = self.env.agents[a].status == RailAgentStatus.READY_TO_DEPART
-            agents_near_to_switch.update({a: (ret_agents_near_to_switch and not ready_to_depart)})
-
-            agents_can_choose.update({a: agents_on_switch[a] or agents_near_to_switch[a]})
-
-            agents_near_to_switch_all.update({a: (ret_agents_near_to_switch_all and not ready_to_depart)})
-
-        return agents_can_choose, agents_on_switch, agents_near_to_switch, agents_near_to_switch_all, agents_on_switch_all
-
-    def debug_render(self, env_renderer):
-        agents_can_choose, agents_on_switch, agents_near_to_switch, agents_near_to_switch_all = \
-            self.required_agent_descision()
-        self.env.dev_obs_dict = {}
-        for a in range(max(3, self.env.get_num_agents())):
-            self.env.dev_obs_dict.update({a: []})
-
-        selected_agent = None
-        if agents_can_choose[0]:
-            if self.env.agents[0].position is not None:
-                self.debug_render_list.append(self.env.agents[0].position)
-            else:
-                self.debug_render_list.append(self.env.agents[0].initial_position)
-
-        if self.env.agents[0].position is not None:
-            self.debug_render_path_list.append(self.env.agents[0].position)
-        else:
-            self.debug_render_path_list.append(self.env.agents[0].initial_position)
-
-        env_renderer.gl.agent_colors[0] = env_renderer.gl.rgb_s2i("FF0000")
-        env_renderer.gl.agent_colors[1] = env_renderer.gl.rgb_s2i("666600")
-        env_renderer.gl.agent_colors[2] = env_renderer.gl.rgb_s2i("006666")
-        env_renderer.gl.agent_colors[3] = env_renderer.gl.rgb_s2i("550000")
-
-        self.env.dev_obs_dict[0] = self.debug_render_list
-        self.env.dev_obs_dict[1] = self.switches.keys()
-        self.env.dev_obs_dict[2] = self.switches_neighbours.keys()
-        self.env.dev_obs_dict[3] = self.debug_render_path_list
-
-    def reset(self):
-        self.build_data()
-        return
-
-
-    def _check_dead_lock_at_branching_position(self, handle, new_position, branch_direction):
-        _, full_shortest_distance_agent_map = self.dead_lock_avoidance_agent.shortest_distance_walker.getData()
-        opp_agents = self.dead_lock_avoidance_agent.shortest_distance_walker.opp_agent_map.get(handle, [])
-        same_agents = self.dead_lock_avoidance_agent.shortest_distance_walker.same_agent_map.get(handle,[])
-        local_walker = DeadlockAvoidanceShortestDistanceWalker(
-            self.env,
-            self.dead_lock_avoidance_agent.shortest_distance_walker.agent_positions,
-            self.dead_lock_avoidance_agent.shortest_distance_walker.switches)
-        local_walker.walk_to_target(handle, new_position, branch_direction)
-        shortest_distance_agent_map, _ = self.dead_lock_avoidance_agent.shortest_distance_walker.getData()
-        my_shortest_path_to_check = shortest_distance_agent_map[handle]
-        next_step_ok = self.dead_lock_avoidance_agent.check_agent_can_move(handle,
-                                                                           my_shortest_path_to_check,
-                                                                           opp_agents,
-                                                                           same_agents,
-                                                                           full_shortest_distance_agent_map)
-        return next_step_ok
-
-    def _explore(self, handle, new_position, new_direction, depth=0):
-
-        has_opp_agent = 0
-        has_same_agent = 0
-        has_switch = 0
-        visited = []
-
-        # stop exploring (max_depth reached)
-        if depth >= self.max_depth:
-            return has_opp_agent, has_same_agent, has_switch, visited
-
-        # max_explore_steps = 100
-        cnt = 0
-        while cnt < 100:
-            cnt += 1
-
-            visited.append(new_position)
-            opp_a = self.env.agent_positions[new_position]
-            if opp_a != -1 and opp_a != handle:
-                if self.env.agents[opp_a].direction != new_direction:
-                    # opp agent found
-                    has_opp_agent = 1
-                    return has_opp_agent, has_same_agent, has_switch, visited
-                else:
-                    has_same_agent = 1
-                    return has_opp_agent, has_same_agent, has_switch, visited
-
-            # convert one-hot encoding to 0,1,2,3
-            agents_on_switch, \
-            agents_near_to_switch, \
-            agents_near_to_switch_all, \
-            agents_on_switch_all = \
-                self.check_agent_descision(new_position, new_direction)
-            if agents_near_to_switch:
-                return has_opp_agent, has_same_agent, has_switch, visited
-
-            possible_transitions = self.env.rail.get_transitions(*new_position, new_direction)
-            if agents_on_switch:
-                f = 0
-                for dir_loop in range(4):
-                    if possible_transitions[dir_loop] == 1:
-                        f += 1
-                        hoa, hsa, hs, v = self._explore(handle,
-                                                        get_new_position(new_position, dir_loop),
-                                                        dir_loop,
-                                                        depth + 1)
-                        visited.append(v)
-                        has_opp_agent += hoa
-                        has_same_agent += hsa
-                        has_switch += hs
-                f = max(f, 1.0)
-                return has_opp_agent / f, has_same_agent / f, has_switch / f, visited
-            else:
-                new_direction = fast_argmax(possible_transitions)
-                new_position = get_new_position(new_position, new_direction)
-
-        return has_opp_agent, has_same_agent, has_switch, visited
-
-    def get(self, handle):
-
-        if handle == 0:
-            self.dead_lock_avoidance_agent.start_step()
-
-        # all values are [0,1]
-        # observation[0]  : 1 path towards target (direction 0) / otherwise 0 -> path is longer or there is no path
-        # observation[1]  : 1 path towards target (direction 1) / otherwise 0 -> path is longer or there is no path
-        # observation[2]  : 1 path towards target (direction 2) / otherwise 0 -> path is longer or there is no path
-        # observation[3]  : 1 path towards target (direction 3) / otherwise 0 -> path is longer or there is no path
-        # observation[4]  : int(agent.status == RailAgentStatus.READY_TO_DEPART)
-        # observation[5]  : int(agent.status == RailAgentStatus.ACTIVE)
-        # observation[6]  : int(agent.status == RailAgentStatus.DONE or agent.status == RailAgentStatus.DONE_REMOVED)
-        # observation[7]  : current agent is located at a switch, where it can take a routing decision
-        # observation[8]  : current agent is located at a cell, where it has to take a stop-or-go decision
-        # observation[9]  : current agent is located one step before/after a switch
-        # observation[10] : 1 if there is a path (track/branch) otherwise 0 (direction 0)
-        # observation[11] : 1 if there is a path (track/branch) otherwise 0 (direction 1)
-        # observation[12] : 1 if there is a path (track/branch) otherwise 0 (direction 2)
-        # observation[13] : 1 if there is a path (track/branch) otherwise 0 (direction 3)
-        # observation[14] : If there is a path with step (direction 0) and there is a agent with opposite direction -> 1
-        # observation[15] : If there is a path with step (direction 1) and there is a agent with opposite direction -> 1
-        # observation[16] : If there is a path with step (direction 2) and there is a agent with opposite direction -> 1
-        # observation[17] : If there is a path with step (direction 3) and there is a agent with opposite direction -> 1
-        # observation[18] : If there is a path with step (direction 0) and there is a agent with same direction -> 1
-        # observation[19] : If there is a path with step (direction 1) and there is a agent with same direction -> 1
-        # observation[20] : If there is a path with step (direction 2) and there is a agent with same direction -> 1
-        # observation[21] : If there is a path with step (direction 3) and there is a agent with same direction -> 1
-        # observation[22] : If there is a switch on the path which agent can not use -> 1
-        # observation[23] : If there is a switch on the path which agent can not use -> 1
-        # observation[24] : If there is a switch on the path which agent can not use -> 1
-        # observation[25] : If there is a switch on the path which agent can not use -> 1
-        # observation[26] : Is there a deadlock signal on shortest path walk(s) (direction 0)-> 1
-        # observation[27] : Is there a deadlock signal on shortest path walk(s) (direction 1)-> 1
-        # observation[28] : Is there a deadlock signal on shortest path walk(s) (direction 2)-> 1
-        # observation[29] : Is there a deadlock signal on shortest path walk(s) (direction 3)-> 1
-        # observation[30] : Is there a deadlock signal on shortest path walk(s) (current position check)-> 1
-
-        observation = np.zeros(self.observation_dim)
-        visited = []
-        agent = self.env.agents[handle]
-
-        agent_done = False
-        if agent.status == RailAgentStatus.READY_TO_DEPART:
-            agent_virtual_position = agent.initial_position
-            observation[4] = 1
-        elif agent.status == RailAgentStatus.ACTIVE:
-            agent_virtual_position = agent.position
-            observation[5] = 1
-        else:
-            observation[6] = 1
-            agent_virtual_position = (-1, -1)
-            agent_done = True
-
-        if not agent_done:
-            visited.append(agent_virtual_position)
-            distance_map = self.env.distance_map.get()
-            current_cell_dist = distance_map[handle,
-                                             agent_virtual_position[0], agent_virtual_position[1],
-                                             agent.direction]
-            possible_transitions = self.env.rail.get_transitions(*agent_virtual_position, agent.direction)
-            orientation = agent.direction
-            if fast_count_nonzero(possible_transitions) == 1:
-                orientation = fast_argmax(possible_transitions)
-
-            for dir_loop, branch_direction in enumerate([(orientation + dir_loop) % 4 for dir_loop in range(-1, 3)]):
-                if possible_transitions[branch_direction]:
-                    new_position = get_new_position(agent_virtual_position, branch_direction)
-                    new_cell_dist = distance_map[handle,
-                                                 new_position[0], new_position[1],
-                                                 branch_direction]
-                    if not (np.math.isinf(new_cell_dist) and np.math.isinf(current_cell_dist)):
-                        observation[dir_loop] = int(new_cell_dist < current_cell_dist)
-
-                    has_opp_agent, has_same_agent, has_switch, v = self._explore(handle, new_position, branch_direction)
-                    visited.append(v)
-
-                    observation[10 + dir_loop] = 1
-                    observation[14 + dir_loop] = has_opp_agent
-                    observation[18 + dir_loop] = has_same_agent
-                    observation[22 + dir_loop] = has_switch
-
-                    next_step_ok = self._check_dead_lock_at_branching_position(handle, new_position, branch_direction)
-                    if next_step_ok:
-                        observation[26 + dir_loop] = 1
-
-        agents_on_switch, \
-        agents_near_to_switch, \
-        agents_near_to_switch_all, \
-        agents_on_switch_all = \
-            self.check_agent_descision(agent_virtual_position, agent.direction)
-        observation[7] = int(agents_on_switch)
-        observation[8] = int(agents_near_to_switch)
-        observation[9] = int(agents_near_to_switch_all)
-
-        observation[30] = int(self.dead_lock_avoidance_agent.act(handle, None, 0) == RailEnvActions.STOP_MOVING)
-
-        self.env.dev_obs_dict.update({handle: visited})
-
-        return observation
-
-    @staticmethod
-    def agent_can_choose(observation):
-        return observation[7] == 1 or observation[8] == 1
-- 
GitLab