test_flatland_malfunction.py 4.01 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import numpy as np

from flatland.envs.generators import complex_rail_generator
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.rail_env import RailEnv


class SingleAgentNavigationObs(TreeObsForRailEnv):
    """
    We derive our bbservation builder from TreeObsForRailEnv, to exploit the existing implementation to compute
    the minimum distances from each grid node to each agent's target.

    We then build a representation vector with 3 binary components, indicating which of the 3 available directions
    for each agent (Left, Forward, Right) lead to the shortest path to its target.
    E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
    will be [1, 0, 0].
    """

    def __init__(self):
        super().__init__(max_depth=0)
        self.observation_space = [3]

    def reset(self):
        # Recompute the distance map, if the environment has changed.
        super().reset()

    def get(self, handle):
        agent = self.env.agents[handle]

        possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
        num_transitions = np.count_nonzero(possible_transitions)

        # Start from the current orientation, and see which transitions are available;
        # organize them as [left, forward, right], relative to the current orientation
        # If only one transition is possible, the forward branch is aligned with it.
        if num_transitions == 1:
            observation = [0, 1, 0]
        else:
            min_distances = []
            for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
                if possible_transitions[direction]:
                    new_position = self._new_position(agent.position, direction)
                    min_distances.append(self.distance_map[handle, new_position[0], new_position[1], direction])
                else:
                    min_distances.append(np.inf)

            observation = [0, 0, 0]
            observation[np.argmin(min_distances)] = 1

        return observation


def test_malfunction_process():
Erik Nygren's avatar
Erik Nygren committed
54
    # Set fixed malfunction duration for this test
55
56
57
    stochastic_data = {'prop_malfunction': 1.,
                       'malfunction_rate': 5,
                       'min_duration': 3,
Erik Nygren's avatar
Erik Nygren committed
58
                       'max_duration': 3}
59
60
61
62
63
64
65
66
67
68
69
    np.random.seed(5)

    env = RailEnv(width=14,
                  height=14,
                  rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999,
                                                        seed=0),
                  number_of_agents=2,
                  obs_builder_object=SingleAgentNavigationObs(),
                  stochastic_data=stochastic_data)

    obs = env.reset()
Erik Nygren's avatar
Erik Nygren committed
70
71
72
73

    # Check that a initial duration for malfunction was assigned
    assert env.agents[0].malfunction_data['next_malfunction'] > 0

74
    agent_halts = 0
Erik Nygren's avatar
Erik Nygren committed
75
76
77
    total_down_time = 0
    agent_malfunctioning = False
    agent_old_position = env.agents[0].position
78
79
80
81
82
83
    for step in range(100):
        actions = {}
        for i in range(len(obs)):
            actions[i] = np.argmax(obs[i]) + 1

        if step % 5 == 0:
Erik Nygren's avatar
Erik Nygren committed
84
            # Stop the agent and set it to be malfunctioning
85
            actions[0] = 4
Erik Nygren's avatar
Erik Nygren committed
86
            env.agents[0].malfunction_data['next_malfunction'] = 0
87
88
            agent_halts += 1

Erik Nygren's avatar
Erik Nygren committed
89
90
91
92
93
        if env.agents[0].malfunction_data['malfunction'] > 0:
            agent_malfunctioning = True
        else:
            agent_malfunctioning = False

94
95
        obs, all_rewards, done, _ = env.step(actions)

Erik Nygren's avatar
Erik Nygren committed
96
97
98
99
100
101
        if agent_malfunctioning:
            assert agent_old_position == env.agents[0].position

        agent_old_position = env.agents[0].position
        total_down_time += env.agents[0].malfunction_data['malfunction']

102
103

    # Check that the agents breaks twice
Erik Nygren's avatar
Erik Nygren committed
104
105
106
107
    assert env.agents[0].malfunction_data['nr_malfunctions'] == 5

    # Check that 11 stops where performed
    assert agent_halts == 20
108

Erik Nygren's avatar
Erik Nygren committed
109
110
    # Check that malfunctioning data was standing around
    assert total_down_time > 0