diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py index e2ea4bfea061fc42e21978a580e25e8e8139449b..5cb6ba1d633e8997b255447521a2938270d6d173 100755 --- a/reinforcement_learning/multi_agent_training.py +++ b/reinforcement_learning/multi_agent_training.py @@ -267,8 +267,8 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): if info['action_required'][agent_handle]: update_values[agent_handle] = True action = policy.act(agent_handle, agent_obs[agent_handle], eps=eps_start) - action_count[map_action(action, get_action_size())] += 1 - actions_taken.append(map_action(action, get_action_size())) + action_count[map_action(action)] += 1 + actions_taken.append(map_action(action)) else: # An action is not required if the train hasn't joined the railway network, # if it already reached its target, or if is currently malfunctioning. @@ -280,7 +280,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): # Environment step step_timer.start() - next_obs, all_rewards, done, info = train_env.step(map_actions(action_dict, get_action_size())) + next_obs, all_rewards, done, info = train_env.step(map_actions(action_dict)) # Reward shaping .Dead-lock .NotMoving .NotStarted if False: @@ -288,7 +288,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): for agent_handle in train_env.get_agent_handles(): agent = train_env.agents[agent_handle] act = action_dict.get(agent_handle, RailEnvActions.DO_NOTHING) - act = map_action(act, get_action_size()) + act = map_action(act) if agent.status == RailAgentStatus.ACTIVE: all_rewards[agent_handle] = 0.0 if done[agent_handle] == False: @@ -494,7 +494,7 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params): action = policy.act(agent, agent_obs[agent], eps=0.0) action_dict.update({agent: action}) policy.end_step(train=False) - obs, all_rewards, done, info = env.step(map_actions(action_dict, get_action_size())) + obs, all_rewards, done, info = env.step(map_actions(action_dict)) for agent in env.get_agent_handles(): score += all_rewards[agent] diff --git a/run.py b/run.py index 998add015fca7d6b57b492131249076bd2382367..8e0535b3ebebd14b81e98ab23d5d9228c0b6fea8 100644 --- a/run.py +++ b/run.py @@ -214,7 +214,7 @@ while True: time_taken_by_controller.append(agent_time) time_start = time.time() - _, all_rewards, done, info = remote_client.env_step(map_actions(action_dict, get_action_size)) + _, all_rewards, done, info = remote_client.env_step(map_actions(action_dict)) step_time = time.time() - time_start time_taken_per_step.append(step_time) diff --git a/utils/agent_action_config.py b/utils/agent_action_config.py index dceba553310e2d4e47a5554da267c11dfa338ee1..3a84875297cacce023998d21f61ea0abcc76c9d6 100644 --- a/utils/agent_action_config.py +++ b/utils/agent_action_config.py @@ -1,4 +1,3 @@ - def get_flatland_full_action_size(): # The action space of flatland is 5 discrete actions return 5 @@ -9,9 +8,9 @@ def get_action_size(): return 4 -def map_actions(actions, action_size): +def map_actions(actions): # Map the - if action_size == get_flatland_full_action_size(): + if get_action_size() == get_flatland_full_action_size(): return actions for key in actions: value = actions.get(key, 0) @@ -19,7 +18,7 @@ def map_actions(actions, action_size): return actions -def map_action(action, action_size): - if action_size == get_flatland_full_action_size(): +def map_action(action): + if get_action_size() == get_flatland_full_action_size(): return action return action + 1