Skip to content
Snippets Groups Projects
Commit e2c3a494 authored by spiglerg's avatar spiglerg
Browse files

new generators that take agents_handles as extra input and return the agent's...

new generators that take agents_handles as extra input and return the agent's initial position, direction and target
parent f7501eb5
No related branches found
No related tags found
No related merge requests found
...@@ -54,8 +54,6 @@ env.agents_direction[0] = 1 ...@@ -54,8 +54,6 @@ env.agents_direction[0] = 1
# TODO: watch out: if these variables are overridden, the obs_builder object has to be reset, too! # TODO: watch out: if these variables are overridden, the obs_builder object has to be reset, too!
env.obs_builder.reset() env.obs_builder.reset()
""" """
""" """
# INFINITE-LOOP TEST # INFINITE-LOOP TEST
specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (7, 0), (0, 0)], specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (7, 0), (0, 0)],
...@@ -84,6 +82,7 @@ env.obs_builder.reset() ...@@ -84,6 +82,7 @@ env.obs_builder.reset()
env = RailEnv(width=7, env = RailEnv(width=7,
height=7, height=7,
rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
# rail_generator=complex_rail_generator(nr_start_goal=2),
number_of_agents=2) number_of_agents=2)
# Print the distance map of each cell to the target of the first agent # Print the distance map of each cell to the target of the first agent
......
...@@ -5,7 +5,7 @@ Definition of the RailEnv environment and related level-generation functions. ...@@ -5,7 +5,7 @@ Definition of the RailEnv environment and related level-generation functions.
Generator functions are functions that take width, height and num_resets as arguments and return Generator functions are functions that take width, height and num_resets as arguments and return
a GridTransitionMap object. a GridTransitionMap object.
""" """
# import numpy as np import numpy as np
# from flatland.core.env import Environment # from flatland.core.env import Environment
# from flatland.core.env_observation_builder import TreeObsForRailEnv # from flatland.core.env_observation_builder import TreeObsForRailEnv
...@@ -271,3 +271,93 @@ def connect_rail(rail_trans, rail_array, start, end): ...@@ -271,3 +271,93 @@ def connect_rail(rail_trans, rail_array, start, end):
def distance_on_rail(pos1, pos2): def distance_on_rail(pos1, pos2):
return abs(pos1[0] - pos2[0]) + abs(pos1[1] - pos2[1]) return abs(pos1[0] - pos2[0]) + abs(pos1[1] - pos2[1])
def get_new_position(position, movement):
if movement == 0: # NORTH
return (position[0] - 1, position[1])
elif movement == 1: # EAST
return (position[0], position[1] + 1)
elif movement == 2: # SOUTH
return (position[0] + 1, position[1])
elif movement == 3: # WEST
return (position[0], position[1] - 1)
def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents):
"""
Given a `rail' GridTransitionMap, return a random placement of agents (initial position, direction and target).
TODO: add extensive documentation, as users may need this function to simplify their custom level generators.
"""
def _path_exists(rail, start, direction, end):
# BFS - Check if a path exists between the 2 nodes
visited = set()
stack = [(start, direction)]
while stack:
node = stack.pop()
if node[0][0] == end[0] and node[0][1] == end[1]:
return 1
if node not in visited:
visited.add(node)
moves = rail.get_transitions((node[0][0], node[0][1], node[1]))
for move_index in range(4):
if moves[move_index]:
stack.append((get_new_position(node[0], move_index),
move_index))
# If cell is a dead-end, append previous node with reversed
# orientation!
nbits = 0
tmp = rail.get_transitions((node[0][0], node[0][1]))
while tmp > 0:
nbits += (tmp & 1)
tmp = tmp >> 1
if nbits == 1:
stack.append((node[0], (node[1] + 2) % 4))
return 0
valid_positions = []
for r in range(rail.height):
for c in range(rail.width):
if rail.get_transitions((r, c)) > 0:
valid_positions.append((r, c))
re_generate = True
while re_generate:
agents_position = [
valid_positions[i] for i in
np.random.choice(len(valid_positions), num_agents)]
agents_target = [
valid_positions[i] for i in
np.random.choice(len(valid_positions), num_agents)]
# agents_direction must be a direction for which a solution is
# guaranteed.
agents_direction = [0] * num_agents
re_generate = False
for i in range(num_agents):
valid_movements = []
for direction in range(4):
position = agents_position[i]
moves = rail.get_transitions((position[0], position[1], direction))
for move_index in range(4):
if moves[move_index]:
valid_movements.append((direction, move_index))
valid_starting_directions = []
for m in valid_movements:
new_position = get_new_position(agents_position[i], m[1])
if m[0] not in valid_starting_directions and \
_path_exists(rail, new_position, m[0], agents_target[i]):
valid_starting_directions.append(m[0])
if len(valid_starting_directions) == 0:
re_generate = True
else:
agents_direction[i] = valid_starting_directions[np.random.choice(len(valid_starting_directions), 1)[0]]
return agents_position, agents_direction, agents_target
...@@ -5,7 +5,7 @@ import numpy as np ...@@ -5,7 +5,7 @@ import numpy as np
from flatland.core.transitions import Grid8Transitions, RailEnvTransitions from flatland.core.transitions import Grid8Transitions, RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap from flatland.core.transition_map import GridTransitionMap
from flatland.envs.env_utils import distance_on_rail, connect_rail from flatland.envs.env_utils import distance_on_rail, connect_rail, get_rnd_agents_pos_tgt_dir_on_rail
def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0): def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0):
...@@ -23,7 +23,7 @@ def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0): ...@@ -23,7 +23,7 @@ def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0):
The matrix with the correct 16-bit bitmaps for each cell. The matrix with the correct 16-bit bitmaps for each cell.
""" """
def generator(width, height, num_resets=0): def generator(width, height, agents_handles, num_resets=0):
rail_trans = RailEnvTransitions() rail_trans = RailEnvTransitions()
rail_array = np.zeros(shape=(width, height), dtype=np.uint16) rail_array = np.zeros(shape=(width, height), dtype=np.uint16)
...@@ -106,8 +106,12 @@ def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0): ...@@ -106,8 +106,12 @@ def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0):
return_rail = GridTransitionMap(width=width, height=height, transitions=rail_trans) return_rail = GridTransitionMap(width=width, height=height, transitions=rail_trans)
return_rail.grid = rail_array return_rail.grid = rail_array
# TODO: return start_goal
return return_rail # TODO: return agents_position, agents_direction and agents_target!
# NOTE: the initial direction must be such that the target can be reached.
# See env_utils.get_rnd_agents_pos_tgt_dir_on_rail() for hints, if required.
return return_rail, [], [], []
return generator return generator
...@@ -132,7 +136,7 @@ def rail_from_manual_specifications_generator(rail_spec): ...@@ -132,7 +136,7 @@ def rail_from_manual_specifications_generator(rail_spec):
the matrix of correct 16-bit bitmaps for each cell. the matrix of correct 16-bit bitmaps for each cell.
""" """
def generator(width, height, num_resets=0): def generator(width, height, agents_handles, num_resets=0):
t_utils = RailEnvTransitions() t_utils = RailEnvTransitions()
height = len(rail_spec) height = len(rail_spec)
...@@ -147,7 +151,11 @@ def rail_from_manual_specifications_generator(rail_spec): ...@@ -147,7 +151,11 @@ def rail_from_manual_specifications_generator(rail_spec):
return [] return []
rail.set_transitions((r, c), t_utils.rotate_transition(t_utils.transitions[cell[0]], cell[1])) rail.set_transitions((r, c), t_utils.rotate_transition(t_utils.transitions[cell[0]], cell[1]))
return rail agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail(
rail,
len(agents_handles))
return rail, agents_position, agents_direction, agents_target
return generator return generator
...@@ -168,8 +176,12 @@ def rail_from_GridTransitionMap_generator(rail_map): ...@@ -168,8 +176,12 @@ def rail_from_GridTransitionMap_generator(rail_map):
Generator function that always returns the given `rail_map' object. Generator function that always returns the given `rail_map' object.
""" """
def generator(width, height, num_resets=0): def generator(width, height, agents_handles, num_resets=0):
return rail_map agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail(
rail_map,
len(agents_handles))
return rail_map, agents_position, agents_direction, agents_target
return generator return generator
...@@ -189,7 +201,7 @@ def rail_from_list_of_saved_GridTransitionMap_generator(list_of_filenames): ...@@ -189,7 +201,7 @@ def rail_from_list_of_saved_GridTransitionMap_generator(list_of_filenames):
Generator function that always returns the given `rail_map' object. Generator function that always returns the given `rail_map' object.
""" """
def generator(width, height, num_resets=0): def generator(width, height, agents_handles, num_resets=0):
t_utils = RailEnvTransitions() t_utils = RailEnvTransitions()
rail_map = GridTransitionMap(width=width, height=height, transitions=t_utils) rail_map = GridTransitionMap(width=width, height=height, transitions=t_utils)
rail_map.load_transition_map(list_of_filenames[num_resets % len(list_of_filenames)], override_gridsize=False) rail_map.load_transition_map(list_of_filenames[num_resets % len(list_of_filenames)], override_gridsize=False)
...@@ -197,7 +209,11 @@ def rail_from_list_of_saved_GridTransitionMap_generator(list_of_filenames): ...@@ -197,7 +209,11 @@ def rail_from_list_of_saved_GridTransitionMap_generator(list_of_filenames):
if rail_map.grid.dtype == np.uint64: if rail_map.grid.dtype == np.uint64:
rail_map.transitions = Grid8Transitions() rail_map.transitions = Grid8Transitions()
return rail_map agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail(
rail_map,
len(agents_handles))
return rail_map, agents_position, agents_direction, agents_target
return generator return generator
...@@ -243,7 +259,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8): ...@@ -243,7 +259,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
The matrix with the correct 16-bit bitmaps for each cell. The matrix with the correct 16-bit bitmaps for each cell.
""" """
def generator(width, height, num_resets=0): def generator(width, height, agents_handles, num_resets=0):
t_utils = RailEnvTransitions() t_utils = RailEnvTransitions()
transition_probability = cell_type_relative_proportion transition_probability = cell_type_relative_proportion
...@@ -472,6 +488,11 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8): ...@@ -472,6 +488,11 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
return_rail = GridTransitionMap(width=width, height=height, transitions=t_utils) return_rail = GridTransitionMap(width=width, height=height, transitions=t_utils)
return_rail.grid = tmp_rail return_rail.grid = tmp_rail
return return_rail
agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail(
return_rail,
len(agents_handles))
return return_rail, agents_position, agents_direction, agents_target
return generator return generator
...@@ -9,6 +9,7 @@ import numpy as np ...@@ -9,6 +9,7 @@ import numpy as np
from flatland.core.env import Environment from flatland.core.env import Environment
from flatland.core.env_observation_builder import TreeObsForRailEnv from flatland.core.env_observation_builder import TreeObsForRailEnv
from flatland.envs.generators import random_rail_generator from flatland.envs.generators import random_rail_generator
from flatland.envs.env_utils import get_new_position
# from flatland.core.transitions import Grid8Transitions, RailEnvTransitions # from flatland.core.transitions import Grid8Transitions, RailEnvTransitions
# from flatland.core.transition_map import GridTransitionMap # from flatland.core.transition_map import GridTransitionMap
...@@ -74,9 +75,10 @@ class RailEnv(Environment): ...@@ -74,9 +75,10 @@ class RailEnv(Environment):
Parameters Parameters
------- -------
rail_generator : function rail_generator : function
The rail_generator function is a function that takes the width and The rail_generator function is a function that takes the width,
height of a rail map along with the number of times the env has height and agents handles of a rail environment, along with the number of times
been reset, and returns a GridTransitionMap object. the env has been reset, and returns a GridTransitionMap object and a list of
starting positions, targets, and initial orientations for agent handle.
Implemented functions are: Implemented functions are:
random_rail_generator : generate a random rail of given size random_rail_generator : generate a random rail of given size
rail_from_GridTransitionMap_generator(rail_map) : generate a rail from rail_from_GridTransitionMap_generator(rail_map) : generate a rail from
...@@ -130,107 +132,16 @@ class RailEnv(Environment): ...@@ -130,107 +132,16 @@ class RailEnv(Environment):
def get_agent_handles(self): def get_agent_handles(self):
return self.agents_handles return self.agents_handles
def fill_valid_positions(self): def reset(self, regen_rail=True, replace_agents=True):
''' Populate the valid_positions list for the current TransitionMap.
'''
self.valid_positions = valid_positions = []
for r in range(self.height):
for c in range(self.width):
if self.rail.get_transitions((r, c)) > 0:
valid_positions.append((r, c))
def check_agent_lists(self):
''' Check that the agent_handles, position and direction lists are all of length
number_of_agents.
(Suggest this is replaced with a single list of Agent objects :)
'''
for lAgents, name in zip(
[self.agents_handles, self.agents_position, self.agents_direction],
["handles", "positions", "directions"]):
assert self.number_of_agents == len(lAgents), "Inconsistent agent list:" + name
def check_agent_locdirpath(self, iAgent):
''' Check that agent iAgent has a valid location and direction,
with a path to its target.
(Not currently used?)
'''
valid_movements = []
for direction in range(4):
position = self.agents_position[iAgent]
moves = self.rail.get_transitions((position[0], position[1], direction))
for move_index in range(4):
if moves[move_index]:
valid_movements.append((direction, move_index))
valid_starting_directions = []
for m in valid_movements:
new_position = self._new_position(self.agents_position[iAgent], m[1])
if m[0] not in valid_starting_directions and \
self._path_exists(new_position, m[0], self.agents_target[iAgent]):
valid_starting_directions.append(m[0])
if len(valid_starting_directions) == 0:
return False
else:
return True
def pick_agent_direction(self, rcPos, rcTarget):
""" Pick and return a valid direction index (0..3) for an agent starting at
row,col rcPos with target rcTarget.
Return None if no path exists.
Picks random direction if more than one exists (uniformly).
""" """
valid_movements = [] TODO: replace_agents is ignored at the moment; agents will always be replaced.
for direction in range(4):
moves = self.rail.get_transitions((*rcPos, direction))
for move_index in range(4):
if moves[move_index]:
valid_movements.append((direction, move_index))
# print("pos", rcPos, "targ", rcTarget, "valid movements", valid_movements)
valid_starting_directions = []
for m in valid_movements:
new_position = self._new_position(rcPos, m[1])
if m[0] not in valid_starting_directions and self._path_exists(new_position, m[0], rcTarget):
valid_starting_directions.append(m[0])
if len(valid_starting_directions) == 0:
return None
else:
return valid_starting_directions[np.random.choice(len(valid_starting_directions), 1)[0]]
def add_agent(self, rcPos=None, rcTarget=None, iDir=None):
""" Add a new agent at position rcPos with target rcTarget and
initial direction index iDir.
Should also store this initial position etc as environment "meta-data"
but this does not yet exist.
""" """
self.check_agent_lists()
if rcPos is None:
rcPos = np.random.choice(len(self.valid_positions))
iAgent = self.number_of_agents
if iDir is None:
iDir = self.pick_agent_direction(rcPos, rcTarget)
if iDir is None:
print("Error picking agent direction at pos:", rcPos)
return None
self.agents_position.append(tuple(rcPos)) # ensure it's a tuple not a list
self.agents_handles.append(max(self.agents_handles + [-1]) + 1) # max(handles) + 1, starting at 0
self.agents_direction.append(iDir)
self.agents_target.append(rcPos) # set the target to the origin initially
self.number_of_agents += 1
self.check_agent_lists()
return iAgent
def reset(self, regen_rail=True, replace_agents=True):
if regen_rail or self.rail is None: if regen_rail or self.rail is None:
# TODO: Import not only rail information but also start and goal positions self.rail, self.agents_position, self.agents_direction, self.agents_target = self.rail_generator(
self.rail = self.rail_generator(self.width, self.height, self.num_resets) self.width,
self.fill_valid_positions() self.height,
self.agents_handles,
self.num_resets)
self.num_resets += 1 self.num_resets += 1
...@@ -238,37 +149,6 @@ class RailEnv(Environment): ...@@ -238,37 +149,6 @@ class RailEnv(Environment):
for handle in self.agents_handles: for handle in self.agents_handles:
self.dones[handle] = False self.dones[handle] = False
# Use a TreeObsForRailEnv to compute distance maps to each agent's target, to sample initial
# agent's orientations that allow a valid solution.
# TODO: Possibility ot fill valid positions from list of goals and start
self.fill_valid_positions()
if replace_agents:
re_generate = True
while re_generate:
# self.agents_position = random.sample(valid_positions,
# self.number_of_agents)
self.agents_position = [
self.valid_positions[i] for i in
np.random.choice(len(self.valid_positions), self.number_of_agents)]
self.agents_target = [
self.valid_positions[i] for i in
np.random.choice(len(self.valid_positions), self.number_of_agents)]
# agents_direction must be a direction for which a solution is
# guaranteed.
self.agents_direction = [0] * self.number_of_agents
re_generate = False
for i in range(self.number_of_agents):
direction = self.pick_agent_direction(self.agents_position[i], self.agents_target[i])
if direction is None:
re_generate = True
break
else:
self.agents_direction[i] = direction
# Reset the state of the observation builder with the new environment # Reset the state of the observation builder with the new environment
self.obs_builder.reset() self.obs_builder.reset()
...@@ -342,7 +222,7 @@ class RailEnv(Environment): ...@@ -342,7 +222,7 @@ class RailEnv(Environment):
movement = np.argmax(possible_transitions) movement = np.argmax(possible_transitions)
transition_isValid = True transition_isValid = True
new_position = self._new_position(pos, movement) new_position = get_new_position(pos, movement)
# Is it a legal move? 1) transition allows the movement in the # Is it a legal move? 1) transition allows the movement in the
# cell, 2) the new cell is not empty (case 0), 3) the cell is # cell, 2) the new cell is not empty (case 0), 3) the cell is
# free, i.e., no agent is currently in that cell # free, i.e., no agent is currently in that cell
...@@ -401,45 +281,6 @@ class RailEnv(Environment): ...@@ -401,45 +281,6 @@ class RailEnv(Environment):
self.actions = [0] * self.number_of_agents self.actions = [0] * self.number_of_agents
return self._get_observations(), self.rewards_dict, self.dones, {} return self._get_observations(), self.rewards_dict, self.dones, {}
def _new_position(self, position, movement):
if movement == 0: # NORTH
return (position[0] - 1, position[1])
elif movement == 1: # EAST
return (position[0], position[1] + 1)
elif movement == 2: # SOUTH
return (position[0] + 1, position[1])
elif movement == 3: # WEST
return (position[0], position[1] - 1)
def _path_exists(self, start, direction, end):
# BFS - Check if a path exists between the 2 nodes
visited = set()
stack = [(start, direction)]
while stack:
node = stack.pop()
if node[0][0] == end[0] and node[0][1] == end[1]:
return 1
if node not in visited:
visited.add(node)
moves = self.rail.get_transitions((node[0][0], node[0][1], node[1]))
for move_index in range(4):
if moves[move_index]:
stack.append((self._new_position(node[0], move_index),
move_index))
# If cell is a dead-end, append previous node with reversed
# orientation!
nbits = 0
tmp = self.rail.get_transitions((node[0][0], node[0][1]))
while tmp > 0:
nbits += (tmp & 1)
tmp = tmp >> 1
if nbits == 1:
stack.append((node[0], (node[1] + 2) % 4))
return 0
def _get_observations(self): def _get_observations(self):
self.obs_dict = {} self.obs_dict = {}
for handle in self.agents_handles: for handle in self.agents_handles:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment