From 75dd03bb4073e38403289e425b306d8d1b7ae7cc Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Wed, 12 Jun 2019 18:53:36 +0200 Subject: [PATCH] #64 bugfix prediction should respect target --- flatland/envs/predictions.py | 12 ++++-- flatland/utils/editor.py | 3 +- tests/test_env_prediction_builder.py | 57 ++++++++-------------------- 3 files changed, 26 insertions(+), 46 deletions(-) diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py index 3338e68..31c279f 100644 --- a/flatland/envs/predictions.py +++ b/flatland/envs/predictions.py @@ -45,10 +45,16 @@ class DummyPredictorForRailEnv(PredictionBuilder): action_priorities = [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT] _agent_initial_position = agent.position _agent_initial_direction = agent.direction - prediction = np.zeros(shape=(self.max_depth, 5)) + prediction = np.zeros(shape=(self.max_depth + 1, 5)) prediction[0] = [0, _agent_initial_position[0], _agent_initial_position[1], _agent_initial_direction, 0] - for index in range(1, self.max_depth): + for index in range(1, self.max_depth + 1): action_done = False + # if we're at the target, stop moving... + if agent.position == agent.target: + prediction[index] = [index, agent.target[0], agent.target[1], agent.direction, + RailEnvActions.STOP_MOVING] + + continue for action in action_priorities: cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \ self.env._check_action_on_agent(action, agent) @@ -61,7 +67,7 @@ class DummyPredictorForRailEnv(PredictionBuilder): action_done = True break if not action_done: - print("Cannot move further.") + raise Exception("Cannot move further. Something is wrong") prediction_dict[agent.handle] = prediction agent.position = _agent_initial_position agent.direction = _agent_initial_direction diff --git a/flatland/utils/editor.py b/flatland/utils/editor.py index d4e5c38..81565d6 100644 --- a/flatland/utils/editor.py +++ b/flatland/utils/editor.py @@ -323,7 +323,8 @@ class Controller(object): def restartAgents(self, event): self.log("Restart Agents - nAgents:", self.view.wRegenNAgents.value) if self.model.init_agents_static is not None: - self.model.env.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in self.model.init_agents_static] + self.model.env.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in + self.model.init_agents_static] self.model.env.agents = None self.model.init_agents_static = None self.player = None diff --git a/tests/test_env_prediction_builder.py b/tests/test_env_prediction_builder.py index 35a6a27..5b0e830 100644 --- a/tests/test_env_prediction_builder.py +++ b/tests/test_env_prediction_builder.py @@ -65,7 +65,7 @@ def test_predictions(): rail_generator=rail_from_GridTransitionMap_generator(rail), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv(), - prediction_builder_object=DummyPredictorForRailEnv(max_depth=20) + prediction_builder_object=DummyPredictorForRailEnv(max_depth=10) ) env.reset() @@ -73,6 +73,7 @@ def test_predictions(): # set initial position and direction for testing... env.agents[0].position = (5, 6) env.agents[0].direction = 0 + env.agents[0].target = (3., 0.) predictions = env.predict() positions = np.array(list(map(lambda prediction: [prediction[1], prediction[2]], predictions[0]))) @@ -89,18 +90,11 @@ def test_predictions(): [3., 3.], [3., 2.], [3., 1.], + # at target (3,0): stay in this position from here on [3., 0.], - [3., 1.], - [3., 2.], - [3., 3.], - [3., 4.], - [3., 5.], - [3., 6.], - [3., 7.], - [3., 8.], - [3., 9.], - [3., 8.], - [3., 7.]]) + [3., 0.], + [3., 0.], + ]) expected_directions = np.array([[0.], [0.], [0.], @@ -109,18 +103,11 @@ def test_predictions(): [3.], [3.], [3.], + # at target (3,0): stay in this position from here on [3.], - [1.], - [1.], - [1.], - [1.], - [1.], - [1.], - [1.], - [1.], - [1.], [3.], - [3.]]) + [3.] + ]) expected_time_offsets = np.array([[0.], [1.], [2.], @@ -132,15 +119,7 @@ def test_predictions(): [8.], [9.], [10.], - [11.], - [12.], - [13.], - [14.], - [15.], - [16.], - [17.], - [18.], - [19.]]) + ]) expected_actions = np.array([[0.], [2.], [2.], @@ -149,18 +128,12 @@ def test_predictions(): [2.], [2.], [2.], + # reaching target by straight [2.], - [2.], - [2.], - [2.], - [2.], - [2.], - [2.], - [2.], - [2.], - [2.], - [2.], - [2.]]) + # at target: stopped moving + [4.], + [4.], + ]) assert np.array_equal(positions, expected_positions) assert np.array_equal(directions, expected_directions) assert np.array_equal(time_offsets, expected_time_offsets) -- GitLab