Skip to content
Snippets Groups Projects
Commit 007611b3 authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files

17.01 / DDDQN

parent 97104dee
No related branches found
No related tags found
No related merge requests found
File added
File added
......@@ -68,10 +68,10 @@ EPSILON = 0.0
# Checkpoint to use (remember to push it!)
set_action_size_reduced()
load_policy = "DDDQN"
checkpoint = "./checkpoints/210119171409-10000.pth" # 12.18162927750207
checkpoint = "./checkpoints/210122120236-3000.pth" # 17.011131341978228
EPSILON = 0.0
load_policy = "DeadLockAvoidance"
# load_policy = "DeadLockAvoidance" # 22.13346834815911
# Use last action cache
USE_ACTION_CACHE = False
......
......@@ -4,9 +4,11 @@ import numpy as np
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.rail_env import fast_count_nonzero, fast_argmax
from flatland.envs.rail_env import fast_count_nonzero, fast_argmax, RailEnvActions
from utils.agent_action_config import get_flatland_full_action_size
from utils.agent_can_choose_helper import AgentCanChooseHelper
from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent
from utils.deadlock_check import get_agent_positions, get_agent_targets
"""
......@@ -26,8 +28,9 @@ class FastTreeObs(ObservationBuilder):
def __init__(self, max_depth: Any):
self.max_depth = max_depth
self.observation_dim = 30
self.observation_dim = 35
self.agent_can_choose_helper = None
self.dead_lock_avoidance_agent = DeadLockAvoidanceAgent(None, get_flatland_full_action_size())
def debug_render(self, env_renderer):
agents_can_choose, agents_on_switch, agents_near_to_switch, agents_near_to_switch_all = \
......@@ -152,9 +155,12 @@ class FastTreeObs(ObservationBuilder):
return has_opp_agent, has_same_agent, has_target, has_opp_target, visited, min_dist
def get_many(self, handles: Optional[List[int]] = None):
self.dead_lock_avoidance_agent.reset(self.env)
self.dead_lock_avoidance_agent.start_step(False)
self.agent_positions = get_agent_positions(self.env)
self.agents_target = get_agent_targets(self.env)
observations = super().get_many(handles)
self.dead_lock_avoidance_agent.end_step(False)
return observations
def get(self, handle: int = 0):
......@@ -246,6 +252,13 @@ class FastTreeObs(ObservationBuilder):
observation[9] = int(agents_near_to_switch)
observation[10] = int(agents_near_to_switch_all)
action = self.dead_lock_avoidance_agent.act(handle, None, eps=0)
observation[30] = action == RailEnvActions.DO_NOTHING
observation[31] = action == RailEnvActions.MOVE_LEFT
observation[32] = action == RailEnvActions.MOVE_FORWARD
observation[33] = action == RailEnvActions.MOVE_RIGHT
observation[34] = action == RailEnvActions.STOP_MOVING
self.env.dev_obs_dict.update({handle: visited})
observation[np.isinf(observation)] = -1
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment