Skip to content
Snippets Groups Projects
Commit ba13fac0 authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

updated start orientation. currently agents start orientation depends on the track they start on.

parent 90fa3a42
No related branches found
No related tags found
No related merge requests found
import time
import numpy as np
from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv
......@@ -39,7 +37,7 @@ env = RailEnv(width=50,
max_tracks_in_city=4,
),
schedule_generator=sparse_schedule_generator(),
number_of_agents=10,
number_of_agents=50,
stochastic_data=stochastic_data, # Malfunction data generator
obs_builder_object=GlobalObsForRailEnv())
......@@ -113,7 +111,6 @@ for step in range(500):
# reward and whether their are done
next_obs, all_rewards, done, _ = env.step(action_dict)
env_renderer.render_env(show=True, show_observations=False, show_predictions=False)
time.sleep(100)
frame_step += 1
# Update replay buffer and train agent
for a in range(env.get_num_agents()):
......
......@@ -575,7 +575,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
print("City position time", time.time() - node_time_start, "Seconds")
# Set up connection points for all cities
node_connection_time = time.time()
inner_connection_points, outer_connection_points, connection_info = _generate_node_connection_points(
inner_connection_points, outer_connection_points, connection_info, city_orientations = _generate_node_connection_points(
node_positions, node_radius, max_inter_city_rails_allowed,
max_tracks_in_city)
print("Connection points", time.time() - node_connection_time)
......@@ -594,8 +594,10 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
print("City build time", time.time() - city_build_time)
# Populate cities
train_station_time = time.time()
train_stations, built_num_trainstation = _set_trainstation_positions(node_positions, through_tracks,
node_radius, grid_map)
train_stations, track_numbers, built_num_trainstation = _set_trainstation_positions(node_positions,
city_orientations,
through_tracks,
node_radius, grid_map)
print("Trainstation placing time", time.time() - train_station_time)
# Adjust the number of agents if you could not build enough trainstations
......@@ -615,7 +617,9 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
return grid_map, {'agents_hints': {
'num_agents': num_agents,
'agent_start_targets_nodes': agent_start_targets_nodes,
'train_stations': train_stations
'train_stations': train_stations,
'city_orientations': city_orientations,
'track_numbers': track_numbers
}}
def _generate_random_node_positions(nb_nodes, node_radius, height, width):
......@@ -669,6 +673,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
inner_connection_points = []
outer_connection_points = []
connection_info = []
city_orientations = []
for node_position in node_positions:
# Chose the directions where close cities are situated
......@@ -683,7 +688,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
current_closest_direction = direction_to_point(node_position, node_positions[closest_neighb_idx[idx]])
connection_sides_idx.append(current_closest_direction)
connection_sides_idx.append((current_closest_direction + 2) % 4)
city_orientations.append(current_closest_direction)
# set the number of tracks within a city, at least 2 tracks per city
connections_per_direction = np.zeros(4, dtype=int)
nr_of_connection_points = np.random.randint(2, tracks_in_city + 1)
......@@ -716,7 +721,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
inner_connection_points.append(connection_points_coordinates_inner)
outer_connection_points.append(connection_points_coordinates_outer)
connection_info.append(connections_per_direction)
return inner_connection_points, outer_connection_points, connection_info
return inner_connection_points, outer_connection_points, connection_info, city_orientations
def _connect_cities(node_positions, connection_points, connection_info, city_cells,
rail_trans, grid_map):
......@@ -792,7 +797,6 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
break
opposite_boarder = (boarder + 2) % 4
track_direction = opposite_boarder
boarder_one = inner_connection_points[current_city][boarder]
boarder_two = inner_connection_points[current_city][opposite_boarder]
......@@ -820,7 +824,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
return through_path_cells
def _set_trainstation_positions(node_positions, through_tracks, node_radius, grid_map):
def _set_trainstation_positions(node_positions, city_orientations, through_tracks, node_radius, grid_map):
"""
:param node_positions:
......@@ -829,6 +833,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
"""
nb_nodes = len(node_positions)
train_stations = [[] for i in range(nb_nodes)]
train_station_orientations = [[] for i in range(nb_nodes)]
built_num_trainstations = 0
for current_city in range(len(node_positions)):
for possible_location in _city_cells(node_positions[current_city], node_radius - 1):
......@@ -841,8 +846,11 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
cell_type = cell_type >> 1
if 1 <= nbits <= 2:
built_num_trainstations += 1
track_nbr = _track_number(node_positions[current_city], city_orientations[current_city],
possible_location)
train_stations[current_city].append(possible_location)
return train_stations, built_num_trainstations
train_station_orientations[current_city].append(track_nbr)
return train_stations, train_station_orientations, built_num_trainstations
def _generate_start_target_pairs(num_agents, nb_nodes, train_stations):
"""
......@@ -970,4 +978,17 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
def _city_overlap(center_1, center_2, radius):
return (np.abs(center_1[0] - center_2[0]) < radius and np.abs(center_1[1] - center_2[1]) < radius)
def _track_number(city_position, city_orientation, position):
"""
FUnction that tells you if you are on even or uneven track number
:param city_position:
:param city_orientation:
:param position:
:return:
"""
if city_orientation % 2 == 0:
return np.abs(city_position[1] - position[1]) % 2
else:
return np.abs(city_position[0] - position[0]) % 2
return generator
......@@ -62,6 +62,8 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None) ->
train_stations = hints['train_stations']
agent_start_targets_nodes = hints['agent_start_targets_nodes']
max_num_agents = hints['num_agents']
city_orientations = hints['city_orientations']
track_numbers = hints['track_numbers']
if num_agents > max_num_agents:
num_agents = max_num_agents
warnings.warn("Too many agents! Changes number of agents.")
......@@ -89,6 +91,7 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None) ->
current_start_node = agent_start_targets_nodes[agent_idx][0]
start_station_idx = np.random.randint(len(train_stations[current_start_node]))
start = train_stations[current_start_node][start_station_idx]
current_track_nbr = track_numbers[current_start_node][start_station_idx]
tries = 0
while (start[0], start[1]) in agents_position:
tries += 1
......@@ -97,15 +100,16 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None) ->
break
start_station_idx = np.random.randint(len(train_stations[current_start_node]))
start = train_stations[current_start_node][start_station_idx]
current_track_nbr = track_numbers[current_start_node][start_station_idx]
agents_position.append((start[0], start[1]))
# Orient the agent correctly
for orientation in range(4):
transitions = rail.get_transitions(start[0], start[1], orientation)
if any(transitions) > 0 and rail.check_path_exists(start, orientation, target):
agents_direction.append(orientation)
break
if current_track_nbr % 2 != 0:
agents_direction.append(city_orientations[current_start_node])
else:
agents_direction.append((city_orientations[current_start_node] + 2) % 2)
if speed_ratio_map:
speeds = speed_initialization_helper(num_agents, speed_ratio_map)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment