Skip to content
Snippets Groups Projects
Commit 507f0e86 authored by gmollard's avatar gmollard
Browse files

added simple conflict detection

parent b9836597
No related branches found
No related tags found
No related merge requests found
from flatland.envs.rail_env import RailEnv
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.generators import random_rail_generator
from ray.rllib.utils.seed import seed as set_seed
from flatland.envs.generators import complex_rail_generator, random_rail_generator
import numpy as np
from flatland.envs.predictions import DummyPredictorForRailEnv
class RailEnvRLLibWrapper(MultiAgentEnv):
def __init__(self, config):
# width,
# height,
# rail_generator=random_rail_generator(),
# number_of_agents=1,
# obs_builder_object=TreeObsForRailEnv(max_depth=2)):
super(MultiAgentEnv, self).__init__()
if hasattr(config, "vector_index"):
vector_index = config.vector_index
else:
vector_index = 1
self.predefined_env = False
if config['rail_generator'] == "complex_rail_generator":
self.rail_generator = complex_rail_generator(nr_start_goal=config['number_of_agents'], min_dist=5,
nr_extra=config['nr_extra'], seed=config['seed'] * (1+vector_index))
else:
raise(Error)
elif config['rail_generator'] == "random_rail_generator":
self.rail_generator = random_rail_generator()
elif config['rail_generator'] == "load_env":
self.predefined_env = True
else:
raise(ValueError, f'Unknown rail generator: {config["rail_generator"]}')
set_seed(config['seed'] * (1+vector_index))
self.env = RailEnv(width=config["width"], height=config["height"],
number_of_agents=config["number_of_agents"],
obs_builder_object=config['obs_builder'], rail_generator=self.rail_generator)
obs_builder_object=config['obs_builder'], rail_generator=self.rail_generator,
prediction_builder_object=DummyPredictorForRailEnv())
# self.env.load('/home/guillaume/EPFL/Master_Thesis/flatland/baselines/torch_training/railway/complex_scene.pkl')
if self.predefined_env:
self.env.load(config['load_env_path'])
# '/home/guillaume/EPFL/Master_Thesis/flatland/baselines/torch_training/railway/complex_scene.pkl')
self.width = self.env.width
self.height = self.env.height
......@@ -42,19 +47,47 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
def reset(self):
self.agents_done = []
obs = self.env.reset()
if self.predefined_env:
obs = self.env.reset(False, False)
else:
obs = self.env.reset()
predictions = self.env.predict()
pred_pos = np.concatenate([[x[:, 1:3]] for x in list(predictions.values())], axis=0)
o = dict()
#for agent, _ in obs.items():
#o[agent] = obs[agent]
# one_hot_agent_encoding = np.zeros(len(self.env.agents))
# one_hot_agent_encoding[agent] += 1
# o[agent] = np.append(obs[agent], one_hot_agent_encoding)
# o['agents'] = obs
# obs[0] = [obs[0], np.ones((17, 17)) * 17]
# obs['global_obs'] = np.ones((17, 17)) * 17
for i_agent in range(len(self.env.agents)):
# prediction of collision that will be added to the observation
# Allows to the agent to know which other train is is about to meet (maybe will come
# up with a priority order of trains).
pred_obs = np.zeros((len(predictions[0]), len(self.env.agents)))
for time_offset in range(len(predictions[0])):
# We consider a time window of t-1; t+1 to find a collision
collision_window = list(range(max(time_offset - 1, 0), min(time_offset + 2, len(predictions[0]))))
coord_agent = pred_pos[i_agent, time_offset, 0] + 1000*pred_pos[i_agent, time_offset, 1]
# x coordinates of all other train in the time window
x_coord_other_agents = pred_pos[list(range(i_agent)) + list(range(i_agent+1, len(self.env.agents)))][
:, collision_window, 0]
# y coordinates of all other train in the time window
y_coord_other_agents = pred_pos[list(range(i_agent)) + list(range(i_agent + 1, len(self.env.agents)))][
:, collision_window, 1]
coord_other_agents = x_coord_other_agents + 1000*y_coord_other_agents
# collision_info here contains the index of the agent colliding with the current agent
for collision_info in np.argwhere(coord_agent == coord_other_agents)[:, 0]:
pred_obs[time_offset, collision_info + 1*(collision_info >= i_agent)] = 1
agent_id_one_hot = np.zeros(len(self.env.agents))
agent_id_one_hot[i_agent] = 1
o[i_agent] = [obs[i_agent], agent_id_one_hot, pred_obs]
self.rail = self.env.rail
......@@ -72,16 +105,53 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
o = dict()
# print(self.agents_done)
# print(dones)
for agent, done in dones.items():
if agent not in self.agents_done:
if agent != '__all__':
# o[agent] = obs[agent]
#one_hot_agent_encoding = np.zeros(len(self.env.agents))
#one_hot_agent_encoding[agent] += 1
o[agent] = obs[agent]#np.append(obs[agent], one_hot_agent_encoding)
r[agent] = rewards[agent]
d[agent] = dones[agent]
for i_agent in range(len(self.env.agents)):
if i_agent not in self.agents_done:
# prediction of collision that will be added to the observation
# Allows to the agent to know which other train is is about to meet (maybe will come
# up with a priority order of trains).
pred_obs = np.zeros((len(predictions[0]), len(self.env.agents)))
for time_offset in range(len(predictions[0])):
# We consider a time window of t-1; t+1 to find a collision
collision_window = list(range(max(time_offset - 1, 0), min(time_offset + 2, len(predictions[0]))))
coord_agent = pred_pos[i_agent, time_offset, 0] + 1000*pred_pos[i_agent, time_offset, 1]
# x coordinates of all other train in the time window
x_coord_other_agents = pred_pos[list(range(i_agent)) + list(range(i_agent+1, len(self.env.agents)))][
:, collision_window, 0]
# y coordinates of all other train in the time window
y_coord_other_agents = pred_pos[list(range(i_agent)) + list(range(i_agent + 1, len(self.env.agents)))][
:, collision_window, 1]
coord_other_agents = x_coord_other_agents + 1000*y_coord_other_agents
# collision_info here contains the index of the agent colliding with the current agent
for collision_info in np.argwhere(coord_agent == coord_other_agents)[:, 0]:
pred_obs[time_offset, collision_info + 1*(collision_info >= i_agent)] = 1
agent_id_one_hot = np.zeros(len(self.env.agents))
agent_id_one_hot[i_agent] = 1
o[i_agent] = [obs[i_agent], agent_id_one_hot, pred_obs]
r[i_agent] = rewards[i_agent]
d[i_agent] = dones[i_agent]
d['__all__'] = dones['__all__']
# for agent, done in dones.items():
# if agent not in self.agents_done:
# if agent != '__all__':
# # o[agent] = obs[agent]
# #one_hot_agent_encoding = np.zeros(len(self.env.agents))
# #one_hot_agent_encoding[agent] += 1
# o[agent] = obs[agent]#np.append(obs[agent], one_hot_agent_encoding)
#
#
# d[agent] = dones[agent]
for agent, done in dones.items():
if done and agent != '__all__':
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment