test_flatland_malfunction.py 5.68 KB
Newer Older
1
2
import numpy as np

3
from flatland.core.grid.grid4_utils import get_new_position
4
5
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.rail_env import RailEnv
u214892's avatar
u214892 committed
6
from flatland.envs.rail_generators import complex_rail_generator
7
from flatland.envs.schedule_generators import complex_schedule_generator
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43


class SingleAgentNavigationObs(TreeObsForRailEnv):
    """
    We derive our bbservation builder from TreeObsForRailEnv, to exploit the existing implementation to compute
    the minimum distances from each grid node to each agent's target.

    We then build a representation vector with 3 binary components, indicating which of the 3 available directions
    for each agent (Left, Forward, Right) lead to the shortest path to its target.
    E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
    will be [1, 0, 0].
    """

    def __init__(self):
        super().__init__(max_depth=0)
        self.observation_space = [3]

    def reset(self):
        # Recompute the distance map, if the environment has changed.
        super().reset()

    def get(self, handle):
        agent = self.env.agents[handle]

        possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
        num_transitions = np.count_nonzero(possible_transitions)

        # Start from the current orientation, and see which transitions are available;
        # organize them as [left, forward, right], relative to the current orientation
        # If only one transition is possible, the forward branch is aligned with it.
        if num_transitions == 1:
            observation = [0, 1, 0]
        else:
            min_distances = []
            for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
                if possible_transitions[direction]:
44
                    new_position = get_new_position(agent.position, direction)
45
                    min_distances.append(self.env.distance_map.get()[handle, new_position[0], new_position[1], direction])
46
47
48
49
50
51
52
53
54
55
                else:
                    min_distances.append(np.inf)

            observation = [0, 0, 0]
            observation[np.argmin(min_distances)] = 1

        return observation


def test_malfunction_process():
Erik Nygren's avatar
Erik Nygren committed
56
    # Set fixed malfunction duration for this test
57
    stochastic_data = {'prop_malfunction': 1.,
58
                       'malfunction_rate': 1000,
59
                       'min_duration': 3,
Erik Nygren's avatar
Erik Nygren committed
60
                       'max_duration': 3}
61
62
    np.random.seed(5)

63
64
    env = RailEnv(width=20,
                  height=20,
65
66
                  rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999,
                                                        seed=0),
67
                  schedule_generator=complex_schedule_generator(),
68
69
70
71
72
                  number_of_agents=2,
                  obs_builder_object=SingleAgentNavigationObs(),
                  stochastic_data=stochastic_data)

    obs = env.reset()
Erik Nygren's avatar
Erik Nygren committed
73
74
75
76

    # Check that a initial duration for malfunction was assigned
    assert env.agents[0].malfunction_data['next_malfunction'] > 0

77
    agent_halts = 0
Erik Nygren's avatar
Erik Nygren committed
78
79
80
    total_down_time = 0
    agent_malfunctioning = False
    agent_old_position = env.agents[0].position
81
82
83
84
85
86
    for step in range(100):
        actions = {}
        for i in range(len(obs)):
            actions[i] = np.argmax(obs[i]) + 1

        if step % 5 == 0:
Erik Nygren's avatar
Erik Nygren committed
87
            # Stop the agent and set it to be malfunctioning
88
            env.agents[0].malfunction_data['malfunction'] = -1
Erik Nygren's avatar
Erik Nygren committed
89
            env.agents[0].malfunction_data['next_malfunction'] = 0
90
91
            agent_halts += 1

92
93
        obs, all_rewards, done, _ = env.step(actions)

Erik Nygren's avatar
Erik Nygren committed
94
95
96
97
98
99
        if env.agents[0].malfunction_data['malfunction'] > 0:
            agent_malfunctioning = True
        else:
            agent_malfunctioning = False

        if agent_malfunctioning:
Erik Nygren's avatar
Erik Nygren committed
100
            # Check that agent is not moving while malfunctioning
Erik Nygren's avatar
Erik Nygren committed
101
102
103
104
105
            assert agent_old_position == env.agents[0].position

        agent_old_position = env.agents[0].position
        total_down_time += env.agents[0].malfunction_data['malfunction']

Erik Nygren's avatar
Erik Nygren committed
106
    # Check that the appropriate number of malfunctions is achieved
107
    assert env.agents[0].malfunction_data['nr_malfunctions'] == 21
Erik Nygren's avatar
Erik Nygren committed
108

Erik Nygren's avatar
Erik Nygren committed
109
    # Check that 20 stops where performed
Erik Nygren's avatar
Erik Nygren committed
110
    assert agent_halts == 20
111

Erik Nygren's avatar
Erik Nygren committed
112
113
    # Check that malfunctioning data was standing around
    assert total_down_time > 0
u214892's avatar
u214892 committed
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146


def test_malfunction_process_statistically():
    """Tests hat malfunctions are produced by stochastic_data!"""
    # Set fixed malfunction duration for this test
    stochastic_data = {'prop_malfunction': 1.,
                       'malfunction_rate': 2,
                       'min_duration': 3,
                       'max_duration': 3}
    np.random.seed(5)

    env = RailEnv(width=20,
                  height=20,
                  rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999,
                                                        seed=0),
                  schedule_generator=complex_schedule_generator(),
                  number_of_agents=2,
                  obs_builder_object=SingleAgentNavigationObs(),
                  stochastic_data=stochastic_data)

    env.reset()
    nb_malfunction = 0
    for step in range(100):
        action_dict = {}
        for agent in env.agents:
            if agent.malfunction_data['malfunction'] > 0:
                nb_malfunction += 1
            # We randomly select an action
            action_dict[agent.handle] = np.random.randint(4)

        env.step(action_dict)

    # check that generation of malfunctions works as expected
147
148
    # results are different in py36 and py37, therefore no exact test on nb_malfunction
    assert nb_malfunction > 150