"""
Collection of environment-specific PredictionBuilder.
"""

import numpy as np

from flatland.core.env_prediction_builder import PredictionBuilder
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.rail_env import RailEnvActions


class ShortestPathPredictorForRailEnv(PredictionBuilder):
    """
    ShortestPathPredictorForRailEnv object.

    This object returns shortest-path predictions for agents in the RailEnv environment.
    The prediction acts as if no other agent is in the environment and always takes the forward action.
    """

    def __init__(self, max_depth):
        self.max_depth = max_depth

    def get(self, custom_args=None, handle=None):
        """
        Called whenever get_many in the observation build is called.
        Requires distance_map to extract the shortest path.

        Parameters
        ----------
        custom_args: dict
            - distance_map : dict
        handle : int, optional
            Handle of the agent for which to compute the observation vector.

        Returns
        -------
        np.array
            Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1)x5 elements:
            - time_offset
            - position axis 0
            - position axis 1
            - direction
            - action taken to come here
            The prediction at 0 is the current position, direction etc.
        """
        agents = self.env.agents
        if handle:
            agents = [self.env.agents[handle]]
        assert custom_args is not None
        distance_map = custom_args.get('distance_map')
        assert distance_map is not None

        prediction_dict = {}
        for agent in agents:
            _agent_initial_position = agent.position
            _agent_initial_direction = agent.direction
            prediction = np.zeros(shape=(self.max_depth + 1, 5))
            prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0]
            visited = set()

            for index in range(1, self.max_depth + 1):
                # if we're at the target, stop moving...
                if agent.position == agent.target:
                    prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING]
                    visited.add((agent.position[0], agent.position[1], agent.direction))
                    continue
                if not agent.moving:
                    prediction[index] = [index, *agent.position, agent.direction, RailEnvActions.STOP_MOVING]
                    visited.add((agent.position[0], agent.position[1], agent.direction))
                    continue
                # Take shortest possible path
                cell_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)

                new_position = None
                new_direction = None
                if np.sum(cell_transitions) == 1:
                    new_direction = np.argmax(cell_transitions)
                    new_position = get_new_position(agent.position, new_direction)
                elif np.sum(cell_transitions) > 1:
                    min_dist = np.inf
                    no_dist_found = True
                    for direction in range(4):
                        if cell_transitions[direction] == 1:
                            neighbour_cell = get_new_position(agent.position, direction)
                            target_dist = distance_map[agent.handle, neighbour_cell[0], neighbour_cell[1], direction]
                            if target_dist < min_dist or no_dist_found:
                                min_dist = target_dist
                                new_direction = direction
                                no_dist_found = False
                    new_position = get_new_position(agent.position, new_direction)
                else:
                    raise Exception("No transition possible {}".format(cell_transitions))

                # update the agent's position and direction
                agent.position = new_position
                agent.direction = new_direction

                # prediction is ready
                prediction[index] = [index, *new_position, new_direction, 0]
                visited.add((new_position[0], new_position[1], new_direction))
            self.env.dev_pred_dict[agent.handle] = visited
            prediction_dict[agent.handle] = prediction

            # cleanup: reset initial position
            agent.position = _agent_initial_position
            agent.direction = _agent_initial_direction

        return prediction_dict