Something went wrong on our end
Forked from
adrian_egli / neurips2020-flatland-starter-kit
37 commits behind the upstream repository.
-
Egli Adrian (IT-SCI-API-PFI) authoredEgli Adrian (IT-SCI-API-PFI) authored
deadlockavoidance_with_decision_agent.py 3.34 KiB
from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.rail_env import RailEnv, RailEnvActions
from reinforcement_learning.policy import HybridPolicy
from utils.agent_action_config import map_rail_env_action
from utils.agent_can_choose_helper import AgentCanChooseHelper
from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent
class DeadLockAvoidanceWithDecisionAgent(HybridPolicy):
def __init__(self, env: RailEnv, state_size, action_size, learning_agent):
print(">> DeadLockAvoidanceWithDecisionAgent")
super(DeadLockAvoidanceWithDecisionAgent, self).__init__()
self.env = env
self.state_size = state_size
self.action_size = action_size
self.learning_agent = learning_agent
self.dead_lock_avoidance_agent = DeadLockAvoidanceAgent(self.env, action_size, False)
self.agent_can_choose_helper = AgentCanChooseHelper()
self.memory = self.learning_agent.memory
self.loss = self.learning_agent.loss
def step(self, handle, state, action, reward, next_state, done):
self.dead_lock_avoidance_agent.step(handle, state, action, reward, next_state, done)
self.learning_agent.step(handle, state, action, reward, next_state, done)
self.loss = self.learning_agent.loss
def act(self, handle, state, eps=0.):
agent = self.env.agents[handle]
position = agent.position
if position is None:
position = agent.initial_position
direction = agent.direction
if agent.status < RailAgentStatus.DONE:
agents_on_switch, agents_near_to_switch, _, _ = \
self.agent_can_choose_helper.check_agent_decision(position, direction)
if agents_on_switch or agents_near_to_switch:
return self.learning_agent.act(handle, state, eps)
else:
act = self.dead_lock_avoidance_agent.act(handle, state, -1.0)
return map_rail_env_action(act)
# Agent is still at target cell
return map_rail_env_action(RailEnvActions.DO_NOTHING)
def save(self, filename):
self.dead_lock_avoidance_agent.save(filename)
self.learning_agent.save(filename)
def load(self, filename):
self.dead_lock_avoidance_agent.load(filename)
self.learning_agent.load(filename)
def start_step(self, train):
self.dead_lock_avoidance_agent.start_step(train)
self.learning_agent.start_step(train)
def end_step(self, train):
self.dead_lock_avoidance_agent.end_step(train)
self.learning_agent.end_step(train)
def start_episode(self, train):
self.dead_lock_avoidance_agent.start_episode(train)
self.learning_agent.start_episode(train)
def end_episode(self, train):
self.dead_lock_avoidance_agent.end_episode(train)
self.learning_agent.end_episode(train)
def load_replay_buffer(self, filename):
self.dead_lock_avoidance_agent.load_replay_buffer(filename)
self.learning_agent.load_replay_buffer(filename)
def test(self):
self.dead_lock_avoidance_agent.test()
self.learning_agent.test()
def reset(self, env: RailEnv):
self.env = env
self.agent_can_choose_helper.build_data(env)
self.dead_lock_avoidance_agent.reset(env)
self.learning_agent.reset(env)
def clone(self):
return self