diff --git a/examples/temporary_example.py b/examples/temporary_example.py index 8b53b1ddb7e3558d8a712e14d768245c56032d9a..96acb30f5820a6a729e53b787106bf55775daa16 100644 --- a/examples/temporary_example.py +++ b/examples/temporary_example.py @@ -2,20 +2,22 @@ import random import numpy as np import matplotlib.pyplot as plt -from flatland.core.env import RailEnv -from flatland.utils.rail_env_generator import * +from flatland.envs.rail_env import * +from flatland.core.env_observation_builder import TreeObsForRailEnv from flatland.utils.rendertools import * random.seed(1) np.random.seed(1) + # Example generate a random rail -env = RailEnv(width=20, height=20, rail_generator=generate_random_rail, number_of_agents=10) +env = RailEnv(width=20, height=20, rail_generator=random_rail_generator, number_of_agents=10) env.reset() env_renderer = RenderTool(env) env_renderer.renderEnv(show=True) + # Example generate a rail given a manual specification, # a map of tuples (cell_type, rotation) specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (7, 0), (0, 0)], @@ -23,13 +25,12 @@ specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (7, 0), (0, 0)], env = RailEnv(width=6, height=2, - rail_generator=generate_rail_from_manual_specifications(specs), - number_of_agents=1) + rail_generator=rail_from_manual_specifications_generator(specs), + number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=1)) handle = env.get_agent_handles() -obs = env.reset() - env.agents_position = [[1, 4]] env.agents_target = [[1, 1]] env.agents_direction = [1] @@ -37,13 +38,15 @@ env.agents_direction = [1] env.obs_builder.reset() # TODO: delete next line -#print(env.obs_builder.distance_map[0,:,:]) -#print(env.obs_builder.max_dist) +#for i in range(4): +# print(env.obs_builder.distance_map[0, :, :, i]) + +obs, all_rewards, done, _ = env.step({0:0}) +env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5) env_renderer = RenderTool(env) env_renderer.renderEnv(show=True) - print("Manual control: s=perform step, q=quit, [agent id] [1-2-3 action] \ (turnleft+move, move to front, turnright+move)") for step in range(100): diff --git a/flatland/core/env.py b/flatland/core/env.py index 02d912a3aba8c6ee84a2159b248e0631e194d2de..284afdffb6ce46ac481018af469c7d2e024fc792 100644 --- a/flatland/core/env.py +++ b/flatland/core/env.py @@ -3,10 +3,6 @@ The env module defines the base Environment class. The base Environment class is adapted from rllib.env.MultiAgentEnv (https://github.com/ray-project/ray). """ -import random - -from .env_observation_builder import TreeObsForRailEnv -from flatland.utils.rail_env_generator import generate_random_rail class Environment: @@ -94,327 +90,3 @@ class Environment: function. """ raise NotImplementedError() - - -class RailEnv: - """ - RailEnv environment class. - - RailEnv is an environment inspired by a (simplified version of) a rail - network, in which agents (trains) have to navigate to their target - locations in the shortest time possible, while at the same time cooperating - to avoid bottlenecks. - - The valid actions in the environment are: - 0: do nothing - 1: turn left and move to the next cell - 2: move to the next cell in front of the agent - 3: turn right and move to the next cell - - Moving forward in a dead-end cell makes the agent turn 180 degrees and step - to the cell it came from. - - The actions of the agents are executed in order of their handle to prevent - deadlocks and to allow them to learn relative priorities. - - TODO: WRITE ABOUT THE REWARD FUNCTION, and possibly allow for alpha and - beta to be passed as parameters to __init__(). - """ - - def __init__(self, - width, - height, - rail_generator=generate_random_rail, - number_of_agents=1, - obs_builder_object=TreeObsForRailEnv(max_depth=2)): - """ - Environment init. - - Parameters - ------- - rail_generator : function - The rail_generator function is a function that takes the width and - height of a rail map along with the number of times the env has - been reset, and returns a GridTransitionMap object. - Implemented functions are: - generate_random_rail : generate a random rail of given size - TODO: generate_rail_from_saved_list --- - width : int - The width of the rail map. Potentially in the future, - a range of widths to sample from. - height : int - The height of the rail map. Potentially in the future, - a range of heights to sample from. - number_of_agents : int - Number of agents to spawn on the map. Potentially in the future, - a range of number of agents to sample from. - obs_builder_object: ObservationBuilder object - ObservationBuilder-derived object that takes builds observation - vectors for each agent. - """ - - self.rail_generator = rail_generator - self.num_resets = 0 - self.rail = None - self.width = width - self.height = height - - self.number_of_agents = number_of_agents - - self.obs_builder = obs_builder_object - self.obs_builder.set_env(self) - - self.actions = [0]*self.number_of_agents - self.rewards = [0]*self.number_of_agents - self.done = False - - self.agents_position = [] - self.agents_target = [] - self.agents_direction = [] - - self.dones = {"__all__": False} - self.obs_dict = {} - self.rewards_dict = {} - - self.agents_handles = list(range(self.number_of_agents)) - - def get_agent_handles(self): - return self.agents_handles - - def reset(self): - self.rail = self.rail_generator(self.width, self.height, self.num_resets) - self.num_resets += 1 - - self.dones = {"__all__": False} - for handle in self.agents_handles: - self.dones[handle] = False - - re_generate = True - while re_generate: - valid_positions = [] - for r in range(self.height): - for c in range(self.width): - if self.rail.get_transitions((r, c)) > 0: - valid_positions.append((r, c)) - - self.agents_position = random.sample(valid_positions, - self.number_of_agents) - self.agents_target = random.sample(valid_positions, - self.number_of_agents) - - # agents_direction must be a direction for which a solution is - # guaranteed. - self.agents_direction = [0]*self.number_of_agents - re_generate = False - for i in range(self.number_of_agents): - valid_movements = [] - for direction in range(4): - position = self.agents_position[i] - moves = self.rail.get_transitions( - (position[0], position[1], direction)) - for move_index in range(4): - if moves[move_index]: - valid_movements.append((direction, move_index)) - - valid_starting_directions = [] - for m in valid_movements: - new_position = self._new_position(self.agents_position[i], - m[1]) - if m[0] not in valid_starting_directions and \ - self._path_exists(new_position, m[0], - self.agents_target[i]): - valid_starting_directions.append(m[0]) - - if len(valid_starting_directions) == 0: - re_generate = True - else: - self.agents_direction[i] = random.sample( - valid_starting_directions, 1)[0] - - # Reset the state of the observation builder with the new environment - self.obs_builder.reset() - - # Return the new observation vectors for each agent - return self._get_observations() - - def step(self, action_dict): - alpha = 1.0 - beta = 1.0 - - invalid_action_penalty = -2 - step_penalty = -1 * alpha - global_reward = 1 * beta - - # Reset the step rewards - self.rewards_dict = {} - for handle in self.agents_handles: - self.rewards_dict[handle] = 0 - - if self.dones["__all__"]: - return self._get_observations(), self.rewards_dict, self.dones, {} - - for i in range(len(self.agents_handles)): - handle = self.agents_handles[i] - - if handle not in action_dict: - continue - - action = action_dict[handle] - - if action < 0 or action > 3: - print('ERROR: illegal action=', action, - 'for agent with handle=', handle) - return - - if action > 0: - pos = self.agents_position[i] - direction = self.agents_direction[i] - - movement = direction - if action == 1: - movement = direction - 1 - elif action == 3: - movement = direction + 1 - - if movement < 0: - movement += 4 - if movement >= 4: - movement -= 4 - - is_deadend = False - if action == 2: - # compute number of possible transitions in the current - # cell - nbits = 0 - tmp = self.rail.get_transitions((pos[0], pos[1])) - while tmp > 0: - nbits += (tmp & 1) - tmp = tmp >> 1 - if nbits == 1: - # dead-end; assuming the rail network is consistent, - # this should match the direction the agent has come - # from. But it's better to check in any case. - reverse_direction = 0 - if direction == 0: - reverse_direction = 2 - elif direction == 1: - reverse_direction = 3 - elif direction == 2: - reverse_direction = 0 - elif direction == 3: - reverse_direction = 1 - - valid_transition = self.rail.get_transition( - (pos[0], pos[1], direction), - reverse_direction) - if valid_transition: - direction = reverse_direction - movement = reverse_direction - is_deadend = True - - new_position = self._new_position(pos, movement) - - # Is it a legal move? 1) transition allows the movement in the - # cell, 2) the new cell is not empty (case 0), 3) the cell is - # free, i.e., no agent is currently in that cell - if new_position[1] >= self.width or\ - new_position[0] >= self.height or\ - new_position[0] < 0 or new_position[1] < 0: - new_cell_isValid = False - - elif self.rail.get_transitions((new_position[0], new_position[1])) > 0: - new_cell_isValid = True - else: - new_cell_isValid = False - - transition_isValid = self.rail.get_transition( - (pos[0], pos[1], direction), - movement) or is_deadend - - cell_isFree = True - for j in range(self.number_of_agents): - if self.agents_position[j] == new_position: - cell_isFree = False - break - - if new_cell_isValid and transition_isValid and cell_isFree: - # move and change direction to face the movement that was - # performed - self.agents_position[i] = new_position - self.agents_direction[i] = movement - else: - # the action was not valid, add penalty - self.rewards_dict[handle] += invalid_action_penalty - - # if agent is not in target position, add step penalty - if self.agents_position[i][0] == self.agents_target[i][0] and \ - self.agents_position[i][1] == self.agents_target[i][1]: - self.dones[handle] = True - else: - self.rewards_dict[handle] += step_penalty - - # Check for end of episode + add global reward to all rewards! - num_agents_in_target_position = 0 - for i in range(self.number_of_agents): - if self.agents_position[i][0] == self.agents_target[i][0] and \ - self.agents_position[i][1] == self.agents_target[i][1]: - num_agents_in_target_position += 1 - - if num_agents_in_target_position == self.number_of_agents: - self.dones["__all__"] = True - self.rewards_dict = [r+global_reward for r in self.rewards_dict] - - # Reset the step actions (in case some agent doesn't 'register_action' - # on the next step) - self.actions = [0]*self.number_of_agents - - return self._get_observations(), self.rewards_dict, self.dones, {} - - def _new_position(self, position, movement): - if movement == 0: # NORTH - return (position[0]-1, position[1]) - elif movement == 1: # EAST - return (position[0], position[1] + 1) - elif movement == 2: # SOUTH - return (position[0]+1, position[1]) - elif movement == 3: # WEST - return (position[0], position[1] - 1) - - def _path_exists(self, start, direction, end): - # BFS - Check if a path exists between the 2 nodes - - visited = set() - stack = [(start, direction)] - while stack: - node = stack.pop() - if node[0][0] == end[0] and node[0][1] == end[1]: - return 1 - if node not in visited: - visited.add(node) - moves = self.rail.get_transitions((node[0][0], node[0][1], node[1])) - for move_index in range(4): - if moves[move_index]: - stack.append((self._new_position(node[0], move_index), - move_index)) - - # If cell is a dead-end, append previous node with reversed - # orientation! - nbits = 0 - tmp = self.rail.get_transitions((node[0][0], node[0][1])) - while tmp > 0: - nbits += (tmp & 1) - tmp = tmp >> 1 - if nbits == 1: - stack.append((node[0], (node[1] + 2) % 4)) - - return 0 - - def _get_observations(self): - self.obs_dict = {} - for handle in self.agents_handles: - self.obs_dict[handle] = self.obs_builder.get(handle) - return self.obs_dict - - def render(self): - # TODO: - pass diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py index f8341e270c3c364109a16b768d994eb1e3dce60c..4d6da5b5608acd5284e681f376a2d53cf1bd535b 100644 --- a/flatland/core/env_observation_builder.py +++ b/flatland/core/env_observation_builder.py @@ -1,3 +1,13 @@ +""" +ObservationBuilder objects are objects that can be passed to environments designed for customizability. +The ObservationBuilder-derived custom classes implement 2 functions, reset() and get() or get(handle). + ++ Reset() is called after each environment reset, to allow for pre-computing relevant data. + ++ Get() is called whenever an observation has to be computed, potentially for each agent independently in +case of multi-agent environments. +""" + import numpy as np from collections import deque @@ -5,47 +15,82 @@ from collections import deque class ObservationBuilder: - def __init__(self, env): + """ + ObservationBuilder base class. + """ + def __init__(self): + pass + + def _set_env(self, env): self.env = env def reset(self): + """ + Called after each environment reset. + """ raise NotImplementedError() - def get(self, handle): + def get(self, handle=0): + """ + Called whenever an observation has to be computed for the `env' environment, possibly + for each agent independently (agent id `handle'). + + Parameters + ------- + handle : int (optional) + Handle of the agent for which to compute the observation vector. + + Returns + ------- + function + An observation structure, specific to the corresponding environment. + """ raise NotImplementedError() class TreeObsForRailEnv(ObservationBuilder): - def __init__(self, env): - self.env = env + """ + TreeObsForRailEnv object. + + This object returns observation vectors for agents in the RailEnv environment. + The information is local to each agent and exploits the tree structure of the rail + network to simplify the representation of the state of the environment for each agent. + """ + def __init__(self, max_depth): + self.max_depth = max_depth def reset(self): self.distance_map = np.inf * np.ones(shape=(self.env.number_of_agents, self.env.height, - self.env.width)) + self.env.width, + 4)) self.max_dist = np.zeros(self.env.number_of_agents) for i in range(self.env.number_of_agents): self.max_dist[i] = self._distance_map_walker(self.env.agents_target[i], i) - def _distance_map_walker(self, position, target_nr): + """ + Utility function to compute distance maps from each cell in the rail network (and each possible + orientation within it) to each agent's target cell. + """ # Returns max distance to target, from the farthest away node, while filling in distance_map for ori in range(4): - self.distance_map[target_nr, position[0], position[1]] = 0 + self.distance_map[target_nr, position[0], position[1], ori] = 0 # Fill in the (up to) 4 neighboring nodes # nodes_queue = [] # list of tuples (row, col, direction, distance); - # direction is the direction of movement, meaning that at least a possible orientation - # of an agent in cell (row,col) allows a movement in direction `direction' - nodes_queue = deque(self._get_and_update_neighbors(position, - target_nr, 0, enforce_target_direction=-1)) + # direction is the direction of movement, meaning that at least a possible orientation of an agent + # in cell (row,col) allows a movement in direction `direction' + nodes_queue = deque(self._get_and_update_neighbors(position, target_nr, 0, enforce_target_direction=-1)) # BFS from target `position' to all the reachable nodes in the grid # Stop the search if the target position is re-visited, in any direction - visited = set([(position[0], position[1], 0), (position[0], position[1], 1), - (position[0], position[1], 2), (position[0], position[1], 3)]) + visited = set([(position[0], position[1], 0), + (position[0], position[1], 1), + (position[0], position[1], 2), + (position[0], position[1], 3)]) max_distance = 0 @@ -54,35 +99,34 @@ class TreeObsForRailEnv(ObservationBuilder): node_id = (node[0], node[1], node[2]) - #print(node_id, visited, (node_id in visited)) - #print(nodes_queue) - if node_id not in visited: visited.add(node_id) - # From the list of possible neighbors that have at least a path to the - # current node, only keep those whose new orientation in the current cell - # would allow a transition to direction node[2] - valid_neighbors = self._get_and_update_neighbors( - (node[0], node[1]), target_nr, node[3], node[2]) + # From the list of possible neighbors that have at least a path to the current node, only keep those + # whose new orientation in the current cell would allow a transition to direction node[2] + valid_neighbors = self._get_and_update_neighbors((node[0], node[1]), target_nr, node[3], node[2]) for n in valid_neighbors: nodes_queue.append(n) - if len(valid_neighbors)>0: + if len(valid_neighbors) > 0: max_distance = max(max_distance, node[3]+1) return max_distance - def _get_and_update_neighbors(self, position, target_nr, current_distance, enforce_target_direction=-1): + """ + Utility function used by _distance_map_walker to perform a BFS walk over the rail, filling in the + minimum distances from each target cell. + """ neighbors = [] for direction in range(4): - new_cell = self._new_position(position, (direction+2)%4) + new_cell = self._new_position(position, (direction+2) % 4) + + if new_cell[0] >= 0 and new_cell[0] < self.env.height and \ + new_cell[1] >= 0 and new_cell[1] < self.env.width: - if new_cell[0]>=0 and new_cell[0]<self.env.height and\ - new_cell[1]>=0 and new_cell[1]<self.env.width: # Check if the two cells are connected by a valid transition transitionValid = False for orientation in range(4): @@ -94,17 +138,16 @@ class TreeObsForRailEnv(ObservationBuilder): if not transitionValid: continue - # Check if a transition in direction node[2] is possible if an agent - # lands in the current cell with orientation `direction'; this only - # applies to cells that are not dead-ends! + # Check if a transition in direction node[2] is possible if an agent lands in the current + # cell with orientation `direction'; this only applies to cells that are not dead-ends! directionMatch = True - if enforce_target_direction>=0: - directionMatch = self.env.rail.get_transition( - (new_cell[0], new_cell[1], direction), enforce_target_direction) + if enforce_target_direction >= 0: + directionMatch = self.env.rail.get_transition((new_cell[0], new_cell[1], direction), + enforce_target_direction) - # If transition is found to invalid, check if perhaps it - # is a dead-end, in which case the direction of movement is rotated - # 180 degrees (moving forward turns the agents and makes it step in the previous cell) + # If transition is found to invalid, check if perhaps it is a dead-end, in which case the + # direction of movement is rotated 180 degrees (moving forward turns the agents and makes + # it step in the previous cell) if not directionMatch: # If cell is a dead-end, append previous node with reversed # orientation! @@ -115,20 +158,28 @@ class TreeObsForRailEnv(ObservationBuilder): tmp = tmp >> 1 if nbits == 1: # Dead-end! - # Check if transition is possible in new_cell - # with orientation (direction+2)%4 in direction `direction' - directionMatch = directionMatch or self.env.rail.get_transition( - (new_cell[0], new_cell[1], (direction+2)%4), direction) + # Check if transition is possible in new_cell with orientation + # (direction+2)%4 in direction `direction' + directionMatch = directionMatch or \ + self.env.rail.get_transition((new_cell[0], new_cell[1], (direction+2) % 4), + direction) if transitionValid and directionMatch: - new_distance = min(self.distance_map[target_nr, - new_cell[0], new_cell[1]], current_distance+1) - neighbors.append((new_cell[0], new_cell[1], direction, new_distance)) - self.distance_map[target_nr, new_cell[0], new_cell[1]] = new_distance + # Append all possible orientations in new_cell that allow a transition to direction! + for orientation in range(4): + moves = self.env.rail.get_transitions((new_cell[0], new_cell[1], orientation)) + if moves[direction]: + new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], orientation], + current_distance+1) + neighbors.append((new_cell[0], new_cell[1], orientation, new_distance)) + self.distance_map[target_nr, new_cell[0], new_cell[1], orientation] = new_distance return neighbors def _new_position(self, position, movement): + """ + Utility function that converts a compass movement over a 2D grid to new positions (r, c). + """ if movement == 0: # NORTH return (position[0]-1, position[1]) elif movement == 1: # EAST @@ -138,202 +189,187 @@ class TreeObsForRailEnv(ObservationBuilder): elif movement == 3: # WEST return (position[0], position[1] - 1) - def get(self, handle): - # TODO: compute the observation for agent `handle' - return [] + """ + Computes the current observation for agent `handle' in env + The observation vector is composed of 4 sequential parts, corresponding to data from the up to 4 possible + movements in a RailEnv (up to because only a subset of possible transitions are allowed in RailEnv). + The possible movements are sorted relative to the current orientation of the agent, rather than NESW as for + the transitions. The order is: + [data from 'left'] + [data from 'forward'] + [data from 'right'] + [data from 'back'] + Each branch data is organized as: + [root node information] + + [recursive branch data from 'left'] + + [... from 'forward'] + + [... from 'right] + + [... from 'back'] -""" + Finally, each node information is composed of 5 floating point values: - def get_observation(self, agent): - # Get the current observation for an agent - current_position = self.internal_position[agent] - #target_heading = self._compass(agent).tolist() - coordinate = tuple(np.transpose(self._position_to_coordinate([current_position]))) - agent_distance = self.distance_map[agent][coordinate][0] - # Start tree search - if current_position == self.target[agent]: - agent_tree = Node(current_position, [-np.inf, -np.inf, -np.inf, -np.inf, -1]) - else: - agent_tree = Node(current_position, [0, 0, 0, 0, agent_distance]) - - initial_tree_state = Tree_State(agent, current_position, -1, 0, 0) - self._tree_search(initial_tree_state, agent_tree, agent) - observation = [] - distance_data = [] + #1: - self._flatten_tree(agent_tree, observation, distance_data, self.max_depth+1) - # This is probably very slow!!!! - #max_obs = np.max([i for i in observation if i < np.inf]) - #if max_obs != 0: - # observation = np.array(observation)/ max_obs + #2: - #print([i for i in distance_data if i >= 0]) - observation = np.concatenate((observation, distance_data)) - #observation = np.concatenate((observation, np.identity(5)[int(self.last_action[agent])])) - #return np.clip(observation / self.max_dist[agent], -1, 1) - return np.clip(observation / 15., -1, 1) + #3: + #4: + #5: minimum distance from node to the agent's target + Missing/padding nodes are filled in with -inf (truncated). + Missing values in present node are filled in with +inf (truncated). - def _tree_search(self, in_tree_state, parent_node, agent): - if in_tree_state.depth >= self.max_depth: - return - target_distance = np.inf - other_target = np.inf - other_agent = np.inf - coordinate = tuple(np.transpose(self._position_to_coordinate([in_tree_state.position]))) - curr_target_dist = self.distance_map[agent][coordinate][0] - forbidden_action = (in_tree_state.direction + 2) % 4 - # Update the position - failed_move = 0 - leaf_distance = in_tree_state.distance - for child_idx in range(4): - if child_idx != forbidden_action or in_tree_state.direction == -1: - tree_state = copy.deepcopy(in_tree_state) - tree_state.direction = child_idx - current_position, invalid_move = self._detect_path( - tree_state.position, tree_state.direction) - if tree_state.initial_direction == None: - tree_state.initial_direction = child_idx - if not invalid_move: - coordinate = tuple(np.transpose(self._position_to_coordinate([current_position]))) - curr_target_dist = self.distance_map[agent][coordinate][0] - #if tree_state.initial_direction == None: - # tree_state.initial_direction = child_idx - tree_state.position = current_position - tree_state.distance += 1 - - - # Collect information at the current position - detection_distance = tree_state.distance - if current_position == self.target[tree_state.agent]: - target_distance = detection_distance - - elif current_position in self.target: - other_target = detection_distance - - if current_position in self.internal_position: - other_agent = detection_distance - - tree_state.data[0] = self._min_greater_zero(target_distance, tree_state.data[0]) - tree_state.data[1] = self._min_greater_zero(other_target, tree_state.data[1]) - tree_state.data[2] = self._min_greater_zero(other_agent, tree_state.data[2]) - tree_state.data[3] = tree_state.distance - tree_state.data[4] = self._min_greater_zero(curr_target_dist, tree_state.data[4]) - - if self._switch_detection(tree_state.position): - tree_state.depth += 1 - new_tree_state = copy.deepcopy(tree_state) - new_node = parent_node.insert(tree_state.position, - tree_state.data, tree_state.initial_direction) - new_tree_state.initial_direction = None - new_tree_state.data = [np.inf, np.inf, np.inf, np.inf, np.inf] - self._tree_search(new_tree_state, new_node, agent) - else: - self._tree_search(tree_state, parent_node, agent) - else: - failed_move += 1 - if failed_move == 3 and in_tree_state.direction != -1: - tree_state.data[4] = self._min_greater_zero(curr_target_dist, tree_state.data[4]) - parent_node.insert(tree_state.position, tree_state.data, tree_state.initial_direction) - return - return - - def _flatten_tree(self, node, observation_vector, distance_sensor, depth): - if depth <= 0: - return - if node != None: - observation_vector.extend(node.data[:-1]) - distance_sensor.extend([node.data[-1]]) - else: - observation_vector.extend([-np.inf, -np.inf, -np.inf, -np.inf]) - distance_sensor.extend([-np.inf]) - for child_idx in range(4): - if node != None: - child = node.children[child_idx] + + In case of the root node, the values are [0, 0, 0, 0, distance from agent to target]. + In case the target node is reached, the values are [0, 0, 0, 0, 0]. + """ + + position = self.env.agents_position[handle] + orientation = self.env.agents_direction[handle] + + # Root node - current position + observation = [0, 0, 0, 0, self.distance_map[handle, position[0], position[1], orientation]] + + # Start from the current orientation, and see which transitions are available; + # organize them as [left, forward, right, back], relative to the current orientation + for branch_direction in [(orientation+4+i) % 4 for i in range(-1, 3)]: + if self.env.rail.get_transition((position[0], position[1], orientation), branch_direction): + new_cell = self._new_position(position, branch_direction) + + branch_observation = self._explore_branch(handle, new_cell, branch_direction, 1) + observation = observation + branch_observation else: - child = None - self._flatten_tree(child, observation_vector, distance_sensor, depth -1) + num_cells_to_fill_in = 0 + pow4 = 1 + for i in range(self.max_depth): + num_cells_to_fill_in += pow4 + pow4 *= 4 + observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf]*num_cells_to_fill_in + return observation + def _explore_branch(self, handle, position, direction, depth): + """ + Utility function to compute tree-based observations. + """ + # [Recursive branch opened] + if depth >= self.max_depth+1: + return [] + + # Continue along direction until next switch or + # until no transitions are possible along the current direction (i.e., dead-ends) + # We treat dead-ends as nodes, instead of going back, to avoid loops + exploring = True + # TODO: last_isSwitch = False + # TODO: last_isTerminal = False # dead-end + # TODO: last_isTarget = False + while exploring: + # ############################# + # ############################# + # Modify here to compute any useful data required to build the end node's features. This code is called + # for each cell visited between the previous branching node and the next switch / target / dead-end. + + # TODO: update the current variables according to the current cell in the path + # (store info about other agents and targets) + + # TODO: [[[for efficiency, [make dict for hashed-lookup of coords] -- do it in the reset function!]]] + + # ############################# + # ############################# + + # If the target node is encountered, pick that as node. Also, no further branching is possible. + if position[0] == self.env.agents_target[handle][0] and position[1] == self.env.agents_target[handle][1]: + # TODO: last_isTarget = True + break + + cell_transitions = self.env.rail.get_transitions((position[0], position[1], direction)) + num_transitions = 0 + for i in range(4): + if cell_transitions[i]: + num_transitions += 1 + + exploring = False + if num_transitions == 1: + # Check if dead-end, or if we can go forward along direction + if cell_transitions[direction]: + position = self._new_position(position, direction) + + # Keep walking through the tree along `direction' + exploring = True - def _switch_detection(self, position): - # Hack to detect switches - # This can later directly be derived from the transition matrix - paths = 0 - for i in range(4): - _, invalid_move = self._detect_path(position, i) - if not invalid_move: - paths +=1 - if paths >= 3: - return True - return False + else: + # If a dead-end is reached, pick that as node. Also, no further branching is possible. + # TODO: last_isTerminal = True + break + elif num_transitions > 0: + # Switch detected + # TODO: last_isSwitch = True + break + elif num_transitions == 0: + # Wrong cell type, but let's cover it and treat it as a dead-end, just in case + # TODO: last_isTerminal = True + break + # `position' is either a terminal node or a switch - def _min_greater_zero(self, x, y): - if x <= 0 and y <= 0: - return 0 - if x < 0: - return y - if y < 0: - return x - return min(x, y) + observation = [] + # ############################# + # ############################# + # Modify here to append new / different features for each visited cell! + observation = [0, 0, 0, 0, self.distance_map[handle, position[0], position[1], direction]] + # TODO: -""" + # ############################# + # ############################# + # Start from the current orientation, and see which transitions are available; + # organize them as [left, forward, right, back], relative to the current orientation + for branch_direction in [(direction+4+i) % 4 for i in range(-1, 3)]: + if self.env.rail.get_transition((position[0], position[1], direction), branch_direction): + new_cell = self._new_position(position, branch_direction) -class Tree_State: - """ - Keep track of the current state while building the tree - """ - def __init__(self, agent, position, direction, depth, distance): - self.agent = agent - self.position = position - self.direction = direction - self.depth = depth - self.initial_direction = None - self.distance = distance - self.data = [np.inf, np.inf, np.inf, np.inf, np.inf] - -class Node(): - """ - Define a tree node to get populated during search - """ - def __init__(self, position, data): - self.n_children = 4 - self.children = [None]*self.n_children - self.data = data - self.position = position + branch_observation = self._explore_branch(handle, new_cell, branch_direction, depth+1) + observation = observation + branch_observation - def insert(self, position, data, child_idx): - """ - Insert new node with data + else: + num_cells_to_fill_in = 0 + pow4 = 1 + for i in range(self.max_depth-depth): + num_cells_to_fill_in += pow4 + pow4 *= 4 + observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf]*num_cells_to_fill_in - @param data node data object to insert - """ - new_node = Node(position, data) - self.children[child_idx] = new_node - return new_node + return observation - def print_tree(self, i=0, depth = 0): + def util_print_obs_subtree(self, tree, num_elements_per_node=5, prompt='', current_depth=0): """ - Print tree content inorder + Utility function to pretty-print tree observations returned by this object. """ - current_i = i - curr_depth = depth+1 - if i < self.n_children: - if self.children[i] != None: - self.children[i].print_tree(depth=curr_depth) - current_i += 1 - if self.children[i] != None: - self.children[i].print_tree(i, depth=curr_depth) - + if len(tree) < num_elements_per_node: + return + depth = 0 + tmp = len(tree)/num_elements_per_node-1 + pow4 = 4 + while tmp > 0: + tmp -= pow4 + depth += 1 + pow4 *= 4 + + prompt_ = ['L:', 'F:', 'R:', 'B:'] + + print(" "*current_depth + prompt, tree[0:num_elements_per_node]) + child_size = (len(tree)-num_elements_per_node)//4 + for children in range(4): + child_tree = tree[(num_elements_per_node+children*child_size): + (num_elements_per_node+(children+1)*child_size)] + self.util_print_obs_subtree(child_tree, + num_elements_per_node, + prompt=prompt_[children], + current_depth=current_depth+1) diff --git a/flatland/envs/__init__.py b/flatland/envs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py new file mode 100644 index 0000000000000000000000000000000000000000..8b34bf3d4ea51dfd1be57c70297b45bedc7906c3 --- /dev/null +++ b/flatland/envs/rail_env.py @@ -0,0 +1,676 @@ +""" +Definition of the RailEnv environment and related level-generation functions. + +Generator functions are functions that take width, height and num_resets as arguments and return +a GridTransitionMap object. +""" +import random +import numpy as np + +from flatland.core.env import Environment +from flatland.core.env_observation_builder import TreeObsForRailEnv + +from flatland.core.transitions import RailEnvTransitions +from flatland.core.transition_map import GridTransitionMap + + +def rail_from_manual_specifications_generator(rail_spec): + """ + Utility to convert a rail given by manual specification as a map of tuples + (cell_type, rotation), to a transition map with the correct 16-bit + transitions specifications. + + Parameters + ------- + rail_spec : list of list of tuples + List (rows) of lists (columns) of tuples, each specifying a cell for + the RailEnv environment as (cell_type, rotation), with rotation being + clock-wise and in [0, 90, 180, 270]. + + Returns + ------- + function + Generator function that always returns a GridTransitionMap object with + the matrix of correct 16-bit bitmaps for each cell. + """ + def generator(width, height, num_resets=0): + t_utils = RailEnvTransitions() + + height = len(rail_spec) + width = len(rail_spec[0]) + rail = GridTransitionMap(width=width, height=height, transitions=t_utils) + + for r in range(height): + for c in range(width): + cell = rail_spec[r][c] + if cell[0] < 0 or cell[0] >= len(t_utils.transitions): + print("ERROR - invalid cell type=", cell[0]) + return [] + rail.set_transitions((r, c), t_utils.rotate_transition( + t_utils.transitions[cell[0]], cell[1])) + + return rail + + return generator + + +def rail_from_GridTransitionMap_generator(rail_map): + """ + Utility to convert a rail given by a GridTransitionMap map with the correct + 16-bit transitions specifications. + + Parameters + ------- + rail_map : GridTransitionMap object + GridTransitionMap object to return when the generator is called. + + Returns + ------- + function + Generator function that always returns the given `rail_map' object. + """ + def generator(width, height, num_resets=0): + return rail_map + + return generator + + +""" +def generate_rail_from_list_of_manual_specifications(list_of_specifications) + def generator(width, height, num_resets=0): + return generate_rail_from_manual_specifications(list_of_specifications) + + return generator +""" + + +def random_rail_generator(width, height, num_resets=0): + """ + Dummy random level generator: + - fill in cells at random in [width-2, height-2] + - keep filling cells in among the unfilled ones, such that all transitions + are legit; if no cell can be filled in without violating some + transitions, pick one among those that can satisfy most transitions + (1,2,3 or 4), and delete (+mark to be re-filled) the cells that were + incompatible. + - keep trying for a total number of insertions + (e.g., (W-2)*(H-2)*MAX_REPETITIONS ); if no solution is found, empty the + board and try again from scratch. + - finally pad the border of the map with dead-ends to avoid border issues. + + Dead-ends are not allowed inside the grid, only at the border; however, if + no cell type can be inserted in a given cell (because of the neighboring + transitions), deadends are allowed if they solve the problem. This was + found to turn most un-genereatable levels into valid ones. + + Parameters + ------- + width : int + The width (number of cells) of the grid to generate. + height : int + The height (number of cells) of the grid to generate. + + Returns + ------- + numpy.ndarray of type numpy.uint16 + The matrix with the correct 16-bit bitmaps for each cell. + """ + + t_utils = RailEnvTransitions() + + transitions_templates_ = [] + for i in range(len(t_utils.transitions)-1): # don't include dead-ends + all_transitions = 0 + for dir_ in range(4): + trans = t_utils.get_transitions(t_utils.transitions[i], dir_) + all_transitions |= (trans[0] << 3) | \ + (trans[1] << 2) | \ + (trans[2] << 1) | \ + (trans[3]) + + template = [int(x) for x in bin(all_transitions)[2:]] + template = [0]*(4-len(template)) + template + + # add all rotations + for rot in [0, 90, 180, 270]: + transitions_templates_.append((template, + t_utils.rotate_transition( + t_utils.transitions[i], + rot))) + template = [template[-1]]+template[:-1] + + def get_matching_templates(template): + ret = [] + for i in range(len(transitions_templates_)): + is_match = True + for j in range(4): + if template[j] >= 0 and \ + template[j] != transitions_templates_[i][0][j]: + is_match = False + break + if is_match: + ret.append(transitions_templates_[i][1]) + return ret + + MAX_INSERTIONS = (width-2) * (height-2) * 10 + MAX_ATTEMPTS_FROM_SCRATCH = 10 + + attempt_number = 0 + while attempt_number < MAX_ATTEMPTS_FROM_SCRATCH: + cells_to_fill = [] + rail = [] + for r in range(height): + rail.append([None]*width) + if r > 0 and r < height-1: + cells_to_fill = cells_to_fill \ + + [(r, c) for c in range(1, width-1)] + + num_insertions = 0 + while num_insertions < MAX_INSERTIONS and len(cells_to_fill) > 0: + cell = random.sample(cells_to_fill, 1)[0] + cells_to_fill.remove(cell) + row = cell[0] + col = cell[1] + + # look at its neighbors and see what are the possible transitions + # that can be chosen from, if any. + valid_template = [-1, -1, -1, -1] + + for el in [(0, 2, (-1, 0)), + (1, 3, (0, 1)), + (2, 0, (1, 0)), + (3, 1, (0, -1))]: # N, E, S, W + neigh_trans = rail[row+el[2][0]][col+el[2][1]] + if neigh_trans is not None: + # select transition coming from facing direction el[1] and + # moving to direction el[1] + max_bit = 0 + for k in range(4): + max_bit |= \ + t_utils.get_transition(neigh_trans, k, el[1]) + + if max_bit: + valid_template[el[0]] = 1 + else: + valid_template[el[0]] = 0 + + possible_cell_transitions = get_matching_templates(valid_template) + + if len(possible_cell_transitions) == 0: # NO VALID TRANSITIONS + # no cell can be filled in without violating some transitions + # can a dead-end solve the problem? + if valid_template.count(1) == 1: + for k in range(4): + if valid_template[k] == 1: + rot = 0 + if k == 0: + rot = 180 + elif k == 1: + rot = 270 + elif k == 2: + rot = 0 + elif k == 3: + rot = 90 + + rail[row][col] = t_utils.rotate_transition( + int('0000000000100000', 2), rot) + num_insertions += 1 + + break + + else: + # can I get valid transitions by removing a single + # neighboring cell? + bestk = -1 + besttrans = [] + for k in range(4): + tmp_template = valid_template[:] + tmp_template[k] = -1 + possible_cell_transitions = get_matching_templates( + tmp_template) + if len(possible_cell_transitions) > len(besttrans): + besttrans = possible_cell_transitions + bestk = k + + if bestk >= 0: + # Replace the corresponding cell with None, append it + # to cells to fill, fill in a transition in the current + # cell. + replace_row = row - 1 + replace_col = col + if bestk == 1: + replace_row = row + replace_col = col + 1 + elif bestk == 2: + replace_row = row + 1 + replace_col = col + elif bestk == 3: + replace_row = row + replace_col = col - 1 + + cells_to_fill.append((replace_row, replace_col)) + rail[replace_row][replace_col] = None + + rail[row][col] = random.sample( + besttrans, 1)[0] + num_insertions += 1 + + else: + print('WARNING: still nothing!') + rail[row][col] = int('0000000000000000', 2) + num_insertions += 1 + pass + + else: + rail[row][col] = random.sample( + possible_cell_transitions, 1)[0] + num_insertions += 1 + + if num_insertions == MAX_INSERTIONS: + # Failed to generate a valid level; try again for a number of times + attempt_number += 1 + else: + break + + if attempt_number == MAX_ATTEMPTS_FROM_SCRATCH: + print('ERROR: failed to generate level') + + # Finally pad the border of the map with dead-ends to avoid border issues; + # at most 1 transition in the neigh cell + for r in range(height): + # Check for transitions coming from [r][1] to WEST + max_bit = 0 + neigh_trans = rail[r][1] + if neigh_trans is not None: + for k in range(4): + neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \ + & (2**4-1) + max_bit = max_bit | (neigh_trans_from_direction & 1) + if max_bit: + rail[r][0] = t_utils.rotate_transition( + int('0000000000100000', 2), 270) + else: + rail[r][0] = int('0000000000000000', 2) + + # Check for transitions coming from [r][-2] to EAST + max_bit = 0 + neigh_trans = rail[r][-2] + if neigh_trans is not None: + for k in range(4): + neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \ + & (2**4-1) + max_bit = max_bit | (neigh_trans_from_direction & (1 << 2)) + if max_bit: + rail[r][-1] = t_utils.rotate_transition(int('0000000000100000', 2), + 90) + else: + rail[r][-1] = int('0000000000000000', 2) + + for c in range(width): + # Check for transitions coming from [1][c] to NORTH + max_bit = 0 + neigh_trans = rail[1][c] + if neigh_trans is not None: + for k in range(4): + neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \ + & (2**4-1) + max_bit = max_bit | (neigh_trans_from_direction & (1 << 3)) + if max_bit: + rail[0][c] = int('0000000000100000', 2) + else: + rail[0][c] = int('0000000000000000', 2) + + # Check for transitions coming from [-2][c] to SOUTH + max_bit = 0 + neigh_trans = rail[-2][c] + if neigh_trans is not None: + for k in range(4): + neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \ + & (2**4-1) + max_bit = max_bit | (neigh_trans_from_direction & (1 << 1)) + if max_bit: + rail[-1][c] = t_utils.rotate_transition( + int('0000000000100000', 2), 180) + else: + rail[-1][c] = int('0000000000000000', 2) + + # For display only, wrong levels + for r in range(height): + for c in range(width): + if rail[r][c] is None: + rail[r][c] = int('0000000000000000', 2) + + tmp_rail = np.asarray(rail, dtype=np.uint16) + return_rail = GridTransitionMap(width=width, height=height, transitions=t_utils) + return_rail.grid = tmp_rail + return return_rail + + +class RailEnv(Environment): + """ + RailEnv environment class. + + RailEnv is an environment inspired by a (simplified version of) a rail + network, in which agents (trains) have to navigate to their target + locations in the shortest time possible, while at the same time cooperating + to avoid bottlenecks. + + The valid actions in the environment are: + 0: do nothing + 1: turn left and move to the next cell + 2: move to the next cell in front of the agent + 3: turn right and move to the next cell + + Moving forward in a dead-end cell makes the agent turn 180 degrees and step + to the cell it came from. + + The actions of the agents are executed in order of their handle to prevent + deadlocks and to allow them to learn relative priorities. + + TODO: WRITE ABOUT THE REWARD FUNCTION, and possibly allow for alpha and + beta to be passed as parameters to __init__(). + """ + + def __init__(self, + width, + height, + rail_generator=random_rail_generator, + number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2)): + """ + Environment init. + + Parameters + ------- + rail_generator : function + The rail_generator function is a function that takes the width and + height of a rail map along with the number of times the env has + been reset, and returns a GridTransitionMap object. + Implemented functions are: + random_rail_generator : generate a random rail of given size + rail_from_GridTransitionMap_generator(rail_map) : generate a rail from + a GridTransitionMap object + rail_from_manual_specifications_generator(rail_spec) : generate a rail from + a rail specifications array + TODO: generate_rail_from_saved_list or from list of ndarray bitmaps --- + width : int + The width of the rail map. Potentially in the future, + a range of widths to sample from. + height : int + The height of the rail map. Potentially in the future, + a range of heights to sample from. + number_of_agents : int + Number of agents to spawn on the map. Potentially in the future, + a range of number of agents to sample from. + obs_builder_object: ObservationBuilder object + ObservationBuilder-derived object that takes builds observation + vectors for each agent. + """ + + self.rail_generator = rail_generator + self.rail = None + self.width = width + self.height = height + + self.number_of_agents = number_of_agents + + self.obs_builder = obs_builder_object + self.obs_builder._set_env(self) + + self.actions = [0]*self.number_of_agents + self.rewards = [0]*self.number_of_agents + self.done = False + + self.dones = {"__all__": False} + self.obs_dict = {} + self.rewards_dict = {} + + self.agents_handles = list(range(self.number_of_agents)) + + # self.agents_position = [] + # self.agents_target = [] + # self.agents_direction = [] + self.num_resets = 0 + self.reset() + self.num_resets = 0 + + def get_agent_handles(self): + return self.agents_handles + + def reset(self): + self.rail = self.rail_generator(self.width, self.height, self.num_resets) + self.num_resets += 1 + + self.dones = {"__all__": False} + for handle in self.agents_handles: + self.dones[handle] = False + + re_generate = True + while re_generate: + valid_positions = [] + for r in range(self.height): + for c in range(self.width): + if self.rail.get_transitions((r, c)) > 0: + valid_positions.append((r, c)) + + self.agents_position = random.sample(valid_positions, + self.number_of_agents) + self.agents_target = random.sample(valid_positions, + self.number_of_agents) + + # agents_direction must be a direction for which a solution is + # guaranteed. + self.agents_direction = [0]*self.number_of_agents + re_generate = False + for i in range(self.number_of_agents): + valid_movements = [] + for direction in range(4): + position = self.agents_position[i] + moves = self.rail.get_transitions( + (position[0], position[1], direction)) + for move_index in range(4): + if moves[move_index]: + valid_movements.append((direction, move_index)) + + valid_starting_directions = [] + for m in valid_movements: + new_position = self._new_position(self.agents_position[i], + m[1]) + if m[0] not in valid_starting_directions and \ + self._path_exists(new_position, m[0], + self.agents_target[i]): + valid_starting_directions.append(m[0]) + + if len(valid_starting_directions) == 0: + re_generate = True + else: + self.agents_direction[i] = random.sample( + valid_starting_directions, 1)[0] + + # Reset the state of the observation builder with the new environment + self.obs_builder.reset() + + # Return the new observation vectors for each agent + return self._get_observations() + + def step(self, action_dict): + alpha = 1.0 + beta = 1.0 + + invalid_action_penalty = -2 + step_penalty = -1 * alpha + global_reward = 1 * beta + + # Reset the step rewards + self.rewards_dict = {} + for handle in self.agents_handles: + self.rewards_dict[handle] = 0 + + if self.dones["__all__"]: + return self._get_observations(), self.rewards_dict, self.dones, {} + + for i in range(len(self.agents_handles)): + handle = self.agents_handles[i] + + if handle not in action_dict: + continue + + action = action_dict[handle] + + if action < 0 or action > 3: + print('ERROR: illegal action=', action, + 'for agent with handle=', handle) + return + + if action > 0: + pos = self.agents_position[i] + direction = self.agents_direction[i] + + movement = direction + if action == 1: + movement = direction - 1 + elif action == 3: + movement = direction + 1 + + if movement < 0: + movement += 4 + if movement >= 4: + movement -= 4 + + is_deadend = False + if action == 2: + # compute number of possible transitions in the current + # cell + nbits = 0 + tmp = self.rail.get_transitions((pos[0], pos[1])) + while tmp > 0: + nbits += (tmp & 1) + tmp = tmp >> 1 + if nbits == 1: + # dead-end; assuming the rail network is consistent, + # this should match the direction the agent has come + # from. But it's better to check in any case. + reverse_direction = 0 + if direction == 0: + reverse_direction = 2 + elif direction == 1: + reverse_direction = 3 + elif direction == 2: + reverse_direction = 0 + elif direction == 3: + reverse_direction = 1 + + valid_transition = self.rail.get_transition( + (pos[0], pos[1], direction), + reverse_direction) + if valid_transition: + direction = reverse_direction + movement = reverse_direction + is_deadend = True + + new_position = self._new_position(pos, movement) + + # Is it a legal move? 1) transition allows the movement in the + # cell, 2) the new cell is not empty (case 0), 3) the cell is + # free, i.e., no agent is currently in that cell + if new_position[1] >= self.width or\ + new_position[0] >= self.height or\ + new_position[0] < 0 or new_position[1] < 0: + new_cell_isValid = False + + elif self.rail.get_transitions((new_position[0], new_position[1])) > 0: + new_cell_isValid = True + else: + new_cell_isValid = False + + transition_isValid = self.rail.get_transition( + (pos[0], pos[1], direction), + movement) or is_deadend + + cell_isFree = True + for j in range(self.number_of_agents): + if self.agents_position[j] == new_position: + cell_isFree = False + break + + if new_cell_isValid and transition_isValid and cell_isFree: + # move and change direction to face the movement that was + # performed + self.agents_position[i] = new_position + self.agents_direction[i] = movement + else: + # the action was not valid, add penalty + self.rewards_dict[handle] += invalid_action_penalty + + # if agent is not in target position, add step penalty + if self.agents_position[i][0] == self.agents_target[i][0] and \ + self.agents_position[i][1] == self.agents_target[i][1]: + self.dones[handle] = True + else: + self.rewards_dict[handle] += step_penalty + + # Check for end of episode + add global reward to all rewards! + num_agents_in_target_position = 0 + for i in range(self.number_of_agents): + if self.agents_position[i][0] == self.agents_target[i][0] and \ + self.agents_position[i][1] == self.agents_target[i][1]: + num_agents_in_target_position += 1 + + if num_agents_in_target_position == self.number_of_agents: + self.dones["__all__"] = True + self.rewards_dict = [r+global_reward for r in self.rewards_dict] + + # Reset the step actions (in case some agent doesn't 'register_action' + # on the next step) + self.actions = [0]*self.number_of_agents + + return self._get_observations(), self.rewards_dict, self.dones, {} + + def _new_position(self, position, movement): + if movement == 0: # NORTH + return (position[0]-1, position[1]) + elif movement == 1: # EAST + return (position[0], position[1] + 1) + elif movement == 2: # SOUTH + return (position[0]+1, position[1]) + elif movement == 3: # WEST + return (position[0], position[1] - 1) + + def _path_exists(self, start, direction, end): + # BFS - Check if a path exists between the 2 nodes + + visited = set() + stack = [(start, direction)] + while stack: + node = stack.pop() + if node[0][0] == end[0] and node[0][1] == end[1]: + return 1 + if node not in visited: + visited.add(node) + moves = self.rail.get_transitions((node[0][0], node[0][1], node[1])) + for move_index in range(4): + if moves[move_index]: + stack.append((self._new_position(node[0], move_index), + move_index)) + + # If cell is a dead-end, append previous node with reversed + # orientation! + nbits = 0 + tmp = self.rail.get_transitions((node[0][0], node[0][1])) + while tmp > 0: + nbits += (tmp & 1) + tmp = tmp >> 1 + if nbits == 1: + stack.append((node[0], (node[1] + 2) % 4)) + + return 0 + + def _get_observations(self): + self.obs_dict = {} + for handle in self.agents_handles: + self.obs_dict[handle] = self.obs_builder.get(handle) + return self.obs_dict + + def render(self): + # TODO: + pass diff --git a/flatland/utils/rail_env_generator.py b/flatland/utils/rail_env_generator.py deleted file mode 100644 index ab11ac5e7df8281294c02fd4ef8192519b061e4d..0000000000000000000000000000000000000000 --- a/flatland/utils/rail_env_generator.py +++ /dev/null @@ -1,341 +0,0 @@ -""" -The rail_env_generator module defines provides utilities to generate env -bitmaps for the RailEnv environment. -""" -import random -import numpy as np - -from flatland.core.transitions import RailEnvTransitions -from flatland.core.transition_map import GridTransitionMap - - -def generate_rail_from_manual_specifications(rail_spec): - """ - Utility to convert a rail given by manual specification as a map of tuples - (cell_type, rotation), to a transition map with the correct 16-bit - transitions specifications. - - Parameters - ------- - rail_spec : list of list of tuples - List (rows) of lists (columns) of tuples, each specifying a cell for - the RailEnv environment as (cell_type, rotation), with rotation being - clock-wise and in [0, 90, 180, 270]. - - Returns - ------- - function - Generator function that always returns a GridTransitionMap object with - the matrix of correct 16-bit bitmaps for each cell. - """ - def generator(width, height, num_resets=0): - t_utils = RailEnvTransitions() - - height = len(rail_spec) - width = len(rail_spec[0]) - rail = GridTransitionMap(width=width, height=height, transitions=t_utils) - - for r in range(height): - for c in range(width): - cell = rail_spec[r][c] - if cell[0] < 0 or cell[0] >= len(t_utils.transitions): - print("ERROR - invalid cell type=", cell[0]) - return [] - rail.set_transitions((r, c), t_utils.rotate_transition( - t_utils.transitions[cell[0]], cell[1])) - - return rail - - return generator - - -def generate_rail_from_GridTransitionMap(rail_map): - """ - Utility to convert a rail given by a GridTransitionMap map with the correct - 16-bit transitions specifications. - - Parameters - ------- - rail_map : GridTransitionMap object - GridTransitionMap object to return when the generator is called. - - Returns - ------- - function - Generator function that always returns the given `rail_map' object. - """ - def generator(width, height, num_resets=0): - return rail_map - - return generator - - -""" -def generate_rail_from_list_of_manual_specifications(list_of_specifications) - def generator(width, height, num_resets=0): - return generate_rail_from_manual_specifications(list_of_specifications) - - return generator -""" - - -def generate_random_rail(width, height, num_resets=0): - """ - Dummy random level generator: - - fill in cells at random in [width-2, height-2] - - keep filling cells in among the unfilled ones, such that all transitions - are legit; if no cell can be filled in without violating some - transitions, pick one among those that can satisfy most transitions - (1,2,3 or 4), and delete (+mark to be re-filled) the cells that were - incompatible. - - keep trying for a total number of insertions - (e.g., (W-2)*(H-2)*MAX_REPETITIONS ); if no solution is found, empty the - board and try again from scratch. - - finally pad the border of the map with dead-ends to avoid border issues. - - Dead-ends are not allowed inside the grid, only at the border; however, if - no cell type can be inserted in a given cell (because of the neighboring - transitions), deadends are allowed if they solve the problem. This was - found to turn most un-genereatable levels into valid ones. - - Parameters - ------- - width : int - The width (number of cells) of the grid to generate. - height : int - The height (number of cells) of the grid to generate. - - Returns - ------- - numpy.ndarray of type numpy.uint16 - The matrix with the correct 16-bit bitmaps for each cell. - """ - - t_utils = RailEnvTransitions() - - transitions_templates_ = [] - for i in range(len(t_utils.transitions)-1): # don't include dead-ends - all_transitions = 0 - for dir_ in range(4): - trans = t_utils.get_transitions(t_utils.transitions[i], dir_) - all_transitions |= (trans[0] << 3) | \ - (trans[1] << 2) | \ - (trans[2] << 1) | \ - (trans[3]) - - template = [int(x) for x in bin(all_transitions)[2:]] - template = [0]*(4-len(template)) + template - - # add all rotations - for rot in [0, 90, 180, 270]: - transitions_templates_.append((template, - t_utils.rotate_transition( - t_utils.transitions[i], - rot))) - template = [template[-1]]+template[:-1] - - def get_matching_templates(template): - ret = [] - for i in range(len(transitions_templates_)): - is_match = True - for j in range(4): - if template[j] >= 0 and \ - template[j] != transitions_templates_[i][0][j]: - is_match = False - break - if is_match: - ret.append(transitions_templates_[i][1]) - return ret - - MAX_INSERTIONS = (width-2) * (height-2) * 10 - MAX_ATTEMPTS_FROM_SCRATCH = 10 - - attempt_number = 0 - while attempt_number < MAX_ATTEMPTS_FROM_SCRATCH: - cells_to_fill = [] - rail = [] - for r in range(height): - rail.append([None]*width) - if r > 0 and r < height-1: - cells_to_fill = cells_to_fill \ - + [(r, c) for c in range(1, width-1)] - - num_insertions = 0 - while num_insertions < MAX_INSERTIONS and len(cells_to_fill) > 0: - cell = random.sample(cells_to_fill, 1)[0] - cells_to_fill.remove(cell) - row = cell[0] - col = cell[1] - - # look at its neighbors and see what are the possible transitions - # that can be chosen from, if any. - valid_template = [-1, -1, -1, -1] - - for el in [(0, 2, (-1, 0)), - (1, 3, (0, 1)), - (2, 0, (1, 0)), - (3, 1, (0, -1))]: # N, E, S, W - neigh_trans = rail[row+el[2][0]][col+el[2][1]] - if neigh_trans is not None: - # select transition coming from facing direction el[1] and - # moving to direction el[1] - max_bit = 0 - for k in range(4): - max_bit |= \ - t_utils.get_transition(neigh_trans, k, el[1]) - - if max_bit: - valid_template[el[0]] = 1 - else: - valid_template[el[0]] = 0 - - possible_cell_transitions = get_matching_templates(valid_template) - - if len(possible_cell_transitions) == 0: # NO VALID TRANSITIONS - # no cell can be filled in without violating some transitions - # can a dead-end solve the problem? - if valid_template.count(1) == 1: - for k in range(4): - if valid_template[k] == 1: - rot = 0 - if k == 0: - rot = 180 - elif k == 1: - rot = 270 - elif k == 2: - rot = 0 - elif k == 3: - rot = 90 - - rail[row][col] = t_utils.rotate_transition( - int('0000000000100000', 2), rot) - num_insertions += 1 - - break - - else: - # can I get valid transitions by removing a single - # neighboring cell? - bestk = -1 - besttrans = [] - for k in range(4): - tmp_template = valid_template[:] - tmp_template[k] = -1 - possible_cell_transitions = get_matching_templates( - tmp_template) - if len(possible_cell_transitions) > len(besttrans): - besttrans = possible_cell_transitions - bestk = k - - if bestk >= 0: - # Replace the corresponding cell with None, append it - # to cells to fill, fill in a transition in the current - # cell. - replace_row = row - 1 - replace_col = col - if bestk == 1: - replace_row = row - replace_col = col + 1 - elif bestk == 2: - replace_row = row + 1 - replace_col = col - elif bestk == 3: - replace_row = row - replace_col = col - 1 - - cells_to_fill.append((replace_row, replace_col)) - rail[replace_row][replace_col] = None - - rail[row][col] = random.sample( - besttrans, 1)[0] - num_insertions += 1 - - else: - print('WARNING: still nothing!') - rail[row][col] = int('0000000000000000', 2) - num_insertions += 1 - pass - - else: - rail[row][col] = random.sample( - possible_cell_transitions, 1)[0] - num_insertions += 1 - - if num_insertions == MAX_INSERTIONS: - # Failed to generate a valid level; try again for a number of times - attempt_number += 1 - else: - break - - if attempt_number == MAX_ATTEMPTS_FROM_SCRATCH: - print('ERROR: failed to generate level') - - # Finally pad the border of the map with dead-ends to avoid border issues; - # at most 1 transition in the neigh cell - for r in range(height): - # Check for transitions coming from [r][1] to WEST - max_bit = 0 - neigh_trans = rail[r][1] - if neigh_trans is not None: - for k in range(4): - neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \ - & (2**4-1) - max_bit = max_bit | (neigh_trans_from_direction & 1) - if max_bit: - rail[r][0] = t_utils.rotate_transition( - int('0000000000100000', 2), 270) - else: - rail[r][0] = int('0000000000000000', 2) - - # Check for transitions coming from [r][-2] to EAST - max_bit = 0 - neigh_trans = rail[r][-2] - if neigh_trans is not None: - for k in range(4): - neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \ - & (2**4-1) - max_bit = max_bit | (neigh_trans_from_direction & (1 << 2)) - if max_bit: - rail[r][-1] = t_utils.rotate_transition(int('0000000000100000', 2), - 90) - else: - rail[r][-1] = int('0000000000000000', 2) - - for c in range(width): - # Check for transitions coming from [1][c] to NORTH - max_bit = 0 - neigh_trans = rail[1][c] - if neigh_trans is not None: - for k in range(4): - neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \ - & (2**4-1) - max_bit = max_bit | (neigh_trans_from_direction & (1 << 3)) - if max_bit: - rail[0][c] = int('0000000000100000', 2) - else: - rail[0][c] = int('0000000000000000', 2) - - # Check for transitions coming from [-2][c] to SOUTH - max_bit = 0 - neigh_trans = rail[-2][c] - if neigh_trans is not None: - for k in range(4): - neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \ - & (2**4-1) - max_bit = max_bit | (neigh_trans_from_direction & (1 << 1)) - if max_bit: - rail[-1][c] = t_utils.rotate_transition( - int('0000000000100000', 2), 180) - else: - rail[-1][c] = int('0000000000000000', 2) - - # For display only, wrong levels - for r in range(height): - for c in range(width): - if rail[r][c] is None: - rail[r][c] = int('0000000000000000', 2) - - tmp_rail = np.asarray(rail, dtype=np.uint16) - return_rail = GridTransitionMap(width=width, height=height, transitions=t_utils) - return_rail.grid = tmp_rail - return return_rail diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index d0a78917dcee833eac2d5ac7567436546635d044..5365800b6cc3ae81ab1c5e7936c96b703c933d79 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -6,6 +6,8 @@ import xarray as xr import matplotlib.pyplot as plt +# TODO: suggested renaming to RailEnvRenderTool, as it will only work with RailEnv! + class RenderTool(object): Visit = recordtype("Visit", ["rc", "iDir", "iDepth", "prev"]) diff --git a/tests/test_environments.py b/tests/test_environments.py index ce9fbd4f010413140ef09a843db0c7657005524b..b62a2e60d03b16a89e026a2ac99ea7f9504d6ec3 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -1,10 +1,9 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -from flatland.core.env import RailEnv +from flatland.core.env import RailEnv, rail_from_GridTransitionMap_generator from flatland.core.transitions import Grid4Transitions from flatland.core.transition_map import GridTransitionMap -from flatland.utils.rail_env_generator import generate_rail_from_GridTransitionMap import numpy as np """Tests for `flatland` package.""" @@ -47,7 +46,10 @@ def test_rail_environment_single_agent(): rail = GridTransitionMap(width=3, height=3, transitions=transitions) rail.grid = rail_map - rail_env = RailEnv(width=3, height=3, rail_generator=generate_rail_from_GridTransitionMap(rail), number_of_agents=1) + rail_env = RailEnv(width=3, + height=3, + rail_generator=rail_from_GridTransitionMap_generator(rail), + number_of_agents=1) for _ in range(200): _ = rail_env.reset() @@ -121,7 +123,7 @@ def test_dead_end(): rail.grid = rail_map rail_env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], - rail_generator=generate_rail_from_GridTransitionMap(rail), + rail_generator=rail_from_GridTransitionMap_generator(rail), number_of_agents=1) def check_consistency(rail_env): @@ -170,7 +172,7 @@ def test_dead_end(): rail.grid = rail_map rail_env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], - rail_generator=generate_rail_from_GridTransitionMap(rail), + rail_generator=rail_from_GridTransitionMap_generator(rail), number_of_agents=1) rail_env.reset() diff --git a/tests/test_rendertools.py b/tests/test_rendertools.py index 5fecd085c646e8a19e7be553eea258105865a8f4..427ff2f70b12e7f2f1f51f5ed4960b549470e739 100644 --- a/tests/test_rendertools.py +++ b/tests/test_rendertools.py @@ -4,14 +4,13 @@ Tests for `flatland` package. """ -from flatland.core.env import RailEnv +from flatland.envs.rail_env import RailEnv, random_rail_generator import numpy as np import random import os import matplotlib.pyplot as plt -from flatland.utils import rail_env_generator import flatland.utils.rendertools as rt @@ -37,7 +36,7 @@ def checkFrozenImage(sFileImage): def test_render_env(): random.seed(100) - oEnv = RailEnv(width=10, height=10, rail_generator=rail_env_generator.generate_random_rail, number_of_agents=2) + oEnv = RailEnv(width=10, height=10, rail_generator=random_rail_generator, number_of_agents=2) oEnv.reset() oRT = rt.RenderTool(oEnv) plt.figure(figsize=(10, 10))