diff --git a/flatland/core/env.py b/flatland/core/env.py
index 3c25beea236945b1728959e02ea07f6c0ba7a6ac..32691f507f4cb5586f10b5645cc22ece718edc21 100644
--- a/flatland/core/env.py
+++ b/flatland/core/env.py
@@ -84,6 +84,21 @@ class Environment:
         """
         raise NotImplementedError()
 
+    def predict(self):
+        """
+        Predictions step.
+
+        Returns predictions for the agents.
+        The returns are dicts mapping from agent_id strings to values.
+
+        Returns
+        -------
+        predictions : dict
+            New predictions for each ready agent.
+
+        """
+        raise NotImplementedError()
+
     def render(self):
         """
         Perform rendering of the environment.
diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py
index 3cef545c1658e6bfe2a292ee26c3e665ce6a5abc..f85afee4b625e59374c6cce266bf55b21e7fdb84 100644
--- a/flatland/core/env_observation_builder.py
+++ b/flatland/core/env_observation_builder.py
@@ -19,7 +19,6 @@ class ObservationBuilder:
 
     def __init__(self):
         self.observation_space = ()
-        pass
 
     def _set_env(self, env):
         self.env = env
diff --git a/flatland/core/env_prediction_builder.py b/flatland/core/env_prediction_builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f5e4dc5033ba3a789313b47d94b033251cb8276
--- /dev/null
+++ b/flatland/core/env_prediction_builder.py
@@ -0,0 +1,44 @@
+"""
+PredictionBuilder objects are objects that can be passed to environments designed for customizability.
+The PredictionBuilder-derived custom classes implement 2 functions, reset() and get([handle]).
+If predictions are not required in every step or not for all agents, then
+
++ Reset() is called after each environment reset, to allow for pre-computing relevant data.
+
++ Get() is called whenever an step has to be computed, potentially for each agent independently in
+case of multi-agent environments.
+"""
+
+
+class PredictionBuilder:
+    """
+    PredictionBuilder base class.
+    """
+
+    def __init__(self, max_depth: int = 20):
+        self.max_depth = max_depth
+
+    def _set_env(self, env):
+        self.env = env
+
+    def reset(self):
+        """
+        Called after each environment reset.
+        """
+        pass
+
+    def get(self, handle=0):
+        """
+        Called whenever step_prediction is called on the environment.
+
+        Parameters
+        -------
+        handle : int (optional)
+            Handle of the agent for which to compute the observation vector.
+
+        Returns
+        -------
+        function
+            A prediction structure, specific to the corresponding environment.
+        """
+        raise NotImplementedError()
diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py
new file mode 100644
index 0000000000000000000000000000000000000000..95c1a984c4151a9a873deeeb29438b290bb4f77e
--- /dev/null
+++ b/flatland/envs/predictions.py
@@ -0,0 +1,72 @@
+"""
+Collection of environment-specific PredictionBuilder.
+"""
+
+import numpy as np
+
+from flatland.core.env_prediction_builder import PredictionBuilder
+
+
+class DummyPredictorForRailEnv(PredictionBuilder):
+    """
+    DummyPredictorForRailEnv object.
+
+    This object returns predictions for agents in the RailEnv environment.
+    The prediction acts as if no other agent is in the environment and always takes the forward action.
+    """
+
+    def get(self, handle=None):
+        """
+        Called whenever step_prediction is called on the environment.
+
+        Parameters
+        -------
+        handle : int (optional)
+            Handle of the agent for which to compute the observation vector.
+
+        Returns
+        -------
+        function
+            Returns a dictionary index by the agent handle and for each agent a vector of 5 elements:
+            - time_offset
+            - position axis 0
+            - position axis 1
+            - direction
+            - action taken to come here
+        """
+        agents = self.env.agents
+        if handle:
+            agents = [self.env.agents[handle]]
+
+        prediction_dict = {}
+
+        for agent in agents:
+
+            # 0: do nothing
+            # 1: turn left and move to the next cell
+            # 2: move to the next cell in front of the agent
+            # 3: turn right and move to the next cell
+            action_priorities = [2, 1, 3]
+            _agent_initial_position = agent.position
+            _agent_initial_direction = agent.direction
+            prediction = np.zeros(shape=(self.max_depth, 5))
+            prediction[0] = [0, _agent_initial_position[0], _agent_initial_position[1], _agent_initial_direction, 0]
+            for index in range(1, self.max_depth):
+                action_done = False
+                for action in action_priorities:
+                    cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = self.env._check_action_on_agent(action,
+                                                                                                                                     agent)
+                    if all([new_cell_isValid, transition_isValid]):
+                        # move and change direction to face the new_direction that was
+                        # performed
+                        agent.position = new_position
+                        agent.direction = new_direction
+                        prediction[index] = [index, new_position[0], new_position[1], new_direction, action]
+                        action_done = True
+                        break
+                if not action_done:
+                    print("Cannot move further.")
+            prediction_dict[agent.handle] = prediction
+            agent.position = _agent_initial_position
+            agent.direction = _agent_initial_direction
+        return prediction_dict
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index a23b6ac71651e8e424bb90a23cbcfd472ce89a12..190a14eae6ae2c529915c29fbc22d4d12ab4c955 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -51,7 +51,9 @@ class RailEnv(Environment):
                  height,
                  rail_generator=random_rail_generator(),
                  number_of_agents=1,
-                 obs_builder_object=TreeObsForRailEnv(max_depth=2)):
+                 obs_builder_object=TreeObsForRailEnv(max_depth=2),
+                 prediction_builder_object=None
+                 ):
         """
         Environment init.
 
@@ -94,6 +96,11 @@ class RailEnv(Environment):
         self.obs_builder = obs_builder_object
         self.obs_builder._set_env(self)
 
+        self.prediction_builder = prediction_builder_object
+        if self.prediction_builder:
+            self.prediction_builder._set_env(self)
+
+
         self.action_space = [1]
         self.observation_space = self.obs_builder.observation_space  # updated on resets?
 
@@ -212,51 +219,9 @@ class RailEnv(Environment):
                 return
 
             if action > 0:
-                # pos = agent.position #  self.agents_position[i]
-                # direction = agent.direction # self.agents_direction[i]
-
-                # compute number of possible transitions in the current
-                # cell used to check for invalid actions
-
-                new_direction, transition_isValid = self.check_action(agent, action)
-
-                new_position = get_new_position(agent.position, new_direction)
-                # Is it a legal move?
-                # 1) transition allows the new_direction in the cell,
-                # 2) the new cell is not empty (case 0),
-                # 3) the cell is free, i.e., no agent is currently in that cell
-
-                # if (
-                #        new_position[1] >= self.width or
-                #        new_position[0] >= self.height or
-                #        new_position[0] < 0 or new_position[1] < 0):
-                #    new_cell_isValid = False
-
-                # if self.rail.get_transitions(new_position) == 0:
-                #     new_cell_isValid = False
-
-                new_cell_isValid = (
-                    np.array_equal(  # Check the new position is still in the grid
-                        new_position,
-                        np.clip(new_position, [0, 0], [self.height - 1, self.width - 1]))
-                    and  # check the new position has some transitions (ie is not an empty cell)
-                    self.rail.get_transitions(new_position) > 0)
-
-                # If transition validity hasn't been checked yet.
-                if transition_isValid is None:
-                    transition_isValid = self.rail.get_transition(
-                        (*agent.position, agent.direction),
-                        new_direction)
-
-                # cell_isFree = True
-                # for j in range(self.number_of_agents):
-                #    if self.agents_position[j] == new_position:
-                #        cell_isFree = False
-                #        break
-                # Check the new position is not the same as any of the existing agent positions
-                # (including itself, for simplicity, since it is moving)
-                cell_isFree = not np.any(
-                    np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1))
+                cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = self._check_action_on_agent(action,
+                                                                                                                             agent,
+                                                                                                                             transition_isValid)
 
                 if all([new_cell_isValid, transition_isValid, cell_isFree]):
                     # move and change direction to face the new_direction that was
@@ -296,6 +261,52 @@ class RailEnv(Environment):
         self.actions = [0] * self.get_num_agents()
         return self._get_observations(), self.rewards_dict, self.dones, {}
 
+    def _check_action_on_agent(self, action, agent):
+        # pos = agent.position #  self.agents_position[i]
+        # direction = agent.direction # self.agents_direction[i]
+        # compute number of possible transitions in the current
+        # cell used to check for invalid actions
+        new_direction, transition_isValid = self.check_action(agent, action)
+        new_position = get_new_position(agent.position, new_direction)
+        # Is it a legal move?
+        # 1) transition allows the new_direction in the cell,
+        # 2) the new cell is not empty (case 0),
+        # 3) the cell is free, i.e., no agent is currently in that cell
+        # if (
+        #        new_position[1] >= self.width or
+        #        new_position[0] >= self.height or
+        #        new_position[0] < 0 or new_position[1] < 0):
+        #    new_cell_isValid = False
+        # if self.rail.get_transitions(new_position) == 0:
+        #     new_cell_isValid = False
+        new_cell_isValid = (
+            np.array_equal(  # Check the new position is still in the grid
+                new_position,
+                np.clip(new_position, [0, 0], [self.height - 1, self.width - 1]))
+            and  # check the new position has some transitions (ie is not an empty cell)
+            self.rail.get_transitions(new_position) > 0)
+        # If transition validity hasn't been checked yet.
+        if transition_isValid is None:
+            transition_isValid = self.rail.get_transition(
+                (*agent.position, agent.direction),
+                new_direction)
+        # cell_isFree = True
+        # for j in range(self.number_of_agents):
+        #    if self.agents_position[j] == new_position:
+        #        cell_isFree = False
+        #        break
+        # Check the new position is not the same as any of the existing agent positions
+        # (including itself, for simplicity, since it is moving)
+        cell_isFree = not np.any(
+            np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1))
+        return cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid
+
+    def predict(self):
+        if not self.prediction_builder:
+            return {}
+        return  self.prediction_builder.get()
+
+
     def check_action(self, agent, action):
         transition_isValid = None
         possible_transitions = self.rail.get_transitions((*agent.position, agent.direction))
@@ -332,6 +343,11 @@ class RailEnv(Environment):
             self.obs_dict[iAgent] = self.obs_builder.get(iAgent)
         return self.obs_dict
 
+    def _get_predictions(self):
+        if not self.prediction_builder:
+            return {}
+        return {}
+
     def render(self):
         # TODO:
         pass
diff --git a/tests/test_env_prediction_builder.py b/tests/test_env_prediction_builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..35a6a27b970ce54e1cabd3cf8c80d30a34800a25
--- /dev/null
+++ b/tests/test_env_prediction_builder.py
@@ -0,0 +1,175 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+import numpy as np
+
+from flatland.core.transition_map import GridTransitionMap, Grid4Transitions
+from flatland.envs.generators import rail_from_GridTransitionMap_generator
+from flatland.envs.observations import GlobalObsForRailEnv
+from flatland.envs.predictions import DummyPredictorForRailEnv
+from flatland.envs.rail_env import RailEnv
+
+"""Test predictions for `flatland` package."""
+
+
+def test_predictions():
+    # We instantiate a very simple rail network on a 7x10 grid:
+    #        |
+    #        |
+    #        |
+    # _ _ _ /_\ _ _  _  _ _ _
+    #               \ /
+    #                |
+    #                |
+    #                |
+
+    cells = [int('0000000000000000', 2),  # empty cell - Case 0
+             int('1000000000100000', 2),  # Case 1 - straight
+             int('1001001000100000', 2),  # Case 2 - simple switch
+             int('1000010000100001', 2),  # Case 3 - diamond drossing
+             int('1001011000100001', 2),  # Case 4 - single slip switch
+             int('1100110000110011', 2),  # Case 5 - double slip switch
+             int('0101001000000010', 2),  # Case 6 - symmetrical switch
+             int('0010000000000000', 2)]  # Case 7 - dead end
+
+    transitions = Grid4Transitions([])
+    empty = cells[0]
+
+    dead_end_from_south = cells[7]
+    dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
+    dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180)
+    dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
+
+    vertical_straight = cells[1]
+    horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
+
+    double_switch_south_horizontal_straight = horizontal_straight + cells[6]
+    double_switch_north_horizontal_straight = transitions.rotate_transition(
+        double_switch_south_horizontal_straight, 180)
+
+    rail_map = np.array(
+        [[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
+        [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 +
+        [[dead_end_from_east] + [horizontal_straight] * 2 +
+         [double_switch_north_horizontal_straight] +
+         [horizontal_straight] * 2 + [double_switch_south_horizontal_straight] +
+         [horizontal_straight] * 2 + [dead_end_from_west]] +
+        [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
+        [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
+
+    rail = GridTransitionMap(width=rail_map.shape[1],
+                             height=rail_map.shape[0], transitions=transitions)
+    rail.grid = rail_map
+    env = RailEnv(width=rail_map.shape[1],
+                  height=rail_map.shape[0],
+                  rail_generator=rail_from_GridTransitionMap_generator(rail),
+                  number_of_agents=1,
+                  obs_builder_object=GlobalObsForRailEnv(),
+                  prediction_builder_object=DummyPredictorForRailEnv(max_depth=20)
+                  )
+
+    env.reset()
+
+    # set initial position and direction for testing...
+    env.agents[0].position = (5, 6)
+    env.agents[0].direction = 0
+
+    predictions = env.predict()
+    positions = np.array(list(map(lambda prediction: [prediction[1], prediction[2]], predictions[0])))
+    directions = np.array(list(map(lambda prediction: [prediction[3]], predictions[0])))
+    time_offsets = np.array(list(map(lambda prediction: [prediction[0]], predictions[0])))
+    actions = np.array(list(map(lambda prediction: [prediction[4]], predictions[0])))
+
+    # compare against expected values
+    expected_positions = np.array([[5., 6.],
+                                   [4., 6.],
+                                   [3., 6.],
+                                   [3., 5.],
+                                   [3., 4.],
+                                   [3., 3.],
+                                   [3., 2.],
+                                   [3., 1.],
+                                   [3., 0.],
+                                   [3., 1.],
+                                   [3., 2.],
+                                   [3., 3.],
+                                   [3., 4.],
+                                   [3., 5.],
+                                   [3., 6.],
+                                   [3., 7.],
+                                   [3., 8.],
+                                   [3., 9.],
+                                   [3., 8.],
+                                   [3., 7.]])
+    expected_directions = np.array([[0.],
+                                    [0.],
+                                    [0.],
+                                    [3.],
+                                    [3.],
+                                    [3.],
+                                    [3.],
+                                    [3.],
+                                    [3.],
+                                    [1.],
+                                    [1.],
+                                    [1.],
+                                    [1.],
+                                    [1.],
+                                    [1.],
+                                    [1.],
+                                    [1.],
+                                    [1.],
+                                    [3.],
+                                    [3.]])
+    expected_time_offsets = np.array([[0.],
+                                      [1.],
+                                      [2.],
+                                      [3.],
+                                      [4.],
+                                      [5.],
+                                      [6.],
+                                      [7.],
+                                      [8.],
+                                      [9.],
+                                      [10.],
+                                      [11.],
+                                      [12.],
+                                      [13.],
+                                      [14.],
+                                      [15.],
+                                      [16.],
+                                      [17.],
+                                      [18.],
+                                      [19.]])
+    expected_actions = np.array([[0.],
+                                 [2.],
+                                 [2.],
+                                 [1.],
+                                 [2.],
+                                 [2.],
+                                 [2.],
+                                 [2.],
+                                 [2.],
+                                 [2.],
+                                 [2.],
+                                 [2.],
+                                 [2.],
+                                 [2.],
+                                 [2.],
+                                 [2.],
+                                 [2.],
+                                 [2.],
+                                 [2.],
+                                 [2.]])
+    assert np.array_equal(positions, expected_positions)
+    assert np.array_equal(directions, expected_directions)
+    assert np.array_equal(time_offsets, expected_time_offsets)
+    assert np.array_equal(actions, expected_actions)
+
+
+def main():
+    test_predictions()
+
+
+if __name__ == "__main__":
+    main()