Skip to content
Snippets Groups Projects
Commit cf80f503 authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files

refactored and added new agent

parent 1c60b970
No related branches found
No related tags found
No related merge requests found
...@@ -267,8 +267,8 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): ...@@ -267,8 +267,8 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
if info['action_required'][agent_handle]: if info['action_required'][agent_handle]:
update_values[agent_handle] = True update_values[agent_handle] = True
action = policy.act(agent_handle, agent_obs[agent_handle], eps=eps_start) action = policy.act(agent_handle, agent_obs[agent_handle], eps=eps_start)
action_count[map_action(action, get_action_size())] += 1 action_count[map_action(action)] += 1
actions_taken.append(map_action(action, get_action_size())) actions_taken.append(map_action(action))
else: else:
# An action is not required if the train hasn't joined the railway network, # An action is not required if the train hasn't joined the railway network,
# if it already reached its target, or if is currently malfunctioning. # if it already reached its target, or if is currently malfunctioning.
...@@ -280,7 +280,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): ...@@ -280,7 +280,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
# Environment step # Environment step
step_timer.start() step_timer.start()
next_obs, all_rewards, done, info = train_env.step(map_actions(action_dict, get_action_size())) next_obs, all_rewards, done, info = train_env.step(map_actions(action_dict))
# Reward shaping .Dead-lock .NotMoving .NotStarted # Reward shaping .Dead-lock .NotMoving .NotStarted
if False: if False:
...@@ -288,7 +288,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params): ...@@ -288,7 +288,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params):
for agent_handle in train_env.get_agent_handles(): for agent_handle in train_env.get_agent_handles():
agent = train_env.agents[agent_handle] agent = train_env.agents[agent_handle]
act = action_dict.get(agent_handle, RailEnvActions.DO_NOTHING) act = action_dict.get(agent_handle, RailEnvActions.DO_NOTHING)
act = map_action(act, get_action_size()) act = map_action(act)
if agent.status == RailAgentStatus.ACTIVE: if agent.status == RailAgentStatus.ACTIVE:
all_rewards[agent_handle] = 0.0 all_rewards[agent_handle] = 0.0
if done[agent_handle] == False: if done[agent_handle] == False:
...@@ -494,7 +494,7 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params): ...@@ -494,7 +494,7 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params):
action = policy.act(agent, agent_obs[agent], eps=0.0) action = policy.act(agent, agent_obs[agent], eps=0.0)
action_dict.update({agent: action}) action_dict.update({agent: action})
policy.end_step(train=False) policy.end_step(train=False)
obs, all_rewards, done, info = env.step(map_actions(action_dict, get_action_size())) obs, all_rewards, done, info = env.step(map_actions(action_dict))
for agent in env.get_agent_handles(): for agent in env.get_agent_handles():
score += all_rewards[agent] score += all_rewards[agent]
......
...@@ -214,7 +214,7 @@ while True: ...@@ -214,7 +214,7 @@ while True:
time_taken_by_controller.append(agent_time) time_taken_by_controller.append(agent_time)
time_start = time.time() time_start = time.time()
_, all_rewards, done, info = remote_client.env_step(map_actions(action_dict, get_action_size)) _, all_rewards, done, info = remote_client.env_step(map_actions(action_dict))
step_time = time.time() - time_start step_time = time.time() - time_start
time_taken_per_step.append(step_time) time_taken_per_step.append(step_time)
......
def get_flatland_full_action_size(): def get_flatland_full_action_size():
# The action space of flatland is 5 discrete actions # The action space of flatland is 5 discrete actions
return 5 return 5
...@@ -9,9 +8,9 @@ def get_action_size(): ...@@ -9,9 +8,9 @@ def get_action_size():
return 4 return 4
def map_actions(actions, action_size): def map_actions(actions):
# Map the # Map the
if action_size == get_flatland_full_action_size(): if get_action_size() == get_flatland_full_action_size():
return actions return actions
for key in actions: for key in actions:
value = actions.get(key, 0) value = actions.get(key, 0)
...@@ -19,7 +18,7 @@ def map_actions(actions, action_size): ...@@ -19,7 +18,7 @@ def map_actions(actions, action_size):
return actions return actions
def map_action(action, action_size): def map_action(action):
if action_size == get_flatland_full_action_size(): if get_action_size() == get_flatland_full_action_size():
return action return action
return action + 1 return action + 1
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment