Skip to content
Snippets Groups Projects
Commit 5bf451eb authored by spiglerg's avatar spiglerg
Browse files

prevent stopping in the middle of a cell

parent 65397f68
No related branches found
No related tags found
No related merge requests found
...@@ -28,19 +28,32 @@ class EnvAgentStatic(object): ...@@ -28,19 +28,32 @@ class EnvAgentStatic(object):
position = attrib() position = attrib()
direction = attrib() direction = attrib()
target = attrib() target = attrib()
moving = attrib() moving = attrib(default=False)
# speed_data: speed is added to position_fraction on each moving step, until position_fraction>=1.0,
def __init__(self, position, direction, target, moving=False): # after which 'transition_action_on_cellexit' is executed (equivalent to executing that action in the previous
# cell if speed=1, as default)
speed_data = attrib(default=dict({'position_fraction': 0.0, 'speed': 1.0, 'transition_action_on_cellexit': 0}))
def __init__(self,
position,
direction,
target,
moving=False,
speed_data={'position_fraction': 0.0, 'speed': 1.0, 'transition_action_on_cellexit': 0}):
self.position = position self.position = position
self.direction = direction self.direction = direction
self.target = target self.target = target
self.moving = moving self.moving = moving
self.speed_data = speed_data
@classmethod @classmethod
def from_lists(cls, positions, directions, targets): def from_lists(cls, positions, directions, targets):
""" Create a list of EnvAgentStatics from lists of positions, directions and targets """ Create a list of EnvAgentStatics from lists of positions, directions and targets
""" """
return list(starmap(EnvAgentStatic, zip(positions, directions, targets, [False] * len(positions)))) speed_datas = []
for i in range(len(positions)):
speed_datas.append({'position_fraction': 0.0, 'speed': 1.0, 'transition_action_on_cellexit': 0})
return list(starmap(EnvAgentStatic, zip(positions, directions, targets, [False] * len(positions), speed_datas)))
def to_list(self): def to_list(self):
...@@ -54,7 +67,7 @@ class EnvAgentStatic(object): ...@@ -54,7 +67,7 @@ class EnvAgentStatic(object):
if type(lTarget) is np.ndarray: if type(lTarget) is np.ndarray:
lTarget = lTarget.tolist() lTarget = lTarget.tolist()
return [lPos, int(self.direction), lTarget, int(self.moving)] return [lPos, int(self.direction), lTarget, int(self.moving), self.speed_data]
@attrs @attrs
...@@ -78,7 +91,7 @@ class EnvAgent(EnvAgentStatic): ...@@ -78,7 +91,7 @@ class EnvAgent(EnvAgentStatic):
def to_list(self): def to_list(self):
return [ return [
self.position, self.direction, self.target, self.handle, self.position, self.direction, self.target, self.handle,
self.old_direction, self.old_position, self.moving] self.old_direction, self.old_position, self.moving, self.speed_data]
@classmethod @classmethod
def from_static(cls, oStatic): def from_static(cls, oStatic):
......
...@@ -73,7 +73,7 @@ class RailEnv(Environment): ...@@ -73,7 +73,7 @@ class RailEnv(Environment):
random_rail_generator : generate a random rail of given size random_rail_generator : generate a random rail of given size
rail_from_GridTransitionMap_generator(rail_map) : generate a rail from rail_from_GridTransitionMap_generator(rail_map) : generate a rail from
a GridTransitionMap object a GridTransitionMap object
rail_from_manual_specifications_generator(rail_spec) : generate a rail from rail_from_manual_sp ecifications_generator(rail_spec) : generate a rail from
a rail specifications array a rail specifications array
TODO: generate_rail_from_saved_list or from list of ndarray bitmaps --- TODO: generate_rail_from_saved_list or from list of ndarray bitmaps ---
width : int width : int
...@@ -101,7 +101,6 @@ class RailEnv(Environment): ...@@ -101,7 +101,6 @@ class RailEnv(Environment):
self.action_space = [1] self.action_space = [1]
self.observation_space = self.obs_builder.observation_space # updated on resets? self.observation_space = self.obs_builder.observation_space # updated on resets?
self.actions = [0] * number_of_agents
self.rewards = [0] * number_of_agents self.rewards = [0] * number_of_agents
self.done = False self.done = False
...@@ -192,29 +191,33 @@ class RailEnv(Environment): ...@@ -192,29 +191,33 @@ class RailEnv(Environment):
# for i in range(len(self.agents_handles)): # for i in range(len(self.agents_handles)):
for iAgent in range(self.get_num_agents()): for iAgent in range(self.get_num_agents()):
agent = self.agents[iAgent] agent = self.agents[iAgent]
agent.speed_data['speed']=0.5
if iAgent not in action_dict: # no action has been supplied for this agent
if agent.moving:
# Keep moving
# Change MOVE_FORWARD to DO_NOTHING
action_dict[iAgent] = RailEnvActions.DO_NOTHING
else:
action_dict[iAgent] = RailEnvActions.DO_NOTHING
if self.dones[iAgent]: # this agent has already completed... if self.dones[iAgent]: # this agent has already completed...
continue continue
action = action_dict[iAgent]
if action < 0 or action > len(RailEnvActions): if np.equal(agent.position, agent.target).all():
print('ERROR: illegal action=', action, self.dones[iAgent] = True
'for agent with index=', iAgent) else:
return self.rewards_dict[iAgent] += step_penalty * agent.speed_data['speed']
if iAgent not in action_dict: # no action has been supplied for this agent
action_dict[iAgent] = RailEnvActions.DO_NOTHING
if action_dict[iAgent] < 0 or action_dict[iAgent] > len(RailEnvActions):
print('ERROR: illegal action=', action_dict[iAgent],
'for agent with index=', iAgent,
'"DO NOTHING" will be executed instead')
action_dict[iAgent] = RailEnvActions.DO_NOTHING
action = action_dict[iAgent]
if action == RailEnvActions.DO_NOTHING and agent.moving: if action == RailEnvActions.DO_NOTHING and agent.moving:
# Keep moving # Keep moving
action = RailEnvActions.MOVE_FORWARD action = RailEnvActions.MOVE_FORWARD
if action == RailEnvActions.STOP_MOVING and agent.moving: if action == RailEnvActions.STOP_MOVING and agent.moving and agent.speed_data['position_fraction'] < 0.01:
# Only allow halting an agent on entering new cells.
agent.moving = False agent.moving = False
self.rewards_dict[iAgent] += stop_penalty self.rewards_dict[iAgent] += stop_penalty
...@@ -223,47 +226,73 @@ class RailEnv(Environment): ...@@ -223,47 +226,73 @@ class RailEnv(Environment):
agent.moving = True agent.moving = True
self.rewards_dict[iAgent] += start_penalty self.rewards_dict[iAgent] += start_penalty
if action != RailEnvActions.DO_NOTHING and action != RailEnvActions.STOP_MOVING: # Now perform a movement.
cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \ # If the agent is in an initial position within a new cell (agent.speed_data['position_fraction']<eps)
self._check_action_on_agent(action, agent) # store the desired action in `transition_action_on_cellexit' (only if the desired transition is
if all([new_cell_isValid, transition_isValid, cell_isFree]): # allowed! otherwise DO_NOTHING!)
agent.old_direction = agent.direction # Then in any case (if agent.moving) and the `transition_action_on_cellexit' is valid, increment the
agent.old_position = agent.position # position_fraction by the speed of the agent (regardless of action taken, as long as no
agent.position = new_position # STOP_MOVING, but that makes agent.moving=False)
agent.direction = new_direction # If the new position fraction is >= 1, reset to 0, and perform the stored
else: # transition_action_on_cellexit
# Logic: if the chosen action is invalid,
# and it was LEFT or RIGHT, and the agent was moving, then keep moving FORWARD. # If the agent can make an action
if (action == RailEnvActions.MOVE_LEFT or action == RailEnvActions.MOVE_RIGHT) and agent.moving: action_selected = False
cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \ if agent.speed_data['position_fraction'] < 0.01:
self._check_action_on_agent(RailEnvActions.MOVE_FORWARD, agent) if action != RailEnvActions.DO_NOTHING and action != RailEnvActions.STOP_MOVING:
cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \
if all([new_cell_isValid, transition_isValid, cell_isFree]): self._check_action_on_agent(action, agent)
agent.old_direction = agent.direction
agent.old_position = agent.position if all([new_cell_isValid, transition_isValid, cell_isFree]):
agent.position = new_position agent.speed_data['transition_action_on_cellexit'] = action
agent.direction = new_direction action_selected = True
else:
# But, if the chosen invalid action was LEFT/RIGHT, and the agent is moving,
# try to keep moving forward!
if (action == RailEnvActions.MOVE_LEFT or action == RailEnvActions.MOVE_RIGHT) and agent.moving:
cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \
self._check_action_on_agent(RailEnvActions.MOVE_FORWARD, agent)
if all([new_cell_isValid, transition_isValid, cell_isFree]):
agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_FORWARD
action_selected = True
else:
# TODO: an invalid action was chosen after entering the cell. The agent cannot move.
self.rewards_dict[iAgent] += invalid_action_penalty
agent.moving = False
self.rewards_dict[iAgent] += stop_penalty
continue
else: else:
# the action was not valid, add penalty # TODO: an invalid action was chosen after entering the cell. The agent cannot move.
self.rewards_dict[iAgent] += invalid_action_penalty self.rewards_dict[iAgent] += invalid_action_penalty
agent.moving = False
self.rewards_dict[iAgent] += stop_penalty
continue
else: if agent.moving and (action_selected or agent.speed_data['position_fraction'] >= 0.01):
# the action was not valid, add penalty agent.speed_data['position_fraction'] += agent.speed_data['speed']
self.rewards_dict[iAgent] += invalid_action_penalty
if np.equal(agent.position, agent.target).all(): if agent.speed_data['position_fraction'] >= 1.0:
self.dones[iAgent] = True agent.speed_data['position_fraction'] = 0.0
else:
self.rewards_dict[iAgent] += step_penalty # Perform stored action to transition to the next cell
# Now 'transition_action_on_cellexit' will be guaranteed to be valid; it was checked on entering
# the cell
cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \
self._check_action_on_agent(agent.speed_data['transition_action_on_cellexit'], agent)
agent.old_direction = agent.direction
agent.old_position = agent.position
agent.position = new_position
agent.direction = new_direction
# Check for end of episode + add global reward to all rewards! # Check for end of episode + add global reward to all rewards!
if np.all([np.array_equal(agent2.position, agent2.target) for agent2 in self.agents]): if np.all([np.array_equal(agent2.position, agent2.target) for agent2 in self.agents]):
self.dones["__all__"] = True self.dones["__all__"] = True
self.rewards_dict = [0 * r + global_reward for r in self.rewards_dict] self.rewards_dict = [0 * r + global_reward for r in self.rewards_dict]
# Reset the step actions (in case some agent doesn't 'register_action'
# on the next step)
self.actions = [0] * self.get_num_agents()
return self._get_observations(), self.rewards_dict, self.dones, {} return self._get_observations(), self.rewards_dict, self.dones, {}
def _check_action_on_agent(self, action, agent): def _check_action_on_agent(self, action, agent):
...@@ -271,6 +300,7 @@ class RailEnv(Environment): ...@@ -271,6 +300,7 @@ class RailEnv(Environment):
# cell used to check for invalid actions # cell used to check for invalid actions
new_direction, transition_isValid = self.check_action(agent, action) new_direction, transition_isValid = self.check_action(agent, action)
new_position = get_new_position(agent.position, new_direction) new_position = get_new_position(agent.position, new_direction)
# Is it a legal move? # Is it a legal move?
# 1) transition allows the new_direction in the cell, # 1) transition allows the new_direction in the cell,
# 2) the new cell is not empty (case 0), # 2) the new cell is not empty (case 0),
...@@ -281,11 +311,13 @@ class RailEnv(Environment): ...@@ -281,11 +311,13 @@ class RailEnv(Environment):
np.clip(new_position, [0, 0], [self.height - 1, self.width - 1])) np.clip(new_position, [0, 0], [self.height - 1, self.width - 1]))
and # check the new position has some transitions (ie is not an empty cell) and # check the new position has some transitions (ie is not an empty cell)
self.rail.get_transitions(new_position) > 0) self.rail.get_transitions(new_position) > 0)
# If transition validity hasn't been checked yet. # If transition validity hasn't been checked yet.
if transition_isValid is None: if transition_isValid is None:
transition_isValid = self.rail.get_transition( transition_isValid = self.rail.get_transition(
(*agent.position, agent.direction), (*agent.position, agent.direction),
new_direction) new_direction)
# Check the new position is not the same as any of the existing agent positions # Check the new position is not the same as any of the existing agent positions
# (including itself, for simplicity, since it is moving) # (including itself, for simplicity, since it is moving)
cell_isFree = not np.any( cell_isFree = not np.any(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment