diff --git a/src/extra.py b/src/extra.py index a6e237102cc64dd475e0b25a01a518cfca8149dd..044c32f757a397ec20387ce8f0f707418848c37e 100644 --- a/src/extra.py +++ b/src/extra.py @@ -3,7 +3,6 @@ from flatland.core.grid.grid4_utils import get_new_position from flatland.envs.agent_utils import RailAgentStatus from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnvActions -from flatland.utils.rendertools import RenderTool, AgentRenderVariant from src.agent.dueling_double_dqn import Agent from src.observations import normalize_observation @@ -150,6 +149,7 @@ class Extra: position = self.env.agents[a].initial_position first_step = True direction = self.env.agents[a].direction + cnt = 0 while position is not None: # and position != self.env.agents[a].target: possible_transitions = self.env.rail.get_transitions(*position, direction) # num_transitions = np.count_nonzero(possible_transitions) @@ -200,6 +200,10 @@ class Extra: else: position = None + cnt += 1 + if cnt > 100: + position = None + return agents_with_deadlock def generate_state(self, handle: int, root, max_depth: int): diff --git a/src/observations.py b/src/observations.py index cc89198440dd932f1ee474deff63a2450242106e..fd9659cd41d72e8c92d2a96ac40a961524fb27ce 100644 --- a/src/observations.py +++ b/src/observations.py @@ -264,6 +264,7 @@ class MyTreeObsForRailEnv(ObservationBuilder): last_is_a_decision_cell = False target_encountered = 0 + cnt = 0 while exploring: dist_min_to_target = min(dist_min_to_target, self.env.distance_map.get()[handle, position[0], position[1], @@ -327,6 +328,10 @@ class MyTreeObsForRailEnv(ObservationBuilder): direction = np.argmax(cell_transitions) position = get_new_position(position, direction) + cnt += 1 + if cnt > 1000: + exploring = False + # ############################# # ############################# # Modify here to append new / different features for each visited cell!