From 4de49a72a8942a4397668f0c046e860454320b9d Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Fri, 5 Jul 2019 18:10:45 +0200 Subject: [PATCH] #62 increase coverage #83 cleanup --- flatland/core/transition_map.py | 6 +- flatland/envs/observations.py | 81 +++++++++++----------- flatland/envs/rail_env.py | 42 ++++++----- tests/test_flatland_core_transition_map.py | 6 +- 4 files changed, 65 insertions(+), 70 deletions(-) diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py index 6c0b92a7..cb09a628 100644 --- a/flatland/core/transition_map.py +++ b/flatland/core/transition_map.py @@ -336,8 +336,4 @@ class GridTransitionMap(TransitionMap): return True -# TODO: GIACOMO: is it better to provide those methods with lists of cell_ids -# (most general implementation) or to make Grid-class specific methods for -# slicing over the 3 dimensions? I'd say both perhaps. - -# TODO: override __getitem__ and __setitem__ (cell contents, not transitions?) +# TODO: improvement override __getitem__ and __setitem__ (cell contents, not transitions?) diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index a10c58e6..5a19deeb 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -23,6 +23,7 @@ class TreeObsForRailEnv(ObservationBuilder): observation_dim = 9 def __init__(self, max_depth, predictor=None): + super().__init__() self.max_depth = max_depth # Compute the size of the returned observation vector @@ -41,15 +42,14 @@ class TreeObsForRailEnv(ObservationBuilder): def reset(self): agents = self.env.agents - nAgents = len(agents) + nb_agents = len(agents) compute_distance_map = True - if self.agents_previous_reset is not None: - if nAgents == len(self.agents_previous_reset): - compute_distance_map = False - for i in range(nAgents): - if agents[i].target != self.agents_previous_reset[i].target: - compute_distance_map = True + if self.agents_previous_reset is not None and nb_agents == len(self.agents_previous_reset): + compute_distance_map = False + for i in range(nb_agents): + if agents[i].target != self.agents_previous_reset[i].target: + compute_distance_map = True self.agents_previous_reset = agents if compute_distance_map: @@ -57,12 +57,12 @@ class TreeObsForRailEnv(ObservationBuilder): def _compute_distance_map(self): agents = self.env.agents - nAgents = len(agents) - self.distance_map = np.inf * np.ones(shape=(nAgents, # self.env.number_of_agents, + nb_agents = len(agents) + self.distance_map = np.inf * np.ones(shape=(nb_agents, self.env.height, self.env.width, 4)) - self.max_dist = np.zeros(nAgents) + self.max_dist = np.zeros(nb_agents) self.max_dist = [self._distance_map_walker(agent.target, i) for i, agent in enumerate(agents)] # Update local lookup table for all agents' target locations self.location_has_target = {tuple(agent.target): 1 for agent in agents} @@ -83,10 +83,8 @@ class TreeObsForRailEnv(ObservationBuilder): # BFS from target `position' to all the reachable nodes in the grid # Stop the search if the target position is re-visited, in any direction - visited = set([(position[0], position[1], 0), - (position[0], position[1], 1), - (position[0], position[1], 2), - (position[0], position[1], 3)]) + visited = {(position[0], position[1], 0), (position[0], position[1], 1), (position[0], position[1], 2), + (position[0], position[1], 3)} max_distance = 0 @@ -133,10 +131,10 @@ class TreeObsForRailEnv(ObservationBuilder): # Check all possible transitions in new_cell for agent_orientation in range(4): # Is a transition along movement `desired_movement_from_new_cell' to the current cell possible? - isValid = self.env.rail.get_transition((new_cell[0], new_cell[1], agent_orientation), - desired_movement_from_new_cell) + is_valid = self.env.rail.get_transition((new_cell[0], new_cell[1], agent_orientation), + desired_movement_from_new_cell) - if isValid: + if is_valid: """ # TODO: check that it works with deadends! -- still bugged! movement = desired_movement_from_new_cell @@ -163,12 +161,14 @@ class TreeObsForRailEnv(ObservationBuilder): elif movement == Grid4TransitionsEnum.WEST: return (position[0], position[1] - 1) - def get_many(self, handles=[]): + def get_many(self, handles=None): """ Called whenever an observation has to be computed for the `env' environment, for each agent with handle in the `handles' list. """ + if handles is None: + handles = [] if self.predictor: self.predicted_pos = {} self.predicted_dir = {} @@ -259,7 +259,6 @@ class TreeObsForRailEnv(ObservationBuilder): # Root node - current position observation = [0, 0, 0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0] - root_observation = observation[:] visited = set() # Start from the current orientation, and see which transitions are available; # organize them as [left, forward, right, back], relative to the current orientation @@ -273,7 +272,7 @@ class TreeObsForRailEnv(ObservationBuilder): if possible_transitions[branch_direction]: new_cell = self._new_position(agent.position, branch_direction) branch_observation, branch_visited = \ - self._explore_branch(handle, new_cell, branch_direction, root_observation, 1, 1) + self._explore_branch(handle, new_cell, branch_direction, 1, 1) observation = observation + branch_observation visited = visited.union(branch_visited) else: @@ -291,7 +290,7 @@ class TreeObsForRailEnv(ObservationBuilder): pow4 *= 4 return num_observations * self.observation_dim - def _explore_branch(self, handle, position, direction, root_observation, tot_dist, depth): + def _explore_branch(self, handle, position, direction, tot_dist, depth): """ Utility function to compute tree-based observations. We walk along the branch and collect the information documented in the get() function. @@ -305,10 +304,10 @@ class TreeObsForRailEnv(ObservationBuilder): # until no transitions are possible along the current direction (i.e., dead-ends) # We treat dead-ends as nodes, instead of going back, to avoid loops exploring = True - last_isSwitch = False - last_isDeadEnd = False - last_isTerminal = False # wrong cell OR cycle; either way, we don't want the agent to land here - last_isTarget = False + last_is_switch = False + last_is_dead_end = False + last_is_terminal = False # wrong cell OR cycle; either way, we don't want the agent to land here + last_is_target = False visited = set() agent = self.env.agents[handle] @@ -369,21 +368,19 @@ class TreeObsForRailEnv(ObservationBuilder): if tot_dist < other_target_encountered: other_target_encountered = tot_dist - if position == agent.target: - if tot_dist < own_target_encountered: - own_target_encountered = tot_dist + if position == agent.target and tot_dist < own_target_encountered: + own_target_encountered = tot_dist # ############################# # ############################# - if (position[0], position[1], direction) in visited: - last_isTerminal = True + last_is_terminal = True break visited.add((position[0], position[1], direction)) # If the target node is encountered, pick that as node. Also, no further branching is possible. if np.array_equal(position, self.env.agents[handle].target): - last_isTarget = True + last_is_target = True break cell_transitions = self.env.rail.get_transitions((*position, direction)) @@ -403,9 +400,9 @@ class TreeObsForRailEnv(ObservationBuilder): tmp = tmp >> 1 if nbits == 1: # Dead-end! - last_isDeadEnd = True + last_is_dead_end = True - if not last_isDeadEnd: + if not last_is_dead_end: # Keep walking through the tree along `direction' exploring = True # convert one-hot encoding to 0,1,2,3 @@ -415,14 +412,14 @@ class TreeObsForRailEnv(ObservationBuilder): tot_dist += 1 elif num_transitions > 0: # Switch detected - last_isSwitch = True + last_is_switch = True break elif num_transitions == 0: # Wrong cell type, but let's cover it and treat it as a dead-end, just in case print("WRONG CELL TYPE detected in tree-search (0 transitions possible) at cell", position[0], position[1], direction) - last_isTerminal = True + last_is_terminal = True break # `position' is either a terminal node or a switch @@ -433,7 +430,7 @@ class TreeObsForRailEnv(ObservationBuilder): # ############################# # Modify here to append new / different features for each visited cell! - if last_isTarget: + if last_is_target: observation = [own_target_encountered, other_target_encountered, other_agent_encountered, @@ -445,7 +442,7 @@ class TreeObsForRailEnv(ObservationBuilder): other_agent_opposite_direction ] - elif last_isTerminal: + elif last_is_terminal: observation = [own_target_encountered, other_target_encountered, other_agent_encountered, @@ -476,25 +473,25 @@ class TreeObsForRailEnv(ObservationBuilder): # Get the possible transitions possible_transitions = self.env.rail.get_transitions((*position, direction)) for branch_direction in [(direction + 4 + i) % 4 for i in range(-1, 3)]: - if last_isDeadEnd and self.env.rail.get_transition((*position, direction), - (branch_direction + 2) % 4): + if last_is_dead_end and self.env.rail.get_transition((*position, direction), + (branch_direction + 2) % 4): # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes # it back new_cell = self._new_position(position, (branch_direction + 2) % 4) branch_observation, branch_visited = self._explore_branch(handle, new_cell, (branch_direction + 2) % 4, - new_root_observation, tot_dist + 1, + tot_dist + 1, depth + 1) observation = observation + branch_observation if len(branch_visited) != 0: visited = visited.union(branch_visited) - elif last_isSwitch and possible_transitions[branch_direction]: + elif last_is_switch and possible_transitions[branch_direction]: new_cell = self._new_position(position, branch_direction) branch_observation, branch_visited = self._explore_branch(handle, new_cell, branch_direction, - new_root_observation, tot_dist + 1, + tot_dist + 1, depth + 1) observation = observation + branch_observation if len(branch_visited) != 0: diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index b4a56a8d..b31be94f 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -109,7 +109,7 @@ class RailEnv(Environment): self.obs_builder._set_env(self) self.action_space = [1] - self.observation_space = self.obs_builder.observation_space # updated on resets? + self.observation_space = self.obs_builder.observation_space self.rewards = [0] * number_of_agents self.done = False @@ -195,31 +195,29 @@ class RailEnv(Environment): # Reset the step rewards self.rewards_dict = dict() - for iAgent in range(self.get_num_agents()): - self.rewards_dict[iAgent] = 0 + for i_agent in range(self.get_num_agents()): + self.rewards_dict[i_agent] = 0 if self.dones["__all__"]: self.rewards_dict = {i: r + global_reward for i, r in self.rewards_dict.items()} return self._get_observations(), self.rewards_dict, self.dones, {} - # for i in range(len(self.agents_handles)): - for iAgent in range(self.get_num_agents()): - agent = self.agents[iAgent] + for i_agent, agent in enumerate(self.get_num_agents()): agent.old_direction = agent.direction agent.old_position = agent.position - if self.dones[iAgent]: # this agent has already completed... + if self.dones[i_agent]: # this agent has already completed... continue - if iAgent not in action_dict: # no action has been supplied for this agent - action_dict[iAgent] = RailEnvActions.DO_NOTHING + if i_agent not in action_dict: # no action has been supplied for this agent + action_dict[i_agent] = RailEnvActions.DO_NOTHING - if action_dict[iAgent] < 0 or action_dict[iAgent] > len(RailEnvActions): - print('ERROR: illegal action=', action_dict[iAgent], - 'for agent with index=', iAgent, + if action_dict[i_agent] < 0 or action_dict[i_agent] > len(RailEnvActions): + print('ERROR: illegal action=', action_dict[i_agent], + 'for agent with index=', i_agent, '"DO NOTHING" will be executed instead') - action_dict[iAgent] = RailEnvActions.DO_NOTHING + action_dict[i_agent] = RailEnvActions.DO_NOTHING - action = action_dict[iAgent] + action = action_dict[i_agent] if action == RailEnvActions.DO_NOTHING and agent.moving: # Keep moving @@ -228,12 +226,12 @@ class RailEnv(Environment): if action == RailEnvActions.STOP_MOVING and agent.moving and agent.speed_data['position_fraction'] == 0.: # Only allow halting an agent on entering new cells. agent.moving = False - self.rewards_dict[iAgent] += stop_penalty + self.rewards_dict[i_agent] += stop_penalty if not agent.moving and not (action == RailEnvActions.DO_NOTHING or action == RailEnvActions.STOP_MOVING): # Only allow agent to start moving by pressing forward. agent.moving = True - self.rewards_dict[iAgent] += start_penalty + self.rewards_dict[i_agent] += start_penalty # Now perform a movement. # If the agent is in an initial position within a new cell (agent.speed_data['position_fraction']<eps) @@ -269,16 +267,16 @@ class RailEnv(Environment): else: # TODO: an invalid action was chosen after entering the cell. The agent cannot move. - self.rewards_dict[iAgent] += invalid_action_penalty + self.rewards_dict[i_agent] += invalid_action_penalty agent.moving = False - self.rewards_dict[iAgent] += stop_penalty + self.rewards_dict[i_agent] += stop_penalty continue else: # TODO: an invalid action was chosen after entering the cell. The agent cannot move. - self.rewards_dict[iAgent] += invalid_action_penalty + self.rewards_dict[i_agent] += invalid_action_penalty agent.moving = False - self.rewards_dict[iAgent] += stop_penalty + self.rewards_dict[i_agent] += stop_penalty continue @@ -300,9 +298,9 @@ class RailEnv(Environment): agent.speed_data['position_fraction'] = 0.0 if np.equal(agent.position, agent.target).all(): - self.dones[iAgent] = True + self.dones[i_agent] = True else: - self.rewards_dict[iAgent] += step_penalty * agent.speed_data['speed'] + self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed'] # Check for end of episode + add global reward to all rewards! if np.all([np.array_equal(agent2.position, agent2.target) for agent2 in self.agents]): diff --git a/tests/test_flatland_core_transition_map.py b/tests/test_flatland_core_transition_map.py index 5117b12a..2f721ca2 100644 --- a/tests/test_flatland_core_transition_map.py +++ b/tests/test_flatland_core_transition_map.py @@ -3,7 +3,7 @@ from flatland.core.grid.grid8 import Grid8Transitions, Grid8TransitionsEnum from flatland.core.transition_map import GridTransitionMap -def test_grid4_set_transitions(): +def test_grid4_get_transitions(): grid4_map = GridTransitionMap(2, 2, Grid4Transitions([])) assert grid4_map.get_transitions((0, 0, Grid4TransitionsEnum.NORTH)) == (0, 0, 0, 0) grid4_map.set_transition((0, 0, Grid4TransitionsEnum.NORTH), Grid4TransitionsEnum.NORTH, 1) @@ -19,3 +19,7 @@ def test_grid8_set_transitions(): assert grid8_map.get_transitions((0, 0, Grid8TransitionsEnum.NORTH)) == (1, 0, 0, 0, 0, 0, 0, 0) grid8_map.set_transition((0, 0, Grid8TransitionsEnum.NORTH), Grid8TransitionsEnum.NORTH, 0) assert grid8_map.get_transitions((0, 0, Grid8TransitionsEnum.NORTH)) == (0, 0, 0, 0, 0, 0, 0, 0) + + + +# TODO GridTransitionMap -- GitLab