Skip to content
Snippets Groups Projects
Commit 3aa177e9 authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

Merge branch '140_predictor_multi_speed' into 'master'

140 predictor multi speed

See merge request flatland/flatland!141
parents 8775f09c b935770a
No related branches found
No related tags found
No related merge requests found
......@@ -10,7 +10,6 @@ np.random.seed(1)
# Use the complex_rail_generator to generate feasible network configurations with corresponding tasks
# Training on simple small tasks is the best way to get familiar with the environment
#
TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
LocalGridObs = LocalObsForRailEnv(view_height=10, view_width=2, center=2)
......
......@@ -124,8 +124,12 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
for agent in agents:
_agent_initial_position = agent.position
_agent_initial_direction = agent.direction
agent_speed = agent.speed_data["speed"]
times_per_cell = int(np.reciprocal(agent_speed))
prediction = np.zeros(shape=(self.max_depth + 1, 5))
prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0]
new_direction = _agent_initial_direction
new_position = _agent_initial_position
visited = set()
for index in range(1, self.max_depth + 1):
# if we're at the target, stop moving...
......@@ -140,12 +144,10 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
# Take shortest possible path
cell_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
new_position = None
new_direction = None
if np.sum(cell_transitions) == 1:
if np.sum(cell_transitions) == 1 and index % times_per_cell == 0:
new_direction = np.argmax(cell_transitions)
new_position = get_new_position(agent.position, new_direction)
elif np.sum(cell_transitions) > 1:
elif np.sum(cell_transitions) > 1 and index % times_per_cell == 0:
min_dist = np.inf
no_dist_found = True
for direction in range(4):
......@@ -157,7 +159,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
new_direction = direction
no_dist_found = False
new_position = get_new_position(agent.position, new_direction)
else:
elif index % times_per_cell == 0:
raise Exception("No transition possible {}".format(cell_transitions))
# update the agent's position and direction
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment