Skip to content
Snippets Groups Projects
Commit 388822a0 authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files

Policy updated

parent 0a273a94
No related branches found
No related tags found
No related merge requests found
......@@ -10,6 +10,8 @@ from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent
class DeadLockAvoidanceWithDecisionAgent(HybridPolicy):
def __init__(self, env: RailEnv, state_size, action_size, learning_agent):
print(">> DeadLockAvoidanceWithDecisionAgent")
super(DeadLockAvoidanceWithDecisionAgent, self).__init__()
self.env = env
self.state_size = state_size
self.action_size = action_size
......@@ -33,7 +35,7 @@ class DeadLockAvoidanceWithDecisionAgent(HybridPolicy):
if agent.status < RailAgentStatus.DONE:
agents_on_switch, agents_near_to_switch, _, _ = \
self.agent_can_choose_helper.check_agent_decision(position, direction)
if agents_on_switch:
if agents_on_switch or agents_near_to_switch:
return self.learning_agent.act(handle, state, eps)
else:
act = self.dead_lock_avoidance_agent.act(handle, state, -1.0)
......
......@@ -177,12 +177,16 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
policy = PPOPolicy(state_size, get_action_size(), use_replay_buffer=False, in_parameters=train_params)
if False:
policy = DeadLockAvoidanceAgent(train_env, get_action_size())
if False:
inter_policy = PPOPolicy(state_size, get_action_size(), use_replay_buffer=False, in_parameters=train_params)
policy = DeadLockAvoidanceWithDecisionAgent(train_env, state_size, get_action_size(), inter_policy)
if True:
inter_policy = DDDQNPolicy(state_size, get_action_size(), train_params)
policy = DeadLockAvoidanceWithDecisionAgent(train_env, state_size, get_action_size(), inter_policy)
if False:
policy = MultiDecisionAgent(state_size, get_action_size(), train_params)
# make sure that at least one policy is set
if policy is None:
policy = DDDQNPolicy(state_size, get_action_size(), train_params)
# Load existing policy
if train_params.load_policy is not "":
policy.load(train_params.load_policy)
......
......@@ -6,7 +6,7 @@ from flatland.core.env_observation_builder import DummyObservationBuilder
from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.rail_env import RailEnv, RailEnvActions, fast_count_nonzero
from reinforcement_learning.policy import HeuristicPolicy
from reinforcement_learning.policy import HeuristicPolicy, DummyMemory
from utils.agent_action_config import map_rail_env_action
from utils.shortest_distance_walker import ShortestDistanceWalker
......@@ -66,10 +66,6 @@ class DeadlockAvoidanceShortestDistanceWalker(ShortestDistanceWalker):
self.shortest_distance_agent_map[(handle, position[0], position[1])] = 1
self.full_shortest_distance_agent_map[(handle, position[0], position[1])] = 1
class DeadLockAvoidanceAgent(HeuristicPolicy):
def __init__(self, env: RailEnv, action_size, show_debug_plot=False):
self.env = env
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment