Skip to content
Snippets Groups Projects
Commit dc738d78 authored by Erik Nygren's avatar Erik Nygren
Browse files

updated prediction length in shortestpathpredictor

parent e4c43e71
No related branches found
No related tags found
No related merge requests found
......@@ -17,11 +17,14 @@ from utils.observation_utils import normalize_observation
random.seed(3)
np.random.seed(2)
tree_depth = 3
observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=ShortestPathPredictorForRailEnv(10))
file_name = "./railway/simple_avoid.pkl"
env = RailEnv(width=10,
height=20,
rail_generator=rail_from_file(file_name),
obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()))
obs_builder_object=observation_helper)
x_dim = env.width
y_dim = env.height
......@@ -38,13 +41,12 @@ env = RailEnv(width=x_dim,
rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=2, min_dist=min_dist,
max_dist=99999,
seed=0),
obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()),
obs_builder_object=observation_helper,
number_of_agents=n_agents)
env.reset(True, True)
"""
tree_depth = 3
observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=ShortestPathPredictorForRailEnv())
env_renderer = RenderTool(env, gl="PILSVG", )
handle = env.get_agent_handles()
num_features_per_node = env.obs_builder.observation_dim
......
......@@ -8,7 +8,6 @@ from flatland.core.env_prediction_builder import PredictionBuilder
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.rail_env import RailEnvActions
class ShortestPathPredictorForRailEnv(PredictionBuilder):
"""
ShortestPathPredictorForRailEnv object.
......@@ -17,6 +16,9 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
The prediction acts as if no other agent is in the environment and always takes the forward action.
"""
def __init__(self, max_depth):
self.max_depth = max_depth
def get(self, custom_args=None, handle=None):
"""
Called whenever get_many in the observation build is called.
......@@ -54,6 +56,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
prediction = np.zeros(shape=(self.max_depth + 1, 5))
prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0]
visited = set()
for index in range(1, self.max_depth + 1):
# if we're at the target, stop moving...
if agent.position == agent.target:
......@@ -100,4 +103,5 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
# cleanup: reset initial position
agent.position = _agent_initial_position
agent.direction = _agent_initial_direction
return prediction_dict
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment