Skip to content
Snippets Groups Projects
Commit 0ecec475 authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

initial commit for speed tests and multi speed initialization. waiting for other merges first

parent 637e7ef1
No related branches found
No related tags found
No related merge requests found
...@@ -26,7 +26,7 @@ class RandomAgent: ...@@ -26,7 +26,7 @@ class RandomAgent:
:param state: input is the observation of the agent :param state: input is the observation of the agent
:return: returns an action :return: returns an action
""" """
return np.random.choice(np.arange(self.action_size)) return np.random.choice([1, 2, 3])
def step(self, memories): def step(self, memories):
""" """
...@@ -50,34 +50,40 @@ class RandomAgent: ...@@ -50,34 +50,40 @@ class RandomAgent:
agent = RandomAgent(218, 4) agent = RandomAgent(218, 4)
n_trials = 5 n_trials = 5
# Empty dictionary for all agent action # Empty dictionary for all agent action
action_dict = dict() action_dict = dict()
# Set all the different speeds
def test_multi_speed_init(): def test_multi_speed_init():
# Reset environment and get initial observations for all agents # Reset environment and get initial observations for all agents
obs = env.reset() env.reset()
# Here you can also further enhance the provided observation by means of normalization # Here you can also further enhance the provided observation by means of normalization
# See training navigation example in the baseline repository # See training navigation example in the baseline repository
old_pos = []
for i_agent in range(env.get_num_agents()): for i_agent in range(env.get_num_agents()):
env.agents[i_agent].speed_data['speed'] = 1. / np.random.randint(1, 10) env.agents[i_agent].speed_data['speed'] = 1. / (i_agent + 1)
old_pos.append(env.agents[i_agent].position)
score = 0 score = 0
# Run episode # Run episode
for step in range(100): for step in range(100):
# Chose an action for each agent in the environment # Chose an action for each agent in the environment
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
action = agent.act(obs[a]) action = agent.act(0)
action_dict.update({a: action}) action_dict.update({a: action})
# Check that agent did not move inbetween its speed updates
assert old_pos[a] == env.agents[a].position
# Environment step which returns the observations for all agents, their corresponding # Environment step which returns the observations for all agents, their corresponding
# reward and whether their are done # reward and whether their are done
next_obs, all_rewards, done, _ = env.step(action_dict) _, _, _, _ = env.step(action_dict)
# Update replay buffer and train agent # Update old position
for a in range(env.get_num_agents()): for i_agent in range(env.get_num_agents()):
agent.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a])) if (step + 1) % (i_agent + 1) == 0:
score += all_rewards[a] print(step, i_agent, env.agents[a].position)
obs = next_obs.copy() old_pos[i_agent] = env.agents[i_agent].position
if done['__all__']:
break
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment