From cf80f503ab12eec9603e4f2b46fd240b694877eb Mon Sep 17 00:00:00 2001 From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch> Date: Wed, 16 Dec 2020 15:05:54 +0100 Subject: [PATCH] refactored and added new agent --- reinforcement_learning/multi_agent_training.py | 10 +++++----- run.py | 2 +- utils/agent_action_config.py | 9 ++++----- 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py index e2ea4bf..5cb6ba1 100755 --- a/reinforcement_learning/multi_agent_training.py +++ b/reinforcement_learning/multi_agent_training.py @@ -267,8 +267,8 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): if info['action_required'][agent_handle]: update_values[agent_handle] = True action = policy.act(agent_handle, agent_obs[agent_handle], eps=eps_start) - action_count[map_action(action, get_action_size())] += 1 - actions_taken.append(map_action(action, get_action_size())) + action_count[map_action(action)] += 1 + actions_taken.append(map_action(action)) else: # An action is not required if the train hasn't joined the railway network, # if it already reached its target, or if is currently malfunctioning. @@ -280,7 +280,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): # Environment step step_timer.start() - next_obs, all_rewards, done, info = train_env.step(map_actions(action_dict, get_action_size())) + next_obs, all_rewards, done, info = train_env.step(map_actions(action_dict)) # Reward shaping .Dead-lock .NotMoving .NotStarted if False: @@ -288,7 +288,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): for agent_handle in train_env.get_agent_handles(): agent = train_env.agents[agent_handle] act = action_dict.get(agent_handle, RailEnvActions.DO_NOTHING) - act = map_action(act, get_action_size()) + act = map_action(act) if agent.status == RailAgentStatus.ACTIVE: all_rewards[agent_handle] = 0.0 if done[agent_handle] == False: @@ -494,7 +494,7 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params): action = policy.act(agent, agent_obs[agent], eps=0.0) action_dict.update({agent: action}) policy.end_step(train=False) - obs, all_rewards, done, info = env.step(map_actions(action_dict, get_action_size())) + obs, all_rewards, done, info = env.step(map_actions(action_dict)) for agent in env.get_agent_handles(): score += all_rewards[agent] diff --git a/run.py b/run.py index 998add0..8e0535b 100644 --- a/run.py +++ b/run.py @@ -214,7 +214,7 @@ while True: time_taken_by_controller.append(agent_time) time_start = time.time() - _, all_rewards, done, info = remote_client.env_step(map_actions(action_dict, get_action_size)) + _, all_rewards, done, info = remote_client.env_step(map_actions(action_dict)) step_time = time.time() - time_start time_taken_per_step.append(step_time) diff --git a/utils/agent_action_config.py b/utils/agent_action_config.py index dceba55..3a84875 100644 --- a/utils/agent_action_config.py +++ b/utils/agent_action_config.py @@ -1,4 +1,3 @@ - def get_flatland_full_action_size(): # The action space of flatland is 5 discrete actions return 5 @@ -9,9 +8,9 @@ def get_action_size(): return 4 -def map_actions(actions, action_size): +def map_actions(actions): # Map the - if action_size == get_flatland_full_action_size(): + if get_action_size() == get_flatland_full_action_size(): return actions for key in actions: value = actions.get(key, 0) @@ -19,7 +18,7 @@ def map_actions(actions, action_size): return actions -def map_action(action, action_size): - if action_size == get_flatland_full_action_size(): +def map_action(action): + if get_action_size() == get_flatland_full_action_size(): return action return action + 1 -- GitLab