custom_observation_example_02_SingleAgentNavigationObs.py 3.72 KB
Newer Older
1 2 3 4
import getopt
import random
import sys
import time
5
from typing import List
6 7 8

import numpy as np

9
from flatland.core.env_observation_builder import ObservationBuilder
10
from flatland.core.grid.grid4_utils import get_new_position
11 12 13
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator
from flatland.envs.schedule_generators import complex_schedule_generator
u214892's avatar
u214892 committed
14
from flatland.utils.misc import str2bool
15 16 17 18 19 20
from flatland.utils.rendertools import RenderTool

random.seed(100)
np.random.seed(100)


21
class SingleAgentNavigationObs(ObservationBuilder):
22
    """
23
    We build a representation vector with 3 binary components, indicating which of the 3 available directions
24 25 26 27 28 29
    for each agent (Left, Forward, Right) lead to the shortest path to its target.
    E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
    will be [1, 0, 0].
    """

    def __init__(self):
30
        super().__init__()
31 32

    def reset(self):
33
        pass
34

35
    def get(self, handle: int = 0) -> List[int]:
36 37
        agent = self.env.agents[handle]

Erik Nygren's avatar
Erik Nygren committed
38 39 40 41 42
        if agent.position:
            possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
        else:
            possible_transitions = self.env.rail.get_transitions(*agent.initial_position, agent.direction)

43 44 45 46 47 48 49 50 51 52 53
        num_transitions = np.count_nonzero(possible_transitions)

        # Start from the current orientation, and see which transitions are available;
        # organize them as [left, forward, right], relative to the current orientation
        # If only one transition is possible, the forward branch is aligned with it.
        if num_transitions == 1:
            observation = [0, 1, 0]
        else:
            min_distances = []
            for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
                if possible_transitions[direction]:
54
                    new_position = get_new_position(agent.position, direction)
u214892's avatar
u214892 committed
55 56
                    min_distances.append(
                        self.env.distance_map.get()[handle, new_position[0], new_position[1], direction])
57 58 59 60 61 62 63 64 65 66 67
                else:
                    min_distances.append(np.inf)

            observation = [0, 0, 0]
            observation[np.argmin(min_distances)] = 1

        return observation


def main(args):
    try:
68
        opts, args = getopt.getopt(args, "", ["sleep-for-animation=", ""])
69 70 71
    except getopt.GetoptError as err:
        print(str(err))  # will print something like "option -a not recognized"
        sys.exit(2)
72
    sleep_for_animation = True
73
    for o, a in opts:
74
        if o in ("--sleep-for-animation"):
u214892's avatar
u214892 committed
75
            sleep_for_animation = str2bool(a)
76 77 78
        else:
            assert False, "unhandled option"

79
    env = RailEnv(width=7, height=7,
80
                  rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999,
81 82
                                                        seed=1), schedule_generator=complex_schedule_generator(),
                  number_of_agents=1, obs_builder_object=SingleAgentNavigationObs())
83

84
    obs, info = env.reset()
MasterScrat's avatar
MasterScrat committed
85
    env_renderer = RenderTool(env)
86 87 88 89 90 91
    env_renderer.render_env(show=True, frames=True, show_observations=True)
    for step in range(100):
        action = np.argmax(obs[0]) + 1
        obs, all_rewards, done, _ = env.step({0: action})
        print("Rewards: ", all_rewards, "  [done=", done, "]")
        env_renderer.render_env(show=True, frames=True, show_observations=True)
92
        if sleep_for_animation:
93 94 95 96 97 98 99 100 101 102 103
            time.sleep(0.1)
        if done["__all__"]:
            break
    env_renderer.close_window()


if __name__ == '__main__':
    if 'argv' in globals():
        main(argv)
    else:
        main(sys.argv[1:])