Skip to content
Snippets Groups Projects
Commit 640962b0 authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files

single agent learning in multi agent environment

parent 234a823c
No related branches found
Tags submission-v7.8
No related merge requests found
File deleted
File deleted
File added
File added
......@@ -19,7 +19,6 @@ from flatland.utils.rendertools import RenderTool
from torch.utils.tensorboard import SummaryWriter
from reinforcement_learning.dddqn_policy import DDDQNPolicy
from reinforcement_learning.ppo.ppo_agent import PPOAgent
from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent
base_dir = Path(__file__).resolve().parent.parent
......@@ -200,9 +199,6 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
# TensorBoard writer
writer = SummaryWriter()
writer.add_hparams(vars(train_params), {})
writer.add_hparams(vars(train_env_params), {})
writer.add_hparams(vars(obs_params), {})
training_timer = Timer()
training_timer.start()
......@@ -287,6 +283,22 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
# Environment step
step_timer.start()
next_obs, all_rewards, done, info = train_env.step(action_dict)
for agent in train_env.get_agent_handles():
act = action_dict.get(agent, RailEnvActions.DO_NOTHING)
if agent_obs[agent][26] == 1:
if act == RailEnvActions.STOP_MOVING:
all_rewards[agent] *= 0.01
else:
if act == RailEnvActions.MOVE_LEFT:
all_rewards[agent] *= 0.9
else:
if agent_obs[agent][7] == 0 and agent_obs[agent][8] == 0:
if act == RailEnvActions.MOVE_FORWARD:
all_rewards[agent] *= 0.01
if done[agent]:
all_rewards[agent] += 100.0
step_timer.end()
# Render an episode at some interval
......@@ -495,11 +507,11 @@ if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=5400, type=int)
parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=1, type=int)
parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=1,
parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=0,
type=int)
parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=25, type=int)
parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=1, type=int)
parser.add_argument("--checkpoint_interval", help="checkpoint interval", default=100, type=int)
parser.add_argument("--eps_start", help="max exploration", default=1.0, type=float)
parser.add_argument("--eps_start", help="max exploration", default=0.1, type=float)
parser.add_argument("--eps_end", help="min exploration", default=0.01, type=float)
parser.add_argument("--eps_decay", help="exploration decay", default=0.9998, type=float)
parser.add_argument("--buffer_size", help="replay buffer size", default=int(1e7), type=int)
......@@ -519,7 +531,7 @@ if __name__ == "__main__":
parser.add_argument("--load_policy", help="policy filename (reference) to load", default="", type=str)
parser.add_argument("--use_fast_tree_observation", help="use FastTreeObs instead of stock TreeObs",
action='store_true')
parser.add_argument("--max_depth", help="max depth", default=2, type=int)
parser.add_argument("--max_depth", help="max depth", default=1, type=int)
training_params = parser.parse_args()
env_params = [
......
......@@ -10,7 +10,6 @@ from flatland.envs.rail_env import RailEnvActions
from flatland.evaluators.client import FlatlandRemoteClient
from flatland.evaluators.client import TimeoutException
from reinforcement_learning.ppo.ppo_agent import PPOAgent
from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent
from utils.deadlock_check import check_if_all_blocked
from utils.fast_tree_obs import FastTreeObs
......@@ -27,9 +26,8 @@ from reinforcement_learning.dddqn_policy import DDDQNPolicy
VERBOSE = True
# Checkpoint to use (remember to push it!)
checkpoint = "./checkpoints/201105222046-5400.pth" # 17.66104361971127 Depth 1
checkpoint = "./checkpoints/201106073658-4400.pth" # 15.64082361736683 Depth 1
checkpoint = "./checkpoints/201106170544-5400.pth" # 15.64082361736683 Depth 1
checkpoint = "./checkpoints/201106234244-400.pth" # 15.64082361736683 Depth 1
checkpoint = "./checkpoints/201106234900-100.pth" # 15.64082361736683 Depth 1
# Use last action cache
USE_ACTION_CACHE = False
......@@ -140,9 +138,8 @@ while True:
nb_hit += 1
else:
action = policy.act(observation[agent], eps=0.01)
if observation[agent][26] == 1:
action = RailEnvActions.STOP_MOVING
if observation[agent][26] == 1:
action = RailEnvActions.STOP_MOVING
action_dict[agent] = action
......
......@@ -23,7 +23,7 @@ class FastTreeObs(ObservationBuilder):
def __init__(self, max_depth):
self.max_depth = max_depth
self.observation_dim = 32
self.observation_dim = 27
def build_data(self):
if self.env is not None:
......@@ -303,23 +303,6 @@ class FastTreeObs(ObservationBuilder):
action = self.dead_lock_avoidance_agent.act([handle], 0.0)
observation[26] = int(action == RailEnvActions.STOP_MOVING)
observation[27] = int(action == RailEnvActions.MOVE_LEFT)
observation[28] = int(action == RailEnvActions.MOVE_FORWARD)
observation[29] = int(action == RailEnvActions.MOVE_RIGHT)
observation[30] = int(self.full_action_required(observation))
observation[31] = int(fast_tree_obs_check_agent_deadlock(observation))
self.env.dev_obs_dict.update({handle: visited})
return observation
def full_action_required(self, observation):
return observation[7] == 1 or observation[8] == 1 or observation[4] == 1
def fast_tree_obs_check_agent_deadlock(observation):
nbr_of_path = 0
nbr_of_blocked_path = 0
for dir_loop in range(4):
nbr_of_path += observation[10 + dir_loop]
nbr_of_blocked_path += int(observation[14 + dir_loop] > 0)
return nbr_of_path <= nbr_of_blocked_path
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment