Skip to content
Snippets Groups Projects
Commit 3e902906 authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

checking that targets don't end up on a connecting rail

parent 67a533f2
No related branches found
No related tags found
No related merge requests found
...@@ -693,7 +693,7 @@ def realistic_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dis ...@@ -693,7 +693,7 @@ def realistic_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dis
return generator return generator
def sparse_rail_generator(nr_train_stations=1, nr_nodes=100, max_neigbours=2, min_node_dist=20, node_radius=2, def sparse_rail_generator(nr_nodes=100, max_neigbours=2, min_node_dist=20, node_radius=2,
seed=0): seed=0):
''' '''
...@@ -706,10 +706,6 @@ def sparse_rail_generator(nr_train_stations=1, nr_nodes=100, max_neigbours=2, mi ...@@ -706,10 +706,6 @@ def sparse_rail_generator(nr_train_stations=1, nr_nodes=100, max_neigbours=2, mi
''' '''
def generator(width, height, num_agents, num_resets=0): def generator(width, height, num_agents, num_resets=0):
if num_agents > nr_train_stations:
num_agents = nr_train_stations
warnings.warn("complex_rail_generator: num_agents > nr_start_goal, changing num_agents")
rail_trans = RailEnvTransitions() rail_trans = RailEnvTransitions()
grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans) grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
rail_array = grid_map.grid rail_array = grid_map.grid
...@@ -780,26 +776,24 @@ def sparse_rail_generator(nr_train_stations=1, nr_nodes=100, max_neigbours=2, mi ...@@ -780,26 +776,24 @@ def sparse_rail_generator(nr_train_stations=1, nr_nodes=100, max_neigbours=2, mi
height - 1) height - 1)
target_y = np.clip(node_positions[start_target[1]][1] + np.random.randint(-node_radius, node_radius), 0, target_y = np.clip(node_positions[start_target[1]][1] + np.random.randint(-node_radius, node_radius), 0,
width - 1) width - 1)
if agent_idx == 0: # Make sure we don't put to starts or targets on same cell
agents_position.append((start_x, start_y)) while (start_x, start_y) in agents_position or (start_x, start_y) == node_positions[start_target[0]]:
agents_target.append((target_x, target_y)) start_x = np.clip(node_positions[start_target[0]][0] + np.random.randint(-node_radius, node_radius),
else: 0,
# Make sure we don't put to starts or targets on same cell height - 1)
while (start_x, start_y) in agents_position or (target_x, target_y) in agents_target: start_y = np.clip(node_positions[start_target[0]][1] + np.random.randint(-node_radius, node_radius),
start_x = np.clip(node_positions[start_target[0]][0] + np.random.randint(-node_radius, node_radius), 0,
0, width - 1)
height - 1) while (target_x, target_y) in agents_target or (target_x, target_y) == node_positions[start_target[1]] or \
start_y = np.clip(node_positions[start_target[0]][1] + np.random.randint(-node_radius, node_radius), rail_array[(target_x, target_y)] != 0:
0, target_x = np.clip(
width - 1) node_positions[start_target[1]][0] + np.random.randint(-node_radius, node_radius), 0,
target_x = np.clip( height - 1)
node_positions[start_target[1]][0] + np.random.randint(-node_radius, node_radius), 0, target_y = np.clip(
height - 1) node_positions[start_target[1]][1] + np.random.randint(-node_radius, node_radius), 0,
target_y = np.clip( width - 1)
node_positions[start_target[1]][1] + np.random.randint(-node_radius, node_radius), 0, agents_position.append((start_x, start_y))
width - 1) agents_target.append((target_x, target_y))
agents_position.append((start_x, start_y))
agents_target.append((target_x, target_y))
new_path = connect_to_nodes(rail_trans, rail_array, agents_position[agent_idx], new_path = connect_to_nodes(rail_trans, rail_array, agents_position[agent_idx],
node_positions[start_target[0]]) node_positions[start_target[0]])
new_path = connect_from_nodes(rail_trans, rail_array, node_positions[start_target[1]], new_path = connect_from_nodes(rail_trans, rail_array, node_positions[start_target[1]],
......
...@@ -10,9 +10,9 @@ def test_sparse_rail_generator(): ...@@ -10,9 +10,9 @@ def test_sparse_rail_generator():
env = RailEnv(width=20, env = RailEnv(width=20,
height=20, height=20,
rail_generator=sparse_rail_generator(nr_train_stations=3, nr_nodes=2, min_node_dist=5, rail_generator=sparse_rail_generator(nr_nodes=5, min_node_dist=8,
node_radius=4), node_radius=3),
number_of_agents=3, number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv()) obs_builder_object=GlobalObsForRailEnv())
# reset to initialize agents_static # reset to initialize agents_static
env_renderer = RenderTool(env, gl="PILSVG", ) env_renderer = RenderTool(env, gl="PILSVG", )
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment