Skip to content
Snippets Groups Projects
Commit ccd73a5b authored by u214892's avatar u214892
Browse files

merge #147 -> #141 schedule_generator for sparse_rail_generator

parent c6b273c5
No related branches found
No related tags found
No related merge requests found
import numpy as np
from flatland.envs.generators import sparse_rail_generator
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.schedule_generators import sparse_rail_generator_agents_placer
from flatland.utils.rendertools import RenderTool
np.random.seed(1)
......@@ -31,6 +32,7 @@ env = RailEnv(width=20,
realistic_mode=True,
enhance_intersection=True
),
agent_generator=sparse_rail_generator_agents_placer(),
number_of_agents=5,
stochastic_data=stochastic_data, # Malfunction data generator
obs_builder_object=TreeObservation)
......@@ -75,7 +77,6 @@ class RandomAgent:
# Set action space to 4 to remove stop action
agent = RandomAgent(218, 4)
# Empty dictionary for all agent action
action_dict = dict()
......
......@@ -786,48 +786,9 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
else:
num_agents -= 1
# Place agents and targets within available train stations
agents_position = []
agents_target = []
agents_direction = []
for agent_idx in range(num_agents):
# Set target for agent
current_target_node = agent_start_targets_nodes[agent_idx][1]
target_station_idx = np.random.randint(len(train_stations[current_target_node]))
target = train_stations[current_target_node][target_station_idx]
tries = 0
while (target[0], target[1]) in agents_target:
target_station_idx = np.random.randint(len(train_stations[current_target_node]))
target = train_stations[current_target_node][target_station_idx]
tries += 1
if tries > 100:
warnings.warn("Could not set target position, removing an agent")
break
agents_target.append((target[0], target[1]))
# Set start for agent
current_start_node = agent_start_targets_nodes[agent_idx][0]
start_station_idx = np.random.randint(len(train_stations[current_start_node]))
start = train_stations[current_start_node][start_station_idx]
tries = 0
while (start[0], start[1]) in agents_position:
tries += 1
if tries > 100:
warnings.warn("Could not set start position, please change initial parameters!!!!")
break
start_station_idx = np.random.randint(len(train_stations[current_start_node]))
start = train_stations[current_start_node][start_station_idx]
agents_position.append((start[0], start[1]))
# Orient the agent correctly
for orientation in range(4):
transitions = grid_map.get_transitions(start[0], start[1], orientation)
if any(transitions) > 0:
agents_direction.append(orientation)
continue
return grid_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
return grid_map, {'agents_hints': {
'agent_start_targets_nodes': agent_start_targets_nodes,
'train_stations': train_stations
}}
return generator
"""Schedule generators (railway undertaking, "EVU")."""
import warnings
from typing import Tuple, List, Callable, Mapping, Optional, Any
import msgpack
......@@ -55,6 +56,61 @@ def complex_rail_generator_agents_placer(speed_ratio_map: Mapping[float, float]
return generator
def sparse_rail_generator_agents_placer(speed_ratio_map: Mapping[float, float] = None) -> ScheduleGenerator:
def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None):
train_stations = hints['train_stations']
agent_start_targets_nodes = hints['agent_start_targets_nodes']
# Place agents and targets within available train stations
agents_position = []
agents_target = []
agents_direction = []
for agent_idx in range(num_agents):
# Set target for agent
current_target_node = agent_start_targets_nodes[agent_idx][1]
target_station_idx = np.random.randint(len(train_stations[current_target_node]))
target = train_stations[current_target_node][target_station_idx]
tries = 0
while (target[0], target[1]) in agents_target:
target_station_idx = np.random.randint(len(train_stations[current_target_node]))
target = train_stations[current_target_node][target_station_idx]
tries += 1
if tries > 100:
warnings.warn("Could not set target position, removing an agent")
break
agents_target.append((target[0], target[1]))
# Set start for agent
current_start_node = agent_start_targets_nodes[agent_idx][0]
start_station_idx = np.random.randint(len(train_stations[current_start_node]))
start = train_stations[current_start_node][start_station_idx]
tries = 0
while (start[0], start[1]) in agents_position:
tries += 1
if tries > 100:
warnings.warn("Could not set start position, please change initial parameters!!!!")
break
start_station_idx = np.random.randint(len(train_stations[current_start_node]))
start = train_stations[current_start_node][start_station_idx]
agents_position.append((start[0], start[1]))
# Orient the agent correctly
for orientation in range(4):
transitions = rail.get_transitions(start[0], start[1], orientation)
if any(transitions) > 0:
agents_direction.append(orientation)
continue
if speed_ratio_map:
speeds = speed_initialization_helper(num_agents, speed_ratio_map)
else:
speeds = [1.0] * len(agents_position)
return agents_position, agents_direction, agents_target, speeds
return generator
def get_rnd_agents_pos_tgt_dir_on_rail(speed_ratio_map: Mapping[float, float] = None) -> ScheduleGenerator:
"""
Given a `rail' GridTransitionMap, return a random placement of agents (initial position, direction and target).
......
......@@ -172,7 +172,7 @@ class PILGL(GraphicsLayer):
def text(self, xPx, yPx, strText, layer=RAIL_LAYER):
xyPixLeftTop = (xPx, yPx)
self.draws[layer].text(xyPixLeftTop, strText, font=self.font, fill=(0, 0, 0, 255))
def text_rowcol(self, rcTopLeft, strText, layer=AGENT_LAYER):
print("Text:", "rc:", rcTopLeft, "text:", strText, "layer:", layer)
xyPixLeftTop = tuple((array(rcTopLeft) * self.nPixCell)[[1, 0]])
......@@ -500,9 +500,9 @@ class PILSVG(PILGL):
False)[0]
self.draw_image_row_col(colored_rail, (row, col), layer=PILGL.PREDICTION_PATH_LAYER)
def set_rail_at(self, row, col, binary_trans, target=None, is_selected=False, rail_grid=None,
show_debug=True):
def set_rail_at(self, row, col, binary_trans, target=None, is_selected=False, rail_grid=None,
show_debug=True):
if binary_trans in self.pil_rail:
pil_track = self.pil_rail[binary_trans]
if target is not None:
......@@ -510,7 +510,7 @@ class PILSVG(PILGL):
target_img = Image.alpha_composite(pil_track, target_img)
self.draw_image_row_col(target_img, (row, col), layer=PILGL.TARGET_LAYER)
if show_debug:
self.text_rowcol((row+0.8, col+0.0), strText=str(target), layer=PILGL.TARGET_LAYER)
self.text_rowcol((row + 0.8, col + 0.0), strText=str(target), layer=PILGL.TARGET_LAYER)
if binary_trans == 0:
if self.background_grid[col][row] <= 4:
......@@ -607,7 +607,7 @@ class PILSVG(PILGL):
if show_debug:
print("Call text:")
self.text_rowcol((row+0.2, col+0.2,), str(agent_idx))
self.text_rowcol((row + 0.2, col + 0.2,), str(agent_idx))
def set_cell_occupied(self, agent_idx, row, col):
occupied_im = self.cell_occupied[agent_idx % len(self.cell_occupied)]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment