Skip to content
Snippets Groups Projects
Commit 75dd03bb authored by u214892's avatar u214892
Browse files

#64 bugfix prediction should respect target

parent 6d48dcd1
No related branches found
No related tags found
No related merge requests found
...@@ -45,10 +45,16 @@ class DummyPredictorForRailEnv(PredictionBuilder): ...@@ -45,10 +45,16 @@ class DummyPredictorForRailEnv(PredictionBuilder):
action_priorities = [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT] action_priorities = [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT]
_agent_initial_position = agent.position _agent_initial_position = agent.position
_agent_initial_direction = agent.direction _agent_initial_direction = agent.direction
prediction = np.zeros(shape=(self.max_depth, 5)) prediction = np.zeros(shape=(self.max_depth + 1, 5))
prediction[0] = [0, _agent_initial_position[0], _agent_initial_position[1], _agent_initial_direction, 0] prediction[0] = [0, _agent_initial_position[0], _agent_initial_position[1], _agent_initial_direction, 0]
for index in range(1, self.max_depth): for index in range(1, self.max_depth + 1):
action_done = False action_done = False
# if we're at the target, stop moving...
if agent.position == agent.target:
prediction[index] = [index, agent.target[0], agent.target[1], agent.direction,
RailEnvActions.STOP_MOVING]
continue
for action in action_priorities: for action in action_priorities:
cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \ cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \
self.env._check_action_on_agent(action, agent) self.env._check_action_on_agent(action, agent)
...@@ -61,7 +67,7 @@ class DummyPredictorForRailEnv(PredictionBuilder): ...@@ -61,7 +67,7 @@ class DummyPredictorForRailEnv(PredictionBuilder):
action_done = True action_done = True
break break
if not action_done: if not action_done:
print("Cannot move further.") raise Exception("Cannot move further. Something is wrong")
prediction_dict[agent.handle] = prediction prediction_dict[agent.handle] = prediction
agent.position = _agent_initial_position agent.position = _agent_initial_position
agent.direction = _agent_initial_direction agent.direction = _agent_initial_direction
......
...@@ -323,7 +323,8 @@ class Controller(object): ...@@ -323,7 +323,8 @@ class Controller(object):
def restartAgents(self, event): def restartAgents(self, event):
self.log("Restart Agents - nAgents:", self.view.wRegenNAgents.value) self.log("Restart Agents - nAgents:", self.view.wRegenNAgents.value)
if self.model.init_agents_static is not None: if self.model.init_agents_static is not None:
self.model.env.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in self.model.init_agents_static] self.model.env.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in
self.model.init_agents_static]
self.model.env.agents = None self.model.env.agents = None
self.model.init_agents_static = None self.model.init_agents_static = None
self.player = None self.player = None
......
...@@ -65,7 +65,7 @@ def test_predictions(): ...@@ -65,7 +65,7 @@ def test_predictions():
rail_generator=rail_from_GridTransitionMap_generator(rail), rail_generator=rail_from_GridTransitionMap_generator(rail),
number_of_agents=1, number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv(), obs_builder_object=GlobalObsForRailEnv(),
prediction_builder_object=DummyPredictorForRailEnv(max_depth=20) prediction_builder_object=DummyPredictorForRailEnv(max_depth=10)
) )
env.reset() env.reset()
...@@ -73,6 +73,7 @@ def test_predictions(): ...@@ -73,6 +73,7 @@ def test_predictions():
# set initial position and direction for testing... # set initial position and direction for testing...
env.agents[0].position = (5, 6) env.agents[0].position = (5, 6)
env.agents[0].direction = 0 env.agents[0].direction = 0
env.agents[0].target = (3., 0.)
predictions = env.predict() predictions = env.predict()
positions = np.array(list(map(lambda prediction: [prediction[1], prediction[2]], predictions[0]))) positions = np.array(list(map(lambda prediction: [prediction[1], prediction[2]], predictions[0])))
...@@ -89,18 +90,11 @@ def test_predictions(): ...@@ -89,18 +90,11 @@ def test_predictions():
[3., 3.], [3., 3.],
[3., 2.], [3., 2.],
[3., 1.], [3., 1.],
# at target (3,0): stay in this position from here on
[3., 0.], [3., 0.],
[3., 1.], [3., 0.],
[3., 2.], [3., 0.],
[3., 3.], ])
[3., 4.],
[3., 5.],
[3., 6.],
[3., 7.],
[3., 8.],
[3., 9.],
[3., 8.],
[3., 7.]])
expected_directions = np.array([[0.], expected_directions = np.array([[0.],
[0.], [0.],
[0.], [0.],
...@@ -109,18 +103,11 @@ def test_predictions(): ...@@ -109,18 +103,11 @@ def test_predictions():
[3.], [3.],
[3.], [3.],
[3.], [3.],
# at target (3,0): stay in this position from here on
[3.], [3.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[3.], [3.],
[3.]]) [3.]
])
expected_time_offsets = np.array([[0.], expected_time_offsets = np.array([[0.],
[1.], [1.],
[2.], [2.],
...@@ -132,15 +119,7 @@ def test_predictions(): ...@@ -132,15 +119,7 @@ def test_predictions():
[8.], [8.],
[9.], [9.],
[10.], [10.],
[11.], ])
[12.],
[13.],
[14.],
[15.],
[16.],
[17.],
[18.],
[19.]])
expected_actions = np.array([[0.], expected_actions = np.array([[0.],
[2.], [2.],
[2.], [2.],
...@@ -149,18 +128,12 @@ def test_predictions(): ...@@ -149,18 +128,12 @@ def test_predictions():
[2.], [2.],
[2.], [2.],
[2.], [2.],
# reaching target by straight
[2.], [2.],
[2.], # at target: stopped moving
[2.], [4.],
[2.], [4.],
[2.], ])
[2.],
[2.],
[2.],
[2.],
[2.],
[2.],
[2.]])
assert np.array_equal(positions, expected_positions) assert np.array_equal(positions, expected_positions)
assert np.array_equal(directions, expected_directions) assert np.array_equal(directions, expected_directions)
assert np.array_equal(time_offsets, expected_time_offsets) assert np.array_equal(time_offsets, expected_time_offsets)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment