diff --git a/RLLib_training/RailEnvRLLibWrapper.py b/RLLib_training/RailEnvRLLibWrapper.py index 704bb126a9f7a7f5c02d1596c822af16ce0be560..f9643a83e6b56537f0c5131c5cdc7ac20d278236 100644 --- a/RLLib_training/RailEnvRLLibWrapper.py +++ b/RLLib_training/RailEnvRLLibWrapper.py @@ -1,39 +1,44 @@ from flatland.envs.rail_env import RailEnv from ray.rllib.env.multi_agent_env import MultiAgentEnv from flatland.envs.observations import TreeObsForRailEnv -from flatland.envs.generators import random_rail_generator from ray.rllib.utils.seed import seed as set_seed from flatland.envs.generators import complex_rail_generator, random_rail_generator import numpy as np +from flatland.envs.predictions import DummyPredictorForRailEnv class RailEnvRLLibWrapper(MultiAgentEnv): def __init__(self, config): - # width, - # height, - # rail_generator=random_rail_generator(), - # number_of_agents=1, - # obs_builder_object=TreeObsForRailEnv(max_depth=2)): + super(MultiAgentEnv, self).__init__() if hasattr(config, "vector_index"): vector_index = config.vector_index else: vector_index = 1 + self.predefined_env = False + if config['rail_generator'] == "complex_rail_generator": self.rail_generator = complex_rail_generator(nr_start_goal=config['number_of_agents'], min_dist=5, nr_extra=config['nr_extra'], seed=config['seed'] * (1+vector_index)) - else: - raise(Error) + elif config['rail_generator'] == "random_rail_generator": self.rail_generator = random_rail_generator() + elif config['rail_generator'] == "load_env": + self.predefined_env = True + + else: + raise(ValueError, f'Unknown rail generator: {config["rail_generator"]}') set_seed(config['seed'] * (1+vector_index)) self.env = RailEnv(width=config["width"], height=config["height"], number_of_agents=config["number_of_agents"], - obs_builder_object=config['obs_builder'], rail_generator=self.rail_generator) + obs_builder_object=config['obs_builder'], rail_generator=self.rail_generator, + prediction_builder_object=DummyPredictorForRailEnv()) - # self.env.load('/home/guillaume/EPFL/Master_Thesis/flatland/baselines/torch_training/railway/complex_scene.pkl') + if self.predefined_env: + self.env.load(config['load_env_path']) + # '/home/guillaume/EPFL/Master_Thesis/flatland/baselines/torch_training/railway/complex_scene.pkl') self.width = self.env.width self.height = self.env.height @@ -42,19 +47,47 @@ class RailEnvRLLibWrapper(MultiAgentEnv): def reset(self): self.agents_done = [] - obs = self.env.reset() + if self.predefined_env: + obs = self.env.reset(False, False) + else: + obs = self.env.reset() + + predictions = self.env.predict() + pred_pos = np.concatenate([[x[:, 1:3]] for x in list(predictions.values())], axis=0) + o = dict() - - - #for agent, _ in obs.items(): - #o[agent] = obs[agent] - # one_hot_agent_encoding = np.zeros(len(self.env.agents)) - # one_hot_agent_encoding[agent] += 1 - # o[agent] = np.append(obs[agent], one_hot_agent_encoding) - - # o['agents'] = obs - # obs[0] = [obs[0], np.ones((17, 17)) * 17] - # obs['global_obs'] = np.ones((17, 17)) * 17 + + for i_agent in range(len(self.env.agents)): + + # prediction of collision that will be added to the observation + # Allows to the agent to know which other train is is about to meet (maybe will come + # up with a priority order of trains). + pred_obs = np.zeros((len(predictions[0]), len(self.env.agents))) + + for time_offset in range(len(predictions[0])): + + # We consider a time window of t-1; t+1 to find a collision + collision_window = list(range(max(time_offset - 1, 0), min(time_offset + 2, len(predictions[0])))) + + coord_agent = pred_pos[i_agent, time_offset, 0] + 1000*pred_pos[i_agent, time_offset, 1] + + # x coordinates of all other train in the time window + x_coord_other_agents = pred_pos[list(range(i_agent)) + list(range(i_agent+1, len(self.env.agents)))][ + :, collision_window, 0] + + # y coordinates of all other train in the time window + y_coord_other_agents = pred_pos[list(range(i_agent)) + list(range(i_agent + 1, len(self.env.agents)))][ + :, collision_window, 1] + + coord_other_agents = x_coord_other_agents + 1000*y_coord_other_agents + + # collision_info here contains the index of the agent colliding with the current agent + for collision_info in np.argwhere(coord_agent == coord_other_agents)[:, 0]: + pred_obs[time_offset, collision_info + 1*(collision_info >= i_agent)] = 1 + + agent_id_one_hot = np.zeros(len(self.env.agents)) + agent_id_one_hot[i_agent] = 1 + o[i_agent] = [obs[i_agent], agent_id_one_hot, pred_obs] self.rail = self.env.rail @@ -72,16 +105,53 @@ class RailEnvRLLibWrapper(MultiAgentEnv): o = dict() # print(self.agents_done) # print(dones) - for agent, done in dones.items(): - if agent not in self.agents_done: - if agent != '__all__': - # o[agent] = obs[agent] - #one_hot_agent_encoding = np.zeros(len(self.env.agents)) - #one_hot_agent_encoding[agent] += 1 - o[agent] = obs[agent]#np.append(obs[agent], one_hot_agent_encoding) - r[agent] = rewards[agent] - - d[agent] = dones[agent] + + for i_agent in range(len(self.env.agents)): + if i_agent not in self.agents_done: + # prediction of collision that will be added to the observation + # Allows to the agent to know which other train is is about to meet (maybe will come + # up with a priority order of trains). + pred_obs = np.zeros((len(predictions[0]), len(self.env.agents))) + + for time_offset in range(len(predictions[0])): + + # We consider a time window of t-1; t+1 to find a collision + collision_window = list(range(max(time_offset - 1, 0), min(time_offset + 2, len(predictions[0])))) + + coord_agent = pred_pos[i_agent, time_offset, 0] + 1000*pred_pos[i_agent, time_offset, 1] + + # x coordinates of all other train in the time window + x_coord_other_agents = pred_pos[list(range(i_agent)) + list(range(i_agent+1, len(self.env.agents)))][ + :, collision_window, 0] + + # y coordinates of all other train in the time window + y_coord_other_agents = pred_pos[list(range(i_agent)) + list(range(i_agent + 1, len(self.env.agents)))][ + :, collision_window, 1] + + coord_other_agents = x_coord_other_agents + 1000*y_coord_other_agents + + # collision_info here contains the index of the agent colliding with the current agent + for collision_info in np.argwhere(coord_agent == coord_other_agents)[:, 0]: + pred_obs[time_offset, collision_info + 1*(collision_info >= i_agent)] = 1 + + agent_id_one_hot = np.zeros(len(self.env.agents)) + agent_id_one_hot[i_agent] = 1 + o[i_agent] = [obs[i_agent], agent_id_one_hot, pred_obs] + r[i_agent] = rewards[i_agent] + d[i_agent] = dones[i_agent] + + d['__all__'] = dones['__all__'] + + # for agent, done in dones.items(): + # if agent not in self.agents_done: + # if agent != '__all__': + # # o[agent] = obs[agent] + # #one_hot_agent_encoding = np.zeros(len(self.env.agents)) + # #one_hot_agent_encoding[agent] += 1 + # o[agent] = obs[agent]#np.append(obs[agent], one_hot_agent_encoding) + # + # + # d[agent] = dones[agent] for agent, done in dones.items(): if done and agent != '__all__':