From 493b8a2027c83ce4683853a29bd9813a9be51eea Mon Sep 17 00:00:00 2001 From: Guillaume Mollard <guillaume.mollard2@gmail.com> Date: Thu, 18 Apr 2019 16:46:00 +0200 Subject: [PATCH] started implementation of global view and the corresponding test --- flatland/core/env_observation_builder.py | 24 +++++++++ tests/test_env_observation_builder.py | 69 ++++++++++++++++++++++++ 2 files changed, 93 insertions(+) create mode 100644 tests/test_env_observation_builder.py diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py index a83c0c51..ac62946d 100644 --- a/flatland/core/env_observation_builder.py +++ b/flatland/core/env_observation_builder.py @@ -1,3 +1,5 @@ +import numpy as np + # TODO: add docstrings, pylint, etc... @@ -23,3 +25,25 @@ class TreeObsForRailEnv(ObservationBuilder): # raise NotImplementedError() return [] + + +class GlobalObsForRailEnv(ObservationBuilder): + """ + Gives a global observation of the entire rail environment. + The observation is composed of the following elements: + + - transition map array with dimensions (env.height, env.width, 16), + assuming 16 bits encoding of transitions. + + - Four 2D arrays containing respectively the position of the given agent, + the position of its target, the positions of the other agents and of + their target. + """ + def __init__(self, env): + super(GlobalObsForRailEnv, self).__init__(env) + self.rail_obs = np.zeros((self.env.height, self.env.width, 16)) + for i in range(self.rail_obs.shape[0]): + for j in range(self.rail_obs.shape[1]): + self.rail_obs[i, j] = self.env.rail.get_transitions((i, j)) + + diff --git a/tests/test_env_observation_builder.py b/tests/test_env_observation_builder.py new file mode 100644 index 00000000..c3f5dfd5 --- /dev/null +++ b/tests/test_env_observation_builder.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from flatland.core.env_observation_builder import GlobalObsForRailEnv +from flatland.core.transitions import Grid4Transitions +import numpy as np + +"""Tests for `flatland` package.""" + + +def test_global_obs(): + # We instantiate a very simple rail network on a 7x10 grid: + # | + # | + # | + # _ _ _ /_\ _ _ _ _ _ _ + # \ / + # | + # | + # | + + cells = [int('0000000000000000', 2), # empty cell - Case 0 + int('1000000000100000', 2), # Case 1 - straight + int('1001001000100000', 2), # Case 2 - simple switch + int('1000010000100001', 2), # Case 3 - diamond drossing + int('1001011000100001', 2), # Case 4 - single slip switch + int('1100110000110011', 2), # Case 5 - double slip switch + int('0101001000000010', 2), # Case 6 - symmetrical switch + int('0010000000000000', 2)] # Case 7 - dead end + + transitions = Grid4Transitions([]) + empty = cells[0] + + dead_end_from_south = cells[7] + dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90) + dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180) + dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270) + + vertical_straight = cells[1] + horizontal_straight = transitions.rotate_transition(vertical_straight, 90) + + double_switch_south_horizontal_straight = horizontal_straight + cells[6] + double_switch_north_horizontal_straight = transitions.rotate_transition( + double_switch_south_horizontal_straight, 180) + + + + rail_map = np.array( + [[empty] * 3 + [dead_end_from_south] + [empty] * 6] + + [[empty] * 3 + [vertical_straight] + [empty] * 6]*2 + + [[horizontal_straight] * 3 + [double_switch_north_horizontal_straight] + + [horizontal_straight] * 2 + [double_switch_south_horizontal_straight] + + [horizontal_straight] * 3] + + [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 + + [[empty] * 3 + [dead_end_from_south] + [empty] * 6], dtype=np.uint16) + + print(rail_map.shape) + +test_global_obs() + + + + + + + + + + -- GitLab