Skip to content
Snippets Groups Projects
Commit 493b8a20 authored by gmollard's avatar gmollard
Browse files

started implementation of global view and the corresponding test

parent 3e030163
No related branches found
No related tags found
2 merge requests!10Global observation builder,!9Global observation builder
Pipeline #285 failed
import numpy as np
# TODO: add docstrings, pylint, etc...
......@@ -23,3 +25,25 @@ class TreeObsForRailEnv(ObservationBuilder):
# raise NotImplementedError()
return []
class GlobalObsForRailEnv(ObservationBuilder):
"""
Gives a global observation of the entire rail environment.
The observation is composed of the following elements:
- transition map array with dimensions (env.height, env.width, 16),
assuming 16 bits encoding of transitions.
- Four 2D arrays containing respectively the position of the given agent,
the position of its target, the positions of the other agents and of
their target.
"""
def __init__(self, env):
super(GlobalObsForRailEnv, self).__init__(env)
self.rail_obs = np.zeros((self.env.height, self.env.width, 16))
for i in range(self.rail_obs.shape[0]):
for j in range(self.rail_obs.shape[1]):
self.rail_obs[i, j] = self.env.rail.get_transitions((i, j))
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from flatland.core.env_observation_builder import GlobalObsForRailEnv
from flatland.core.transitions import Grid4Transitions
import numpy as np
"""Tests for `flatland` package."""
def test_global_obs():
# We instantiate a very simple rail network on a 7x10 grid:
# |
# |
# |
# _ _ _ /_\ _ _ _ _ _ _
# \ /
# |
# |
# |
cells = [int('0000000000000000', 2), # empty cell - Case 0
int('1000000000100000', 2), # Case 1 - straight
int('1001001000100000', 2), # Case 2 - simple switch
int('1000010000100001', 2), # Case 3 - diamond drossing
int('1001011000100001', 2), # Case 4 - single slip switch
int('1100110000110011', 2), # Case 5 - double slip switch
int('0101001000000010', 2), # Case 6 - symmetrical switch
int('0010000000000000', 2)] # Case 7 - dead end
transitions = Grid4Transitions([])
empty = cells[0]
dead_end_from_south = cells[7]
dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180)
dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
vertical_straight = cells[1]
horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
double_switch_south_horizontal_straight = horizontal_straight + cells[6]
double_switch_north_horizontal_straight = transitions.rotate_transition(
double_switch_south_horizontal_straight, 180)
rail_map = np.array(
[[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
[[empty] * 3 + [vertical_straight] + [empty] * 6]*2 +
[[horizontal_straight] * 3 + [double_switch_north_horizontal_straight] +
[horizontal_straight] * 2 + [double_switch_south_horizontal_straight] +
[horizontal_straight] * 3] +
[[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 +
[[empty] * 3 + [dead_end_from_south] + [empty] * 6], dtype=np.uint16)
print(rail_map.shape)
test_global_obs()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment