Skip to content
Snippets Groups Projects
Commit 8112097a authored by Erik Nygren's avatar Erik Nygren
Browse files

Merge branch 'master' of gitlab.aicrowd.com:flatland/flatland into flatland_navigation_training

parents a9440cfb ea939b74
No related branches found
No related tags found
No related merge requests found
LICENSE 0 → 100644
MIT License
Copyright (c) 2019 SBB AG
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
File added
...@@ -6,39 +6,35 @@ from flatland.envs.rail_env import * ...@@ -6,39 +6,35 @@ from flatland.envs.rail_env import *
from flatland.core.env_observation_builder import TreeObsForRailEnv from flatland.core.env_observation_builder import TreeObsForRailEnv
from flatland.utils.rendertools import * from flatland.utils.rendertools import *
random.seed(1) random.seed(0)
np.random.seed(1) np.random.seed(0)
"""
transition_probability = [1.0, # empty cell - Case 0
3.0, # Case 1 - straight
1.0, # Case 2 - simple switch
3.0, # Case 3 - diamond drossing
2.0, # Case 4 - single slip
1.0, # Case 5 - double slip
1.0, # Case 6 - symmetrical
1.0] # Case 7 - dead end
"""
transition_probability = [1.0, # empty cell - Case 0 transition_probability = [1.0, # empty cell - Case 0
1.0, # Case 1 - straight 1.0, # Case 1 - straight
0.5, # Case 2 - simple switch 1.0, # Case 2 - simple switch
0.2, # Case 3 - diamond drossing 0.3, # Case 3 - diamond drossing
0.5, # Case 4 - single slip 0.5, # Case 4 - single slip
0.1, # Case 5 - double slip 0.5, # Case 5 - double slip
0.2, # Case 6 - symmetrical 0.2, # Case 6 - symmetrical
1.0] # Case 7 - dead end 0.0] # Case 7 - dead end
# Example generate a random rail # Example generate a random rail
env = RailEnv(width=20, env = RailEnv(width=20,
height=20, height=20,
rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
number_of_agents=10) number_of_agents=10)
# env = RailEnv(width=20,
# height=20,
# rail_generator=rail_from_list_of_saved_GridTransitionMap_generator(['examples/sample_10_10_rail.npy']),
# number_of_agents=10)
env.reset() env.reset()
env_renderer = RenderTool(env) env_renderer = RenderTool(env)
env_renderer.renderEnv(show=True) env_renderer.renderEnv(show=True)
"""
# Example generate a rail given a manual specification, # Example generate a rail given a manual specification,
# a map of tuples (cell_type, rotation) # a map of tuples (cell_type, rotation)
specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (7, 0), (0, 0)], specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (7, 0), (0, 0)],
...@@ -51,19 +47,25 @@ env = RailEnv(width=6, ...@@ -51,19 +47,25 @@ env = RailEnv(width=6,
obs_builder_object=TreeObsForRailEnv(max_depth=2)) obs_builder_object=TreeObsForRailEnv(max_depth=2))
handle = env.get_agent_handles() handle = env.get_agent_handles()
env.agents_position[0] = [1, 4] env.agents_position[0] = [1, 4]
env.agents_target[0] = [1, 1] env.agents_target[0] = [1, 1]
env.agents_direction[0] = 1 env.agents_direction[0] = 1
# TODO: watch out: if these variables are overridden, the obs_builder object has to be reset, too! # TODO: watch out: if these variables are overridden, the obs_builder object has to be reset, too!
env.obs_builder.reset() env.obs_builder.reset()
"""
env = RailEnv(width=7,
height=7,
rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
number_of_agents=2)
# TODO: delete next line # Print the distance map of each cell to the target of the first agent
#for i in range(4): # for i in range(4):
# print(env.obs_builder.distance_map[0, :, :, i]) # print(env.obs_builder.distance_map[0, :, :, i])
# Print the observation vector for agent 0
obs, all_rewards, done, _ = env.step({0:0}) obs, all_rewards, done, _ = env.step({0:0})
env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5) for i in range(env.number_of_agents):
env.obs_builder.util_print_obs_subtree(tree=obs[i], num_features_per_node=5)
env_renderer = RenderTool(env) env_renderer = RenderTool(env)
env_renderer.renderEnv(show=True) env_renderer.renderEnv(show=True)
......
...@@ -103,7 +103,6 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -103,7 +103,6 @@ class TreeObsForRailEnv(ObservationBuilder):
node = nodes_queue.popleft() node = nodes_queue.popleft()
node_id = (node[0], node[1], node[2]) node_id = (node[0], node[1], node[2])
if node_id not in visited: if node_id not in visited:
visited.add(node_id) visited.add(node_id)
...@@ -126,58 +125,50 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -126,58 +125,50 @@ class TreeObsForRailEnv(ObservationBuilder):
""" """
neighbors = [] neighbors = []
for direction in range(4): possible_directions = [0, 1, 2, 3]
new_cell = self._new_position(position, (direction+2) % 4) if enforce_target_direction >= 0:
# The agent must land into the current cell with orientation `enforce_target_direction'.
# This is only possible if the agent has arrived from the cell in the opposite direction!
possible_directions = [(enforce_target_direction+2) % 4]
for neigh_direction in possible_directions:
new_cell = self._new_position(position, neigh_direction)
if new_cell[0] >= 0 and new_cell[0] < self.env.height and \ if new_cell[0] >= 0 and new_cell[0] < self.env.height and \
new_cell[1] >= 0 and new_cell[1] < self.env.width: new_cell[1] >= 0 and new_cell[1] < self.env.width:
# Check if the two cells are connected by a valid transition desired_movement_from_new_cell = (neigh_direction+2) % 4
transitionValid = False
for orientation in range(4): """
moves = self.env.rail.get_transitions((new_cell[0], new_cell[1], orientation)) # Is the next cell a dead-end?
if moves[direction]: isNextCellDeadEnd = False
transitionValid = True nbits = 0
break tmp = self.env.rail.get_transitions((new_cell[0], new_cell[1]))
while tmp > 0:
if not transitionValid: nbits += (tmp & 1)
continue tmp = tmp >> 1
if nbits == 1:
# Check if a transition in direction node[2] is possible if an agent lands in the current # Dead-end!
# cell with orientation `direction'; this only applies to cells that are not dead-ends! isNextCellDeadEnd = True
directionMatch = True """
if enforce_target_direction >= 0:
directionMatch = self.env.rail.get_transition((new_cell[0], new_cell[1], direction), # Check all possible transitions in new_cell
enforce_target_direction) for agent_orientation in range(4):
# Is a transition along movement `desired_movement_from_new_cell' to the current cell possible?
# If transition is found to invalid, check if perhaps it is a dead-end, in which case the isValid = self.env.rail.get_transition((new_cell[0], new_cell[1], agent_orientation),
# direction of movement is rotated 180 degrees (moving forward turns the agents and makes desired_movement_from_new_cell)
# it step in the previous cell)
if not directionMatch: if isValid:
# If cell is a dead-end, append previous node with reversed """
# orientation! # TODO: check that it works with deadends! -- still bugged!
nbits = 0 movement = desired_movement_from_new_cell
tmp = self.env.rail.get_transitions((new_cell[0], new_cell[1])) if isNextCellDeadEnd:
while tmp > 0: movement = (desired_movement_from_new_cell+2) % 4
nbits += (tmp & 1) """
tmp = tmp >> 1 new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation],
if nbits == 1: current_distance+1)
# Dead-end! neighbors.append((new_cell[0], new_cell[1], agent_orientation, new_distance))
# Check if transition is possible in new_cell with orientation self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation] = new_distance
# (direction+2)%4 in direction `direction'
directionMatch = directionMatch or \
self.env.rail.get_transition((new_cell[0], new_cell[1], (direction+2) % 4),
direction)
if transitionValid and directionMatch:
# Append all possible orientations in new_cell that allow a transition to direction!
for orientation in range(4):
moves = self.env.rail.get_transitions((new_cell[0], new_cell[1], orientation))
if moves[direction]:
new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], orientation],
current_distance+1)
neighbors.append((new_cell[0], new_cell[1], orientation, new_distance))
self.distance_map[target_nr, new_cell[0], new_cell[1], orientation] = new_distance
return neighbors return neighbors
...@@ -309,16 +300,24 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -309,16 +300,24 @@ class TreeObsForRailEnv(ObservationBuilder):
exploring = False exploring = False
if num_transitions == 1: if num_transitions == 1:
# Check if dead-end, or if we can go forward along direction # Check if dead-end, or if we can go forward along direction
if cell_transitions[direction]: nbits = 0
position = self._new_position(position, direction) tmp = self.env.rail.get_transitions((position[0], position[1]))
while tmp > 0:
nbits += (tmp & 1)
tmp = tmp >> 1
if nbits == 1:
# Dead-end!
last_isDeadEnd = True
if not last_isDeadEnd:
# Keep walking through the tree along `direction' # Keep walking through the tree along `direction'
exploring = True exploring = True
else: for i in range(4):
# If a dead-end is reached, pick that as node. Also, no further branching is possible. if cell_transitions[i]:
last_isDeadEnd = True position = self._new_position(position, i)
break direction = i
break
elif num_transitions > 0: elif num_transitions > 0:
# Switch detected # Switch detected
...@@ -352,8 +351,6 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -352,8 +351,6 @@ class TreeObsForRailEnv(ObservationBuilder):
0, 0,
self.distance_map[handle, position[0], position[1], direction]] self.distance_map[handle, position[0], position[1], direction]]
# TODO:
# ############################# # #############################
# ############################# # #############################
...@@ -386,15 +383,15 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -386,15 +383,15 @@ class TreeObsForRailEnv(ObservationBuilder):
return observation return observation
def util_print_obs_subtree(self, tree, num_elements_per_node=5, prompt='', current_depth=0): def util_print_obs_subtree(self, tree, num_features_per_node=5, prompt='', current_depth=0):
""" """
Utility function to pretty-print tree observations returned by this object. Utility function to pretty-print tree observations returned by this object.
""" """
if len(tree) < num_elements_per_node: if len(tree) < num_features_per_node:
return return
depth = 0 depth = 0
tmp = len(tree)/num_elements_per_node-1 tmp = len(tree)/num_features_per_node-1
pow4 = 4 pow4 = 4
while tmp > 0: while tmp > 0:
tmp -= pow4 tmp -= pow4
...@@ -403,12 +400,12 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -403,12 +400,12 @@ class TreeObsForRailEnv(ObservationBuilder):
prompt_ = ['L:', 'F:', 'R:', 'B:'] prompt_ = ['L:', 'F:', 'R:', 'B:']
print(" "*current_depth + prompt, tree[0:num_elements_per_node]) print(" "*current_depth + prompt, tree[0:num_features_per_node])
child_size = (len(tree)-num_elements_per_node)//4 child_size = (len(tree)-num_features_per_node)//4
for children in range(4): for children in range(4):
child_tree = tree[(num_elements_per_node+children*child_size): child_tree = tree[(num_features_per_node+children*child_size):
(num_elements_per_node+(children+1)*child_size)] (num_features_per_node+(children+1)*child_size)]
self.util_print_obs_subtree(child_tree, self.util_print_obs_subtree(child_tree,
num_elements_per_node, num_features_per_node,
prompt=prompt_[children], prompt=prompt_[children],
current_depth=current_depth+1) current_depth=current_depth+1)
...@@ -118,7 +118,7 @@ class GridTransitionMap(TransitionMap): ...@@ -118,7 +118,7 @@ class GridTransitionMap(TransitionMap):
Width of the grid. Width of the grid.
height : int height : int
Height of the grid. Height of the grid.
transitions_class : Transitions object transitions : Transitions object
The Transitions object to use to encode/decode transitions over the The Transitions object to use to encode/decode transitions over the
grid. grid.
...@@ -243,6 +243,54 @@ class GridTransitionMap(TransitionMap): ...@@ -243,6 +243,54 @@ class GridTransitionMap(TransitionMap):
return return
self.transitions.set_transition(self.grid[cell_id[0]][cell_id[1]], cell_id[2], transition_index, new_transition) self.transitions.set_transition(self.grid[cell_id[0]][cell_id[1]], cell_id[2], transition_index, new_transition)
def save_transition_map(self, filename):
"""
Save the transitions grid as `filename', in npy format.
Parameters
----------
filename : string
Name of the file to which to save the transitions grid.
"""
np.save(filename, self.grid)
def load_transition_map(self, filename, override_gridsize=True):
"""
Load the transitions grid from `filename' (npy format).
The load function only updates the transitions grid, and possibly width and height, but the object has to be
initialized with the correct `transitions' object anyway.
Parameters
----------
filename : string
Name of the file from which to load the transitions grid.
override_gridsize : bool
If override_gridsize=True, the width and height of the GridTransitionMap object are replaced with the size
of the map loaded from `filename'. If override_gridsize=False, the transitions grid is either cropped (if
the grid size is larger than (height,width) ) or padded with zeros (if the grid size is smaller than
(height,width) )
"""
new_grid = np.load(filename)
new_height = new_grid.shape[0]
new_width = new_grid.shape[1]
if override_gridsize:
self.width = new_width
self.height = new_height
self.grid = new_grid
else:
if new_grid.dtype == np.uint16:
self.grid = np.zeros((self.height, self.width), dtype=np.uint16)
elif new_grid.dtype == np.uint64:
self.grid = np.zeros((self.height, self.width), dtype=np.uint64)
self.grid[0:min(self.height, new_height),
0:min(self.width, new_width)] = new_grid[0:min(self.height, new_height),
0:min(self.width, new_width)]
# TODO: GIACOMO: is it better to provide those methods with lists of cell_ids # TODO: GIACOMO: is it better to provide those methods with lists of cell_ids
# (most general implementation) or to make Grid-class specific methods for # (most general implementation) or to make Grid-class specific methods for
......
...@@ -4,13 +4,12 @@ Definition of the RailEnv environment and related level-generation functions. ...@@ -4,13 +4,12 @@ Definition of the RailEnv environment and related level-generation functions.
Generator functions are functions that take width, height and num_resets as arguments and return Generator functions are functions that take width, height and num_resets as arguments and return
a GridTransitionMap object. a GridTransitionMap object.
""" """
import random
import numpy as np import numpy as np
from flatland.core.env import Environment from flatland.core.env import Environment
from flatland.core.env_observation_builder import TreeObsForRailEnv from flatland.core.env_observation_builder import TreeObsForRailEnv
from flatland.core.transitions import RailEnvTransitions from flatland.core.transitions import Grid8Transitions, RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap from flatland.core.transition_map import GridTransitionMap
...@@ -75,6 +74,33 @@ def rail_from_GridTransitionMap_generator(rail_map): ...@@ -75,6 +74,33 @@ def rail_from_GridTransitionMap_generator(rail_map):
return generator return generator
def rail_from_list_of_saved_GridTransitionMap_generator(list_of_filenames):
"""
Utility to sequentially and cyclically return GridTransitionMap-s from a list of files, on each environment reset.
Parameters
-------
list_of_filenames : list
List of filenames with the saved grids to load.
Returns
-------
function
Generator function that always returns the given `rail_map' object.
"""
def generator(width, height, num_resets=0):
t_utils = RailEnvTransitions()
rail_map = GridTransitionMap(width=width, height=height, transitions=t_utils)
rail_map.load_transition_map(list_of_filenames[num_resets % len(list_of_filenames)], override_gridsize=False)
if rail_map.grid.dtype == np.uint64:
rail_map.transitions = Grid8Transitions()
return rail_map
return generator
""" """
def generate_rail_from_list_of_manual_specifications(list_of_specifications) def generate_rail_from_list_of_manual_specifications(list_of_specifications)
def generator(width, height, num_resets=0): def generator(width, height, num_resets=0):
...@@ -172,7 +198,8 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8): ...@@ -172,7 +198,8 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8):
num_insertions = 0 num_insertions = 0
while num_insertions < MAX_INSERTIONS and len(cells_to_fill) > 0: while num_insertions < MAX_INSERTIONS and len(cells_to_fill) > 0:
cell = random.sample(cells_to_fill, 1)[0] # cell = random.sample(cells_to_fill, 1)[0]
cell = cells_to_fill[np.random.choice(len(cells_to_fill), 1)[0]]
cells_to_fill.remove(cell) cells_to_fill.remove(cell)
row = cell[0] row = cell[0]
col = cell[1] col = cell[1]
...@@ -218,7 +245,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8): ...@@ -218,7 +245,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8):
rot = 90 rot = 90
rail[row][col] = t_utils.rotate_transition( rail[row][col] = t_utils.rotate_transition(
int('0000000000100000', 2), rot) int('0010000000000000', 2), rot)
num_insertions += 1 num_insertions += 1
break break
...@@ -257,8 +284,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8): ...@@ -257,8 +284,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8):
rail[replace_row][replace_col] = None rail[replace_row][replace_col] = None
possible_transitions, possible_probabilities = zip(*besttrans) possible_transitions, possible_probabilities = zip(*besttrans)
possible_probabilities = \ possible_probabilities = [p/sum(possible_probabilities) for p in possible_probabilities]
np.exp(possible_probabilities) / sum(np.exp(possible_probabilities))
rail[row][col] = np.random.choice(possible_transitions, rail[row][col] = np.random.choice(possible_transitions,
p=possible_probabilities) p=possible_probabilities)
...@@ -272,7 +298,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8): ...@@ -272,7 +298,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8):
else: else:
possible_transitions, possible_probabilities = zip(*possible_cell_transitions) possible_transitions, possible_probabilities = zip(*possible_cell_transitions)
possible_probabilities = np.exp(possible_probabilities) / sum(np.exp(possible_probabilities)) possible_probabilities = [p/sum(possible_probabilities) for p in possible_probabilities]
rail[row][col] = np.random.choice(possible_transitions, rail[row][col] = np.random.choice(possible_transitions,
p=possible_probabilities) p=possible_probabilities)
...@@ -300,7 +326,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8): ...@@ -300,7 +326,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8):
max_bit = max_bit | (neigh_trans_from_direction & 1) max_bit = max_bit | (neigh_trans_from_direction & 1)
if max_bit: if max_bit:
rail[r][0] = t_utils.rotate_transition( rail[r][0] = t_utils.rotate_transition(
int('0000000000100000', 2), 270) int('0010000000000000', 2), 270)
else: else:
rail[r][0] = int('0000000000000000', 2) rail[r][0] = int('0000000000000000', 2)
...@@ -313,7 +339,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8): ...@@ -313,7 +339,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8):
& (2**4-1) & (2**4-1)
max_bit = max_bit | (neigh_trans_from_direction & (1 << 2)) max_bit = max_bit | (neigh_trans_from_direction & (1 << 2))
if max_bit: if max_bit:
rail[r][-1] = t_utils.rotate_transition(int('0000000000100000', 2), rail[r][-1] = t_utils.rotate_transition(int('0010000000000000', 2),
90) 90)
else: else:
rail[r][-1] = int('0000000000000000', 2) rail[r][-1] = int('0000000000000000', 2)
...@@ -328,7 +354,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8): ...@@ -328,7 +354,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8):
& (2**4-1) & (2**4-1)
max_bit = max_bit | (neigh_trans_from_direction & (1 << 3)) max_bit = max_bit | (neigh_trans_from_direction & (1 << 3))
if max_bit: if max_bit:
rail[0][c] = int('0000000000100000', 2) rail[0][c] = int('0010000000000000', 2)
else: else:
rail[0][c] = int('0000000000000000', 2) rail[0][c] = int('0000000000000000', 2)
...@@ -342,7 +368,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8): ...@@ -342,7 +368,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8):
max_bit = max_bit | (neigh_trans_from_direction & (1 << 1)) max_bit = max_bit | (neigh_trans_from_direction & (1 << 1))
if max_bit: if max_bit:
rail[-1][c] = t_utils.rotate_transition( rail[-1][c] = t_utils.rotate_transition(
int('0000000000100000', 2), 180) int('0010000000000000', 2), 180)
else: else:
rail[-1][c] = int('0000000000000000', 2) rail[-1][c] = int('0000000000000000', 2)
...@@ -353,6 +379,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8): ...@@ -353,6 +379,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8):
rail[r][c] = int('0000000000000000', 2) rail[r][c] = int('0000000000000000', 2)
tmp_rail = np.asarray(rail, dtype=np.uint16) tmp_rail = np.asarray(rail, dtype=np.uint16)
return_rail = GridTransitionMap(width=width, height=height, transitions=t_utils) return_rail = GridTransitionMap(width=width, height=height, transitions=t_utils)
return_rail.grid = tmp_rail return_rail.grid = tmp_rail
return return_rail return return_rail
...@@ -388,7 +415,7 @@ class RailEnv(Environment): ...@@ -388,7 +415,7 @@ class RailEnv(Environment):
def __init__(self, def __init__(self,
width, width,
height, height,
rail_generator=random_rail_generator, rail_generator=random_rail_generator(),
number_of_agents=1, number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2)): obs_builder_object=TreeObsForRailEnv(max_depth=2)):
""" """
...@@ -467,10 +494,14 @@ class RailEnv(Environment): ...@@ -467,10 +494,14 @@ class RailEnv(Environment):
if self.rail.get_transitions((r, c)) > 0: if self.rail.get_transitions((r, c)) > 0:
valid_positions.append((r, c)) valid_positions.append((r, c))
self.agents_position = random.sample(valid_positions, # self.agents_position = random.sample(valid_positions,
self.number_of_agents) # self.number_of_agents)
self.agents_target = random.sample(valid_positions, self.agents_position = [
self.number_of_agents) valid_positions[i] for i in
np.random.choice(len(valid_positions), self.number_of_agents)]
self.agents_target = [
valid_positions[i] for i in
np.random.choice(len(valid_positions), self.number_of_agents)]
# agents_direction must be a direction for which a solution is # agents_direction must be a direction for which a solution is
# guaranteed. # guaranteed.
...@@ -498,8 +529,8 @@ class RailEnv(Environment): ...@@ -498,8 +529,8 @@ class RailEnv(Environment):
if len(valid_starting_directions) == 0: if len(valid_starting_directions) == 0:
re_generate = True re_generate = True
else: else:
self.agents_direction[i] = random.sample( self.agents_direction[i] = valid_starting_directions[
valid_starting_directions, 1)[0] np.random.choice(len(valid_starting_directions), 1)[0]]
# Reset the state of the observation builder with the new environment # Reset the state of the observation builder with the new environment
self.obs_builder.reset() self.obs_builder.reset()
......
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
...@@ -6,7 +6,6 @@ Tests for `flatland` package. ...@@ -6,7 +6,6 @@ Tests for `flatland` package.
from flatland.envs.rail_env import RailEnv, random_rail_generator from flatland.envs.rail_env import RailEnv, random_rail_generator
import numpy as np import numpy as np
import random
import os import os
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
...@@ -35,8 +34,9 @@ def checkFrozenImage(sFileImage): ...@@ -35,8 +34,9 @@ def checkFrozenImage(sFileImage):
def test_render_env(): def test_render_env():
random.seed(100) # random.seed(100)
oEnv = RailEnv(width=10, height=10, rail_generator=random_rail_generator, number_of_agents=2) np.random.seed(100)
oEnv = RailEnv(width=10, height=10, rail_generator=random_rail_generator(), number_of_agents=2)
oEnv.reset() oEnv.reset()
oRT = rt.RenderTool(oEnv) oRT = rt.RenderTool(oEnv)
plt.figure(figsize=(10, 10)) plt.figure(figsize=(10, 10))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment