Skip to content
Snippets Groups Projects
deadlockavoidance_with_decision_agent.py 3.34 KiB
from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.rail_env import RailEnv, RailEnvActions

from reinforcement_learning.policy import HybridPolicy
from utils.agent_action_config import map_rail_env_action
from utils.agent_can_choose_helper import AgentCanChooseHelper
from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent


class DeadLockAvoidanceWithDecisionAgent(HybridPolicy):

    def __init__(self, env: RailEnv, state_size, action_size, learning_agent):
        print(">> DeadLockAvoidanceWithDecisionAgent")
        super(DeadLockAvoidanceWithDecisionAgent, self).__init__()
        self.env = env
        self.state_size = state_size
        self.action_size = action_size
        self.learning_agent = learning_agent
        self.dead_lock_avoidance_agent = DeadLockAvoidanceAgent(self.env, action_size, False)
        self.agent_can_choose_helper = AgentCanChooseHelper()
        self.memory = self.learning_agent.memory
        self.loss = self.learning_agent.loss

    def step(self, handle, state, action, reward, next_state, done):
        self.dead_lock_avoidance_agent.step(handle, state, action, reward, next_state, done)
        self.learning_agent.step(handle, state, action, reward, next_state, done)
        self.loss = self.learning_agent.loss

    def act(self, handle, state, eps=0.):
        agent = self.env.agents[handle]
        position = agent.position
        if position is None:
            position = agent.initial_position
        direction = agent.direction
        if agent.status < RailAgentStatus.DONE:
            agents_on_switch, agents_near_to_switch, _, _ = \
                self.agent_can_choose_helper.check_agent_decision(position, direction)
            if agents_on_switch or agents_near_to_switch:
                return self.learning_agent.act(handle, state, eps)
            else:
                act = self.dead_lock_avoidance_agent.act(handle, state, -1.0)
                return map_rail_env_action(act)
        # Agent is still at target cell
        return map_rail_env_action(RailEnvActions.DO_NOTHING)

    def save(self, filename):
        self.dead_lock_avoidance_agent.save(filename)
        self.learning_agent.save(filename)

    def load(self, filename):
        self.dead_lock_avoidance_agent.load(filename)
        self.learning_agent.load(filename)

    def start_step(self, train):
        self.dead_lock_avoidance_agent.start_step(train)
        self.learning_agent.start_step(train)

    def end_step(self, train):
        self.dead_lock_avoidance_agent.end_step(train)
        self.learning_agent.end_step(train)

    def start_episode(self, train):
        self.dead_lock_avoidance_agent.start_episode(train)
        self.learning_agent.start_episode(train)

    def end_episode(self, train):
        self.dead_lock_avoidance_agent.end_episode(train)
        self.learning_agent.end_episode(train)

    def load_replay_buffer(self, filename):
        self.dead_lock_avoidance_agent.load_replay_buffer(filename)
        self.learning_agent.load_replay_buffer(filename)

    def test(self):
        self.dead_lock_avoidance_agent.test()
        self.learning_agent.test()

    def reset(self, env: RailEnv):
        self.env = env
        self.agent_can_choose_helper.build_data(env)
        self.dead_lock_avoidance_agent.reset(env)
        self.learning_agent.reset(env)

    def clone(self):
        return self