Skip to content
Snippets Groups Projects
Commit 546342c1 authored by gmollard's avatar gmollard
Browse files

added step memory in preprocessor

parent f04c2dfd
No related branches found
No related tags found
No related merge requests found
......@@ -50,15 +50,23 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1):
class TreeObsPreprocessor(Preprocessor):
def _init_shape(self, obs_space, options):
print(options)
self.step_memory = options["custom_options"]["step_memory"]
return sum([space.shape[0] for space in obs_space]),
def transform(self, observation):
data = norm_obs_clip(observation[0][0])
distance = norm_obs_clip(observation[0][1])
agent_data = np.clip(observation[0][2], -1, 1)
data2 = norm_obs_clip(observation[1][0])
distance2 = norm_obs_clip(observation[1][1])
agent_data2 = np.clip(observation[1][2], -1, 1)
if self.step_memory == 2:
data = norm_obs_clip(observation[0][0])
distance = norm_obs_clip(observation[0][1])
agent_data = np.clip(observation[0][2], -1, 1)
data2 = norm_obs_clip(observation[1][0])
distance2 = norm_obs_clip(observation[1][1])
agent_data2 = np.clip(observation[1][2], -1, 1)
else:
data = norm_obs_clip(observation[0])
distance = norm_obs_clip(observation[1])
agent_data = np.clip(observation[2], -1, 1)
return np.concatenate((np.concatenate((np.concatenate((data, distance)), agent_data)), np.concatenate((np.concatenate((data2, distance2)), agent_data2))))
......@@ -105,8 +105,8 @@ def render_training_result(config):
policy = trainer.get_policy("ppo_policy")
preprocessor = preprocessor(obs_space)
env_renderer = RenderTool(env, gl="PIL")
preprocessor = preprocessor(obs_space, {"step_memory": config["step_memory"]})
env_renderer = RenderTool(env, gl="PILSVG")
for episode in range(N_EPISODES):
observation = env.reset()
......
......@@ -99,7 +99,8 @@ def train(config, reporter):
# Trainer configuration
trainer_config = DEFAULT_CONFIG.copy()
trainer_config['model'] = {"fcnet_hiddens": config['hidden_sizes'], "custom_preprocessor": preprocessor}
trainer_config['model'] = {"fcnet_hiddens": config['hidden_sizes'], "custom_preprocessor": preprocessor,
"custom_options": {"step_memory": config["step_memory"]}}
trainer_config['multiagent'] = {"policy_graphs": policy_graphs,
"policy_mapping_fn": policy_mapping_fn,
......@@ -131,6 +132,7 @@ def train(config, reporter):
"on_episode_end": tune.function(on_episode_end)
}
def logger_creator(conf):
"""Creates a Unified logger with a default logdir prefix."""
logdir = config['policy_folder_name'].format(**locals())
......@@ -179,7 +181,8 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every,
"kl_coeff": kl_coeff,
"lambda_gae": lambda_gae,
"min_dist": min_dist,
"step_memory": step_memory
"step_memory": step_memory # If equal to two, the current observation plus
# the observation of last time step will be given as input the the model.
},
resources_per_trial={
"cpu": 3,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment