Commit fcc163ab authored by hagrid67's avatar hagrid67
Browse files

Added comments and fixed lint

parent 38423832
......@@ -137,7 +137,7 @@ class TreeObsForRailEnv(ObservationBuilder):
new_cell = self._new_position(position, neigh_direction)
if new_cell[0] >= 0 and new_cell[0] < self.env.height and \
new_cell[1] >= 0 and new_cell[1] < self.env.width:
new_cell[1] >= 0 and new_cell[1] < self.env.width:
desired_movement_from_new_cell = (neigh_direction + 2) % 4
......
......@@ -163,9 +163,9 @@ def a_star(rail_trans, rail_array, start, end):
for new_pos in [(0, -1), (0, 1), (-1, 0), (1, 0)]:
node_pos = (current_node.pos[0] + new_pos[0], current_node.pos[1] + new_pos[1])
if node_pos[0] >= rail_shape[0] or \
node_pos[0] < 0 or \
node_pos[1] >= rail_shape[1] or \
node_pos[1] < 0:
node_pos[0] < 0 or \
node_pos[1] >= rail_shape[1] or \
node_pos[1] < 0:
continue
# validate positions
......@@ -540,8 +540,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
for i in range(len(transitions_templates_)):
is_match = True
for j in range(4):
if template[j] >= 0 and \
template[j] != transitions_templates_[i][0][j]:
if template[j] >= 0 and template[j] != transitions_templates_[i][0][j]:
is_match = False
break
if is_match:
......@@ -742,6 +741,29 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
return generator
class EnvAgentStatic(object):
""" TODO: EnvAgentStatic - To store initial position, direction and target.
This is like static data for the environment - it's where an agent starts,
rather than where it is at the moment.
The target should also be stored here.
"""
def __init__(self, rcPos, iDir, rcTarget):
self.rcPos = rcPos
self.iDir = iDir
self.rcTarget = rcTarget
class EnvAgent(object):
""" TODO: EnvAgent - replace separate agent lists with a single list
of agent objects. The EnvAgent represent's the environment's view
of the dynamic agent state. So target is not part of it - target is
static.
"""
def __init__(self, rcPos, iDir):
self.rcPos = rcPos
self.iDir = iDir
class RailEnv(Environment):
"""
RailEnv environment class.
......@@ -836,6 +858,8 @@ class RailEnv(Environment):
return self.agents_handles
def fill_valid_positions(self):
''' Populate the valid_positions list for the current TransitionMap.
'''
self.valid_positions = valid_positions = []
for r in range(self.height):
for c in range(self.width):
......@@ -843,12 +867,20 @@ class RailEnv(Environment):
valid_positions.append((r, c))
def check_agent_lists(self):
''' Check that the agent_handles, position and direction lists are all of length
number_of_agents.
(Suggest this is replaced with a single list of Agent objects :)
'''
for lAgents, name in zip(
[self.agents_handles, self.agents_position, self.agents_direction],
["handles", "positions", "directions"]):
[self.agents_handles, self.agents_position, self.agents_direction],
["handles", "positions", "directions"]):
assert self.number_of_agents == len(lAgents), "Inconsistent agent list:" + name
def check_agent_locdirpath(self, iAgent):
''' Check that agent iAgent has a valid location and direction,
with a path to its target.
(Not currently used?)
'''
valid_movements = []
for direction in range(4):
position = self.agents_position[iAgent]
......@@ -861,13 +893,20 @@ class RailEnv(Environment):
for m in valid_movements:
new_position = self._new_position(self.agents_position[iAgent], m[1])
if m[0] not in valid_starting_directions and \
self._path_exists(new_position, m[0], self.agents_target[iAgent]):
self._path_exists(new_position, m[0], self.agents_target[iAgent]):
valid_starting_directions.append(m[0])
if len(valid_starting_directions) == 0:
return False
else:
return True
def pick_agent_direction(self, rcPos, rcTarget):
""" Pick and return a valid direction index (0..3) for an agent starting at
row,col rcPos with target rcTarget.
Return None if no path exists.
Picks random direction if more than one exists (uniformly).
"""
valid_movements = []
for direction in range(4):
moves = self.rail.get_transitions((*rcPos, direction))
......@@ -879,8 +918,7 @@ class RailEnv(Environment):
valid_starting_directions = []
for m in valid_movements:
new_position = self._new_position(rcPos, m[1])
if m[0] not in valid_starting_directions and \
self._path_exists(new_position, m[0], rcTarget):
if m[0] not in valid_starting_directions and self._path_exists(new_position, m[0], rcTarget):
valid_starting_directions.append(m[0])
if len(valid_starting_directions) == 0:
......@@ -889,6 +927,11 @@ class RailEnv(Environment):
return valid_starting_directions[np.random.choice(len(valid_starting_directions), 1)[0]]
def add_agent(self, rcPos=None, rcTarget=None, iDir=None):
""" Add a new agent at position rcPos with target rcTarget and
initial direction index iDir.
Should also store this initial position etc as environment "meta-data"
but this does not yet exist.
"""
self.check_agent_lists()
if rcPos is None:
......@@ -949,31 +992,7 @@ class RailEnv(Environment):
break
else:
self.agents_direction[i] = direction
# Jeremy extracted this into the method pick_agent_direction
if False:
for i in range(self.number_of_agents):
valid_movements = []
for direction in range(4):
position = self.agents_position[i]
moves = self.rail.get_transitions((position[0], position[1], direction))
for move_index in range(4):
if moves[move_index]:
valid_movements.append((direction, move_index))
valid_starting_directions = []
for m in valid_movements:
new_position = self._new_position(self.agents_position[i], m[1])
if m[0] not in valid_starting_directions and \
self._path_exists(new_position, m[0], self.agents_target[i]):
valid_starting_directions.append(m[0])
if len(valid_starting_directions) == 0:
re_generate = True
else:
self.agents_direction[i] = valid_starting_directions[
np.random.choice(len(valid_starting_directions), 1)[0]]
# Reset the state of the observation builder with the new environment
self.obs_builder.reset()
......@@ -1079,14 +1098,14 @@ class RailEnv(Environment):
movement = curv_dir
curv_dir = (curv_dir + 1) % 4
new_position = self._new_position(pos, movement)
# Is it a legal move? 1) transition allows the movement in the
# cell, 2) the new cell is not empty (case 0), 3) the cell is
# free, i.e., no agent is currently in that cell
if new_position[1] >= self.width or \
new_position[0] >= self.height or \
new_position[0] < 0 or new_position[1] < 0:
if (
new_position[1] >= self.width or
new_position[0] >= self.height or
new_position[0] < 0 or new_position[1] < 0):
new_cell_isValid = False
elif self.rail.get_transitions((new_position[0], new_position[1])) > 0:
......@@ -1095,7 +1114,7 @@ class RailEnv(Environment):
new_cell_isValid = False
# If transition validity hasn't been checked yet.
if transition_isValid == None:
if transition_isValid is None:
transition_isValid = self.rail.get_transition(
(pos[0], pos[1], direction),
movement) or is_deadend
......@@ -1117,7 +1136,7 @@ class RailEnv(Environment):
# if agent is not in target position, add step penalty
if self.agents_position[i][0] == self.agents_target[i][0] and \
self.agents_position[i][1] == self.agents_target[i][1]:
self.agents_position[i][1] == self.agents_target[i][1]:
self.dones[handle] = True
else:
self.rewards_dict[handle] += step_penalty
......@@ -1126,7 +1145,7 @@ class RailEnv(Environment):
num_agents_in_target_position = 0
for i in range(self.number_of_agents):
if self.agents_position[i][0] == self.agents_target[i][0] and \
self.agents_position[i][1] == self.agents_target[i][1]:
self.agents_position[i][1] == self.agents_target[i][1]:
num_agents_in_target_position += 1
if num_agents_in_target_position == self.number_of_agents:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment