Skip to content
Snippets Groups Projects
Commit 73886ef8 authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

updated introduction file

parent 940b8132
No related branches found
Tags v0.3.5
No related merge requests found
import time
# In Flatland you can use custom observation builders and predicitors
# Observation builders generate the observation needed by the controller
# Preditctors can be used to do short time prediction which can help in avoiding conflicts in the network
......@@ -94,7 +92,7 @@ env_renderer = RenderTool(env, gl="PILSVG",
# We first look at the map we have created
# nv_renderer.render_env(show=True)
#time.sleep(2)
# time.sleep(2)
# Import your own Agent or use RLlib to train agents on Flatland
# As an example we use a random agent instead
class RandomAgent:
......@@ -127,6 +125,7 @@ class RandomAgent:
# Load a policy
return
# Initialize the agent with the parameters corresponding to the environment and observation_builder
controller = RandomAgent(218, env.action_space[0])
......@@ -154,7 +153,7 @@ for agent_idx, agent in enumerate(env.agents):
# Let's check if there are any agents with the same start location
agents_with_same_start = []
print("\n The following agents have the same initial position:")
print("============================")
print("=====================================================")
for agent_idx, agent in enumerate(env.agents):
for agent_2_idx, agent2 in enumerate(env.agents):
if agent_idx != agent_2_idx and agent.initial_position == agent2.initial_position:
......@@ -162,16 +161,10 @@ for agent_idx, agent in enumerate(env.agents):
agents_with_same_start.append(agent_idx)
# Lets try to enter with all of these agents at the same time
action_dict = {}
action_dict = dict()
for agent_id in agents_with_same_start:
action_dict[agent_id] = 1 # Set agents to moving
print("\n This happened when all tried to enter at the same time:")
print("========================================================")
for agent_id in agents_with_same_start:
print("Agent {} status is: {} with its current position being {}".format(agent_id, str(env.agents[agent_id].status),
str(env.agents[agent_id].position)))
action_dict[agent_id] = 1 # Try to move with the agents
# Do a step in the environment to see what agents entered:
env.step(action_dict)
......@@ -181,20 +174,62 @@ print("\n This happened when all tried to enter at the same time:")
print("========================================================")
for agent_id in agents_with_same_start:
print(
"Agent {} status is: {} with its current position being {} which is the same as the start position {} and orientation {}".format(
"Agent {} status is: {} with the current position being {}.".format(
agent_id, str(env.agents[agent_id].status),
str(env.agents[agent_id].position), env.agents[agent_id].initial_position, env.agents[agent_id].direction))
# Empty dictionary for all agent action
action_dict = dict()
str(env.agents[agent_id].position)))
# As you see only the agents with lower indexes moved. As soon as the cell is free again the agents can attempt
# to start again.
# You will also notice, that the agents move at different speeds once they are on the rail.
# The agents will always move at full speed when moving, never a speed inbetween.
# The fastest an agent can go is 1, meaning that it moves to the next cell at every time step
# All slower speeds indicate the fraction of a cell that is moved at each time step
# Lets look at the current speed data of the agents:
print("\n The speed information of the agents are:")
print("=========================================")
for agent_idx, agent in enumerate(env.agents):
print(
"Agent {} speed is: {:.2f} with the current fractional position being {}".format(
agent_idx, agent.speed_data['speed'], agent.speed_data['position_fraction']))
# New the agents can also have stochastic malfunctions happening which will lead to them being unable to move
# for a certain amount of time steps. The malfunction data of the agents can easily be accessed as follows
print("\n The malfunction data of the agents are:")
print("========================================")
for agent_idx, agent in enumerate(env.agents):
print(
"Agent {} will malfunction = {} at a rate of {}, the next malfunction will occur in {} step. Agent OK = {}".format(
agent_idx, agent.malfunction_data['malfunction_rate'] > 0, agent.malfunction_data['malfunction_rate'],
agent.malfunction_data['next_malfunction'], agent.malfunction_data['malfunction'] < 1))
# Now that you have seen these novel concepts that were introduced you will realize that agents don't need to take
# an action at every time step as it will only change the outcome when actions are chosen at cell entry.
# Therefore the environment provides information about what agents need to provide an action in the next step.
# You can access this in the following way.
# Chose an action for each agent
for a in range(env.get_num_agents()):
action = controller.act(0)
action_dict.update({a: action})
# Do the environment step
observations, rewards, dones, information = env.step(action_dict)
print("\n Thefollowing agents can register an action:")
print("========================================")
print(information['action_required'])
# We recommend that you monitor the malfunction data and the action required in order to optimize your training
# and controlling code.
# Let us now look at an episode playing out with random actions performed
print("Start episode...")
# Reset environment and get initial observations for all agents
start_reset = time.time()
obs, info = env.reset()
end_reset = time.time()
print(end_reset - start_reset)
print(env.get_num_agents(), )
# Reset the rendering sytem
# Reset the rendering system
env_renderer.reset()
# Here you can also further enhance the provided observation by means of normalization
......@@ -206,7 +241,7 @@ frame_step = 0
for step in range(500):
# Chose an action for each agent in the environment
for a in range(env.get_num_agents()):
action = controller.act(obs[a])
action = controller.act(observations[a])
action_dict.update({a: action})
# Environment step which returns the observations for all agents, their corresponding
......@@ -216,11 +251,11 @@ for step in range(500):
frame_step += 1
# Update replay buffer and train agent
for a in range(env.get_num_agents()):
controller.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a]))
controller.step((observations[a], action_dict[a], all_rewards[a], next_obs[a], done[a]))
score += all_rewards[a]
obs = next_obs.copy()
observations = next_obs.copy()
if done['__all__']:
break
print('Episode: Steps {}\t Score = {}'.format(step, score))
print('Episode: Steps {}\t Score = {}'.format(step, score))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment