test_global_observation.py 7.84 KB
Newer Older
1
2
import numpy as np

3
from flatland.envs.agent_utils import EnvAgent
4
from flatland.envs.observations import GlobalObsForRailEnv
u214892's avatar
u214892 committed
5
from flatland.envs.rail_env import RailEnv, RailEnvActions
6
from flatland.envs.rail_generators import sparse_rail_generator
7
from flatland.envs.line_generators import sparse_line_generator
8
from flatland.envs.step_utils.states import TrainState
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24


def test_get_global_observation():
    number_of_agents = 20

    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
                       'malfunction_rate': 30,  # Rate of malfunction occurence
                       'min_duration': 3,  # Minimal duration of malfunction
                       'max_duration': 20  # Max duration of malfunction
                       }

    speed_ration_map = {1.: 0.25,  # Fast passenger train
                        1. / 2.: 0.25,  # Fast freight train
                        1. / 3.: 0.25,  # Slow commuter train
                        1. / 4.: 0.25}  # Slow freight train

25
26
27
28
29
    env = RailEnv(width=50, height=50, rail_generator=sparse_rail_generator(max_num_cities=6,
                                                                            max_rails_between_cities=4,
                                                                            seed=15,
                                                                            grid_mode=False
                                                                            ),
30
                  line_generator=sparse_line_generator(speed_ration_map), number_of_agents=number_of_agents,
Erik Nygren's avatar
Erik Nygren committed
31
                  obs_builder_object=GlobalObsForRailEnv())
32
    env.reset()
33

Dipam Chakraborty's avatar
Dipam Chakraborty committed
34
35
36
37
    # Perform DO_NOTHING actions until all trains get to READY_TO_DEPART
    for _ in range(max([agent.earliest_departure for agent in env.agents])):
        env.step({}) # DO_NOTHING for all agents

u214892's avatar
u214892 committed
38
    obs, all_rewards, done, _ = env.step({i: RailEnvActions.MOVE_FORWARD for i in range(number_of_agents)})
39
    for i in range(len(env.agents)):
u214892's avatar
u214892 committed
40
        agent: EnvAgent = env.agents[i]
41
        print("[{}] state={}, position={}, target={}, initial_position={}".format(i, agent.state, agent.position,
u214892's avatar
u214892 committed
42
43
44
45
                                                                                   agent.target,
                                                                                   agent.initial_position))

    for i, agent in enumerate(env.agents):
46
47
48
        obs_agents_state = obs[i][1]
        obs_targets = obs[i][2]

u214892's avatar
u214892 committed
49
        # test first channel of obs_targets: own target
50
        nr_agents = np.count_nonzero(obs_targets[:, :, 0])
u214892's avatar
u214892 committed
51
        assert nr_agents == 1, "agent {}: something wrong with own target, found {}".format(i, nr_agents)
52

u214892's avatar
u214892 committed
53
        # test second channel of obs_targets: other agent's target
u214892's avatar
u214892 committed
54
55
56
57
58
59
60
        for r in range(env.height):
            for c in range(env.width):
                _other_agent_target = 0
                for other_i, other_agent in enumerate(env.agents):
                    if other_agent.target == (r, c):
                        _other_agent_target = 1
                        break
u214892's avatar
u214892 committed
61
62
63
64
65
66
67
68
                assert obs_targets[(r, c)][
                           1] == _other_agent_target, "agent {}: at {} expected to be other agent's target = {}".format(
                    i, (r, c),
                    _other_agent_target)

        # test first channel of obs_agents_state: direction at own position
        for r in range(env.height):
            for c in range(env.width):
69
                if (agent.state.is_on_map_state() or agent.state == TrainState.DONE) and (
u214892's avatar
u214892 committed
70
71
                    r, c) == agent.position:
                    assert np.isclose(obs_agents_state[(r, c)][0], agent.direction), \
72
73
74
                        "agent {} in state {} at {} expected to contain own direction {}, found {}" \
                            .format(i, agent.state, (r, c), agent.direction, obs_agents_state[(r, c)][0])
                elif (agent.state == TrainState.READY_TO_DEPART) and (r, c) == agent.initial_position:
u214892's avatar
u214892 committed
75
                    assert np.isclose(obs_agents_state[(r, c)][0], agent.direction), \
76
77
                        "agent {} in state {} at {} expected to contain own direction {}, found {}" \
                            .format(i, agent.state, (r, c), agent.direction, obs_agents_state[(r, c)][0])
u214892's avatar
u214892 committed
78
79
                else:
                    assert np.isclose(obs_agents_state[(r, c)][0], -1), \
80
81
                        "agent {} in state {} at {} expected contain -1 found {}" \
                            .format(i, agent.state, (r, c), obs_agents_state[(r, c)][0])
u214892's avatar
u214892 committed
82
83
84
85
86
87
88
89

        # test second channel of obs_agents_state: direction at other agents position
        for r in range(env.height):
            for c in range(env.width):
                has_agent = False
                for other_i, other_agent in enumerate(env.agents):
                    if i == other_i:
                        continue
90
                    if other_agent.state in [TrainState.MOVING, TrainState.MALFUNCTION, TrainState.STOPPED, TrainState.DONE] and (
u214892's avatar
u214892 committed
91
92
                        r, c) == other_agent.position:
                        assert np.isclose(obs_agents_state[(r, c)][1], other_agent.direction), \
93
94
                            "agent {} in state {} at {} should see other agent with direction {}, found = {}" \
                                .format(i, agent.state, (r, c), other_agent.direction, obs_agents_state[(r, c)][1])
u214892's avatar
u214892 committed
95
96
97
                    has_agent = True
                if not has_agent:
                    assert np.isclose(obs_agents_state[(r, c)][1], -1), \
98
99
                        "agent {} in state {} at {} should see no other agent direction (-1), found = {}" \
                            .format(i, agent.state, (r, c), obs_agents_state[(r, c)][1])
u214892's avatar
u214892 committed
100
101
102
103
104
105

        # test third and fourth channel of obs_agents_state: malfunction and speed of own or other agent in the grid
        for r in range(env.height):
            for c in range(env.width):
                has_agent = False
                for other_i, other_agent in enumerate(env.agents):
106
107
                    if other_agent.state in [TrainState.MOVING, TrainState.MALFUNCTION, TrainState.STOPPED,
                                              TrainState.DONE] and other_agent.position == (r, c):
u214892's avatar
u214892 committed
108
                        assert np.isclose(obs_agents_state[(r, c)][2], other_agent.malfunction_data['malfunction']), \
109
110
                            "agent {} in state {} at {} should see agent malfunction {}, found = {}" \
                                .format(i, agent.state, (r, c), other_agent.malfunction_data['malfunction'],
u214892's avatar
u214892 committed
111
                                        obs_agents_state[(r, c)][2])
112
                        assert np.isclose(obs_agents_state[(r, c)][3], other_agent.speed_counter.speed)
u214892's avatar
u214892 committed
113
114
115
                        has_agent = True
                if not has_agent:
                    assert np.isclose(obs_agents_state[(r, c)][2], -1), \
116
117
                        "agent {} in state {} at {} should see no agent malfunction (-1), found = {}" \
                            .format(i, agent.state, (r, c), obs_agents_state[(r, c)][2])
u214892's avatar
u214892 committed
118
                    assert np.isclose(obs_agents_state[(r, c)][3], -1), \
119
120
                        "agent {} in state {} at {} should see no agent speed (-1), found = {}" \
                            .format(i, agent.state, (r, c), obs_agents_state[(r, c)][3])
u214892's avatar
u214892 committed
121
122
123
124
125
126

        # test fifth channel of obs_agents_state: number of agents ready to depart in to this cell
        for r in range(env.height):
            for c in range(env.width):
                count = 0
                for other_i, other_agent in enumerate(env.agents):
127
                    if other_agent.state == TrainState.READY_TO_DEPART and other_agent.initial_position == (r, c):
u214892's avatar
u214892 committed
128
129
                        count += 1
                assert np.isclose(obs_agents_state[(r, c)][4], count), \
130
131
                    "agent {} in state {} at {} should see {} agents ready to depart, found{}" \
                        .format(i, agent.state, (r, c), count, obs_agents_state[(r, c)][4])