Skip to content
Snippets Groups Projects
shortest_Distance_walker.py 3.93 KiB
import numpy as np
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_env import fast_count_nonzero, fast_argmax


class ShortestDistanceWalker:
    def __init__(self, env: RailEnv):
        self.env = env

    def walk(self, handle, position, direction):
        possible_transitions = self.env.rail.get_transitions(*position, direction)
        num_transitions = fast_count_nonzero(possible_transitions)
        if num_transitions == 1:
            new_direction = fast_argmax(possible_transitions)
            new_position = get_new_position(position, new_direction)

            dist = self.env.distance_map.get()[handle, new_position[0], new_position[1], new_direction]
            return new_position, new_direction, dist, RailEnvActions.MOVE_FORWARD, possible_transitions
        else:
            min_distances = []
            positions = []
            directions = []
            for new_direction in [(direction + i) % 4 for i in range(-1, 2)]:
                if possible_transitions[new_direction]:
                    new_position = get_new_position(position, new_direction)
                    min_distances.append(
                        self.env.distance_map.get()[handle, new_position[0], new_position[1], new_direction])
                    positions.append(new_position)
                    directions.append(new_direction)
                else:
                    min_distances.append(np.inf)
                    positions.append(None)
                    directions.append(None)

        a = self.get_action(handle, min_distances)
        return positions[a], directions[a], min_distances[a], a + 1, possible_transitions

    def get_action(self, handle, min_distances):
        return np.argmin(min_distances)

    def callback(self, handle, agent, position, direction, action, possible_transitions):
        pass

    def get_agent_position_and_direction(self, handle):
        agent = self.env.agents[handle]
        if agent.position is not None:
            position = agent.position
        else:
            position = agent.initial_position
        direction = agent.direction
        return position, direction

    def walk_to_target(self, handle, position=None, direction=None, max_step=500):
        if position is None and direction is None:
            position, direction = self.get_agent_position_and_direction(handle)
        elif position is None:
            position, _ = self.get_agent_position_and_direction(handle)
        elif direction is None:
            _, direction = self.get_agent_position_and_direction(handle)

        agent = self.env.agents[handle]
        step = 0
        while (position != agent.target) and (step < max_step):
            position, direction, dist, action, possible_transitions = self.walk(handle, position, direction)
            if position is None:
                break
            self.callback(handle, agent, position, direction, action, possible_transitions)
            step += 1

    def callback_one_step(self, handle, agent, position, direction, action, possible_transitions):
        pass

    def walk_one_step(self, handle):
        agent = self.env.agents[handle]
        if agent.position is not None:
            position = agent.position
        else:
            position = agent.initial_position
        direction = agent.direction
        possible_transitions = (0, 1, 0, 0)
        if (position != agent.target):
            new_position, new_direction, dist, action, possible_transitions = self.walk(handle, position, direction)
            if new_position is None:
                return position, direction, RailEnvActions.STOP_MOVING, possible_transitions
            self.callback_one_step(handle, agent, new_position, new_direction, action, possible_transitions)
        return new_position, new_direction, action, possible_transitions