Skip to content
Snippets Groups Projects
Commit 3504c01f authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

finally fixed all direction errors with the x-axis...

parent 9b69afb6
No related branches found
No related tags found
No related merge requests found
......@@ -32,14 +32,14 @@ speed_ration_map = {1.: 0.25, # Fast passenger train
env = RailEnv(width=50,
height=50,
rail_generator=sparse_rail_generator(num_cities=11, # Number of cities in map (where train stations are)
num_trainstations=50, # Number of possible start/targets on map
rail_generator=sparse_rail_generator(num_cities=9, # Number of cities in map (where train stations are)
num_trainstations=0, # Number of possible start/targets on map
min_node_dist=8, # Minimal distance of nodes
node_radius=3, # Proximity of stations to city center
seed=15, # Random seed
grid_mode=False,
grid_mode=True,
max_connection_points_per_side=2,
max_nr_connection_directions=2,
max_nr_connection_directions=4
),
schedule_generator=sparse_schedule_generator(),
number_of_agents=50,
......
......@@ -293,7 +293,7 @@ def distance_on_rail(pos1, pos2):
return np.sqrt(np.power(pos1[0] - pos2[0], 2) + np.power(pos1[1] - pos2[1], 2))
def closest_direction(pos1, pos2):
def direction_to_point(pos1, pos2):
"""
Returns the closest direction orientation of position 2 relative to position 1
:param pos1: position we are interested in
......@@ -305,9 +305,9 @@ def closest_direction(pos1, pos2):
direction = np.sign(diff_vec[axis])
if axis == 0:
if direction > 0:
return 2
else:
return 0
else:
return 2
else:
if direction > 0:
return 3
......
......@@ -6,7 +6,7 @@ import msgpack
import numpy as np
from flatland.core.grid.grid4_utils import get_direction, mirror
from flatland.core.grid.grid_utils import distance_on_rail, closest_direction
from flatland.core.grid.grid_utils import distance_on_rail, direction_to_point
from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.grid4_generators_utils import connect_rail, connect_nodes
......@@ -701,7 +701,8 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
# TODO: Change the way this code works! Check that we get sufficient direction.
# TODO: Check if this works as expected
while len(connection_sides_idx) < max_nr_connection_directions and idx < len(neighb_dist):
current_closest_direction = closest_direction(node_position, node_positions[closest_neighb_idx[idx]])
current_closest_direction = direction_to_point(node_position, node_positions[closest_neighb_idx[idx]])
print(node_position)
if current_closest_direction not in connection_sides_idx:
connection_sides_idx.append(current_closest_direction)
idx += 1
......@@ -711,7 +712,7 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
connections_per_direction = np.zeros(4, dtype=int)
for idx in connection_sides_idx:
nr_of_connection_points = np.random.randint(1, max_nr_connection_points + 1)
nr_of_connection_points = max_nr_connection_points # np.random.randint(1, max_nr_connection_points + 1)
connections_per_direction[idx] = nr_of_connection_points
connection_points_coordinates = []
......@@ -722,13 +723,13 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
for connection_idx in range(connections_per_direction[direction]):
if direction == 0:
connection_points_coordinates.append(
(node_position[0] + node_size, node_position[1] + connection_slots[connection_idx]))
(node_position[0] - node_size, node_position[1] + connection_slots[connection_idx]))
if direction == 1:
connection_points_coordinates.append(
(node_position[0] + connection_slots[connection_idx], node_position[1] + node_size))
if direction == 2:
connection_points_coordinates.append(
(node_position[0] - node_size, node_position[1] + connection_slots[connection_idx]))
(node_position[0] + node_size, node_position[1] + connection_slots[connection_idx]))
if direction == 3:
connection_points_coordinates.append(
(node_position[0] + connection_slots[connection_idx], node_position[1] - node_size))
......@@ -737,8 +738,7 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
connection_info.append(connections_per_direction)
return connection_points, connection_info
def _connect_cities(node_positions, connection_points, connection_info, rail_trans,
grid_map):
def _connect_cities(node_positions, connection_points, connection_info, rail_trans, grid_map):
"""
Function to connect the different cities through their connection points
:param node_positions: Positions of city centers
......@@ -755,7 +755,7 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
if nbr_connection_points > 0:
neighb_idx = _closest_neigh_in_direction(current_node, direction, node_positions)
print(current_node, node_positions[current_node], direction, neighb_idx,
connection_info[current_node])
connection_info[current_node], connection_points[current_node])
else:
direction += 1
continue
......@@ -914,17 +914,17 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
def _closest_neigh_in_direction(current_node, direction, node_positions):
# Sort available neighbors according to their distance.
available_nodes = np.arange(len(node_positions))
node_dist = []
for av_node in available_nodes:
for av_node in range(len(node_positions)):
node_dist.append(distance_on_rail(node_positions[current_node], node_positions[av_node]))
sorted_neighbours = available_nodes[np.argsort(node_dist)]
sorted_neighbours = np.argsort(node_dist)
for neighb in sorted_neighbours[1:]:
distance_0 = np.abs(node_positions[current_node][0] - node_positions[neighb][0])
distance_1 = np.abs(node_positions[current_node][1] - node_positions[neighb][1])
if direction == 0:
if node_positions[neighb][0] > node_positions[current_node][0] and distance_1 <= distance_0:
if node_positions[neighb][0] < node_positions[current_node][0] and distance_1 <= distance_0:
return neighb
if direction == 1:
......@@ -932,7 +932,7 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
return neighb
if direction == 2:
if node_positions[neighb][0] < node_positions[current_node][0] and distance_1 <= distance_0:
if node_positions[neighb][0] > node_positions[current_node][0] and distance_1 <= distance_0:
return neighb
if direction == 3:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment