Skip to content
Snippets Groups Projects
Commit 236eba12 authored by spiglerg's avatar spiglerg
Browse files

initial (working) tree observations, railenv+generators moved to envs/rail_env.py, fixes

parent 5064de8e
No related branches found
No related tags found
No related merge requests found
......@@ -2,20 +2,22 @@ import random
import numpy as np
import matplotlib.pyplot as plt
from flatland.core.env import RailEnv
from flatland.utils.rail_env_generator import *
from flatland.envs.rail_env import *
from flatland.core.env_observation_builder import TreeObsForRailEnv
from flatland.utils.rendertools import *
random.seed(1)
np.random.seed(1)
# Example generate a random rail
env = RailEnv(width=20, height=20, rail_generator=generate_random_rail, number_of_agents=10)
env = RailEnv(width=20, height=20, rail_generator=random_rail_generator, number_of_agents=10)
env.reset()
env_renderer = RenderTool(env)
env_renderer.renderEnv(show=True)
# Example generate a rail given a manual specification,
# a map of tuples (cell_type, rotation)
specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (7, 0), (0, 0)],
......@@ -23,13 +25,12 @@ specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (7, 0), (0, 0)],
env = RailEnv(width=6,
height=2,
rail_generator=generate_rail_from_manual_specifications(specs),
number_of_agents=1)
rail_generator=rail_from_manual_specifications_generator(specs),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=1))
handle = env.get_agent_handles()
obs = env.reset()
env.agents_position = [[1, 4]]
env.agents_target = [[1, 1]]
env.agents_direction = [1]
......@@ -37,13 +38,15 @@ env.agents_direction = [1]
env.obs_builder.reset()
# TODO: delete next line
#print(env.obs_builder.distance_map[0,:,:])
#print(env.obs_builder.max_dist)
#for i in range(4):
# print(env.obs_builder.distance_map[0, :, :, i])
obs, all_rewards, done, _ = env.step({0:0})
env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5)
env_renderer = RenderTool(env)
env_renderer.renderEnv(show=True)
print("Manual control: s=perform step, q=quit, [agent id] [1-2-3 action] \
(turnleft+move, move to front, turnright+move)")
for step in range(100):
......
......@@ -3,10 +3,6 @@ The env module defines the base Environment class.
The base Environment class is adapted from rllib.env.MultiAgentEnv
(https://github.com/ray-project/ray).
"""
import random
from .env_observation_builder import TreeObsForRailEnv
from flatland.utils.rail_env_generator import generate_random_rail
class Environment:
......@@ -94,327 +90,3 @@ class Environment:
function.
"""
raise NotImplementedError()
class RailEnv:
"""
RailEnv environment class.
RailEnv is an environment inspired by a (simplified version of) a rail
network, in which agents (trains) have to navigate to their target
locations in the shortest time possible, while at the same time cooperating
to avoid bottlenecks.
The valid actions in the environment are:
0: do nothing
1: turn left and move to the next cell
2: move to the next cell in front of the agent
3: turn right and move to the next cell
Moving forward in a dead-end cell makes the agent turn 180 degrees and step
to the cell it came from.
The actions of the agents are executed in order of their handle to prevent
deadlocks and to allow them to learn relative priorities.
TODO: WRITE ABOUT THE REWARD FUNCTION, and possibly allow for alpha and
beta to be passed as parameters to __init__().
"""
def __init__(self,
width,
height,
rail_generator=generate_random_rail,
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2)):
"""
Environment init.
Parameters
-------
rail_generator : function
The rail_generator function is a function that takes the width and
height of a rail map along with the number of times the env has
been reset, and returns a GridTransitionMap object.
Implemented functions are:
generate_random_rail : generate a random rail of given size
TODO: generate_rail_from_saved_list ---
width : int
The width of the rail map. Potentially in the future,
a range of widths to sample from.
height : int
The height of the rail map. Potentially in the future,
a range of heights to sample from.
number_of_agents : int
Number of agents to spawn on the map. Potentially in the future,
a range of number of agents to sample from.
obs_builder_object: ObservationBuilder object
ObservationBuilder-derived object that takes builds observation
vectors for each agent.
"""
self.rail_generator = rail_generator
self.num_resets = 0
self.rail = None
self.width = width
self.height = height
self.number_of_agents = number_of_agents
self.obs_builder = obs_builder_object
self.obs_builder.set_env(self)
self.actions = [0]*self.number_of_agents
self.rewards = [0]*self.number_of_agents
self.done = False
self.agents_position = []
self.agents_target = []
self.agents_direction = []
self.dones = {"__all__": False}
self.obs_dict = {}
self.rewards_dict = {}
self.agents_handles = list(range(self.number_of_agents))
def get_agent_handles(self):
return self.agents_handles
def reset(self):
self.rail = self.rail_generator(self.width, self.height, self.num_resets)
self.num_resets += 1
self.dones = {"__all__": False}
for handle in self.agents_handles:
self.dones[handle] = False
re_generate = True
while re_generate:
valid_positions = []
for r in range(self.height):
for c in range(self.width):
if self.rail.get_transitions((r, c)) > 0:
valid_positions.append((r, c))
self.agents_position = random.sample(valid_positions,
self.number_of_agents)
self.agents_target = random.sample(valid_positions,
self.number_of_agents)
# agents_direction must be a direction for which a solution is
# guaranteed.
self.agents_direction = [0]*self.number_of_agents
re_generate = False
for i in range(self.number_of_agents):
valid_movements = []
for direction in range(4):
position = self.agents_position[i]
moves = self.rail.get_transitions(
(position[0], position[1], direction))
for move_index in range(4):
if moves[move_index]:
valid_movements.append((direction, move_index))
valid_starting_directions = []
for m in valid_movements:
new_position = self._new_position(self.agents_position[i],
m[1])
if m[0] not in valid_starting_directions and \
self._path_exists(new_position, m[0],
self.agents_target[i]):
valid_starting_directions.append(m[0])
if len(valid_starting_directions) == 0:
re_generate = True
else:
self.agents_direction[i] = random.sample(
valid_starting_directions, 1)[0]
# Reset the state of the observation builder with the new environment
self.obs_builder.reset()
# Return the new observation vectors for each agent
return self._get_observations()
def step(self, action_dict):
alpha = 1.0
beta = 1.0
invalid_action_penalty = -2
step_penalty = -1 * alpha
global_reward = 1 * beta
# Reset the step rewards
self.rewards_dict = {}
for handle in self.agents_handles:
self.rewards_dict[handle] = 0
if self.dones["__all__"]:
return self._get_observations(), self.rewards_dict, self.dones, {}
for i in range(len(self.agents_handles)):
handle = self.agents_handles[i]
if handle not in action_dict:
continue
action = action_dict[handle]
if action < 0 or action > 3:
print('ERROR: illegal action=', action,
'for agent with handle=', handle)
return
if action > 0:
pos = self.agents_position[i]
direction = self.agents_direction[i]
movement = direction
if action == 1:
movement = direction - 1
elif action == 3:
movement = direction + 1
if movement < 0:
movement += 4
if movement >= 4:
movement -= 4
is_deadend = False
if action == 2:
# compute number of possible transitions in the current
# cell
nbits = 0
tmp = self.rail.get_transitions((pos[0], pos[1]))
while tmp > 0:
nbits += (tmp & 1)
tmp = tmp >> 1
if nbits == 1:
# dead-end; assuming the rail network is consistent,
# this should match the direction the agent has come
# from. But it's better to check in any case.
reverse_direction = 0
if direction == 0:
reverse_direction = 2
elif direction == 1:
reverse_direction = 3
elif direction == 2:
reverse_direction = 0
elif direction == 3:
reverse_direction = 1
valid_transition = self.rail.get_transition(
(pos[0], pos[1], direction),
reverse_direction)
if valid_transition:
direction = reverse_direction
movement = reverse_direction
is_deadend = True
new_position = self._new_position(pos, movement)
# Is it a legal move? 1) transition allows the movement in the
# cell, 2) the new cell is not empty (case 0), 3) the cell is
# free, i.e., no agent is currently in that cell
if new_position[1] >= self.width or\
new_position[0] >= self.height or\
new_position[0] < 0 or new_position[1] < 0:
new_cell_isValid = False
elif self.rail.get_transitions((new_position[0], new_position[1])) > 0:
new_cell_isValid = True
else:
new_cell_isValid = False
transition_isValid = self.rail.get_transition(
(pos[0], pos[1], direction),
movement) or is_deadend
cell_isFree = True
for j in range(self.number_of_agents):
if self.agents_position[j] == new_position:
cell_isFree = False
break
if new_cell_isValid and transition_isValid and cell_isFree:
# move and change direction to face the movement that was
# performed
self.agents_position[i] = new_position
self.agents_direction[i] = movement
else:
# the action was not valid, add penalty
self.rewards_dict[handle] += invalid_action_penalty
# if agent is not in target position, add step penalty
if self.agents_position[i][0] == self.agents_target[i][0] and \
self.agents_position[i][1] == self.agents_target[i][1]:
self.dones[handle] = True
else:
self.rewards_dict[handle] += step_penalty
# Check for end of episode + add global reward to all rewards!
num_agents_in_target_position = 0
for i in range(self.number_of_agents):
if self.agents_position[i][0] == self.agents_target[i][0] and \
self.agents_position[i][1] == self.agents_target[i][1]:
num_agents_in_target_position += 1
if num_agents_in_target_position == self.number_of_agents:
self.dones["__all__"] = True
self.rewards_dict = [r+global_reward for r in self.rewards_dict]
# Reset the step actions (in case some agent doesn't 'register_action'
# on the next step)
self.actions = [0]*self.number_of_agents
return self._get_observations(), self.rewards_dict, self.dones, {}
def _new_position(self, position, movement):
if movement == 0: # NORTH
return (position[0]-1, position[1])
elif movement == 1: # EAST
return (position[0], position[1] + 1)
elif movement == 2: # SOUTH
return (position[0]+1, position[1])
elif movement == 3: # WEST
return (position[0], position[1] - 1)
def _path_exists(self, start, direction, end):
# BFS - Check if a path exists between the 2 nodes
visited = set()
stack = [(start, direction)]
while stack:
node = stack.pop()
if node[0][0] == end[0] and node[0][1] == end[1]:
return 1
if node not in visited:
visited.add(node)
moves = self.rail.get_transitions((node[0][0], node[0][1], node[1]))
for move_index in range(4):
if moves[move_index]:
stack.append((self._new_position(node[0], move_index),
move_index))
# If cell is a dead-end, append previous node with reversed
# orientation!
nbits = 0
tmp = self.rail.get_transitions((node[0][0], node[0][1]))
while tmp > 0:
nbits += (tmp & 1)
tmp = tmp >> 1
if nbits == 1:
stack.append((node[0], (node[1] + 2) % 4))
return 0
def _get_observations(self):
self.obs_dict = {}
for handle in self.agents_handles:
self.obs_dict[handle] = self.obs_builder.get(handle)
return self.obs_dict
def render(self):
# TODO:
pass
This diff is collapsed.
"""
The rail_env_generator module defines provides utilities to generate env
bitmaps for the RailEnv environment.
Definition of the RailEnv environment and related level-generation functions.
Generator functions are functions that take width, height and num_resets as arguments and return
a GridTransitionMap object.
"""
import random
import numpy as np
from flatland.core.env import Environment
from flatland.core.env_observation_builder import TreeObsForRailEnv
from flatland.core.transitions import RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap
def generate_rail_from_manual_specifications(rail_spec):
def rail_from_manual_specifications_generator(rail_spec):
"""
Utility to convert a rail given by manual specification as a map of tuples
(cell_type, rotation), to a transition map with the correct 16-bit
......@@ -49,7 +54,7 @@ def generate_rail_from_manual_specifications(rail_spec):
return generator
def generate_rail_from_GridTransitionMap(rail_map):
def rail_from_GridTransitionMap_generator(rail_map):
"""
Utility to convert a rail given by a GridTransitionMap map with the correct
16-bit transitions specifications.
......@@ -79,7 +84,7 @@ def generate_rail_from_list_of_manual_specifications(list_of_specifications)
"""
def generate_random_rail(width, height, num_resets=0):
def random_rail_generator(width, height, num_resets=0):
"""
Dummy random level generator:
- fill in cells at random in [width-2, height-2]
......@@ -339,3 +344,333 @@ def generate_random_rail(width, height, num_resets=0):
return_rail = GridTransitionMap(width=width, height=height, transitions=t_utils)
return_rail.grid = tmp_rail
return return_rail
class RailEnv(Environment):
"""
RailEnv environment class.
RailEnv is an environment inspired by a (simplified version of) a rail
network, in which agents (trains) have to navigate to their target
locations in the shortest time possible, while at the same time cooperating
to avoid bottlenecks.
The valid actions in the environment are:
0: do nothing
1: turn left and move to the next cell
2: move to the next cell in front of the agent
3: turn right and move to the next cell
Moving forward in a dead-end cell makes the agent turn 180 degrees and step
to the cell it came from.
The actions of the agents are executed in order of their handle to prevent
deadlocks and to allow them to learn relative priorities.
TODO: WRITE ABOUT THE REWARD FUNCTION, and possibly allow for alpha and
beta to be passed as parameters to __init__().
"""
def __init__(self,
width,
height,
rail_generator=random_rail_generator,
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2)):
"""
Environment init.
Parameters
-------
rail_generator : function
The rail_generator function is a function that takes the width and
height of a rail map along with the number of times the env has
been reset, and returns a GridTransitionMap object.
Implemented functions are:
random_rail_generator : generate a random rail of given size
rail_from_GridTransitionMap_generator(rail_map) : generate a rail from
a GridTransitionMap object
rail_from_manual_specifications_generator(rail_spec) : generate a rail from
a rail specifications array
TODO: generate_rail_from_saved_list or from list of ndarray bitmaps ---
width : int
The width of the rail map. Potentially in the future,
a range of widths to sample from.
height : int
The height of the rail map. Potentially in the future,
a range of heights to sample from.
number_of_agents : int
Number of agents to spawn on the map. Potentially in the future,
a range of number of agents to sample from.
obs_builder_object: ObservationBuilder object
ObservationBuilder-derived object that takes builds observation
vectors for each agent.
"""
self.rail_generator = rail_generator
self.rail = None
self.width = width
self.height = height
self.number_of_agents = number_of_agents
self.obs_builder = obs_builder_object
self.obs_builder._set_env(self)
self.actions = [0]*self.number_of_agents
self.rewards = [0]*self.number_of_agents
self.done = False
self.dones = {"__all__": False}
self.obs_dict = {}
self.rewards_dict = {}
self.agents_handles = list(range(self.number_of_agents))
# self.agents_position = []
# self.agents_target = []
# self.agents_direction = []
self.num_resets = 0
self.reset()
self.num_resets = 0
def get_agent_handles(self):
return self.agents_handles
def reset(self):
self.rail = self.rail_generator(self.width, self.height, self.num_resets)
self.num_resets += 1
self.dones = {"__all__": False}
for handle in self.agents_handles:
self.dones[handle] = False
re_generate = True
while re_generate:
valid_positions = []
for r in range(self.height):
for c in range(self.width):
if self.rail.get_transitions((r, c)) > 0:
valid_positions.append((r, c))
self.agents_position = random.sample(valid_positions,
self.number_of_agents)
self.agents_target = random.sample(valid_positions,
self.number_of_agents)
# agents_direction must be a direction for which a solution is
# guaranteed.
self.agents_direction = [0]*self.number_of_agents
re_generate = False
for i in range(self.number_of_agents):
valid_movements = []
for direction in range(4):
position = self.agents_position[i]
moves = self.rail.get_transitions(
(position[0], position[1], direction))
for move_index in range(4):
if moves[move_index]:
valid_movements.append((direction, move_index))
valid_starting_directions = []
for m in valid_movements:
new_position = self._new_position(self.agents_position[i],
m[1])
if m[0] not in valid_starting_directions and \
self._path_exists(new_position, m[0],
self.agents_target[i]):
valid_starting_directions.append(m[0])
if len(valid_starting_directions) == 0:
re_generate = True
else:
self.agents_direction[i] = random.sample(
valid_starting_directions, 1)[0]
# Reset the state of the observation builder with the new environment
self.obs_builder.reset()
# Return the new observation vectors for each agent
return self._get_observations()
def step(self, action_dict):
alpha = 1.0
beta = 1.0
invalid_action_penalty = -2
step_penalty = -1 * alpha
global_reward = 1 * beta
# Reset the step rewards
self.rewards_dict = {}
for handle in self.agents_handles:
self.rewards_dict[handle] = 0
if self.dones["__all__"]:
return self._get_observations(), self.rewards_dict, self.dones, {}
for i in range(len(self.agents_handles)):
handle = self.agents_handles[i]
if handle not in action_dict:
continue
action = action_dict[handle]
if action < 0 or action > 3:
print('ERROR: illegal action=', action,
'for agent with handle=', handle)
return
if action > 0:
pos = self.agents_position[i]
direction = self.agents_direction[i]
movement = direction
if action == 1:
movement = direction - 1
elif action == 3:
movement = direction + 1
if movement < 0:
movement += 4
if movement >= 4:
movement -= 4
is_deadend = False
if action == 2:
# compute number of possible transitions in the current
# cell
nbits = 0
tmp = self.rail.get_transitions((pos[0], pos[1]))
while tmp > 0:
nbits += (tmp & 1)
tmp = tmp >> 1
if nbits == 1:
# dead-end; assuming the rail network is consistent,
# this should match the direction the agent has come
# from. But it's better to check in any case.
reverse_direction = 0
if direction == 0:
reverse_direction = 2
elif direction == 1:
reverse_direction = 3
elif direction == 2:
reverse_direction = 0
elif direction == 3:
reverse_direction = 1
valid_transition = self.rail.get_transition(
(pos[0], pos[1], direction),
reverse_direction)
if valid_transition:
direction = reverse_direction
movement = reverse_direction
is_deadend = True
new_position = self._new_position(pos, movement)
# Is it a legal move? 1) transition allows the movement in the
# cell, 2) the new cell is not empty (case 0), 3) the cell is
# free, i.e., no agent is currently in that cell
if new_position[1] >= self.width or\
new_position[0] >= self.height or\
new_position[0] < 0 or new_position[1] < 0:
new_cell_isValid = False
elif self.rail.get_transitions((new_position[0], new_position[1])) > 0:
new_cell_isValid = True
else:
new_cell_isValid = False
transition_isValid = self.rail.get_transition(
(pos[0], pos[1], direction),
movement) or is_deadend
cell_isFree = True
for j in range(self.number_of_agents):
if self.agents_position[j] == new_position:
cell_isFree = False
break
if new_cell_isValid and transition_isValid and cell_isFree:
# move and change direction to face the movement that was
# performed
self.agents_position[i] = new_position
self.agents_direction[i] = movement
else:
# the action was not valid, add penalty
self.rewards_dict[handle] += invalid_action_penalty
# if agent is not in target position, add step penalty
if self.agents_position[i][0] == self.agents_target[i][0] and \
self.agents_position[i][1] == self.agents_target[i][1]:
self.dones[handle] = True
else:
self.rewards_dict[handle] += step_penalty
# Check for end of episode + add global reward to all rewards!
num_agents_in_target_position = 0
for i in range(self.number_of_agents):
if self.agents_position[i][0] == self.agents_target[i][0] and \
self.agents_position[i][1] == self.agents_target[i][1]:
num_agents_in_target_position += 1
if num_agents_in_target_position == self.number_of_agents:
self.dones["__all__"] = True
self.rewards_dict = [r+global_reward for r in self.rewards_dict]
# Reset the step actions (in case some agent doesn't 'register_action'
# on the next step)
self.actions = [0]*self.number_of_agents
return self._get_observations(), self.rewards_dict, self.dones, {}
def _new_position(self, position, movement):
if movement == 0: # NORTH
return (position[0]-1, position[1])
elif movement == 1: # EAST
return (position[0], position[1] + 1)
elif movement == 2: # SOUTH
return (position[0]+1, position[1])
elif movement == 3: # WEST
return (position[0], position[1] - 1)
def _path_exists(self, start, direction, end):
# BFS - Check if a path exists between the 2 nodes
visited = set()
stack = [(start, direction)]
while stack:
node = stack.pop()
if node[0][0] == end[0] and node[0][1] == end[1]:
return 1
if node not in visited:
visited.add(node)
moves = self.rail.get_transitions((node[0][0], node[0][1], node[1]))
for move_index in range(4):
if moves[move_index]:
stack.append((self._new_position(node[0], move_index),
move_index))
# If cell is a dead-end, append previous node with reversed
# orientation!
nbits = 0
tmp = self.rail.get_transitions((node[0][0], node[0][1]))
while tmp > 0:
nbits += (tmp & 1)
tmp = tmp >> 1
if nbits == 1:
stack.append((node[0], (node[1] + 2) % 4))
return 0
def _get_observations(self):
self.obs_dict = {}
for handle in self.agents_handles:
self.obs_dict[handle] = self.obs_builder.get(handle)
return self.obs_dict
def render(self):
# TODO:
pass
......@@ -6,6 +6,8 @@ import xarray as xr
import matplotlib.pyplot as plt
# TODO: suggested renaming to RailEnvRenderTool, as it will only work with RailEnv!
class RenderTool(object):
Visit = recordtype("Visit", ["rc", "iDir", "iDepth", "prev"])
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from flatland.core.env import RailEnv
from flatland.core.env import RailEnv, rail_from_GridTransitionMap_generator
from flatland.core.transitions import Grid4Transitions
from flatland.core.transition_map import GridTransitionMap
from flatland.utils.rail_env_generator import generate_rail_from_GridTransitionMap
import numpy as np
"""Tests for `flatland` package."""
......@@ -47,7 +46,10 @@ def test_rail_environment_single_agent():
rail = GridTransitionMap(width=3, height=3, transitions=transitions)
rail.grid = rail_map
rail_env = RailEnv(width=3, height=3, rail_generator=generate_rail_from_GridTransitionMap(rail), number_of_agents=1)
rail_env = RailEnv(width=3,
height=3,
rail_generator=rail_from_GridTransitionMap_generator(rail),
number_of_agents=1)
for _ in range(200):
_ = rail_env.reset()
......@@ -121,7 +123,7 @@ def test_dead_end():
rail.grid = rail_map
rail_env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=generate_rail_from_GridTransitionMap(rail),
rail_generator=rail_from_GridTransitionMap_generator(rail),
number_of_agents=1)
def check_consistency(rail_env):
......@@ -170,7 +172,7 @@ def test_dead_end():
rail.grid = rail_map
rail_env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=generate_rail_from_GridTransitionMap(rail),
rail_generator=rail_from_GridTransitionMap_generator(rail),
number_of_agents=1)
rail_env.reset()
......
......@@ -4,14 +4,13 @@
Tests for `flatland` package.
"""
from flatland.core.env import RailEnv
from flatland.envs.rail_env import RailEnv, random_rail_generator
import numpy as np
import random
import os
import matplotlib.pyplot as plt
from flatland.utils import rail_env_generator
import flatland.utils.rendertools as rt
......@@ -37,7 +36,7 @@ def checkFrozenImage(sFileImage):
def test_render_env():
random.seed(100)
oEnv = RailEnv(width=10, height=10, rail_generator=rail_env_generator.generate_random_rail, number_of_agents=2)
oEnv = RailEnv(width=10, height=10, rail_generator=random_rail_generator, number_of_agents=2)
oEnv.reset()
oRT = rt.RenderTool(oEnv)
plt.figure(figsize=(10, 10))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment