Commit 3f321add authored by nilabha's avatar nilabha

Merge branch 'rllib-IL' into 'flatland-paper-baselines'

Rllib IL Changes

See merge request !14
parents 50e24f72 48a2454e
Pipeline #5030 failed with stage
in 2 minutes and 47 seconds
from .registry import CUSTOM_ALGORITHMS
This diff is collapsed.
"""
Registry of custom implemented algorithms names
Please refer to the following examples to add your custom algorithms :
- AlphaZero : https://github.com/ray-project/ray/tree/master/rllib/contrib/alpha_zero
- bandits : https://github.com/ray-project/ray/tree/master/rllib/contrib/bandits
- maddpg : https://github.com/ray-project/ray/tree/master/rllib/contrib/maddpg
- random_agent: https://github.com/ray-project/ray/tree/master/rllib/contrib/random_agent
An example integration of the random agent is shown here :
- https://github.com/AIcrowd/neurips2020-procgen-starter-kit/tree/master/algorithms/custom_random_agent
"""
def _import_imitation_trainer():
from .imitation_agent.imitation_trainer import ImitationAgent
return ImitationAgent
CUSTOM_ALGORITHMS = {
"ImitationAgent": _import_imitation_trainer
}
\ No newline at end of file
flatland-sparse-small-tree-fc-apex-il-trainer:
run: ImitationAgent
env: flatland_sparse
stop:
timesteps_total: 15000000 # 1.5e7
checkpoint_freq: 50
checkpoint_at_end: True
keep_checkpoints_num: 50
checkpoint_score_attr: episode_reward_mean
num_samples: 3
config:
num_workers: 6
num_envs_per_worker: 5
num_gpus: 0
clip_rewards: False
vf_clip_param: 500.0
entropy_coeff: 0.01
# effective batch_size: train_batch_size * num_agents_in_each_environment [5, 10]
# see https://github.com/ray-project/ray/issues/4628
train_batch_size: 1000 # 5000
rollout_fragment_length: 50 # 100
sgd_minibatch_size: 100 # 500
vf_share_layers: False
env_config:
custom_fn: imitation_ppo_train_fn
expert:
ratio: 0.5
ratio_decay: 1
min_ratio: 0.5
observation: tree
observation_config:
max_depth: 2
shortest_path_max_depth: 30
generator: sparse_rail_generator
generator_config: small_v0
wandb:
project: flatland-paper
entity: aicrowd
tags: ["small_v0", "tree_obs", "apex_rllib_il"] # TODO should be set programmatically
model:
fcnet_activation: relu
fcnet_hiddens: [256, 256]
vf_share_layers: False # Should be same as ppo vf_shared_layers
flatland-sparse-small-tree-fc-apex-il-trainer:
run: ImitationAgent
env: flatland_sparse
stop:
timesteps_total: 15000000 # 1.5e7
checkpoint_freq: 50
checkpoint_at_end: True
keep_checkpoints_num: 50
checkpoint_score_attr: episode_reward_mean
num_samples: 3
config:
num_workers: 13
num_envs_per_worker: 5
num_gpus: 0
clip_rewards: False
vf_clip_param: 500.0
entropy_coeff: 0.01
# effective batch_size: train_batch_size * num_agents_in_each_environment [5, 10]
# see https://github.com/ray-project/ray/issues/4628
train_batch_size: 1000 # 5000
rollout_fragment_length: 50 # 100
sgd_minibatch_size: 100 # 500
vf_share_layers: False
env_config:
observation: tree
observation_config:
max_depth: 2
shortest_path_max_depth: 30
generator: sparse_rail_generator
generator_config: small_v0
wandb:
project: flatland-paper
entity: aicrowd
tags: ["small_v0", "tree_obs", "apex_rllib_il"] # TODO should be set programmatically
model:
fcnet_activation: relu
fcnet_hiddens: [256, 256]
vf_share_layers: False # Should be same as ppo vf_shared_layers
evaluation_num_workers: 2
# Enable evaluation, once per training iteration.
evaluation_interval: 3
evaluation_interval: 1
# Run 1 episode each time evaluation runs.
evaluation_num_episodes: 2
# Override the env config for evaluation.
......
import numpy as np
from collections import deque
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_env import RailEnvActions
class Vertex:
def __init__(self, y, x, idx):
self.point = (y, x)
self.idx = idx
self.out = [[], [], [], []]
self.in_edges = [[], [], [], []]
class Edge:
def __init__(self, start_v, end_v, start_dir, end_dir, action_type):
self.start_v = start_v
self.end_v = end_v
self.start_direction = start_dir
self.end_direction = end_dir
self.action_type = action_type
class CellGraph:
def __init__(self, env : RailEnv):
self.env = env
self._build_graph()
def _build_graph(self):
width = self.env.width
height = self.env.height
self.vertex_idx = np.zeros((height, width), dtype=np.int)
self.vertex_idx.fill(-1)
self.vertexes = []
for y in range(height):
for x in range(width):
if self._is_rail(y, x):
idx = len(self.vertexes)
self.vertexes.append(Vertex(y, x, idx))
self.vertex_idx[y, x] = idx
# print('vertexes:', len(self.vertexes))
edges_cnt = 0
for v_idx, v in enumerate(self.vertexes):
start_point = v.point
for direction in range(4):
directions = self._possible_directions(start_point, direction)
# assert len(directions) <= 2
for end_direction in directions:
next_point = self._next_point(start_point, end_direction)
end_v = self._vertex_idx_from_point(next_point)
action_type = self._action_from_directions(direction, end_direction)
e = Edge(v_idx, end_v, direction, end_direction, action_type)
v.out[direction].append(e)
self.vertexes[end_v].in_edges[end_direction].append(e)
edges_cnt += 1
# print('edges_cnt', edges_cnt)
def _is_rail(self, y, x):
return self.env.rail.grid[y, x] != 0
def _next_point(self, point, direction):
if direction==0:
return (point[0]-1, point[1])
elif direction==1:
return (point[0], point[1]+1)
elif direction==2:
return (point[0]+1, point[1])
else:
return (point[0], point[1]-1)
def _possible_directions(self, point, in_direction):
return np.flatnonzero(self.env.rail.get_transitions(point[0], point[1], in_direction))
def _vertex_idx_from_point(self, point):
assert (point[0] >= 0) and (point[0] < self.vertex_idx.shape[0])
assert (point[1] >= 0) and (point[1] < self.vertex_idx.shape[1])
return self.vertex_idx[point[0], point[1]]
def position_from_vertexid(self, vertexid: int):
return self.vertexes[vertexid].point
def _action_from_directions(self, in_direction, new_direction):
if in_direction==new_direction:
return RailEnvActions.MOVE_FORWARD
if (in_direction+1)%4 == new_direction:
return RailEnvActions.MOVE_RIGHT
elif (in_direction-1)%4 == new_direction:
return RailEnvActions.MOVE_LEFT
else:
return RailEnvActions.MOVE_FORWARD
This diff is collapsed.
import traceback
from copy import deepcopy
from typing import Dict
from flatland.envs.rail_env import RailEnv, RailAgentStatus, RailEnvActions
from libs import cell_graph_rescheduling, cell_graph_partial_rescheduling, cell_graph_rescheduling_data
from libs.cell_graph import CellGraph
from libs.cell_graph_agent import CellGraphAgent
from libs.cell_graph_locker import CellGraphLocker
class CellGraphDispatcher:
def __init__(self, env: RailEnv, sort_function=None):
self.env = env
self.graph = CellGraph(env)
self.locker = CellGraphLocker(self.graph)
max_steps = env._max_episode_steps
self.controllers = [CellGraphAgent(agent, self.graph, self.locker, i, max_steps) for i, agent in
enumerate(env.agents)]
self.action_dict = {}
if sort_function is None:
sort_function = lambda idx: self.controllers[idx].dist_to_target[
self.graph._vertex_idx_from_point(env.agents[idx].initial_position),
env.agents[idx].initial_direction] \
- 10000 * env.agents[idx].speed_data['speed']
else:
sort_function = sort_function(self)
self.agents_order = sorted(range(len(env.agents)), key=sort_function)
self.agent_locked_by_malfunction = []
for agent in env.agents:
self.agent_locked_by_malfunction.append(agent.malfunction_data['malfunction'] > 0)
self.crashed = False
self.blocked_agents = set()
def step(self, step) -> Dict[int, RailEnvActions]:
try:
has_new_malfunctions = False
for i, agent in enumerate(self.env.agents):
is_locked = agent.malfunction_data['malfunction']
if agent.status == RailAgentStatus.ACTIVE:
if (not self.agent_locked_by_malfunction[i]) and is_locked:
has_new_malfunctions = True
self.agent_locked_by_malfunction[i] = is_locked
updated = set()
full_recalc_needed = False
# old_locker = None
try:
if has_new_malfunctions:
# print('new malfunction at step', step)
# old_locker = deepcopy(self.locker)
cached_ways, vertex_agent_order, agent_way_position, agent_position_duration = \
cell_graph_rescheduling_data.get_rescheduling_data(self.env, step, self.controllers, self.graph,
self.locker)
vertex_agent_order2 = deepcopy(vertex_agent_order)
agent_way_position2 = deepcopy(agent_way_position)
agent_position_duration2 = deepcopy(agent_position_duration)
new_way, full_recalc_needed = cell_graph_rescheduling.reschedule(cached_ways, vertex_agent_order,
agent_way_position,
agent_position_duration,
self.env, step, self.controllers,
self.graph, self.locker)
for i in self.agents_order:
if len(new_way[i]):
changed = cell_graph_rescheduling.recover_agent_way(self.controllers[i], self.env.agents[i],
self.graph, new_way[i])
if changed:
updated.add(i)
# resheduling failed, try to make a partial rescheduling
except Exception as e:
print("-----------------Rescheduling Exception----------------")
print("Step: ", step)
# traceback.print_exc()
print("-----------------Rescheduling Exception----------------")
updated.clear()
full_recalc_needed = False
# if old_locker is not None:
# self.locker.data = old_locker.data
self.partial_resheduling(cached_ways, vertex_agent_order2, agent_way_position2,
agent_position_duration2, step)
self.limit_max_visited()
for i in self.agents_order:
try:
agent = self.env.agents[i]
# if agent.speed_data['position_fraction'] >= 1.0:
# print('agent', i, 'blocked by some another agent, fraction:', agent.speed_data['position_fraction'])
force_new_path = full_recalc_needed or self.crashed or i in updated
# force_new_path = full_recalc_needed or i in updated
# if (force_new_path and i in self.blocked_agents):
# # self.action_dict.update({i: RailEnvActions.DO_NOTHING})
# force_new_path = False
# # continue
if i in self.blocked_agents:
force_new_path = True
if agent.speed_data['position_fraction'] > 0.0 and not force_new_path:
self.action_dict.update({i: RailEnvActions.DO_NOTHING})
continue
# action = self.controllers[i].act(agent, step, force_new_path=has_new_malfunctions)
action = self.controllers[i].act(agent, step, force_new_path=force_new_path)
self.action_dict.update({i: action})
# act crashed tor one agent
except Exception as e:
print("-----------------Agent step Exception----------------", i)
print("Step: ", step)
# traceback.print_exc()
print("-----------------Agent step Exception----------------")
self.action_dict.update({i: RailEnvActions.DO_NOTHING})
self.limit_max_visited()
# pass
self.blocked_agents.clear()
self.crashed = False
# global step exception handling, no idea what to do here
except Exception as e:
# except ArithmeticError:
self.crashed = True
print("-----------------Step Exception----------------")
print("Step: ", step)
traceback.print_exc()
print("-----------------Step Exception----------------")
# hit_problem = False
# for j in self.agents_order:
# if j == i:
# hit_problem = True
# if hit_problem:
# self.action_dict.update({j: RailEnvActions.STOP_MOVING })
self.action_dict = {i: RailEnvActions.STOP_MOVING for i in self.agents_order}
self.limit_max_visited()
# raise e
return self.action_dict
def partial_resheduling(self, cached_ways, vertex_agent_order2, agent_way_position2, agent_position_duration2,
step):
print('partial_resheduling')
try:
new_way, blocked_agents = cell_graph_partial_rescheduling.partial_reschedule(cached_ways,
vertex_agent_order2,
agent_way_position2,
agent_position_duration2,
self.env, step,
self.controllers, self.graph,
self.locker)
for i in self.agents_order:
if len(new_way[i]):
cell_graph_rescheduling.recover_agent_way(self.controllers[i], self.env.agents[i], self.graph,
new_way[i])
self.blocked_agents.update(blocked_agents)
print('blocked agents', self.blocked_agents)
except Exception as e:
self.crashed = True
print("-----------------Partial rescheduing Exception----------------")
traceback.print_exc()
print("-----------------Partial rescheduing Exception----------------")
self.limit_max_visited()
def limit_max_visited(self):
for c in self.controllers:
c.set_max_visited(100)
import numpy as np
class CellGraphLocker:
def __init__(self, graph):
self.graph = graph
self.data = []
self.reset()
def reset(self):
vertexes = len(self.graph.vertexes)
self.data = [[] for i in range(vertexes)]
def lock(self, vertex_idx, agent_idx, duration):
# assert not self.is_locked(vertex_idx, agent_idx, duration)
if len(self.data[vertex_idx])==0:
self.data[vertex_idx].append((duration, agent_idx))
return
# index = self.equal_or_greater_index(vertex_idx, duration[0])
index = self.equal_or_greater_index_end(vertex_idx, duration[1])
if index < len(self.data[vertex_idx]):
curr_lock_info = self.data[vertex_idx][index]
if (curr_lock_info[1] == agent_idx) and self._has_intersection(curr_lock_info[0], duration):
assert (curr_lock_info[0][0] <= duration[0]) and (duration[1] <= curr_lock_info[0][1])
return
assert curr_lock_info[0][0] >= duration[1]
self.data[vertex_idx].insert(index, (duration, agent_idx))
# if (curr_lock_info[1]==agent_idx) and (curr_lock_info[0][1] == duration[1]) and (curr_lock_info[0][0] <= duration[0]):
# self.data[vertex_idx][index] = (duration, agent_idx)
# else:
# self.data[vertex_idx].insert(index, (duration, agent_idx) )
else:
self.data[vertex_idx].append((duration, agent_idx))
def is_locked(self, vertex_idx, agent_idx, duration):
if len(self.data[vertex_idx])==0:
return False
new_lock = (duration, agent_idx)
left_lock = None
right_lock = None
index = self.equal_or_greater_index(vertex_idx, duration[0])
if index < len(self.data[vertex_idx]):
if self.data[vertex_idx][index][0][0] == duration:
return True
right_lock = self.data[vertex_idx][index]
if index>0:
left_lock = self.data[vertex_idx][index - 1]
else:
left_lock = self.data[vertex_idx][index - 1]
return self._has_conflict(left_lock, new_lock) or self._has_conflict(new_lock, right_lock)
# index = self.equal_or_greater_index(vertex_idx, duration[0])
# if index < len(self.data[vertex_idx]):
# lock_duration, lock_agent_idx = self.data[vertex_idx][index]
# if (lock_duration[0] < duration[1]) and (agent_idx != lock_agent_idx):
# return True
#
# if index > 0:
# lock_duration, lock_agent_idx = self.data[vertex_idx][index - 1]
# if (lock_duration[1] > duration[0]) and (agent_idx != lock_agent_idx):
# return True
#
# return False
def _has_conflict(self, left, right):
if (left is None) or (right is None):
return False
d1 = left[0]
d2 = right[0]
if left[1] > right[1]:
d1 = (d1[0], d1[1] + 1)
return self._has_intersection(d1, d2)
def next_free_time(self, vertex_idx, agent_idx, duration):
index = self.equal_or_greater_index_end(vertex_idx, duration[1])
if index < len(self.data[vertex_idx]):
# lock_duration = self.data[vertex_idx][index][0]
# if self._has_intersection(duration, lock_duration):
# return lock_duration[1]
if self._has_conflict((duration, agent_idx), self.data[vertex_idx][index]):
return self.data[vertex_idx][index][0][1]
if index > 0:
# lock_duration = self.data[vertex_idx][index - 1][0]
# if self._has_intersection(duration, lock_duration):
# return lock_duration[1]
if self._has_conflict(self.data[vertex_idx][index-1], (duration, agent_idx)):
return self.data[vertex_idx][index-1][0][1]
# print('already free')
return duration[0]
# index = self.equal_or_greater_index(vertex_idx, duration[0])
# if index < len(self.data[vertex_idx]):
# lock_duration, lock_agent_idx = self.data[vertex_idx][index]
# if (lock_duration[0] < duration[1]) and (agent_idx != lock_agent_idx):
# return True
#
# if index > 0:
# lock_duration, lock_agent_idx = self.data[vertex_idx][index - 1]
# if (lock_duration[1] > duration[0]) and (agent_idx != lock_agent_idx):
# return True
#
# return False
def unlock(self, vertex_idx, agent_idx, duration):
assert len(self.data[vertex_idx])
index = self.equal_or_greater_index(vertex_idx, duration[0])
assert (index >= 0) and (index < len(self.data[vertex_idx]))
lock_duration, lock_agent_idx = self.data[vertex_idx][index]
assert (lock_duration == duration) and (lock_agent_idx == agent_idx)
self.data[vertex_idx].pop(index)
def equal_or_greater_index(self, vertex_idx, start_time):
# d = self.data[vertex_idx]
#
# if not len(d):
# return 0
#
# l = 0
# r = len(d) - 1
#
# while l <= r:
# c = (l + r) // 2
#
# lock_duration_start = d[c][0][0]
# if lock_duration_start == start_time:
# return c
# elif lock_duration_start < start_time:
# l = c + 1
# else:
# r = c - 1
#
# return max(l, r)
#
for i, (lock_duration, lock_agent_idx) in enumerate(self.data[vertex_idx]):
if lock_duration[0] >= start_time:
return i
return len(self.data[vertex_idx])
def equal_or_greater_index_end(self, vertex_idx, end_time):
for i, (lock_duration, lock_agent_idx) in enumerate(self.data[vertex_idx]):
if lock_duration[1] >= end_time:
return i
return len(self.data[vertex_idx])
def _has_intersection(self, a, b):
return not ((a[1] <= b[0]) or (b[1] <= a[0]))
def unlock_agent(self, agent_id):
for i in range(len(self.data)):
for j in reversed(range(len(self.data[i]))):
if self.data[i][j][1] == agent_id:
self.data[i].pop(j)
def unlock_agent_with_list(self, agent_id, vertex_list):
for i in vertex_list:
for j in reversed(range(len(self.data[i]))):
if self.data[i][j][1] == agent_id:
self.data[i].pop(j)
def last_time_step(self, vertex_idx, agent_idx):
if not len(self.data[vertex_idx]):
return 0
res = self.data[vertex_idx][-1][0][1]
if self.data[vertex_idx][-1][1] != agent_idx:
res += 1
return res
import libs.cell_graph_agent
from libs.cell_graph import CellGraph
from libs.cell_graph_locker import CellGraphLocker
from libs.cell_graph_agent import AgentWayStep, CellGraphAgent
from flatland.envs.rail_env import RailEnv, RailAgentStatus, RailEnvActions
from flatland.envs.agent_utils import EnvAgent
from typing import List
def partial_reschedule(cached_ways, vertex_agent_order, agent_way_position, agent_position_duration,
env: RailEnv, step_idx, controllers, graph: CellGraph, locker: CellGraphLocker):
locker.reset()
new_way = [[] for i in range(len(controllers))]
def rescheduling_main():
# recalculate new duration for each agent on each cell of the cached way
position_updated = True
full_recalc_needed = False
blocked_vertexes = set()
for i, agent in enumerate(env.agents):
if agent.status == RailAgentStatus.ACTIVE:
controller = controllers[i]
way = controller.get_cached_way()
if len(way):
# assert len(way)
first_vertex = way[-1].vertex_idx
if vertex_agent_order[first_vertex][0] != i:
print('blocked at start', i, first_vertex)
blocked_vertexes.add(first_vertex)
# assert vertex_agent_order[first_vertex][0] == i
# assert way[0].arrival_time <= step_idx, (way[0].arrival_time, step_idx)
while position_updated:
position_updated = False
for i in range(len(controllers)):
agent = env.agents[i]
if (agent_way_position[i] >= len(cached_ways[i])) or agent_done(env, i):
continue
vertex_idx = cached_ways[i][agent_way_position[i]].vertex_idx
duration = agent_position_duration[i]
if vertex_idx in blocked_vertexes:
continue
if agent_way_position[i] == len(cached_ways[i])-1:
if vertex_idx == controllers[i].target_vertex: # target vertex
new_way[i].append(AgentWayStep(vertex_idx=vertex_idx,
direction=None,
arrival_time=duration[0],
departure_time=duration[1],
wait_steps = 0,
action = None,
prev_way_id = -1))
locker.lock(vertex_idx, i, (duration[0], duration[1]))
# assert len(vertex_agent_order[vertex_idx]) and vertex_agent_order[vertex_idx][0] == i
if not (len(vertex_agent_order[vertex_idx]) and vertex_agent_order[vertex_idx][0] == i):
blocked_vertexes.add(vertex_idx)
continue
vertex_agent_order[vertex_idx].pop(0)
agent_position_duration[i] = None
agent_way_position[i] += 1
position_updated = True
else:
next_vertex_idx = cached_ways[i][agent_way_position[i] + 1].vertex_idx
ticks_per_step = int(round(1 / env.agents[i].speed_data['speed']))
# if vertex_agent_order[next_vertex_idx][0] == i and vertex_agent_order[vertex_idx][0] == i: # possible move to next vertex
if vertex_agent_order[next_vertex_idx][0] == i and next_vertex_idx not in blocked_vertexes: # possible move to next vertex
new_duration = (duration[0], max(duration[1], locker.last_time_step(next_vertex_idx, i)))
# if agent_way_position[i]==0 and agent.speed_data['position_fraction'] > 0: