Skip to content
Snippets Groups Projects
Commit c37f4f7d authored by hagrid67's avatar hagrid67
Browse files

updated random_rail_generator to use only numpy.random not random

updated test_rendertools to only set seed of numpy.random not plain random
to confirm that only the numpy random sequence is being used.
parent b12af97a
No related branches found
No related tags found
No related merge requests found
...@@ -4,7 +4,6 @@ Definition of the RailEnv environment and related level-generation functions. ...@@ -4,7 +4,6 @@ Definition of the RailEnv environment and related level-generation functions.
Generator functions are functions that take width, height and num_resets as arguments and return Generator functions are functions that take width, height and num_resets as arguments and return
a GridTransitionMap object. a GridTransitionMap object.
""" """
import random
import numpy as np import numpy as np
from flatland.core.env import Environment from flatland.core.env import Environment
...@@ -199,7 +198,8 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8): ...@@ -199,7 +198,8 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8):
num_insertions = 0 num_insertions = 0
while num_insertions < MAX_INSERTIONS and len(cells_to_fill) > 0: while num_insertions < MAX_INSERTIONS and len(cells_to_fill) > 0:
cell = random.sample(cells_to_fill, 1)[0] # cell = random.sample(cells_to_fill, 1)[0]
cell = cells_to_fill[np.random.choice(len(cells_to_fill), 1)[0]]
cells_to_fill.remove(cell) cells_to_fill.remove(cell)
row = cell[0] row = cell[0]
col = cell[1] col = cell[1]
...@@ -494,10 +494,14 @@ class RailEnv(Environment): ...@@ -494,10 +494,14 @@ class RailEnv(Environment):
if self.rail.get_transitions((r, c)) > 0: if self.rail.get_transitions((r, c)) > 0:
valid_positions.append((r, c)) valid_positions.append((r, c))
self.agents_position = random.sample(valid_positions, # self.agents_position = random.sample(valid_positions,
self.number_of_agents) # self.number_of_agents)
self.agents_target = random.sample(valid_positions, self.agents_position = [
self.number_of_agents) valid_positions[i] for i in
np.random.choice(len(valid_positions), self.number_of_agents)]
self.agents_target = [
valid_positions[i] for i in
np.random.choice(len(valid_positions), self.number_of_agents)]
# agents_direction must be a direction for which a solution is # agents_direction must be a direction for which a solution is
# guaranteed. # guaranteed.
...@@ -525,8 +529,8 @@ class RailEnv(Environment): ...@@ -525,8 +529,8 @@ class RailEnv(Environment):
if len(valid_starting_directions) == 0: if len(valid_starting_directions) == 0:
re_generate = True re_generate = True
else: else:
self.agents_direction[i] = random.sample( self.agents_direction[i] = valid_starting_directions[
valid_starting_directions, 1)[0] np.random.choice(len(valid_starting_directions), 1)[0]]
# Reset the state of the observation builder with the new environment # Reset the state of the observation builder with the new environment
self.obs_builder.reset() self.obs_builder.reset()
......
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
...@@ -6,7 +6,6 @@ Tests for `flatland` package. ...@@ -6,7 +6,6 @@ Tests for `flatland` package.
from flatland.envs.rail_env import RailEnv, random_rail_generator from flatland.envs.rail_env import RailEnv, random_rail_generator
import numpy as np import numpy as np
import random
import os import os
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
...@@ -35,7 +34,7 @@ def checkFrozenImage(sFileImage): ...@@ -35,7 +34,7 @@ def checkFrozenImage(sFileImage):
def test_render_env(): def test_render_env():
random.seed(100) # random.seed(100)
np.random.seed(100) np.random.seed(100)
oEnv = RailEnv(width=10, height=10, rail_generator=random_rail_generator(), number_of_agents=2) oEnv = RailEnv(width=10, height=10, rail_generator=random_rail_generator(), number_of_agents=2)
oEnv.reset() oEnv.reset()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment