Skip to content
Snippets Groups Projects
Commit 02d8a6d7 authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files

Q&D

parent 800f847f
No related branches found
No related tags found
No related merge requests found
......@@ -194,8 +194,8 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params, clo
policy = PPOAgent(state_size, action_size, n_agents)
if False:
policy = MultiPolicy(state_size, action_size, n_agents, train_env)
if True:
policy = DeadLockAvoidanceAgent(train_env,state_size, action_size)
if False:
policy = DeadLockAvoidanceAgent(train_env, state_size, action_size)
# Load existing policy
if train_params.load_policy is not None:
......@@ -253,6 +253,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params, clo
reset_timer.start()
obs, info = train_env.reset(regenerate_rail=True, regenerate_schedule=True)
policy.reset()
deadLockAvoidanceAgent = DeadLockAvoidanceAgent(train_env, state_size, action_size)
reset_timer.end()
if train_params.render:
......@@ -273,20 +274,25 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params, clo
for step in range(max_steps - 1):
inference_timer.start()
policy.start_step()
deadLockAvoidanceAgent.start_step()
for agent in train_env.get_agent_handles():
if info['action_required'][agent]:
update_values[agent] = True
action = policy.act(agent,agent_obs[agent], eps=eps_start)
action_count[action] += 1
actions_taken.append(action)
else:
# An action is not required if the train hasn't joined the railway network,
# if it already reached its target, or if is currently malfunctioning.
update_values[agent] = False
action = 0
action = deadLockAvoidanceAgent.act(agent, None, 0.0)
update_values[agent] = False
if action != RailEnvActions.STOP_MOVING:
if info['action_required'][agent]:
update_values[agent] = True
action = policy.act(agent, agent_obs[agent], eps=eps_start)
action_count[action] += 1
actions_taken.append(action)
else:
# An action is not required if the train hasn't joined the railway network,
# if it already reached its target, or if is currently malfunctioning.
action = 0
action_dict.update({agent: action})
policy.end_step()
deadLockAvoidanceAgent.end_step()
inference_timer.end()
# Environment step
......@@ -458,22 +464,26 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params):
score = 0.0
obs, info = env.reset(regenerate_rail=True, regenerate_schedule=True)
deadLockAvoidanceAgent = DeadLockAvoidanceAgent(env, None, None)
final_step = 0
for step in range(max_steps - 1):
deadLockAvoidanceAgent.start_step()
for agent in env.get_agent_handles():
if tree_observation.check_is_observation_valid(agent_obs[agent]):
agent_obs[agent] = tree_observation.get_normalized_observation(obs[agent], tree_depth=tree_depth,
observation_radius=observation_radius)
action = 0
if info['action_required'][agent]:
if tree_observation.check_is_observation_valid(agent_obs[agent]):
action = policy.act(agent,agent_obs[agent], eps=0.0)
action = deadLockAvoidanceAgent.act(agent, None, 0)
if action != RailEnvActions.STOP_MOVING:
if info['action_required'][agent]:
if tree_observation.check_is_observation_valid(agent_obs[agent]):
action = policy.act(agent, agent_obs[agent], eps=0.0)
action_dict.update({agent: action})
obs, all_rewards, done, info = env.step(action_dict)
deadLockAvoidanceAgent.end_step()
for agent in env.get_agent_handles():
score += all_rewards[agent]
......@@ -505,7 +515,7 @@ if __name__ == "__main__":
parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=200000, type=int)
parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=0,
type=int)
parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=25, type=int)
parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=5, type=int)
parser.add_argument("--checkpoint_interval", help="checkpoint interval", default=200, type=int)
parser.add_argument("--eps_start", help="max exploration", default=0.1, type=float)
parser.add_argument("--eps_end", help="min exploration", default=0.0001, type=float)
......@@ -525,9 +535,9 @@ if __name__ == "__main__":
parser.add_argument("--num_threads", help="number of threads PyTorch can use", default=1, type=int)
parser.add_argument("--load_policy", help="policy filename (reference) to load", default="", type=str)
parser.add_argument("--use_extra_observation", help="extra observation", default=True, type=bool)
parser.add_argument("--max_depth", help="max depth", default=-1, type=int)
parser.add_argument("--max_depth", help="max depth", default=1, type=int)
parser.add_argument("--close_following", help="enable close following feature", default=True, type=bool)
parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=2, type=int)
parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=1, type=int)
parser.add_argument("--render", help="render 1 episode in 100", default=False, type=bool)
training_params = parser.parse_args()
......
......@@ -13,7 +13,6 @@ class MultiPolicy(Policy):
self.action_size = action_size
self.memory = []
self.loss = 0
self.dead_lock_avoidance_policy = DeadLockAvoidanceAgent(env, state_size, action_size)
self.extra_policy = ExtraPolicy(state_size, action_size)
self.ppo_policy = PPOAgent(state_size + action_size, action_size, n_agents, env)
......@@ -40,9 +39,6 @@ class MultiPolicy(Policy):
self.ppo_policy.step(handle, extended_state, action, reward, extended_next_state, done)
def act(self, handle, state, eps=0.):
dead_lock_avoidance_action = self.dead_lock_avoidance_policy.act(handle, state, 0.0)
if dead_lock_avoidance_action == RailEnvActions.STOP_MOVING:
return RailEnvActions.STOP_MOVING
action_extra_state = self.extra_policy.act(handle, state, 0.0)
extended_state = np.copy(state)
for action_itr in np.arange(self.action_size):
......@@ -60,11 +56,9 @@ class MultiPolicy(Policy):
self.extra_policy.test()
def start_step(self):
self.dead_lock_avoidance_policy.start_step()
self.extra_policy.start_step()
self.ppo_policy.start_step()
def end_step(self):
self.dead_lock_avoidance_policy.end_step()
self.extra_policy.end_step()
self.ppo_policy.end_step()
import matplotlib.pyplot as plt
import numpy as np
from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_env import RailEnv, RailEnvActions, fast_count_nonzero
from reinforcement_learning.policy import Policy
from utils.shortest_Distance_walker import ShortestDistanceWalker
class MyWalker(ShortestDistanceWalker):
def __init__(self, env: RailEnv, agent_positions):
def __init__(self, env: RailEnv, agent_positions, switches):
super().__init__(env)
self.shortest_distance_agent_map = np.zeros((self.env.get_num_agents(),
self.env.height,
......@@ -22,25 +22,38 @@ class MyWalker(ShortestDistanceWalker):
self.agent_positions = agent_positions
self.agent_map = {}
self.opp_agent_map = {}
self.same_agent_map = {}
self.switches = switches
def get_action(self, handle, min_distances):
if min_distances[0] != np.inf:
m = min(min_distances)
if min_distances[0] < m + 5:
return 0
return np.argmin(min_distances)
def getData(self):
return self.shortest_distance_agent_map, self.full_shortest_distance_agent_map
def callback(self, handle, agent, position, direction, action):
def callback(self, handle, agent, position, direction, action, possible_transitions):
opp_a = self.agent_positions[position]
if opp_a != -1 and opp_a != handle:
if self.env.agents[opp_a].direction != direction:
d = self.agent_map.get(handle, [])
d = self.opp_agent_map.get(handle, [])
if opp_a not in d:
d.append(opp_a)
self.agent_map.update({handle: d})
d = self.agent_map.get(handle, [])
if len(d) == 0:
self.shortest_distance_agent_map[(handle, position[0], position[1])] = 1
self.opp_agent_map.update({handle: d})
else:
if len(self.opp_agent_map.get(handle, [])) == 0:
d = self.same_agent_map.get(handle, [])
if opp_a not in d:
d.append(opp_a)
self.same_agent_map.update({handle: d})
if len(self.opp_agent_map.get(handle, [])) == 0:
if self.switches.get(position, None) is None:
self.shortest_distance_agent_map[(handle, position[0], position[1])] = 1
self.full_shortest_distance_agent_map[(handle, position[0], position[1])] = 1
......@@ -52,6 +65,7 @@ class DeadLockAvoidanceAgent(Policy):
self.memory = []
self.loss = 0
self.agent_can_move = {}
self.switches = {}
def step(self, handle, state, action, reward, next_state, done):
pass
......@@ -61,11 +75,21 @@ class DeadLockAvoidanceAgent(Policy):
check = self.agent_can_move.get(handle, None)
if check is None:
return RailEnvActions.STOP_MOVING
return check[3]
def reset(self):
pass
self.switches = {}
for h in range(self.env.height):
for w in range(self.env.width):
pos = (h, w)
for dir in range(4):
possible_transitions = self.env.rail.get_transitions(*pos, dir)
num_transitions = fast_count_nonzero(possible_transitions)
if num_transitions > 1:
if pos not in self.switches.keys():
self.switches.update({pos: [dir]})
else:
self.switches[pos].append(dir)
def start_step(self):
self.shortest_distance_mapper()
......@@ -86,7 +110,7 @@ class DeadLockAvoidanceAgent(Policy):
if agent.position is not None:
agent_positions[agent.position] = handle
my_walker = MyWalker(self.env, agent_positions)
my_walker = MyWalker(self.env, agent_positions, self.switches)
for handle in range(self.env.get_num_agents()):
agent = self.env.agents[handle]
if agent.status <= RailAgentStatus.ACTIVE:
......@@ -96,14 +120,15 @@ class DeadLockAvoidanceAgent(Policy):
self.agent_can_move = {}
agent_positions_map = (agent_positions > -1).astype(int)
for handle in range(self.env.get_num_agents()):
opp_agents = my_walker.agent_map.get(handle, [])
opp_agents = my_walker.opp_agent_map.get(handle, [])
same_agents = my_walker.same_agent_map.get(handle, [])
me = shortest_distance_agent_map[handle]
delta = me
next_step_ok = True
next_position, next_direction, action = my_walker.walk_one_step(handle)
next_position, next_direction, action, possible_transitions = my_walker.walk_one_step(handle)
for opp_a in opp_agents:
opp = full_shortest_distance_agent_map[opp_a]
delta = (delta - opp - agent_positions_map > 0).astype(int)
delta = ((delta - opp - agent_positions_map) > 0).astype(int)
if (np.sum(delta) < 3):
next_step_ok = False
......
......@@ -84,7 +84,7 @@ class Extra(ObservationBuilder):
def getData(self):
return self.shortest_distance_agent_counter, self.shortest_distance_agent_direction_counter
def callback(self, handle, agent, position, direction, action):
def callback(self, handle, agent, position, direction, action, possible_transitions):
self.shortest_distance_agent_counter[position] += 1
self.shortest_distance_agent_direction_counter[(position[0], position[1], direction)] += 1
......
......@@ -16,7 +16,7 @@ class ShortestDistanceWalker:
new_position = get_new_position(position, new_direction)
dist = self.env.distance_map.get()[handle, new_position[0], new_position[1], new_direction]
return new_position, new_direction, dist, RailEnvActions.MOVE_FORWARD
return new_position, new_direction, dist, RailEnvActions.MOVE_FORWARD, possible_transitions
else:
min_distances = []
positions = []
......@@ -34,28 +34,31 @@ class ShortestDistanceWalker:
directions.append(None)
a = self.get_action(handle, min_distances)
return positions[a], directions[a], min_distances[a], a + 1
return positions[a], directions[a], min_distances[a], a + 1, possible_transitions
def get_action(self, handle, min_distances):
return np.argmin(min_distances)
def callback(self, handle, agent, position, direction, action):
def callback(self, handle, agent, position, direction, action, possible_transitions):
pass
def walk_to_target(self, handle):
def walk_to_target(self, handle, max_step=500):
agent = self.env.agents[handle]
if agent.position is not None:
position = agent.position
else:
position = agent.initial_position
direction = agent.direction
while (position != agent.target):
position, direction, dist, action = self.walk(handle, position, direction)
step = 0
while (position != agent.target) and (step < max_step):
position, direction, dist, action, possible_transitions = self.walk(handle, position, direction)
if position is None:
break
self.callback(handle, agent, position, direction, action)
self.callback(handle, agent, position, direction, action, possible_transitions)
step += 1
def callback_one_step(self, handle, agent, position, direction, action):
def callback_one_step(self, handle, agent, position, direction, action, possible_transitions):
pass
def walk_one_step(self, handle):
......@@ -65,9 +68,10 @@ class ShortestDistanceWalker:
else:
position = agent.initial_position
direction = agent.direction
possible_transitions = (0, 1, 0, 0)
if (position != agent.target):
new_position, new_direction, dist, action = self.walk(handle, position, direction)
new_position, new_direction, dist, action, possible_transitions = self.walk(handle, position, direction)
if new_position is None:
return position, direction, RailEnvActions.STOP_MOVING
self.callback_one_step(handle, agent, new_position, new_direction, action)
return new_position, new_direction, action
return position, direction, RailEnvActions.STOP_MOVING, possible_transitions
self.callback_one_step(handle, agent, new_position, new_direction, action, possible_transitions)
return new_position, new_direction, action, possible_transitions
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment