Skip to content
Snippets Groups Projects
Commit 9fd05a66 authored by Erik Nygren's avatar Erik Nygren
Browse files

enhanced functionality

parent cd27006a
No related branches found
No related tags found
No related merge requests found
......@@ -8,11 +8,11 @@ import numpy as np
np.random.seed(2)
"""
file_name = "./railway/complex_scene.pkl"
file_name = "../torch_training/railway/complex_scene.pkl"
env = RailEnv(width=10,
height=20,
rail_generator=rail_from_file(file_name),
obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()))
obs_builder_object=TreeObsForRailEnv(max_depth=1, predictor=ShortestPathPredictorForRailEnv()))
x_dim = env.width
y_dim = env.height
......@@ -38,8 +38,8 @@ observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=ShortestP
env_renderer = RenderTool(env, gl="PILSVG", )
handle = env.get_agent_handles()
n_trials = 1
max_steps = 3 * (env.height + env.width)
record_images = True
max_steps = 100 * (env.height + env.width)
record_images = False
agent = OrderedAgent()
action_dict = dict()
......@@ -63,6 +63,7 @@ for trials in range(1, n_trials + 1):
for a in range(env.get_num_agents()):
if done[a]:
acting_agent += 1
print(acting_agent)
if a == acting_agent:
action = agent.act(obs[a], eps=0)
else:
......
......@@ -18,7 +18,7 @@ class OrderedAgent:
min_dist = min_lt(distance, 0)
min_direction = np.where(distance == min_dist)
if len(min_direction[0]) > 1:
return min_direction[0][0] + 1
return min_direction[0][-1] + 1
return min_direction[0] + 1
def step(self, memories):
......
......@@ -98,7 +98,6 @@ for trials in range(1, n_trials + 1):
for a in range(env.get_num_agents()):
action = agent.act(agent_obs[a], eps=0)
action_dict.update({a: action})
# Environment step
next_obs, all_rewards, done, _ = env.step(action_dict)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment