diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index ceedc90a95f8a434be67a7533873d3eb00154537..916e50b20b10a02c43c5b1da8bc0728930b8c535 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -8,7 +8,7 @@ from flatland.utils.rendertools import RenderTool np.random.seed(1) -# Use the complex_rail_generator to generate feasible network configurations with corresponding tasks +# Use the new sparse_rail_generator to generate feasible network configurations with corresponding tasks # Training on simple small tasks is the best way to get familiar with the environment # Use a the malfunction generator to break agents from time to time @@ -22,7 +22,7 @@ TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictor env = RailEnv(width=20, height=20, rail_generator=sparse_rail_generator(num_cities=2, # Number of cities in map (where train stations are) - num_intersections=1, # Number of interesections (no start / target) + num_intersections=1, # Number of intersections (no start / target) num_trainstations=15, # Number of possible start/targets on map min_node_dist=3, # Minimal distance of nodes node_radius=3, # Proximity of stations to city center @@ -32,16 +32,14 @@ env = RailEnv(width=20, enhance_intersection=True ), number_of_agents=5, - stochastic_data=stochastic_data, # Malfunction generator data + stochastic_data=stochastic_data, # Malfunction data generator obs_builder_object=TreeObservation) env_renderer = RenderTool(env, gl="PILSVG", ) # Import your own Agent or use RLlib to train agents on Flatland -# As an example we use a random agent here - - +# As an example we use a random agent instead class RandomAgent: def __init__(self, state_size, action_size): @@ -76,48 +74,46 @@ class RandomAgent: # Initialize the agent with the parameters corresponding to the environment and observation_builder # Set action space to 4 to remove stop action agent = RandomAgent(218, 4) -n_trials = 1 + # Empty dictionary for all agent action action_dict = dict() -print("Starting Training...") - -for trials in range(1, n_trials + 1): - - # Reset environment and get initial observations for all agents - obs = env.reset() - for idx in range(env.get_num_agents()): - tmp_agent = env.agents[idx] - speed = (idx % 5) + 1 - tmp_agent.speed_data["speed"] = 1 / speed - env_renderer.reset() - # Here you can also further enhance the provided observation by means of normalization - # See training navigation example in the baseline repository - - score = 0 - # Run episode - frame_step = 0 - for step in range(500): - # Chose an action for each agent in the environment - for a in range(env.get_num_agents()): - action = agent.act(obs[a]) - action_dict.update({a: action}) - - # Environment step which returns the observations for all agents, their corresponding - # reward and whether their are done - next_obs, all_rewards, done, _ = env.step(action_dict) - env_renderer.render_env(show=True, show_observations=False, show_predictions=False) - try: - env_renderer.gl.save_image("./../rendering/flatland_2_0_frame_{:04d}.bmp".format(frame_step)) - except: - print("Path not found: ./../rendering/") - frame_step += 1 - # Update replay buffer and train agent - for a in range(env.get_num_agents()): - agent.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a])) - score += all_rewards[a] - - obs = next_obs.copy() - if done['__all__']: - break - print('Episode Nr. {}\t Score = {}'.format(trials, score)) + +print("Start episode...") +# Reset environment and get initial observations for all agents +obs = env.reset() +# Update/Set agent's speed +for idx in range(env.get_num_agents()): + speed = 1.0 / ((idx % 5) + 1.0) + env.agents[idx].speed_data["speed"] = speed + +# Reset the rendering sytem +env_renderer.reset() + +# Here you can also further enhance the provided observation by means of normalization +# See training navigation example in the baseline repository + +score = 0 +# Run episode +frame_step = 0 +for step in range(500): + # Chose an action for each agent in the environment + for a in range(env.get_num_agents()): + action = agent.act(obs[a]) + action_dict.update({a: action}) + + # Environment step which returns the observations for all agents, their corresponding + # reward and whether their are done + next_obs, all_rewards, done, _ = env.step(action_dict) + env_renderer.render_env(show=True, show_observations=False, show_predictions=False) + frame_step += 1 + # Update replay buffer and train agent + for a in range(env.get_num_agents()): + agent.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a])) + score += all_rewards[a] + + obs = next_obs.copy() + if done['__all__']: + break + +print('Episode: Steps {}\t Score = {}'.format(step, score))