import random

import numpy as np

from flatland.core.env_observation_builder import ObservationBuilder
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import random_rail_generator

random.seed(100)
np.random.seed(100)


class SimpleObs(ObservationBuilder):
    """
    Simplest observation builder. The object returns observation vectors with 5 identical components,
    all equal to the ID of the respective agent.
    """

    def __init__(self):
        super().__init__()

    def reset(self):
        return

    def get(self, handle: int = 0) -> np.ndarray:
        observation = handle * np.ones((5,))
        return observation


def main():
    env = RailEnv(width=7, height=7, rail_generator=random_rail_generator(), number_of_agents=3,
                  obs_builder_object=SimpleObs())
    env.reset()

    # Print the observation vector for each agents
    obs, all_rewards, done, _ = env.step({0: 0})
    for i in range(env.get_num_agents()):
        print("Agent ", i, "'s observation: ", obs[i])


if __name__ == '__main__':
    main()