diff --git a/flatland/baselines/dueling_double_dqn.py b/flatland/baselines/dueling_double_dqn.py index 084d0a22b8c3ae39a2503e849ed346f50a7c8aa9..ee75a615cda4d26a06810d2b7d109fe5691d5ac4 100644 --- a/flatland/baselines/dueling_double_dqn.py +++ b/flatland/baselines/dueling_double_dqn.py @@ -22,12 +22,12 @@ device = torch.device("cpu") print(device) -class Agent(): +class Agent: """Interacts with and learns from the environment.""" def __init__(self, state_size, action_size, net_type, seed, double_dqn=True, input_channels=5): """Initialize an Agent object. - + Params ====== state_size (int): dimension of each state @@ -78,7 +78,7 @@ class Agent(): def act(self, state, eps=0.): """Returns actions for given state as per current policy. - + Params ====== state (array_like): current state @@ -140,7 +140,7 @@ class Agent(): ====== local_model (PyTorch model): weights will be copied from target_model (PyTorch model): weights will be copied to - tau (float): interpolation parameter + tau (float): interpolation parameter """ for target_param, local_param in zip(target_model.parameters(), local_model.parameters()): target_param.data.copy_(tau * local_param.data + (1.0 - tau) * target_param.data) diff --git a/flatland/baselines/model.py b/flatland/baselines/model.py index 3b52e9f5ed691aeb3be01e69f04f92d615829b0b..7a5b3d613342a4fba8e2c8f1f45df21381e12684 100644 --- a/flatland/baselines/model.py +++ b/flatland/baselines/model.py @@ -1,4 +1,3 @@ -import torch import torch.nn as nn import torch.nn.functional as F diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py index ce0a28ae9f984249233e9086f78debd3c1cbc54e..b0730d8a4cd5e7d8567e243765b258a111070296 100644 --- a/flatland/core/env_observation_builder.py +++ b/flatland/core/env_observation_builder.py @@ -9,6 +9,7 @@ case of multi-agent environments. """ import numpy as np + from collections import deque # TODO: add docstrings, pylint, etc... @@ -103,6 +104,7 @@ class TreeObsForRailEnv(ObservationBuilder): node = nodes_queue.popleft() node_id = (node[0], node[1], node[2]) + if node_id not in visited: visited.add(node_id) @@ -125,6 +127,53 @@ class TreeObsForRailEnv(ObservationBuilder): """ neighbors = [] + for direction in range(4): + new_cell = self._new_position(position, (direction+2) % 4) + + if new_cell[0] >= 0 and new_cell[0] < self.env.height and new_cell[1] >= 0 and new_cell[1] < self.env.width: + # Check if the two cells are connected by a valid transition + transitionValid = False + for orientation in range(4): + moves = self.env.rail.get_transitions((new_cell[0], new_cell[1], orientation)) + if moves[direction]: + transitionValid = True + break + + if not transitionValid: + continue + + # Check if a transition in direction node[2] is possible if an agent + # lands in the current cell with orientation `direction'; this only + # applies to cells that are not dead-ends! + directionMatch = True + if enforce_target_direction >= 0: + directionMatch = self.env.rail.get_transition( + (new_cell[0], new_cell[1], direction), enforce_target_direction) + + # If transition is found to invalid, check if perhaps it + # is a dead-end, in which case the direction of movement is rotated + # 180 degrees (moving forward turns the agents and makes it step in the previous cell) + if not directionMatch: + # If cell is a dead-end, append previous node with reversed + # orientation! + nbits = 0 + tmp = self.env.rail.get_transitions((new_cell[0], new_cell[1])) + while tmp > 0: + nbits += (tmp & 1) + tmp = tmp >> 1 + if nbits == 1: + # Dead-end! + # Check if transition is possible in new_cell + # with orientation (direction+2)%4 in direction `direction' + directionMatch = directionMatch or self.env.rail.get_transition( + (new_cell[0], new_cell[1], (direction+2) % 4), direction) + + if transitionValid and directionMatch: + new_distance = min(self.distance_map[target_nr, + new_cell[0], new_cell[1]], current_distance+1) + neighbors.append((new_cell[0], new_cell[1], direction, new_distance)) + self.distance_map[target_nr, new_cell[0], new_cell[1]] = new_distance + possible_directions = [0, 1, 2, 3] if enforce_target_direction >= 0: # The agent must land into the current cell with orientation `enforce_target_direction'. @@ -195,6 +244,10 @@ class TreeObsForRailEnv(ObservationBuilder): the transitions. The order is: [data from 'left'] + [data from 'forward'] + [data from 'right'] + [data from 'back'] + + + + Each branch data is organized as: [root node information] + [recursive branch data from 'left'] + @@ -410,3 +463,51 @@ class TreeObsForRailEnv(ObservationBuilder): num_features_per_node, prompt=prompt_[children], current_depth=current_depth+1) + + +class GlobalObsForRailEnv(ObservationBuilder): + """ + Gives a global observation of the entire rail environment. + The observation is composed of the following elements: + + - transition map array with dimensions (env.height, env.width, 16), + assuming 16 bits encoding of transitions. + + - Four 2D arrays containing respectively the position of the given agent, + the position of its target, the positions of the other agents and of + their target. + + - A 4 elements array with one of encoding of the direction. + """ + def __init__(self): + super(GlobalObsForRailEnv, self).__init__() + + def reset(self): + self.rail_obs = np.zeros((self.env.height, self.env.width, 16)) + for i in range(self.rail_obs.shape[0]): + for j in range(self.rail_obs.shape[1]): + self.rail_obs[i, j] = np.array( + list(f'{self.env.rail.get_transitions((i, j)):016b}')).astype(int) + + # self.targets = np.zeros(self.env.height, self.env.width) + # for target_pos in self.env.agents_target: + # self.targets[target_pos] += 1 + + def get(self, handle): + obs_agents_targets_pos = np.zeros((4, self.env.height, self.env.width)) + agent_pos = self.env.agents_position[handle] + obs_agents_targets_pos[0][agent_pos] += 1 + for i in range(len(self.env.agents_position)): + if i != handle: + obs_agents_targets_pos[3][self.env.agents_position[i]] += 1 + + agent_target_pos = self.env.agents_target[handle] + obs_agents_targets_pos[1][agent_target_pos] += 1 + for i in range(len(self.env.agents_target)): + if i != handle: + obs_agents_targets_pos[2][self.env.agents_target[i]] += 1 + + direction = np.zeros(4) + direction[self.env.agents_direction[handle]] = 1 + + return self.rail_obs, obs_agents_targets_pos, direction diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index 24f365820b06698931b46afb6266deca03d6834b..9f008f00c261f76e71771732aa02c3c3071f9542 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -406,6 +406,7 @@ class RenderTool(object): plt.clf() # if oFigure is None: # oFigure = plt.figure() + def drawTrans(oFrom, oTo, sColor="gray"): plt.plot( [oFrom[0], oTo[0]], # x @@ -554,8 +555,6 @@ class RenderTool(object): plt.pause(0.00001) return - - def _draw_square(self, center, size, color): x0 = center[0]-size/2 x1 = center[0]+size/2 diff --git a/tests/test_env_observation_builder.py b/tests/test_env_observation_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..44dbfc6d8f14a7293e148f125b47eed66c9ca08d --- /dev/null +++ b/tests/test_env_observation_builder.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from flatland.core.env_observation_builder import GlobalObsForRailEnv +from flatland.core.transition_map import GridTransitionMap, Grid4Transitions +from flatland.envs.rail_env import RailEnv, rail_from_GridTransitionMap_generator +from flatland.utils.rendertools import * + +"""Tests for `flatland` package.""" + + +def test_global_obs(): + # We instantiate a very simple rail network on a 7x10 grid: + # | + # | + # | + # _ _ _ /_\ _ _ _ _ _ _ + # \ / + # | + # | + # | + + cells = [int('0000000000000000', 2), # empty cell - Case 0 + int('1000000000100000', 2), # Case 1 - straight + int('1001001000100000', 2), # Case 2 - simple switch + int('1000010000100001', 2), # Case 3 - diamond drossing + int('1001011000100001', 2), # Case 4 - single slip switch + int('1100110000110011', 2), # Case 5 - double slip switch + int('0101001000000010', 2), # Case 6 - symmetrical switch + int('0010000000000000', 2)] # Case 7 - dead end + + transitions = Grid4Transitions([]) + empty = cells[0] + + dead_end_from_south = cells[7] + dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90) + dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180) + dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270) + + vertical_straight = cells[1] + horizontal_straight = transitions.rotate_transition(vertical_straight, 90) + + double_switch_south_horizontal_straight = horizontal_straight + cells[6] + double_switch_north_horizontal_straight = transitions.rotate_transition( + double_switch_south_horizontal_straight, 180) + + rail_map = np.array( + [[empty] * 3 + [dead_end_from_south] + [empty] * 6] + + [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 + + [[dead_end_from_east] + [horizontal_straight] * 2 + + [double_switch_north_horizontal_straight] + + [horizontal_straight] * 2 + [double_switch_south_horizontal_straight] + + [horizontal_straight] * 2 + [dead_end_from_west]] + + [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 + + [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16) + + rail = GridTransitionMap(width=rail_map.shape[1], + height=rail_map.shape[0], transitions=transitions) + rail.grid = rail_map + env = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_GridTransitionMap_generator(rail), + number_of_agents=1, + obs_builder_object=GlobalObsForRailEnv()) + + global_obs = env.reset() + # env_renderer = RenderTool(env) + # env_renderer.renderEnv(show=True) + + # global_obs.reset() + assert(global_obs[0][0].shape == rail_map.shape + (16,)) + + rail_map_recons = np.zeros_like(rail_map) + for i in range(global_obs[0][0].shape[0]): + for j in range(global_obs[0][0].shape[1]): + rail_map_recons[i, j] = int( + ''.join(global_obs[0][0][i, j].astype(int).astype(str)), 2) + + assert(rail_map_recons.all() == rail_map.all()) + + # If this assertion is wrong, it means that the observation returned + # places the agent on an empty cell + assert(np.sum(rail_map * global_obs[0][1][0]) > 0) + + + +test_global_obs() + + + + + + + + + + diff --git a/tests/test_environments.py b/tests/test_environments.py index 4f31fe6196846ff843aee3a445d1347d8975799d..ea8748b8aa4b50a1371a013be98f3b42d0d01228 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -5,6 +5,7 @@ import numpy as np from flatland.envs.rail_env import RailEnv, rail_from_GridTransitionMap_generator from flatland.core.transitions import Grid4Transitions from flatland.core.transition_map import GridTransitionMap +from flatland.core.env_observation_builder import GlobalObsForRailEnv """Tests for `flatland` package.""" @@ -49,7 +50,9 @@ def test_rail_environment_single_agent(): rail_env = RailEnv(width=3, height=3, rail_generator=rail_from_GridTransitionMap_generator(rail), - number_of_agents=1) + number_of_agents=1, + obs_builder_object=GlobalObsForRailEnv()) + for _ in range(200): _ = rail_env.reset() @@ -124,7 +127,8 @@ def test_dead_end(): rail_env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_GridTransitionMap_generator(rail), - number_of_agents=1) + number_of_agents=1, + obs_builder_object=GlobalObsForRailEnv()) def check_consistency(rail_env): # We run step to check that trains do not move anymore @@ -147,14 +151,14 @@ def test_dead_end(): # We try the configuration in the 4 directions: rail_env.reset() - rail_env.agents_target[0] = [0, 0] - rail_env.agents_position[0] = [0, 2] + rail_env.agents_target[0] = (0, 0) + rail_env.agents_position[0] = (0, 2) rail_env.agents_direction[0] = 1 check_consistency(rail_env) rail_env.reset() - rail_env.agents_target[0] = [0, 4] - rail_env.agents_position[0] = [0, 2] + rail_env.agents_target[0] = (0, 4) + rail_env.agents_position[0] = (0, 2) rail_env.agents_direction[0] = 3 check_consistency(rail_env) @@ -173,16 +177,17 @@ def test_dead_end(): rail_env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_GridTransitionMap_generator(rail), - number_of_agents=1) + number_of_agents=1, + obs_builder_object=GlobalObsForRailEnv()) rail_env.reset() - rail_env.agents_target[0] = [0, 0] - rail_env.agents_position[0] = [2, 0] + rail_env.agents_target[0] = (0, 0) + rail_env.agents_position[0] = (2, 0) rail_env.agents_direction[0] = 2 check_consistency(rail_env) rail_env.reset() - rail_env.agents_target[0] = [4, 0] - rail_env.agents_position[0] = [2, 0] + rail_env.agents_target[0] = (4, 0) + rail_env.agents_position[0] = (2, 0) rail_env.agents_direction[0] = 0 check_consistency(rail_env) diff --git a/tests/test_rendertools.py b/tests/test_rendertools.py index 0bdc47bace289656181b01b1f44344e4322363a0..e45b7d1815365afda98f699d628a7e6f51c92395 100644 --- a/tests/test_rendertools.py +++ b/tests/test_rendertools.py @@ -11,6 +11,7 @@ import os import matplotlib.pyplot as plt import flatland.utils.rendertools as rt +from flatland.core.env_observation_builder import GlobalObsForRailEnv def checkFrozenImage(sFileImage): @@ -36,7 +37,10 @@ def checkFrozenImage(sFileImage): def test_render_env(): # random.seed(100) np.random.seed(100) - oEnv = RailEnv(width=10, height=10, rail_generator=random_rail_generator(), number_of_agents=2) + oEnv = RailEnv(width=10, height=10, + rail_generator=random_rail_generator(), + number_of_agents=2, + obs_builder_object=GlobalObsForRailEnv()) oEnv.reset() oRT = rt.RenderTool(oEnv) plt.figure(figsize=(10, 10))