diff --git a/MANIFEST.in b/MANIFEST.in index 15ae22aa6240e9fc05354c6be3fe120383a8f91c..11453ed84ca71670f795228796a2058964549b8c 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -4,7 +4,6 @@ include changelog.md include LICENSE include README.md include requirements_torch_training.txt -include requirements_RLLib_training.txt diff --git a/README.md b/README.md index cdd54b5229cd88c94618782bf242595b453413f8..d8dc09bde2ace60244a901506eb25d1790acb74a 100644 --- a/README.md +++ b/README.md @@ -15,9 +15,6 @@ With the above introductions you will solve tasks like these and even more...  -# RLLib Training -The `RLLib_training` folder shows an example of how to train agents with algorithm from implemented in the RLLib library available at: <https://github.com/ray-project/ray/tree/master/python/ray/rllib> - # Sequential Agent This is a very simple baseline to show you have the `complex_level_generator` generates feasible network configurations. If you run the `run_test.py` file you will see a simple agent that solves the level by sequentially running each agent along its shortest path. diff --git a/RLLib_training/README.md b/RLLib_training/README.md deleted file mode 100644 index 8bda956f226af1c7ef4c7e1237b447cf7af4327a..0000000000000000000000000000000000000000 --- a/RLLib_training/README.md +++ /dev/null @@ -1,78 +0,0 @@ -This repository allows to run Rail Environment multi agent training with the RLLib Library. - -## Installation: - -To run scripts of this repository, the deep learning library tensorflow should be installed, along with the following packages: -```sh -pip install gym ray==0.7.0 gin-config opencv-python lz4 psutil -``` - -To start a training with different parameters, you can create a folder containing a config.gin file (see example in `experiment_configs/config_example/config.gin`. - -Then, you can modify the config.gin file path at the end of the `train_experiment.py` file. - -The results will be stored inside the folder, and the learning curves can be visualized in -tensorboard: - -``` -tensorboard --logdir=/path/to/folder_containing_config_gin_file -``` - -## Gin config files - -In each config.gin files, all the parameters of the `run_experiment` functions have to be specified. -For example, to indicate the number of agents that have to be initialized at the beginning of each simulation, the following line should be added: - -``` -run_experiment.n_agents = 2 -``` - -If several number of agents have to be explored during the experiment, one can pass the following value to the `n_agents` parameter: - -``` -run_experiment.n_agents = {"grid_search": [2,5]} -``` - -which is the way to indicate to the tune library to experiment several values for a parameter. - -To reference a class or an object within gin, you should first register it from the `train_experiment.py` script adding the following line: - -``` -gin.external_configurable(TreeObsForRailEnv) -``` - -and then a `TreeObsForRailEnv` object can be referenced in the `config.gin` file: - -``` -run_experiment.obs_builder = {"grid_search": [@TreeObsForRailEnv(), @GlobalObsForRailEnv()]} -TreeObsForRailEnv.max_depth = 2 -``` - -Note that `@TreeObsForRailEnv` references the class, while `@TreeObsForRailEnv()` references instantiates an object of this class. - - - - -More documentation on how to use gin-config can be found on the github repository: https://github.com/google/gin-config - -## Run an example: -To start a training on a 20X20 map, with different numbers of agents initialized at each episode, on can run the train_experiment.py script: -``` -python RLLib_training/train_experiment.py -``` -This will load the gin config file in the folder `experiment_configs/config_examples`. - -To visualize the result of a training, one can load a training checkpoint and use the policy learned. -This is done in the `render_training_result.py` script. One has to modify the `CHECKPOINT_PATH` at the beginning of this script: - -``` -CHECKPOINT_PATH = os.path.join(__file_dirname__, 'experiment_configs', 'config_example', 'ppo_policy_two_obs_with_predictions_n_agents_4_map_size_20q58l5_f7', - 'checkpoint_101', 'checkpoint-101') -``` -and load the corresponding gin config file: - -``` -gin.parse_config_file(os.path.join(__file_dirname__, 'experiment_configs', 'config_example', 'config.gin')) -``` - - diff --git a/RLLib_training/RailEnvRLLibWrapper.py b/RLLib_training/RailEnvRLLibWrapper.py deleted file mode 100644 index f82cd42d9bbd836b681ff284a82f357b2760bb0c..0000000000000000000000000000000000000000 --- a/RLLib_training/RailEnvRLLibWrapper.py +++ /dev/null @@ -1,135 +0,0 @@ -import numpy as np -from ray.rllib.env.multi_agent_env import MultiAgentEnv -from ray.rllib.utils.seed import seed as set_seed - -from flatland.envs.rail_env import RailEnv -from flatland.envs.rail_generators import complex_rail_generator, random_rail_generator -from flatland.envs.schedule_generators import complex_schedule_generator, random_schedule_generator - - -class RailEnvRLLibWrapper(MultiAgentEnv): - - def __init__(self, config): - - super(MultiAgentEnv, self).__init__() - - # Environment ID if num_envs_per_worker > 1 - if hasattr(config, "vector_index"): - vector_index = config.vector_index - else: - vector_index = 1 - - self.predefined_env = False - - if config['rail_generator'] == "complex_rail_generator": - self.rail_generator = complex_rail_generator(nr_start_goal=config['number_of_agents'], - min_dist=config['min_dist'], - nr_extra=config['nr_extra'], - seed=config['seed'] * (1 + vector_index)) - self.schedule_generator = complex_schedule_generator() - - elif config['rail_generator'] == "random_rail_generator": - self.rail_generator = random_rail_generator() - self.schedule_generator = random_schedule_generator() - elif config['rail_generator'] == "load_env": - self.predefined_env = True - self.rail_generator = random_rail_generator() - self.schedule_generator = random_schedule_generator() - else: - raise (ValueError, f'Unknown rail generator: {config["rail_generator"]}') - - set_seed(config['seed'] * (1 + vector_index)) - self.env = RailEnv(width=config["width"], height=config["height"], - number_of_agents=config["number_of_agents"], - obs_builder_object=config['obs_builder'], - rail_generator=self.rail_generator, - schedule_generator=self.schedule_generator - ) - - if self.predefined_env: - self.env.load_resource('torch_training.railway', 'complex_scene.pkl') - - self.width = self.env.width - self.height = self.env.height - self.step_memory = config["step_memory"] - - # needed for the renderer - self.rail = self.env.rail - self.agents = self.env.agents - self.agents_static = self.env.agents_static - self.dev_obs_dict = self.env.dev_obs_dict - - def reset(self): - self.agents_done = [] - if self.predefined_env: - obs = self.env.reset(False, False) - else: - obs = self.env.reset() - - # RLLib only receives observation of agents that are not done. - o = dict() - - for i_agent in range(len(self.env.agents)): - data, distance, agent_data = self.env.obs_builder.split_tree(tree=np.array(obs[i_agent]), - current_depth=0) - o[i_agent] = [data, distance, agent_data] - - # needed for the renderer - self.rail = self.env.rail - self.agents = self.env.agents - self.agents_static = self.env.agents_static - self.dev_obs_dict = self.env.dev_obs_dict - - # If step_memory > 1, we need to concatenate it the observations in memory, only works for - # step_memory = 1 or 2 for the moment - if self.step_memory < 2: - return o - else: - self.old_obs = o - oo = dict() - - for i_agent in range(len(self.env.agents)): - oo[i_agent] = [o[i_agent], o[i_agent]] - return oo - - def step(self, action_dict): - obs, rewards, dones, infos = self.env.step(action_dict) - - d = dict() - r = dict() - o = dict() - - for i_agent in range(len(self.env.agents)): - if i_agent not in self.agents_done: - data, distance, agent_data = self.env.obs_builder.split_tree(tree=np.array(obs[i_agent]), - current_depth=0) - - o[i_agent] = [data, distance, agent_data] - r[i_agent] = rewards[i_agent] - d[i_agent] = dones[i_agent] - - d['__all__'] = dones['__all__'] - - if self.step_memory >= 2: - oo = dict() - - for i_agent in range(len(self.env.agents)): - if i_agent not in self.agents_done: - oo[i_agent] = [o[i_agent], self.old_obs[i_agent]] - - self.old_obs = o - - for agent, done in dones.items(): - if done and agent != '__all__': - self.agents_done.append(agent) - - if self.step_memory < 2: - return o, r, d, infos - else: - return oo, r, d, infos - - def get_agent_handles(self): - return self.env.get_agent_handles() - - def get_num_agents(self): - return self.env.get_num_agents() diff --git a/RLLib_training/__init__.py b/RLLib_training/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/RLLib_training/custom_preprocessors.py b/RLLib_training/custom_preprocessors.py deleted file mode 100644 index d4c81a83f1c05317315a3f71f99565006e9311e1..0000000000000000000000000000000000000000 --- a/RLLib_training/custom_preprocessors.py +++ /dev/null @@ -1,26 +0,0 @@ -import numpy as np -from ray.rllib.models.preprocessors import Preprocessor -from utils.observation_utils import norm_obs_clip - -class TreeObsPreprocessor(Preprocessor): - def _init_shape(self, obs_space, options): - print(options) - self.step_memory = options["custom_options"]["step_memory"] - return sum([space.shape[0] for space in obs_space]), - - def transform(self, observation): - - if self.step_memory == 2: - data = norm_obs_clip(observation[0][0]) - distance = norm_obs_clip(observation[0][1]) - agent_data = np.clip(observation[0][2], -1, 1) - data2 = norm_obs_clip(observation[1][0]) - distance2 = norm_obs_clip(observation[1][1]) - agent_data2 = np.clip(observation[1][2], -1, 1) - else: - data = norm_obs_clip(observation[0]) - distance = norm_obs_clip(observation[1]) - agent_data = np.clip(observation[2], -1, 1) - - return np.concatenate((np.concatenate((np.concatenate((data, distance)), agent_data)), np.concatenate((np.concatenate((data2, distance2)), agent_data2)))) - diff --git a/RLLib_training/experiment_configs/__init__.py b/RLLib_training/experiment_configs/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/RLLib_training/experiment_configs/config_example/config.gin b/RLLib_training/experiment_configs/config_example/config.gin deleted file mode 100644 index 59d2dfb508f13cccf4b9152f24ab06d44c290450..0000000000000000000000000000000000000000 --- a/RLLib_training/experiment_configs/config_example/config.gin +++ /dev/null @@ -1,25 +0,0 @@ -run_experiment.name = "experiment_example" -run_experiment.num_iterations = 1002 -run_experiment.save_every = 100 -run_experiment.hidden_sizes = [32, 32] - -run_experiment.map_width = 20 -run_experiment.map_height = 20 -run_experiment.n_agents = {"grid_search": [3, 4, 5, 6, 7, 8]} -run_experiment.rail_generator = "complex_rail_generator" # Change this to "load_env" in order to load a predefined complex scene -run_experiment.nr_extra = 5 -run_experiment.policy_folder_name = "ppo_policy_two_obs_with_predictions_n_agents_{config[n_agents]}_" - -run_experiment.seed = 123 - -run_experiment.conv_model = False - -run_experiment.obs_builder = @TreeObsForRailEnv() -TreeObsForRailEnv.predictor = @ShortestPathPredictorForRailEnv() -TreeObsForRailEnv.max_depth = 2 - -run_experiment.entropy_coeff = 0.001 -run_experiment.kl_coeff = 0.2 -run_experiment.lambda_gae = 0.9 -run_experiment.step_memory = 2 -run_experiment.min_dist = 10 diff --git a/RLLib_training/render_training_result.py b/RLLib_training/render_training_result.py deleted file mode 100644 index 1ee7cc1ce394f3b40791706871aa180ec0510b52..0000000000000000000000000000000000000000 --- a/RLLib_training/render_training_result.py +++ /dev/null @@ -1,169 +0,0 @@ -from RailEnvRLLibWrapper import RailEnvRLLibWrapper -from custom_preprocessors import TreeObsPreprocessor -import gym -import os - -from ray.rllib.agents.ppo.ppo import DEFAULT_CONFIG -from ray.rllib.agents.ppo.ppo import PPOTrainer as Trainer -from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph as PolicyGraph - -from ray.rllib.models import ModelCatalog - -import ray -import numpy as np - -import gin - -from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv -gin.external_configurable(DummyPredictorForRailEnv) -gin.external_configurable(ShortestPathPredictorForRailEnv) - -from ray.rllib.utils.seed import seed as set_seed -from flatland.envs.observations import TreeObsForRailEnv - -from flatland.utils.rendertools import RenderTool -import time - -gin.external_configurable(TreeObsForRailEnv) - -ModelCatalog.register_custom_preprocessor("tree_obs_prep", TreeObsPreprocessor) -ray.init() # object_store_memory=150000000000, redis_max_memory=30000000000) - -__file_dirname__ = os.path.dirname(os.path.realpath(__file__)) - -CHECKPOINT_PATH = os.path.join(__file_dirname__, 'experiment_configs', 'config_example', 'ppo_policy_two_obs_with_predictions_n_agents_4_map_size_20q58l5_f7', - 'checkpoint_101', 'checkpoint-101') # To Modify -N_EPISODES = 10 -N_STEPS_PER_EPISODE = 50 - - -def render_training_result(config): - print('Init Env') - - set_seed(config['seed'], config['seed'], config['seed']) - - # Example configuration to generate a random rail - env_config = {"width": config['map_width'], - "height": config['map_height'], - "rail_generator": config["rail_generator"], - "nr_extra": config["nr_extra"], - "number_of_agents": config['n_agents'], - "seed": config['seed'], - "obs_builder": config['obs_builder'], - "min_dist": config['min_dist'], - "step_memory": config["step_memory"]} - - # Observation space and action space definitions - if isinstance(config["obs_builder"], TreeObsForRailEnv): - obs_space = gym.spaces.Tuple((gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(168,)),) * 2) - preprocessor = TreeObsPreprocessor - - else: - raise ValueError("Undefined observation space") - - act_space = gym.spaces.Discrete(5) - - # Dict with the different policies to train - policy_graphs = { - "ppo_policy": (PolicyGraph, obs_space, act_space, {}) - } - - def policy_mapping_fn(agent_id): - return "ppo_policy" - - # Trainer configuration - trainer_config = DEFAULT_CONFIG.copy() - - trainer_config['model'] = {"fcnet_hiddens": config['hidden_sizes']} - - trainer_config['multiagent'] = {"policy_graphs": policy_graphs, - "policy_mapping_fn": policy_mapping_fn, - "policies_to_train": list(policy_graphs.keys())} - - trainer_config["num_workers"] = 0 - trainer_config["num_cpus_per_worker"] = 4 - trainer_config["num_gpus"] = 0.2 - trainer_config["num_gpus_per_worker"] = 0.2 - trainer_config["num_cpus_for_driver"] = 1 - trainer_config["num_envs_per_worker"] = 1 - trainer_config['entropy_coeff'] = config['entropy_coeff'] - trainer_config["env_config"] = env_config - trainer_config["batch_mode"] = "complete_episodes" - trainer_config['simple_optimizer'] = False - trainer_config['postprocess_inputs'] = True - trainer_config['log_level'] = 'WARN' - trainer_config['num_sgd_iter'] = 10 - trainer_config['clip_param'] = 0.2 - trainer_config['kl_coeff'] = config['kl_coeff'] - trainer_config['lambda'] = config['lambda_gae'] - - env = RailEnvRLLibWrapper(env_config) - - trainer = Trainer(env=RailEnvRLLibWrapper, config=trainer_config) - - trainer.restore(CHECKPOINT_PATH) - - policy = trainer.get_policy("ppo_policy") - - preprocessor = preprocessor(obs_space, {"step_memory": config["step_memory"]}) - env_renderer = RenderTool(env, gl="PILSVG") - for episode in range(N_EPISODES): - - observation = env.reset() - for i in range(N_STEPS_PER_EPISODE): - preprocessed_obs = [] - for obs in observation.values(): - preprocessed_obs.append(preprocessor.transform(obs)) - action, _, infos = policy.compute_actions(preprocessed_obs, []) - logits = infos['behaviour_logits'] - actions = dict() - - # We select the greedy action. - for j, logit in enumerate(logits): - actions[j] = np.argmax(logit) - - # In case we prefer to sample an action stochastically according to the policy graph. - # for j, act in enumerate(action): - # actions[j] = act - - # Time to see the rendering at one step - time.sleep(1) - - env_renderer.renderEnv(show=True, frames=True, iEpisode=episode, iStep=i, - action_dict=list(actions.values())) - - observation, _, _, _ = env.step(actions) - - env_renderer.close_window() - - -@gin.configurable -def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every, - map_width, map_height, policy_folder_name, obs_builder, - entropy_coeff, seed, conv_model, rail_generator, nr_extra, kl_coeff, lambda_gae, - step_memory, min_dist): - - render_training_result( - config={"n_agents": n_agents, - "hidden_sizes": hidden_sizes, # Array containing the sizes of the network layers - "save_every": save_every, - "map_width": map_width, - "map_height": map_height, - 'policy_folder_name': policy_folder_name, - "obs_builder": obs_builder, - "entropy_coeff": entropy_coeff, - "seed": seed, - "conv_model": conv_model, - "rail_generator": rail_generator, - "nr_extra": nr_extra, - "kl_coeff": kl_coeff, - "lambda_gae": lambda_gae, - "min_dist": min_dist, - "step_memory": step_memory - } - ) - - -if __name__ == '__main__': - gin.parse_config_file(os.path.join(__file_dirname__, 'experiment_configs', 'config_example', 'config.gin')) # To Modify - run_experiment() diff --git a/RLLib_training/train_experiment.py b/RLLib_training/train_experiment.py deleted file mode 100644 index 7435a8fed728ec363321ba7a2bcf04b186513559..0000000000000000000000000000000000000000 --- a/RLLib_training/train_experiment.py +++ /dev/null @@ -1,210 +0,0 @@ -import os - -import gin -import gym -from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv - -# Import PPO trainer: we can replace these imports by any other trainer from RLLib. -from ray.rllib.agents.ppo.ppo import DEFAULT_CONFIG -from ray.rllib.agents.ppo.ppo import PPOTrainer as Trainer -from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph as PolicyGraph -from ray.rllib.models import ModelCatalog - -gin.external_configurable(DummyPredictorForRailEnv) -gin.external_configurable(ShortestPathPredictorForRailEnv) - -import ray - -from ray.tune.logger import UnifiedLogger -from ray.tune.logger import pretty_print -import os - -from RailEnvRLLibWrapper import RailEnvRLLibWrapper -import tempfile - -from ray import tune - -from ray.rllib.utils.seed import seed as set_seed -from flatland.envs.observations import TreeObsForRailEnv - -gin.external_configurable(TreeObsForRailEnv) - -import numpy as np -from custom_preprocessors import TreeObsPreprocessor - -ModelCatalog.register_custom_preprocessor("tree_obs_prep", TreeObsPreprocessor) -ray.init() # object_store_memory=150000000000, redis_max_memory=30000000000) - -__file_dirname__ = os.path.dirname(os.path.realpath(__file__)) - - -def on_episode_start(info): - episode = info['episode'] - map_width = info['env'].envs[0].width - map_height = info['env'].envs[0].height - episode.horizon = 3*(map_width + map_height) - - -def on_episode_end(info): - episode = info['episode'] - - # Calculation of a custom score metric: cum of all accumulated rewards, divided by the number of agents - # and the number of the maximum time steps of the episode. - score = 0 - for k, v in episode._agent_reward_history.items(): - score += np.sum(v) - score /= (len(episode._agent_reward_history) * episode.horizon) - - # Calculation of the proportion of solved episodes before the maximum time step - done = 0 - if len(episode._agent_reward_history[0]) <= episode.horizon-5: - done = 1 - - episode.custom_metrics["score"] = score - episode.custom_metrics["proportion_episode_solved"] = done - - -def train(config, reporter): - print('Init Env') - - set_seed(config['seed'], config['seed'], config['seed']) - - # Given the depth of the tree observation and the number of features per node we get the following state_size - num_features_per_node = config['obs_builder'].observation_dim - tree_depth = 2 - nr_nodes = 0 - for i in range(tree_depth + 1): - nr_nodes += np.power(4, i) - obs_size = num_features_per_node * nr_nodes - - - # Environment parameters - env_config = {"width": config['map_width'], - "height": config['map_height'], - "rail_generator": config["rail_generator"], - "nr_extra": config["nr_extra"], - "number_of_agents": config['n_agents'], - "seed": config['seed'], - "obs_builder": config['obs_builder'], - "min_dist": config['min_dist'], - "step_memory": config["step_memory"]} - - # Observation space and action space definitions - if isinstance(config["obs_builder"], TreeObsForRailEnv): - obs_space = gym.spaces.Tuple((gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(obs_size,)),) * 2) - preprocessor = "tree_obs_prep" - else: - raise ValueError("Undefined observation space") # Only TreeObservation implemented for now. - - act_space = gym.spaces.Discrete(5) - - # Dict with the different policies to train. In this case, all trains follow the same policy - policy_graphs = { - "ppo_policy": (PolicyGraph, obs_space, act_space, {}) - } - - # Function that maps an agent id to the name of its respective policy. - def policy_mapping_fn(agent_id): - return "ppo_policy" - - # Trainer configuration - trainer_config = DEFAULT_CONFIG.copy() - trainer_config['model'] = {"fcnet_hiddens": config['hidden_sizes'], "custom_preprocessor": preprocessor, - "custom_options": {"step_memory": config["step_memory"], "obs_size": obs_size}} - - trainer_config['multiagent'] = {"policy_graphs": policy_graphs, - "policy_mapping_fn": policy_mapping_fn, - "policies_to_train": list(policy_graphs.keys())} - - # Maximum time steps for an episode is set to 3*map_width*map_height - trainer_config["horizon"] = 3 * (config['map_width'] + config['map_height']) - - # Parameters for calculation parallelization - trainer_config["num_workers"] = 0 - trainer_config["num_cpus_per_worker"] = 8 - trainer_config["num_gpus"] = 0.2 - trainer_config["num_gpus_per_worker"] = 0.2 - trainer_config["num_cpus_for_driver"] = 1 - trainer_config["num_envs_per_worker"] = 1 - - # Parameters for PPO training - trainer_config['entropy_coeff'] = config['entropy_coeff'] - trainer_config["env_config"] = env_config - trainer_config["batch_mode"] = "complete_episodes" - trainer_config['simple_optimizer'] = False - trainer_config['log_level'] = 'WARN' - trainer_config['num_sgd_iter'] = 10 - trainer_config['clip_param'] = 0.2 - trainer_config['kl_coeff'] = config['kl_coeff'] - trainer_config['lambda'] = config['lambda_gae'] - trainer_config['callbacks'] = { - "on_episode_start": tune.function(on_episode_start), - "on_episode_end": tune.function(on_episode_end) - } - - - def logger_creator(conf): - """Creates a Unified logger with a default logdir prefix.""" - logdir = config['policy_folder_name'].format(**locals()) - logdir = tempfile.mkdtemp( - prefix=logdir, dir=config['local_dir']) - return UnifiedLogger(conf, logdir, None) - - logger = logger_creator - - trainer = Trainer(env=RailEnvRLLibWrapper, config=trainer_config, logger_creator=logger) - - for i in range(100000 + 2): - print("== Iteration", i, "==") - - print(pretty_print(trainer.train())) - - if i % config['save_every'] == 0: - checkpoint = trainer.save() - print("checkpoint saved at", checkpoint) - - reporter(num_iterations_trained=trainer._iteration) - - -@gin.configurable -def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every, - map_width, map_height, policy_folder_name, local_dir, obs_builder, - entropy_coeff, seed, conv_model, rail_generator, nr_extra, kl_coeff, lambda_gae, - step_memory, min_dist): - tune.run( - train, - name=name, - stop={"num_iterations_trained": num_iterations}, - config={"n_agents": n_agents, - "hidden_sizes": hidden_sizes, # Array containing the sizes of the network layers - "save_every": save_every, - "map_width": map_width, - "map_height": map_height, - "local_dir": local_dir, - 'policy_folder_name': policy_folder_name, - "obs_builder": obs_builder, - "entropy_coeff": entropy_coeff, - "seed": seed, - "conv_model": conv_model, - "rail_generator": rail_generator, - "nr_extra": nr_extra, - "kl_coeff": kl_coeff, - "lambda_gae": lambda_gae, - "min_dist": min_dist, - "step_memory": step_memory # If equal to two, the current observation plus - # the observation of last time step will be given as input the the model. - }, - resources_per_trial={ - "cpu": 8, - "gpu": 0.2 - }, - verbose=2, - local_dir=local_dir - ) - - -if __name__ == '__main__': - folder_name = 'config_example' # To Modify - gin.parse_config_file(os.path.join(__file_dirname__, 'experiment_configs', folder_name, 'config.gin')) - dir = os.path.join(__file_dirname__, 'experiment_configs', folder_name) - run_experiment(local_dir=dir) diff --git a/score_test.py b/score_test.py index ff4a94c5e1b82c90eec0c5bf129bad496046e595..f52a1e010aefa085d42e46155eb0618e27f05702 100644 --- a/score_test.py +++ b/score_test.py @@ -20,9 +20,6 @@ nr_trials_per_test = 100 test_results = [] test_times = [] test_dones = [] -# Load agent -# agent = Agent(state_size, action_size, "FC", 0) -# agent.qnetwork_local.load_state_dict(torch.load('./torch_training/Nets/avoid_checkpoint1700.pth')) agent = RandomAgent(state_size, action_size) start_time_scoring = time.time() test_idx = 0 diff --git a/scoring/score_test.py b/scoring/score_test.py index 5665d446047d2c2dd0f7504de6de391ade98f1b3..4baee4a176e80a3c947cbd9402a78b76c735f839 100644 --- a/scoring/score_test.py +++ b/scoring/score_test.py @@ -28,8 +28,8 @@ test_dones = [] sequential_agent_test = False # Load your agent -agent = Agent(state_size, action_size, "FC", 0) -agent.qnetwork_local.load_state_dict(torch.load('../torch_training/Nets/avoid_checkpoint60000.pth')) +agent = Agent(state_size, action_size) +agent.qnetwork_local.load_state_dict(torch.load('../torch_training/Nets/avoid_checkpoint500.pth')) # Load the necessary Observation Builder and Predictor predictor = ShortestPathPredictorForRailEnv() diff --git a/setup.py b/setup.py index 2b9b731ea02a0c9bdbea7602ea1dfa2ad6e194e2..5bc77c5188d799da8898f959c180c78f9c1496f6 100644 --- a/setup.py +++ b/setup.py @@ -2,8 +2,7 @@ from setuptools import setup, find_packages install_reqs = [] dependency_links = [] -# TODO: include requirements_RLLib_training.txt -requirements_paths = ['requirements_torch_training.txt'] # , 'requirements_RLLib_training.txt'] +requirements_paths = ['requirements_torch_training.txt'] for requirements_path in requirements_paths: with open(requirements_path, 'r') as f: install_reqs += [ diff --git a/torch_training/Nets/avoid_checkpoint15000.pth b/torch_training/Nets/avoid_checkpoint15000.pth deleted file mode 100644 index e1daf228b7f1f6b108329715c3cdbd67805e28ae..0000000000000000000000000000000000000000 Binary files a/torch_training/Nets/avoid_checkpoint15000.pth and /dev/null differ diff --git a/torch_training/Nets/avoid_checkpoint30000.pth b/torch_training/Nets/avoid_checkpoint30000.pth deleted file mode 100644 index 0e2c1b28c1655bc16c9339066b8d105282f14418..0000000000000000000000000000000000000000 Binary files a/torch_training/Nets/avoid_checkpoint30000.pth and /dev/null differ diff --git a/torch_training/Nets/avoid_checkpoint60000.pth b/torch_training/Nets/avoid_checkpoint60000.pth deleted file mode 100644 index b4fef60542f50419353047721fae31f5382e7bd4..0000000000000000000000000000000000000000 Binary files a/torch_training/Nets/avoid_checkpoint60000.pth and /dev/null differ diff --git a/torch_training/dueling_double_dqn.py b/torch_training/dueling_double_dqn.py index cf2f7d512b99aafb9fe0477bf048441efa0bff9e..b7bb4bcdc7c72fa09b352fdf5cf99258f8f9ad0c 100644 --- a/torch_training/dueling_double_dqn.py +++ b/torch_training/dueling_double_dqn.py @@ -8,7 +8,7 @@ import torch import torch.nn.functional as F import torch.optim as optim -from torch_training.model import QNetwork, QNetwork2 +from torch_training.model import QNetwork BUFFER_SIZE = int(1e5) # replay buffer size BATCH_SIZE = 512 # minibatch size @@ -16,43 +16,33 @@ GAMMA = 0.99 # discount factor 0.99 TAU = 1e-3 # for soft update of target parameters LR = 0.5e-4 # learning rate 0.5e-4 works UPDATE_EVERY = 10 # how often to update the network -double_dqn = True # If using double dqn algorithm -input_channels = 5 # Number of Input channels device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") -#device = torch.device("cpu") print(device) class Agent: """Interacts with and learns from the environment.""" - def __init__(self, state_size, action_size, net_type, seed, double_dqn=True, input_channels=5): + def __init__(self, state_size, action_size, double_dqn=True): """Initialize an Agent object. Params ====== state_size (int): dimension of each state action_size (int): dimension of each action - seed (int): random seed """ self.state_size = state_size self.action_size = action_size - self.seed = random.seed(seed) - self.version = net_type self.double_dqn = double_dqn # Q-Network - if self.version == "Conv": - self.qnetwork_local = QNetwork2(state_size, action_size, seed, input_channels).to(device) - self.qnetwork_target = copy.deepcopy(self.qnetwork_local) - else: - self.qnetwork_local = QNetwork(state_size, action_size, seed).to(device) - self.qnetwork_target = copy.deepcopy(self.qnetwork_local) + self.qnetwork_local = QNetwork(state_size, action_size).to(device) + self.qnetwork_target = copy.deepcopy(self.qnetwork_local) self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=LR) # Replay memory - self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed) + self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE) # Initialize time step (for updating every UPDATE_EVERY steps) self.t_step = 0 @@ -152,7 +142,7 @@ class Agent: class ReplayBuffer: """Fixed-size buffer to store experience tuples.""" - def __init__(self, action_size, buffer_size, batch_size, seed): + def __init__(self, action_size, buffer_size, batch_size): """Initialize a ReplayBuffer object. Params @@ -160,13 +150,11 @@ class ReplayBuffer: action_size (int): dimension of each action buffer_size (int): maximum size of buffer batch_size (int): size of each training batch - seed (int): random seed """ self.action_size = action_size self.memory = deque(maxlen=buffer_size) self.batch_size = batch_size self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"]) - self.seed = random.seed(seed) def add(self, state, action, reward, next_state, done): """Add a new experience to memory.""" @@ -188,7 +176,7 @@ class ReplayBuffer: dones = torch.from_numpy(self.__v_stack_impr([e.done for e in experiences if e is not None]).astype(np.uint8)) \ .float().to(device) - return (states, actions, rewards, next_states, dones) + return states, actions, rewards, next_states, dones def __len__(self): """Return the current size of internal memory.""" diff --git a/torch_training/model.py b/torch_training/model.py index 7a5b3d613342a4fba8e2c8f1f45df21381e12684..9a5afccfda50e63271b1f5d8ed5a2d74b5e169e7 100644 --- a/torch_training/model.py +++ b/torch_training/model.py @@ -3,7 +3,7 @@ import torch.nn.functional as F class QNetwork(nn.Module): - def __init__(self, state_size, action_size, seed, hidsize1=128, hidsize2=128): + def __init__(self, state_size, action_size, hidsize1=128, hidsize2=128): super(QNetwork, self).__init__() self.fc1_val = nn.Linear(state_size, hidsize1) @@ -24,38 +24,3 @@ class QNetwork(nn.Module): adv = F.relu(self.fc2_adv(adv)) adv = self.fc3_adv(adv) return val + adv - adv.mean() - - -class QNetwork2(nn.Module): - def __init__(self, state_size, action_size, seed, input_channels, hidsize1=128, hidsize2=64): - super(QNetwork2, self).__init__() - self.conv1 = nn.Conv2d(input_channels, 16, kernel_size=3, stride=1) - self.bn1 = nn.BatchNorm2d(16) - self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=3) - self.bn2 = nn.BatchNorm2d(32) - self.conv3 = nn.Conv2d(32, 64, kernel_size=5, stride=3) - self.bn3 = nn.BatchNorm2d(64) - - self.fc1_val = nn.Linear(6400, hidsize1) - self.fc2_val = nn.Linear(hidsize1, hidsize2) - self.fc3_val = nn.Linear(hidsize2, 1) - - self.fc1_adv = nn.Linear(6400, hidsize1) - self.fc2_adv = nn.Linear(hidsize1, hidsize2) - self.fc3_adv = nn.Linear(hidsize2, action_size) - - def forward(self, x): - x = F.relu(self.conv1(x)) - x = F.relu(self.conv2(x)) - x = F.relu(self.conv3(x)) - - # value function approximation - val = F.relu(self.fc1_val(x.view(x.size(0), -1))) - val = F.relu(self.fc2_val(val)) - val = self.fc3_val(val) - - # advantage calculation - adv = F.relu(self.fc1_adv(x.view(x.size(0), -1))) - adv = F.relu(self.fc2_adv(adv)) - adv = self.fc3_adv(adv) - return val + adv - adv.mean() diff --git a/torch_training/multi_agent_inference.py b/torch_training/multi_agent_inference.py index 2fbbe61f7fd61eed4d8589695a06da238091717a..580886b1db73ba34d539e14968deea384b5b98be 100644 --- a/torch_training/multi_agent_inference.py +++ b/torch_training/multi_agent_inference.py @@ -4,8 +4,8 @@ from collections import deque import numpy as np import torch from importlib_resources import path -from observation_builders.observations import TreeObsForRailEnv -from predictors.predictions import ShortestPathPredictorForRailEnv +from flatland.envs.observations import TreeObsForRailEnv +from flatland.envs.predictions import ShortestPathPredictorForRailEnv import torch_training.Nets from flatland.envs.rail_env import RailEnv @@ -87,8 +87,8 @@ dones_list = [] action_prob = [0] * action_size agent_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents() -agent = Agent(state_size, action_size, "FC", 0) -with path(torch_training.Nets, "avoid_checkpoint100.pth") as file_in: +agent = Agent(state_size, action_size) +with path(torch_training.Nets, "avoid_checkpoint500.pth") as file_in: agent.qnetwork_local.load_state_dict(torch.load(file_in)) record_images = False diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py index 222430dd0af0f239ddd99d127b21349a13a2e892..fe9e27969f129313c19f54fd36738188ba376082 100644 --- a/torch_training/multi_agent_training.py +++ b/torch_training/multi_agent_training.py @@ -121,11 +121,11 @@ def main(argv): observation_radius = 10 # Initialize the agent - agent = Agent(state_size, action_size, "FC", 0) + agent = Agent(state_size, action_size) # Here you can pre-load an agent if False: - with path(torch_training.Nets, "avoid_checkpoint2400.pth") as file_in: + with path(torch_training.Nets, "avoid_checkpoint500.pth") as file_in: agent.qnetwork_local.load_state_dict(torch.load(file_in)) # Do training over n_episodes diff --git a/torch_training/multi_agent_two_time_step_training.py b/torch_training/multi_agent_two_time_step_training.py index 08cd84c379fe54cd4d6b71140a96623ebe2a8cbf..57f4a619d3bd3ae797fa7fe9aff7c799064bc6f6 100644 --- a/torch_training/multi_agent_two_time_step_training.py +++ b/torch_training/multi_agent_two_time_step_training.py @@ -7,16 +7,16 @@ from collections import deque import matplotlib.pyplot as plt import numpy as np import torch -from importlib_resources import path - -# Import Torch and utility functions to normalize observation -import torch_training.Nets from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import complex_rail_generator # Import Flatland/ Observations and Predictors from flatland.envs.schedule_generators import complex_schedule_generator +from importlib_resources import path + +# Import Torch and utility functions to normalize observation +import torch_training.Nets from torch_training.dueling_double_dqn import Agent from utils.observation_utils import norm_obs_clip, split_tree @@ -41,12 +41,12 @@ def main(argv): n_agents = np.random.randint(3, 8) n_goals = n_agents + np.random.randint(0, 3) min_dist = int(0.75 * min(x_dim, y_dim)) - tree_depth = 3 + tree_depth = 2 print("main2") + demo = False # Get an observation builder and predictor - predictor = ShortestPathPredictorForRailEnv() - observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=predictor()) + observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=ShortestPathPredictorForRailEnv()) env = RailEnv(width=x_dim, height=y_dim, @@ -60,7 +60,6 @@ def main(argv): handle = env.get_agent_handles() features_per_node = env.obs_builder.observation_dim - tree_depth = 2 nr_nodes = 0 for i in range(tree_depth + 1): nr_nodes += np.power(4, i) @@ -87,11 +86,11 @@ def main(argv): agent_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents() # Initialize the agent - agent = Agent(state_size, action_size, "FC", 0) + agent = Agent(state_size, action_size) # Here you can pre-load an agent if False: - with path(torch_training.Nets, "avoid_checkpoint30000.pth") as file_in: + with path(torch_training.Nets, "avoid_checkpoint500.pth") as file_in: agent.qnetwork_local.load_state_dict(torch.load(file_in)) # Do training over n_episodes @@ -132,6 +131,7 @@ def main(argv): # Build agent specific observations for a in range(env.get_num_agents()): data, distance, agent_data = split_tree(tree=np.array(obs[a]), + num_features_per_node=features_per_node, current_depth=0) data = norm_obs_clip(data) distance = norm_obs_clip(distance) @@ -164,6 +164,7 @@ def main(argv): next_obs, all_rewards, done, _ = env.step(action_dict) for a in range(env.get_num_agents()): data, distance, agent_data = split_tree(tree=np.array(next_obs[a]), + num_features_per_node=features_per_node, current_depth=0) data = norm_obs_clip(data) distance = norm_obs_clip(distance) diff --git a/torch_training/observation_builders/observations.py b/torch_training/observation_builders/observations.py deleted file mode 100644 index 66c38301d034151aa24e306d2e432d54d2802018..0000000000000000000000000000000000000000 --- a/torch_training/observation_builders/observations.py +++ /dev/null @@ -1,865 +0,0 @@ -""" -Collection of environment-specific ObservationBuilder. -""" -import pprint -from collections import deque - -import numpy as np - -from flatland.core.env_observation_builder import ObservationBuilder -from flatland.core.grid.grid4_utils import get_new_position -from flatland.core.grid.grid_utils import coordinate_to_position - - -class TreeObsForRailEnv(ObservationBuilder): - """ - TreeObsForRailEnv object. - - This object returns observation vectors for agents in the RailEnv environment. - The information is local to each agent and exploits the graph structure of the rail - network to simplify the representation of the state of the environment for each agent. - - For details about the features in the tree observation see the get() function. - """ - - def __init__(self, max_depth, predictor=None): - super().__init__() - self.max_depth = max_depth - self.observation_dim = 11 - # Compute the size of the returned observation vector - size = 0 - pow4 = 1 - for i in range(self.max_depth + 1): - size += pow4 - pow4 *= 4 - self.observation_space = [size * self.observation_dim] - self.location_has_agent = {} - self.location_has_agent_direction = {} - self.predictor = predictor - self.agents_previous_reset = None - self.tree_explored_actions = [1, 2, 3, 0] - self.tree_explorted_actions_char = ['L', 'F', 'R', 'B'] - self.distance_map = None - self.distance_map_computed = False - - def reset(self): - agents = self.env.agents - nb_agents = len(agents) - compute_distance_map = True - if self.agents_previous_reset is not None and nb_agents == len(self.agents_previous_reset): - compute_distance_map = False - for i in range(nb_agents): - if agents[i].target != self.agents_previous_reset[i].target: - compute_distance_map = True - # Don't compute the distance map if it was loaded - if self.agents_previous_reset is None and self.distance_map is not None: - self.location_has_target = {tuple(agent.target): 1 for agent in agents} - compute_distance_map = False - - if compute_distance_map: - self._compute_distance_map() - - self.agents_previous_reset = agents - - def _compute_distance_map(self): - agents = self.env.agents - # For testing only --> To assert if a distance map need to be recomputed. - self.distance_map_computed = True - nb_agents = len(agents) - self.distance_map = np.inf * np.ones(shape=(nb_agents, - self.env.height, - self.env.width, - 4)) - self.max_dist = np.zeros(nb_agents) - self.max_dist = [self._distance_map_walker(agent.target, i) for i, agent in enumerate(agents)] - # Update local lookup table for all agents' target locations - self.location_has_target = {tuple(agent.target): 1 for agent in agents} - - def _distance_map_walker(self, position, target_nr): - """ - Utility function to compute distance maps from each cell in the rail network (and each possible - orientation within it) to each agent's target cell. - """ - # Returns max distance to target, from the farthest away node, while filling in distance_map - self.distance_map[target_nr, position[0], position[1], :] = 0 - - # Fill in the (up to) 4 neighboring nodes - # direction is the direction of movement, meaning that at least a possible orientation of an agent - # in cell (row,col) allows a movement in direction `direction` - nodes_queue = deque(self._get_and_update_neighbors(position, target_nr, 0, enforce_target_direction=-1)) - - # BFS from target `position` to all the reachable nodes in the grid - # Stop the search if the target position is re-visited, in any direction - visited = {(position[0], position[1], 0), (position[0], position[1], 1), (position[0], position[1], 2), - (position[0], position[1], 3)} - - max_distance = 0 - - while nodes_queue: - node = nodes_queue.popleft() - - node_id = (node[0], node[1], node[2]) - - if node_id not in visited: - visited.add(node_id) - - # From the list of possible neighbors that have at least a path to the current node, only keep those - # whose new orientation in the current cell would allow a transition to direction node[2] - valid_neighbors = self._get_and_update_neighbors((node[0], node[1]), target_nr, node[3], node[2]) - - for n in valid_neighbors: - nodes_queue.append(n) - - if len(valid_neighbors) > 0: - max_distance = max(max_distance, node[3] + 1) - - return max_distance - - def _get_and_update_neighbors(self, position, target_nr, current_distance, enforce_target_direction=-1): - """ - Utility function used by _distance_map_walker to perform a BFS walk over the rail, filling in the - minimum distances from each target cell. - """ - neighbors = [] - - possible_directions = [0, 1, 2, 3] - if enforce_target_direction >= 0: - # The agent must land into the current cell with orientation `enforce_target_direction`. - # This is only possible if the agent has arrived from the cell in the opposite direction! - possible_directions = [(enforce_target_direction + 2) % 4] - - for neigh_direction in possible_directions: - new_cell = get_new_position(position, neigh_direction) - - if new_cell[0] >= 0 and new_cell[0] < self.env.height and new_cell[1] >= 0 and new_cell[1] < self.env.width: - - desired_movement_from_new_cell = (neigh_direction + 2) % 4 - - # Check all possible transitions in new_cell - for agent_orientation in range(4): - # Is a transition along movement `desired_movement_from_new_cell` to the current cell possible? - is_valid = self.env.rail.get_transition((new_cell[0], new_cell[1], agent_orientation), - desired_movement_from_new_cell) - - if is_valid: - """ - # TODO: check that it works with deadends! -- still bugged! - movement = desired_movement_from_new_cell - if isNextCellDeadEnd: - movement = (desired_movement_from_new_cell+2) % 4 - """ - new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation], - current_distance + 1) - neighbors.append((new_cell[0], new_cell[1], agent_orientation, new_distance)) - self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation] = new_distance - - return neighbors - - def get_many(self, handles=None): - """ - Called whenever an observation has to be computed for the `env` environment, for each agent with handle - in the `handles` list. - """ - - if handles is None: - handles = [] - if self.predictor: - self.max_prediction_depth = 0 - self.predicted_pos = {} - self.predicted_dir = {} - self.predictions = self.predictor.get(custom_args={'distance_map': self.distance_map}) - if self.predictions: - - for t in range(len(self.predictions[0])): - pos_list = [] - dir_list = [] - for a in handles: - pos_list.append(self.predictions[a][t][1:3]) - dir_list.append(self.predictions[a][t][3]) - self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)}) - self.predicted_dir.update({t: dir_list}) - self.max_prediction_depth = len(self.predicted_pos) - observations = {} - for h in handles: - observations[h] = self.get(h) - return observations - - def get(self, handle): - """ - Computes the current observation for agent `handle` in env - - The observation vector is composed of 4 sequential parts, corresponding to data from the up to 4 possible - movements in a RailEnv (up to because only a subset of possible transitions are allowed in RailEnv). - The possible movements are sorted relative to the current orientation of the agent, rather than NESW as for - the transitions. The order is: - [data from 'left'] + [data from 'forward'] + [data from 'right'] + [data from 'back'] - - Each branch data is organized as: - [root node information] + - [recursive branch data from 'left'] + - [... from 'forward'] + - [... from 'right] + - [... from 'back'] - - Each node information is composed of 9 features: - - #1: if own target lies on the explored branch the current distance from the agent in number of cells is stored. - - #2: if another agents target is detected the distance in number of cells from the agents current location - is stored - - #3: if another agent is detected the distance in number of cells from current agent position is stored. - - #4: possible conflict detected - tot_dist = Other agent predicts to pass along this cell at the same time as the agent, we store the - distance in number of cells from current agent position - - 0 = No other agent reserve the same cell at similar time - - #5: if an not usable switch (for agent) is detected we store the distance. - - #6: This feature stores the distance in number of cells to the next branching (current node) - - #7: minimum distance from node to the agent's target given the direction of the agent if this path is chosen - - #8: agent in the same direction - n = number of agents present same direction - (possible future use: number of other agents in the same direction in this branch) - 0 = no agent present same direction - - #9: agent in the opposite direction - n = number of agents present other direction than myself (so conflict) - (possible future use: number of other agents in other direction in this branch, ie. number of conflicts) - 0 = no agent present other direction than myself - - #10: malfunctioning/blokcing agents - n = number of time steps the oberved agent remains blocked - - #11: slowest observed speed of an agent in same direction - 1 if no agent is observed - - min_fractional speed otherwise - - - - - - Missing/padding nodes are filled in with -inf (truncated). - Missing values in present node are filled in with +inf (truncated). - - - In case of the root node, the values are [0, 0, 0, 0, distance from agent to target, own malfunction, own speed] - In case the target node is reached, the values are [0, 0, 0, 0, 0]. - """ - - # Update local lookup table for all agents' positions - self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents} - self.location_has_agent_direction = {tuple(agent.position): agent.direction for agent in self.env.agents} - self.location_has_agent_speed = {tuple(agent.position): agent.speed_data['speed'] for agent in self.env.agents} - self.location_has_agent_malfunction = {tuple(agent.position): agent.malfunction_data['malfunction'] for agent in - self.env.agents} - - if handle > len(self.env.agents): - print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents)) - agent = self.env.agents[handle] # TODO: handle being treated as index - possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction) - num_transitions = np.count_nonzero(possible_transitions) - - # Root node - current position - # Here information about the agent itself is stored - observation = [0, 0, 0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0, - agent.malfunction_data['malfunction'], agent.speed_data['speed']] - - visited = set() - - # Start from the current orientation, and see which transitions are available; - # organize them as [left, forward, right, back], relative to the current orientation - # If only one transition is possible, the tree is oriented with this transition as the forward branch. - orientation = agent.direction - - if num_transitions == 1: - orientation = np.argmax(possible_transitions) - - for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]: - if possible_transitions[branch_direction]: - new_cell = get_new_position(agent.position, branch_direction) - branch_observation, branch_visited = \ - self._explore_branch(handle, new_cell, branch_direction, 1, 1) - observation = observation + branch_observation - visited = visited.union(branch_visited) - else: - # add cells filled with infinity if no transition is possible - observation = observation + [-np.inf] * self._num_cells_to_fill_in(self.max_depth) - self.env.dev_obs_dict[handle] = visited - - return observation - - def _num_cells_to_fill_in(self, remaining_depth): - """Computes the length of observation vector: sum_{i=0,depth-1} 2^i * observation_dim.""" - num_observations = 0 - pow4 = 1 - for i in range(remaining_depth): - num_observations += pow4 - pow4 *= 4 - return num_observations * self.observation_dim - - def _explore_branch(self, handle, position, direction, tot_dist, depth): - """ - Utility function to compute tree-based observations. - We walk along the branch and collect the information documented in the get() function. - If there is a branching point a new node is created and each possible branch is explored. - """ - - # [Recursive branch opened] - if depth >= self.max_depth + 1: - return [], [] - - # Continue along direction until next switch or - # until no transitions are possible along the current direction (i.e., dead-ends) - # We treat dead-ends as nodes, instead of going back, to avoid loops - exploring = True - last_is_switch = False - last_is_dead_end = False - last_is_terminal = False # wrong cell OR cycle; either way, we don't want the agent to land here - last_is_target = False - - visited = set() - agent = self.env.agents[handle] - time_per_cell = np.reciprocal(agent.speed_data["speed"]) - own_target_encountered = np.inf - other_agent_encountered = np.inf - other_target_encountered = np.inf - potential_conflict = np.inf - unusable_switch = np.inf - other_agent_same_direction = 0 - other_agent_opposite_direction = 0 - malfunctioning_agent = 0. - min_fractional_speed = 1. - num_steps = 1 - while exploring: - # ############################# - # ############################# - # Modify here to compute any useful data required to build the end node's features. This code is called - # for each cell visited between the previous branching node and the next switch / target / dead-end. - if position in self.location_has_agent: - if tot_dist < other_agent_encountered: - other_agent_encountered = tot_dist - - # Check if any of the observed agents is malfunctioning, store agent with longest duration left - if self.location_has_agent_malfunction[position] > malfunctioning_agent: - malfunctioning_agent = self.location_has_agent_malfunction[position] - - if self.location_has_agent_direction[position] == direction: - # Cummulate the number of agents on branch with same direction - other_agent_same_direction += 1 - - # Check fractional speed of agents - current_fractional_speed = self.location_has_agent_speed[position] - if current_fractional_speed < min_fractional_speed: - min_fractional_speed = current_fractional_speed - - if self.location_has_agent_direction[position] != direction: - # Cummulate the number of agents on branch with other direction - other_agent_opposite_direction += 1 - - # Check number of possible transitions for agent and total number of transitions in cell (type) - cell_transitions = self.env.rail.get_transitions(*position, direction) - transition_bit = bin(self.env.rail.get_full_transitions(*position)) - total_transitions = transition_bit.count("1") - crossing_found = False - if int(transition_bit, 2) == int('1000010000100001', 2): - crossing_found = True - - # Register possible future conflict - predicted_time = int(tot_dist * time_per_cell) - if self.predictor and predicted_time < self.max_prediction_depth: - int_position = coordinate_to_position(self.env.width, [position]) - if tot_dist < self.max_prediction_depth: - - pre_step = max(0, predicted_time - 1) - post_step = min(self.max_prediction_depth - 1, predicted_time + 1) - - # Look for conflicting paths at distance tot_dist - if int_position in np.delete(self.predicted_pos[predicted_time], handle, 0): - conflicting_agent = np.where(self.predicted_pos[predicted_time] == int_position) - for ca in conflicting_agent[0]: - if direction != self.predicted_dir[predicted_time][ca] and cell_transitions[ - self._reverse_dir( - self.predicted_dir[predicted_time][ca])] == 1 and tot_dist < potential_conflict: - potential_conflict = tot_dist - if self.env.dones[ca] and tot_dist < potential_conflict: - potential_conflict = tot_dist - - # Look for conflicting paths at distance num_step-1 - elif int_position in np.delete(self.predicted_pos[pre_step], handle, 0): - conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position) - for ca in conflicting_agent[0]: - if direction != self.predicted_dir[pre_step][ca] \ - and cell_transitions[self._reverse_dir(self.predicted_dir[pre_step][ca])] == 1 \ - and tot_dist < potential_conflict: # noqa: E125 - potential_conflict = tot_dist - if self.env.dones[ca] and tot_dist < potential_conflict: - potential_conflict = tot_dist - - # Look for conflicting paths at distance num_step+1 - elif int_position in np.delete(self.predicted_pos[post_step], handle, 0): - conflicting_agent = np.where(self.predicted_pos[post_step] == int_position) - for ca in conflicting_agent[0]: - if direction != self.predicted_dir[post_step][ca] and cell_transitions[self._reverse_dir( - self.predicted_dir[post_step][ca])] == 1 \ - and tot_dist < potential_conflict: # noqa: E125 - potential_conflict = tot_dist - if self.env.dones[ca] and tot_dist < potential_conflict: - potential_conflict = tot_dist - - if position in self.location_has_target and position != agent.target: - if tot_dist < other_target_encountered: - other_target_encountered = tot_dist - - if position == agent.target and tot_dist < own_target_encountered: - own_target_encountered = tot_dist - - # ############################# - # ############################# - if (position[0], position[1], direction) in visited: - last_is_terminal = True - break - visited.add((position[0], position[1], direction)) - - # If the target node is encountered, pick that as node. Also, no further branching is possible. - if np.array_equal(position, self.env.agents[handle].target): - last_is_target = True - break - - # Check if crossing is found --> Not an unusable switch - if crossing_found: - # Treat the crossing as a straight rail cell - total_transitions = 2 - num_transitions = np.count_nonzero(cell_transitions) - - exploring = False - - # Detect Switches that can only be used by other agents. - if total_transitions > 2 > num_transitions and tot_dist < unusable_switch: - unusable_switch = tot_dist - - if num_transitions == 1: - # Check if dead-end, or if we can go forward along direction - nbits = total_transitions - if nbits == 1: - # Dead-end! - last_is_dead_end = True - - if not last_is_dead_end: - # Keep walking through the tree along `direction` - exploring = True - # convert one-hot encoding to 0,1,2,3 - direction = np.argmax(cell_transitions) - position = get_new_position(position, direction) - num_steps += 1 - tot_dist += 1 - elif num_transitions > 0: - # Switch detected - last_is_switch = True - break - - elif num_transitions == 0: - # Wrong cell type, but let's cover it and treat it as a dead-end, just in case - print("WRONG CELL TYPE detected in tree-search (0 transitions possible) at cell", position[0], - position[1], direction) - last_is_terminal = True - break - - # `position` is either a terminal node or a switch - - # ############################# - # ############################# - # Modify here to append new / different features for each visited cell! - - if last_is_target: - observation = [own_target_encountered, - other_target_encountered, - other_agent_encountered, - potential_conflict, - unusable_switch, - tot_dist, - 0, - other_agent_same_direction, - other_agent_opposite_direction, - malfunctioning_agent, - min_fractional_speed - ] - - elif last_is_terminal: - observation = [own_target_encountered, - other_target_encountered, - other_agent_encountered, - potential_conflict, - unusable_switch, - np.inf, - self.distance_map[handle, position[0], position[1], direction], - other_agent_same_direction, - other_agent_opposite_direction, - malfunctioning_agent, - min_fractional_speed - ] - - else: - observation = [own_target_encountered, - other_target_encountered, - other_agent_encountered, - potential_conflict, - unusable_switch, - tot_dist, - self.distance_map[handle, position[0], position[1], direction], - other_agent_same_direction, - other_agent_opposite_direction, - malfunctioning_agent, - min_fractional_speed - ] - # ############################# - # ############################# - # Start from the current orientation, and see which transitions are available; - # organize them as [left, forward, right, back], relative to the current orientation - # Get the possible transitions - possible_transitions = self.env.rail.get_transitions(*position, direction) - for branch_direction in [(direction + 4 + i) % 4 for i in range(-1, 3)]: - if last_is_dead_end and self.env.rail.get_transition((*position, direction), - (branch_direction + 2) % 4): - # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes - # it back - new_cell = get_new_position(position, (branch_direction + 2) % 4) - branch_observation, branch_visited = self._explore_branch(handle, - new_cell, - (branch_direction + 2) % 4, - tot_dist + 1, - depth + 1) - observation = observation + branch_observation - if len(branch_visited) != 0: - visited = visited.union(branch_visited) - elif last_is_switch and possible_transitions[branch_direction]: - new_cell = get_new_position(position, branch_direction) - branch_observation, branch_visited = self._explore_branch(handle, - new_cell, - branch_direction, - tot_dist + 1, - depth + 1) - observation = observation + branch_observation - if len(branch_visited) != 0: - visited = visited.union(branch_visited) - else: - # no exploring possible, add just cells with infinity - observation = observation + [-np.inf] * self._num_cells_to_fill_in(self.max_depth - depth) - - return observation, visited - - def util_print_obs_subtree(self, tree): - """ - Utility function to pretty-print tree observations returned by this object. - """ - pp = pprint.PrettyPrinter(indent=4) - pp.pprint(self.unfold_observation_tree(tree)) - - def unfold_observation_tree(self, tree, current_depth=0, actions_for_display=True): - """ - Utility function to pretty-print tree observations returned by this object. - """ - if len(tree) < self.observation_dim: - return - - depth = 0 - tmp = len(tree) / self.observation_dim - 1 - pow4 = 4 - while tmp > 0: - tmp -= pow4 - depth += 1 - pow4 *= 4 - - unfolded = {} - unfolded[''] = tree[0:self.observation_dim] - child_size = (len(tree) - self.observation_dim) // 4 - for child in range(4): - child_tree = tree[(self.observation_dim + child * child_size): - (self.observation_dim + (child + 1) * child_size)] - observation_tree = self.unfold_observation_tree(child_tree, current_depth=current_depth + 1) - if observation_tree is not None: - if actions_for_display: - label = self.tree_explorted_actions_char[child] - else: - label = self.tree_explored_actions[child] - unfolded[label] = observation_tree - return unfolded - - def _set_env(self, env): - self.env = env - if self.predictor: - self.predictor._set_env(self.env) - - def _reverse_dir(self, direction): - return int((direction + 2) % 4) - - -class GlobalObsForRailEnv(ObservationBuilder): - """ - Gives a global observation of the entire rail environment. - The observation is composed of the following elements: - - - transition map array with dimensions (env.height, env.width, 16), - assuming 16 bits encoding of transitions. - - - Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent - target and the positions of the other agents targets. - - - A 3D array (map_height, map_width, 8) with the 4 first channels containing the one hot encoding - of the direction of the given agent and the 4 second channels containing the positions - of the other agents at their position coordinates. - """ - - def __init__(self): - self.observation_space = () - super(GlobalObsForRailEnv, self).__init__() - - def _set_env(self, env): - super()._set_env(env) - - self.observation_space = [4, self.env.height, self.env.width] - - def reset(self): - self.rail_obs = np.zeros((self.env.height, self.env.width, 16)) - for i in range(self.rail_obs.shape[0]): - for j in range(self.rail_obs.shape[1]): - bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]] - bitlist = [0] * (16 - len(bitlist)) + bitlist - self.rail_obs[i, j] = np.array(bitlist) - - def get(self, handle): - obs_targets = np.zeros((self.env.height, self.env.width, 2)) - obs_agents_state = np.zeros((self.env.height, self.env.width, 8)) - agents = self.env.agents - agent = agents[handle] - - direction = np.zeros(4) - direction[agent.direction] = 1 - agent_pos = agents[handle].position - obs_agents_state[agent_pos][:4] = direction - obs_targets[agent.target][0] += 1 - - for i in range(len(agents)): - if i != handle: # TODO: handle used as index...? - agent2 = agents[i] - obs_agents_state[agent2.position][4 + agent2.direction] = 1 - obs_targets[agent2.target][1] += 1 - - direction = self._get_one_hot_for_agent_direction(agent) - - return self.rail_obs, obs_agents_state, obs_targets, direction - - -class GlobalObsForRailEnvDirectionDependent(ObservationBuilder): - """ - Gives a global observation of the entire rail environment. - The observation is composed of the following elements: - - - transition map array with dimensions (env.height, env.width, 16), - assuming 16 bits encoding of transitions, flipped in the direction of the agent - (the agent is always heading north on the flipped view). - - - Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent - target and the positions of the other agents targets, also flipped depending on the agent's direction. - - - A 3D array (map_height, map_width, 5) containing the one hot encoding of the direction of the other - agents at their position coordinates, and the last channel containing the position of the given agent. - - - A 4 elements array with one hot encoding of the direction. - """ - - def __init__(self): - self.observation_space = () - super(GlobalObsForRailEnvDirectionDependent, self).__init__() - - def _set_env(self, env): - super()._set_env(env) - - self.observation_space = [4, self.env.height, self.env.width] - - def reset(self): - self.rail_obs = np.zeros((self.env.height, self.env.width, 16)) - for i in range(self.rail_obs.shape[0]): - for j in range(self.rail_obs.shape[1]): - bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]] - bitlist = [0] * (16 - len(bitlist)) + bitlist - self.rail_obs[i, j] = np.array(bitlist) - - def get(self, handle): - obs_targets = np.zeros((self.env.height, self.env.width, 2)) - obs_agents_state = np.zeros((self.env.height, self.env.width, 5)) - agents = self.env.agents - agent = agents[handle] - direction = agent.direction - - idx = np.tile(np.arange(16), 2) - - rail_obs = self.rail_obs[:, :, idx[direction * 4: direction * 4 + 16]] - - if direction == 1: - rail_obs = np.flip(rail_obs, axis=1) - elif direction == 2: - rail_obs = np.flip(rail_obs) - elif direction == 3: - rail_obs = np.flip(rail_obs, axis=0) - - agent_pos = agents[handle].position - obs_agents_state[agent_pos][0] = 1 - obs_targets[agent.target][0] += 1 - - idx = np.tile(np.arange(4), 2) - for i in range(len(agents)): - if i != handle: # TODO: handle used as index...? - agent2 = agents[i] - obs_agents_state[agent2.position][1 + idx[4 + (agent2.direction - direction)]] = 1 - obs_targets[agent2.target][1] += 1 - - direction = self._get_one_hot_for_agent_direction(agent) - - return rail_obs, obs_agents_state, obs_targets, direction - - -class LocalObsForRailEnv(ObservationBuilder): - """ - Gives a local observation of the rail environment around the agent. - The observation is composed of the following elements: - - - transition map array of the local environment around the given agent, - with dimensions (view_height,2*view_width+1, 16), - assuming 16 bits encoding of transitions. - - - Two 2D arrays (view_height,2*view_width+1, 2) containing respectively, - if they are in the agent's vision range, its target position, the positions of the other targets. - - - A 2D array (view_height,2*view_width+1, 4) containing the one hot encoding of directions - of the other agents at their position coordinates, if they are in the agent's vision range. - - - A 4 elements array with one hot encoding of the direction. - - Use the parameters view_width and view_height to define the rectangular view of the agent. - The center parameters moves the agent along the height axis of this rectangle. If it is 0 the agent only has - observation in front of it. - """ - - def __init__(self, view_width, view_height, center): - - super(LocalObsForRailEnv, self).__init__() - self.view_width = view_width - self.view_height = view_height - self.center = center - self.max_padding = max(self.view_width, self.view_height - self.center) - - def reset(self): - # We build the transition map with a view_radius empty cells expansion on each side. - # This helps to collect the local transition map view when the agent is close to a border. - self.max_padding = max(self.view_width, self.view_height) - self.rail_obs = np.zeros((self.env.height, - self.env.width, 16)) - for i in range(self.env.height): - for j in range(self.env.width): - bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]] - bitlist = [0] * (16 - len(bitlist)) + bitlist - self.rail_obs[i, j] = np.array(bitlist) - - def get(self, handle): - agents = self.env.agents - agent = agents[handle] - - # Correct agents position for padding - # agent_rel_pos[0] = agent.position[0] + self.max_padding - # agent_rel_pos[1] = agent.position[1] + self.max_padding - - # Collect visible cells as set to be plotted - visited, rel_coords = self.field_of_view(agent.position, agent.direction, ) - local_rail_obs = None - - # Add the visible cells to the observed cells - self.env.dev_obs_dict[handle] = set(visited) - - # Locate observed agents and their coresponding targets - local_rail_obs = np.zeros((self.view_height, 2 * self.view_width + 1, 16)) - obs_map_state = np.zeros((self.view_height, 2 * self.view_width + 1, 2)) - obs_other_agents_state = np.zeros((self.view_height, 2 * self.view_width + 1, 4)) - _idx = 0 - for pos in visited: - curr_rel_coord = rel_coords[_idx] - local_rail_obs[curr_rel_coord[0], curr_rel_coord[1], :] = self.rail_obs[pos[0], pos[1], :] - if pos == agent.target: - obs_map_state[curr_rel_coord[0], curr_rel_coord[1], 0] = 1 - else: - for tmp_agent in agents: - if pos == tmp_agent.target: - obs_map_state[curr_rel_coord[0], curr_rel_coord[1], 1] = 1 - if pos != agent.position: - for tmp_agent in agents: - if pos == tmp_agent.position: - obs_other_agents_state[curr_rel_coord[0], curr_rel_coord[1], :] = np.identity(4)[ - tmp_agent.direction] - - _idx += 1 - - direction = np.identity(4)[agent.direction] - return local_rail_obs, obs_map_state, obs_other_agents_state, direction - - def get_many(self, handles=None): - """ - Called whenever an observation has to be computed for the `env' environment, for each agent with handle - in the `handles' list. - """ - - observations = {} - for h in handles: - observations[h] = self.get(h) - return observations - - def field_of_view(self, position, direction, state=None): - # Compute the local field of view for an agent in the environment - data_collection = False - if state is not None: - temp_visible_data = np.zeros(shape=(self.view_height, 2 * self.view_width + 1, 16)) - data_collection = True - if direction == 0: - origin = (position[0] + self.center, position[1] - self.view_width) - elif direction == 1: - origin = (position[0] - self.view_width, position[1] - self.center) - elif direction == 2: - origin = (position[0] - self.center, position[1] + self.view_width) - else: - origin = (position[0] + self.view_width, position[1] + self.center) - visible = list() - rel_coords = list() - for h in range(self.view_height): - for w in range(2 * self.view_width + 1): - if direction == 0: - if 0 <= origin[0] - h < self.env.height and 0 <= origin[1] + w < self.env.width: - visible.append((origin[0] - h, origin[1] + w)) - rel_coords.append((h, w)) - # if data_collection: - # temp_visible_data[h, w, :] = state[origin[0] - h, origin[1] + w, :] - elif direction == 1: - if 0 <= origin[0] + w < self.env.height and 0 <= origin[1] + h < self.env.width: - visible.append((origin[0] + w, origin[1] + h)) - rel_coords.append((h, w)) - # if data_collection: - # temp_visible_data[h, w, :] = state[origin[0] + w, origin[1] + h, :] - elif direction == 2: - if 0 <= origin[0] + h < self.env.height and 0 <= origin[1] - w < self.env.width: - visible.append((origin[0] + h, origin[1] - w)) - rel_coords.append((h, w)) - # if data_collection: - # temp_visible_data[h, w, :] = state[origin[0] + h, origin[1] - w, :] - else: - if 0 <= origin[0] - w < self.env.height and 0 <= origin[1] - h < self.env.width: - visible.append((origin[0] - w, origin[1] - h)) - rel_coords.append((h, w)) - # if data_collection: - # temp_visible_data[h, w, :] = state[origin[0] - w, origin[1] - h, :] - if data_collection: - return temp_visible_data - else: - return visible, rel_coords diff --git a/torch_training/predictors/predictions.py b/torch_training/predictors/predictions.py deleted file mode 100644 index 8306b726572e7a7e168e71755a760e010a6b1eb5..0000000000000000000000000000000000000000 --- a/torch_training/predictors/predictions.py +++ /dev/null @@ -1,179 +0,0 @@ -""" -Collection of environment-specific PredictionBuilder. -""" - -import numpy as np - -from flatland.core.env_prediction_builder import PredictionBuilder -from flatland.core.grid.grid4_utils import get_new_position -from flatland.envs.rail_env import RailEnvActions - - -class DummyPredictorForRailEnv(PredictionBuilder): - """ - DummyPredictorForRailEnv object. - - This object returns predictions for agents in the RailEnv environment. - The prediction acts as if no other agent is in the environment and always takes the forward action. - """ - - def get(self, custom_args=None, handle=None): - """ - Called whenever get_many in the observation build is called. - - Parameters - ------- - custom_args: dict - Not used in this dummy implementation. - handle : int (optional) - Handle of the agent for which to compute the observation vector. - - Returns - ------- - np.array - Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1)x5 elements: - - time_offset - - position axis 0 - - position axis 1 - - direction - - action taken to come here - The prediction at 0 is the current position, direction etc. - - """ - agents = self.env.agents - if handle: - agents = [self.env.agents[handle]] - - prediction_dict = {} - - for agent in agents: - action_priorities = [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT] - _agent_initial_position = agent.position - _agent_initial_direction = agent.direction - prediction = np.zeros(shape=(self.max_depth + 1, 5)) - prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0] - for index in range(1, self.max_depth + 1): - action_done = False - # if we're at the target, stop moving... - if agent.position == agent.target: - prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING] - - continue - for action in action_priorities: - cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \ - self.env._check_action_on_agent(action, agent) - if all([new_cell_isValid, transition_isValid]): - # move and change direction to face the new_direction that was - # performed - agent.position = new_position - agent.direction = new_direction - prediction[index] = [index, *new_position, new_direction, action] - action_done = True - break - if not action_done: - raise Exception("Cannot move further. Something is wrong") - prediction_dict[agent.handle] = prediction - agent.position = _agent_initial_position - agent.direction = _agent_initial_direction - return prediction_dict - - -class ShortestPathPredictorForRailEnv(PredictionBuilder): - """ - ShortestPathPredictorForRailEnv object. - - This object returns shortest-path predictions for agents in the RailEnv environment. - The prediction acts as if no other agent is in the environment and always takes the forward action. - """ - - def __init__(self, max_depth=20): - # Initialize with depth 20 - self.max_depth = max_depth - - def get(self, custom_args=None, handle=None): - """ - Called whenever get_many in the observation build is called. - Requires distance_map to extract the shortest path. - - Parameters - ---------- - custom_args: dict - - distance_map : dict - handle : int, optional - Handle of the agent for which to compute the observation vector. - - Returns - ------- - np.array - Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1)x5 elements: - - time_offset - - position axis 0 - - position axis 1 - - direction - - action taken to come here - The prediction at 0 is the current position, direction etc. - """ - agents = self.env.agents - if handle: - agents = [self.env.agents[handle]] - assert custom_args is not None - distance_map = custom_args.get('distance_map') - assert distance_map is not None - - prediction_dict = {} - for agent in agents: - _agent_initial_position = agent.position - _agent_initial_direction = agent.direction - agent_speed = agent.speed_data["speed"] - times_per_cell = int(np.reciprocal(agent_speed)) - prediction = np.zeros(shape=(self.max_depth + 1, 5)) - prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0] - new_direction = _agent_initial_direction - new_position = _agent_initial_position - visited = set() - for index in range(1, self.max_depth + 1): - # if we're at the target, stop moving... - if agent.position == agent.target: - prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING] - visited.add((agent.position[0], agent.position[1], agent.direction)) - continue - if not agent.moving: - prediction[index] = [index, *agent.position, agent.direction, RailEnvActions.STOP_MOVING] - visited.add((agent.position[0], agent.position[1], agent.direction)) - continue - # Take shortest possible path - cell_transitions = self.env.rail.get_transitions(*agent.position, agent.direction) - - if np.sum(cell_transitions) == 1 and index % times_per_cell == 0: - new_direction = np.argmax(cell_transitions) - new_position = get_new_position(agent.position, new_direction) - elif np.sum(cell_transitions) > 1 and index % times_per_cell == 0: - min_dist = np.inf - no_dist_found = True - for direction in range(4): - if cell_transitions[direction] == 1: - neighbour_cell = get_new_position(agent.position, direction) - target_dist = distance_map[agent.handle, neighbour_cell[0], neighbour_cell[1], direction] - if target_dist < min_dist or no_dist_found: - min_dist = target_dist - new_direction = direction - no_dist_found = False - new_position = get_new_position(agent.position, new_direction) - elif index % times_per_cell == 0: - raise Exception("No transition possible {}".format(cell_transitions)) - - # update the agent's position and direction - agent.position = new_position - agent.direction = new_direction - - # prediction is ready - prediction[index] = [index, *new_position, new_direction, 0] - visited.add((new_position[0], new_position[1], new_direction)) - self.env.dev_pred_dict[agent.handle] = visited - prediction_dict[agent.handle] = prediction - - # cleanup: reset initial position - agent.position = _agent_initial_position - agent.direction = _agent_initial_direction - - return prediction_dict diff --git a/torch_training/render_agent_behavior.py b/torch_training/render_agent_behavior.py index 589d6109641c540c103f030e7eb139c35df25298..2649a2367367e17e39328ca8c28cc9c2f1fc0172 100644 --- a/torch_training/render_agent_behavior.py +++ b/torch_training/render_agent_behavior.py @@ -101,8 +101,8 @@ dones_list = [] action_prob = [0] * action_size agent_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents() -agent = Agent(state_size, action_size, "FC", 0) -with path(torch_training.Nets, "navigator_checkpoint10700.pth") as file_in: +agent = Agent(state_size, action_size) +with path(torch_training.Nets, "avoid_checkpoint500.pth") as file_in: agent.qnetwork_local.load_state_dict(torch.load(file_in)) record_images = False diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py index f69929f65accc53101ba28d8904cdf76b7e1cfca..bd221ae4b912b89c1d5bb242676d4f75819cfd90 100644 --- a/torch_training/training_navigation.py +++ b/torch_training/training_navigation.py @@ -11,7 +11,7 @@ sys.path.append(str(base_dir)) import matplotlib.pyplot as plt import numpy as np import torch -from dueling_double_dqn import Agent +from torch_training.dueling_double_dqn import Agent from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_env import RailEnv @@ -117,7 +117,7 @@ def main(argv): cummulated_reward = np.zeros(env.get_num_agents()) # Now we load a Double dueling DQN agent - agent = Agent(state_size, action_size, "FC", 0) + agent = Agent(state_size, action_size) for trials in range(1, n_trials + 1):