Skip to content
Snippets Groups Projects
Commit 0d3e32ed authored by MasterScrat's avatar MasterScrat
Browse files

Use random actions if more than 100 agents

parent beac6e15
No related branches found
No related tags found
No related merge requests found
......@@ -25,6 +25,13 @@ def infer_action(obs):
return policy.act(norm_obs, eps=0.0)
def random_controller(obs, number_of_agents):
_action = {}
for _idx in range(number_of_agents):
_action[_idx] = np.random.randint(0, 5)
return _action
def rl_controller(obs, number_of_agents):
obs_list = []
for agent in range(number_of_agents):
......@@ -52,7 +59,7 @@ if __name__ == "__main__":
# Observation parameters
observation_tree_depth = 1
observation_radius = 10
observation_max_path_depth = 30
observation_max_path_depth = 20
# Observation builder
predictor = ShortestPathPredictorForRailEnv(observation_max_path_depth)
......@@ -116,29 +123,35 @@ if __name__ == "__main__":
#####################################################################
# Evaluation of a single episode
#####################################################################
time_start = time.time()
action = rl_controller(observation, number_of_agents)
agent_time = time.time() - time_start
time_taken_by_controller.append(agent_time)
time_start = time.time()
_, all_rewards, done, info = remote_client.env_step(action)
steps += 1
step_time = time.time() - time_start
time_taken_per_step.append(step_time)
time_start = time.time()
tree_observation.prepare_get_many(list(range(number_of_agents)))
prepare_time = time.time() - time_start
time_start = time.time()
observation_list = []
for h in range(number_of_agents):
observation_list.append(tree_observation.get(h))
observation = dict(zip(range(number_of_agents), observation_list))
obs_time = time.time() - time_start
# print("Step {}\t Prepare time {:.3f}\t Obs time {:.3f}\t Inference time {:.3f}\t Step time {:.3f}".format(str(steps).zfill(3), prepare_time, obs_time, agent_time, step_time))
if number_of_agents < 100:
time_start = time.time()
action = rl_controller(observation, number_of_agents)
agent_time = time.time() - time_start
time_taken_by_controller.append(agent_time)
time_start = time.time()
_, all_rewards, done, info = remote_client.env_step(action)
steps += 1
step_time = time.time() - time_start
time_taken_per_step.append(step_time)
time_start = time.time()
tree_observation.prepare_get_many(list(range(number_of_agents)))
prepare_time = time.time() - time_start
time_start = time.time()
observation_list = []
for h in range(number_of_agents):
observation_list.append(tree_observation.get(h))
observation = dict(zip(range(number_of_agents), observation_list))
obs_time = time.time() - time_start
else:
# too many agents: just act randomly
action = random_controller(observation, number_of_agents)
_, all_rewards, done, info = remote_client.env_step(action)
#print("Step {}\t Predictor time {:.3f}\t Obs time {:.3f}\t Inference time {:.3f}\t Step time {:.3f}".format(str(steps).zfill(3), prepare_time, obs_time, agent_time, step_time))
# if check_if_all_blocked(local_env):
# print("DEADLOCKED!!")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment