From 976867032529a0570d9dfa181253d6f1cda2830b Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Tue, 9 Jul 2019 15:12:49 +0200 Subject: [PATCH] #42 run baselines in ci: num_features_per_node --- RLLib_training/RailEnvRLLibWrapper.py | 11 +++++------ torch_training/multi_agent_training.py | 4 ++-- torch_training/training_navigation.py | 10 ++-------- utils/misc_utils.py | 3 +-- 4 files changed, 10 insertions(+), 18 deletions(-) diff --git a/RLLib_training/RailEnvRLLibWrapper.py b/RLLib_training/RailEnvRLLibWrapper.py index d065063..9898000 100644 --- a/RLLib_training/RailEnvRLLibWrapper.py +++ b/RLLib_training/RailEnvRLLibWrapper.py @@ -1,9 +1,10 @@ import numpy as np -from flatland.envs.generators import complex_rail_generator, random_rail_generator -from flatland.envs.rail_env import RailEnv from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.rllib.utils.seed import seed as set_seed +from flatland.envs.generators import complex_rail_generator, random_rail_generator +from flatland.envs.rail_env import RailEnv + class RailEnvRLLibWrapper(MultiAgentEnv): @@ -63,7 +64,7 @@ class RailEnvRLLibWrapper(MultiAgentEnv): for i_agent in range(len(self.env.agents)): data, distance, agent_data = self.env.obs_builder.split_tree(tree=np.array(obs[i_agent]), - num_features_per_node=8, current_depth=0) + current_depth=0) o[i_agent] = [data, distance, agent_data] # needed for the renderer @@ -72,8 +73,6 @@ class RailEnvRLLibWrapper(MultiAgentEnv): self.agents_static = self.env.agents_static self.dev_obs_dict = self.env.dev_obs_dict - - # If step_memory > 1, we need to concatenate it the observations in memory, only works for # step_memory = 1 or 2 for the moment if self.step_memory < 2: @@ -96,7 +95,7 @@ class RailEnvRLLibWrapper(MultiAgentEnv): for i_agent in range(len(self.env.agents)): if i_agent not in self.agents_done: data, distance, agent_data = self.env.obs_builder.split_tree(tree=np.array(obs[i_agent]), - num_features_per_node=8, current_depth=0) + current_depth=0) o[i_agent] = [data, distance, agent_data] r[i_agent] = rewards[i_agent] diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py index a7104d3..7b9470d 100644 --- a/torch_training/multi_agent_training.py +++ b/torch_training/multi_agent_training.py @@ -100,7 +100,7 @@ for trials in range(1, n_trials + 1): final_obs = obs.copy() final_obs_next = obs.copy() for a in range(env.get_num_agents()): - data, distance, agent_data = split_tree(tree=np.array(obs[a]), num_features_per_node=features_per_node, + data, distance, agent_data = split_tree(tree=np.array(obs[a]), current_depth=0) data = norm_obs_clip(data) distance = norm_obs_clip(distance) @@ -141,7 +141,7 @@ for trials in range(1, n_trials + 1): # print(all_rewards,action) obs_original = next_obs.copy() for a in range(env.get_num_agents()): - data, distance, agent_data = split_tree(tree=np.array(next_obs[a]), num_features_per_node=features_per_node, + data, distance, agent_data = split_tree(tree=np.array(next_obs[a]), current_depth=0) data = norm_obs_clip(data) distance = norm_obs_clip(distance) diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py index 8410bf2..88aa57e 100644 --- a/torch_training/training_navigation.py +++ b/torch_training/training_navigation.py @@ -1,13 +1,10 @@ from collections import deque + import matplotlib.pyplot as plt import numpy as np import random - import torch from dueling_double_dqn import Agent -from importlib_resources import path - -import torch_training.Nets from flatland.envs.generators import complex_rail_generator from flatland.envs.observations import TreeObsForRailEnv @@ -91,7 +88,6 @@ for trials in range(1, n_trials + 1): # Build agent specific local observation for a in range(env.get_num_agents()): rail_data, distance_data, agent_data = split_tree(tree=np.array(obs[a]), - num_features_per_node=features_per_node, current_depth=0) rail_data = norm_obs_clip(rail_data) distance_data = norm_obs_clip(distance_data) @@ -125,7 +121,6 @@ for trials in range(1, n_trials + 1): for a in range(env.get_num_agents()): rail_data, distance_data, agent_data = split_tree(tree=np.array(next_obs[a]), - num_features_per_node=features_per_node, current_depth=0) rail_data = norm_obs_clip(rail_data) distance_data = norm_obs_clip(distance_data) @@ -185,7 +180,7 @@ env_renderer.set_new_rail() # Split the observation tree into its parts and normalize the observation using the utility functions. # Build agent specific local observation for a in range(env.get_num_agents()): - rail_data, distance_data, agent_data = split_tree(tree=np.array(obs[a]), num_features_per_node=features_per_node, + rail_data, distance_data, agent_data = split_tree(tree=np.array(obs[a]), current_depth=0) rail_data = norm_obs_clip(rail_data) distance_data = norm_obs_clip(distance_data) @@ -211,7 +206,6 @@ for step in range(max_steps): for a in range(env.get_num_agents()): rail_data, distance_data, agent_data = split_tree(tree=np.array(next_obs[a]), - num_features_per_node=features_per_node, current_depth=0) rail_data = norm_obs_clip(rail_data) distance_data = norm_obs_clip(distance_data) diff --git a/utils/misc_utils.py b/utils/misc_utils.py index 03c9fdd..5b29c6b 100644 --- a/utils/misc_utils.py +++ b/utils/misc_utils.py @@ -101,7 +101,7 @@ def run_test(parameters, agent, test_nr=0, tree_depth=3): lp_reset(True, True) obs = env.reset(True, True) for a in range(env.get_num_agents()): - data, distance, agent_data = split_tree(tree=np.array(obs[a]), num_features_per_node=9, + data, distance, agent_data = split_tree(tree=np.array(obs[a]), current_depth=0) data = norm_obs_clip(data) distance = norm_obs_clip(distance) @@ -127,7 +127,6 @@ def run_test(parameters, agent, test_nr=0, tree_depth=3): for a in range(env.get_num_agents()): data, distance, agent_data = split_tree(tree=np.array(next_obs[a]), - num_features_per_node=features_per_node, current_depth=0) data = norm_obs_clip(data) distance = norm_obs_clip(distance) -- GitLab