diff --git a/tests/simple_rail.py b/tests/simple_rail.py
index 894864acc8d7435681636f2b0c60da238265194e..fb739f29942d0552b41b602046030c76031e7539 100644
--- a/tests/simple_rail.py
+++ b/tests/simple_rail.py
@@ -1,10 +1,11 @@
 import numpy as np
+from typing import Tuple
 
 from flatland.core.grid.grid4 import Grid4Transitions
 from flatland.core.transition_map import GridTransitionMap
 
 
-def make_simple_rail():
+def make_simple_rail() -> Tuple[GridTransitionMap,np.array]:
     # We instantiate a very simple rail network on a 7x10 grid:
     #        |
     #        |
diff --git a/tests/test_distance_map.py b/tests/test_distance_map.py
index 79f4bab164312f757ad584bd3708a7af3fb7a97e..742e841c1a699f849a19c1f27f1b084b440a155a 100644
--- a/tests/test_distance_map.py
+++ b/tests/test_distance_map.py
@@ -53,4 +53,4 @@ def test_walker():
     print(obs_builder.distance_map[(0, *[0, 1], 1)])
     assert obs_builder.distance_map[(0, *[0, 1], 1)] == 3
     print(obs_builder.distance_map[(0, *[0, 2], 3)])
-    assert obs_builder.distance_map[(0, *[0, 2], 1)] == 2  # does not work yet, Erik's proposal.
+    assert obs_builder.distance_map[(0, *[0, 2], 1)] == 2
diff --git a/tests/test_flatland_envs_observations.py b/tests/test_flatland_envs_observations.py
index fb45e5fa162d1323533a62349922b569a415395f..4eb8d633c5a246c5a755a922cae02484bb1ac696 100644
--- a/tests/test_flatland_envs_observations.py
+++ b/tests/test_flatland_envs_observations.py
@@ -3,9 +3,13 @@
 
 import numpy as np
 
+from flatland.core.grid.grid4 import Grid4TransitionsEnum
+from flatland.envs.agent_utils import EnvAgent
 from flatland.envs.generators import rail_from_GridTransitionMap_generator
-from flatland.envs.observations import GlobalObsForRailEnv
-from flatland.envs.rail_env import RailEnv
+from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv
+from flatland.envs.predictions import ShortestPathPredictorForRailEnv
+from flatland.envs.rail_env import RailEnv, RailEnvActions
+from flatland.utils.rendertools import RenderTool
 from tests.simple_rail import make_simple_rail
 
 """Tests for `flatland` package."""
@@ -35,3 +39,253 @@ def test_global_obs():
     # If this assertion is wrong, it means that the observation returned
     # places the agent on an empty cell
     assert (np.sum(rail_map * global_obs[0][1][:, :, :4].sum(2)) > 0)
+
+
+def _step_along_shortest_path(env, obs_builder, rail):
+    actions = {}
+    expected_next_position = {}
+    for agent in env.agents:
+        agent: EnvAgent
+        shortest_distance = np.inf
+
+        for exit_direction in range(4):
+            neighbour = obs_builder._new_position(agent.position, exit_direction)
+
+            if neighbour[0] >= 0 and neighbour[0] < env.height and neighbour[1] >= 0 and neighbour[1] < env.width:
+                desired_movement_from_new_cell = (exit_direction + 2) % 4
+
+                # Check all possible transitions in new_cell
+                for agent_orientation in range(4):
+                    # Is a transition along movement `entry_direction' to the neighbour possible?
+                    is_valid = obs_builder.env.rail.get_transition((neighbour[0], neighbour[1], agent_orientation),
+                                                                   desired_movement_from_new_cell)
+                    if is_valid:
+                        distance_to_target = obs_builder.distance_map[
+                            (agent.handle, *agent.position, exit_direction)]
+                        print("agent {} at {} facing {} taking {} distance {}".format(agent.handle, agent.position,
+                                                                                      agent.direction,
+                                                                                      exit_direction,
+                                                                                      distance_to_target))
+
+                        if distance_to_target < shortest_distance:
+                            shortest_distance = distance_to_target
+                            actions_to_be_taken_when_facing_north = {
+                                Grid4TransitionsEnum.NORTH: RailEnvActions.MOVE_FORWARD,
+                                Grid4TransitionsEnum.EAST: RailEnvActions.MOVE_RIGHT,
+                                Grid4TransitionsEnum.WEST: RailEnvActions.MOVE_LEFT,
+                                Grid4TransitionsEnum.SOUTH: RailEnvActions.DO_NOTHING,
+                            }
+                            print("   improved (direction) -> {}".format(exit_direction))
+
+                            actions[agent.handle] = actions_to_be_taken_when_facing_north[
+                                (exit_direction - agent.direction) % len(rail.transitions.get_direction_enum())]
+                            expected_next_position[agent.handle] = neighbour
+                            print("   improved (action) -> {}".format(actions[agent.handle]))
+    _, rewards, dones, _ = env.step(actions)
+    return rewards
+
+
+def test_reward_function_conflict(rendering=False):
+    rail, rail_map = make_simple_rail()
+    env = RailEnv(width=rail_map.shape[1],
+                  height=rail_map.shape[0],
+                  rail_generator=rail_from_GridTransitionMap_generator(rail),
+                  number_of_agents=2,
+                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
+                  )
+    obs_builder: TreeObsForRailEnv = env.obs_builder
+    # initialize agents_static
+    env.reset()
+
+    # set the initial position
+    agent = env.agents_static[0]
+    agent.position = (5, 6)  # south dead-end
+    agent.direction = 0  # north
+    agent.target = (3, 9)  # east dead-end
+    agent.moving = True
+
+    agent = env.agents_static[1]
+    agent.position = (3, 8)  # east dead-end
+    agent.direction = 3  # west
+    agent.target = (6, 6)  # south dead-end
+    agent.moving = True
+
+    # reset to set agents from agents_static
+    env.reset(False, False)
+
+    if rendering:
+        renderer = RenderTool(env, gl="PILSVG")
+        renderer.renderEnv(show=True, show_observations=True)
+
+    iteration = -1
+    expected_positions = {
+        0: {
+            0: (5, 6),
+            1: (3, 8)
+        },
+        # both can move
+        1: {
+            0: (4, 6),
+            1: (3, 7)
+        },
+        # first can move, second stuck
+        2: {
+            0: (3, 6),
+            1: (3, 7)
+        },
+        # both stuck from now on
+        3: {
+            0: (3, 6),
+            1: (3, 7)
+        },
+        4: {
+            0: (3, 6),
+            1: (3, 7)
+        },
+        5: {
+            0: (3, 6),
+            1: (3, 7)
+        },
+    }
+    while not env.dones["__all__"] and iteration + 1 < 5:
+        iteration += 1
+        rewards = _step_along_shortest_path(env, obs_builder, rail)
+
+        for agent in env.agents:
+            assert rewards[agent.handle] == -1
+            expected_position = expected_positions[iteration + 1][agent.handle]
+            assert agent.position == expected_position, "[{}] agent {} at {}, expected {}".format(iteration + 1,
+                                                                                                  agent.handle,
+                                                                                                  agent.position,
+                                                                                                  expected_position)
+        if rendering:
+            renderer.renderEnv(show=True, show_observations=True)
+
+
+def test_reward_function_waiting(rendering=False):
+    rail, rail_map = make_simple_rail()
+    env = RailEnv(width=rail_map.shape[1],
+                  height=rail_map.shape[0],
+                  rail_generator=rail_from_GridTransitionMap_generator(rail),
+                  number_of_agents=2,
+                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
+                  )
+    obs_builder: TreeObsForRailEnv = env.obs_builder
+    # initialize agents_static
+    env.reset()
+
+    # set the initial position
+    agent = env.agents_static[0]
+    agent.position = (3, 8)  # east dead-end
+    agent.direction = 3  # west
+    agent.target = (3, 1)  # west dead-end
+    agent.moving = True
+
+    agent = env.agents_static[1]
+    agent.position = (5, 6)  # south dead-end
+    agent.direction = 0  # north
+    agent.target = (3, 8)  # east dead-end
+    agent.moving = True
+
+    # reset to set agents from agents_static
+    env.reset(False, False)
+
+    if rendering:
+        renderer = RenderTool(env, gl="PILSVG")
+        renderer.renderEnv(show=True, show_observations=True)
+
+    iteration = -1
+    expectations = {
+        0: {
+            'positions': {
+                0: (3, 8),
+                1: (5, 6),
+            },
+            'rewards': [-1, -1],
+        },
+        1: {
+            'positions': {
+                0: (3, 7),
+                1: (4, 6),
+            },
+            'rewards': [-1, -1],
+        },
+        # second agent has to wait for first, first can continue
+        2: {
+            'positions': {
+                0: (3, 6),
+                1: (4, 6),
+            },
+            'rewards': [-1, -1],
+        },
+        # both can move again
+        3: {
+            'positions': {
+                0: (3, 5),
+                1: (3, 6),
+            },
+            'rewards': [-1, -1],
+        },
+        4: {
+            'positions': {
+                0: (3, 4),
+                1: (3, 7),
+            },
+            'rewards': [-1, -1],
+        },
+        # second reached target
+        5: {
+            'positions': {
+                0: (3, 3),
+                1: (3, 8),
+            },
+            'rewards': [-1, 0],
+        },
+        6: {
+            'positions': {
+                0: (3, 2),
+                1: (3, 8),
+            },
+            'rewards': [-1, 0],
+        },
+        # first reaches, target too
+        7: {
+            'positions': {
+                0: (3, 1),
+                1: (5, 6),
+            },
+            'rewards': [1, 1],
+        },
+        8: {
+            'positions': {
+                0: (3, 1),
+                1: (5, 6),
+            },
+            'rewards': [1, 1],
+        },
+    }
+    while not env.dones["__all__"] and iteration + 1 < 5:
+        iteration += 1
+        rewards = _step_along_shortest_path(env, obs_builder, rail)
+
+        if rendering:
+            renderer.renderEnv(show=True, show_observations=True)
+
+        print(env.dones["__all__"])
+        for agent in env.agents:
+            agent: EnvAgent
+            print("[{}] agent {} at {}, target {} ".format(iteration + 1, agent.handle, agent.position, agent.target))
+        print(np.all([np.array_equal(agent2.position, agent2.target) for agent2 in env.agents]))
+        for agent in env.agents:
+            expected_position = expectations[iteration + 1]['positions'][agent.handle]
+            assert agent.position == expected_position, \
+                "[{}] agent {} at {}, expected {}".format(iteration + 1,
+                                                          agent.handle,
+                                                          agent.position,
+                                                          expected_position)
+            expected_reward = expectations[iteration + 1]['rewards'][agent.handle]
+            actual_reward = rewards[agent.handle]
+            assert expected_reward == actual_reward, "[{}] agent {} reward {}, expected {}".format(iteration + 1,
+                                                                                                   agent.handle,
+                                                                                                   actual_reward,
+                                                                                                   expected_reward)