From ccd73a5b994666790f7875f51102193331a4a5e1 Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Wed, 28 Aug 2019 13:33:11 +0200 Subject: [PATCH] merge #147 -> #141 schedule_generator for sparse_rail_generator --- examples/flatland_2_0_example.py | 5 ++- flatland/envs/rail_generators.py | 47 ++--------------------- flatland/envs/schedule_generators.py | 56 ++++++++++++++++++++++++++++ flatland/utils/graphics_pil.py | 12 +++--- 4 files changed, 69 insertions(+), 51 deletions(-) diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index 916e50b2..9f4d62cf 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -1,9 +1,10 @@ import numpy as np - from flatland.envs.generators import sparse_rail_generator + from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv +from flatland.envs.schedule_generators import sparse_rail_generator_agents_placer from flatland.utils.rendertools import RenderTool np.random.seed(1) @@ -31,6 +32,7 @@ env = RailEnv(width=20, realistic_mode=True, enhance_intersection=True ), + agent_generator=sparse_rail_generator_agents_placer(), number_of_agents=5, stochastic_data=stochastic_data, # Malfunction data generator obs_builder_object=TreeObservation) @@ -75,7 +77,6 @@ class RandomAgent: # Set action space to 4 to remove stop action agent = RandomAgent(218, 4) - # Empty dictionary for all agent action action_dict = dict() diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index d338301c..ed507dca 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -786,48 +786,9 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 else: num_agents -= 1 - # Place agents and targets within available train stations - agents_position = [] - agents_target = [] - agents_direction = [] - - for agent_idx in range(num_agents): - # Set target for agent - current_target_node = agent_start_targets_nodes[agent_idx][1] - target_station_idx = np.random.randint(len(train_stations[current_target_node])) - target = train_stations[current_target_node][target_station_idx] - tries = 0 - while (target[0], target[1]) in agents_target: - target_station_idx = np.random.randint(len(train_stations[current_target_node])) - target = train_stations[current_target_node][target_station_idx] - tries += 1 - if tries > 100: - warnings.warn("Could not set target position, removing an agent") - break - agents_target.append((target[0], target[1])) - - # Set start for agent - current_start_node = agent_start_targets_nodes[agent_idx][0] - start_station_idx = np.random.randint(len(train_stations[current_start_node])) - start = train_stations[current_start_node][start_station_idx] - tries = 0 - while (start[0], start[1]) in agents_position: - tries += 1 - if tries > 100: - warnings.warn("Could not set start position, please change initial parameters!!!!") - break - start_station_idx = np.random.randint(len(train_stations[current_start_node])) - start = train_stations[current_start_node][start_station_idx] - - agents_position.append((start[0], start[1])) - - # Orient the agent correctly - for orientation in range(4): - transitions = grid_map.get_transitions(start[0], start[1], orientation) - if any(transitions) > 0: - agents_direction.append(orientation) - continue - - return grid_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) + return grid_map, {'agents_hints': { + 'agent_start_targets_nodes': agent_start_targets_nodes, + 'train_stations': train_stations + }} return generator diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py index 50f31378..ef1f9666 100644 --- a/flatland/envs/schedule_generators.py +++ b/flatland/envs/schedule_generators.py @@ -1,4 +1,5 @@ """Schedule generators (railway undertaking, "EVU").""" +import warnings from typing import Tuple, List, Callable, Mapping, Optional, Any import msgpack @@ -55,6 +56,61 @@ def complex_rail_generator_agents_placer(speed_ratio_map: Mapping[float, float] return generator +def sparse_rail_generator_agents_placer(speed_ratio_map: Mapping[float, float] = None) -> ScheduleGenerator: + def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None): + train_stations = hints['train_stations'] + agent_start_targets_nodes = hints['agent_start_targets_nodes'] + # Place agents and targets within available train stations + agents_position = [] + agents_target = [] + agents_direction = [] + for agent_idx in range(num_agents): + # Set target for agent + current_target_node = agent_start_targets_nodes[agent_idx][1] + target_station_idx = np.random.randint(len(train_stations[current_target_node])) + target = train_stations[current_target_node][target_station_idx] + tries = 0 + while (target[0], target[1]) in agents_target: + target_station_idx = np.random.randint(len(train_stations[current_target_node])) + target = train_stations[current_target_node][target_station_idx] + tries += 1 + if tries > 100: + warnings.warn("Could not set target position, removing an agent") + break + agents_target.append((target[0], target[1])) + + # Set start for agent + current_start_node = agent_start_targets_nodes[agent_idx][0] + start_station_idx = np.random.randint(len(train_stations[current_start_node])) + start = train_stations[current_start_node][start_station_idx] + tries = 0 + while (start[0], start[1]) in agents_position: + tries += 1 + if tries > 100: + warnings.warn("Could not set start position, please change initial parameters!!!!") + break + start_station_idx = np.random.randint(len(train_stations[current_start_node])) + start = train_stations[current_start_node][start_station_idx] + + agents_position.append((start[0], start[1])) + + # Orient the agent correctly + for orientation in range(4): + transitions = rail.get_transitions(start[0], start[1], orientation) + if any(transitions) > 0: + agents_direction.append(orientation) + continue + + if speed_ratio_map: + speeds = speed_initialization_helper(num_agents, speed_ratio_map) + else: + speeds = [1.0] * len(agents_position) + + return agents_position, agents_direction, agents_target, speeds + + return generator + + def get_rnd_agents_pos_tgt_dir_on_rail(speed_ratio_map: Mapping[float, float] = None) -> ScheduleGenerator: """ Given a `rail' GridTransitionMap, return a random placement of agents (initial position, direction and target). diff --git a/flatland/utils/graphics_pil.py b/flatland/utils/graphics_pil.py index 6a0a9282..92a0f84f 100644 --- a/flatland/utils/graphics_pil.py +++ b/flatland/utils/graphics_pil.py @@ -172,7 +172,7 @@ class PILGL(GraphicsLayer): def text(self, xPx, yPx, strText, layer=RAIL_LAYER): xyPixLeftTop = (xPx, yPx) self.draws[layer].text(xyPixLeftTop, strText, font=self.font, fill=(0, 0, 0, 255)) - + def text_rowcol(self, rcTopLeft, strText, layer=AGENT_LAYER): print("Text:", "rc:", rcTopLeft, "text:", strText, "layer:", layer) xyPixLeftTop = tuple((array(rcTopLeft) * self.nPixCell)[[1, 0]]) @@ -500,9 +500,9 @@ class PILSVG(PILGL): False)[0] self.draw_image_row_col(colored_rail, (row, col), layer=PILGL.PREDICTION_PATH_LAYER) - def set_rail_at(self, row, col, binary_trans, target=None, is_selected=False, rail_grid=None, - show_debug=True): - + def set_rail_at(self, row, col, binary_trans, target=None, is_selected=False, rail_grid=None, + show_debug=True): + if binary_trans in self.pil_rail: pil_track = self.pil_rail[binary_trans] if target is not None: @@ -510,7 +510,7 @@ class PILSVG(PILGL): target_img = Image.alpha_composite(pil_track, target_img) self.draw_image_row_col(target_img, (row, col), layer=PILGL.TARGET_LAYER) if show_debug: - self.text_rowcol((row+0.8, col+0.0), strText=str(target), layer=PILGL.TARGET_LAYER) + self.text_rowcol((row + 0.8, col + 0.0), strText=str(target), layer=PILGL.TARGET_LAYER) if binary_trans == 0: if self.background_grid[col][row] <= 4: @@ -607,7 +607,7 @@ class PILSVG(PILGL): if show_debug: print("Call text:") - self.text_rowcol((row+0.2, col+0.2,), str(agent_idx)) + self.text_rowcol((row + 0.2, col + 0.2,), str(agent_idx)) def set_cell_occupied(self, agent_idx, row, col): occupied_im = self.cell_occupied[agent_idx % len(self.cell_occupied)] -- GitLab