Skip to content
Snippets Groups Projects
Commit 225bbed9 authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files

refactored file name

parent b8fc444d
No related branches found
No related tags found
No related merge requests found
...@@ -66,10 +66,6 @@ class DDDQNPolicy(Policy): ...@@ -66,10 +66,6 @@ class DDDQNPolicy(Policy):
# Epsilon-greedy action selection # Epsilon-greedy action selection
if random.random() >= eps: if random.random() >= eps:
return np.argmax(action_values.cpu().data.numpy()) return np.argmax(action_values.cpu().data.numpy())
qvals = action_values.cpu().data.numpy()[0]
qvals = qvals - np.min(qvals)
qvals = qvals / (1e-5 + np.sum(qvals))
return np.argmax(np.random.multinomial(1, qvals))
else: else:
return random.choice(np.arange(self.action_size)) return random.choice(np.arange(self.action_size))
......
...@@ -171,9 +171,14 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): ...@@ -171,9 +171,14 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
scores_window = deque(maxlen=checkpoint_interval) # todo smooth when rendering instead scores_window = deque(maxlen=checkpoint_interval) # todo smooth when rendering instead
completion_window = deque(maxlen=checkpoint_interval) completion_window = deque(maxlen=checkpoint_interval)
# IF USE_SINGLE_AGENT_TRAINING is set and the episode_idx <= MAX_SINGLE_TRAINING_ITERATION then
# the training gets done with single use. Each UPDATE_POLICY2_N_EPISODE the second policy get replaced
# with the policy (the one which get trained).
USE_SINGLE_AGENT_TRAINING = True
MAX_SINGLE_TRAINING_ITERATION = 1000
UPDATE_POLICY2_N_EPISODE = 200
# Double Dueling DQN policy # Double Dueling DQN policy
USE_SINGLE_AGENT_TRAINING = False
UPDATE_POLICY2_N_EPISODE = 1000
policy = DDDQNPolicy(state_size, action_size, train_params) policy = DDDQNPolicy(state_size, action_size, train_params)
# policy = PPOAgent(state_size, action_size, n_agents) # policy = PPOAgent(state_size, action_size, n_agents)
# Load existing policy # Load existing policy
...@@ -221,6 +226,9 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): ...@@ -221,6 +226,9 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
preproc_timer = Timer() preproc_timer = Timer()
inference_timer = Timer() inference_timer = Timer()
if episode_idx > MAX_SINGLE_TRAINING_ITERATION:
USE_SINGLE_AGENT_TRAINING = False
# Reset environment # Reset environment
reset_timer.start() reset_timer.start()
train_env_params.n_agents = episode_idx % n_agents + 1 train_env_params.n_agents = episode_idx % n_agents + 1
...@@ -293,6 +301,11 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): ...@@ -293,6 +301,11 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
if agent_obs[agent][26] == 1: if agent_obs[agent][26] == 1:
if act != RailEnvActions.STOP_MOVING: if act != RailEnvActions.STOP_MOVING:
all_rewards[agent] -= 10.0 all_rewards[agent] -= 10.0
if agent_obs[agent][27] == 1:
if act == RailEnvActions.MOVE_LEFT or \
act == RailEnvActions.MOVE_RIGHT or \
act == RailEnvActions.DO_NOTHING:
all_rewards[agent] -= 1.0
step_timer.end() step_timer.end()
...@@ -310,7 +323,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): ...@@ -310,7 +323,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
if update_values[agent] or done['__all__']: if update_values[agent] or done['__all__']:
# Only learn from timesteps where somethings happened # Only learn from timesteps where somethings happened
learn_timer.start() learn_timer.start()
if agent in agent_to_learn: if agent in agent_to_learn or not USE_SINGLE_AGENT_TRAINING:
policy.step(agent, policy.step(agent,
agent_prev_obs[agent], agent_prev_action[agent], all_rewards[agent], agent_prev_obs[agent], agent_prev_action[agent], all_rewards[agent],
agent_obs[agent], agent_obs[agent],
...@@ -501,8 +514,8 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params): ...@@ -501,8 +514,8 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params):
if __name__ == "__main__": if __name__ == "__main__":
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=54000, type=int) parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=2000, type=int)
parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=1, parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=2,
type=int) type=int)
parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=1, parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=1,
type=int) type=int)
......
...@@ -27,12 +27,11 @@ VERBOSE = True ...@@ -27,12 +27,11 @@ VERBOSE = True
# Checkpoint to use (remember to push it!) # Checkpoint to use (remember to push it!)
# checkpoint = "./checkpoints/201112143850-5400.pth" # 21.220418678677177 DEPTH=2 AGENTS=10 # checkpoint = "./checkpoints/201112143850-5400.pth" # 21.220418678677177 DEPTH=2 AGENTS=10
checkpoint = "./checkpoints/201113211844-6700.pth" # 19.690047767961005 DEPTH=2 AGENTS=20 checkpoint = "./checkpoints/201117082153-1500.pth" # 21.570149424415636 DEPTH=2 AGENTS=10
# Use last action cache # Use last action cache
USE_ACTION_CACHE = False USE_ACTION_CACHE = False
USE_DEAD_LOCK_AVOIDANCE_AGENT = False USE_DEAD_LOCK_AVOIDANCE_AGENT = False # 21.54485505223213
# Observation parameters (must match training parameters!) # Observation parameters (must match training parameters!)
observation_tree_depth = 2 observation_tree_depth = 2
......
...@@ -187,6 +187,7 @@ class Extra(ObservationBuilder): ...@@ -187,6 +187,7 @@ class Extra(ObservationBuilder):
def _check_dead_lock_at_branching_position(self, handle, new_position, branch_direction): def _check_dead_lock_at_branching_position(self, handle, new_position, branch_direction):
_, full_shortest_distance_agent_map = self.dead_lock_avoidance_agent.shortest_distance_walker.getData() _, full_shortest_distance_agent_map = self.dead_lock_avoidance_agent.shortest_distance_walker.getData()
opp_agents = self.dead_lock_avoidance_agent.shortest_distance_walker.opp_agent_map.get(handle, []) opp_agents = self.dead_lock_avoidance_agent.shortest_distance_walker.opp_agent_map.get(handle, [])
same_agents = self.dead_lock_avoidance_agent.shortest_distance_walker.same_agent_map.get(handle,[])
local_walker = DeadlockAvoidanceShortestDistanceWalker( local_walker = DeadlockAvoidanceShortestDistanceWalker(
self.env, self.env,
self.dead_lock_avoidance_agent.shortest_distance_walker.agent_positions, self.dead_lock_avoidance_agent.shortest_distance_walker.agent_positions,
...@@ -196,6 +197,7 @@ class Extra(ObservationBuilder): ...@@ -196,6 +197,7 @@ class Extra(ObservationBuilder):
my_shortest_path_to_check = shortest_distance_agent_map[handle] my_shortest_path_to_check = shortest_distance_agent_map[handle]
next_step_ok = self.dead_lock_avoidance_agent.check_agent_can_move(my_shortest_path_to_check, next_step_ok = self.dead_lock_avoidance_agent.check_agent_can_move(my_shortest_path_to_check,
opp_agents, opp_agents,
same_agents,
full_shortest_distance_agent_map) full_shortest_distance_agent_map)
return next_step_ok return next_step_ok
......
...@@ -25,7 +25,7 @@ class FastTreeObs(ObservationBuilder): ...@@ -25,7 +25,7 @@ class FastTreeObs(ObservationBuilder):
def __init__(self, max_depth): def __init__(self, max_depth):
self.max_depth = max_depth self.max_depth = max_depth
self.observation_dim = 27 self.observation_dim = 32
def build_data(self): def build_data(self):
if self.env is not None: if self.env is not None:
...@@ -40,8 +40,8 @@ class FastTreeObs(ObservationBuilder): ...@@ -40,8 +40,8 @@ class FastTreeObs(ObservationBuilder):
else: else:
self.dead_lock_avoidance_agent = None self.dead_lock_avoidance_agent = None
def find_all_cell_where_agent_can_choose(self): def find_all_switches(self):
switches = {} self.switches = {}
for h in range(self.env.height): for h in range(self.env.height):
for w in range(self.env.width): for w in range(self.env.width):
pos = (h, w) pos = (h, w)
...@@ -49,12 +49,13 @@ class FastTreeObs(ObservationBuilder): ...@@ -49,12 +49,13 @@ class FastTreeObs(ObservationBuilder):
possible_transitions = self.env.rail.get_transitions(*pos, dir) possible_transitions = self.env.rail.get_transitions(*pos, dir)
num_transitions = fast_count_nonzero(possible_transitions) num_transitions = fast_count_nonzero(possible_transitions)
if num_transitions > 1: if num_transitions > 1:
if pos not in switches.keys(): if pos not in self.switches.keys():
switches.update({pos: [dir]}) self.switches.update({pos: [dir]})
else: else:
switches[pos].append(dir) self.switches[pos].append(dir)
switches_neighbours = {} def find_all_switch_neighbours(self):
self.switches_neighbours = {}
for h in range(self.env.height): for h in range(self.env.height):
for w in range(self.env.width): for w in range(self.env.width):
# look one step forward # look one step forward
...@@ -64,35 +65,34 @@ class FastTreeObs(ObservationBuilder): ...@@ -64,35 +65,34 @@ class FastTreeObs(ObservationBuilder):
for d in range(4): for d in range(4):
if possible_transitions[d] == 1: if possible_transitions[d] == 1:
new_cell = get_new_position(pos, d) new_cell = get_new_position(pos, d)
if new_cell in switches.keys() and pos not in switches.keys(): if new_cell in self.switches.keys() and pos not in self.switches.keys():
if pos not in switches_neighbours.keys(): if pos not in self.switches_neighbours.keys():
switches_neighbours.update({pos: [dir]}) self.switches_neighbours.update({pos: [dir]})
else: else:
switches_neighbours[pos].append(dir) self.switches_neighbours[pos].append(dir)
self.switches = switches def find_all_cell_where_agent_can_choose(self):
self.switches_neighbours = switches_neighbours self.find_all_switches()
self.find_all_switch_neighbours()
def check_agent_decision(self, position, direction): def check_agent_decision(self, position, direction):
switches = self.switches
switches_neighbours = self.switches_neighbours
agents_on_switch = False agents_on_switch = False
agents_on_switch_all = False agents_on_switch_all = False
agents_near_to_switch = False agents_near_to_switch = False
agents_near_to_switch_all = False agents_near_to_switch_all = False
if position in switches.keys(): if position in self.switches.keys():
agents_on_switch = direction in switches[position] agents_on_switch = direction in self.switches[position]
agents_on_switch_all = True agents_on_switch_all = True
if position in switches_neighbours.keys(): if position in self.switches_neighbours.keys():
new_cell = get_new_position(position, direction) new_cell = get_new_position(position, direction)
if new_cell in switches.keys(): if new_cell in self.switches.keys():
if not direction in switches[new_cell]: if not direction in self.switches[new_cell]:
agents_near_to_switch = direction in switches_neighbours[position] agents_near_to_switch = direction in self.switches_neighbours[position]
else: else:
agents_near_to_switch = direction in switches_neighbours[position] agents_near_to_switch = direction in self.switches_neighbours[position]
agents_near_to_switch_all = direction in switches_neighbours[position] agents_near_to_switch_all = direction in self.switches_neighbours[position]
return agents_on_switch, agents_near_to_switch, agents_near_to_switch_all, agents_on_switch_all return agents_on_switch, agents_near_to_switch, agents_near_to_switch_all, agents_on_switch_all
...@@ -151,15 +151,6 @@ class FastTreeObs(ObservationBuilder): ...@@ -151,15 +151,6 @@ class FastTreeObs(ObservationBuilder):
self.build_data() self.build_data()
return return
def fast_argmax(self, array):
if array[0] == 1:
return 0
if array[1] == 1:
return 1
if array[2] == 1:
return 2
return 3
def _explore(self, handle, new_position, new_direction, depth=0): def _explore(self, handle, new_position, new_direction, depth=0):
has_opp_agent = 0 has_opp_agent = 0
has_same_agent = 0 has_same_agent = 0
...@@ -269,6 +260,7 @@ class FastTreeObs(ObservationBuilder): ...@@ -269,6 +260,7 @@ class FastTreeObs(ObservationBuilder):
# observation[24] : If there is a switch on the path which agent can not use -> 1 # observation[24] : If there is a switch on the path which agent can not use -> 1
# observation[25] : If there is a switch on the path which agent can not use -> 1 # observation[25] : If there is a switch on the path which agent can not use -> 1
# observation[26] : If there the dead-lock avoidance agent predicts a deadlock -> 1 # observation[26] : If there the dead-lock avoidance agent predicts a deadlock -> 1
# observation[27] : If there the agent can only walk forward or stop -> 1
observation = np.zeros(self.observation_dim) observation = np.zeros(self.observation_dim)
visited = [] visited = []
...@@ -313,18 +305,21 @@ class FastTreeObs(ObservationBuilder): ...@@ -313,18 +305,21 @@ class FastTreeObs(ObservationBuilder):
observation[14 + dir_loop] = has_opp_agent observation[14 + dir_loop] = has_opp_agent
observation[18 + dir_loop] = has_same_agent observation[18 + dir_loop] = has_same_agent
observation[22 + dir_loop] = has_target observation[22 + dir_loop] = has_target
observation[26 + dir_loop] = int(np.math.isinf(new_cell_dist))
agents_on_switch, \
agents_near_to_switch, \
agents_near_to_switch_all, \
agents_on_switch_all = \
self.check_agent_decision(agent_virtual_position, agent.direction)
observation[7] = int(agents_on_switch)
observation[8] = int(agents_near_to_switch)
observation[9] = int(agents_near_to_switch_all)
action = self.dead_lock_avoidance_agent.act([handle], 0.0)
observation[30] = int(action == RailEnvActions.STOP_MOVING)
observation[31] = int(fast_count_nonzero(possible_transitions) == 1)
agents_on_switch, \
agents_near_to_switch, \
agents_near_to_switch_all, \
agents_on_switch_all = \
self.check_agent_decision(agent_virtual_position, agent.direction)
observation[7] = int(agents_on_switch)
observation[8] = int(agents_near_to_switch)
observation[9] = int(agents_near_to_switch_all)
action = self.dead_lock_avoidance_agent.act([handle], 0.0)
observation[26] = int(action == RailEnvActions.STOP_MOVING)
self.env.dev_obs_dict.update({handle: visited}) self.env.dev_obs_dict.update({handle: visited})
return observation return observation
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment