diff --git a/examples/training_example.py b/examples/training_example.py index 8f7d18f55ec63f2ed2d57584817d14f8bc722259..c038e7b477069957efdec622b2c56e9e84cb7ac0 100644 --- a/examples/training_example.py +++ b/examples/training_example.py @@ -10,7 +10,6 @@ np.random.seed(1) # Use the complex_rail_generator to generate feasible network configurations with corresponding tasks # Training on simple small tasks is the best way to get familiar with the environment -# TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()) LocalGridObs = LocalObsForRailEnv(view_height=10, view_width=2, center=2) diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py index 88a79ea72606e7c1b46f92f6d73429979d67a4e6..4718ad9906db9b479123b53e9e9df0ff4db3b462 100644 --- a/flatland/envs/predictions.py +++ b/flatland/envs/predictions.py @@ -124,8 +124,12 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): for agent in agents: _agent_initial_position = agent.position _agent_initial_direction = agent.direction + agent_speed = agent.speed_data["speed"] + times_per_cell = int(np.reciprocal(agent_speed)) prediction = np.zeros(shape=(self.max_depth + 1, 5)) prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0] + new_direction = _agent_initial_direction + new_position = _agent_initial_position visited = set() for index in range(1, self.max_depth + 1): # if we're at the target, stop moving... @@ -140,12 +144,10 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): # Take shortest possible path cell_transitions = self.env.rail.get_transitions(*agent.position, agent.direction) - new_position = None - new_direction = None - if np.sum(cell_transitions) == 1: + if np.sum(cell_transitions) == 1 and index % times_per_cell == 0: new_direction = np.argmax(cell_transitions) new_position = get_new_position(agent.position, new_direction) - elif np.sum(cell_transitions) > 1: + elif np.sum(cell_transitions) > 1 and index % times_per_cell == 0: min_dist = np.inf no_dist_found = True for direction in range(4): @@ -157,7 +159,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): new_direction = direction no_dist_found = False new_position = get_new_position(agent.position, new_direction) - else: + elif index % times_per_cell == 0: raise Exception("No transition possible {}".format(cell_transitions)) # update the agent's position and direction