diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py index 0d5387dc90dc5d5dceadd50d12e909434bc1c123..3fda7378771bf2b5cd20c7d065c9454d3e5629a5 100644 --- a/flatland/envs/predictions.py +++ b/flatland/envs/predictions.py @@ -31,7 +31,7 @@ class DummyPredictorForRailEnv(PredictionBuilder): Returns ------- np.array - Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1) x 5 elements: + Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1)x5 elements: - time_offset - position axis 0 - position axis 1 @@ -101,7 +101,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): Returns ------- np.array - Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1) x 5 elements: + Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1)x5 elements: - time_offset - position axis 0 - position axis 1 diff --git a/tests/test_env_prediction_builder.py b/tests/test_env_prediction_builder.py index 4d4078c38c9d98baf9e286c9cb1f2ff8920020e4..f34829d7f5a94230473e6394bbafc0c8c9ae78c1 100644 --- a/tests/test_env_prediction_builder.py +++ b/tests/test_env_prediction_builder.py @@ -166,7 +166,7 @@ def test_shortest_path_predictor(rendering=False): if rendering: renderer = RenderTool(env, gl="PILSVG") renderer.renderEnv(show=True, show_observations=False) - # input("Continue?") + input("Continue?") agent = env.agents[0] assert agent.position == (5, 6)