Skip to content
Snippets Groups Projects
Commit 640962b0 authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files

single agent learning in multi agent environment

parent 234a823c
No related branches found
No related tags found
No related merge requests found
File deleted
File deleted
File added
File added
......@@ -19,7 +19,6 @@ from flatland.utils.rendertools import RenderTool
from torch.utils.tensorboard import SummaryWriter
from reinforcement_learning.dddqn_policy import DDDQNPolicy
from reinforcement_learning.ppo.ppo_agent import PPOAgent
from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent
base_dir = Path(__file__).resolve().parent.parent
......@@ -200,9 +199,6 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
# TensorBoard writer
writer = SummaryWriter()
writer.add_hparams(vars(train_params), {})
writer.add_hparams(vars(train_env_params), {})
writer.add_hparams(vars(obs_params), {})
training_timer = Timer()
training_timer.start()
......@@ -287,6 +283,22 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
# Environment step
step_timer.start()
next_obs, all_rewards, done, info = train_env.step(action_dict)
for agent in train_env.get_agent_handles():
act = action_dict.get(agent, RailEnvActions.DO_NOTHING)
if agent_obs[agent][26] == 1:
if act == RailEnvActions.STOP_MOVING:
all_rewards[agent] *= 0.01
else:
if act == RailEnvActions.MOVE_LEFT:
all_rewards[agent] *= 0.9
else:
if agent_obs[agent][7] == 0 and agent_obs[agent][8] == 0:
if act == RailEnvActions.MOVE_FORWARD:
all_rewards[agent] *= 0.01
if done[agent]:
all_rewards[agent] += 100.0
step_timer.end()
# Render an episode at some interval
......@@ -495,11 +507,11 @@ if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=5400, type=int)
parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=1, type=int)
parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=1,
parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=0,
type=int)
parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=25, type=int)
parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=1, type=int)
parser.add_argument("--checkpoint_interval", help="checkpoint interval", default=100, type=int)
parser.add_argument("--eps_start", help="max exploration", default=1.0, type=float)
parser.add_argument("--eps_start", help="max exploration", default=0.1, type=float)
parser.add_argument("--eps_end", help="min exploration", default=0.01, type=float)
parser.add_argument("--eps_decay", help="exploration decay", default=0.9998, type=float)
parser.add_argument("--buffer_size", help="replay buffer size", default=int(1e7), type=int)
......@@ -519,7 +531,7 @@ if __name__ == "__main__":
parser.add_argument("--load_policy", help="policy filename (reference) to load", default="", type=str)
parser.add_argument("--use_fast_tree_observation", help="use FastTreeObs instead of stock TreeObs",
action='store_true')
parser.add_argument("--max_depth", help="max depth", default=2, type=int)
parser.add_argument("--max_depth", help="max depth", default=1, type=int)
training_params = parser.parse_args()
env_params = [
......
......@@ -10,7 +10,6 @@ from flatland.envs.rail_env import RailEnvActions
from flatland.evaluators.client import FlatlandRemoteClient
from flatland.evaluators.client import TimeoutException
from reinforcement_learning.ppo.ppo_agent import PPOAgent
from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent
from utils.deadlock_check import check_if_all_blocked
from utils.fast_tree_obs import FastTreeObs
......@@ -27,9 +26,8 @@ from reinforcement_learning.dddqn_policy import DDDQNPolicy
VERBOSE = True
# Checkpoint to use (remember to push it!)
checkpoint = "./checkpoints/201105222046-5400.pth" # 17.66104361971127 Depth 1
checkpoint = "./checkpoints/201106073658-4400.pth" # 15.64082361736683 Depth 1
checkpoint = "./checkpoints/201106170544-5400.pth" # 15.64082361736683 Depth 1
checkpoint = "./checkpoints/201106234244-400.pth" # 15.64082361736683 Depth 1
checkpoint = "./checkpoints/201106234900-100.pth" # 15.64082361736683 Depth 1
# Use last action cache
USE_ACTION_CACHE = False
......@@ -140,9 +138,8 @@ while True:
nb_hit += 1
else:
action = policy.act(observation[agent], eps=0.01)
if observation[agent][26] == 1:
action = RailEnvActions.STOP_MOVING
if observation[agent][26] == 1:
action = RailEnvActions.STOP_MOVING
action_dict[agent] = action
......
......@@ -23,7 +23,7 @@ class FastTreeObs(ObservationBuilder):
def __init__(self, max_depth):
self.max_depth = max_depth
self.observation_dim = 32
self.observation_dim = 27
def build_data(self):
if self.env is not None:
......@@ -303,23 +303,6 @@ class FastTreeObs(ObservationBuilder):
action = self.dead_lock_avoidance_agent.act([handle], 0.0)
observation[26] = int(action == RailEnvActions.STOP_MOVING)
observation[27] = int(action == RailEnvActions.MOVE_LEFT)
observation[28] = int(action == RailEnvActions.MOVE_FORWARD)
observation[29] = int(action == RailEnvActions.MOVE_RIGHT)
observation[30] = int(self.full_action_required(observation))
observation[31] = int(fast_tree_obs_check_agent_deadlock(observation))
self.env.dev_obs_dict.update({handle: visited})
return observation
def full_action_required(self, observation):
return observation[7] == 1 or observation[8] == 1 or observation[4] == 1
def fast_tree_obs_check_agent_deadlock(observation):
nbr_of_path = 0
nbr_of_blocked_path = 0
for dir_loop in range(4):
nbr_of_path += observation[10 + dir_loop]
nbr_of_blocked_path += int(observation[14 + dir_loop] > 0)
return nbr_of_path <= nbr_of_blocked_path
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment