diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index a875304aded202b09d133db342a647f5fbefb6fd..22c072a838085d40f045a496b1aa6cc8aa778fc3 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -4,7 +4,6 @@ Definition of the RailEnv environment and related level-generation functions. Generator functions are functions that take width, height and num_resets as arguments and return a GridTransitionMap object. """ -import random import numpy as np from flatland.core.env import Environment @@ -199,7 +198,8 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8): num_insertions = 0 while num_insertions < MAX_INSERTIONS and len(cells_to_fill) > 0: - cell = random.sample(cells_to_fill, 1)[0] + # cell = random.sample(cells_to_fill, 1)[0] + cell = cells_to_fill[np.random.choice(len(cells_to_fill), 1)[0]] cells_to_fill.remove(cell) row = cell[0] col = cell[1] @@ -494,10 +494,14 @@ class RailEnv(Environment): if self.rail.get_transitions((r, c)) > 0: valid_positions.append((r, c)) - self.agents_position = random.sample(valid_positions, - self.number_of_agents) - self.agents_target = random.sample(valid_positions, - self.number_of_agents) + # self.agents_position = random.sample(valid_positions, + # self.number_of_agents) + self.agents_position = [ + valid_positions[i] for i in + np.random.choice(len(valid_positions), self.number_of_agents)] + self.agents_target = [ + valid_positions[i] for i in + np.random.choice(len(valid_positions), self.number_of_agents)] # agents_direction must be a direction for which a solution is # guaranteed. @@ -525,8 +529,8 @@ class RailEnv(Environment): if len(valid_starting_directions) == 0: re_generate = True else: - self.agents_direction[i] = random.sample( - valid_starting_directions, 1)[0] + self.agents_direction[i] = valid_starting_directions[ + np.random.choice(len(valid_starting_directions), 1)[0]] # Reset the state of the observation builder with the new environment self.obs_builder.reset() diff --git a/images/basic-env.png b/images/basic-env.png index 2dc0e66e377f8125738fc991ee81c714e25e2ded..850d6ecad2d1adb6d3d4f829116acee67b9441db 100644 Binary files a/images/basic-env.png and b/images/basic-env.png differ diff --git a/images/env-path.png b/images/env-path.png index 385133fd94202c8e11d1e62301c2d1b49b722ef4..95b9faa9fbe78e36c49216058274dbd18495cc12 100644 Binary files a/images/env-path.png and b/images/env-path.png differ diff --git a/images/env-tree-graph.png b/images/env-tree-graph.png index 80ffd084718ab98dc36461e437bb01b8dd2e9ddd..f33b5f4c8d69ab2c028e6cc1689f4ccbb7600ce3 100644 Binary files a/images/env-tree-graph.png and b/images/env-tree-graph.png differ diff --git a/images/env-tree-spatial.png b/images/env-tree-spatial.png index 578d0c62be980bee6d9119e4d85de24c8a988b65..54ac9bfc0c2cb0853a319368add4d6ac5514fd28 100644 Binary files a/images/env-tree-spatial.png and b/images/env-tree-spatial.png differ diff --git a/tests/test_rendertools.py b/tests/test_rendertools.py index 9c76f3c23cde79263cb9754135865c866a43d804..0bdc47bace289656181b01b1f44344e4322363a0 100644 --- a/tests/test_rendertools.py +++ b/tests/test_rendertools.py @@ -6,7 +6,6 @@ Tests for `flatland` package. from flatland.envs.rail_env import RailEnv, random_rail_generator import numpy as np -import random import os import matplotlib.pyplot as plt @@ -35,7 +34,7 @@ def checkFrozenImage(sFileImage): def test_render_env(): - random.seed(100) + # random.seed(100) np.random.seed(100) oEnv = RailEnv(width=10, height=10, rail_generator=random_rail_generator(), number_of_agents=2) oEnv.reset()