diff --git a/flatland/contrib/wrappers/flatland_wrappers.py b/flatland/contrib/wrappers/flatland_wrappers.py
index d07fd06ae867bedce34237815f88a17280cf0651..a60b186f3781e69116b03d2c3b61d238c0fda546 100644
--- a/flatland/contrib/wrappers/flatland_wrappers.py
+++ b/flatland/contrib/wrappers/flatland_wrappers.py
@@ -181,110 +181,89 @@ def find_all_cells_where_agent_can_choose(env: RailEnv):
     decision_cells = switches + switches_neighbors
     return tuple(map(set, (switches, switches_neighbors, decision_cells)))
 
+  
 
-class NoChoiceCellsSkipper:
+    
+class SkipNoChoiceCellsWrapper(RailEnvWrapper):
+  
+    # env can be a real RailEnv, or anything that shares the same interface
+    # e.g. obs, rewards, dones, info = env.step(action_dict) and obs, info = env.reset(), and so on.
     def __init__(self, env:RailEnv, accumulate_skipped_rewards: bool, discounting: float) -> None:
-      self.env = env
+      super().__init__(env)
+      # save these so they can be inspected easier.
+      self.accumulate_skipped_rewards = accumulate_skipped_rewards
+      self.discounting = discounting
       self.switches = None
       self.switches_neighbors = None
       self.decision_cells = None
-      self.accumulate_skipped_rewards = accumulate_skipped_rewards
-      self.discounting = discounting
       self.skipped_rewards = defaultdict(list)
 
-      # env.reset() can change the rail grid layout, so the switches, etc. will change! --> need to do this in reset() as well.
-      #self.switches, self.switches_neighbors, self.decision_cells = find_all_cells_where_agent_can_choose(self.env)
-
-      # compute and initialize value for switches, switches_neighbors, and decision_cells.
+      # sets initial values for switches, decision_cells, etc.
       self.reset_cells()
 
+
     def on_decision_cell(self, agent: EnvAgent) -> bool:
-        return agent.position is None or agent.position == agent.initial_position or agent.position in self.decision_cells
+      return agent.position is None or agent.position == agent.initial_position or agent.position in self.decision_cells
 
     def on_switch(self, agent: EnvAgent) -> bool:
-        return agent.position in self.switches
+      return agent.position in self.switches
 
     def next_to_switch(self, agent: EnvAgent) -> bool:
-        return agent.position in self.switches_neighbors
-
-    def no_choice_skip_step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]:
-        o, r, d, i = {}, {}, {}, {}
-      
-        # NEED TO INITIALIZE i["..."]
-        # as we will access i["..."][agent_id]
-        i["action_required"] = dict()
-        i["malfunction"] = dict()
-        i["speed"] = dict()
-        i["state"] = dict()
-
-        while len(o) == 0:
-            obs, reward, done, info = self.env.step(action_dict)
-
-            for agent_id, agent_obs in obs.items():
-                if done[agent_id] or self.on_decision_cell(self.env.agents[agent_id]):
-                    
-                    o[agent_id] = agent_obs
-                    r[agent_id] = reward[agent_id]
-                    d[agent_id] = done[agent_id]
-
-            
-                    i["action_required"][agent_id] = info["action_required"][agent_id] 
-                    i["malfunction"][agent_id] = info["malfunction"][agent_id]
-                    i["speed"][agent_id] = info["speed"][agent_id]
-                    i["state"][agent_id] = info["state"][agent_id]
-                                                                  
-                    if self.accumulate_skipped_rewards:
-                        discounted_skipped_reward = r[agent_id]
-                        for skipped_reward in reversed(self.skipped_rewards[agent_id]):
-                            discounted_skipped_reward = self.discounting * discounted_skipped_reward + skipped_reward
-                        r[agent_id] = discounted_skipped_reward
-                        self.skipped_rewards[agent_id] = []
-
-                elif self.accumulate_skipped_rewards:
-                    self.skipped_rewards[agent_id].append(reward[agent_id])
-                # end of for-loop
-
-            d['__all__'] = done['__all__']
-            action_dict = {}
-            # end of while-loop
-
-        return o, r, d, i
-
-    
-    def reset_cells(self) -> None:
-        self.switches, self.switches_neighbors, self.decision_cells = find_all_cells_where_agent_can_choose(self.env)
-
+      return agent.position in self.switches_neighbors
 
