Commit f426762a authored by MasterScrat's avatar MasterScrat
Browse files

Submission with bugfix when creating env, using a checkpoint, disabled parallel inference

parent efb0ec8a
......@@ -6,6 +6,7 @@ from pathlib import Path
import numpy as np
import time
import torch
from flatland.core.env_observation_builder import DummyObservationBuilder
from flatland.evaluators.client import FlatlandRemoteClient
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
......@@ -20,25 +21,30 @@ from reinforcement_learning.dddqn_policy import DDDQNPolicy
from utils.observation_utils import normalize_observation
def infer_action(obs):
norm_obs = normalize_observation(obs, tree_depth=observation_tree_depth, observation_radius=observation_radius)
return policy.act(norm_obs, eps=0.0)
def random_controller(obs, number_of_agents):
_action = {}
for _idx in range(number_of_agents):
_action[_idx] = np.random.randint(0, 5)
return _action
# TODO check if the parallel version works properly
#
# def infer_action(obs):
# norm_obs = normalize_observation(obs, tree_depth=observation_tree_depth, observation_radius=observation_radius)
# return policy.act(norm_obs, eps=0.0)
#
#
# def rl_controller(obs, number_of_agents):
# obs_list = []
# for agent in range(number_of_agents):
# if obs[agent] and info['action_required'][agent]:
# obs_list.append(obs[agent])
#
# return dict(zip(range(number_of_agents), pool.map(infer_action, obs_list)))
def rl_controller(obs, number_of_agents):
obs_list = []
action_dict = {}
for agent in range(number_of_agents):
if obs[agent] and info['action_required'][agent]:
obs_list.append(obs[agent])
norm_obs = normalize_observation(obs[agent], tree_depth=observation_tree_depth, observation_radius=observation_radius)
action = policy.act(norm_obs, eps=0.0)
action_dict[agent] = action
return dict(zip(range(number_of_agents), pool.map(infer_action, obs_list)))
return action_dict
if __name__ == "__main__":
......@@ -49,7 +55,7 @@ if __name__ == "__main__":
remote_client = FlatlandRemoteClient()
# Checkpoint to use
checkpoint = "checkpoints/multi-100.pth"
checkpoint = "checkpoints/multi-300.pth"
# Evaluation is faster on CPU (except if you use a really huge)
parameters = {
......@@ -73,7 +79,7 @@ if __name__ == "__main__":
action_size = 5
policy = DDDQNPolicy(state_size, action_size, Namespace(**parameters), evaluation_mode=True)
# policy.qnetwork_local = torch.load(checkpoint)
policy.qnetwork_local = torch.load(checkpoint)
# Controller
pool = Pool()
......@@ -85,7 +91,6 @@ if __name__ == "__main__":
#####################################################################
evaluation_number = 0
while True:
evaluation_number += 1
time_start = time.time()
......@@ -94,23 +99,25 @@ if __name__ == "__main__":
)
env_creation_time = time.time() - time_start
local_env = remote_client.env
number_of_agents = len(local_env.agents)
tree_observation.set_env(local_env)
tree_observation.reset()
tree_observation.prepare_get_many(list(range(number_of_agents)))
observation = tree_observation.get_many(list(range(number_of_agents)))
if not observation:
#
# If the remote_client returns False on a `env_create` call,
# then it basically means that your agent has already been
# evaluated on all the required evaluation environments,
# and hence its safe to break out of the main evaluation loop
break
print("Evaluation Number : {}".format(evaluation_number))
local_env = remote_client.env
nb_agents = len(local_env.agents)
tree_observation.set_env(local_env)
tree_observation.reset()
tree_observation.prepare_get_many(list(range(nb_agents)))
observation = tree_observation.get_many(list(range(nb_agents)))
nb_cities = 2 # FIXME get true value
max_nb_steps = int(4 * 2 * (local_env.width + local_env.height + (nb_agents / nb_cities)))
print("Evaluation {}: {} agents in {}x{}".format(evaluation_number, nb_agents, local_env.width, local_env.height))
# Now we enter into another infinite loop where we
# compute the actions for all the individual steps in this episode
......@@ -120,43 +127,61 @@ if __name__ == "__main__":
steps = 0
while True:
prepare_time, obs_time, agent_time, step_time = 0, 0, 0, 0
#####################################################################
# Evaluation of a single episode
#####################################################################
if number_of_agents < 100:
steps += 1
if nb_agents < 100:
time_start = time.time()
action = rl_controller(observation, number_of_agents)
action = rl_controller(observation, nb_agents)
agent_time = time.time() - time_start
time_taken_by_controller.append(agent_time)
time_start = time.time()
_, all_rewards, done, info = remote_client.env_step(action)
steps += 1
step_time = time.time() - time_start
time_taken_per_step.append(step_time)
time_start = time.time()
tree_observation.prepare_get_many(list(range(number_of_agents)))
tree_observation.prepare_get_many(list(range(nb_agents)))
prepare_time = time.time() - time_start
time_start = time.time()
observation_list = []
for h in range(number_of_agents):
for h in range(nb_agents):
observation_list.append(tree_observation.get(h))
observation = dict(zip(range(number_of_agents), observation_list))
observation = dict(zip(range(nb_agents), observation_list))
obs_time = time.time() - time_start
else:
# too many agents: just act randomly
action = random_controller(observation, number_of_agents)
_, all_rewards, done, info = remote_client.env_step(action)
# too many agents: just wait for it to pass 🏃💨
time_start = time.time()
_, all_rewards, done, info = remote_client.env_step({})
step_time = time.time() - time_start
time_taken_per_step.append(step_time)
#print("Step {}\t Predictor time {:.3f}\t Obs time {:.3f}\t Inference time {:.3f}\t Step time {:.3f}".format(str(steps).zfill(3), prepare_time, obs_time, agent_time, step_time))
nb_agents_done = sum(done[idx] for idx in local_env.get_agent_handles())
print("Step {}/{}\tAgents done: {}\t Predictor time {:.3f}s\t Obs time {:.3f}s\t Inference time {:.3f}s\t Step time {:.3f}s".format(
str(steps).zfill(4),
max_nb_steps,
nb_agents_done,
prepare_time,
obs_time,
agent_time,
step_time
), end="" if done['__all__'] else "\r")
# if check_if_all_blocked(local_env):
# print("DEADLOCKED!!")
# print(evaluation_number, steps, done['__all__'])
if done['__all__']:
print()
print("Reward : ", sum(list(all_rewards.values())))
#
# When done['__all__'] == True, then the evaluation of this
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment