From c8bd78337bce7864e4badfd765377f20f94844e3 Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Wed, 17 Jul 2019 12:28:45 -0400 Subject: [PATCH] minor updates and new training checkpoints --- torch_training/multi_agent_inference.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torch_training/multi_agent_inference.py b/torch_training/multi_agent_inference.py index 71ca244..4ee79bb 100644 --- a/torch_training/multi_agent_inference.py +++ b/torch_training/multi_agent_inference.py @@ -3,7 +3,7 @@ from collections import deque import numpy as np import torch -from flatland.envs.generators import complex_rail_generator +from flatland.envs.generators import rail_from_file from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv @@ -16,7 +16,7 @@ from utils.observation_utils import norm_obs_clip, split_tree random.seed(3) np.random.seed(2) -""" + file_name = "./railway/flatland.pkl" env = RailEnv(width=10, height=20, @@ -41,7 +41,7 @@ env = RailEnv(width=x_dim, obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()), number_of_agents=n_agents) env.reset(True, True) - +""" tree_depth = 3 observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=ShortestPathPredictorForRailEnv()) env_renderer = RenderTool(env, gl="PILSVG", ) @@ -81,7 +81,7 @@ for trials in range(1, n_trials + 1): # Reset environment obs = env.reset(True, True) - env_renderer.set_new_rail() + env_renderer.reset() for a in range(env.get_num_agents()): data, distance, agent_data = split_tree(tree=np.array(obs[a]), num_features_per_node=num_features_per_node, -- GitLab