diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index 71a185c765bcab831e7b104124a164bcf2398b14..a99009a5d5e4f564c15efe981d76b2c0c6cebdfa 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -1,9 +1,9 @@ import numpy as np -from flatland.envs.rail_generators import sparse_rail_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import sparse_rail_generator from flatland.envs.schedule_generators import sparse_schedule_generator from flatland.utils.rendertools import RenderTool @@ -13,27 +13,35 @@ np.random.seed(1) # Training on simple small tasks is the best way to get familiar with the environment # Use a the malfunction generator to break agents from time to time -stochastic_data = {'prop_malfunction': 0.5, # Percentage of defective agents +stochastic_data = {'prop_malfunction': 0.1, # Percentage of defective agents 'malfunction_rate': 30, # Rate of malfunction occurence 'min_duration': 3, # Minimal duration of malfunction - 'max_duration': 10 # Max duration of malfunction + 'max_duration': 20 # Max duration of malfunction } +# Custom observation builder TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()) -env = RailEnv(width=20, - height=20, - rail_generator=sparse_rail_generator(num_cities=2, # Number of cities in map (where train stations are) - num_intersections=1, # Number of intersections (no start / target) + +# Different agent types (trains) with different speeds. +speed_ration_map = {1.: 0.25, # Fast passenger train + 1. / 2.: 0.25, # Fast freight train + 1. / 3.: 0.25, # Slow commuter train + 1. / 4.: 0.25} # Slow freight train + +env = RailEnv(width=50, + height=50, + rail_generator=sparse_rail_generator(num_cities=20, # Number of cities in map (where train stations are) + num_intersections=5, # Number of intersections (no start / target) num_trainstations=15, # Number of possible start/targets on map min_node_dist=3, # Minimal distance of nodes - node_radius=3, # Proximity of stations to city center - num_neighb=2, # Number of connections to other cities/intersections + node_radius=2, # Proximity of stations to city center + num_neighb=4, # Number of connections to other cities/intersections seed=15, # Random seed realistic_mode=True, enhance_intersection=True ), - schedule_generator=sparse_schedule_generator(), - number_of_agents=5, + schedule_generator=sparse_schedule_generator(speed_ration_map), + number_of_agents=10, stochastic_data=stochastic_data, # Malfunction data generator obs_builder_object=TreeObservation) @@ -83,10 +91,6 @@ action_dict = dict() print("Start episode...") # Reset environment and get initial observations for all agents obs = env.reset() -# Update/Set agent's speed -for idx in range(env.get_num_agents()): - speed = 1.0 / ((idx % 5) + 1.0) - env.agents[idx].speed_data["speed"] = speed # Reset the rendering sytem env_renderer.reset() diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py index 4c4070088c59499f885c16db68976c163ec91001..b228e10b6c146f5692166e179bb9f574a68c9134 100644 --- a/flatland/envs/agent_utils.py +++ b/flatland/envs/agent_utils.py @@ -30,10 +30,11 @@ class EnvAgentStatic(object): lambda: dict({'malfunction': 0, 'malfunction_rate': 0, 'next_malfunction': 0, 'nr_malfunctions': 0}))) @classmethod - def from_lists(cls, positions, directions, targets, speeds=None): + def from_lists(cls, positions, directions, targets, speeds=None, malfunction_rates=None): """ Create a list of EnvAgentStatics from lists of positions, directions and targets """ speed_datas = [] + for i in range(len(positions)): speed_datas.append({'position_fraction': 0.0, 'speed': speeds[i] if speeds is not None else 1.0, @@ -41,10 +42,11 @@ class EnvAgentStatic(object): # TODO: on initialization, all agents are re-set as non-broken. Perhaps it may be desirable to set # some as broken? + malfunction_datas = [] for i in range(len(positions)): malfunction_datas.append({'malfunction': 0, - 'malfunction_rate': 0, + 'malfunction_rate': malfunction_rates[i] if malfunction_rates is not None else 0., 'next_malfunction': 0, 'nr_malfunctions': 0}) diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 62efbdc5bc0781c4c7482412dafd98710ed9d14e..280fd345d8c1db206c42dc30ba2d7b5fa2e8a69e 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -238,7 +238,6 @@ class RailEnv(Environment): agents_hints = optionals['agents_hints'] self.agents_static = EnvAgentStatic.from_lists( *self.schedule_generator(self.rail, self.get_num_agents(), hints=agents_hints)) - self.restart_agents() for i_agent in range(self.get_num_agents()): @@ -248,7 +247,6 @@ class RailEnv(Environment): if np.random.random() < self.proportion_malfunctioning_trains: agent.malfunction_data['malfunction_rate'] = self.mean_malfunction_rate - agent.speed_data['position_fraction'] = 0.0 agent.malfunction_data['malfunction'] = 0 self._agent_malfunction(agent) @@ -510,11 +508,9 @@ class RailEnv(Environment): grid_data = self.rail.grid.tolist() agent_static_data = [agent.to_list() for agent in self.agents_static] agent_data = [agent.to_list() for agent in self.agents] - - msgpack.packb(grid_data) - msgpack.packb(agent_data) - msgpack.packb(agent_static_data) - + msgpack.packb(grid_data, use_bin_type=True) + msgpack.packb(agent_data, use_bin_type=True) + msgpack.packb(agent_static_data, use_bin_type=True) msg_data = { "grid": grid_data, "agents_static": agent_static_data, @@ -528,11 +524,11 @@ class RailEnv(Environment): return msgpack.packb(msg_data, use_bin_type=True) def set_full_state_msg(self, msg_data): - data = msgpack.unpackb(msg_data, use_list=False) - self.rail.grid = np.array(data[b"grid"]) + data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8') + self.rail.grid = np.array(data["grid"]) # agents are always reset as not moving - self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]] - self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4]) for d in data[b"agents"]] + self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data["agents_static"]] + self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8]) for d in data["agents"]] # setup with loaded data self.height, self.width = self.rail.grid.shape self.rail.height = self.height @@ -540,13 +536,13 @@ class RailEnv(Environment): self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False) def set_full_state_dist_msg(self, msg_data): - data = msgpack.unpackb(msg_data, use_list=False) - self.rail.grid = np.array(data[b"grid"]) + data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8') + self.rail.grid = np.array(data["grid"]) # agents are always reset as not moving - self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]] - self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4]) for d in data[b"agents"]] - if hasattr(self.obs_builder, 'distance_map') and b"distance_maps" in data.keys(): - self.obs_builder.distance_map = data[b"distance_maps"] + self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data["agents_static"]] + self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8]) for d in data["agents"]] + if hasattr(self.obs_builder, 'distance_map') and "distance_maps" in data.keys(): + self.obs_builder.distance_map = data["distance_maps"] # setup with loaded data self.height, self.width = self.rail.grid.shape self.rail.height = self.height @@ -557,13 +553,12 @@ class RailEnv(Environment): grid_data = self.rail.grid.tolist() agent_static_data = [agent.to_list() for agent in self.agents_static] agent_data = [agent.to_list() for agent in self.agents] - - msgpack.packb(grid_data) - msgpack.packb(agent_data) - msgpack.packb(agent_static_data) + msgpack.packb(grid_data, use_bin_type=True) + msgpack.packb(agent_data, use_bin_type=True) + msgpack.packb(agent_static_data, use_bin_type=True) if hasattr(self.obs_builder, 'distance_map'): distance_map_data = self.obs_builder.distance_map - msgpack.packb(distance_map_data) + msgpack.packb(distance_map_data, use_bin_type=True) msg_data = { "grid": grid_data, "agents_static": agent_static_data, diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 40ec2e0df89a3de48bf0a2a4430de3e30fa556e5..b69d80a065670ebafb60da2972d31efb05add047 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -573,9 +573,14 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 x_positions = np.linspace(node_radius, height - node_radius, nodes_per_row, dtype=int) y_positions = np.linspace(node_radius, width - node_radius, nodes_per_col, dtype=int) + fraction = 0 + city_fraction = num_cities / tot_num_node + step = np.gcd(num_intersections, num_cities) / tot_num_node + for node_idx in range(num_cities + num_intersections): to_close = True tries = 0 + if not realistic_mode: while to_close: x_tmp = node_radius + np.random.randint(height - node_radius) @@ -587,7 +592,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist: to_close = True - # CHeck distance to intersections + # Check distance to intersections for node_pos in intersection_positions: if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist: to_close = True @@ -603,13 +608,13 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 warnings.warn("Could not set nodes, please change initial parameters!!!!") break else: + fraction = (fraction + step) % 1. x_tmp = x_positions[node_idx % nodes_per_row] y_tmp = y_positions[node_idx // nodes_per_row] - if len(city_positions) < num_cities and (node_idx % (tot_num_node // num_cities)) == 0: + if len(city_positions) < num_cities and fraction < city_fraction: city_positions.append((x_tmp, y_tmp)) else: intersection_positions.append((x_tmp, y_tmp)) - node_positions = city_positions + intersection_positions # Chose node connection @@ -627,14 +632,15 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 current_node = node_stack[0] delete_idx = np.where(available_nodes_full == current_node) available_nodes_full = np.delete(available_nodes_full, delete_idx, 0) + # Priority city to intersection connections - if current_node < num_cities and len(available_intersections) > 0: + if False and current_node < num_cities and len(available_intersections) > 0: available_nodes = available_intersections delete_idx = np.where(available_cities == current_node) available_cities = np.delete(available_cities, delete_idx, 0) # Priority intersection to city connections - elif current_node >= num_cities and len(available_cities) > 0: + elif False and current_node >= num_cities and len(available_cities) > 0: available_nodes = available_cities delete_idx = np.where(available_intersections == current_node) available_intersections = np.delete(available_intersections, delete_idx, 0) @@ -672,8 +678,8 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 if num_cities > 1: train_stations = [[] for i in range(num_cities)] built_num_trainstation = 0 - spot_found = True for station in range(num_trainstations): + spot_found = True trainstation_node = int(station / num_trainstations * num_cities) station_x = np.clip(node_positions[trainstation_node][0] + np.random.randint(-node_radius, node_radius), @@ -699,7 +705,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 if tries > 100: warnings.warn("Could not set trainstations, please change initial parameters!!!!") spot_found = False - break + if spot_found: train_stations[trainstation_node].append((station_x, station_y)) @@ -708,12 +714,12 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 (station_x, station_y)) # Check if connection was made if len(connection) == 0: - train_stations[trainstation_node].pop(-1) + if len(train_stations[trainstation_node]) > 0: + train_stations[trainstation_node].pop(-1) else: built_num_trainstation += 1 # Adjust the number of agents if you could not build enough trainstations - if num_agents > built_num_trainstation: num_agents = built_num_trainstation warnings.warn("sparse_rail_generator: num_agents > nr_start_goal, changing num_agents") diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py index 4843e0040d80b79de54e8ed57674a37884ef6809..a3a6dc1e4813a8c41284d46b6c03c5898cdcc63e 100644 --- a/flatland/envs/schedule_generators.py +++ b/flatland/envs/schedule_generators.py @@ -60,7 +60,10 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None): train_stations = hints['train_stations'] agent_start_targets_nodes = hints['agent_start_targets_nodes'] - num_agents = hints['num_agents'] + max_num_agents = hints['num_agents'] + if num_agents > max_num_agents: + num_agents = max_num_agents + warnings.warn("Too many agents! Changes number of agents.") # Place agents and targets within available train stations agents_position = [] agents_target = [] @@ -191,7 +194,8 @@ def random_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> if len(valid_starting_directions) == 0: update_agents[i] = 1 - warnings.warn("reset position for agent[{}]: {} -> {}".format(i, agents_position[i], agents_target[i])) + warnings.warn( + "reset position for agent[{}]: {} -> {}".format(i, agents_position[i], agents_target[i])) re_generate = True break else: @@ -221,15 +225,25 @@ def schedule_from_file(filename) -> ScheduleGenerator: def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None) -> ScheduleGeneratorProduct: with open(filename, "rb") as file_in: load_data = file_in.read() - data = msgpack.unpackb(load_data, use_list=False) + data = msgpack.unpackb(load_data, use_list=False, encoding='utf-8') # agents are always reset as not moving - agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]] + if len(data['agents_static'][0]) > 5: + print(len(data['agents_static'][0])) + agents_static = [EnvAgentStatic(d[0], d[1], d[2], d[3], d[4], d[5]) for d in data["agents_static"]] + else: + agents_static = [EnvAgentStatic(d[0], d[1], d[2], d[3]) for d in data["agents_static"]] + # setup with loaded data agents_position = [a.position for a in agents_static] agents_direction = [a.direction for a in agents_static] agents_target = [a.target for a in agents_static] - - return agents_position, agents_direction, agents_target, [1.0] * len(agents_position) + if len(data['agents_static'][0]) > 5: + agents_speed = [a.speed_data['speed'] for a in agents_static] + agents_malfunction = [a.malfunction_data['malfunction_rate'] for a in agents_static] + else: + agents_speed = None + agents_malfunction = None + return agents_position, agents_direction, agents_target, agents_speed, agents_malfunction return generator diff --git a/flatland/evaluators/client.py b/flatland/evaluators/client.py index f2dac1c8705a651e2a7be026b4bd82a961efbbdd..92f1a49d777552b2108aff3963aa5c1bc84fdbfc 100644 --- a/flatland/evaluators/client.py +++ b/flatland/evaluators/client.py @@ -254,14 +254,12 @@ class FlatlandRemoteClient(object): if __name__ == "__main__": remote_client = FlatlandRemoteClient() - def my_controller(obs, _env): _action = {} for _idx, _ in enumerate(_env.agents): _action[_idx] = np.random.randint(0, 5) return _action - my_observation_builder = TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()) diff --git a/flatland/utils/simple_rail.py b/flatland/utils/simple_rail.py index 67bd93dd35c8f53ef3cdef23dbae0f0d785b9a64..3dfa6bad0d05d8519b233422593cd7f9c2c460f3 100644 --- a/flatland/utils/simple_rail.py +++ b/flatland/utils/simple_rail.py @@ -81,6 +81,7 @@ def make_simple_rail2() -> Tuple[GridTransitionMap, np.array]: rail.grid = rail_map return rail, rail_map + def make_simple_rail_unconnected() -> Tuple[GridTransitionMap, np.array]: # We instantiate a very simple rail network on a 7x10 grid: # Note that that cells have invalid RailEnvTransitions! @@ -103,13 +104,13 @@ def make_simple_rail_unconnected() -> Tuple[GridTransitionMap, np.array]: horizontal_straight = transitions.rotate_transition(vertical_straight, 90) simple_switch_north_left = cells[2] simple_switch_north_right = cells[10] - simple_switch_east_west_north = transitions.rotate_transition(simple_switch_north_right, 270) + # simple_switch_east_west_north = transitions.rotate_transition(simple_switch_north_right, 270) simple_switch_east_west_south = transitions.rotate_transition(simple_switch_north_left, 270) rail_map = np.array( [[empty] * 3 + [dead_end_from_south] + [empty] * 6] + - [[empty] * 3 + [vertical_straight] + [empty] * 6] + - [[empty] * 3 + [dead_end_from_north] + [empty] * 6] + - [[dead_end_from_east] + [horizontal_straight] * 5 + [simple_switch_east_west_south] + + [[empty] * 3 + [vertical_straight] + [empty] * 6] + + [[empty] * 3 + [dead_end_from_north] + [empty] * 6] + + [[dead_end_from_east] + [horizontal_straight] * 5 + [simple_switch_east_west_south] + [horizontal_straight] * 2 + [dead_end_from_west]] + [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 + [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16) diff --git a/tests/test_flatland_core_transition_map.py b/tests/test_flatland_core_transition_map.py index 8812c847e61d81f6614f37d26489b4c17ea7fd14..1137b8816973c12601029543c221810c9acd157c 100644 --- a/tests/test_flatland_core_transition_map.py +++ b/tests/test_flatland_core_transition_map.py @@ -116,18 +116,18 @@ def test_path_exists(rendering=False): check_path( env, rail, - (1,3), # east dead-end + (1, 3), # east dead-end 2, # south - (3,3), # north dead-end + (3, 3), # north dead-end True ) check_path( env, rail, - (1,3), # east dead-end + (1, 3), # east dead-end 0, # north - (3,3), # north dead-end + (3, 3), # north dead-end True )