Skip to content
Snippets Groups Projects
Commit a4bcd315 authored by Erik Nygren's avatar Erik Nygren
Browse files

Merge branch 'master' of gitlab.aicrowd.com:flatland/flatland into dqn_on_flatland

parents bd5177e8 460c1e5d
No related branches found
No related tags found
No related merge requests found
......@@ -22,12 +22,12 @@ device = torch.device("cpu")
print(device)
class Agent():
class Agent:
"""Interacts with and learns from the environment."""
def __init__(self, state_size, action_size, net_type, seed, double_dqn=True, input_channels=5):
"""Initialize an Agent object.
Params
======
state_size (int): dimension of each state
......@@ -78,7 +78,7 @@ class Agent():
def act(self, state, eps=0.):
"""Returns actions for given state as per current policy.
Params
======
state (array_like): current state
......@@ -140,7 +140,7 @@ class Agent():
======
local_model (PyTorch model): weights will be copied from
target_model (PyTorch model): weights will be copied to
tau (float): interpolation parameter
tau (float): interpolation parameter
"""
for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
target_param.data.copy_(tau * local_param.data + (1.0 - tau) * target_param.data)
......
import torch
import torch.nn as nn
import torch.nn.functional as F
......
......@@ -9,6 +9,7 @@ case of multi-agent environments.
"""
import numpy as np
from collections import deque
# TODO: add docstrings, pylint, etc...
......@@ -103,6 +104,7 @@ class TreeObsForRailEnv(ObservationBuilder):
node = nodes_queue.popleft()
node_id = (node[0], node[1], node[2])
if node_id not in visited:
visited.add(node_id)
......@@ -125,6 +127,53 @@ class TreeObsForRailEnv(ObservationBuilder):
"""
neighbors = []
for direction in range(4):
new_cell = self._new_position(position, (direction+2) % 4)
if new_cell[0] >= 0 and new_cell[0] < self.env.height and new_cell[1] >= 0 and new_cell[1] < self.env.width:
# Check if the two cells are connected by a valid transition
transitionValid = False
for orientation in range(4):
moves = self.env.rail.get_transitions((new_cell[0], new_cell[1], orientation))
if moves[direction]:
transitionValid = True
break
if not transitionValid:
continue
# Check if a transition in direction node[2] is possible if an agent
# lands in the current cell with orientation `direction'; this only
# applies to cells that are not dead-ends!
directionMatch = True
if enforce_target_direction >= 0:
directionMatch = self.env.rail.get_transition(
(new_cell[0], new_cell[1], direction), enforce_target_direction)
# If transition is found to invalid, check if perhaps it
# is a dead-end, in which case the direction of movement is rotated
# 180 degrees (moving forward turns the agents and makes it step in the previous cell)
if not directionMatch:
# If cell is a dead-end, append previous node with reversed
# orientation!
nbits = 0
tmp = self.env.rail.get_transitions((new_cell[0], new_cell[1]))
while tmp > 0:
nbits += (tmp & 1)
tmp = tmp >> 1
if nbits == 1:
# Dead-end!
# Check if transition is possible in new_cell
# with orientation (direction+2)%4 in direction `direction'
directionMatch = directionMatch or self.env.rail.get_transition(
(new_cell[0], new_cell[1], (direction+2) % 4), direction)
if transitionValid and directionMatch:
new_distance = min(self.distance_map[target_nr,
new_cell[0], new_cell[1]], current_distance+1)
neighbors.append((new_cell[0], new_cell[1], direction, new_distance))
self.distance_map[target_nr, new_cell[0], new_cell[1]] = new_distance
possible_directions = [0, 1, 2, 3]
if enforce_target_direction >= 0:
# The agent must land into the current cell with orientation `enforce_target_direction'.
......@@ -195,6 +244,10 @@ class TreeObsForRailEnv(ObservationBuilder):
the transitions. The order is:
[data from 'left'] + [data from 'forward'] + [data from 'right'] + [data from 'back']
Each branch data is organized as:
[root node information] +
[recursive branch data from 'left'] +
......@@ -410,3 +463,51 @@ class TreeObsForRailEnv(ObservationBuilder):
num_features_per_node,
prompt=prompt_[children],
current_depth=current_depth+1)
class GlobalObsForRailEnv(ObservationBuilder):
"""
Gives a global observation of the entire rail environment.
The observation is composed of the following elements:
- transition map array with dimensions (env.height, env.width, 16),
assuming 16 bits encoding of transitions.
- Four 2D arrays containing respectively the position of the given agent,
the position of its target, the positions of the other agents and of
their target.
- A 4 elements array with one of encoding of the direction.
"""
def __init__(self):
super(GlobalObsForRailEnv, self).__init__()
def reset(self):
self.rail_obs = np.zeros((self.env.height, self.env.width, 16))
for i in range(self.rail_obs.shape[0]):
for j in range(self.rail_obs.shape[1]):
self.rail_obs[i, j] = np.array(
list(f'{self.env.rail.get_transitions((i, j)):016b}')).astype(int)
# self.targets = np.zeros(self.env.height, self.env.width)
# for target_pos in self.env.agents_target:
# self.targets[target_pos] += 1
def get(self, handle):
obs_agents_targets_pos = np.zeros((4, self.env.height, self.env.width))
agent_pos = self.env.agents_position[handle]
obs_agents_targets_pos[0][agent_pos] += 1
for i in range(len(self.env.agents_position)):
if i != handle:
obs_agents_targets_pos[3][self.env.agents_position[i]] += 1
agent_target_pos = self.env.agents_target[handle]
obs_agents_targets_pos[1][agent_target_pos] += 1
for i in range(len(self.env.agents_target)):
if i != handle:
obs_agents_targets_pos[2][self.env.agents_target[i]] += 1
direction = np.zeros(4)
direction[self.env.agents_direction[handle]] = 1
return self.rail_obs, obs_agents_targets_pos, direction
......@@ -406,6 +406,7 @@ class RenderTool(object):
plt.clf()
# if oFigure is None:
# oFigure = plt.figure()
def drawTrans(oFrom, oTo, sColor="gray"):
plt.plot(
[oFrom[0], oTo[0]], # x
......@@ -554,8 +555,6 @@ class RenderTool(object):
plt.pause(0.00001)
return
def _draw_square(self, center, size, color):
x0 = center[0]-size/2
x1 = center[0]+size/2
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from flatland.core.env_observation_builder import GlobalObsForRailEnv
from flatland.core.transition_map import GridTransitionMap, Grid4Transitions
from flatland.envs.rail_env import RailEnv, rail_from_GridTransitionMap_generator
from flatland.utils.rendertools import *
"""Tests for `flatland` package."""
def test_global_obs():
# We instantiate a very simple rail network on a 7x10 grid:
# |
# |
# |
# _ _ _ /_\ _ _ _ _ _ _
# \ /
# |
# |
# |
cells = [int('0000000000000000', 2), # empty cell - Case 0
int('1000000000100000', 2), # Case 1 - straight
int('1001001000100000', 2), # Case 2 - simple switch
int('1000010000100001', 2), # Case 3 - diamond drossing
int('1001011000100001', 2), # Case 4 - single slip switch
int('1100110000110011', 2), # Case 5 - double slip switch
int('0101001000000010', 2), # Case 6 - symmetrical switch
int('0010000000000000', 2)] # Case 7 - dead end
transitions = Grid4Transitions([])
empty = cells[0]
dead_end_from_south = cells[7]
dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180)
dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
vertical_straight = cells[1]
horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
double_switch_south_horizontal_straight = horizontal_straight + cells[6]
double_switch_north_horizontal_straight = transitions.rotate_transition(
double_switch_south_horizontal_straight, 180)
rail_map = np.array(
[[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
[[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 +
[[dead_end_from_east] + [horizontal_straight] * 2 +
[double_switch_north_horizontal_straight] +
[horizontal_straight] * 2 + [double_switch_south_horizontal_straight] +
[horizontal_straight] * 2 + [dead_end_from_west]] +
[[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
[[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
rail = GridTransitionMap(width=rail_map.shape[1],
height=rail_map.shape[0], transitions=transitions)
rail.grid = rail_map
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_GridTransitionMap_generator(rail),
number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv())
global_obs = env.reset()
# env_renderer = RenderTool(env)
# env_renderer.renderEnv(show=True)
# global_obs.reset()
assert(global_obs[0][0].shape == rail_map.shape + (16,))
rail_map_recons = np.zeros_like(rail_map)
for i in range(global_obs[0][0].shape[0]):
for j in range(global_obs[0][0].shape[1]):
rail_map_recons[i, j] = int(
''.join(global_obs[0][0][i, j].astype(int).astype(str)), 2)
assert(rail_map_recons.all() == rail_map.all())
# If this assertion is wrong, it means that the observation returned
# places the agent on an empty cell
assert(np.sum(rail_map * global_obs[0][1][0]) > 0)
test_global_obs()
......@@ -5,6 +5,7 @@ import numpy as np
from flatland.envs.rail_env import RailEnv, rail_from_GridTransitionMap_generator
from flatland.core.transitions import Grid4Transitions
from flatland.core.transition_map import GridTransitionMap
from flatland.core.env_observation_builder import GlobalObsForRailEnv
"""Tests for `flatland` package."""
......@@ -49,7 +50,9 @@ def test_rail_environment_single_agent():
rail_env = RailEnv(width=3,
height=3,
rail_generator=rail_from_GridTransitionMap_generator(rail),
number_of_agents=1)
number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv())
for _ in range(200):
_ = rail_env.reset()
......@@ -124,7 +127,8 @@ def test_dead_end():
rail_env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_GridTransitionMap_generator(rail),
number_of_agents=1)
number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv())
def check_consistency(rail_env):
# We run step to check that trains do not move anymore
......@@ -147,14 +151,14 @@ def test_dead_end():
# We try the configuration in the 4 directions:
rail_env.reset()
rail_env.agents_target[0] = [0, 0]
rail_env.agents_position[0] = [0, 2]
rail_env.agents_target[0] = (0, 0)
rail_env.agents_position[0] = (0, 2)
rail_env.agents_direction[0] = 1
check_consistency(rail_env)
rail_env.reset()
rail_env.agents_target[0] = [0, 4]
rail_env.agents_position[0] = [0, 2]
rail_env.agents_target[0] = (0, 4)
rail_env.agents_position[0] = (0, 2)
rail_env.agents_direction[0] = 3
check_consistency(rail_env)
......@@ -173,16 +177,17 @@ def test_dead_end():
rail_env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_GridTransitionMap_generator(rail),
number_of_agents=1)
number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv())
rail_env.reset()
rail_env.agents_target[0] = [0, 0]
rail_env.agents_position[0] = [2, 0]
rail_env.agents_target[0] = (0, 0)
rail_env.agents_position[0] = (2, 0)
rail_env.agents_direction[0] = 2
check_consistency(rail_env)
rail_env.reset()
rail_env.agents_target[0] = [4, 0]
rail_env.agents_position[0] = [2, 0]
rail_env.agents_target[0] = (4, 0)
rail_env.agents_position[0] = (2, 0)
rail_env.agents_direction[0] = 0
check_consistency(rail_env)
......@@ -11,6 +11,7 @@ import os
import matplotlib.pyplot as plt
import flatland.utils.rendertools as rt
from flatland.core.env_observation_builder import GlobalObsForRailEnv
def checkFrozenImage(sFileImage):
......@@ -36,7 +37,10 @@ def checkFrozenImage(sFileImage):
def test_render_env():
# random.seed(100)
np.random.seed(100)
oEnv = RailEnv(width=10, height=10, rail_generator=random_rail_generator(), number_of_agents=2)
oEnv = RailEnv(width=10, height=10,
rail_generator=random_rail_generator(),
number_of_agents=2,
obs_builder_object=GlobalObsForRailEnv())
oEnv.reset()
oRT = rt.RenderTool(oEnv)
plt.figure(figsize=(10, 10))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment