diff --git a/sequential_agent/run_test.py b/sequential_agent/run_test.py index 970d6aadb8aeec1086afb16c257ad3cd65902f45..d0b9ce70f50465a58885a9b1feb754791bb49f34 100644 --- a/sequential_agent/run_test.py +++ b/sequential_agent/run_test.py @@ -8,11 +8,11 @@ import numpy as np np.random.seed(2) """ -file_name = "./railway/complex_scene.pkl" +file_name = "../torch_training/railway/complex_scene.pkl" env = RailEnv(width=10, height=20, rail_generator=rail_from_file(file_name), - obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv())) + obs_builder_object=TreeObsForRailEnv(max_depth=1, predictor=ShortestPathPredictorForRailEnv())) x_dim = env.width y_dim = env.height @@ -38,8 +38,8 @@ observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=ShortestP env_renderer = RenderTool(env, gl="PILSVG", ) handle = env.get_agent_handles() n_trials = 1 -max_steps = 3 * (env.height + env.width) -record_images = True +max_steps = 100 * (env.height + env.width) +record_images = False agent = OrderedAgent() action_dict = dict() @@ -63,6 +63,7 @@ for trials in range(1, n_trials + 1): for a in range(env.get_num_agents()): if done[a]: acting_agent += 1 + print(acting_agent) if a == acting_agent: action = agent.act(obs[a], eps=0) else: diff --git a/sequential_agent/simple_order_agent.py b/sequential_agent/simple_order_agent.py index 6e888c51ab7210062ee6efb9862cd78e5a61ca5a..3feff350e94226f157559036abdaea8d5dc18bf9 100644 --- a/sequential_agent/simple_order_agent.py +++ b/sequential_agent/simple_order_agent.py @@ -18,7 +18,7 @@ class OrderedAgent: min_dist = min_lt(distance, 0) min_direction = np.where(distance == min_dist) if len(min_direction[0]) > 1: - return min_direction[0][0] + 1 + return min_direction[0][-1] + 1 return min_direction[0] + 1 def step(self, memories): diff --git a/torch_training/multi_agent_inference.py b/torch_training/multi_agent_inference.py index e8fd6d4cd607e23c2a32f4e789e5f800d9c0461e..84a0846cef45454f574c03db8c8a77d264fd798d 100644 --- a/torch_training/multi_agent_inference.py +++ b/torch_training/multi_agent_inference.py @@ -98,7 +98,6 @@ for trials in range(1, n_trials + 1): for a in range(env.get_num_agents()): action = agent.act(agent_obs[a], eps=0) action_dict.update({a: action}) - # Environment step next_obs, all_rewards, done, _ = env.step(action_dict)