Skip to content
Snippets Groups Projects
Commit c37f4f7d authored by hagrid67's avatar hagrid67
Browse files

updated random_rail_generator to use only numpy.random not random

updated test_rendertools to only set seed of numpy.random not plain random
to confirm that only the numpy random sequence is being used.
parent b12af97a
No related branches found
No related tags found
No related merge requests found
......@@ -4,7 +4,6 @@ Definition of the RailEnv environment and related level-generation functions.
Generator functions are functions that take width, height and num_resets as arguments and return
a GridTransitionMap object.
"""
import random
import numpy as np
from flatland.core.env import Environment
......@@ -199,7 +198,8 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8):
num_insertions = 0
while num_insertions < MAX_INSERTIONS and len(cells_to_fill) > 0:
cell = random.sample(cells_to_fill, 1)[0]
# cell = random.sample(cells_to_fill, 1)[0]
cell = cells_to_fill[np.random.choice(len(cells_to_fill), 1)[0]]
cells_to_fill.remove(cell)
row = cell[0]
col = cell[1]
......@@ -494,10 +494,14 @@ class RailEnv(Environment):
if self.rail.get_transitions((r, c)) > 0:
valid_positions.append((r, c))
self.agents_position = random.sample(valid_positions,
self.number_of_agents)
self.agents_target = random.sample(valid_positions,
self.number_of_agents)
# self.agents_position = random.sample(valid_positions,
# self.number_of_agents)
self.agents_position = [
valid_positions[i] for i in
np.random.choice(len(valid_positions), self.number_of_agents)]
self.agents_target = [
valid_positions[i] for i in
np.random.choice(len(valid_positions), self.number_of_agents)]
# agents_direction must be a direction for which a solution is
# guaranteed.
......@@ -525,8 +529,8 @@ class RailEnv(Environment):
if len(valid_starting_directions) == 0:
re_generate = True
else:
self.agents_direction[i] = random.sample(
valid_starting_directions, 1)[0]
self.agents_direction[i] = valid_starting_directions[
np.random.choice(len(valid_starting_directions), 1)[0]]
# Reset the state of the observation builder with the new environment
self.obs_builder.reset()
......
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
......@@ -6,7 +6,6 @@ Tests for `flatland` package.
from flatland.envs.rail_env import RailEnv, random_rail_generator
import numpy as np
import random
import os
import matplotlib.pyplot as plt
......@@ -35,7 +34,7 @@ def checkFrozenImage(sFileImage):
def test_render_env():
random.seed(100)
# random.seed(100)
np.random.seed(100)
oEnv = RailEnv(width=10, height=10, rail_generator=random_rail_generator(), number_of_agents=2)
oEnv.reset()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment