Skip to content
Snippets Groups Projects
Commit 507f0e86 authored by gmollard's avatar gmollard
Browse files

added simple conflict detection

parent b9836597
No related branches found
No related tags found
No related merge requests found
from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnv
from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.rllib.env.multi_agent_env import MultiAgentEnv
from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.generators import random_rail_generator
from ray.rllib.utils.seed import seed as set_seed from ray.rllib.utils.seed import seed as set_seed
from flatland.envs.generators import complex_rail_generator, random_rail_generator from flatland.envs.generators import complex_rail_generator, random_rail_generator
import numpy as np import numpy as np
from flatland.envs.predictions import DummyPredictorForRailEnv
class RailEnvRLLibWrapper(MultiAgentEnv): class RailEnvRLLibWrapper(MultiAgentEnv):
def __init__(self, config): def __init__(self, config):
# width,
# height,
# rail_generator=random_rail_generator(),
# number_of_agents=1,
# obs_builder_object=TreeObsForRailEnv(max_depth=2)):
super(MultiAgentEnv, self).__init__() super(MultiAgentEnv, self).__init__()
if hasattr(config, "vector_index"): if hasattr(config, "vector_index"):
vector_index = config.vector_index vector_index = config.vector_index
else: else:
vector_index = 1 vector_index = 1
self.predefined_env = False
if config['rail_generator'] == "complex_rail_generator": if config['rail_generator'] == "complex_rail_generator":
self.rail_generator = complex_rail_generator(nr_start_goal=config['number_of_agents'], min_dist=5, self.rail_generator = complex_rail_generator(nr_start_goal=config['number_of_agents'], min_dist=5,
nr_extra=config['nr_extra'], seed=config['seed'] * (1+vector_index)) nr_extra=config['nr_extra'], seed=config['seed'] * (1+vector_index))
else: elif config['rail_generator'] == "random_rail_generator":
raise(Error)
self.rail_generator = random_rail_generator() self.rail_generator = random_rail_generator()
elif config['rail_generator'] == "load_env":
self.predefined_env = True
else:
raise(ValueError, f'Unknown rail generator: {config["rail_generator"]}')
set_seed(config['seed'] * (1+vector_index)) set_seed(config['seed'] * (1+vector_index))
self.env = RailEnv(width=config["width"], height=config["height"], self.env = RailEnv(width=config["width"], height=config["height"],
number_of_agents=config["number_of_agents"], number_of_agents=config["number_of_agents"],
obs_builder_object=config['obs_builder'], rail_generator=self.rail_generator) obs_builder_object=config['obs_builder'], rail_generator=self.rail_generator,
prediction_builder_object=DummyPredictorForRailEnv())
# self.env.load('/home/guillaume/EPFL/Master_Thesis/flatland/baselines/torch_training/railway/complex_scene.pkl') if self.predefined_env:
self.env.load(config['load_env_path'])
# '/home/guillaume/EPFL/Master_Thesis/flatland/baselines/torch_training/railway/complex_scene.pkl')
self.width = self.env.width self.width = self.env.width
self.height = self.env.height self.height = self.env.height
...@@ -42,19 +47,47 @@ class RailEnvRLLibWrapper(MultiAgentEnv): ...@@ -42,19 +47,47 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
def reset(self): def reset(self):
self.agents_done = [] self.agents_done = []
obs = self.env.reset() if self.predefined_env:
obs = self.env.reset(False, False)
else:
obs = self.env.reset()
predictions = self.env.predict()
pred_pos = np.concatenate([[x[:, 1:3]] for x in list(predictions.values())], axis=0)
o = dict() o = dict()
for i_agent in range(len(self.env.agents)):
#for agent, _ in obs.items():
#o[agent] = obs[agent] # prediction of collision that will be added to the observation
# one_hot_agent_encoding = np.zeros(len(self.env.agents)) # Allows to the agent to know which other train is is about to meet (maybe will come
# one_hot_agent_encoding[agent] += 1 # up with a priority order of trains).
# o[agent] = np.append(obs[agent], one_hot_agent_encoding) pred_obs = np.zeros((len(predictions[0]), len(self.env.agents)))
# o['agents'] = obs for time_offset in range(len(predictions[0])):
# obs[0] = [obs[0], np.ones((17, 17)) * 17]
# obs['global_obs'] = np.ones((17, 17)) * 17 # We consider a time window of t-1; t+1 to find a collision
collision_window = list(range(max(time_offset - 1, 0), min(time_offset + 2, len(predictions[0]))))
coord_agent = pred_pos[i_agent, time_offset, 0] + 1000*pred_pos[i_agent, time_offset, 1]
# x coordinates of all other train in the time window
x_coord_other_agents = pred_pos[list(range(i_agent)) + list(range(i_agent+1, len(self.env.agents)))][
:, collision_window, 0]
# y coordinates of all other train in the time window
y_coord_other_agents = pred_pos[list(range(i_agent)) + list(range(i_agent + 1, len(self.env.agents)))][
:, collision_window, 1]
coord_other_agents = x_coord_other_agents + 1000*y_coord_other_agents
# collision_info here contains the index of the agent colliding with the current agent
for collision_info in np.argwhere(coord_agent == coord_other_agents)[:, 0]:
pred_obs[time_offset, collision_info + 1*(collision_info >= i_agent)] = 1
agent_id_one_hot = np.zeros(len(self.env.agents))
agent_id_one_hot[i_agent] = 1
o[i_agent] = [obs[i_agent], agent_id_one_hot, pred_obs]
self.rail = self.env.rail self.rail = self.env.rail
...@@ -72,16 +105,53 @@ class RailEnvRLLibWrapper(MultiAgentEnv): ...@@ -72,16 +105,53 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
o = dict() o = dict()
# print(self.agents_done) # print(self.agents_done)
# print(dones) # print(dones)
for agent, done in dones.items():
if agent not in self.agents_done: for i_agent in range(len(self.env.agents)):
if agent != '__all__': if i_agent not in self.agents_done:
# o[agent] = obs[agent] # prediction of collision that will be added to the observation
#one_hot_agent_encoding = np.zeros(len(self.env.agents)) # Allows to the agent to know which other train is is about to meet (maybe will come
#one_hot_agent_encoding[agent] += 1 # up with a priority order of trains).
o[agent] = obs[agent]#np.append(obs[agent], one_hot_agent_encoding) pred_obs = np.zeros((len(predictions[0]), len(self.env.agents)))
r[agent] = rewards[agent]
for time_offset in range(len(predictions[0])):
d[agent] = dones[agent]
# We consider a time window of t-1; t+1 to find a collision
collision_window = list(range(max(time_offset - 1, 0), min(time_offset + 2, len(predictions[0]))))
coord_agent = pred_pos[i_agent, time_offset, 0] + 1000*pred_pos[i_agent, time_offset, 1]
# x coordinates of all other train in the time window
x_coord_other_agents = pred_pos[list(range(i_agent)) + list(range(i_agent+1, len(self.env.agents)))][
:, collision_window, 0]
# y coordinates of all other train in the time window
y_coord_other_agents = pred_pos[list(range(i_agent)) + list(range(i_agent + 1, len(self.env.agents)))][
:, collision_window, 1]
coord_other_agents = x_coord_other_agents + 1000*y_coord_other_agents
# collision_info here contains the index of the agent colliding with the current agent
for collision_info in np.argwhere(coord_agent == coord_other_agents)[:, 0]:
pred_obs[time_offset, collision_info + 1*(collision_info >= i_agent)] = 1
agent_id_one_hot = np.zeros(len(self.env.agents))
agent_id_one_hot[i_agent] = 1
o[i_agent] = [obs[i_agent], agent_id_one_hot, pred_obs]
r[i_agent] = rewards[i_agent]
d[i_agent] = dones[i_agent]
d['__all__'] = dones['__all__']
# for agent, done in dones.items():
# if agent not in self.agents_done:
# if agent != '__all__':
# # o[agent] = obs[agent]
# #one_hot_agent_encoding = np.zeros(len(self.env.agents))
# #one_hot_agent_encoding[agent] += 1
# o[agent] = obs[agent]#np.append(obs[agent], one_hot_agent_encoding)
#
#
# d[agent] = dones[agent]
for agent, done in dones.items(): for agent, done in dones.items():
if done and agent != '__all__': if done and agent != '__all__':
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment