Skip to content
Snippets Groups Projects
multi_policy.py 2.50 KiB
import numpy as np
from flatland.envs.rail_env import RailEnvActions

from reinforcement_learning.policy import Policy
from reinforcement_learning.ppo.ppo_agent import PPOAgent
from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent
from utils.extra import ExtraPolicy


class MultiPolicy(Policy):
    def __init__(self, state_size, action_size, n_agents, env):
        self.state_size = state_size
        self.action_size = action_size
        self.memory = []
        self.loss = 0
        self.extra_policy = ExtraPolicy(state_size, action_size)
        self.ppo_policy = PPOAgent(state_size + action_size, action_size, n_agents, env)

    def load(self, filename):
        self.ppo_policy.load(filename)
        self.extra_policy.load(filename)

    def save(self, filename):
        self.ppo_policy.save(filename)
        self.extra_policy.save(filename)

    def step(self, handle, state, action, reward, next_state, done):
        action_extra_state = self.extra_policy.act(handle, state, 0.0)
        action_extra_next_state = self.extra_policy.act(handle, next_state, 0.0)

        extended_state = np.copy(state)
        for action_itr in np.arange(self.action_size):
            extended_state = np.append(extended_state, [int(action_extra_state == action_itr)])
        extended_next_state = np.copy(next_state)
        for action_itr in np.arange(self.action_size):
            extended_next_state = np.append(extended_next_state, [int(action_extra_next_state == action_itr)])

        self.extra_policy.step(handle, state, action, reward, next_state, done)
        self.ppo_policy.step(handle, extended_state, action, reward, extended_next_state, done)

    def act(self, handle, state, eps=0.):
        action_extra_state = self.extra_policy.act(handle, state, 0.0)
        extended_state = np.copy(state)
        for action_itr in np.arange(self.action_size):
            extended_state = np.append(extended_state, [int(action_extra_state == action_itr)])
        action_ppo = self.ppo_policy.act(handle, extended_state, eps)
        self.loss = self.ppo_policy.loss
        return action_ppo

    def reset(self):
        self.ppo_policy.reset()
        self.extra_policy.reset()

    def test(self):
        self.ppo_policy.test()
        self.extra_policy.test()

    def start_step(self):
        self.extra_policy.start_step()
        self.ppo_policy.start_step()

    def end_step(self):
        self.extra_policy.end_step()
        self.ppo_policy.end_step()