diff --git a/examples/introduction_flatland_2_1_1.py b/examples/introduction_flatland_2_1_1.py index 99fd45a71d2b04166f54154fb096d8bdb876ae7f..2e21e1bbb3495a7cb0eb29eeb1e95233e5d0eadc 100644 --- a/examples/introduction_flatland_2_1_1.py +++ b/examples/introduction_flatland_2_1_1.py @@ -1,5 +1,3 @@ -import time - # In Flatland you can use custom observation builders and predicitors # Observation builders generate the observation needed by the controller # Preditctors can be used to do short time prediction which can help in avoiding conflicts in the network @@ -94,7 +92,7 @@ env_renderer = RenderTool(env, gl="PILSVG", # We first look at the map we have created # nv_renderer.render_env(show=True) -#time.sleep(2) +# time.sleep(2) # Import your own Agent or use RLlib to train agents on Flatland # As an example we use a random agent instead class RandomAgent: @@ -127,6 +125,7 @@ class RandomAgent: # Load a policy return + # Initialize the agent with the parameters corresponding to the environment and observation_builder controller = RandomAgent(218, env.action_space[0]) @@ -154,7 +153,7 @@ for agent_idx, agent in enumerate(env.agents): # Let's check if there are any agents with the same start location agents_with_same_start = [] print("\n The following agents have the same initial position:") -print("============================") +print("=====================================================") for agent_idx, agent in enumerate(env.agents): for agent_2_idx, agent2 in enumerate(env.agents): if agent_idx != agent_2_idx and agent.initial_position == agent2.initial_position: @@ -162,16 +161,10 @@ for agent_idx, agent in enumerate(env.agents): agents_with_same_start.append(agent_idx) # Lets try to enter with all of these agents at the same time -action_dict = {} +action_dict = dict() for agent_id in agents_with_same_start: - action_dict[agent_id] = 1 # Set agents to moving - -print("\n This happened when all tried to enter at the same time:") -print("========================================================") -for agent_id in agents_with_same_start: - print("Agent {} status is: {} with its current position being {}".format(agent_id, str(env.agents[agent_id].status), - str(env.agents[agent_id].position))) + action_dict[agent_id] = 1 # Try to move with the agents # Do a step in the environment to see what agents entered: env.step(action_dict) @@ -181,20 +174,62 @@ print("\n This happened when all tried to enter at the same time:") print("========================================================") for agent_id in agents_with_same_start: print( - "Agent {} status is: {} with its current position being {} which is the same as the start position {} and orientation {}".format( + "Agent {} status is: {} with the current position being {}.".format( agent_id, str(env.agents[agent_id].status), - str(env.agents[agent_id].position), env.agents[agent_id].initial_position, env.agents[agent_id].direction)) -# Empty dictionary for all agent action -action_dict = dict() + str(env.agents[agent_id].position))) + +# As you see only the agents with lower indexes moved. As soon as the cell is free again the agents can attempt +# to start again. + +# You will also notice, that the agents move at different speeds once they are on the rail. +# The agents will always move at full speed when moving, never a speed inbetween. +# The fastest an agent can go is 1, meaning that it moves to the next cell at every time step +# All slower speeds indicate the fraction of a cell that is moved at each time step +# Lets look at the current speed data of the agents: + +print("\n The speed information of the agents are:") +print("=========================================") + +for agent_idx, agent in enumerate(env.agents): + print( + "Agent {} speed is: {:.2f} with the current fractional position being {}".format( + agent_idx, agent.speed_data['speed'], agent.speed_data['position_fraction'])) + +# New the agents can also have stochastic malfunctions happening which will lead to them being unable to move +# for a certain amount of time steps. The malfunction data of the agents can easily be accessed as follows +print("\n The malfunction data of the agents are:") +print("========================================") + +for agent_idx, agent in enumerate(env.agents): + print( + "Agent {} will malfunction = {} at a rate of {}, the next malfunction will occur in {} step. Agent OK = {}".format( + agent_idx, agent.malfunction_data['malfunction_rate'] > 0, agent.malfunction_data['malfunction_rate'], + agent.malfunction_data['next_malfunction'], agent.malfunction_data['malfunction'] < 1)) + +# Now that you have seen these novel concepts that were introduced you will realize that agents don't need to take +# an action at every time step as it will only change the outcome when actions are chosen at cell entry. +# Therefore the environment provides information about what agents need to provide an action in the next step. +# You can access this in the following way. + +# Chose an action for each agent +for a in range(env.get_num_agents()): + action = controller.act(0) + action_dict.update({a: action}) + +# Do the environment step +observations, rewards, dones, information = env.step(action_dict) +print("\n Thefollowing agents can register an action:") +print("========================================") +print(information['action_required']) + +# We recommend that you monitor the malfunction data and the action required in order to optimize your training +# and controlling code. + +# Let us now look at an episode playing out with random actions performed print("Start episode...") -# Reset environment and get initial observations for all agents -start_reset = time.time() -obs, info = env.reset() -end_reset = time.time() -print(end_reset - start_reset) -print(env.get_num_agents(), ) -# Reset the rendering sytem + +# Reset the rendering system env_renderer.reset() # Here you can also further enhance the provided observation by means of normalization @@ -206,7 +241,7 @@ frame_step = 0 for step in range(500): # Chose an action for each agent in the environment for a in range(env.get_num_agents()): - action = controller.act(obs[a]) + action = controller.act(observations[a]) action_dict.update({a: action}) # Environment step which returns the observations for all agents, their corresponding @@ -216,11 +251,11 @@ for step in range(500): frame_step += 1 # Update replay buffer and train agent for a in range(env.get_num_agents()): - controller.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a])) + controller.step((observations[a], action_dict[a], all_rewards[a], next_obs[a], done[a])) score += all_rewards[a] - obs = next_obs.copy() + observations = next_obs.copy() if done['__all__']: break -print('Episode: Steps {}\t Score = {}'.format(step, score)) + print('Episode: Steps {}\t Score = {}'.format(step, score))