Skip to content
Snippets Groups Projects
Commit 0ecec475 authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

initial commit for speed tests and multi speed initialization. waiting for other merges first

parent 637e7ef1
No related branches found
No related tags found
No related merge requests found
......@@ -26,7 +26,7 @@ class RandomAgent:
:param state: input is the observation of the agent
:return: returns an action
"""
return np.random.choice(np.arange(self.action_size))
return np.random.choice([1, 2, 3])
def step(self, memories):
"""
......@@ -50,34 +50,40 @@ class RandomAgent:
agent = RandomAgent(218, 4)
n_trials = 5
# Empty dictionary for all agent action
action_dict = dict()
# Set all the different speeds
def test_multi_speed_init():
# Reset environment and get initial observations for all agents
obs = env.reset()
env.reset()
# Here you can also further enhance the provided observation by means of normalization
# See training navigation example in the baseline repository
old_pos = []
for i_agent in range(env.get_num_agents()):
env.agents[i_agent].speed_data['speed'] = 1. / np.random.randint(1, 10)
env.agents[i_agent].speed_data['speed'] = 1. / (i_agent + 1)
old_pos.append(env.agents[i_agent].position)
score = 0
# Run episode
for step in range(100):
# Chose an action for each agent in the environment
for a in range(env.get_num_agents()):
action = agent.act(obs[a])
action = agent.act(0)
action_dict.update({a: action})
# Check that agent did not move inbetween its speed updates
assert old_pos[a] == env.agents[a].position
# Environment step which returns the observations for all agents, their corresponding
# reward and whether their are done
next_obs, all_rewards, done, _ = env.step(action_dict)
_, _, _, _ = env.step(action_dict)
# Update replay buffer and train agent
for a in range(env.get_num_agents()):
agent.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a]))
score += all_rewards[a]
# Update old position
for i_agent in range(env.get_num_agents()):
if (step + 1) % (i_agent + 1) == 0:
print(step, i_agent, env.agents[a].position)
obs = next_obs.copy()
if done['__all__']:
break
old_pos[i_agent] = env.agents[i_agent].position
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment