Skip to content
Snippets Groups Projects
Commit 21d3448b authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

First impementation of sparse rail generator.

not functional yet
parent 6fade66c
No related branches found
No related tags found
No related merge requests found
...@@ -690,3 +690,93 @@ def realistic_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dis ...@@ -690,3 +690,93 @@ def realistic_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dis
return grid_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) return grid_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
return generator return generator
def sparse_rail_generator(nr_train_stations=1, nr_nodes=100, mean_node_neighbours=2, min_node_dist=20, node_radius=2,
seed=0):
'''
:param nr_train_stations:
:param nr_nodes:
:param mean_node_neighbours:
:param min_node_dist:
:param seed:
:return:
'''
def generator(width, height, num_agents, num_resets=0):
if num_agents > nr_train_stations:
num_agents = nr_train_stations
print("complex_rail_generator: num_agents > nr_start_goal, changing num_agents")
rail_trans = RailEnvTransitions()
grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
rail_array = grid_map.grid
rail_array.fill(0)
np.random.seed(seed + num_resets)
# Generate a set of nodes for the sparse network
node_positions = []
for node_idx in range(nr_nodes):
to_close = True
while to_close:
x_tmp = np.random.randint(width)
y_tmp = np.random.randint(height)
to_close = False
for node_pos in node_positions:
if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist:
to_close = True
if not to_close:
node_positions.append((x_tmp, y_tmp))
# Generate start and target node directory for all agents
agent_start_targets_nodes = []
for agent_idx in range(num_agents):
start_target_tuple = np.random.choice(nr_nodes, 2, replace=False)
agent_start_targets_nodes.append(start_target_tuple)
# Generate actual start and target locations from around nodes
agents_position = []
agents_target = []
agents_direction = []
agent_idx = 0
for start_target in agent_start_targets_nodes:
start_x = np.clip(node_positions[start_target[0]][0] + np.random.randint(-node_radius, node_radius), 0,
width - 1)
start_y = np.clip(node_positions[start_target[0]][1] + np.random.randint(-node_radius, node_radius), 0,
height - 1)
target_x = np.clip(node_positions[start_target[1]][0] + np.random.randint(-node_radius, node_radius), 0,
width - 1)
target_y = np.clip(node_positions[start_target[1]][1] + np.random.randint(-node_radius, node_radius), 0,
height - 1)
if agent_idx == 0:
agents_position.append((start_y, start_x))
agents_target.append((target_y, target_x))
else:
while ((start_x, start_y) in agents_position or (target_x, target_y) in agents_target):
start_x = np.clip(node_positions[start_target[0]][0] + np.random.randint(-node_radius, node_radius),
0,
width - 1)
start_y = np.clip(node_positions[start_target[0]][1] + np.random.randint(-node_radius, node_radius),
0,
height - 1)
target_x = np.clip(
node_positions[start_target[1]][0] + np.random.randint(-node_radius, node_radius), 0,
width - 1)
target_y = np.clip(
node_positions[start_target[1]][1] + np.random.randint(-node_radius, node_radius), 0,
height - 1)
agents_position.append((start_y, start_x))
agents_target.append((target_y, target_x))
agents_direction.append(0)
agent_idx += 1
print(agents_position)
print(agents_target)
print(node_positions)
for n in node_positions:
for m in node_positions:
print(distance_on_rail(n, m))
return grid_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
return generator
from flatland.envs.generators import sparse_rail_generator
from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.rail_env import RailEnv
def test_sparse_rail_generator():
env = RailEnv(width=20,
height=20,
rail_generator=sparse_rail_generator(nr_train_stations=10, nr_nodes=5, min_node_dist=10,
node_radius=4),
number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv())
# reset to initialize agents_static
env.reset()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment