Commit fe26db7d authored by Erik Nygren's avatar Erik Nygren 🚅
Browse files

Updated predictor to respect differential speed in env.

parent a447381f
Pipeline #1690 canceled with stages
......@@ -84,7 +84,7 @@ for trials in range(1, n_trials + 1):
# Environment step which returns the observations for all agents, their corresponding
# reward and whether their are done
next_obs, all_rewards, done, _ = env.step(action_dict)
env_renderer.render_env(show=True, show_observations=True, show_predictions=True)
env_renderer.render_env(show=True, show_observations=False, show_predictions=True)
# Update replay buffer and train agent
for a in range(env.get_num_agents()):
......
......@@ -124,8 +124,12 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
for agent in agents:
_agent_initial_position = agent.position
_agent_initial_direction = agent.direction
agent_speed = agent.speed_data["speed"]
times_per_cell = int(np.reciprocal(agent_speed))
prediction = np.zeros(shape=(self.max_depth + 1, 5))
prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0]
new_direction = _agent_initial_direction
new_position = _agent_initial_position
visited = set()
for index in range(1, self.max_depth + 1):
# if we're at the target, stop moving...
......@@ -140,12 +144,10 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
# Take shortest possible path
cell_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
new_position = None
new_direction = None
if np.sum(cell_transitions) == 1:
if np.sum(cell_transitions) == 1 and index % times_per_cell == 0:
new_direction = np.argmax(cell_transitions)
new_position = get_new_position(agent.position, new_direction)
elif np.sum(cell_transitions) > 1:
elif np.sum(cell_transitions) > 1 and index % times_per_cell == 0:
min_dist = np.inf
no_dist_found = True
for direction in range(4):
......@@ -157,7 +159,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
new_direction = direction
no_dist_found = False
new_position = get_new_position(agent.position, new_direction)
else:
elif index % times_per_cell == 0:
raise Exception("No transition possible {}".format(cell_transitions))
# update the agent's position and direction
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment