Skip to content
Snippets Groups Projects
Commit 625403b0 authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files

10 Agents ~0.9257%

parent 8eb24851
No related branches found
No related tags found
No related merge requests found
...@@ -285,21 +285,13 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): ...@@ -285,21 +285,13 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
step_timer.start() step_timer.start()
next_obs, all_rewards, done, info = train_env.step(action_dict) next_obs, all_rewards, done, info = train_env.step(action_dict)
if False: if True:
for agent in train_env.get_agent_handles(): for agent in train_env.get_agent_handles():
act = action_dict.get(agent, RailEnvActions.DO_NOTHING) act = action_dict.get(agent, RailEnvActions.DO_NOTHING)
if agent_obs[agent][26] == 1: if agent_obs[agent][5] == 1:
if act == RailEnvActions.STOP_MOVING: if agent_obs[agent][26] == 1:
all_rewards[agent] *= 0.01 if act != RailEnvActions.STOP_MOVING:
else: all_rewards[agent] -= 10.0
if act == RailEnvActions.MOVE_LEFT:
all_rewards[agent] *= 0.9
else:
if agent_obs[agent][7] == 0 and agent_obs[agent][8] == 0:
if act == RailEnvActions.MOVE_FORWARD:
all_rewards[agent] *= 0.01
if done[agent]:
all_rewards[agent] += 100.0
step_timer.end() step_timer.end()
...@@ -508,11 +500,11 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params): ...@@ -508,11 +500,11 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params):
if __name__ == "__main__": if __name__ == "__main__":
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=5400, type=int) parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=5400, type=int)
parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=1, parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=2,
type=int) type=int)
parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=1, parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=2,
type=int) type=int)
parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=10, type=int) parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=25, type=int)
parser.add_argument("--checkpoint_interval", help="checkpoint interval", default=100, type=int) parser.add_argument("--checkpoint_interval", help="checkpoint interval", default=100, type=int)
parser.add_argument("--eps_start", help="max exploration", default=1.0, type=float) parser.add_argument("--eps_start", help="max exploration", default=1.0, type=float)
parser.add_argument("--eps_end", help="min exploration", default=0.01, type=float) parser.add_argument("--eps_end", help="min exploration", default=0.01, type=float)
......
...@@ -26,7 +26,7 @@ from reinforcement_learning.dddqn_policy import DDDQNPolicy ...@@ -26,7 +26,7 @@ from reinforcement_learning.dddqn_policy import DDDQNPolicy
VERBOSE = True VERBOSE = True
# Checkpoint to use (remember to push it!) # Checkpoint to use (remember to push it!)
checkpoint = "./checkpoints/201111175340-5400.pth" checkpoint = "./checkpoints/201112143850-4100.pth" # 21.543589381053096 DEPTH=2
# Use last action cache # Use last action cache
USE_ACTION_CACHE = False USE_ACTION_CACHE = False
...@@ -137,14 +137,13 @@ while True: ...@@ -137,14 +137,13 @@ while True:
nb_hit += 1 nb_hit += 1
else: else:
action = policy.act(observation[agent], eps=0.01) action = policy.act(observation[agent], eps=0.01)
#if observation[agent][26] == 1:
# action = RailEnvActions.STOP_MOVING
action_dict[agent] = action action_dict[agent] = action
if USE_ACTION_CACHE: if USE_ACTION_CACHE:
agent_last_obs[agent] = observation[agent] agent_last_obs[agent] = observation[agent]
agent_last_action[agent] = action agent_last_action[agent] = action
policy.end_step() policy.end_step()
agent_time = time.time() - time_start agent_time = time.time() - time_start
time_taken_by_controller.append(agent_time) time_taken_by_controller.append(agent_time)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment