diff --git a/examples/demo.py b/examples/demo.py
index 2afc7828d43805d42ffe5333ecaf8f47bf271a49..6066219b6e96744344c4c91197552924f7aab907 100644
--- a/examples/demo.py
+++ b/examples/demo.py
@@ -1,8 +1,8 @@
 import os
 import time
 import random
-
 import numpy as np
+from datetime import datetime
 
 from flatland.envs.generators import complex_rail_generator
 # from flatland.envs.generators import rail_from_list_of_saved_GridTransitionMap_generator
@@ -125,12 +125,16 @@ class Demo:
         self.env = env
         self.create_renderer()
         self.action_size = 4
+        self.max_frame_rate = 60
 
     def create_renderer(self):
         self.renderer = RenderTool(self.env, gl="PILSVG")
         handle = self.env.get_agent_handles()
         return handle
 
+    def set_max_framerate(self,max_frame_rate):
+        self.max_frame_rate = max_frame_rate
+
     def run_demo(self, max_nbr_of_steps=30):
         action_dict = dict()
 
@@ -141,7 +145,7 @@ class Demo:
 
         for step in range(max_nbr_of_steps):
 
-            # time.sleep(.1)
+            begin_frame_time_stamp = datetime.now()
 
             # Action
             for iAgent in range(self.env.get_num_agents()):
@@ -173,6 +177,17 @@ class Demo:
 
             if done['__all__']:
                 break
+
+
+            # ensure that the rendering is not faster then the maximal allowed frame rate
+            end_frame_time_stamp = datetime.now()
+            frame_exe_time = end_frame_time_stamp - begin_frame_time_stamp
+            max_time = 1/self.max_frame_rate
+            delta = (max_time - frame_exe_time.total_seconds())
+            if delta > 0.0:
+                time.sleep(delta)
+
+
         self.renderer.close_window()