From fcc163ab51722c9aedb6affba7380b2e15a66438 Mon Sep 17 00:00:00 2001 From: hagrid67 <jdhwatson@gmail.com> Date: Wed, 1 May 2019 19:46:23 +0100 Subject: [PATCH] Added comments and fixed lint --- flatland/core/env_observation_builder.py | 2 +- flatland/envs/rail_env.py | 103 ++++++++++++++--------- 2 files changed, 62 insertions(+), 43 deletions(-) diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py index 1ae2819d..de3ee932 100644 --- a/flatland/core/env_observation_builder.py +++ b/flatland/core/env_observation_builder.py @@ -137,7 +137,7 @@ class TreeObsForRailEnv(ObservationBuilder): new_cell = self._new_position(position, neigh_direction) if new_cell[0] >= 0 and new_cell[0] < self.env.height and \ - new_cell[1] >= 0 and new_cell[1] < self.env.width: + new_cell[1] >= 0 and new_cell[1] < self.env.width: desired_movement_from_new_cell = (neigh_direction + 2) % 4 diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index f1a3d20c..eff1c108 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -163,9 +163,9 @@ def a_star(rail_trans, rail_array, start, end): for new_pos in [(0, -1), (0, 1), (-1, 0), (1, 0)]: node_pos = (current_node.pos[0] + new_pos[0], current_node.pos[1] + new_pos[1]) if node_pos[0] >= rail_shape[0] or \ - node_pos[0] < 0 or \ - node_pos[1] >= rail_shape[1] or \ - node_pos[1] < 0: + node_pos[0] < 0 or \ + node_pos[1] >= rail_shape[1] or \ + node_pos[1] < 0: continue # validate positions @@ -540,8 +540,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8): for i in range(len(transitions_templates_)): is_match = True for j in range(4): - if template[j] >= 0 and \ - template[j] != transitions_templates_[i][0][j]: + if template[j] >= 0 and template[j] != transitions_templates_[i][0][j]: is_match = False break if is_match: @@ -742,6 +741,29 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8): return generator +class EnvAgentStatic(object): + """ TODO: EnvAgentStatic - To store initial position, direction and target. + This is like static data for the environment - it's where an agent starts, + rather than where it is at the moment. + The target should also be stored here. + """ + def __init__(self, rcPos, iDir, rcTarget): + self.rcPos = rcPos + self.iDir = iDir + self.rcTarget = rcTarget + + +class EnvAgent(object): + """ TODO: EnvAgent - replace separate agent lists with a single list + of agent objects. The EnvAgent represent's the environment's view + of the dynamic agent state. So target is not part of it - target is + static. + """ + def __init__(self, rcPos, iDir): + self.rcPos = rcPos + self.iDir = iDir + + class RailEnv(Environment): """ RailEnv environment class. @@ -836,6 +858,8 @@ class RailEnv(Environment): return self.agents_handles def fill_valid_positions(self): + ''' Populate the valid_positions list for the current TransitionMap. + ''' self.valid_positions = valid_positions = [] for r in range(self.height): for c in range(self.width): @@ -843,12 +867,20 @@ class RailEnv(Environment): valid_positions.append((r, c)) def check_agent_lists(self): + ''' Check that the agent_handles, position and direction lists are all of length + number_of_agents. + (Suggest this is replaced with a single list of Agent objects :) + ''' for lAgents, name in zip( - [self.agents_handles, self.agents_position, self.agents_direction], - ["handles", "positions", "directions"]): + [self.agents_handles, self.agents_position, self.agents_direction], + ["handles", "positions", "directions"]): assert self.number_of_agents == len(lAgents), "Inconsistent agent list:" + name def check_agent_locdirpath(self, iAgent): + ''' Check that agent iAgent has a valid location and direction, + with a path to its target. + (Not currently used?) + ''' valid_movements = [] for direction in range(4): position = self.agents_position[iAgent] @@ -861,13 +893,20 @@ class RailEnv(Environment): for m in valid_movements: new_position = self._new_position(self.agents_position[iAgent], m[1]) if m[0] not in valid_starting_directions and \ - self._path_exists(new_position, m[0], self.agents_target[iAgent]): + self._path_exists(new_position, m[0], self.agents_target[iAgent]): valid_starting_directions.append(m[0]) if len(valid_starting_directions) == 0: return False + else: + return True def pick_agent_direction(self, rcPos, rcTarget): + """ Pick and return a valid direction index (0..3) for an agent starting at + row,col rcPos with target rcTarget. + Return None if no path exists. + Picks random direction if more than one exists (uniformly). + """ valid_movements = [] for direction in range(4): moves = self.rail.get_transitions((*rcPos, direction)) @@ -879,8 +918,7 @@ class RailEnv(Environment): valid_starting_directions = [] for m in valid_movements: new_position = self._new_position(rcPos, m[1]) - if m[0] not in valid_starting_directions and \ - self._path_exists(new_position, m[0], rcTarget): + if m[0] not in valid_starting_directions and self._path_exists(new_position, m[0], rcTarget): valid_starting_directions.append(m[0]) if len(valid_starting_directions) == 0: @@ -889,6 +927,11 @@ class RailEnv(Environment): return valid_starting_directions[np.random.choice(len(valid_starting_directions), 1)[0]] def add_agent(self, rcPos=None, rcTarget=None, iDir=None): + """ Add a new agent at position rcPos with target rcTarget and + initial direction index iDir. + Should also store this initial position etc as environment "meta-data" + but this does not yet exist. + """ self.check_agent_lists() if rcPos is None: @@ -949,31 +992,7 @@ class RailEnv(Environment): break else: self.agents_direction[i] = direction - - # Jeremy extracted this into the method pick_agent_direction - if False: - for i in range(self.number_of_agents): - valid_movements = [] - for direction in range(4): - position = self.agents_position[i] - moves = self.rail.get_transitions((position[0], position[1], direction)) - for move_index in range(4): - if moves[move_index]: - valid_movements.append((direction, move_index)) - - valid_starting_directions = [] - for m in valid_movements: - new_position = self._new_position(self.agents_position[i], m[1]) - if m[0] not in valid_starting_directions and \ - self._path_exists(new_position, m[0], self.agents_target[i]): - valid_starting_directions.append(m[0]) - - if len(valid_starting_directions) == 0: - re_generate = True - else: - self.agents_direction[i] = valid_starting_directions[ - np.random.choice(len(valid_starting_directions), 1)[0]] - + # Reset the state of the observation builder with the new environment self.obs_builder.reset() @@ -1079,14 +1098,14 @@ class RailEnv(Environment): movement = curv_dir curv_dir = (curv_dir + 1) % 4 - new_position = self._new_position(pos, movement) # Is it a legal move? 1) transition allows the movement in the # cell, 2) the new cell is not empty (case 0), 3) the cell is # free, i.e., no agent is currently in that cell - if new_position[1] >= self.width or \ - new_position[0] >= self.height or \ - new_position[0] < 0 or new_position[1] < 0: + if ( + new_position[1] >= self.width or + new_position[0] >= self.height or + new_position[0] < 0 or new_position[1] < 0): new_cell_isValid = False elif self.rail.get_transitions((new_position[0], new_position[1])) > 0: @@ -1095,7 +1114,7 @@ class RailEnv(Environment): new_cell_isValid = False # If transition validity hasn't been checked yet. - if transition_isValid == None: + if transition_isValid is None: transition_isValid = self.rail.get_transition( (pos[0], pos[1], direction), movement) or is_deadend @@ -1117,7 +1136,7 @@ class RailEnv(Environment): # if agent is not in target position, add step penalty if self.agents_position[i][0] == self.agents_target[i][0] and \ - self.agents_position[i][1] == self.agents_target[i][1]: + self.agents_position[i][1] == self.agents_target[i][1]: self.dones[handle] = True else: self.rewards_dict[handle] += step_penalty @@ -1126,7 +1145,7 @@ class RailEnv(Environment): num_agents_in_target_position = 0 for i in range(self.number_of_agents): if self.agents_position[i][0] == self.agents_target[i][0] and \ - self.agents_position[i][1] == self.agents_target[i][1]: + self.agents_position[i][1] == self.agents_target[i][1]: num_agents_in_target_position += 1 if num_agents_in_target_position == self.number_of_agents: -- GitLab