-# IMPORTANT: rail env should be reset() / initialized before put into this one!
-class SkipNoChoiceCellsWrapper(RailEnvWrapper):
-  
-    # env can be a real RailEnv, or anything that shares the same interface
-    # e.g. obs, rewards, dones, info = env.step(action_dict) and obs, info = env.reset(), and so on.
-    def __init__(self, env:RailEnv, accumulate_skipped_rewards: bool, discounting: float) -> None:
-        super().__init__(env)
-        # save these so they can be inspected easier.
-        self.accumulate_skipped_rewards = accumulate_skipped_rewards
-        self.discounting = discounting
-        self.skipper = NoChoiceCellsSkipper(env=self.env, accumulate_skipped_rewards=self.accumulate_skipped_rewards, discounting=self.discounting)
 
-        self.skipper.reset_cells()
+    def reset_cells(self) -> None:
+      self.switches, self.switches_neighbors, self.decision_cells = find_all_cells_where_agent_can_choose(self.env)
 
-        self.switches = self.skipper.switches
-        self.switches_neighbors = self.skipper.switches_neighbors
-        self.decision_cells = self.skipper.decision_cells
-        self.skipped_rewards = self.skipper.skipped_rewards
 
-  
     def step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]:
-        obs, rewards, dones, info = self.skipper.no_choice_skip_step(action_dict=action_dict)
-        return obs, rewards, dones, info
+      o, r, d, i = {}, {}, {}, {}
+    
+      # NEED TO INITIALIZE i["..."]
+      # as we will access i["..."][agent_id]
+      i["action_required"] = dict()
+      i["malfunction"] = dict()
+      i["speed"] = dict()
+      i["state"] = dict()
+
+      while len(o) == 0:
+        obs, reward, done, info = self.env.step(action_dict)
+
+        for agent_id, agent_obs in obs.items():
+          if done[agent_id] or self.on_decision_cell(self.env.agents[agent_id]):
+            o[agent_id] = agent_obs
+            r[agent_id] = reward[agent_id]
+            d[agent_id] = done[agent_id]
+
+            i["action_required"][agent_id] = info["action_required"][agent_id] 
+            i["malfunction"][agent_id] = info["malfunction"][agent_id]
+            i["speed"][agent_id] = info["speed"][agent_id]
+            i["state"][agent_id] = info["state"][agent_id]
+                                                          
+            if self.accumulate_skipped_rewards:
+              discounted_skipped_reward = r[agent_id]
+
+              for skipped_reward in reversed(self.skipped_rewards[agent_id]):
+                discounted_skipped_reward = self.discounting * discounted_skipped_reward + skipped_reward
+
+              r[agent_id] = discounted_skipped_reward
+              self.skipped_rewards[agent_id] = []
+
+          elif self.accumulate_skipped_rewards:
+            self.skipped_rewards[agent_id].append(reward[agent_id])
+          # end of for-loop
+
+        d['__all__'] = done['__all__']
+        action_dict = {}
+        # end of while-loop
+
+      return o, r, d, i
         
 
-    
-    # arguments from RailEnv.reset() are: self, regenerate_rail: bool = True, regenerate_schedule: bool = True, activate_agents: bool = False, random_seed: bool = None
+
     def reset(self, **kwargs) -> Tuple[Dict, Dict]:
-        obs, info = self.env.reset(**kwargs)
-        # resets decision cells, switches, etc. These can change with an env.reset(...)!
-        # needs to be done after env.reset().
-        self.skipper.reset_cells()
-        return obs, info
+      obs, info = self.env.reset(**kwargs)
+      # resets decision cells, switches, etc. These can change with an env.reset(...)!
+      # needs to be done after env.reset().
+      self.reset_cells()
+      return obs, info
\ No newline at end of file