Skip to content
Snippets Groups Projects
Commit 8c8f6ef8 authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files

FastTreeObs (fix) -> 0.7887

parent e67bb6c3
No related branches found
No related tags found
No related merge requests found
File deleted
File deleted
File deleted
File deleted
...@@ -264,7 +264,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): ...@@ -264,7 +264,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
if info['action_required'][agent]: if info['action_required'][agent]:
update_values[agent] = True update_values[agent] = True
if agent == agent_to_learn: if agent == agent_to_learn or True:
action = policy.act(agent_obs[agent], eps=eps_start) action = policy.act(agent_obs[agent], eps=eps_start)
else: else:
action = policy2.act([agent], eps=eps_start) action = policy2.act([agent], eps=eps_start)
...@@ -284,20 +284,21 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): ...@@ -284,20 +284,21 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
step_timer.start() step_timer.start()
next_obs, all_rewards, done, info = train_env.step(action_dict) next_obs, all_rewards, done, info = train_env.step(action_dict)
for agent in train_env.get_agent_handles(): if False:
act = action_dict.get(agent, RailEnvActions.DO_NOTHING) for agent in train_env.get_agent_handles():
if agent_obs[agent][26] == 1: act = action_dict.get(agent, RailEnvActions.DO_NOTHING)
if act == RailEnvActions.STOP_MOVING: if agent_obs[agent][26] == 1:
all_rewards[agent] *= 0.01 if act == RailEnvActions.STOP_MOVING:
else: all_rewards[agent] *= 0.01
if act == RailEnvActions.MOVE_LEFT:
all_rewards[agent] *= 0.9
else: else:
if agent_obs[agent][7] == 0 and agent_obs[agent][8] == 0: if act == RailEnvActions.MOVE_LEFT:
if act == RailEnvActions.MOVE_FORWARD: all_rewards[agent] *= 0.9
all_rewards[agent] *= 0.01 else:
if done[agent]: if agent_obs[agent][7] == 0 and agent_obs[agent][8] == 0:
all_rewards[agent] += 100.0 if act == RailEnvActions.MOVE_FORWARD:
all_rewards[agent] *= 0.01
if done[agent]:
all_rewards[agent] += 100.0
step_timer.end() step_timer.end()
...@@ -531,7 +532,7 @@ if __name__ == "__main__": ...@@ -531,7 +532,7 @@ if __name__ == "__main__":
parser.add_argument("--load_policy", help="policy filename (reference) to load", default="", type=str) parser.add_argument("--load_policy", help="policy filename (reference) to load", default="", type=str)
parser.add_argument("--use_fast_tree_observation", help="use FastTreeObs instead of stock TreeObs", parser.add_argument("--use_fast_tree_observation", help="use FastTreeObs instead of stock TreeObs",
action='store_true') action='store_true')
parser.add_argument("--max_depth", help="max depth", default=1, type=int) parser.add_argument("--max_depth", help="max depth", default=2, type=int)
training_params = parser.parse_args() training_params = parser.parse_args()
env_params = [ env_params = [
......
...@@ -26,8 +26,7 @@ from reinforcement_learning.dddqn_policy import DDDQNPolicy ...@@ -26,8 +26,7 @@ from reinforcement_learning.dddqn_policy import DDDQNPolicy
VERBOSE = True VERBOSE = True
# Checkpoint to use (remember to push it!) # Checkpoint to use (remember to push it!)
checkpoint = "./checkpoints/201106234244-400.pth" # 15.64082361736683 Depth 1 checkpoint = "./checkpoints/201106234900-5400.pth" # 15.64082361736683 Depth 1
checkpoint = "./checkpoints/201106234900-300.pth" # 15.64082361736683 Depth 1
# Use last action cache # Use last action cache
USE_ACTION_CACHE = False USE_ACTION_CACHE = False
......
...@@ -168,7 +168,7 @@ class FastTreeObs(ObservationBuilder): ...@@ -168,7 +168,7 @@ class FastTreeObs(ObservationBuilder):
if depth >= self.max_depth: if depth >= self.max_depth:
return has_opp_agent, has_same_agent, has_switch, visited return has_opp_agent, has_same_agent, has_switch, visited
# max_explore_steps = 100 # max_explore_steps = 100 -> just to ensure that the exploration ends
cnt = 0 cnt = 0
while cnt < 100: while cnt < 100:
cnt += 1 cnt += 1
...@@ -177,26 +177,41 @@ class FastTreeObs(ObservationBuilder): ...@@ -177,26 +177,41 @@ class FastTreeObs(ObservationBuilder):
opp_a = self.env.agent_positions[new_position] opp_a = self.env.agent_positions[new_position]
if opp_a != -1 and opp_a != handle: if opp_a != -1 and opp_a != handle:
if self.env.agents[opp_a].direction != new_direction: if self.env.agents[opp_a].direction != new_direction:
# opp agent found # opp agent found -> stop exploring. This would be a strong signal.
has_opp_agent = 1 has_opp_agent = 1
return has_opp_agent, has_same_agent, has_switch, visited return has_opp_agent, has_same_agent, has_switch, visited
else: else:
# same agent found
# the agent can follow the agent, because this agent is still moving ahead and there shouldn't
# be any dead-lock nor other issue -> agent is just walking -> if other agent has a deadlock
# this should be avoided by other agents -> one edge case would be when other agent has it's
# target on this branch -> thus the agents should scan further whether there will be an opposite
# agent walking on same track
has_same_agent = 1 has_same_agent = 1
return has_opp_agent, has_same_agent, has_switch, visited # !NOT stop exploring! return has_opp_agent, has_same_agent, has_switch, visited
# convert one-hot encoding to 0,1,2,3 # agents_on_switch == TRUE -> Current cell is a switch where the agent can decide (branch) in exploration
agents_on_switch, \ # agent_near_to_switch == TRUE -> One cell before the switch, where the agent can decide
agents_near_to_switch, \ #
agents_near_to_switch_all, \ agents_on_switch, agents_near_to_switch, _, _ = \
agents_on_switch_all = \
self.check_agent_decision(new_position, new_direction) self.check_agent_decision(new_position, new_direction)
if agents_near_to_switch: if agents_near_to_switch:
# The exploration was walking on a path where the agent can not decide
# Best option would be MOVE_FORWARD -> Skip exploring - just walking
return has_opp_agent, has_same_agent, has_switch, visited return has_opp_agent, has_same_agent, has_switch, visited
possible_transitions = self.env.rail.get_transitions(*new_position, new_direction) possible_transitions = self.env.rail.get_transitions(*new_position, new_direction)
if agents_on_switch: if agents_on_switch:
f = 0 f = 0
for dir_loop in range(4): orientation = new_direction
if fast_count_nonzero(possible_transitions) == 1:
orientation = fast_argmax(possible_transitions)
for dir_loop, branch_direction in enumerate(
[(orientation + dir_loop) % 4 for dir_loop in range(-1, 3)]):
# branch the exploration path and aggregate the found information
# --- OPEN RESEARCH QUESTION ---> is this good or shall we use full detailed information as
# we did in the TreeObservation (FLATLAND) ?
if possible_transitions[dir_loop] == 1: if possible_transitions[dir_loop] == 1:
f += 1 f += 1
hoa, hsa, hs, v = self._explore(handle, hoa, hsa, hs, v = self._explore(handle,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment