diff --git a/examples/training_example.py b/examples/training_example.py index cfed6c92cc74c45445c436a65d15c9eb8292fe32..521cca8c786ca28c40460da95a6379a5962705b2 100644 --- a/examples/training_example.py +++ b/examples/training_example.py @@ -84,7 +84,7 @@ for trials in range(1, n_trials + 1): # Environment step which returns the observations for all agents, their corresponding # reward and whether their are done next_obs, all_rewards, done, _ = env.step(action_dict) - env_renderer.render_env(show=True, show_observations=True, show_predictions=True) + env_renderer.render_env(show=True, show_observations=False, show_predictions=True) # Update replay buffer and train agent for a in range(env.get_num_agents()): diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py index 88a79ea72606e7c1b46f92f6d73429979d67a4e6..4718ad9906db9b479123b53e9e9df0ff4db3b462 100644 --- a/flatland/envs/predictions.py +++ b/flatland/envs/predictions.py @@ -124,8 +124,12 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): for agent in agents: _agent_initial_position = agent.position _agent_initial_direction = agent.direction + agent_speed = agent.speed_data["speed"] + times_per_cell = int(np.reciprocal(agent_speed)) prediction = np.zeros(shape=(self.max_depth + 1, 5)) prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0] + new_direction = _agent_initial_direction + new_position = _agent_initial_position visited = set() for index in range(1, self.max_depth + 1): # if we're at the target, stop moving... @@ -140,12 +144,10 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): # Take shortest possible path cell_transitions = self.env.rail.get_transitions(*agent.position, agent.direction) - new_position = None - new_direction = None - if np.sum(cell_transitions) == 1: + if np.sum(cell_transitions) == 1 and index % times_per_cell == 0: new_direction = np.argmax(cell_transitions) new_position = get_new_position(agent.position, new_direction) - elif np.sum(cell_transitions) > 1: + elif np.sum(cell_transitions) > 1 and index % times_per_cell == 0: min_dist = np.inf no_dist_found = True for direction in range(4): @@ -157,7 +159,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): new_direction = direction no_dist_found = False new_position = get_new_position(agent.position, new_direction) - else: + elif index % times_per_cell == 0: raise Exception("No transition possible {}".format(cell_transitions)) # update the agent's position and direction