Skip to content
Snippets Groups Projects
Commit 3f2468c2 authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files

DeadLockAvoidance used for extra obs (current position/direction check and as...

DeadLockAvoidance used for extra obs (current position/direction check and as well for branching checks (one step ahead)
parent fc32274d
No related branches found
No related tags found
No related merge requests found
......@@ -9,6 +9,7 @@ from pprint import pprint
import numpy as np
import psutil
from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
......@@ -196,8 +197,6 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params, clo
policy = PPOAgent(state_size, action_size, n_agents, train_env)
if False:
policy = MultiPolicy(state_size, action_size, n_agents, train_env)
if False:
policy = DeadLockAvoidanceAgent(train_env, state_size, action_size)
# Load existing policy
if train_params.load_policy is not None:
......@@ -244,6 +243,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params, clo
training_id
))
rl_policy = policy
for episode_idx in range(n_episodes + 1):
step_timer = Timer()
reset_timer = Timer()
......@@ -254,6 +254,18 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params, clo
# Reset environment
reset_timer.start()
obs, info = train_env.reset(regenerate_rail=True, regenerate_schedule=True)
# train different number of agents : 1,2,3,... n_agents
for handle in range(train_env.get_num_agents()):
if (episode_idx % n_agents) < handle:
train_env.agents[handle].status = RailAgentStatus.DONE_REMOVED
# start with simple deadlock avoidance agent policy (imitation learning?)
if episode_idx < 500:
policy = DeadLockAvoidanceAgent(train_env, state_size, action_size)
else:
policy = rl_policy
policy.reset()
reset_timer.end()
......@@ -512,7 +524,7 @@ if __name__ == "__main__":
parser.add_argument("--eps_start", help="max exploration", default=0.5, type=float)
parser.add_argument("--eps_end", help="min exploration", default=0.01, type=float)
parser.add_argument("--eps_decay", help="exploration decay", default=0.9985, type=float)
parser.add_argument("--buffer_size", help="replay buffer size", default=int(1e5), type=int)
parser.add_argument("--buffer_size", help="replay buffer size", default=int(1e6), type=int)
parser.add_argument("--buffer_min_size", help="min buffer size to start training", default=0, type=int)
parser.add_argument("--restore_replay_buffer", help="replay buffer to restore", default="", type=str)
parser.add_argument("--save_replay_buffer", help="save replay buffer at each evaluation interval", default=False,
......@@ -528,8 +540,8 @@ if __name__ == "__main__":
parser.add_argument("--load_policy", help="policy filename (reference) to load", default="", type=str)
parser.add_argument("--use_extra_observation", help="extra observation", default=True, type=bool)
parser.add_argument("--close_following", help="enable close following feature", default=True, type=bool)
parser.add_argument("--max_depth", help="max depth", default=1, type=int)
parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=1,
parser.add_argument("--max_depth", help="max depth", default=2, type=int)
parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=2,
type=int)
parser.add_argument("--render", help="render 1 episode in 100", default=False, type=bool)
......
......@@ -148,6 +148,6 @@ class DeadLockAvoidanceAgent(Policy):
for opp_a in opp_agents:
opp = full_shortest_distance_agent_map[opp_a]
delta = ((delta - opp - agent_positions_map) > 0).astype(int)
if (np.sum(delta) < 1 + len(opp_agents)):
if (np.sum(delta) < 2 + len(opp_agents)):
next_step_ok = False
return next_step_ok
......@@ -183,14 +183,21 @@ class Extra(ObservationBuilder):
self.build_data()
return
def fast_argmax(self, array):
if array[0] == 1:
return 0
if array[1] == 1:
return 1
if array[2] == 1:
return 2
return 3
def _check_dead_lock_at_branching_position(self, handle, new_position, branch_direction):
_, full_shortest_distance_agent_map = self.dead_lock_avoidance_agent.shortest_distance_walker.getData()
opp_agents = self.dead_lock_avoidance_agent.shortest_distance_walker.opp_agent_map.get(handle, [])
local_walker = DeadlockAvoidanceShortestDistanceWalker(
self.env,
self.dead_lock_avoidance_agent.shortest_distance_walker.agent_positions,
self.dead_lock_avoidance_agent.shortest_distance_walker.switches)
local_walker.walk_to_target(handle, new_position, branch_direction)
shortest_distance_agent_map, _ = self.dead_lock_avoidance_agent.shortest_distance_walker.getData()
my_shortest_path_to_check = shortest_distance_agent_map[handle]
next_step_ok = self.dead_lock_avoidance_agent.check_agent_can_move(my_shortest_path_to_check,
opp_agents,
full_shortest_distance_agent_map)
return next_step_ok
def _explore(self, handle, new_position, new_direction, depth=0):
......@@ -332,18 +339,7 @@ class Extra(ObservationBuilder):
observation[18 + dir_loop] = has_same_agent
observation[22 + dir_loop] = has_switch
_, full_shortest_distance_agent_map = self.dead_lock_avoidance_agent.shortest_distance_walker.getData()
opp_agents = self.dead_lock_avoidance_agent.shortest_distance_walker.opp_agent_map.get(handle, [])
local_walker = DeadlockAvoidanceShortestDistanceWalker(
self.env,
self.dead_lock_avoidance_agent.shortest_distance_walker.agent_positions,
self.dead_lock_avoidance_agent.shortest_distance_walker.switches)
local_walker.walk_to_target(handle, new_position, branch_direction)
shortest_distance_agent_map, _ = self.dead_lock_avoidance_agent.shortest_distance_walker.getData()
my_shortest_path_to_check = shortest_distance_agent_map[handle]
next_step_ok = self.dead_lock_avoidance_agent.check_agent_can_move(my_shortest_path_to_check,
opp_agents,
full_shortest_distance_agent_map)
next_step_ok = self._check_dead_lock_at_branching_position(handle, new_position, branch_direction)
if next_step_ok:
observation[26 + dir_loop] = 1
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment