diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py
index e2ea4bfea061fc42e21978a580e25e8e8139449b..5cb6ba1d633e8997b255447521a2938270d6d173 100755
--- a/reinforcement_learning/multi_agent_training.py
+++ b/reinforcement_learning/multi_agent_training.py
@@ -267,8 +267,8 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
                 if info['action_required'][agent_handle]:
                     update_values[agent_handle] = True
                     action = policy.act(agent_handle, agent_obs[agent_handle], eps=eps_start)
-                    action_count[map_action(action, get_action_size())] += 1
-                    actions_taken.append(map_action(action, get_action_size()))
+                    action_count[map_action(action)] += 1
+                    actions_taken.append(map_action(action))
                 else:
                     # An action is not required if the train hasn't joined the railway network,
                     # if it already reached its target, or if is currently malfunctioning.
@@ -280,7 +280,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
 
             # Environment step
             step_timer.start()
-            next_obs, all_rewards, done, info = train_env.step(map_actions(action_dict, get_action_size()))
+            next_obs, all_rewards, done, info = train_env.step(map_actions(action_dict))
 
             # Reward shaping .Dead-lock .NotMoving .NotStarted
             if False:
@@ -288,7 +288,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
                 for agent_handle in train_env.get_agent_handles():
                     agent = train_env.agents[agent_handle]
                     act = action_dict.get(agent_handle, RailEnvActions.DO_NOTHING)
-                    act = map_action(act, get_action_size())
+                    act = map_action(act)
                     if agent.status == RailAgentStatus.ACTIVE:
                         all_rewards[agent_handle] = 0.0
                         if done[agent_handle] == False:
@@ -494,7 +494,7 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params):
                         action = policy.act(agent, agent_obs[agent], eps=0.0)
                 action_dict.update({agent: action})
             policy.end_step(train=False)
-            obs, all_rewards, done, info = env.step(map_actions(action_dict, get_action_size()))
+            obs, all_rewards, done, info = env.step(map_actions(action_dict))
 
             for agent in env.get_agent_handles():
                 score += all_rewards[agent]
diff --git a/run.py b/run.py
index 998add015fca7d6b57b492131249076bd2382367..8e0535b3ebebd14b81e98ab23d5d9228c0b6fea8 100644
--- a/run.py
+++ b/run.py
@@ -214,7 +214,7 @@ while True:
                 time_taken_by_controller.append(agent_time)
 
                 time_start = time.time()
-                _, all_rewards, done, info = remote_client.env_step(map_actions(action_dict, get_action_size))
+                _, all_rewards, done, info = remote_client.env_step(map_actions(action_dict))
                 step_time = time.time() - time_start
                 time_taken_per_step.append(step_time)
 
diff --git a/utils/agent_action_config.py b/utils/agent_action_config.py
index dceba553310e2d4e47a5554da267c11dfa338ee1..3a84875297cacce023998d21f61ea0abcc76c9d6 100644
--- a/utils/agent_action_config.py
+++ b/utils/agent_action_config.py
@@ -1,4 +1,3 @@
-
 def get_flatland_full_action_size():
     # The action space of flatland is 5 discrete actions
     return 5
@@ -9,9 +8,9 @@ def get_action_size():
     return 4
 
 
-def map_actions(actions, action_size):
+def map_actions(actions):
     # Map the
-    if action_size == get_flatland_full_action_size():
+    if get_action_size() == get_flatland_full_action_size():
         return actions
     for key in actions:
         value = actions.get(key, 0)
@@ -19,7 +18,7 @@ def map_actions(actions, action_size):
     return actions
 
 
-def map_action(action, action_size):
-    if action_size == get_flatland_full_action_size():
+def map_action(action):
+    if get_action_size() == get_flatland_full_action_size():
         return action
     return action + 1