diff --git a/Makefile b/Makefile index 43d86866d18f6270b5a88be7492b7ab24885b193..e23fd6b899d9396945aabec2fdbf6ad6269e9f51 100644 --- a/Makefile +++ b/Makefile @@ -71,7 +71,8 @@ docs: ## generate Sphinx HTML documentation, including API docs sphinx-apidoc -o docs/ flatland $(MAKE) -C docs clean $(MAKE) -C docs html - export HOME=$(pwd) + # N.B. HOME variable required by pydeps! + export HOME=${PWD} python3 -m pydeps flatland -o docs/_build/html/flatland.svg $(BROWSER) docs/_build/html/index.html diff --git a/examples/play_model.py b/examples/play_model.py index f80d9a14c54c1fa411a8736e4891ba5b95c188d8..62726c24c96be0e5dae2f4840e18da452163b7ac 100644 --- a/examples/play_model.py +++ b/examples/play_model.py @@ -1,6 +1,5 @@ from flatland.envs.rail_env import RailEnv from flatland.envs.generators import complex_rail_generator -# from flatland.core.env_observation_builder import TreeObsForRailEnv from flatland.utils.rendertools import RenderTool from flatland.baselines.dueling_double_dqn import Agent from collections import deque diff --git a/examples/temporary_example.py b/examples/temporary_example.py index 0ed2f6207683b0983a6c8a9783c6677834437bd1..1f3504f221d59d0205974bb135c2237364a22e07 100644 --- a/examples/temporary_example.py +++ b/examples/temporary_example.py @@ -4,7 +4,7 @@ import matplotlib.pyplot as plt from flatland.envs.rail_env import * from flatland.envs.generators import * -from flatland.core.env_observation_builder import TreeObsForRailEnv +from flatland.envs.observations import TreeObsForRailEnv from flatland.utils.rendertools import * random.seed(0) diff --git a/examples/training_navigation.py b/examples/training_navigation.py index 9d45cd175c7d324804a2a34e278428aebad69e28..23970a9059f15ce89b59b1cbcb9b862d248d1cd5 100644 --- a/examples/training_navigation.py +++ b/examples/training_navigation.py @@ -1,6 +1,6 @@ from flatland.envs.rail_env import * from flatland.envs.generators import * -from flatland.core.env_observation_builder import TreeObsForRailEnv +from flatland.envs.observations import TreeObsForRailEnv from flatland.utils.rendertools import * from flatland.baselines.dueling_double_dqn import Agent from collections import deque @@ -32,8 +32,8 @@ env = RailEnv(width=10, """ env = RailEnv(width=15, height=15, - rail_generator=complex_rail_generator(nr_start_goal=5, min_dist=5, max_dist=99999, seed=0), - number_of_agents=5) + rail_generator=complex_rail_generator(nr_start_goal=10, min_dist=5, max_dist=99999, seed=0), + number_of_agents=3) """ env = RailEnv(width=20, height=20, @@ -50,7 +50,7 @@ action_size = 4 n_trials = 15000 eps = 1. eps_end = 0.005 -eps_decay = 0.998 +eps_decay = 0.9995 action_dict = dict() final_action_dict = dict() scores_window = deque(maxlen=100) @@ -62,9 +62,9 @@ action_prob = [0] * 4 agent_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents() agent = Agent(state_size, action_size, "FC", 0) -# agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint15000.pth')) +agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint15000.pth')) -demo = False +demo = True def max_lt(seq, val): diff --git a/flatland/baselines/Nets/avoid_checkpoint15000.pth b/flatland/baselines/Nets/avoid_checkpoint15000.pth index adcfe61576553bbf0e2b4ba00d9fffafbfd9d7da..14882a37a86085b137f4422b6bba75f387a2d3b5 100644 Binary files a/flatland/baselines/Nets/avoid_checkpoint15000.pth and b/flatland/baselines/Nets/avoid_checkpoint15000.pth differ diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py index 8e7f2ae57f5e27b586365e870e2acbc8666b912e..09a624e872200e29ede834d272f7de506d6de076 100644 --- a/flatland/core/env_observation_builder.py +++ b/flatland/core/env_observation_builder.py @@ -8,10 +8,6 @@ The ObservationBuilder-derived custom classes implement 2 functions, reset() and case of multi-agent environments. """ -import numpy as np - -from collections import deque - class ObservationBuilder: """ @@ -46,511 +42,3 @@ class ObservationBuilder: An observation structure, specific to the corresponding environment. """ raise NotImplementedError() - - -class TreeObsForRailEnv(ObservationBuilder): - """ - TreeObsForRailEnv object. - - This object returns observation vectors for agents in the RailEnv environment. - The information is local to each agent and exploits the tree structure of the rail - network to simplify the representation of the state of the environment for each agent. - """ - - def __init__(self, max_depth): - self.max_depth = max_depth - - def reset(self): - agents = self.env.agents - nAgents = len(agents) - self.distance_map = np.inf * np.ones(shape=(nAgents, # self.env.number_of_agents, - self.env.height, - self.env.width, - 4)) - self.max_dist = np.zeros(nAgents) - - # for i in range(nAgents): - # self.max_dist[i] = self._distance_map_walker(self.env.agents_target[i], i) - self.max_dist = [self._distance_map_walker(agent.target, i) for i, agent in enumerate(agents)] - - # Update local lookup table for all agents' target locations - self.location_has_target = {} - # for loc in self.env.agents_target: - # self.location_has_target[(loc[0], loc[1])] = 1 - self.location_has_target = {tuple(agent.target): 1 for agent in agents} - - def _distance_map_walker(self, position, target_nr): - """ - Utility function to compute distance maps from each cell in the rail network (and each possible - orientation within it) to each agent's target cell. - """ - # Returns max distance to target, from the farthest away node, while filling in distance_map - - self.distance_map[target_nr, position[0], position[1], :] = 0 - - # Fill in the (up to) 4 neighboring nodes - # nodes_queue = [] # list of tuples (row, col, direction, distance); - # direction is the direction of movement, meaning that at least a possible orientation of an agent - # in cell (row,col) allows a movement in direction `direction' - nodes_queue = deque(self._get_and_update_neighbors(position, target_nr, 0, enforce_target_direction=-1)) - - # BFS from target `position' to all the reachable nodes in the grid - # Stop the search if the target position is re-visited, in any direction - visited = set([(position[0], position[1], 0), - (position[0], position[1], 1), - (position[0], position[1], 2), - (position[0], position[1], 3)]) - - max_distance = 0 - - while nodes_queue: - node = nodes_queue.popleft() - - node_id = (node[0], node[1], node[2]) - - if node_id not in visited: - visited.add(node_id) - - # From the list of possible neighbors that have at least a path to the current node, only keep those - # whose new orientation in the current cell would allow a transition to direction node[2] - valid_neighbors = self._get_and_update_neighbors((node[0], node[1]), target_nr, node[3], node[2]) - - for n in valid_neighbors: - nodes_queue.append(n) - - if len(valid_neighbors) > 0: - max_distance = max(max_distance, node[3] + 1) - - return max_distance - - def _get_and_update_neighbors(self, position, target_nr, current_distance, enforce_target_direction=-1): - """ - Utility function used by _distance_map_walker to perform a BFS walk over the rail, filling in the - minimum distances from each target cell. - """ - neighbors = [] - - possible_directions = [0, 1, 2, 3] - if enforce_target_direction >= 0: - # The agent must land into the current cell with orientation `enforce_target_direction'. - # This is only possible if the agent has arrived from the cell in the opposite direction! - possible_directions = [(enforce_target_direction + 2) % 4] - - for neigh_direction in possible_directions: - new_cell = self._new_position(position, neigh_direction) - - if new_cell[0] >= 0 and new_cell[0] < self.env.height and new_cell[1] >= 0 and new_cell[1] < self.env.width: - - desired_movement_from_new_cell = (neigh_direction + 2) % 4 - - """ - # Is the next cell a dead-end? - isNextCellDeadEnd = False - nbits = 0 - tmp = self.env.rail.get_transitions((new_cell[0], new_cell[1])) - while tmp > 0: - nbits += (tmp & 1) - tmp = tmp >> 1 - if nbits == 1: - # Dead-end! - isNextCellDeadEnd = True - """ - - # Check all possible transitions in new_cell - for agent_orientation in range(4): - # Is a transition along movement `desired_movement_from_new_cell' to the current cell possible? - isValid = self.env.rail.get_transition((new_cell[0], new_cell[1], agent_orientation), - desired_movement_from_new_cell) - - if isValid: - """ - # TODO: check that it works with deadends! -- still bugged! - movement = desired_movement_from_new_cell - if isNextCellDeadEnd: - movement = (desired_movement_from_new_cell+2) % 4 - """ - new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation], - current_distance + 1) - neighbors.append((new_cell[0], new_cell[1], agent_orientation, new_distance)) - self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation] = new_distance - - return neighbors - - def _new_position(self, position, movement): - """ - Utility function that converts a compass movement over a 2D grid to new positions (r, c). - """ - if movement == 0: # NORTH - return (position[0] - 1, position[1]) - elif movement == 1: # EAST - return (position[0], position[1] + 1) - elif movement == 2: # SOUTH - return (position[0] + 1, position[1]) - elif movement == 3: # WEST - return (position[0], position[1] - 1) - - def get(self, handle): - """ - Computes the current observation for agent `handle' in env - - The observation vector is composed of 4 sequential parts, corresponding to data from the up to 4 possible - movements in a RailEnv (up to because only a subset of possible transitions are allowed in RailEnv). - The possible movements are sorted relative to the current orientation of the agent, rather than NESW as for - the transitions. The order is: - [data from 'left'] + [data from 'forward'] + [data from 'right'] + [data from 'back'] - - - - - - Each branch data is organized as: - [root node information] + - [recursive branch data from 'left'] + - [... from 'forward'] + - [... from 'right] + - [... from 'back'] - - Finally, each node information is composed of 5 floating point values: - - #1: - - #2: 1 if a target of another agent is detected between the previous node and the current one. - - #3: 1 if another agent is detected between the previous node and the current one. - - #4: distance of agent to the current branch node - - #5: minimum distance from node to the agent's target (when landing to the node following the corresponding - branch. - - Missing/padding nodes are filled in with -inf (truncated). - Missing values in present node are filled in with +inf (truncated). - - - In case of the root node, the values are [0, 0, 0, 0, distance from agent to target]. - In case the target node is reached, the values are [0, 0, 0, 0, 0]. - """ - - # Update local lookup table for all agents' positions - # self.location_has_agent = {} - # for loc in self.env.agents_position: - # self.location_has_agent[(loc[0], loc[1])] = 1 - self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents} - - agent = self.env.agents[handle] # TODO: handle being treated as index - # position = self.env.agents_position[handle] - # orientation = self.env.agents_direction[handle] - possible_transitions = self.env.rail.get_transitions((*agent.position, agent.direction)) - num_transitions = np.count_nonzero(possible_transitions) - # Root node - current position - # observation = [0, 0, 0, 0, self.distance_map[handle, position[0], position[1], orientation]] - observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)]] - root_observation = observation[:] - - # Start from the current orientation, and see which transitions are available; - # organize them as [left, forward, right, back], relative to the current orientation - # If only one transition is possible, the tree is oriented with this transition as the forward branch. - # TODO: Test if this works as desired! - orientation = agent.direction - if num_transitions == 1: - orientation == np.argmax(possible_transitions) - - # for branch_direction in [(orientation + 4 + i) % 4 for i in range(-1, 3)]: - for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]: - if possible_transitions[branch_direction]: - new_cell = self._new_position(agent.position, branch_direction) - - branch_observation = self._explore_branch(handle, new_cell, branch_direction, root_observation, 1) - observation = observation + branch_observation - else: - num_cells_to_fill_in = 0 - pow4 = 1 - for i in range(self.max_depth): - num_cells_to_fill_in += pow4 - pow4 *= 4 - observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in - return observation - - def _explore_branch(self, handle, position, direction, root_observation, depth): - """ - Utility function to compute tree-based observations. - """ - # [Recursive branch opened] - if depth >= self.max_depth + 1: - return [] - - # Continue along direction until next switch or - # until no transitions are possible along the current direction (i.e., dead-ends) - # We treat dead-ends as nodes, instead of going back, to avoid loops - exploring = True - last_isSwitch = False - last_isDeadEnd = False - last_isTerminal = False # wrong cell OR cycle; either way, we don't want the agent to land here - last_isTarget = False - - visited = set() - - # other_agent_encountered = False - # other_target_encountered = False - other_agent_encountered = np.inf - other_target_encountered = np.inf - - num_steps = 1 - while exploring: - # ############################# - # ############################# - # Modify here to compute any useful data required to build the end node's features. This code is called - # for each cell visited between the previous branching node and the next switch / target / dead-end. - if position in self.location_has_agent: - # other_agent_encountered = True - if num_steps < other_agent_encountered: - other_agent_encountered = num_steps - - if position in self.location_has_target: - # other_target_encountered = True - if num_steps < other_target_encountered: - other_target_encountered = num_steps - # ############################# - # ############################# - - if (position[0], position[1], direction) in visited: - last_isTerminal = True - break - visited.add((position[0], position[1], direction)) - - # If the target node is encountered, pick that as node. Also, no further branching is possible. - # if position[0] == self.env.agents_target[handle][0] and position[1] == self.env.agents_target[handle][1]: - if np.array_equal(position, self.env.agents[handle].target): - last_isTarget = True - break - - cell_transitions = self.env.rail.get_transitions((*position, direction)) - num_transitions = np.count_nonzero(cell_transitions) - exploring = False - if num_transitions == 1: - # Check if dead-end, or if we can go forward along direction - nbits = 0 - tmp = self.env.rail.get_transitions(tuple(position)) - while tmp > 0: - nbits += (tmp & 1) - tmp = tmp >> 1 - if nbits == 1: - # Dead-end! - last_isDeadEnd = True - - if not last_isDeadEnd: - # Keep walking through the tree along `direction' - exploring = True - direction = np.argmax(cell_transitions) - position = self._new_position(position, direction) - num_steps += 1 - - elif num_transitions > 0: - # Switch detected - last_isSwitch = True - break - - elif num_transitions == 0: - # Wrong cell type, but let's cover it and treat it as a dead-end, just in case - print("WRONG CELL TYPE detected in tree-search (0 transitions possible) at cell", position[0], - position[1], direction) - last_isTerminal = True - break - - # `position' is either a terminal node or a switch - - observation = [] - - # ############################# - # ############################# - # Modify here to append new / different features for each visited cell! - """ - if last_isTarget: - observation = [0, - 1 if other_target_encountered else 0, - 1 if other_agent_encountered else 0, - root_observation[3] + num_steps, - 0] - - elif last_isTerminal: - observation = [0, - 1 if other_target_encountered else 0, - 1 if other_agent_encountered else 0, - np.inf, - np.inf] - else: - observation = [0, - 1 if other_target_encountered else 0, - 1 if other_agent_encountered else 0, - root_observation[3] + num_steps, - self.distance_map[handle, position[0], position[1], direction]] - """ - if last_isTarget: - observation = [0, - other_target_encountered, - other_agent_encountered, - root_observation[3] + num_steps, - 0] - - elif last_isTerminal: - observation = [0, - other_target_encountered, - other_agent_encountered, - np.inf, - np.inf] - else: - observation = [0, - other_target_encountered, - other_agent_encountered, - root_observation[3] + num_steps, - self.distance_map[handle, position[0], position[1], direction]] - # ############################# - # ############################# - - new_root_observation = observation[:] - # Start from the current orientation, and see which transitions are available; - # organize them as [left, forward, right, back], relative to the current orientation - # Get the possible transitions - possible_transitions = self.env.rail.get_transitions((*position, direction)) - for branch_direction in [(direction + 4 + i) % 4 for i in range(-1, 3)]: - if last_isDeadEnd and self.env.rail.get_transition((*position, direction), - (branch_direction + 2) % 4): - # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes - # it back - new_cell = self._new_position(position, (branch_direction + 2) % 4) - branch_observation = self._explore_branch(handle, - new_cell, - (branch_direction + 2) % 4, - new_root_observation, - depth + 1) - observation = observation + branch_observation - - elif last_isSwitch and possible_transitions[branch_direction]: - new_cell = self._new_position(position, branch_direction) - branch_observation = self._explore_branch(handle, - new_cell, - branch_direction, - new_root_observation, - depth + 1) - observation = observation + branch_observation - - else: - num_cells_to_fill_in = 0 - pow4 = 1 - for i in range(self.max_depth - depth): - num_cells_to_fill_in += pow4 - pow4 *= 4 - observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in - - return observation - - def util_print_obs_subtree(self, tree, num_features_per_node=5, prompt='', current_depth=0): - """ - Utility function to pretty-print tree observations returned by this object. - """ - if len(tree) < num_features_per_node: - return - - depth = 0 - tmp = len(tree) / num_features_per_node - 1 - pow4 = 4 - while tmp > 0: - tmp -= pow4 - depth += 1 - pow4 *= 4 - - prompt_ = ['L:', 'F:', 'R:', 'B:'] - - print(" " * current_depth + prompt, tree[0:num_features_per_node]) - child_size = (len(tree) - num_features_per_node) // 4 - for children in range(4): - child_tree = tree[(num_features_per_node + children * child_size): - (num_features_per_node + (children + 1) * child_size)] - self.util_print_obs_subtree(child_tree, - num_features_per_node, - prompt=prompt_[children], - current_depth=current_depth + 1) - - def split_tree(self, tree, num_features_per_node=5, current_depth=0): - """ - - :param tree: - :param num_features_per_node: - :param prompt: - :param current_depth: - :return: - """ - - if len(tree) < num_features_per_node: - return [], [] - - depth = 0 - tmp = len(tree) / num_features_per_node - 1 - pow4 = 4 - while tmp > 0: - tmp -= pow4 - depth += 1 - pow4 *= 4 - child_size = (len(tree) - num_features_per_node) // 4 - tree_data = tree[0:num_features_per_node - 1].tolist() - distance_data = [tree[num_features_per_node - 1]] - for children in range(4): - child_tree = tree[(num_features_per_node + children * child_size): - (num_features_per_node + (children + 1) * child_size)] - tmp_tree_data, tmp_distance_data = self.split_tree(child_tree, - num_features_per_node, - current_depth=current_depth + 1) - if len(tmp_tree_data) > 0: - tree_data.extend(tmp_tree_data) - distance_data.extend(tmp_distance_data) - return tree_data, distance_data - - -class GlobalObsForRailEnv(ObservationBuilder): - """ - Gives a global observation of the entire rail environment. - The observation is composed of the following elements: - - - transition map array with dimensions (env.height, env.width, 16), - assuming 16 bits encoding of transitions. - - - Four 2D arrays containing respectively the position of the given agent, - the position of its target, the positions of the other agents and of - their target. - - - A 4 elements array with one of encoding of the direction. - """ - - def __init__(self): - super(GlobalObsForRailEnv, self).__init__() - - def reset(self): - self.rail_obs = np.zeros((self.env.height, self.env.width, 16)) - for i in range(self.rail_obs.shape[0]): - for j in range(self.rail_obs.shape[1]): - self.rail_obs[i, j] = np.array( - list(f'{self.env.rail.get_transitions((i, j)):016b}')).astype(int) - - # self.targets = np.zeros(self.env.height, self.env.width) - # for target_pos in self.env.agents_target: - # self.targets[target_pos] += 1 - - def get(self, handle): - obs = np.zeros((4, self.env.height, self.env.width)) - agents = self.env.agents - agent = agents[handle] - - agent_pos = agents[handle].position - obs[0][agent_pos] += 1 - obs[1][agent.target] += 1 - - for i in range(len(agents)): - if i != handle: # TODO: handle used as index...? - agent2 = agents[i] - obs[3][agent2.position] += 1 - obs[2][agent2.target] += 1 - - direction = np.zeros(4) - direction[agent.direction] = 1 - - return self.rail_obs, obs, direction diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py index 49209460dd191f1628aabac49d3ce5b72df7e817..1f1bc1db065057a5ced1c917c5a24e2decebf2fa 100644 --- a/flatland/envs/agent_utils.py +++ b/flatland/envs/agent_utils.py @@ -23,14 +23,21 @@ class EnvAgentStatic(object): position = attrib() direction = attrib() target = attrib() - old_direction = attrib(default=None) + + def __init__(self, position, direction, target): + self.position = position + self.direction = direction + self.target = target @classmethod def from_lists(cls, positions, directions, targets): """ Create a list of EnvAgentStatics from lists of positions, directions and targets """ return list(starmap(EnvAgentStatic, zip(positions, directions, targets))) - + + def to_list(self): + return [self.position, self.direction, self.target] + @attrs class EnvAgent(EnvAgentStatic): @@ -41,6 +48,15 @@ class EnvAgent(EnvAgentStatic): forcing the env to refer to it in the EnvAgentStatic """ handle = attrib(default=None) + old_direction = attrib(default=None) + + def __init__(self, position, direction, target, handle, old_direction): + super(EnvAgent, self).__init__(position, direction, target) + self.handle = handle + self.old_direction = old_direction + + def to_list(self): + return [self.position, self.direction, self.target, self.handle, self.old_direction] @classmethod def from_static(cls, oStatic): @@ -56,7 +72,6 @@ class EnvAgent(EnvAgentStatic): """ if handles is None: handles = range(len(lEnvAgentStatic)) - + return [EnvAgent(**oEAS.__dict__, handle=handle) for handle, oEAS in zip(handles, lEnvAgentStatic)] - diff --git a/flatland/envs/env_utils.py b/flatland/envs/env_utils.py index a1e46db6cf05f3093c403ecb772bf8bf4bca5469..b58604c6d7ededa28a33d30e87e13777a3cd54ec 100644 --- a/flatland/envs/env_utils.py +++ b/flatland/envs/env_utils.py @@ -8,7 +8,7 @@ a GridTransitionMap object. import numpy as np # from flatland.core.env import Environment -# from flatland.core.env_observation_builder import TreeObsForRailEnv +# from flatland.envs.observations import TreeObsForRailEnv # from flatland.core.transitions import Grid8Transitions, RailEnvTransitions # from flatland.core.transition_map import GridTransitionMap diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index e9e2d3da95e6bf56948b6610fa25544efd3fe90c..04e9a8fefac4ada79527f00200dc6fbfa3d7b924 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -1,7 +1,7 @@ import numpy as np # from flatland.core.env import Environment -# from flatland.core.env_observation_builder import TreeObsForRailEnv +# from flatland.envs.observations import TreeObsForRailEnv from flatland.core.transitions import Grid8Transitions, RailEnvTransitions from flatland.core.transition_map import GridTransitionMap diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py new file mode 100644 index 0000000000000000000000000000000000000000..8e4be0ba49d8405aa420d2bc8e4854f6300e3837 --- /dev/null +++ b/flatland/envs/observations.py @@ -0,0 +1,515 @@ +""" +Collection of environment-specific ObservationBuilder. +""" +import numpy as np +from collections import deque + +from flatland.core.env_observation_builder import ObservationBuilder + + +class TreeObsForRailEnv(ObservationBuilder): + """ + TreeObsForRailEnv object. + + This object returns observation vectors for agents in the RailEnv environment. + The information is local to each agent and exploits the tree structure of the rail + network to simplify the representation of the state of the environment for each agent. + """ + + def __init__(self, max_depth): + self.max_depth = max_depth + + def reset(self): + agents = self.env.agents + nAgents = len(agents) + self.distance_map = np.inf * np.ones(shape=(nAgents, # self.env.number_of_agents, + self.env.height, + self.env.width, + 4)) + self.max_dist = np.zeros(nAgents) + + # for i in range(nAgents): + # self.max_dist[i] = self._distance_map_walker(self.env.agents_target[i], i) + self.max_dist = [self._distance_map_walker(agent.target, i) for i, agent in enumerate(agents)] + + # Update local lookup table for all agents' target locations + self.location_has_target = {} + # for loc in self.env.agents_target: + # self.location_has_target[(loc[0], loc[1])] = 1 + self.location_has_target = {tuple(agent.target): 1 for agent in agents} + + def _distance_map_walker(self, position, target_nr): + """ + Utility function to compute distance maps from each cell in the rail network (and each possible + orientation within it) to each agent's target cell. + """ + # Returns max distance to target, from the farthest away node, while filling in distance_map + + self.distance_map[target_nr, position[0], position[1], :] = 0 + + # Fill in the (up to) 4 neighboring nodes + # nodes_queue = [] # list of tuples (row, col, direction, distance); + # direction is the direction of movement, meaning that at least a possible orientation of an agent + # in cell (row,col) allows a movement in direction `direction' + nodes_queue = deque(self._get_and_update_neighbors(position, target_nr, 0, enforce_target_direction=-1)) + + # BFS from target `position' to all the reachable nodes in the grid + # Stop the search if the target position is re-visited, in any direction + visited = set([(position[0], position[1], 0), + (position[0], position[1], 1), + (position[0], position[1], 2), + (position[0], position[1], 3)]) + + max_distance = 0 + + while nodes_queue: + node = nodes_queue.popleft() + + node_id = (node[0], node[1], node[2]) + + if node_id not in visited: + visited.add(node_id) + + # From the list of possible neighbors that have at least a path to the current node, only keep those + # whose new orientation in the current cell would allow a transition to direction node[2] + valid_neighbors = self._get_and_update_neighbors((node[0], node[1]), target_nr, node[3], node[2]) + + for n in valid_neighbors: + nodes_queue.append(n) + + if len(valid_neighbors) > 0: + max_distance = max(max_distance, node[3] + 1) + + return max_distance + + def _get_and_update_neighbors(self, position, target_nr, current_distance, enforce_target_direction=-1): + """ + Utility function used by _distance_map_walker to perform a BFS walk over the rail, filling in the + minimum distances from each target cell. + """ + neighbors = [] + + possible_directions = [0, 1, 2, 3] + if enforce_target_direction >= 0: + # The agent must land into the current cell with orientation `enforce_target_direction'. + # This is only possible if the agent has arrived from the cell in the opposite direction! + possible_directions = [(enforce_target_direction + 2) % 4] + + for neigh_direction in possible_directions: + new_cell = self._new_position(position, neigh_direction) + + if new_cell[0] >= 0 and new_cell[0] < self.env.height and new_cell[1] >= 0 and new_cell[1] < self.env.width: + + desired_movement_from_new_cell = (neigh_direction + 2) % 4 + + """ + # Is the next cell a dead-end? + isNextCellDeadEnd = False + nbits = 0 + tmp = self.env.rail.get_transitions((new_cell[0], new_cell[1])) + while tmp > 0: + nbits += (tmp & 1) + tmp = tmp >> 1 + if nbits == 1: + # Dead-end! + isNextCellDeadEnd = True + """ + + # Check all possible transitions in new_cell + for agent_orientation in range(4): + # Is a transition along movement `desired_movement_from_new_cell' to the current cell possible? + isValid = self.env.rail.get_transition((new_cell[0], new_cell[1], agent_orientation), + desired_movement_from_new_cell) + + if isValid: + """ + # TODO: check that it works with deadends! -- still bugged! + movement = desired_movement_from_new_cell + if isNextCellDeadEnd: + movement = (desired_movement_from_new_cell+2) % 4 + """ + new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation], + current_distance + 1) + neighbors.append((new_cell[0], new_cell[1], agent_orientation, new_distance)) + self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation] = new_distance + + return neighbors + + def _new_position(self, position, movement): + """ + Utility function that converts a compass movement over a 2D grid to new positions (r, c). + """ + if movement == 0: # NORTH + return (position[0] - 1, position[1]) + elif movement == 1: # EAST + return (position[0], position[1] + 1) + elif movement == 2: # SOUTH + return (position[0] + 1, position[1]) + elif movement == 3: # WEST + return (position[0], position[1] - 1) + + def get(self, handle): + """ + Computes the current observation for agent `handle' in env + + The observation vector is composed of 4 sequential parts, corresponding to data from the up to 4 possible + movements in a RailEnv (up to because only a subset of possible transitions are allowed in RailEnv). + The possible movements are sorted relative to the current orientation of the agent, rather than NESW as for + the transitions. The order is: + [data from 'left'] + [data from 'forward'] + [data from 'right'] + [data from 'back'] + + + + + + Each branch data is organized as: + [root node information] + + [recursive branch data from 'left'] + + [... from 'forward'] + + [... from 'right] + + [... from 'back'] + + Finally, each node information is composed of 5 floating point values: + + #1: + + #2: 1 if a target of another agent is detected between the previous node and the current one. + + #3: 1 if another agent is detected between the previous node and the current one. + + #4: distance of agent to the current branch node + + #5: minimum distance from node to the agent's target (when landing to the node following the corresponding + branch. + + Missing/padding nodes are filled in with -inf (truncated). + Missing values in present node are filled in with +inf (truncated). + + + In case of the root node, the values are [0, 0, 0, 0, distance from agent to target]. + In case the target node is reached, the values are [0, 0, 0, 0, 0]. + """ + + # Update local lookup table for all agents' positions + # self.location_has_agent = {} + # for loc in self.env.agents_position: + # self.location_has_agent[(loc[0], loc[1])] = 1 + self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents} + + agent = self.env.agents[handle] # TODO: handle being treated as index + # position = self.env.agents_position[handle] + # orientation = self.env.agents_direction[handle] + possible_transitions = self.env.rail.get_transitions((*agent.position, agent.direction)) + num_transitions = np.count_nonzero(possible_transitions) + # Root node - current position + # observation = [0, 0, 0, 0, self.distance_map[handle, position[0], position[1], orientation]] + observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)]] + root_observation = observation[:] + + # Start from the current orientation, and see which transitions are available; + # organize them as [left, forward, right, back], relative to the current orientation + # If only one transition is possible, the tree is oriented with this transition as the forward branch. + # TODO: Test if this works as desired! + orientation = agent.direction + if num_transitions == 1: + orientation == np.argmax(possible_transitions) + + # for branch_direction in [(orientation + 4 + i) % 4 for i in range(-1, 3)]: + for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]: + if possible_transitions[branch_direction]: + new_cell = self._new_position(agent.position, branch_direction) + + branch_observation = self._explore_branch(handle, new_cell, branch_direction, root_observation, 1) + observation = observation + branch_observation + else: + num_cells_to_fill_in = 0 + pow4 = 1 + for i in range(self.max_depth): + num_cells_to_fill_in += pow4 + pow4 *= 4 + observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in + return observation + + def _explore_branch(self, handle, position, direction, root_observation, depth): + """ + Utility function to compute tree-based observations. + """ + # [Recursive branch opened] + if depth >= self.max_depth + 1: + return [] + + # Continue along direction until next switch or + # until no transitions are possible along the current direction (i.e., dead-ends) + # We treat dead-ends as nodes, instead of going back, to avoid loops + exploring = True + last_isSwitch = False + last_isDeadEnd = False + last_isTerminal = False # wrong cell OR cycle; either way, we don't want the agent to land here + last_isTarget = False + + visited = set() + + # other_agent_encountered = False + # other_target_encountered = False + other_agent_encountered = np.inf + other_target_encountered = np.inf + + num_steps = 1 + while exploring: + # ############################# + # ############################# + # Modify here to compute any useful data required to build the end node's features. This code is called + # for each cell visited between the previous branching node and the next switch / target / dead-end. + if position in self.location_has_agent: + # other_agent_encountered = True + if num_steps < other_agent_encountered: + other_agent_encountered = num_steps + + if position in self.location_has_target: + # other_target_encountered = True + if num_steps < other_target_encountered: + other_target_encountered = num_steps + # ############################# + # ############################# + + if (position[0], position[1], direction) in visited: + last_isTerminal = True + break + visited.add((position[0], position[1], direction)) + + # If the target node is encountered, pick that as node. Also, no further branching is possible. + # if position[0] == self.env.agents_target[handle][0] and position[1] == self.env.agents_target[handle][1]: + if np.array_equal(position, self.env.agents[handle].target): + last_isTarget = True + break + + cell_transitions = self.env.rail.get_transitions((*position, direction)) + num_transitions = np.count_nonzero(cell_transitions) + exploring = False + if num_transitions == 1: + # Check if dead-end, or if we can go forward along direction + nbits = 0 + tmp = self.env.rail.get_transitions(tuple(position)) + while tmp > 0: + nbits += (tmp & 1) + tmp = tmp >> 1 + if nbits == 1: + # Dead-end! + last_isDeadEnd = True + + if not last_isDeadEnd: + # Keep walking through the tree along `direction' + exploring = True + direction = np.argmax(cell_transitions) + position = self._new_position(position, direction) + num_steps += 1 + + elif num_transitions > 0: + # Switch detected + last_isSwitch = True + break + + elif num_transitions == 0: + # Wrong cell type, but let's cover it and treat it as a dead-end, just in case + print("WRONG CELL TYPE detected in tree-search (0 transitions possible) at cell", position[0], + position[1], direction) + last_isTerminal = True + break + + # `position' is either a terminal node or a switch + + observation = [] + + # ############################# + # ############################# + # Modify here to append new / different features for each visited cell! + """ + if last_isTarget: + observation = [0, + 1 if other_target_encountered else 0, + 1 if other_agent_encountered else 0, + root_observation[3] + num_steps, + 0] + + elif last_isTerminal: + observation = [0, + 1 if other_target_encountered else 0, + 1 if other_agent_encountered else 0, + np.inf, + np.inf] + else: + observation = [0, + 1 if other_target_encountered else 0, + 1 if other_agent_encountered else 0, + root_observation[3] + num_steps, + self.distance_map[handle, position[0], position[1], direction]] + """ + if last_isTarget: + observation = [0, + other_target_encountered, + other_agent_encountered, + root_observation[3] + num_steps, + 0] + + elif last_isTerminal: + observation = [0, + other_target_encountered, + other_agent_encountered, + np.inf, + np.inf] + else: + observation = [0, + other_target_encountered, + other_agent_encountered, + root_observation[3] + num_steps, + self.distance_map[handle, position[0], position[1], direction]] + # ############################# + # ############################# + + new_root_observation = observation[:] + # Start from the current orientation, and see which transitions are available; + # organize them as [left, forward, right, back], relative to the current orientation + # Get the possible transitions + possible_transitions = self.env.rail.get_transitions((*position, direction)) + for branch_direction in [(direction + 4 + i) % 4 for i in range(-1, 3)]: + if last_isDeadEnd and self.env.rail.get_transition((*position, direction), + (branch_direction + 2) % 4): + # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes + # it back + new_cell = self._new_position(position, (branch_direction + 2) % 4) + branch_observation = self._explore_branch(handle, + new_cell, + (branch_direction + 2) % 4, + new_root_observation, + depth + 1) + observation = observation + branch_observation + + elif last_isSwitch and possible_transitions[branch_direction]: + new_cell = self._new_position(position, branch_direction) + branch_observation = self._explore_branch(handle, + new_cell, + branch_direction, + new_root_observation, + depth + 1) + observation = observation + branch_observation + + else: + num_cells_to_fill_in = 0 + pow4 = 1 + for i in range(self.max_depth - depth): + num_cells_to_fill_in += pow4 + pow4 *= 4 + observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf] * num_cells_to_fill_in + + return observation + + def util_print_obs_subtree(self, tree, num_features_per_node=5, prompt='', current_depth=0): + """ + Utility function to pretty-print tree observations returned by this object. + """ + if len(tree) < num_features_per_node: + return + + depth = 0 + tmp = len(tree) / num_features_per_node - 1 + pow4 = 4 + while tmp > 0: + tmp -= pow4 + depth += 1 + pow4 *= 4 + + prompt_ = ['L:', 'F:', 'R:', 'B:'] + + print(" " * current_depth + prompt, tree[0:num_features_per_node]) + child_size = (len(tree) - num_features_per_node) // 4 + for children in range(4): + child_tree = tree[(num_features_per_node + children * child_size): + (num_features_per_node + (children + 1) * child_size)] + self.util_print_obs_subtree(child_tree, + num_features_per_node, + prompt=prompt_[children], + current_depth=current_depth + 1) + + def split_tree(self, tree, num_features_per_node=5, current_depth=0): + """ + + :param tree: + :param num_features_per_node: + :param prompt: + :param current_depth: + :return: + """ + + if len(tree) < num_features_per_node: + return [], [] + + depth = 0 + tmp = len(tree) / num_features_per_node - 1 + pow4 = 4 + while tmp > 0: + tmp -= pow4 + depth += 1 + pow4 *= 4 + child_size = (len(tree) - num_features_per_node) // 4 + tree_data = tree[0:num_features_per_node - 1].tolist() + distance_data = [tree[num_features_per_node - 1]] + for children in range(4): + child_tree = tree[(num_features_per_node + children * child_size): + (num_features_per_node + (children + 1) * child_size)] + tmp_tree_data, tmp_distance_data = self.split_tree(child_tree, + num_features_per_node, + current_depth=current_depth + 1) + if len(tmp_tree_data) > 0: + tree_data.extend(tmp_tree_data) + distance_data.extend(tmp_distance_data) + return tree_data, distance_data + + +class GlobalObsForRailEnv(ObservationBuilder): + """ + Gives a global observation of the entire rail environment. + The observation is composed of the following elements: + + - transition map array with dimensions (env.height, env.width, 16), + assuming 16 bits encoding of transitions. + + - Four 2D arrays containing respectively the position of the given agent, + the position of its target, the positions of the other agents and of + their target. + + - A 4 elements array with one of encoding of the direction. + """ + + def __init__(self): + super(GlobalObsForRailEnv, self).__init__() + + def reset(self): + self.rail_obs = np.zeros((self.env.height, self.env.width, 16)) + for i in range(self.rail_obs.shape[0]): + for j in range(self.rail_obs.shape[1]): + self.rail_obs[i, j] = np.array( + list(f'{self.env.rail.get_transitions((i, j)):016b}')).astype(int) + + # self.targets = np.zeros(self.env.height, self.env.width) + # for target_pos in self.env.agents_target: + # self.targets[target_pos] += 1 + + def get(self, handle): + obs = np.zeros((4, self.env.height, self.env.width)) + agents = self.env.agents + agent = agents[handle] + + agent_pos = agents[handle].position + obs[0][agent_pos] += 1 + obs[1][agent.target] += 1 + + for i in range(len(agents)): + if i != handle: # TODO: handle used as index...? + agent2 = agents[i] + obs[3][agent2.position] += 1 + obs[2][agent2.target] += 1 + + direction = np.zeros(4) + direction[agent.direction] = 1 + + return self.rail_obs, obs, direction diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index adb40f03b2ee100d8e12cfead660f9600dd9968b..a5248a8fb3b5a7ee2d433aa12c5daca84753e6e9 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -5,10 +5,10 @@ Generator functions are functions that take width, height and num_resets as argu a GridTransitionMap object. """ import numpy as np -import pickle +import msgpack from flatland.core.env import Environment -from flatland.core.env_observation_builder import TreeObsForRailEnv +from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.generators import random_rail_generator from flatland.envs.env_utils import get_new_position from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent @@ -157,7 +157,7 @@ class RailEnv(Environment): # self.dones[handle] = False self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False) # perhaps dones should be part of each agent. - + # Reset the state of the observation builder with the new environment self.obs_builder.reset() @@ -215,7 +215,7 @@ class RailEnv(Environment): # 1) transition allows the new_direction in the cell, # 2) the new cell is not empty (case 0), # 3) the cell is free, i.e., no agent is currently in that cell - + # if ( # new_position[1] >= self.width or # new_position[0] >= self.height or @@ -226,11 +226,11 @@ class RailEnv(Environment): # new_cell_isValid = False new_cell_isValid = ( - np.array_equal( # Check the new position is still in the grid - new_position, - np.clip(new_position, [0, 0], [self.height-1, self.width-1])) - and # check the new position has some transitions (ie is not an empty cell) - self.rail.get_transitions(new_position) > 0) + np.array_equal( # Check the new position is still in the grid + new_position, + np.clip(new_position, [0, 0], [self.height - 1, self.width - 1])) + and # check the new position has some transitions (ie is not an empty cell) + self.rail.get_transitions(new_position) > 0) # If transition validity hasn't been checked yet. if transition_isValid is None: @@ -246,7 +246,7 @@ class RailEnv(Environment): # Check the new position is not the same as any of the existing agent positions # (including itself, for simplicity, since it is moving) cell_isFree = not np.any( - np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1)) + np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1)) if all([new_cell_isValid, transition_isValid, cell_isFree]): # move and change direction to face the new_direction that was @@ -278,7 +278,7 @@ class RailEnv(Environment): # if num_agents_in_target_position == self.number_of_agents: if np.all([np.array_equal(agent2.position, agent2.target) for agent2 in self.agents]): self.dones["__all__"] = True - self.rewards_dict = [0*r+global_reward for r in self.rewards_dict] + self.rewards_dict = [0 * r + global_reward for r in self.rewards_dict] # Reset the step actions (in case some agent doesn't 'register_action' # on the next step) @@ -324,20 +324,37 @@ class RailEnv(Environment): # TODO: pass - def save(self, sFilename): - dSave = { - "grid": self.rail.grid, - "agents_static": self.agents_static - } - with open(sFilename, "wb") as fOut: - pickle.dump(dSave, fOut) - - def load(self, sFilename): - with open(sFilename, "rb") as fIn: - dLoad = pickle.load(fIn) - self.rail.grid = dLoad["grid"] - self.height, self.width = self.rail.grid.shape - self.agents_static = dLoad["agents_static"] - self.agents = [None] * self.get_num_agents() - self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False) - + def get_full_state_msg(self): + grid_data = self.rail.grid.tolist() + agent_static_data = [agent.to_list() for agent in self.agents_static] + agent_data = [agent.to_list() for agent in self.agents] + msg_data = { + "grid": grid_data, + "agents_static": agent_static_data, + "agents": agent_data} + return msgpack.packb(msg_data, use_bin_type=True) + + def get_agent_state_msg(self): + agent_data = [agent.to_list() for agent in self.agents] + msg_data = { + "agents": agent_data} + return msgpack.packb(msg_data, use_bin_type=True) + + def set_full_state_msg(self, msg_data): + data = msgpack.unpackb(msg_data, use_list=False) + self.rail.grid = np.array(data[b"grid"]) + self.agents_static = [EnvAgentStatic(d[0], d[1], d[2]) for d in data[b"agents_static"]] + self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4]) for d in data[b"agents"]] + # setup with loaded data + self.height, self.width = self.rail.grid.shape + # self.agents = [None] * self.get_num_agents() + self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False) + + def save(self, filename): + with open(filename, "wb") as file_out: + file_out.write(self.get_full_state_msg()) + + def load(self, filename): + with open(filename, "rb") as file_in: + load_data = file_in.read() + self.set_full_state_msg(load_data) diff --git a/flatland/utils/editor.py b/flatland/utils/editor.py index 3c7f47012e383e1e2bca9af120152b65cfa43e81..a33a0116a93b9c1aa25c9da8103925f7c5ed51ae 100644 --- a/flatland/utils/editor.py +++ b/flatland/utils/editor.py @@ -17,7 +17,7 @@ import os from flatland.envs.rail_env import RailEnv, random_rail_generator from flatland.envs.generators import complex_rail_generator # from flatland.core.transitions import RailEnvTransitions -from flatland.core.env_observation_builder import TreeObsForRailEnv +from flatland.envs.observations import TreeObsForRailEnv import flatland.utils.rendertools as rt from examples.play_model import Player from flatland.envs.env_utils import mirror @@ -110,8 +110,7 @@ class View(object): self.wTab.set_title(i, title) self.wTab.children = [ VBox([self.wDebug, self.wDebug_move]), - VBox([self.wRegenMethod, self.wReplaceAgents]) - ] + VBox([self.wRegenMethod, self.wReplaceAgents])] # Progress bar intended for stepping in the background (not yet working) self.wProg_steps = ipywidgets.IntProgress(value=0, min=0, max=20, step=1, description="Step") @@ -121,21 +120,20 @@ class View(object): dict(name="Refresh", method=self.controller.refresh, tip="Redraw only"), dict(name="Clear", method=self.controller.clear, tip="Clear rails and agents"), dict(name="Reset", method=self.controller.reset, - tip="Standard env reset, including regen rail + agents"), + tip="Standard env reset, including regen rail + agents"), dict(name="Restart Agents", method=self.controller.restartAgents, - tip="Move agents back to start positions"), + tip="Move agents back to start positions"), dict(name="Regenerate", method=self.controller.regenerate, - tip="Regenerate the rails using the method selected below"), + tip="Regenerate the rails using the method selected below"), dict(name="Load", method=self.controller.load), dict(name="Save", method=self.controller.save), dict(name="Step", method=self.controller.step), - dict(name="Run Steps", method=self.controller.start_run), - ] + dict(name="Run Steps", method=self.controller.start_run)] self.lwButtons = [] for dButton in ldButtons: wButton = ipywidgets.Button(description=dButton["name"], - tooltip=dButton["tip"] if "tip" in dButton else dButton["name"]) + tooltip=dButton["tip"] if "tip" in dButton else dButton["name"]) wButton.on_click(dButton["method"]) self.lwButtons.append(wButton) @@ -145,8 +143,7 @@ class View(object): self.wSize, self.wNAgents, self.wProg_steps, - self.wTab - ]) + self.wTab]) self.wMain = HBox([self.wImage, self.wVbox_controls]) @@ -164,7 +161,7 @@ class View(object): with self.wOutput: # plt.figure(figsize=(10, 10)) self.oRT.renderEnv(spacing=False, arrows=False, sRailColor="gray", - show=False, iSelectedAgent=self.model.iSelectedAgent) + show=False, iSelectedAgent=self.model.iSelectedAgent) img = self.oRT.getImage() # plt.clf() # plt.close() @@ -607,7 +604,7 @@ class EditorModel(object): # Has the user clicked on an existing agent? iAgent = self.find_agent_at(rcCell) - + if iAgent is None: # No if self.iSelectedAgent is None: @@ -656,7 +653,7 @@ class EditorModel(object): # self.log("step ", i) self.step() time.sleep(0.2) - wProg_steps.value = i+1 # indicate progress on bar + wProg_steps.value = i + 1 # indicate progress on bar finally: self.thread = None @@ -683,6 +680,3 @@ class EditorModel(object): binTrans, sbinTrans, [sbinTrans[i:(i + 4)] for i in range(0, len(sbinTrans), 4)]) - - - \ No newline at end of file diff --git a/flatland/utils/graphics_layer.py b/flatland/utils/graphics_layer.py index c0d390c34c66567b7ae216fa53e03abb6baf8a06..4cfcc64bffb82f91a0f36822188db297bc1ed37e 100644 --- a/flatland/utils/graphics_layer.py +++ b/flatland/utils/graphics_layer.py @@ -54,7 +54,7 @@ class GraphicsLayer(object): color = tuple((gcolor[:3] * 255).astype(int)) else: color = self.tColGrid - + if lighten: color = tuple([int(255 - (255 - iRGB) / 3) for iRGB in color]) @@ -65,6 +65,6 @@ class GraphicsLayer(object): def setRailAt(self, row, col, binTrans): pass - + def setAgentAt(self, iAgent, row, col, iDirIn, iDirOut): pass diff --git a/flatland/utils/graphics_pil.py b/flatland/utils/graphics_pil.py index 01cc5f03c7d8f015ff24a57c4ba53f65ddc00684..41516fd94737556a4b8abbc7ccfce0fd503a3d6e 100644 --- a/flatland/utils/graphics_pil.py +++ b/flatland/utils/graphics_pil.py @@ -39,7 +39,7 @@ class PILGL(GraphicsLayer): r = np.sqrt(s) gPoints = np.stack([np.atleast_1d(gX), -np.atleast_1d(gY)]).T * self.nPixCell for x, y in gPoints: - self.draw.rectangle([(x-r, y-r), (x+r, y+r)], fill=color, outline=color) + self.draw.rectangle([(x - r, y - r), (x + r, y + r)], fill=color, outline=color) def text(self, *args, **kwargs): pass diff --git a/flatland/utils/render_qt.py b/flatland/utils/render_qt.py index af40c05cbcf3628a9a6bdb370dae4161d70a8a31..ea9613968c473c351849509bfc3e277dd7fe0701 100644 --- a/flatland/utils/render_qt.py +++ b/flatland/utils/render_qt.py @@ -111,7 +111,7 @@ class QTSVG(GraphicsLayer): self.wWinMain.resize(1000, 1000) self.wWinMain.show() self.wWinMain.setFocus() - + self.track = self.track = Track() self.lwTrack = [] self.zug = Zug() @@ -152,21 +152,21 @@ class QTSVG(GraphicsLayer): def processEvents(self): self.app.processEvents() time.sleep(0.001) - + def clear_rails(self): # print("Clear rails: ", len(self.lwTrack)) for wRail in self.lwTrack: self.layout.removeWidget(wRail) self.lwTrack = [] self.clear_agents() - + def clear_agents(self): # print("Clear Agents: ", len(self.lwAgents)) for wAgent in self.lwAgents: self.layout.removeWidget(wAgent) self.lwAgents = [] self.agents_prev = [] - + def setRailAt(self, row, col, binTrans): if binTrans in self.track.dSvg: sSVG = self.track.dSvg[binTrans].to_string() @@ -204,7 +204,7 @@ class QTSVG(GraphicsLayer): return # Ensure we have adequate slots in the list lwAgents - for i in range(len(self.lwAgents), iAgent+1): + for i in range(len(self.lwAgents), iAgent + 1): self.lwAgents.append(None) self.agents_prev.append(None) @@ -226,7 +226,7 @@ def main2(): gl = QTGL(10, 10) for i in range(10): gl.beginFrame() - gl.plot([3+i, 4], [-4-i, -5], color="r") + gl.plot([3 + i, 4], [-4 - i, -5], color="r") gl.endFrame() time.sleep(1) diff --git a/flatland/utils/svg.py b/flatland/utils/svg.py index b219560c00f6a1e08b7384260274101ee5b430e6..32d5631839f43964cf5ff7d94520aacf67db0488 100644 --- a/flatland/utils/svg.py +++ b/flatland/utils/svg.py @@ -30,10 +30,10 @@ class SVG(object): def merge(self, svg2): svg3 = svg2.copy() - + svg3.renumber_styles(offset=10) svg3.eStyle.text = self.eStyle.text + "\n" + svg3.eStyle.text - + for child in self.svg.root: if not child.tag.endswith("style"): svg3.svg.root.append(child) @@ -48,12 +48,12 @@ class SVG(object): lEl = self.svg.root.xpath("//*[@class='{}']".format(sClass)) for el in lEl: el.attrib["class"] = "st{}".format(iStyle + offset) - - sStyle2 = str(iStyle+offset) + + sStyle2 = str(iStyle + offset) sNewStyle = "\t.st" + sStyle2 + "{" + self.dStyles[sStyle] + "}\n" sNewStyles += sNewStyle - + self.eStyle.text = sNewStyles def set_style_color(self, style_name, color): @@ -63,7 +63,7 @@ class SVG(object): sValue = "fill:#" + "".join([format(col, "#04x")[2:] for col in color]) + ";" sNewStyle = "\t.st" + sKey + "{" + sValue + "}\n" sNewStyles += sNewStyle - + self.eStyle.text = sNewStyles def set_rotate(self, angle): @@ -87,7 +87,7 @@ class Zug(object): if delta_dir in (0, 2): svg = self.svg_straight.copy() svg.set_rotate(iDirIn * 90) - + if delta_dir == 1: # bend to right, eg N->E, E->S svg = self.svg_curve1.copy() svg.set_rotate((iDirIn - 1) * 90) @@ -124,8 +124,7 @@ class Track(object): "NN SS EN SW": "Weiche_vertikal_oben_links.svg", "NN SS SE WN": "Weiche_vertikal_oben_rechts.svg", "NN SS NW ES": "Weiche_vertikal_unten_links.svg", - "NN SS NE WS": "Weiche_vertikal_unten_rechts.svg", - } + "NN SS NE WS": "Weiche_vertikal_unten_rechts.svg"} self.dSvg = {} @@ -136,8 +135,7 @@ class Track(object): svgBG = SVG("./svg/Background_#91D1DD.svg") for sTrans, sFile in dFiles.items(): - - svg = SVG("./svg/"+sFile) + svg = SVG("./svg/" + sFile) lTrans16 = ["0"] * 16 for sTran in sTrans.split(" "): @@ -165,7 +163,7 @@ class Track(object): def main(): # svg1 = SVG("./svg/Gleis_vertikal.svg") # svg2 = SVG("./svg/Zug_1_Weiche_#0091ea.svg") - + # svg3 = svg2.merge(svg1) # svg3.set_rotate(90) @@ -188,4 +186,3 @@ def main(): if __name__ == "__main__": main() - diff --git a/requirements_dev.txt b/requirements_dev.txt index b0dafbffc9d1b507e4c67fdae91b5ed0c0d452f7..69a03322b1aa316db9e406235e4fe2c39b5282f5 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -19,5 +19,10 @@ matplotlib==3.0.2 PyQt5==5.12 Pillow==5.4.1 +<<<<<<< HEAD svgutils==0.3.1 +======= +msgpack==0.6.1 +svgutils==0.3.1 +>>>>>>> 19fac96634f312ac3c014c0517fad0c1fe273655 diff --git a/tests/test_environments.py b/tests/test_environments.py index f12dfa3d6b57f76ce490c2c748fefc008ba371a5..4c55eac7afb44d95f0e49d665eeeb4bc36becea9 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -4,6 +4,7 @@ import numpy as np from flatland.envs.rail_env import RailEnv from flatland.envs.generators import rail_from_GridTransitionMap_generator +from flatland.envs.generators import complex_rail_generator from flatland.core.transitions import Grid4Transitions from flatland.core.transition_map import GridTransitionMap from flatland.core.env_observation_builder import GlobalObsForRailEnv @@ -12,6 +13,30 @@ from flatland.envs.agent_utils import EnvAgent """Tests for `flatland` package.""" +def test_save_load(): + env = RailEnv(width=10, height=10, + rail_generator=complex_rail_generator(nr_start_goal=2, nr_extra=5, min_dist=6, seed=0), + number_of_agents=2) + env.reset() + agent_1_pos = env.agents_static[0].position + agent_1_dir = env.agents_static[0].direction + agent_1_tar = env.agents_static[0].target + agent_2_pos = env.agents_static[1].position + agent_2_dir = env.agents_static[1].direction + agent_2_tar = env.agents_static[1].target + env.save("test_save.dat") + env.load("test_save.dat") + assert(env.width == 10) + assert(env.height == 10) + assert(len(env.agents) == 2) + assert(agent_1_pos == env.agents_static[0].position) + assert(agent_1_dir == env.agents_static[0].direction) + assert(agent_1_tar == env.agents_static[0].target) + assert(agent_2_pos == env.agents_static[1].position) + assert(agent_2_dir == env.agents_static[1].direction) + assert(agent_2_tar == env.agents_static[1].target) + + def test_rail_environment_single_agent(): cells = [int('0000000000000000', 2), # empty cell - Case 0 @@ -205,4 +230,4 @@ def test_dead_end(): if __name__ == "__main__": test_rail_environment_single_agent() - test_dead_end() \ No newline at end of file + test_dead_end() diff --git a/tests/test_rendertools.py b/tests/test_rendertools.py index ea4ad3c46f3465c2f30aabaf8ae9cff7e2efd7f7..f6defb2de400f1a007da237b2f91ec38df5db07b 100644 --- a/tests/test_rendertools.py +++ b/tests/test_rendertools.py @@ -6,13 +6,16 @@ Tests for `flatland` package. from flatland.envs.rail_env import RailEnv, random_rail_generator import numpy as np -import os +<<<<<<< HEAD +======= +# import os +>>>>>>> dc2fa1ee0244b15c76d89ab768c5e1bbd2716147 import sys import matplotlib.pyplot as plt import flatland.utils.rendertools as rt -from flatland.core.env_observation_builder import TreeObsForRailEnv +from flatland.envs.observations import TreeObsForRailEnv def checkFrozenImage(oRT, sFileImage, resave=False): @@ -49,7 +52,7 @@ def test_render_env(save_new_images=False): oEnv.rail.load_transition_map(sfTestEnv) oRT = rt.RenderTool(oEnv) oRT.renderEnv() - + checkFrozenImage(oRT, "basic-env.npz", resave=save_new_images) oRT = rt.RenderTool(oEnv, gl="PIL") @@ -86,4 +89,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main()