From fe26db7d95b144d437120662437409981c659e1f Mon Sep 17 00:00:00 2001 From: MLErik <baerenjesus@gmail.com> Date: Mon, 5 Aug 2019 16:26:38 -0400 Subject: [PATCH] Updated predictor to respect differential speed in env. --- examples/training_example.py | 2 +- flatland/envs/predictions.py | 12 +++++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/examples/training_example.py b/examples/training_example.py index cfed6c92..521cca8c 100644 --- a/examples/training_example.py +++ b/examples/training_example.py @@ -84,7 +84,7 @@ for trials in range(1, n_trials + 1): # Environment step which returns the observations for all agents, their corresponding # reward and whether their are done next_obs, all_rewards, done, _ = env.step(action_dict) - env_renderer.render_env(show=True, show_observations=True, show_predictions=True) + env_renderer.render_env(show=True, show_observations=False, show_predictions=True) # Update replay buffer and train agent for a in range(env.get_num_agents()): diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py index 88a79ea7..4718ad99 100644 --- a/flatland/envs/predictions.py +++ b/flatland/envs/predictions.py @@ -124,8 +124,12 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): for agent in agents: _agent_initial_position = agent.position _agent_initial_direction = agent.direction + agent_speed = agent.speed_data["speed"] + times_per_cell = int(np.reciprocal(agent_speed)) prediction = np.zeros(shape=(self.max_depth + 1, 5)) prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0] + new_direction = _agent_initial_direction + new_position = _agent_initial_position visited = set() for index in range(1, self.max_depth + 1): # if we're at the target, stop moving... @@ -140,12 +144,10 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): # Take shortest possible path cell_transitions = self.env.rail.get_transitions(*agent.position, agent.direction) - new_position = None - new_direction = None - if np.sum(cell_transitions) == 1: + if np.sum(cell_transitions) == 1 and index % times_per_cell == 0: new_direction = np.argmax(cell_transitions) new_position = get_new_position(agent.position, new_direction) - elif np.sum(cell_transitions) > 1: + elif np.sum(cell_transitions) > 1 and index % times_per_cell == 0: min_dist = np.inf no_dist_found = True for direction in range(4): @@ -157,7 +159,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): new_direction = direction no_dist_found = False new_position = get_new_position(agent.position, new_direction) - else: + elif index % times_per_cell == 0: raise Exception("No transition possible {}".format(cell_transitions)) # update the agent's position and direction -- GitLab