Skip to content
Snippets Groups Projects
Commit e2c3a494 authored by spiglerg's avatar spiglerg
Browse files

new generators that take agents_handles as extra input and return the agent's...

new generators that take agents_handles as extra input and return the agent's initial position, direction and target
parent f7501eb5
No related branches found
No related tags found
No related merge requests found
......@@ -54,8 +54,6 @@ env.agents_direction[0] = 1
# TODO: watch out: if these variables are overridden, the obs_builder object has to be reset, too!
env.obs_builder.reset()
"""
"""
# INFINITE-LOOP TEST
specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (7, 0), (0, 0)],
......@@ -84,6 +82,7 @@ env.obs_builder.reset()
env = RailEnv(width=7,
height=7,
rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
# rail_generator=complex_rail_generator(nr_start_goal=2),
number_of_agents=2)
# Print the distance map of each cell to the target of the first agent
......
......@@ -5,7 +5,7 @@ Definition of the RailEnv environment and related level-generation functions.
Generator functions are functions that take width, height and num_resets as arguments and return
a GridTransitionMap object.
"""
# import numpy as np
import numpy as np
# from flatland.core.env import Environment
# from flatland.core.env_observation_builder import TreeObsForRailEnv
......@@ -271,3 +271,93 @@ def connect_rail(rail_trans, rail_array, start, end):
def distance_on_rail(pos1, pos2):
return abs(pos1[0] - pos2[0]) + abs(pos1[1] - pos2[1])
def get_new_position(position, movement):
if movement == 0: # NORTH
return (position[0] - 1, position[1])
elif movement == 1: # EAST
return (position[0], position[1] + 1)
elif movement == 2: # SOUTH
return (position[0] + 1, position[1])
elif movement == 3: # WEST
return (position[0], position[1] - 1)
def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents):
"""
Given a `rail' GridTransitionMap, return a random placement of agents (initial position, direction and target).
TODO: add extensive documentation, as users may need this function to simplify their custom level generators.
"""
def _path_exists(rail, start, direction, end):
# BFS - Check if a path exists between the 2 nodes
visited = set()
stack = [(start, direction)]
while stack:
node = stack.pop()
if node[0][0] == end[0] and node[0][1] == end[1]:
return 1
if node not in visited:
visited.add(node)
moves = rail.get_transitions((node[0][0], node[0][1], node[1]))
for move_index in range(4):
if moves[move_index]:
stack.append((get_new_position(node[0], move_index),
move_index))
# If cell is a dead-end, append previous node with reversed
# orientation!
nbits = 0
tmp = rail.get_transitions((node[0][0], node[0][1]))
while tmp > 0:
nbits += (tmp & 1)
tmp = tmp >> 1
if nbits == 1:
stack.append((node[0], (node[1] + 2) % 4))
return 0
valid_positions = []
for r in range(rail.height):
for c in range(rail.width):
if rail.get_transitions((r, c)) > 0:
valid_positions.append((r, c))
re_generate = True
while re_generate:
agents_position = [
valid_positions[i] for i in
np.random.choice(len(valid_positions), num_agents)]
agents_target = [
valid_positions[i] for i in
np.random.choice(len(valid_positions), num_agents)]
# agents_direction must be a direction for which a solution is
# guaranteed.
agents_direction = [0] * num_agents
re_generate = False
for i in range(num_agents):
valid_movements = []
for direction in range(4):
position = agents_position[i]
moves = rail.get_transitions((position[0], position[1], direction))
for move_index in range(4):
if moves[move_index]:
valid_movements.append((direction, move_index))
valid_starting_directions = []
for m in valid_movements:
new_position = get_new_position(agents_position[i], m[1])
if m[0] not in valid_starting_directions and \
_path_exists(rail, new_position, m[0], agents_target[i]):
valid_starting_directions.append(m[0])
if len(valid_starting_directions) == 0:
re_generate = True
else:
agents_direction[i] = valid_starting_directions[np.random.choice(len(valid_starting_directions), 1)[0]]
return agents_position, agents_direction, agents_target
......@@ -5,7 +5,7 @@ import numpy as np
from flatland.core.transitions import Grid8Transitions, RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.env_utils import distance_on_rail, connect_rail
from flatland.envs.env_utils import distance_on_rail, connect_rail, get_rnd_agents_pos_tgt_dir_on_rail
def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0):
......@@ -23,7 +23,7 @@ def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0):
The matrix with the correct 16-bit bitmaps for each cell.
"""
def generator(width, height, num_resets=0):
def generator(width, height, agents_handles, num_resets=0):
rail_trans = RailEnvTransitions()
rail_array = np.zeros(shape=(width, height), dtype=np.uint16)
......@@ -106,8 +106,12 @@ def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0):
return_rail = GridTransitionMap(width=width, height=height, transitions=rail_trans)
return_rail.grid = rail_array
# TODO: return start_goal
return return_rail
# TODO: return agents_position, agents_direction and agents_target!
# NOTE: the initial direction must be such that the target can be reached.
# See env_utils.get_rnd_agents_pos_tgt_dir_on_rail() for hints, if required.
return return_rail, [], [], []
return generator
......@@ -132,7 +136,7 @@ def rail_from_manual_specifications_generator(rail_spec):
the matrix of correct 16-bit bitmaps for each cell.
"""
def generator(width, height, num_resets=0):
def generator(width, height, agents_handles, num_resets=0):
t_utils = RailEnvTransitions()
height = len(rail_spec)
......@@ -147,7 +151,11 @@ def rail_from_manual_specifications_generator(rail_spec):
return []
rail.set_transitions((r, c), t_utils.rotate_transition(t_utils.transitions[cell[0]], cell[1]))
return rail
agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail(
rail,
len(agents_handles))
return rail, agents_position, agents_direction, agents_target
return generator
......@@ -168,8 +176,12 @@ def rail_from_GridTransitionMap_generator(rail_map):
Generator function that always returns the given `rail_map' object.
"""
def generator(width, height, num_resets=0):
return rail_map
def generator(width, height, agents_handles, num_resets=0):
agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail(
rail_map,
len(agents_handles))
return rail_map, agents_position, agents_direction, agents_target
return generator
......@@ -189,7 +201,7 @@ def rail_from_list_of_saved_GridTransitionMap_generator(list_of_filenames):
Generator function that always returns the given `rail_map' object.
"""
def generator(width, height, num_resets=0):
def generator(width, height, agents_handles, num_resets=0):
t_utils = RailEnvTransitions()
rail_map = GridTransitionMap(width=width, height=height, transitions=t_utils)
rail_map.load_transition_map(list_of_filenames[num_resets % len(list_of_filenames)], override_gridsize=False)
......@@ -197,7 +209,11 @@ def rail_from_list_of_saved_GridTransitionMap_generator(list_of_filenames):
if rail_map.grid.dtype == np.uint64:
rail_map.transitions = Grid8Transitions()
return rail_map
agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail(
rail_map,
len(agents_handles))
return rail_map, agents_position, agents_direction, agents_target
return generator
......@@ -243,7 +259,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
The matrix with the correct 16-bit bitmaps for each cell.
"""
def generator(width, height, num_resets=0):
def generator(width, height, agents_handles, num_resets=0):
t_utils = RailEnvTransitions()
transition_probability = cell_type_relative_proportion
......@@ -472,6 +488,11 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
return_rail = GridTransitionMap(width=width, height=height, transitions=t_utils)
return_rail.grid = tmp_rail
return return_rail
agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail(
return_rail,
len(agents_handles))
return return_rail, agents_position, agents_direction, agents_target
return generator
......@@ -9,6 +9,7 @@ import numpy as np
from flatland.core.env import Environment
from flatland.core.env_observation_builder import TreeObsForRailEnv
from flatland.envs.generators import random_rail_generator
from flatland.envs.env_utils import get_new_position
# from flatland.core.transitions import Grid8Transitions, RailEnvTransitions
# from flatland.core.transition_map import GridTransitionMap
......@@ -74,9 +75,10 @@ class RailEnv(Environment):
Parameters
-------
rail_generator : function
The rail_generator function is a function that takes the width and
height of a rail map along with the number of times the env has
been reset, and returns a GridTransitionMap object.
The rail_generator function is a function that takes the width,
height and agents handles of a rail environment, along with the number of times
the env has been reset, and returns a GridTransitionMap object and a list of
starting positions, targets, and initial orientations for agent handle.
Implemented functions are:
random_rail_generator : generate a random rail of given size
rail_from_GridTransitionMap_generator(rail_map) : generate a rail from
......@@ -130,107 +132,16 @@ class RailEnv(Environment):
def get_agent_handles(self):
return self.agents_handles
def fill_valid_positions(self):
''' Populate the valid_positions list for the current TransitionMap.
'''
self.valid_positions = valid_positions = []
for r in range(self.height):
for c in range(self.width):
if self.rail.get_transitions((r, c)) > 0:
valid_positions.append((r, c))
def check_agent_lists(self):
''' Check that the agent_handles, position and direction lists are all of length
number_of_agents.
(Suggest this is replaced with a single list of Agent objects :)
'''
for lAgents, name in zip(
[self.agents_handles, self.agents_position, self.agents_direction],
["handles", "positions", "directions"]):
assert self.number_of_agents == len(lAgents), "Inconsistent agent list:" + name
def check_agent_locdirpath(self, iAgent):
''' Check that agent iAgent has a valid location and direction,
with a path to its target.
(Not currently used?)
'''
valid_movements = []
for direction in range(4):
position = self.agents_position[iAgent]
moves = self.rail.get_transitions((position[0], position[1], direction))
for move_index in range(4):
if moves[move_index]:
valid_movements.append((direction, move_index))
valid_starting_directions = []
for m in valid_movements:
new_position = self._new_position(self.agents_position[iAgent], m[1])
if m[0] not in valid_starting_directions and \
self._path_exists(new_position, m[0], self.agents_target[iAgent]):
valid_starting_directions.append(m[0])
if len(valid_starting_directions) == 0:
return False
else:
return True
def pick_agent_direction(self, rcPos, rcTarget):
""" Pick and return a valid direction index (0..3) for an agent starting at
row,col rcPos with target rcTarget.
Return None if no path exists.
Picks random direction if more than one exists (uniformly).
def reset(self, regen_rail=True, replace_agents=True):
"""
valid_movements = []
for direction in range(4):
moves = self.rail.get_transitions((*rcPos, direction))
for move_index in range(4):
if moves[move_index]:
valid_movements.append((direction, move_index))
# print("pos", rcPos, "targ", rcTarget, "valid movements", valid_movements)
valid_starting_directions = []
for m in valid_movements:
new_position = self._new_position(rcPos, m[1])
if m[0] not in valid_starting_directions and self._path_exists(new_position, m[0], rcTarget):
valid_starting_directions.append(m[0])
if len(valid_starting_directions) == 0:
return None
else:
return valid_starting_directions[np.random.choice(len(valid_starting_directions), 1)[0]]
def add_agent(self, rcPos=None, rcTarget=None, iDir=None):
""" Add a new agent at position rcPos with target rcTarget and
initial direction index iDir.
Should also store this initial position etc as environment "meta-data"
but this does not yet exist.
TODO: replace_agents is ignored at the moment; agents will always be replaced.
"""
self.check_agent_lists()
if rcPos is None:
rcPos = np.random.choice(len(self.valid_positions))
iAgent = self.number_of_agents
if iDir is None:
iDir = self.pick_agent_direction(rcPos, rcTarget)
if iDir is None:
print("Error picking agent direction at pos:", rcPos)
return None
self.agents_position.append(tuple(rcPos)) # ensure it's a tuple not a list
self.agents_handles.append(max(self.agents_handles + [-1]) + 1) # max(handles) + 1, starting at 0
self.agents_direction.append(iDir)
self.agents_target.append(rcPos) # set the target to the origin initially
self.number_of_agents += 1
self.check_agent_lists()
return iAgent
def reset(self, regen_rail=True, replace_agents=True):
if regen_rail or self.rail is None:
# TODO: Import not only rail information but also start and goal positions
self.rail = self.rail_generator(self.width, self.height, self.num_resets)
self.fill_valid_positions()
self.rail, self.agents_position, self.agents_direction, self.agents_target = self.rail_generator(
self.width,
self.height,
self.agents_handles,
self.num_resets)
self.num_resets += 1
......@@ -238,37 +149,6 @@ class RailEnv(Environment):
for handle in self.agents_handles:
self.dones[handle] = False
# Use a TreeObsForRailEnv to compute distance maps to each agent's target, to sample initial
# agent's orientations that allow a valid solution.
# TODO: Possibility ot fill valid positions from list of goals and start
self.fill_valid_positions()
if replace_agents:
re_generate = True
while re_generate:
# self.agents_position = random.sample(valid_positions,
# self.number_of_agents)
self.agents_position = [
self.valid_positions[i] for i in
np.random.choice(len(self.valid_positions), self.number_of_agents)]
self.agents_target = [
self.valid_positions[i] for i in
np.random.choice(len(self.valid_positions), self.number_of_agents)]
# agents_direction must be a direction for which a solution is
# guaranteed.
self.agents_direction = [0] * self.number_of_agents
re_generate = False
for i in range(self.number_of_agents):
direction = self.pick_agent_direction(self.agents_position[i], self.agents_target[i])
if direction is None:
re_generate = True
break
else:
self.agents_direction[i] = direction
# Reset the state of the observation builder with the new environment
self.obs_builder.reset()
......@@ -342,7 +222,7 @@ class RailEnv(Environment):
movement = np.argmax(possible_transitions)
transition_isValid = True
new_position = self._new_position(pos, movement)
new_position = get_new_position(pos, movement)
# Is it a legal move? 1) transition allows the movement in the
# cell, 2) the new cell is not empty (case 0), 3) the cell is
# free, i.e., no agent is currently in that cell
......@@ -401,45 +281,6 @@ class RailEnv(Environment):
self.actions = [0] * self.number_of_agents
return self._get_observations(), self.rewards_dict, self.dones, {}
def _new_position(self, position, movement):
if movement == 0: # NORTH
return (position[0] - 1, position[1])
elif movement == 1: # EAST
return (position[0], position[1] + 1)
elif movement == 2: # SOUTH
return (position[0] + 1, position[1])
elif movement == 3: # WEST
return (position[0], position[1] - 1)
def _path_exists(self, start, direction, end):
# BFS - Check if a path exists between the 2 nodes
visited = set()
stack = [(start, direction)]
while stack:
node = stack.pop()
if node[0][0] == end[0] and node[0][1] == end[1]:
return 1
if node not in visited:
visited.add(node)
moves = self.rail.get_transitions((node[0][0], node[0][1], node[1]))
for move_index in range(4):
if moves[move_index]:
stack.append((self._new_position(node[0], move_index),
move_index))
# If cell is a dead-end, append previous node with reversed
# orientation!
nbits = 0
tmp = self.rail.get_transitions((node[0][0], node[0][1]))
while tmp > 0:
nbits += (tmp & 1)
tmp = tmp >> 1
if nbits == 1:
stack.append((node[0], (node[1] + 2) % 4))
return 0
def _get_observations(self):
self.obs_dict = {}
for handle in self.agents_handles:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment