diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py
index 6c0b92a7c6d45187dae5158fb0a81c9fab2d7280..cb09a628136c5907c93024970757df7e5980842f 100644
--- a/flatland/core/transition_map.py
+++ b/flatland/core/transition_map.py
@@ -336,8 +336,4 @@ class GridTransitionMap(TransitionMap):
 
         return True
 
-# TODO: GIACOMO: is it better to provide those methods with lists of cell_ids
-# (most general implementation) or to make Grid-class specific methods for
-# slicing over the 3 dimensions?  I'd say both perhaps.
-
-# TODO: override __getitem__ and __setitem__ (cell contents, not transitions?)
+# TODO: improvement override __getitem__ and __setitem__ (cell contents, not transitions?)
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index a10c58e6f5eac8ebd04990505894a917c2212b3f..1f02d518a9714de91d8910b8cb1408f25eb3fe88 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -23,6 +23,7 @@ class TreeObsForRailEnv(ObservationBuilder):
     observation_dim = 9
 
     def __init__(self, max_depth, predictor=None):
+        super().__init__()
         self.max_depth = max_depth
 
         # Compute the size of the returned observation vector
@@ -41,15 +42,14 @@ class TreeObsForRailEnv(ObservationBuilder):
 
     def reset(self):
         agents = self.env.agents
-        nAgents = len(agents)
+        nb_agents = len(agents)
 
         compute_distance_map = True
-        if self.agents_previous_reset is not None:
-            if nAgents == len(self.agents_previous_reset):
-                compute_distance_map = False
-                for i in range(nAgents):
-                    if agents[i].target != self.agents_previous_reset[i].target:
-                        compute_distance_map = True
+        if self.agents_previous_reset is not None and nb_agents == len(self.agents_previous_reset):
+            compute_distance_map = False
+            for i in range(nb_agents):
+                if agents[i].target != self.agents_previous_reset[i].target:
+                    compute_distance_map = True
         self.agents_previous_reset = agents
 
         if compute_distance_map:
@@ -57,12 +57,12 @@ class TreeObsForRailEnv(ObservationBuilder):
 
     def _compute_distance_map(self):
         agents = self.env.agents
-        nAgents = len(agents)
-        self.distance_map = np.inf * np.ones(shape=(nAgents,  # self.env.number_of_agents,
+        nb_agents = len(agents)
+        self.distance_map = np.inf * np.ones(shape=(nb_agents,
                                                     self.env.height,
                                                     self.env.width,
                                                     4))
-        self.max_dist = np.zeros(nAgents)
+        self.max_dist = np.zeros(nb_agents)
         self.max_dist = [self._distance_map_walker(agent.target, i) for i, agent in enumerate(agents)]
         # Update local lookup table for all agents' target locations
         self.location_has_target = {tuple(agent.target): 1 for agent in agents}
@@ -83,10 +83,8 @@ class TreeObsForRailEnv(ObservationBuilder):
 
         # BFS from target `position' to all the reachable nodes in the grid
         # Stop the search if the target position is re-visited, in any direction
-        visited = set([(position[0], position[1], 0),
-                       (position[0], position[1], 1),
-                       (position[0], position[1], 2),
-                       (position[0], position[1], 3)])
+        visited = {(position[0], position[1], 0), (position[0], position[1], 1), (position[0], position[1], 2),
+                   (position[0], position[1], 3)}
 
         max_distance = 0
 
@@ -133,10 +131,10 @@ class TreeObsForRailEnv(ObservationBuilder):
                 # Check all possible transitions in new_cell
                 for agent_orientation in range(4):
                     # Is a transition along movement `desired_movement_from_new_cell' to the current cell possible?
-                    isValid = self.env.rail.get_transition((new_cell[0], new_cell[1], agent_orientation),
-                                                           desired_movement_from_new_cell)
+                    is_valid = self.env.rail.get_transition((new_cell[0], new_cell[1], agent_orientation),
+                                                            desired_movement_from_new_cell)
 
-                    if isValid:
+                    if is_valid:
                         """
                         # TODO: check that it works with deadends! -- still bugged!
                         movement = desired_movement_from_new_cell
@@ -163,12 +161,14 @@ class TreeObsForRailEnv(ObservationBuilder):
         elif movement == Grid4TransitionsEnum.WEST:
             return (position[0], position[1] - 1)
 
-    def get_many(self, handles=[]):
+    def get_many(self, handles=None):
         """
         Called whenever an observation has to be computed for the `env' environment, for each agent with handle
         in the `handles' list.
         """
 
+        if handles is None:
+            handles = []
         if self.predictor:
             self.predicted_pos = {}
             self.predicted_dir = {}
@@ -259,7 +259,6 @@ class TreeObsForRailEnv(ObservationBuilder):
         # Root node - current position
         observation = [0, 0, 0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0]
 
-        root_observation = observation[:]
         visited = set()
         # Start from the current orientation, and see which transitions are available;
         # organize them as [left, forward, right, back], relative to the current orientation
@@ -273,7 +272,7 @@ class TreeObsForRailEnv(ObservationBuilder):
             if possible_transitions[branch_direction]:
                 new_cell = self._new_position(agent.position, branch_direction)
                 branch_observation, branch_visited = \
-                    self._explore_branch(handle, new_cell, branch_direction, root_observation, 1, 1)
+                    self._explore_branch(handle, new_cell, branch_direction, 1, 1)
                 observation = observation + branch_observation
                 visited = visited.union(branch_visited)
             else:
@@ -291,7 +290,7 @@ class TreeObsForRailEnv(ObservationBuilder):
             pow4 *= 4
         return num_observations * self.observation_dim
 
-    def _explore_branch(self, handle, position, direction, root_observation, tot_dist, depth):
+    def _explore_branch(self, handle, position, direction, tot_dist, depth):
         """
         Utility function to compute tree-based observations.
         We walk along the branch and collect the information documented in the get() function.
@@ -305,10 +304,10 @@ class TreeObsForRailEnv(ObservationBuilder):
         # until no transitions are possible along the current direction (i.e., dead-ends)
         # We treat dead-ends as nodes, instead of going back, to avoid loops
         exploring = True
-        last_isSwitch = False
-        last_isDeadEnd = False
-        last_isTerminal = False  # wrong cell OR cycle;  either way, we don't want the agent to land here
-        last_isTarget = False
+        last_is_switch = False
+        last_is_dead_end = False
+        last_is_terminal = False  # wrong cell OR cycle;  either way, we don't want the agent to land here
+        last_is_target = False
 
         visited = set()
         agent = self.env.agents[handle]
@@ -369,21 +368,19 @@ class TreeObsForRailEnv(ObservationBuilder):
                 if tot_dist < other_target_encountered:
                     other_target_encountered = tot_dist
 
-            if position == agent.target:
-                if tot_dist < own_target_encountered:
-                    own_target_encountered = tot_dist
+            if position == agent.target and tot_dist < own_target_encountered:
+                own_target_encountered = tot_dist
 
             # #############################
             # #############################
-
             if (position[0], position[1], direction) in visited:
-                last_isTerminal = True
+                last_is_terminal = True
                 break
             visited.add((position[0], position[1], direction))
 
             # If the target node is encountered, pick that as node. Also, no further branching is possible.
             if np.array_equal(position, self.env.agents[handle].target):
-                last_isTarget = True
+                last_is_target = True
                 break
 
             cell_transitions = self.env.rail.get_transitions((*position, direction))
@@ -403,9 +400,9 @@ class TreeObsForRailEnv(ObservationBuilder):
                     tmp = tmp >> 1
                 if nbits == 1:
                     # Dead-end!
-                    last_isDeadEnd = True
+                    last_is_dead_end = True
 
-                if not last_isDeadEnd:
+                if not last_is_dead_end:
                     # Keep walking through the tree along `direction'
                     exploring = True
                     # convert one-hot encoding to 0,1,2,3
@@ -415,14 +412,14 @@ class TreeObsForRailEnv(ObservationBuilder):
                     tot_dist += 1
             elif num_transitions > 0:
                 # Switch detected
-                last_isSwitch = True
+                last_is_switch = True
                 break
 
             elif num_transitions == 0:
                 # Wrong cell type, but let's cover it and treat it as a dead-end, just in case
                 print("WRONG CELL TYPE detected in tree-search (0 transitions possible) at cell", position[0],
                       position[1], direction)
-                last_isTerminal = True
+                last_is_terminal = True
                 break
 
         # `position' is either a terminal node or a switch
@@ -433,7 +430,7 @@ class TreeObsForRailEnv(ObservationBuilder):
         # #############################
         # Modify here to append new / different features for each visited cell!
 
-        if last_isTarget:
+        if last_is_target:
             observation = [own_target_encountered,
                            other_target_encountered,
                            other_agent_encountered,
@@ -445,7 +442,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                            other_agent_opposite_direction
                            ]
 
-        elif last_isTerminal:
+        elif last_is_terminal:
             observation = [own_target_encountered,
                            other_target_encountered,
                            other_agent_encountered,
@@ -469,32 +466,30 @@ class TreeObsForRailEnv(ObservationBuilder):
                            ]
         # #############################
         # #############################
-
-        new_root_observation = observation[:]
         # Start from the current orientation, and see which transitions are available;
         # organize them as [left, forward, right, back], relative to the current orientation
         # Get the possible transitions
         possible_transitions = self.env.rail.get_transitions((*position, direction))
         for branch_direction in [(direction + 4 + i) % 4 for i in range(-1, 3)]:
-            if last_isDeadEnd and self.env.rail.get_transition((*position, direction),
-                                                               (branch_direction + 2) % 4):
+            if last_is_dead_end and self.env.rail.get_transition((*position, direction),
+                                                                 (branch_direction + 2) % 4):
                 # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes
                 # it back
                 new_cell = self._new_position(position, (branch_direction + 2) % 4)
                 branch_observation, branch_visited = self._explore_branch(handle,
                                                                           new_cell,
                                                                           (branch_direction + 2) % 4,
-                                                                          new_root_observation, tot_dist + 1,
+                                                                          tot_dist + 1,
                                                                           depth + 1)
                 observation = observation + branch_observation
                 if len(branch_visited) != 0:
                     visited = visited.union(branch_visited)
-            elif last_isSwitch and possible_transitions[branch_direction]:
+            elif last_is_switch and possible_transitions[branch_direction]:
                 new_cell = self._new_position(position, branch_direction)
                 branch_observation, branch_visited = self._explore_branch(handle,
                                                                           new_cell,
                                                                           branch_direction,
-                                                                          new_root_observation, tot_dist + 1,
+                                                                          tot_dist + 1,
                                                                           depth + 1)
                 observation = observation + branch_observation
                 if len(branch_visited) != 0:
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 7952f29bf935414388c72d534de1c0719781464f..f5f46408875e6235493682014d8bc4313ad5ea34 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -165,8 +165,8 @@ class RailEnv(Environment):
 
         self.restart_agents()
 
-        for i_agemt in range(self.get_num_agents()):
-            agent = self.agents[i_agemt]
+        for i_agent in range(self.get_num_agents()):
+            agent = self.agents[i_agent]
             agent.speed_data['position_fraction'] = 0.0
 
         self.num_resets += 1
@@ -195,31 +195,31 @@ class RailEnv(Environment):
 
         # Reset the step rewards
         self.rewards_dict = dict()
-        for i_agemt in range(self.get_num_agents()):
-            self.rewards_dict[i_agemt] = 0
+        for i_agent in range(self.get_num_agents()):
+            self.rewards_dict[i_agent] = 0
 
         if self.dones["__all__"]:
             self.rewards_dict = {i: r + global_reward for i, r in self.rewards_dict.items()}
             return self._get_observations(), self.rewards_dict, self.dones, {}
 
         # for i in range(len(self.agents_handles)):
-        for i_agemt in range(self.get_num_agents()):
-            agent = self.agents[i_agemt]
+        for i_agent in range(self.get_num_agents()):
+            agent = self.agents[i_agent]
             agent.old_direction = agent.direction
             agent.old_position = agent.position
-            if self.dones[i_agemt]:  # this agent has already completed...
+            if self.dones[i_agent]:  # this agent has already completed...
                 continue
 
-            if i_agemt not in action_dict:  # no action has been supplied for this agent
-                action_dict[i_agemt] = RailEnvActions.DO_NOTHING
+            if i_agent not in action_dict:  # no action has been supplied for this agent
+                action_dict[i_agent] = RailEnvActions.DO_NOTHING
 
-            if action_dict[i_agemt] < 0 or action_dict[i_agemt] > len(RailEnvActions):
-                print('ERROR: illegal action=', action_dict[i_agemt],
-                      'for agent with index=', i_agemt,
+            if action_dict[i_agent] < 0 or action_dict[i_agent] > len(RailEnvActions):
+                print('ERROR: illegal action=', action_dict[i_agent],
+                      'for agent with index=', i_agent,
                       '"DO NOTHING" will be executed instead')
-                action_dict[i_agemt] = RailEnvActions.DO_NOTHING
+                action_dict[i_agent] = RailEnvActions.DO_NOTHING
 
-            action = action_dict[i_agemt]
+            action = action_dict[i_agent]
 
             if action == RailEnvActions.DO_NOTHING and agent.moving:
                 # Keep moving
@@ -228,12 +228,12 @@ class RailEnv(Environment):
             if action == RailEnvActions.STOP_MOVING and agent.moving and agent.speed_data['position_fraction'] == 0.:
                 # Only allow halting an agent on entering new cells.
                 agent.moving = False
-                self.rewards_dict[i_agemt] += stop_penalty
+                self.rewards_dict[i_agent] += stop_penalty
 
             if not agent.moving and not (action == RailEnvActions.DO_NOTHING or action == RailEnvActions.STOP_MOVING):
                 # Allow agent to start with any forward or direction action
                 agent.moving = True
-                self.rewards_dict[i_agemt] += start_penalty
+                self.rewards_dict[i_agent] += start_penalty
 
             # Now perform a movement.
             # If the agent is in an initial position within a new cell (agent.speed_data['position_fraction']<eps)
@@ -269,18 +269,18 @@ class RailEnv(Environment):
 
                             else:
                                 # TODO: an invalid action was chosen after entering the cell. The agent cannot move.
-                                self.rewards_dict[i_agemt] += invalid_action_penalty
-                                self.rewards_dict[i_agemt] += step_penalty * agent.speed_data['speed']
+                                self.rewards_dict[i_agent] += invalid_action_penalty
+                                self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed']
                                 agent.moving = False
-                                self.rewards_dict[i_agemt] += stop_penalty
+                                self.rewards_dict[i_agent] += stop_penalty
 
                                 continue
                         else:
                             # TODO: an invalid action was chosen after entering the cell. The agent cannot move.
-                            self.rewards_dict[i_agemt] += invalid_action_penalty
-                            self.rewards_dict[i_agemt] += step_penalty * agent.speed_data['speed']
+                            self.rewards_dict[i_agent] += invalid_action_penalty
+                            self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed']
                             agent.moving = False
-                            self.rewards_dict[i_agemt] += stop_penalty
+                            self.rewards_dict[i_agent] += stop_penalty
 
                             continue
 
@@ -302,9 +302,9 @@ class RailEnv(Environment):
                     agent.speed_data['position_fraction'] = 0.0
 
             if np.equal(agent.position, agent.target).all():
-                self.dones[i_agemt] = True
+                self.dones[i_agent] = True
             else:
-                self.rewards_dict[i_agemt] += step_penalty * agent.speed_data['speed']
+                self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed']
 
         # Check for end of episode + add global reward to all rewards!
         if np.all([np.array_equal(agent2.position, agent2.target) for agent2 in self.agents]):
diff --git a/tests/test_flatland_core_transition_map.py b/tests/test_flatland_core_transition_map.py
index 5117b12af72be948bc806940169e330d254fdb9b..8013e912d93ea1125f4f7bdc622f2f4e3d4b2333 100644
--- a/tests/test_flatland_core_transition_map.py
+++ b/tests/test_flatland_core_transition_map.py
@@ -3,7 +3,7 @@ from flatland.core.grid.grid8 import Grid8Transitions, Grid8TransitionsEnum
 from flatland.core.transition_map import GridTransitionMap
 
 
-def test_grid4_set_transitions():
+def test_grid4_get_transitions():
     grid4_map = GridTransitionMap(2, 2, Grid4Transitions([]))
     assert grid4_map.get_transitions((0, 0, Grid4TransitionsEnum.NORTH)) == (0, 0, 0, 0)
     grid4_map.set_transition((0, 0, Grid4TransitionsEnum.NORTH), Grid4TransitionsEnum.NORTH, 1)
@@ -19,3 +19,5 @@ def test_grid8_set_transitions():
     assert grid8_map.get_transitions((0, 0, Grid8TransitionsEnum.NORTH)) == (1, 0, 0, 0, 0, 0, 0, 0)
     grid8_map.set_transition((0, 0, Grid8TransitionsEnum.NORTH), Grid8TransitionsEnum.NORTH, 0)
     assert grid8_map.get_transitions((0, 0, Grid8TransitionsEnum.NORTH)) == (0, 0, 0, 0, 0, 0, 0, 0)
+
+# TODO GridTransitionMap