diff --git a/checkpoints/210122120236-3000.pth.local b/checkpoints/210122120236-3000.pth.local new file mode 100644 index 0000000000000000000000000000000000000000..fc041a1fb19c59269b14012b7f4a99cdf059f19f Binary files /dev/null and b/checkpoints/210122120236-3000.pth.local differ diff --git a/checkpoints/210122120236-3000.pth.target b/checkpoints/210122120236-3000.pth.target new file mode 100644 index 0000000000000000000000000000000000000000..c68bab3e29942319e053d9a4022fcf93042fa6e9 Binary files /dev/null and b/checkpoints/210122120236-3000.pth.target differ diff --git a/run.py b/run.py index 57f8fdc3d8ef11c8490f0b3e06d572f0e2da1744..6e2a3d88c4945b6b3a9c7159e7ba7f965ffc20cc 100644 --- a/run.py +++ b/run.py @@ -68,10 +68,10 @@ EPSILON = 0.0 # Checkpoint to use (remember to push it!) set_action_size_reduced() load_policy = "DDDQN" -checkpoint = "./checkpoints/210119171409-10000.pth" # 12.18162927750207 +checkpoint = "./checkpoints/210122120236-3000.pth" # 17.011131341978228 EPSILON = 0.0 -load_policy = "DeadLockAvoidance" +# load_policy = "DeadLockAvoidance" # 22.13346834815911 # Use last action cache USE_ACTION_CACHE = False diff --git a/utils/fast_tree_obs.py b/utils/fast_tree_obs.py index f0b6277a808ad029e1cd8e50b3ce4097ee93beb0..4d4ce4b0396f790b2e71879801f2821792116052 100755 --- a/utils/fast_tree_obs.py +++ b/utils/fast_tree_obs.py @@ -4,9 +4,11 @@ import numpy as np from flatland.core.env_observation_builder import ObservationBuilder from flatland.core.grid.grid4_utils import get_new_position from flatland.envs.agent_utils import RailAgentStatus -from flatland.envs.rail_env import fast_count_nonzero, fast_argmax +from flatland.envs.rail_env import fast_count_nonzero, fast_argmax, RailEnvActions +from utils.agent_action_config import get_flatland_full_action_size from utils.agent_can_choose_helper import AgentCanChooseHelper +from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent from utils.deadlock_check import get_agent_positions, get_agent_targets """ @@ -26,8 +28,9 @@ class FastTreeObs(ObservationBuilder): def __init__(self, max_depth: Any): self.max_depth = max_depth - self.observation_dim = 30 + self.observation_dim = 35 self.agent_can_choose_helper = None + self.dead_lock_avoidance_agent = DeadLockAvoidanceAgent(None, get_flatland_full_action_size()) def debug_render(self, env_renderer): agents_can_choose, agents_on_switch, agents_near_to_switch, agents_near_to_switch_all = \ @@ -152,9 +155,12 @@ class FastTreeObs(ObservationBuilder): return has_opp_agent, has_same_agent, has_target, has_opp_target, visited, min_dist def get_many(self, handles: Optional[List[int]] = None): + self.dead_lock_avoidance_agent.reset(self.env) + self.dead_lock_avoidance_agent.start_step(False) self.agent_positions = get_agent_positions(self.env) self.agents_target = get_agent_targets(self.env) observations = super().get_many(handles) + self.dead_lock_avoidance_agent.end_step(False) return observations def get(self, handle: int = 0): @@ -246,6 +252,13 @@ class FastTreeObs(ObservationBuilder): observation[9] = int(agents_near_to_switch) observation[10] = int(agents_near_to_switch_all) + action = self.dead_lock_avoidance_agent.act(handle, None, eps=0) + observation[30] = action == RailEnvActions.DO_NOTHING + observation[31] = action == RailEnvActions.MOVE_LEFT + observation[32] = action == RailEnvActions.MOVE_FORWARD + observation[33] = action == RailEnvActions.MOVE_RIGHT + observation[34] = action == RailEnvActions.STOP_MOVING + self.env.dev_obs_dict.update({handle: visited}) observation[np.isinf(observation)] = -1