From 9fd05a66e9cc29d0821d5f4c21f8de48a7e67887 Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Thu, 18 Jul 2019 16:38:43 -0400 Subject: [PATCH] enhanced functionality --- sequential_agent/run_test.py | 9 +++++---- sequential_agent/simple_order_agent.py | 2 +- torch_training/multi_agent_inference.py | 1 - 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sequential_agent/run_test.py b/sequential_agent/run_test.py index 970d6aa..d0b9ce7 100644 --- a/sequential_agent/run_test.py +++ b/sequential_agent/run_test.py @@ -8,11 +8,11 @@ import numpy as np np.random.seed(2) """ -file_name = "./railway/complex_scene.pkl" +file_name = "../torch_training/railway/complex_scene.pkl" env = RailEnv(width=10, height=20, rail_generator=rail_from_file(file_name), - obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv())) + obs_builder_object=TreeObsForRailEnv(max_depth=1, predictor=ShortestPathPredictorForRailEnv())) x_dim = env.width y_dim = env.height @@ -38,8 +38,8 @@ observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=ShortestP env_renderer = RenderTool(env, gl="PILSVG", ) handle = env.get_agent_handles() n_trials = 1 -max_steps = 3 * (env.height + env.width) -record_images = True +max_steps = 100 * (env.height + env.width) +record_images = False agent = OrderedAgent() action_dict = dict() @@ -63,6 +63,7 @@ for trials in range(1, n_trials + 1): for a in range(env.get_num_agents()): if done[a]: acting_agent += 1 + print(acting_agent) if a == acting_agent: action = agent.act(obs[a], eps=0) else: diff --git a/sequential_agent/simple_order_agent.py b/sequential_agent/simple_order_agent.py index 6e888c5..3feff35 100644 --- a/sequential_agent/simple_order_agent.py +++ b/sequential_agent/simple_order_agent.py @@ -18,7 +18,7 @@ class OrderedAgent: min_dist = min_lt(distance, 0) min_direction = np.where(distance == min_dist) if len(min_direction[0]) > 1: - return min_direction[0][0] + 1 + return min_direction[0][-1] + 1 return min_direction[0] + 1 def step(self, memories): diff --git a/torch_training/multi_agent_inference.py b/torch_training/multi_agent_inference.py index e8fd6d4..84a0846 100644 --- a/torch_training/multi_agent_inference.py +++ b/torch_training/multi_agent_inference.py @@ -98,7 +98,6 @@ for trials in range(1, n_trials + 1): for a in range(env.get_num_agents()): action = agent.act(agent_obs[a], eps=0) action_dict.update({a: action}) - # Environment step next_obs, all_rewards, done, _ = env.step(action_dict) -- GitLab