Skip to content
Snippets Groups Projects
Commit 5a13cff6 authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

updated inner city connections

parent 95ba74c8
No related branches found
No related tags found
No related merge requests found
......@@ -9,7 +9,7 @@ from flatland.core.grid.grid4_utils import get_direction, mirror
from flatland.core.grid.grid_utils import distance_on_rail
from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.grid4_generators_utils import connect_rail, connect_nodes, connect_from_nodes
from flatland.envs.grid4_generators_utils import connect_rail, connect_nodes
RailGeneratorProduct = Tuple[GridTransitionMap, Optional[Dict]]
RailGenerator = Callable[[int, int, int, int], RailGeneratorProduct]
......@@ -691,9 +691,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
0,
width - 1)
tries = 0
while (station_x, station_y) in train_stations[trainstation_node] \
or (station_x, station_y) == node_positions[trainstation_node] \
or rail_array[(station_x, station_y)] != 0: # noqa: E125
while (station_x, station_y) in train_stations[trainstation_node]:
station_x = np.clip(
node_positions[trainstation_node][0] + np.random.randint(-reduced_node_radius,
......@@ -717,14 +715,13 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
# Connect train station to random nodes
rand_corner_nodes = np.random.choice(range(len(connection_points[trainstation_node])), 2, replace=False)
connection_1 = connect_from_nodes(rail_trans, grid_map,
connection_points[trainstation_node][rand_corner_nodes[0]],
(station_x, station_y))
connection_2 = connect_from_nodes(rail_trans, grid_map,
connection_points[trainstation_node][rand_corner_nodes[1]],
(station_x, station_y))
for corner_node_idx in rand_corner_nodes:
connection = connect_nodes(rail_trans, grid_map,
connection_points[trainstation_node][corner_node_idx],
(station_x, station_y))
grid_map.fix_transitions((station_x, station_y))
# Check if connection was made
if len(connection_1) == 0 and len(connection_2) == 0:
if len(connection) == 0:
if len(train_stations[trainstation_node]) > 0:
train_stations[trainstation_node].pop(-1)
else:
......@@ -777,7 +774,9 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
num_agents -= 1
return grid_map, {'agents_hints': {
'num_agents': num_agents
'num_agents': num_agents,
'agent_start_targets_nodes': agent_start_targets_nodes,
'train_stations': train_stations_slots
}}
def _generate_node_positions_not_grid_mode(city_positions, height, intersection_positions, nb_nodes,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment