diff --git a/RLLib_training/train_experiment.py b/RLLib_training/train_experiment.py
index 223b9059945db484b94b0916924aa739f32270e4..48407f645a07d3947053365a86398a90aa4e41d7 100644
--- a/RLLib_training/train_experiment.py
+++ b/RLLib_training/train_experiment.py
@@ -77,8 +77,7 @@ def train(config, reporter):
         obs_space = gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(105,))
         preprocessor = "tree_obs_prep"
 
-    elif isinstance(config["obs_builder"], GlobalObsForRailEnv) or \
-         isinstance(config["obs_builder"], GlobalObsForRailEnvDirectionDependent):
+    elif isinstance(config["obs_builder"], GlobalObsForRailEnv):
         obs_space = gym.spaces.Tuple((
             gym.spaces.Box(low=0, high=1, shape=(config['map_height'], config['map_width'], 16)),
             gym.spaces.Box(low=0, high=1, shape=(config['map_height'], config['map_width'], 8)),
@@ -88,6 +87,16 @@ def train(config, reporter):
         else:
             preprocessor = "global_obs_prep"
 
+    elif isinstance(config["obs_builder"], GlobalObsForRailEnvDirectionDependent):
+        obs_space = gym.spaces.Tuple((
+            gym.spaces.Box(low=0, high=1, shape=(config['map_height'], config['map_width'], 16)),
+            gym.spaces.Box(low=0, high=1, shape=(config['map_height'], config['map_width'], 5)),
+            gym.spaces.Box(low=0, high=1, shape=(config['map_height'], config['map_width'], 2))))
+        if config['conv_model']:
+            preprocessor = "conv_obs_prep"
+        else:
+            preprocessor = "global_obs_prep"
+
     elif isinstance(config["obs_builder"], LocalObsForRailEnv):
         view_radius = config["obs_builder"].view_radius
         obs_space = gym.spaces.Tuple((