Skip to content
Snippets Groups Projects
Commit e28c57e5 authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files

Policy updated

parent 388822a0
No related branches found
No related tags found
No related merge requests found
...@@ -2,8 +2,8 @@ from flatland.envs.agent_utils import RailAgentStatus ...@@ -2,8 +2,8 @@ from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.rail_env import RailEnv, RailEnvActions from flatland.envs.rail_env import RailEnv, RailEnvActions
from reinforcement_learning.policy import HybridPolicy from reinforcement_learning.policy import HybridPolicy
from reinforcement_learning.ppo_agent import PPOPolicy
from utils.agent_action_config import map_rail_env_action from utils.agent_action_config import map_rail_env_action
from utils.agent_can_choose_helper import AgentCanChooseHelper
from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent
...@@ -17,69 +17,69 @@ class DeadLockAvoidanceWithDecisionAgent(HybridPolicy): ...@@ -17,69 +17,69 @@ class DeadLockAvoidanceWithDecisionAgent(HybridPolicy):
self.action_size = action_size self.action_size = action_size
self.learning_agent = learning_agent self.learning_agent = learning_agent
self.dead_lock_avoidance_agent = DeadLockAvoidanceAgent(self.env, action_size, False) self.dead_lock_avoidance_agent = DeadLockAvoidanceAgent(self.env, action_size, False)
self.agent_can_choose_helper = AgentCanChooseHelper() self.policy_selector = PPOPolicy(state_size, 2)
self.memory = self.learning_agent.memory self.memory = self.learning_agent.memory
self.loss = self.learning_agent.loss self.loss = self.learning_agent.loss
def step(self, handle, state, action, reward, next_state, done): def step(self, handle, state, action, reward, next_state, done):
select = self.policy_selector.act(handle, state, 0.0)
self.policy_selector.step(handle, state, select, reward, next_state, done)
self.dead_lock_avoidance_agent.step(handle, state, action, reward, next_state, done) self.dead_lock_avoidance_agent.step(handle, state, action, reward, next_state, done)
self.learning_agent.step(handle, state, action, reward, next_state, done) self.learning_agent.step(handle, state, action, reward, next_state, done)
self.loss = self.learning_agent.loss self.loss = self.learning_agent.loss
def act(self, handle, state, eps=0.): def act(self, handle, state, eps=0.):
agent = self.env.agents[handle] select = self.policy_selector.act(handle, state, eps)
position = agent.position if select == 0:
if position is None: return self.learning_agent.act(handle, state, eps)
position = agent.initial_position return self.dead_lock_avoidance_agent.act(handle, state, -1.0)
direction = agent.direction
if agent.status < RailAgentStatus.DONE:
agents_on_switch, agents_near_to_switch, _, _ = \
self.agent_can_choose_helper.check_agent_decision(position, direction)
if agents_on_switch or agents_near_to_switch:
return self.learning_agent.act(handle, state, eps)
else:
act = self.dead_lock_avoidance_agent.act(handle, state, -1.0)
return map_rail_env_action(act)
# Agent is still at target cell
return map_rail_env_action(RailEnvActions.DO_NOTHING)
def save(self, filename): def save(self, filename):
self.dead_lock_avoidance_agent.save(filename) self.dead_lock_avoidance_agent.save(filename)
self.learning_agent.save(filename) self.learning_agent.save(filename)
self.policy_selector.save(filename + '.selector')
def load(self, filename): def load(self, filename):
self.dead_lock_avoidance_agent.load(filename) self.dead_lock_avoidance_agent.load(filename)
self.learning_agent.load(filename) self.learning_agent.load(filename)
self.policy_selector.load(filename + '.selector')
def start_step(self, train): def start_step(self, train):
self.dead_lock_avoidance_agent.start_step(train) self.dead_lock_avoidance_agent.start_step(train)
self.learning_agent.start_step(train) self.learning_agent.start_step(train)
self.policy_selector.start_step(train)
def end_step(self, train): def end_step(self, train):
self.dead_lock_avoidance_agent.end_step(train) self.dead_lock_avoidance_agent.end_step(train)
self.learning_agent.end_step(train) self.learning_agent.end_step(train)
self.policy_selector.end_step(train)
def start_episode(self, train): def start_episode(self, train):
self.dead_lock_avoidance_agent.start_episode(train) self.dead_lock_avoidance_agent.start_episode(train)
self.learning_agent.start_episode(train) self.learning_agent.start_episode(train)
self.policy_selector.start_episode(train)
def end_episode(self, train): def end_episode(self, train):
self.dead_lock_avoidance_agent.end_episode(train) self.dead_lock_avoidance_agent.end_episode(train)
self.learning_agent.end_episode(train) self.learning_agent.end_episode(train)
self.policy_selector.end_episode(train)
def load_replay_buffer(self, filename): def load_replay_buffer(self, filename):
self.dead_lock_avoidance_agent.load_replay_buffer(filename) self.dead_lock_avoidance_agent.load_replay_buffer(filename)
self.learning_agent.load_replay_buffer(filename) self.learning_agent.load_replay_buffer(filename)
self.policy_selector.load_replay_buffer(filename + ".selector")
def test(self): def test(self):
self.dead_lock_avoidance_agent.test() self.dead_lock_avoidance_agent.test()
self.learning_agent.test() self.learning_agent.test()
self.policy_selector.test()
def reset(self, env: RailEnv): def reset(self, env: RailEnv):
self.env = env self.env = env
self.agent_can_choose_helper.build_data(env)
self.dead_lock_avoidance_agent.reset(env) self.dead_lock_avoidance_agent.reset(env)
self.learning_agent.reset(env) self.learning_agent.reset(env)
self.policy_selector.reset(env)
def clone(self): def clone(self):
return self return self
...@@ -178,6 +178,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): ...@@ -178,6 +178,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
if False: if False:
policy = DeadLockAvoidanceAgent(train_env, get_action_size()) policy = DeadLockAvoidanceAgent(train_env, get_action_size())
if True: if True:
# inter_policy = PPOPolicy(state_size, get_action_size(), use_replay_buffer=False, in_parameters=train_params)
inter_policy = DDDQNPolicy(state_size, get_action_size(), train_params) inter_policy = DDDQNPolicy(state_size, get_action_size(), train_params)
policy = DeadLockAvoidanceWithDecisionAgent(train_env, state_size, get_action_size(), inter_policy) policy = DeadLockAvoidanceWithDecisionAgent(train_env, state_size, get_action_size(), inter_policy)
if False: if False:
...@@ -234,7 +235,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): ...@@ -234,7 +235,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
# Reset environment # Reset environment
reset_timer.start() reset_timer.start()
number_of_agents = n_agents # int(min(n_agents, 1 + np.floor(episode_idx / 500))) number_of_agents = int(min(n_agents, 1 + np.floor(episode_idx / 200)))
train_env_params.n_agents = episode_idx % number_of_agents + 1 train_env_params.n_agents = episode_idx % number_of_agents + 1
train_env = create_rail_env(train_env_params, tree_observation) train_env = create_rail_env(train_env_params, tree_observation)
......
...@@ -22,9 +22,9 @@ class MultiDecisionAgent(LearningPolicy): ...@@ -22,9 +22,9 @@ class MultiDecisionAgent(LearningPolicy):
def step(self, handle, state, action, reward, next_state, done): def step(self, handle, state, action, reward, next_state, done):
select = self.policy_selector.act(handle, state, 0.0)
self.ppo_policy.step(handle, state, action, reward, next_state, done) self.ppo_policy.step(handle, state, action, reward, next_state, done)
self.dddqn_policy.step(handle, state, action, reward, next_state, done) self.dddqn_policy.step(handle, state, action, reward, next_state, done)
select = self.policy_selector.act(handle, state, 0.0)
self.policy_selector.step(handle, state, select, reward, next_state, done) self.policy_selector.step(handle, state, select, reward, next_state, done)
def act(self, handle, state, eps=0.): def act(self, handle, state, eps=0.):
......
...@@ -4,11 +4,10 @@ import numpy as np ...@@ -4,11 +4,10 @@ import numpy as np
from flatland.core.env_observation_builder import ObservationBuilder from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.grid.grid4_utils import get_new_position from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.agent_utils import RailAgentStatus from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.rail_env import fast_count_nonzero, fast_argmax, RailEnvActions from flatland.envs.rail_env import fast_count_nonzero, fast_argmax
from utils.agent_can_choose_helper import AgentCanChooseHelper from utils.agent_can_choose_helper import AgentCanChooseHelper
from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent from utils.deadlock_check import get_agent_positions, get_agent_targets
from utils.deadlock_check import check_for_deadlock, get_agent_positions
""" """
LICENCE for the FastTreeObs Observation Builder LICENCE for the FastTreeObs Observation Builder
...@@ -27,7 +26,7 @@ class FastTreeObs(ObservationBuilder): ...@@ -27,7 +26,7 @@ class FastTreeObs(ObservationBuilder):
def __init__(self, max_depth: Any): def __init__(self, max_depth: Any):
self.max_depth = max_depth self.max_depth = max_depth
self.observation_dim = 41 self.observation_dim = 30
self.agent_can_choose_helper = None self.agent_can_choose_helper = None
def debug_render(self, env_renderer): def debug_render(self, env_renderer):
...@@ -65,21 +64,18 @@ class FastTreeObs(ObservationBuilder): ...@@ -65,21 +64,18 @@ class FastTreeObs(ObservationBuilder):
self.agent_can_choose_helper.build_data(self.env) self.agent_can_choose_helper.build_data(self.env)
self.debug_render_list = [] self.debug_render_list = []
self.debug_render_path_list = [] self.debug_render_path_list = []
if self.env is not None:
self.dead_lock_avoidance_agent = DeadLockAvoidanceAgent(self.env, 5, False)
else:
self.dead_lock_avoidance_agent = None
def _explore(self, handle, new_position, new_direction, distance_map, depth=0): def _explore(self, handle, new_position, new_direction, distance_map, depth=0):
has_opp_agent = 0 has_opp_agent = 0
has_same_agent = 0 has_same_agent = 0
has_target = 0 has_target = 0
has_opp_target = 0
visited = [] visited = []
min_dist = distance_map[handle, new_position[0], new_position[1], new_direction] min_dist = distance_map[handle, new_position[0], new_position[1], new_direction]
# stop exploring (max_depth reached) # stop exploring (max_depth reached)
if depth >= self.max_depth: if depth >= self.max_depth:
return has_opp_agent, has_same_agent, has_target, visited, min_dist return has_opp_agent, has_same_agent, has_target, has_opp_target, visited, min_dist
# max_explore_steps = 100 -> just to ensure that the exploration ends # max_explore_steps = 100 -> just to ensure that the exploration ends
cnt = 0 cnt = 0
...@@ -92,7 +88,7 @@ class FastTreeObs(ObservationBuilder): ...@@ -92,7 +88,7 @@ class FastTreeObs(ObservationBuilder):
if self.env.agents[opp_a].direction != new_direction: if self.env.agents[opp_a].direction != new_direction:
# opp agent found -> stop exploring. This would be a strong signal. # opp agent found -> stop exploring. This would be a strong signal.
has_opp_agent = 1 has_opp_agent = 1
return has_opp_agent, has_same_agent, has_target, visited, min_dist return has_opp_agent, has_same_agent, has_target, has_opp_target, visited, min_dist
else: else:
# same agent found # same agent found
# the agent can follow the agent, because this agent is still moving ahead and there shouldn't # the agent can follow the agent, because this agent is still moving ahead and there shouldn't
...@@ -101,7 +97,8 @@ class FastTreeObs(ObservationBuilder): ...@@ -101,7 +97,8 @@ class FastTreeObs(ObservationBuilder):
# target on this branch -> thus the agents should scan further whether there will be an opposite # target on this branch -> thus the agents should scan further whether there will be an opposite
# agent walking on same track # agent walking on same track
has_same_agent = 1 has_same_agent = 1
# !NOT stop exploring! return has_opp_agent, has_same_agent, has_switch, visited,min_dist # !NOT stop exploring!
return has_opp_agent, has_same_agent, has_target, has_opp_target, visited, min_dist
# agents_on_switch == TRUE -> Current cell is a switch where the agent can decide (branch) in exploration # agents_on_switch == TRUE -> Current cell is a switch where the agent can decide (branch) in exploration
# agent_near_to_switch == TRUE -> One cell before the switch, where the agent can decide # agent_near_to_switch == TRUE -> One cell before the switch, where the agent can decide
...@@ -112,10 +109,14 @@ class FastTreeObs(ObservationBuilder): ...@@ -112,10 +109,14 @@ class FastTreeObs(ObservationBuilder):
if agents_near_to_switch: if agents_near_to_switch:
# The exploration was walking on a path where the agent can not decide # The exploration was walking on a path where the agent can not decide
# Best option would be MOVE_FORWARD -> Skip exploring - just walking # Best option would be MOVE_FORWARD -> Skip exploring - just walking
return has_opp_agent, has_same_agent, has_target, visited, min_dist return has_opp_agent, has_same_agent, has_target, has_opp_target, visited, min_dist
if self.env.agents[handle].target in self.agents_target:
has_opp_target = 1
if self.env.agents[handle].target == new_position: if self.env.agents[handle].target == new_position:
has_target = 1 has_target = 1
return has_opp_agent, has_same_agent, has_target, has_opp_target, visited, min_dist
possible_transitions = self.env.rail.get_transitions(*new_position, new_direction) possible_transitions = self.env.rail.get_transitions(*new_position, new_direction)
if agents_on_switch: if agents_on_switch:
...@@ -130,30 +131,30 @@ class FastTreeObs(ObservationBuilder): ...@@ -130,30 +131,30 @@ class FastTreeObs(ObservationBuilder):
# --- OPEN RESEARCH QUESTION ---> is this good or shall we use full detailed information as # --- OPEN RESEARCH QUESTION ---> is this good or shall we use full detailed information as
# we did in the TreeObservation (FLATLAND) ? # we did in the TreeObservation (FLATLAND) ?
if possible_transitions[dir_loop] == 1: if possible_transitions[dir_loop] == 1:
hoa, hsa, ht, v, m_dist = self._explore(handle, hoa, hsa, ht, hot, v, m_dist = self._explore(handle,
get_new_position(new_position, dir_loop), get_new_position(new_position, dir_loop),
dir_loop, dir_loop,
distance_map, distance_map,
depth + 1) depth + 1)
visited.append(v) visited.append(v)
has_opp_agent += max(hoa, has_opp_agent) has_opp_agent = max(hoa, has_opp_agent)
has_same_agent += max(hsa, has_same_agent) has_same_agent = max(hsa, has_same_agent)
has_target = max(has_target, ht) has_target = max(has_target, ht)
has_opp_target = max(has_opp_target, hot)
min_dist = min(min_dist, m_dist) min_dist = min(min_dist, m_dist)
return has_opp_agent, has_same_agent, has_target, visited, min_dist return has_opp_agent, has_same_agent, has_target, has_opp_target, visited, min_dist
else: else:
new_direction = fast_argmax(possible_transitions) new_direction = fast_argmax(possible_transitions)
new_position = get_new_position(new_position, new_direction) new_position = get_new_position(new_position, new_direction)
min_dist = min(min_dist, distance_map[handle, new_position[0], new_position[1], new_direction]) min_dist = min(min_dist, distance_map[handle, new_position[0], new_position[1], new_direction])
return has_opp_agent, has_same_agent, has_target, visited, min_dist return has_opp_agent, has_same_agent, has_target, has_opp_target, visited, min_dist
def get_many(self, handles: Optional[List[int]] = None): def get_many(self, handles: Optional[List[int]] = None):
self.dead_lock_avoidance_agent.start_step(train=False)
self.agent_positions = get_agent_positions(self.env) self.agent_positions = get_agent_positions(self.env)
self.agents_target = get_agent_targets(self.env)
observations = super().get_many(handles) observations = super().get_many(handles)
self.dead_lock_avoidance_agent.end_step(train=False)
return observations return observations
def get(self, handle: int = 0): def get(self, handle: int = 0):
...@@ -184,8 +185,6 @@ class FastTreeObs(ObservationBuilder): ...@@ -184,8 +185,6 @@ class FastTreeObs(ObservationBuilder):
# observation[23] : If there is a switch on the path which agent can not use -> 1 # observation[23] : If there is a switch on the path which agent can not use -> 1
# observation[24] : If there is a switch on the path which agent can not use -> 1 # observation[24] : If there is a switch on the path which agent can not use -> 1
# observation[25] : If there is a switch on the path which agent can not use -> 1 # observation[25] : If there is a switch on the path which agent can not use -> 1
# observation[26] : If there the dead-lock avoidance agent predicts a deadlock -> 1
# observation[27] : If there the agent can only walk forward or stop -> 1
observation = np.zeros(self.observation_dim) observation = np.zeros(self.observation_dim)
visited = [] visited = []
...@@ -223,24 +222,18 @@ class FastTreeObs(ObservationBuilder): ...@@ -223,24 +222,18 @@ class FastTreeObs(ObservationBuilder):
if not (np.math.isinf(new_cell_dist) and np.math.isinf(current_cell_dist)): if not (np.math.isinf(new_cell_dist) and np.math.isinf(current_cell_dist)):
observation[dir_loop] = int(new_cell_dist < current_cell_dist) observation[dir_loop] = int(new_cell_dist < current_cell_dist)
has_opp_agent, has_same_agent, has_target, v, min_dist = self._explore(handle, has_opp_agent, has_same_agent, has_target, has_opp_target, v, min_dist = self._explore(handle,
new_position, new_position,
branch_direction, branch_direction,
distance_map) distance_map)
visited.append(v) visited.append(v)
if not (np.math.isinf(min_dist) and np.math.isinf(current_cell_dist)): if not (np.math.isinf(min_dist) and np.math.isinf(current_cell_dist)):
observation[31 + dir_loop] = int(min_dist < current_cell_dist) observation[11 + dir_loop] = int(min_dist < current_cell_dist)
observation[11 + dir_loop] = int(not np.math.isinf(new_cell_dist))
observation[15 + dir_loop] = has_opp_agent observation[15 + dir_loop] = has_opp_agent
observation[19 + dir_loop] = has_same_agent observation[19 + dir_loop] = has_same_agent
observation[23 + dir_loop] = has_target observation[23 + dir_loop] = has_target
observation[27 + dir_loop] = int(np.math.isinf(new_cell_dist)) observation[27 + dir_loop] = has_opp_target
observation[36] = int(check_for_deadlock(handle,
self.env,
self.agent_positions,
new_position,
branch_direction))
agents_on_switch, \ agents_on_switch, \
agents_near_to_switch, \ agents_near_to_switch, \
...@@ -253,11 +246,6 @@ class FastTreeObs(ObservationBuilder): ...@@ -253,11 +246,6 @@ class FastTreeObs(ObservationBuilder):
observation[9] = int(agents_near_to_switch) observation[9] = int(agents_near_to_switch)
observation[10] = int(agents_near_to_switch_all) observation[10] = int(agents_near_to_switch_all)
action = self.dead_lock_avoidance_agent.act(handle, None, 0.0)
observation[35] = int(action == RailEnvActions.STOP_MOVING)
observation[40] = int(check_for_deadlock(handle, self.env, self.agent_positions))
self.env.dev_obs_dict.update({handle: visited}) self.env.dev_obs_dict.update({handle: visited})
observation[np.isinf(observation)] = -1 observation[np.isinf(observation)] = -1
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment