Something went wrong on our end
Forked from
adrian_egli / neurips2020-flatland-starter-kit
87 commits behind the upstream repository.
-
Egli Adrian (IT-SCI-API-PFI) authoredEgli Adrian (IT-SCI-API-PFI) authored
multi_policy.py 2.50 KiB
import numpy as np
from flatland.envs.rail_env import RailEnvActions
from reinforcement_learning.policy import Policy
from reinforcement_learning.ppo.ppo_agent import PPOAgent
from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent
from utils.extra import ExtraPolicy
class MultiPolicy(Policy):
def __init__(self, state_size, action_size, n_agents, env):
self.state_size = state_size
self.action_size = action_size
self.memory = []
self.loss = 0
self.extra_policy = ExtraPolicy(state_size, action_size)
self.ppo_policy = PPOAgent(state_size + action_size, action_size, n_agents, env)
def load(self, filename):
self.ppo_policy.load(filename)
self.extra_policy.load(filename)
def save(self, filename):
self.ppo_policy.save(filename)
self.extra_policy.save(filename)
def step(self, handle, state, action, reward, next_state, done):
action_extra_state = self.extra_policy.act(handle, state, 0.0)
action_extra_next_state = self.extra_policy.act(handle, next_state, 0.0)
extended_state = np.copy(state)
for action_itr in np.arange(self.action_size):
extended_state = np.append(extended_state, [int(action_extra_state == action_itr)])
extended_next_state = np.copy(next_state)
for action_itr in np.arange(self.action_size):
extended_next_state = np.append(extended_next_state, [int(action_extra_next_state == action_itr)])
self.extra_policy.step(handle, state, action, reward, next_state, done)
self.ppo_policy.step(handle, extended_state, action, reward, extended_next_state, done)
def act(self, handle, state, eps=0.):
action_extra_state = self.extra_policy.act(handle, state, 0.0)
extended_state = np.copy(state)
for action_itr in np.arange(self.action_size):
extended_state = np.append(extended_state, [int(action_extra_state == action_itr)])
action_ppo = self.ppo_policy.act(handle, extended_state, eps)
self.loss = self.ppo_policy.loss
return action_ppo
def reset(self):
self.ppo_policy.reset()
self.extra_policy.reset()
def test(self):
self.ppo_policy.test()
self.extra_policy.test()
def start_step(self):
self.extra_policy.start_step()
self.ppo_policy.start_step()
def end_step(self):
self.extra_policy.end_step()
self.ppo_policy.end_step()