Skip to content
Snippets Groups Projects
Commit 080b5d03 authored by spiglerg's avatar spiglerg
Browse files

Merge branch '115_agent_not_moving' into 'master'

115 agent not moving

Closes #116

See merge request flatland/flatland!119
parents 2b0cbd96 254c74bd
No related branches found
No related tags found
No related merge requests found
...@@ -358,21 +358,26 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -358,21 +358,26 @@ class TreeObsForRailEnv(ObservationBuilder):
if int_position in np.delete(self.predicted_pos[tot_dist], handle, 0): if int_position in np.delete(self.predicted_pos[tot_dist], handle, 0):
conflicting_agent = np.where(self.predicted_pos[tot_dist] == int_position) conflicting_agent = np.where(self.predicted_pos[tot_dist] == int_position)
for ca in conflicting_agent[0]: for ca in conflicting_agent[0]:
if direction != self.predicted_dir[tot_dist][ca] and tot_dist < potential_conflict: if direction != self.predicted_dir[tot_dist][ca] and tot_dist < potential_conflict:
potential_conflict = tot_dist potential_conflict = tot_dist
if self.env.dones[ca] and tot_dist < potential_conflict:
potential_conflict = tot_dist
# Look for opposing paths at distance num_step-1 # Look for opposing paths at distance num_step-1
elif int_position in np.delete(self.predicted_pos[pre_step], handle, 0): elif int_position in np.delete(self.predicted_pos[pre_step], handle, 0):
conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position) conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position)
for ca in conflicting_agent[0]: for ca in conflicting_agent[0]:
if direction != self.predicted_dir[pre_step][ca] and tot_dist < potential_conflict: if direction != self.predicted_dir[pre_step][ca] and tot_dist < potential_conflict:
potential_conflict = tot_dist potential_conflict = tot_dist
if self.env.dones[ca] and tot_dist < potential_conflict:
potential_conflict = tot_dist
# Look for opposing paths at distance num_step+1 # Look for opposing paths at distance num_step+1
elif int_position in np.delete(self.predicted_pos[post_step], handle, 0): elif int_position in np.delete(self.predicted_pos[post_step], handle, 0):
conflicting_agent = np.where(self.predicted_pos[post_step] == int_position) conflicting_agent = np.where(self.predicted_pos[post_step] == int_position)
for ca in conflicting_agent[0]: for ca in conflicting_agent[0]:
if direction != self.predicted_dir[post_step][ca] and tot_dist < potential_conflict: if direction != self.predicted_dir[post_step][ca] and tot_dist < potential_conflict:
potential_conflict = tot_dist potential_conflict = tot_dist
if self.env.dones[ca] and tot_dist < potential_conflict:
potential_conflict = tot_dist
if position in self.location_has_target and position != agent.target: if position in self.location_has_target and position != agent.target:
if tot_dist < other_target_encountered: if tot_dist < other_target_encountered:
......
...@@ -334,6 +334,7 @@ class RailEnv(Environment): ...@@ -334,6 +334,7 @@ class RailEnv(Environment):
if np.equal(agent.position, agent.target).all(): if np.equal(agent.position, agent.target).all():
self.dones[i_agent] = True self.dones[i_agent] = True
agent.moving = False
else: else:
self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed'] self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed']
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment