Skip to content
Snippets Groups Projects
Commit dc738d78 authored by Erik Nygren's avatar Erik Nygren
Browse files

updated prediction length in shortestpathpredictor

parent e4c43e71
No related branches found
No related tags found
No related merge requests found
...@@ -17,11 +17,14 @@ from utils.observation_utils import normalize_observation ...@@ -17,11 +17,14 @@ from utils.observation_utils import normalize_observation
random.seed(3) random.seed(3)
np.random.seed(2) np.random.seed(2)
tree_depth = 3
observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=ShortestPathPredictorForRailEnv(10))
file_name = "./railway/simple_avoid.pkl" file_name = "./railway/simple_avoid.pkl"
env = RailEnv(width=10, env = RailEnv(width=10,
height=20, height=20,
rail_generator=rail_from_file(file_name), rail_generator=rail_from_file(file_name),
obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv())) obs_builder_object=observation_helper)
x_dim = env.width x_dim = env.width
y_dim = env.height y_dim = env.height
...@@ -38,13 +41,12 @@ env = RailEnv(width=x_dim, ...@@ -38,13 +41,12 @@ env = RailEnv(width=x_dim,
rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=2, min_dist=min_dist, rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=2, min_dist=min_dist,
max_dist=99999, max_dist=99999,
seed=0), seed=0),
obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()), obs_builder_object=observation_helper,
number_of_agents=n_agents) number_of_agents=n_agents)
env.reset(True, True) env.reset(True, True)
""" """
tree_depth = 3
observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=ShortestPathPredictorForRailEnv())
env_renderer = RenderTool(env, gl="PILSVG", ) env_renderer = RenderTool(env, gl="PILSVG", )
handle = env.get_agent_handles() handle = env.get_agent_handles()
num_features_per_node = env.obs_builder.observation_dim num_features_per_node = env.obs_builder.observation_dim
......
...@@ -8,7 +8,6 @@ from flatland.core.env_prediction_builder import PredictionBuilder ...@@ -8,7 +8,6 @@ from flatland.core.env_prediction_builder import PredictionBuilder
from flatland.core.grid.grid4_utils import get_new_position from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.rail_env import RailEnvActions from flatland.envs.rail_env import RailEnvActions
class ShortestPathPredictorForRailEnv(PredictionBuilder): class ShortestPathPredictorForRailEnv(PredictionBuilder):
""" """
ShortestPathPredictorForRailEnv object. ShortestPathPredictorForRailEnv object.
...@@ -17,6 +16,9 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): ...@@ -17,6 +16,9 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
The prediction acts as if no other agent is in the environment and always takes the forward action. The prediction acts as if no other agent is in the environment and always takes the forward action.
""" """
def __init__(self, max_depth):
self.max_depth = max_depth
def get(self, custom_args=None, handle=None): def get(self, custom_args=None, handle=None):
""" """
Called whenever get_many in the observation build is called. Called whenever get_many in the observation build is called.
...@@ -54,6 +56,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): ...@@ -54,6 +56,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
prediction = np.zeros(shape=(self.max_depth + 1, 5)) prediction = np.zeros(shape=(self.max_depth + 1, 5))
prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0] prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0]
visited = set() visited = set()
for index in range(1, self.max_depth + 1): for index in range(1, self.max_depth + 1):
# if we're at the target, stop moving... # if we're at the target, stop moving...
if agent.position == agent.target: if agent.position == agent.target:
...@@ -100,4 +103,5 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): ...@@ -100,4 +103,5 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
# cleanup: reset initial position # cleanup: reset initial position
agent.position = _agent_initial_position agent.position = _agent_initial_position
agent.direction = _agent_initial_direction agent.direction = _agent_initial_direction
return prediction_dict return prediction_dict
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment