test_env_observation_builder.py 3.3 KB
Newer Older
1
2
3
#!/usr/bin/env python
# -*- coding: utf-8 -*-

spiglerg's avatar
spiglerg committed
4
5
import numpy as np

gmollard's avatar
gmollard committed
6
from flatland.core.transition_map import GridTransitionMap, Grid4Transitions
7
from flatland.envs.generators import rail_from_GridTransitionMap_generator
8
9
from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.rail_env import RailEnv
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49

"""Tests for `flatland` package."""


def test_global_obs():
    # We instantiate a very simple rail network on a 7x10 grid:
    #        |
    #        |
    #        |
    # _ _ _ /_\ _ _  _  _ _ _
    #               \ /
    #                |
    #                |
    #                |

    cells = [int('0000000000000000', 2),  # empty cell - Case 0
             int('1000000000100000', 2),  # Case 1 - straight
             int('1001001000100000', 2),  # Case 2 - simple switch
             int('1000010000100001', 2),  # Case 3 - diamond drossing
             int('1001011000100001', 2),  # Case 4 - single slip switch
             int('1100110000110011', 2),  # Case 5 - double slip switch
             int('0101001000000010', 2),  # Case 6 - symmetrical switch
             int('0010000000000000', 2)]  # Case 7 - dead end

    transitions = Grid4Transitions([])
    empty = cells[0]

    dead_end_from_south = cells[7]
    dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
    dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180)
    dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)

    vertical_straight = cells[1]
    horizontal_straight = transitions.rotate_transition(vertical_straight, 90)

    double_switch_south_horizontal_straight = horizontal_straight + cells[6]
    double_switch_north_horizontal_straight = transitions.rotate_transition(
        double_switch_south_horizontal_straight, 180)

    rail_map = np.array(
spiglerg's avatar
spiglerg committed
50
51
52
53
54
55
56
57
        [[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
        [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 +
        [[dead_end_from_east] + [horizontal_straight] * 2 +
         [double_switch_north_horizontal_straight] +
         [horizontal_straight] * 2 + [double_switch_south_horizontal_straight] +
         [horizontal_straight] * 2 + [dead_end_from_west]] +
        [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
        [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
gmollard's avatar
gmollard committed
58
59
60
61

    rail = GridTransitionMap(width=rail_map.shape[1],
                             height=rail_map.shape[0], transitions=transitions)
    rail.grid = rail_map
gmollard's avatar
gmollard committed
62
63
64
65
66
    env = RailEnv(width=rail_map.shape[1],
                  height=rail_map.shape[0],
                  rail_generator=rail_from_GridTransitionMap_generator(rail),
                  number_of_agents=1,
                  obs_builder_object=GlobalObsForRailEnv())
gmollard's avatar
gmollard committed
67

gmollard's avatar
gmollard committed
68
    global_obs = env.reset()
gmollard's avatar
gmollard committed
69

70
    assert (global_obs[0][0].shape == rail_map.shape + (16,))
gmollard's avatar
gmollard committed
71
72

    rail_map_recons = np.zeros_like(rail_map)
gmollard's avatar
gmollard committed
73
74
75
76
    for i in range(global_obs[0][0].shape[0]):
        for j in range(global_obs[0][0].shape[1]):
            rail_map_recons[i, j] = int(
                ''.join(global_obs[0][0][i, j].astype(int).astype(str)), 2)
gmollard's avatar
gmollard committed
77

78
    assert (rail_map_recons.all() == rail_map.all())
gmollard's avatar
gmollard committed
79
80
81

    # If this assertion is wrong, it means that the observation returned
    # places the agent on an empty cell
82
    assert (np.sum(rail_map * global_obs[0][1][:, :, :4].sum(2)) > 0)