Commit 0d3e32ed authored by MasterScrat's avatar MasterScrat
Browse files

Use random actions if more than 100 agents

parent beac6e15
......@@ -25,6 +25,13 @@ def infer_action(obs):
return policy.act(norm_obs, eps=0.0)
def random_controller(obs, number_of_agents):
_action = {}
for _idx in range(number_of_agents):
_action[_idx] = np.random.randint(0, 5)
return _action
def rl_controller(obs, number_of_agents):
obs_list = []
for agent in range(number_of_agents):
......@@ -52,7 +59,7 @@ if __name__ == "__main__":
# Observation parameters
observation_tree_depth = 1
observation_radius = 10
observation_max_path_depth = 30
observation_max_path_depth = 20
# Observation builder
predictor = ShortestPathPredictorForRailEnv(observation_max_path_depth)
......@@ -116,29 +123,35 @@ if __name__ == "__main__":
#####################################################################
# Evaluation of a single episode
#####################################################################
time_start = time.time()
action = rl_controller(observation, number_of_agents)
agent_time = time.time() - time_start
time_taken_by_controller.append(agent_time)
time_start = time.time()
_, all_rewards, done, info = remote_client.env_step(action)
steps += 1
step_time = time.time() - time_start
time_taken_per_step.append(step_time)
time_start = time.time()
tree_observation.prepare_get_many(list(range(number_of_agents)))
prepare_time = time.time() - time_start
time_start = time.time()
observation_list = []
for h in range(number_of_agents):
observation_list.append(tree_observation.get(h))
observation = dict(zip(range(number_of_agents), observation_list))
obs_time = time.time() - time_start
# print("Step {}\t Prepare time {:.3f}\t Obs time {:.3f}\t Inference time {:.3f}\t Step time {:.3f}".format(str(steps).zfill(3), prepare_time, obs_time, agent_time, step_time))
if number_of_agents < 100:
time_start = time.time()
action = rl_controller(observation, number_of_agents)
agent_time = time.time() - time_start
time_taken_by_controller.append(agent_time)
time_start = time.time()
_, all_rewards, done, info = remote_client.env_step(action)
steps += 1
step_time = time.time() - time_start
time_taken_per_step.append(step_time)
time_start = time.time()
tree_observation.prepare_get_many(list(range(number_of_agents)))
prepare_time = time.time() - time_start
time_start = time.time()
observation_list = []
for h in range(number_of_agents):
observation_list.append(tree_observation.get(h))
observation = dict(zip(range(number_of_agents), observation_list))
obs_time = time.time() - time_start
else:
# too many agents: just act randomly
action = random_controller(observation, number_of_agents)
_, all_rewards, done, info = remote_client.env_step(action)
#print("Step {}\t Predictor time {:.3f}\t Obs time {:.3f}\t Inference time {:.3f}\t Step time {:.3f}".format(str(steps).zfill(3), prepare_time, obs_time, agent_time, step_time))
# if check_if_all_blocked(local_env):
# print("DEADLOCKED!!")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment